diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2ac1f68..d485286 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [lib] [dependencies] -lalrpop-util = "0.19" +ptx_parser = { path = "../ptx_parser" } regex = "1" rspirv = "0.7" spirv_headers = "1.5" @@ -17,8 +17,12 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" +[dependencies.lalrpop-util] +version = "0.19.12" +features = ["lexer"] + [build-dependencies.lalrpop] -version = "0.19" +version = "0.19.12" features = ["lexer"] [dev-dependencies] diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f1323be..358b8ce 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -34,15 +34,9 @@ pub enum PtxError { #[error("")] NonExternPointer, #[error("{start}:{end}")] - UnrecognizedStatement { - start: usize, - end: usize, - }, + UnrecognizedStatement { start: usize, end: usize }, #[error("{start}:{end}")] - UnrecognizedDirective { - start: usize, - end: usize, - }, + UnrecognizedDirective { start: usize, end: usize }, } // For some weird reson this is illegal: @@ -578,11 +572,15 @@ impl CvtDetails { if saturate { if src.kind() == ScalarKind::Signed { if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } else { if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } } @@ -598,7 +596,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::FloatFromInt(CvtDesc { dst, @@ -618,7 +618,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::IntFromFloat(CvtDesc { dst, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1cb9630..b70019e 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,9 +24,11 @@ lalrpop_mod!( ); pub mod ast; +mod pass; #[cfg(test)] mod test; mod translate; +mod translate2; use std::fmt; diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs new file mode 100644 index 0000000..7b794d6 --- /dev/null +++ b/ptx/src/pass/mod.rs @@ -0,0 +1,531 @@ +use ptx_parser as ast; +use std::{ + borrow::Cow, + cell::RefCell, + collections::{hash_map, HashMap}, + rc::Rc, +}; + +mod normalize; + +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] +enum PtxSpecialRegister { + Tid, + Ntid, + Ctaid, + Nctaid, + Clock, + LanemaskLt, +} + +impl PtxSpecialRegister { + fn try_parse(s: &str) -> Option { + match s { + "%tid" => Some(Self::Tid), + "%ntid" => Some(Self::Ntid), + "%ctaid" => Some(Self::Ctaid), + "%nctaid" => Some(Self::Nctaid), + "%clock" => Some(Self::Clock), + "%lanemask_lt" => Some(Self::LanemaskLt), + _ => None, + } + } + + fn get_type(self) -> ast::Type { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), + _ => ast::Type::Scalar(self.get_function_return_type()), + } + } + + fn get_function_return_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid => ast::ScalarType::U32, + PtxSpecialRegister::Ntid => ast::ScalarType::U32, + PtxSpecialRegister::Ctaid => ast::ScalarType::U32, + PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Clock => ast::ScalarType::U32, + PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, + } + } + + fn get_function_input_type(self) -> Option { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), + PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, + } + } + + fn get_unprefixed_function_name(self) -> &'static str { + match self { + PtxSpecialRegister::Tid => "sreg_tid", + PtxSpecialRegister::Ntid => "sreg_ntid", + PtxSpecialRegister::Ctaid => "sreg_ctaid", + PtxSpecialRegister::Nctaid => "sreg_nctaid", + PtxSpecialRegister::Clock => "sreg_clock", + PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", + } + } +} + +struct SpecialRegistersMap { + reg_to_id: HashMap, + id_to_reg: HashMap, +} + +impl SpecialRegistersMap { + fn new() -> Self { + SpecialRegistersMap { + reg_to_id: HashMap::new(), + id_to_reg: HashMap::new(), + } + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = SpirvWord(current_id.0); + current_id.0 += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } +} + +struct FnStringIdResolver<'input, 'b> { + current_id: &'b mut SpirvWord, + global_variables: &'b HashMap, SpirvWord>, + global_type_check: &'b HashMap>, + special_registers: &'b mut SpecialRegistersMap, + variables: Vec, SpirvWord>>, + type_check: HashMap>, +} + +impl<'a, 'b> FnStringIdResolver<'a, 'b> { + fn finish(self) -> NumericIdResolver<'b> { + NumericIdResolver { + current_id: self.current_id, + global_type_check: self.global_type_check, + type_check: self.type_check, + special_registers: self.special_registers, + } + } + + fn start_block(&mut self) { + self.variables.push(HashMap::new()) + } + + fn end_block(&mut self) { + self.variables.pop(); + } + + fn get_id(&mut self, id: &str) -> Result { + for scope in self.variables.iter().rev() { + match scope.get(id) { + Some(id) => return Ok(*id), + None => continue, + } + } + match self.global_variables.get(id) { + Some(id) => Ok(*id), + None => { + let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?; + Ok(self.special_registers.get_or_add(self.current_id, sreg)) + } + } + } + + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> SpirvWord { + let numeric_id = *self.current_id; + self.variables + .last_mut() + .unwrap() + .insert(Cow::Borrowed(id), numeric_id); + self.type_check.insert( + numeric_id.0, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); + self.current_id.0 += 1; + numeric_id + } + + #[must_use] + fn add_defs( + &mut self, + base_id: &'a str, + count: u32, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> impl Iterator { + let numeric_id = *self.current_id; + for i in 0..count { + self.variables.last_mut().unwrap().insert( + Cow::Owned(format!("{}{}", base_id, i)), + SpirvWord(numeric_id.0 + i), + ); + self.type_check.insert( + numeric_id.0 + i, + Some((typ.clone(), state_space, is_variable)), + ); + } + self.current_id.0 += count; + (0..count) + .into_iter() + .map(move |i| SpirvWord(i + numeric_id.0)) + } +} + +struct NumericIdResolver<'b> { + current_id: &'b mut SpirvWord, + global_type_check: &'b HashMap>, + type_check: HashMap>, + special_registers: &'b mut SpecialRegistersMap, +} + +impl<'b> NumericIdResolver<'b> { + fn finish(self) -> MutableNumericIdResolver<'b> { + MutableNumericIdResolver { base: self } + } + + fn get_typed( + &self, + id: SpirvWord, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { + match self.type_check.get(&id.0) { + Some(Some(x)) => Ok(x.clone()), + Some(None) => Err(TranslateError::UntypedSymbol), + None => match self.special_registers.get(id) { + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), + None => match self.global_type_check.get(&id.0) { + Some(Some(result)) => Ok(result.clone()), + Some(None) | None => Err(TranslateError::UntypedSymbol), + }, + }, + } + } + + // This is for identifiers which will be emitted later as OpVariable + // They are candidates for insertion of LoadVar/StoreVar + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id.0, Some((typ, state_space, true))); + self.current_id.0 += 1; + new_id + } + + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id.0, typ.map(|(t, space)| (t, space, false))); + self.current_id.0 += 1; + new_id + } +} + +struct MutableNumericIdResolver<'b> { + base: NumericIdResolver<'b>, +} + +impl<'b> MutableNumericIdResolver<'b> { + fn unmut(self) -> NumericIdResolver<'b> { + self.base + } + + fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) + } + + fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + self.base.register_intermediate(Some((typ, state_space))) + } +} + +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol {} + UntypedSymbol {} + MismatchedType {} + Spirv(err: rspirv::dr::Error) { + from() + display("{}", err) + cause(err) + } + Unreachable {} + Todo {} + } +} + +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + +fn error_unknown_symbol() -> TranslateError { + TranslateError::UnknownSymbol +} + +pub struct GlobalFnDeclResolver<'input, 'a> { + fns: &'a HashMap>, +} + +impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { + fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> { + self.fns.get(&id).ok_or_else(error_unknown_symbol) + } +} + +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + /* + fn resolve_in_spirv_repr( + &self, + call_inst: ast::CallInst, + ) -> Result, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut return_arguments = Vec::new(); + let mut input_arguments = call_inst + .param_list + .into_iter() + .zip(func_decl.input_arguments.iter()) + .map(|(id, var)| (id, var.v_type.clone(), var.state_space)) + .collect::>(); + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter(); + for (idx, id) in call_inst.ret_params.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + return_arguments.push((*id, var.v_type.clone(), var.state_space)); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + input_arguments.push(( + ast::Operand::Reg(*id), + var.v_type.clone(), + var.state_space, + )); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if return_arguments.len() != func_decl.return_arguments.len() + || input_arguments.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + Ok(ResolvedCall { + return_arguments, + input_arguments, + uniform: call_inst.uniform, + name: call_inst.func, + }) + } + */ +} + +enum Statement { + Label(SpirvWord), + Variable(ast::Variable), + Instruction(I), + // SPIR-V compatible replacement for PTX predicates + Conditional(BrachCondition), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), + Conversion(ImplicitConversion), + Constant(ConstantDefinition), + RetValue(ast::RetData, SpirvWord), + PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), +} + +struct BrachCondition { + predicate: SpirvWord, + if_true: SpirvWord, + if_false: SpirvWord, +} +struct LoadVarDetails { + arg: ast::LdArgs, + typ: ast::Type, + state_space: ast::StateSpace, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, +} + +struct StoreVarDetails { + arg: ast::StArgs, + typ: ast::Type, + member_index: Option, +} + +#[derive(Clone)] +struct ImplicitConversion { + src: SpirvWord, + dst: SpirvWord, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, + kind: ConversionKind, +} + +#[derive(PartialEq, Clone)] +enum ConversionKind { + Default, + // zero-extend/chop/bitcast depending on types + SignExtend, + BitToPtr, + PtrToPtr, + AddressOf, +} + +struct ConstantDefinition { + pub dst: SpirvWord, + pub typ: ast::ScalarType, + pub value: ast::ImmediateValue, +} + +pub struct PtrAccess { + underlying_type: ast::Type, + state_space: ast::StateSpace, + dst: SpirvWord, + ptr_src: SpirvWord, + offset_src: T, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: SpirvWord, + unpacked: Vec, + non_default_implicit_conversion: Option< + fn( + (ast::StateSpace, &ast::Type), + (ast::StateSpace, &ast::Type), + ) -> Result, TranslateError>, + >, +} + +struct FunctionPointerDetails { + dst: SpirvWord, + src: SpirvWord, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct SpirvWord(spirv::Word); + +impl From for SpirvWord { + fn from(value: spirv::Word) -> Self { + Self(value) + } +} +impl From for spirv::Word { + fn from(value: SpirvWord) -> Self { + value.0 + } +} + +impl ast::Operand for SpirvWord { + type Ident = Self; +} + +fn pred_map_variable Result>( + this: ast::PredAt, + f: &mut F, +) -> Result, TranslateError> { + let new_label = f(this.label)?; + Ok(ast::PredAt { + not: this.not, + label: new_label, + }) +} + +impl Result, Err> ast::VisitorMap for X { + fn visit( + &mut self, + args: T, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> U { + todo!() + } + + fn visit_ident( + &mut self, + args: ::Ident, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + ) -> ::Ident { + todo!() + } +} + +fn op_map_variable<'a, F: FnMut(&str) -> Result>( + this: ast::Instruction>, + f: &mut F, +) -> Result>, TranslateError> { + ast::visit_map(this , f) +} diff --git a/ptx/src/pass/normalize.rs b/ptx/src/pass/normalize.rs new file mode 100644 index 0000000..3832685 --- /dev/null +++ b/ptx/src/pass/normalize.rs @@ -0,0 +1,83 @@ +use super::*; +use ptx_parser as ast; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +fn run<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, +) -> Result, TranslateError> { + for s in func.iter() { + match s { + ast::Statement::Label(id) => { + id_defs.add_def(*id, None, false); + } + _ => (), + } + } + let mut result = Vec::new(); + for s in func { + expand_map_variables(id_defs, fn_defs, &mut result, s)?; + } + Ok(result) +} + +fn expand_map_variables<'a, 'b>( + id_defs: &mut FnStringIdResolver<'a, 'b>, + fn_defs: &GlobalFnDeclResolver<'a, 'b>, + result: &mut Vec, + s: ast::Statement>, +) -> Result<(), TranslateError> { + match s { + ast::Statement::Block(block) => { + id_defs.start_block(); + for s in block { + expand_map_variables(id_defs, fn_defs, result, s)?; + } + id_defs.end_block(); + } + ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), + ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( + p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) + .transpose()?, + op_map_variable(i, &mut |id| id_defs.get_id(id))?, + ))), + ast::Statement::Variable(var) => { + let var_type = var.var.v_type.clone(); + match var.count { + Some(count) => { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init.clone(), + })) + } + } + None => { + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init, + })); + } + } + } + }; + Ok(()) +} diff --git a/ptx/src/translate2.rs b/ptx/src/translate2.rs new file mode 100644 index 0000000..4ac5dea --- /dev/null +++ b/ptx/src/translate2.rs @@ -0,0 +1,60 @@ +use std::collections::HashMap; +use half::f16; +use ptx_parser as ast; + +fn to_ssa<'input, 'b>( + ptx_impl_imports: &'b mut HashMap>, + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + func_decl: Rc>>, + f_body: Option>>>, + tuning: Vec, + linkage: ast::LinkingDirective, +) -> Result, TranslateError> { + //deparamize_function_decl(&func_decl)?; + let f_body = match f_body { + Some(vec) => vec, + None => { + return Ok(Function { + func_decl: func_decl, + body: None, + globals: Vec::new(), + import_as: None, + tuning, + linkage, + }) + } + }; + let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; + /* + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &mut (*func_decl).borrow_mut(), + )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; + let expanded_statements = + insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.unmut(); + let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); + let (f_body, globals) = + extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + Ok(Function { + func_decl: func_decl, + globals: globals, + body: Some(f_body), + import_as: None, + tuning, + linkage, + }) + */ +} \ No newline at end of file diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index af3058b..a4df14f 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -4,6 +4,8 @@ version = "0.0.0" authors = ["Andrzej Janik "] edition = "2021" +[lib] + [dependencies] logos = "0.14" winnow = { version = "0.6.18" } @@ -11,3 +13,4 @@ ptx_parser_macros = { path = "../ptx_parser_macros" } thiserror = "1.0" bitflags = "1.2" rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 6cf1264..87a2f6b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -147,9 +147,9 @@ ptx_parser_macros::generate_instruction_type!( Call { data: CallDetails, arguments: CallArgs, - visit: arguments.visit(data, visitor), - visit_mut: arguments.visit_mut(data, visitor), - map: Instruction::Call{ arguments: arguments.map(&data, visitor), data } + visit: arguments.visit(data, visitor)?, + visit_mut: arguments.visit_mut(data, visitor)?, + map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } }, Cvt { data: CvtDetails, @@ -488,93 +488,185 @@ ptx_parser_macros::generate_instruction_type!( } ); -pub trait Visitor { - fn visit(&mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); - fn visit_ident(&self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool); +pub trait Visitor { + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; } -pub trait VisitorMut { - fn visit(&mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool); +impl, bool) -> Result<(), Err>> + Visitor for Fn +{ + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err> { + (self)(args, type_space, is_dst) + } + + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err> { + (self)(&T::from_ident(*args), type_space, is_dst) + } +} + +pub trait VisitorMut { + fn visit( + &mut self, + args: &mut T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result<(), Err>; fn visit_ident( &mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ); + ) -> Result<(), Err>; } -pub trait VisitorMap { - fn visit(&mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool) -> To; +pub trait VisitorMap { + fn visit( + &mut self, + args: From, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result; fn visit_ident( &mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, - ) -> To::Ident; + ) -> Result; } -trait VisitOperand { +impl< + T: Operand, + U: Operand, + Err, + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool) -> Result, + > VisitorMap for Fn +{ + fn visit( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + (self)(args, type_space, is_dst) + } + + fn visit_ident( + &mut self, + args: T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + ) -> Result { + let value: U = (self)(T::from_ident(args), type_space, is_dst)?; + Ok(value) + } +} + +trait VisitOperand { type Operand: Operand; #[allow(unused)] // Used by generated code - fn visit(&self, fn_: impl FnMut(&Self::Operand)); + fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; #[allow(unused)] // Used by generated code - fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)); + fn visit_mut( + &mut self, + fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err>; } -impl VisitOperand for T { +impl VisitOperand for T { type Operand = Self; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { fn_(self) } - fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { fn_(self) } } -impl VisitOperand for Option { +impl VisitOperand for Option { type Operand = T; - fn visit(&self, fn_: impl FnMut(&Self::Operand)) { - self.as_ref().map(fn_); + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) } - fn visit_mut(&mut self, fn_: impl FnMut(&mut Self::Operand)) { - self.as_mut().map(fn_); + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) } } -impl VisitOperand for Vec { +impl VisitOperand for Vec { type Operand = T; - fn visit(&self, mut fn_: impl FnMut(&Self::Operand)) { + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { for o in self { - fn_(o) + fn_(o)?; } + Ok(()) } - fn visit_mut(&mut self, mut fn_: impl FnMut(&mut Self::Operand)) { + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { for o in self { - fn_(o) + fn_(o)?; } + Ok(()) } } -trait MapOperand: Sized { +trait MapOperand: Sized { type Input; type Output; #[allow(unused)] // Used by generated code - fn map(self, fn_: impl FnOnce(Self::Input) -> U) -> Self::Output; + fn map( + self, + fn_: impl FnOnce(Self::Input) -> Result, + ) -> Result, Err>; } -impl MapOperand for T { +impl MapOperand for T { type Input = Self; type Output = U; - fn map(self, fn_: impl FnOnce(T) -> U) -> U { + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { fn_(self) } } -impl MapOperand for Option { +impl MapOperand for Option { type Input = T; type Output = Option; - fn map(self, fn_: impl FnOnce(T) -> U) -> Option { - self.map(|x| fn_(x)) + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { + self.map(|x| fn_(x)).transpose() } } @@ -715,10 +807,16 @@ pub enum ParsedOperand { impl Operand for ParsedOperand { type Ident = Ident; + + fn from_ident(ident: Self::Ident) -> Self { + ParsedOperand::Reg(ident) + } } -pub trait Operand { +pub trait Operand: Sized { type Ident: Copy; + + fn from_ident(ident: Self::Ident) -> Self; } #[derive(Copy, Clone)] @@ -1048,67 +1146,77 @@ pub struct CallArgs { impl CallArgs { #[allow(dead_code)] // Used by generated code - fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor) { + fn visit( + &self, + details: &CallDetails, + visitor: &mut impl Visitor, + ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true); + visitor.visit_ident(param, Some((type_, *space)), true)?; } - visitor.visit_ident(&self.func, None, false); + visitor.visit_ident(&self.func, None, false)?; for (param, (type_, space)) in self .input_arguments .iter() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true); + visitor.visit(param, Some((type_, *space)), true)?; } + Ok(()) } #[allow(dead_code)] // Used by generated code - fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut) { + fn visit_mut( + &mut self, + details: &CallDetails, + visitor: &mut impl VisitorMut, + ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter_mut() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true); + visitor.visit_ident(param, Some((type_, *space)), true)?; } - visitor.visit_ident(&mut self.func, None, false); + visitor.visit_ident(&mut self.func, None, false)?; for (param, (type_, space)) in self .input_arguments .iter_mut() .zip(details.input_arguments.iter()) { - visitor.visit(param, Some((type_, *space)), true); + visitor.visit(param, Some((type_, *space)), true)?; } + Ok(()) } #[allow(dead_code)] // Used by generated code - fn map( + fn map( self, details: &CallDetails, - visitor: &mut impl VisitorMap, - ) -> CallArgs { + visitor: &mut impl VisitorMap, + ) -> Result, Err> { let return_arguments = self .return_arguments .into_iter() .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true)) - .collect::>(); - let func = visitor.visit_ident(self.func, None, false); + .collect::, _>>()?; + let func = visitor.visit_ident(self.func, None, false)?; let input_arguments = self .input_arguments .into_iter() .zip(details.input_arguments.iter()) .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true)) - .collect::>(); - CallArgs { + .collect::, _>>()?; + Ok(CallArgs { return_arguments, func, input_arguments, - } + }) } } diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/lib.rs similarity index 98% rename from ptx_parser/src/main.rs rename to ptx_parser/src/lib.rs index 5db94f2..cfb8793 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/lib.rs @@ -1,8 +1,8 @@ +use derive_more::Display; use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::mem; use std::num::{ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; @@ -81,16 +81,16 @@ impl VectorPrefix { } } -struct PtxParserState<'input> { - errors: Vec, +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec, function_declarations: FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, } -impl<'input> PtxParserState<'input> { - fn new() -> Self { +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec) -> Self { Self { - errors: Vec::new(), + errors, function_declarations: FxHashMap::default(), } } @@ -115,7 +115,7 @@ impl<'input> PtxParserState<'input> { } } -impl<'input> Debug for PtxParserState<'input> { +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PtxParserState") .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ @@ -123,7 +123,7 @@ impl<'input> Debug for PtxParserState<'input> { } } -type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>; +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { any.verify_map(|t| { @@ -277,6 +277,18 @@ fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult(text: &'input str) -> Option> { + let lexer = Token::lexer(text); + let input = lexer.collect::, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + module.parse(parser).ok() +} + fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { ( version, @@ -818,6 +830,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Lexer(#[from] TokenError), + #[error("")] Todo, #[error("")] SyntaxError, @@ -1042,9 +1056,15 @@ fn empty_call<'input>( type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; +#[derive(Clone, PartialEq, Default, Debug, Display)] +pub struct TokenError; + +impl std::error::Error for TokenError {} + derive_parser!( #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] #[logos(skip r"\s+")] + #[logos(error = TokenError)] enum Token<'input> { #[token(",")] Comma, @@ -1134,6 +1154,7 @@ derive_parser!( pub enum StateSpace { Reg, Generic, + Sreg, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -2825,57 +2846,6 @@ derive_parser!( ); -fn main() { - use winnow::Parser; - - let lexer = Token::lexer( - " - .version 6.5 - .target sm_30 - .address_size 64 - - .const .align 8 .b32 constparams; - - .visible .entry const( - .param .u64 input, - .param .u64 output - ) - { - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .b16 temp1; - .reg .b16 temp2; - .reg .b16 temp3; - .reg .b16 temp4; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - ld.const.b16 temp1, [constparams]; - ld.const.b16 temp2, [constparams+2]; - ld.const.b16 temp3, [constparams+4]; - ld.const.b16 temp4, [constparams+6]; - st.u16 [out_addr], temp1; - st.u16 [out_addr+2], temp2; - st.u16 [out_addr+4], temp3; - st.u16 [out_addr+6], temp4; - ret; - } - - ", - ); - let tokens = lexer.clone().collect::>(); - println!("{:?}", &tokens); - let tokens = lexer.map(|t| t.unwrap()).collect::>(); - println!("{:?}", &tokens); - let stream = PtxParser { - input: &tokens[..], - state: PtxParserState::new(), - }; - let _module = module.parse(stream).unwrap(); - println!("{}", mem::size_of::()); -} - #[cfg(test)] mod tests { use super::target; diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index a2f8396..4502c95 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -1017,7 +1017,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro: input.emit_arg_types(&mut result); input.emit_instruction_type(&mut result); input.emit_visit(&mut result); - input.emit_visit_mut(&mut result); + //input.emit_visit_mut(&mut result); input.emit_visit_map(&mut result); result.into() } diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs index 4532964..3e53607 100644 --- a/ptx_parser_macros_impl/src/lib.rs +++ b/ptx_parser_macros_impl/src/lib.rs @@ -67,37 +67,29 @@ impl GenerateInstructionType { let visit_ref = kind.reference(); let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); let visit_fn = format_ident!("visit{}", kind.fn_suffix()); - let visit_slice_fn = format_ident!("visit{}_slice", kind.fn_suffix()); let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { ( - quote! { <#type_parameters, To: Operand> }, - quote! { <#short_parameters, To> }, - quote! { #type_name }, + quote! { <#type_parameters, To: Operand, Err> }, + quote! { <#short_parameters, To, Err> }, + quote! { std::result::Result<#type_name, Err> }, ) } else { ( - quote! { <#type_parameters> }, - quote! { <#short_parameters> }, - quote! { () }, + quote! { <#type_parameters, Err> }, + quote! { <#short_parameters, Err> }, + quote! { std::result::Result<(), Err> }, ) }; quote! { - fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { - match i { + pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + Ok(match i { #inner_tokens - } + }) } }.to_tokens(tokens); if kind == VisitKind::Map { return; } - quote! { - fn #visit_slice_fn #type_parameters (instructions: #visit_ref [#type_name<#short_parameters>], visitor: &mut impl #visitor_type #visitor_parameters) { - for i in instructions { - #visit_fn(i, visitor) - } - } - }.to_tokens(tokens); } } @@ -630,14 +622,14 @@ impl ArgumentField { quote! { { #type_space - visitor.visit_ident(&mut arguments.#name, type_space, #is_dst); + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst)?; } } } else { quote! { { #type_space - visitor.visit_ident(& arguments.#name, type_space, #is_dst); + visitor.visit_ident(& arguments.#name, type_space, #is_dst)?; } } } @@ -663,7 +655,7 @@ impl ArgumentField { }; quote! {{ #type_space - #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst)); + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst))?; }} } } @@ -701,11 +693,11 @@ impl ArgumentField { }; let map_call = if is_ident { quote! { - visitor.visit_ident(arguments.#name, type_space, #is_dst) + visitor.visit_ident(arguments.#name, type_space, #is_dst)? } } else { quote! { - MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst)) + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst))? } }; quote! {