diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index a0bb023..5432207 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -82,8 +82,8 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable), - Method(Function<'a, &'a str, Statement

>), + Variable(LinkingDirective, Variable), + Method(LinkingDirective, Function<'a, &'a str, Statement

>), } #[derive(Hash, PartialEq, Eq, Copy, Clone)] @@ -96,7 +96,7 @@ pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec>, pub name: MethodName<'input, ID>, pub input_arguments: Vec>, - pub shared_mem: Option>, + pub shared_mem: Option, } pub struct Function<'a, ID, S> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index e8370cd..b697317 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -343,10 +343,16 @@ TargetSpecifier = { Directive: Option>> = { AddressSize => None, - => Some(ast::Directive::Method(f)), + => { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + }, File => None, Section => None, - ";" => Some(ast::Directive::Variable(v)), + ";" => { + let (linking, var) = v; + Some(ast::Directive::Variable(linking, var)) + }, ! => { let err = <>; errors.push(err.error); @@ -358,11 +364,13 @@ AddressSize = { ".address_size" U8Num }; -Function: ast::Function<'input, &'input str, ast::Statement>> = { - LinkingDirectives +Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement>>) = { + - => ast::Function{<>} + => { + (linking, ast::Function{func_directive, tuning, body}) + } }; LinkingDirective: ast::LinkingDirective = { @@ -598,18 +606,18 @@ SharedVariable: ast::Variable<&'input str> = { } } -ModuleVariable: ast::Variable<&'input str> = { - LinkingDirectives ".global" => { +ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { + ".global" => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Global; - ast::Variable { align, v_type, state_space, name, array_init } + (linking, ast::Variable { align, v_type, state_space, name, array_init }) }, - LinkingDirectives ".shared" => { + ".shared" => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Shared; - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } + (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }) }, - > > =>? { + > > =>? { let (align, t, name, arr_or_ptr) = var; let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { @@ -620,17 +628,17 @@ ModuleVariable: ast::Variable<&'input str> = { } } ast::ArrayOrPointer::Pointer => { - if !ldirs.contains(ast::LinkingDirective::EXTERN) { + if !linking.contains(ast::LinkingDirective::EXTERN) { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, v_type, state_space, name, array_init }) + Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init })) } } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a0b5077..6b9dcfb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -172,14 +172,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32)) } SpirvType::Array(typ, array_dimensions) => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self.get_or_add_spirv_scalar(b, typ); let len_const = b.constant_u32(u32_type, None, len); (base, len_const) } array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let base = self .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); @@ -221,7 +225,7 @@ impl TypeWordMap { fn get_or_add_fn( &mut self, b: &mut dr::Builder, - in_params: impl ExactSizeIterator, + in_params: impl Iterator, mut out_params: impl ExactSizeIterator, ) -> (spirv::Word, spirv::Word) { let (out_args, out_spirv_type) = if out_params.len() == 0 { @@ -233,6 +237,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key), ) } else { + // TODO: support multiple return values todo!() }; ( @@ -436,7 +441,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( let empty_body = Vec::new(); for d in directives.iter() { match d { - Directive::Variable(var) => { + Directive::Variable(_, var) => { emit_variable(builder, map, &var)?; } Directive::Method(f) => { @@ -699,7 +704,6 @@ fn multi_hash_map_append(m: &mut MultiHashMap, transformation has a semantical meaning - we emit additional "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") */ -/* fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -707,13 +711,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { - Directive::Variable(ast::Variable { - v_type: ast::Type::Pointer(p_type), - state_space: ast::StateSpace::Shared, - name, - .. - }) => { - extern_shared_decls.insert(*name, p_type.clone()); + Directive::Variable( + linking, + ast::Variable { + v_type: ast::Type::Array(p_type, dims), + state_space: ast::StateSpace::Shared, + name, + .. + }, + ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => { + extern_shared_decls.insert(*name, *p_type); } _ => {} } @@ -732,14 +739,13 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) => { let call_key = (*func_decl).borrow().name; let statements = statements .into_iter() .map(|statement| match statement { Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call.func, call_key); + multi_hash_map_append(&mut directly_called_by, call.name, call_key); Statement::Call(call) } statement => statement.map_id(&mut |id, _| { @@ -756,7 +762,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) } directive => directive, @@ -775,7 +780,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }) => { if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) { return Directive::Method(Function { @@ -784,21 +788,12 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem, }); } let shared_id_param = new_id(); { let mut func_decl = (*func_decl).borrow_mut(); - func_decl.input_arguments.push({ - ast::Variable { - name: shared_id_param, - align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()), - state_space: ast::StateSpace::Shared, - array_init: Vec::new(), - } - }); + func_decl.shared_mem = Some(shared_id_param); } let statements = replace_uses_of_shared_memory( new_id, @@ -813,7 +808,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), import_as, tuning, - uses_shared_mem: true, }) } directive => directive, @@ -835,8 +829,8 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) { - call.param_list.push(( + if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) { + call.input_arguments.push(( shared_id_param, ast::Type::Scalar(ast::ScalarType::B8), ast::StateSpace::Shared, @@ -854,13 +848,11 @@ fn replace_uses_of_shared_memory<'a>( result.push(Statement::Conversion(ImplicitConversion { src: shared_id_param, dst: replacement_id, - from_type: ast::Type::Pointer(ast::ScalarType::B8), + from_type: ast::Type::Scalar(ast::ScalarType::B8), from_space: ast::StateSpace::Shared, - to_type: ast::Type::Pointer((*scalar_type).into()), + to_type: ast::Type::Scalar(*scalar_type), to_space: ast::StateSpace::Shared, - kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_ - dst_ + kind: ConversionKind::PtrToPtr, })); replacement_id } else { @@ -912,7 +904,6 @@ fn get_callers_of_extern_shared_single<'a>( } } } -*/ type DenormCountMap = HashMap; @@ -948,7 +939,7 @@ fn compute_denorm_information<'input>( let mut denorm_methods = HashMap::new(); for directive in module { match directive { - Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} Directive::Method(Function { func_decl, body: Some(statements), @@ -1158,14 +1149,17 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { - ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable { - align: var.align, - v_type: var.v_type.clone(), - state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), - array_init: var.array_init, - })), - ast::Directive::Method(f) => { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(_, f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) @@ -2576,7 +2570,7 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: impl ExactSizeIterator, + spirv_input: impl Iterator, spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( @@ -5597,7 +5591,7 @@ impl ast::ArgParams for ExpandedArgParams { impl ArgParamsEx for ExpandedArgParams {} enum Directive<'input> { - Variable(ast::Variable), + Variable(ast::LinkingDirective, ast::Variable), Method(Function<'input>), } @@ -7582,19 +7576,28 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { } impl<'a> ast::MethodDeclaration<'a, spirv::Word> { - fn effective_input_arguments( - &self, - ) -> impl ExactSizeIterator + '_ { + fn effective_input_arguments(&self) -> impl Iterator + '_ { let is_kernel = self.name.is_kernel(); - self.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) + self.input_arguments + .iter() + .map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv()); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) + .chain(self.shared_mem.iter().map(|id| { + ( + *id, + SpirvType::Pointer( + Box::new(SpirvType::Base(SpirvScalarKey::B8)), + spirv::StorageClass::Workgroup, + ), + ) + })) } }