diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 4f32860..35251ee 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -5,7 +5,8 @@ edition = "2021" [dependencies] logos = "0.14" -winnow = { version = "0.6.18", features = ["debug"] } +winnow = { version = "0.6.18" } gen = { path = "../gen" } thiserror = "1.0" bitflags = "1.2" +rustc-hash = "2.0.0" diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ab0fc58..0dabd5d 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -557,7 +557,7 @@ pub struct SetpData { impl SetpData { pub(crate) fn try_parse( - errors: &mut PtxParserState, + state: &mut PtxParserState, cmp_op: super::RawSetpCompareOp, ftz: bool, type_: ScalarType, @@ -565,7 +565,7 @@ impl SetpData { let flush_to_zero = match (ftz, type_) { (_, ScalarType::F32) => Some(ftz), _ => { - errors.push(PtxError::NonF32Ftz); + state.errors.push(PtxError::NonF32Ftz); None } }; @@ -576,7 +576,7 @@ impl SetpData { match SetpCompareInt::try_from(cmp_op) { Ok(op) => SetpCompareOp::Integer(op), Err(err) => { - errors.push(err); + state.errors.push(err); SetpCompareOp::Integer(SetpCompareInt::Eq) } } @@ -682,36 +682,52 @@ impl From for SetpCompareFloat { } pub struct CallDetails { - uniform: bool, - ret_params: Vec<(Type, StateSpace)>, - param_list: Vec<(Type, StateSpace)>, + pub uniform: bool, + pub return_arguments: Vec<(Type, StateSpace)>, + pub input_arguments: Vec<(Type, StateSpace)>, } pub struct CallArgs { - pub ret_params: Vec, + pub return_arguments: Vec, pub func: T::Ident, - pub param_list: Vec, + pub input_arguments: Vec, } impl CallArgs { #[allow(dead_code)] // Used by generated code fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor) { - for (param, (type_, space)) in self.ret_params.iter().zip(details.ret_params.iter()) { + for (param, (type_, space)) in self + .return_arguments + .iter() + .zip(details.return_arguments.iter()) + { visitor.visit_ident(param, Some((type_, *space)), true); } visitor.visit_ident(&self.func, None, false); - for (param, (type_, space)) in self.param_list.iter().zip(details.param_list.iter()) { + for (param, (type_, space)) in self + .input_arguments + .iter() + .zip(details.input_arguments.iter()) + { visitor.visit(param, Some((type_, *space)), true); } } #[allow(dead_code)] // Used by generated code fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut) { - for (param, (type_, space)) in self.ret_params.iter_mut().zip(details.ret_params.iter()) { + for (param, (type_, space)) in self + .return_arguments + .iter_mut() + .zip(details.return_arguments.iter()) + { visitor.visit_ident(param, Some((type_, *space)), true); } visitor.visit_ident(&mut self.func, None, false); - for (param, (type_, space)) in self.param_list.iter_mut().zip(details.param_list.iter()) { + for (param, (type_, space)) in self + .input_arguments + .iter_mut() + .zip(details.input_arguments.iter()) + { visitor.visit(param, Some((type_, *space)), true); } } @@ -722,23 +738,23 @@ impl CallArgs { details: &CallDetails, visitor: &mut impl VisitorMap, ) -> CallArgs { - let ret_params = self - .ret_params + let return_arguments = self + .return_arguments .into_iter() - .zip(details.ret_params.iter()) + .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) .collect::>(); let func = visitor.visit_ident(self.func, None, false); - let param_list = self - .param_list + let input_arguments = self + .input_arguments .into_iter() - .zip(details.param_list.iter()) + .zip(details.input_arguments.iter()) .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) .collect::>(); CallArgs { - ret_params, + return_arguments, func, - param_list, + input_arguments, } } } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index a6a2381..2c602d5 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1,5 +1,7 @@ use gen::derive_parser; use logos::Logos; +use rustc_hash::FxHashMap; +use std::fmt::Debug; use std::mem; use std::num::{ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; @@ -69,8 +71,49 @@ impl From for ast::RoundingMode { } } -type PtxParserState = Vec; -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>; +struct PtxParserState<'input> { + errors: Vec, + function_declarations: + FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, +} + +impl<'input> PtxParserState<'input> { + fn new() -> Self { + Self { + errors: Vec::new(), + function_declarations: FxHashMap::default(), + } + } + + fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) { + let name = match function_decl.name { + MethodName::Kernel(name) => name, + MethodName::Func(name) => name, + }; + let return_arguments = Self::get_type_space(&*function_decl.return_arguments); + let input_arguments = Self::get_type_space(&*function_decl.input_arguments); + // TODO: check if declarations match + self.function_declarations + .insert(name, (return_arguments, input_arguments)); + } + + fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { + input_arguments + .iter() + .map(|var| (var.v_type.clone(), var.state_space)) + .collect::>() + } +} + +impl<'input> Debug for PtxParserState<'input> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PtxParserState") + .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ + .finish() + } +} + +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { @@ -127,7 +170,7 @@ fn take_error<'a, 'input: 'a, O, E>( Ok(match parser.parse_next(input)? { Ok(x) => x, Err((x, err)) => { - input.state.push(err); + input.state.errors.push(err); x } }) @@ -353,7 +396,7 @@ fn function<'a, 'input>( ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>, )> { - ( + let (linking, function) = ( linking_directives, method_declaration, repeat(0.., tuning_directive), @@ -369,7 +412,9 @@ fn function<'a, 'input>( }, ) }) - .parse_next(stream) + .parse_next(stream)?; + stream.state.record_function(&function.func_directive); + Ok((linking, function)) } fn linking_directives<'a, 'input>( @@ -771,6 +816,10 @@ pub enum PtxError { #[error("")] WrongType, #[error("")] + UnknownFunction, + #[error("")] + MalformedCall, + #[error("")] WrongArrayType, #[error("")] WrongVectorElement, @@ -903,6 +952,74 @@ fn bra<'a, 'input>( .parse_next(stream) } +fn call<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + let (uni, return_arguments, name, input_arguments) = ( + opt(Token::DotUni), + opt(( + Token::LParen, + separated(1.., ident, Token::Comma).map(|x: Vec<_>| x), + Token::RParen, + Token::Comma, + ) + .map(|(_, arguments, _, _)| arguments)), + ident, + opt(( + Token::Comma.void(), + Token::LParen.void(), + separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x), + Token::RParen.void(), + ) + .map(|(_, _, arguments, _)| arguments)), + ) + .parse_next(stream)?; + let uniform = uni.is_some(); + let recorded_fn = match stream.state.function_declarations.get(name) { + Some(decl) => decl, + None => { + stream.state.errors.push(PtxError::UnknownFunction); + return Ok(empty_call(uniform, name)); + } + }; + let return_arguments = return_arguments.unwrap_or(Vec::new()); + let input_arguments = input_arguments.unwrap_or(Vec::new()); + if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len() + { + stream.state.errors.push(PtxError::MalformedCall); + return Ok(empty_call(uniform, name)); + } + let data = CallDetails { + uniform, + return_arguments: recorded_fn.0.clone(), + input_arguments: recorded_fn.1.clone(), + }; + let arguments = CallArgs { + return_arguments, + func: name, + input_arguments, + }; + Ok(ast::Instruction::Call { data, arguments }) +} + +fn empty_call<'input>( + uniform: bool, + name: &'input str, +) -> ast::Instruction> { + ast::Instruction::Call { + data: CallDetails { + uniform, + return_arguments: Vec::new(), + input_arguments: Vec::new(), + }, + arguments: CallArgs { + return_arguments: Vec::new(), + func: name, + input_arguments: Vec::new(), + }, + } +} + // Modifiers are turned into arguments to the blocks, with type: // * If it is an alternative: // * If it is mandatory then its type is Foo (as defined by the relevant rule) @@ -1033,7 +1150,7 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1058,7 +1175,7 @@ derive_parser!( } st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1072,7 +1189,7 @@ derive_parser!( } st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::St { data: StData { @@ -1085,7 +1202,7 @@ derive_parser!( } } st.mmio.relaxed.sys{.global}.type [a], b => { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); Instruction::St { data: ast::StData { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), @@ -1114,7 +1231,7 @@ derive_parser!( ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { let (a, unified) = a; if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1129,7 +1246,7 @@ derive_parser!( } ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { if level_prefetch_size.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1144,7 +1261,7 @@ derive_parser!( } ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1159,7 +1276,7 @@ derive_parser!( } ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); } Instruction::Ld { data: LdDetails { @@ -1173,7 +1290,7 @@ derive_parser!( } } ld.mmio.relaxed.sys{.global}.type d, [a] => { - state.push(PtxError::Todo); + state.errors.push(PtxError::Todo); Instruction::Ld { data: LdDetails { qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), @@ -1506,6 +1623,9 @@ derive_parser!( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra bra <= { bra(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call + call <= { call(stream) } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } } @@ -1558,7 +1678,7 @@ fn main() { println!("{:?}", &tokens); let stream = PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; let _module = module.parse(stream).unwrap(); println!("{}", mem::size_of::()); @@ -1567,6 +1687,7 @@ fn main() { #[cfg(test)] mod tests { use super::target; + use super::PtxParserState; use super::Token; use logos::Logos; use winnow::prelude::*; @@ -1578,7 +1699,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); } @@ -1590,7 +1711,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); } @@ -1602,7 +1723,7 @@ mod tests { .unwrap(); let stream = super::PtxParser { input: &tokens[..], - state: Vec::new(), + state: PtxParserState::new(), }; assert!(target.parse(stream).is_err()); }