diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3e62cb1..c7b9563 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,5 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::convert::TryInto; use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; @@ -34,107 +33,12 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(ScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(ScalarType, LdStateSpace) - } -} - -type VecU32 = Vec; - -sub_type! { - VariableLocalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - } -} - -impl TryFrom for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} - // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; // even more interestingly this is legal, but only in .func (not in .entry): // .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(ScalarType), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { @@ -178,7 +82,7 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable), + Variable(Variable), Method(Function<'a, &'a str, Statement

>), } @@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> { }, } -pub type FnArgument = Variable; -pub type KernelArgument = Variable; +pub type FnArgument = Variable; +pub type KernelArgument = Variable; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, @@ -201,76 +105,6 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; -#[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), @@ -283,7 +117,7 @@ pub enum Type { pub enum PointerType { Scalar(ScalarType), Vector(ScalarType, u8), - Array(ScalarType, VecU32), + Array(ScalarType, Vec), // Instances of this variant are generated during stateful conversion Pointer(ScalarType, LdStateSpace), } @@ -366,51 +200,19 @@ pub enum Statement { } pub struct MultiVariable { - pub var: Variable, + pub var: Variable, pub count: Option, } #[derive(Clone)] -pub struct Variable { +pub struct Variable { pub align: Option, - pub v_type: T, + pub v_type: Type, + pub state_space: StateSpace, pub name: ID, pub array_init: Vec, } -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -419,6 +221,7 @@ pub enum StateSpace { Local, Shared, Param, + Generic, } pub struct PredAt { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 44852a2..dc439b7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -404,28 +404,29 @@ FnArguments: Vec> = { "(" > ")" => args }; -KernelInput: ast::Variable = { +KernelInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; ast::Variable { align, - v_type: ast::KernelArgumentType::Normal(v_type), + v_type, + state_space: ast::StateSpace::Param, name, array_init: Vec::new() } } } -FnInput: ast::Variable = { +FnInput: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Reg; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } }, => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Param; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } } } @@ -508,102 +509,109 @@ VariableParam: u32 = { "<" ">" => n } -Variable: ast::Variable = { +Variable: ast::Variable<&'input str> = { => { let (align, v_type, name) = v; - let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name, array_init: Vec::new()} + let state_space = ast::StateSpace::Reg; + ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} }, LocalVariable, => { let (align, array_init, v_type, name) = v; - let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name, array_init} + let state_space = ast::StateSpace::Param; + ast::Variable {align, v_type, state_space, name, array_init} }, SharedVariable, }; -RegVariable: (Option, ast::VariableRegType, &'input str) = { +RegVariable: (Option, ast::Type, &'input str) = { ".reg" > => { let (align, t, name) = var; - let v_type = ast::VariableRegType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name) }, ".reg" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableRegType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name) } } -LocalVariable: ast::Variable = { +LocalVariable: ast::Variable<&'input str> = { ".local" > => { let (align, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Scalar(t); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Vector(t, v_len); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Local; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableLocalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } -SharedVariable: ast::Variable = { +SharedVariable: ast::Variable<&'input str> = { ".shared" > => { let (align, t, name) = var; - let v_type = ast::VariableGlobalType::Scalar(t); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Scalar(t); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Vector(t, v_len); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" > =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Shared; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableGlobalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } - -ModuleVariable: ast::Variable = { +ModuleVariable: ast::Variable<&'input str> = { LinkingDirectives ".global" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + let state_space = ast::StateSpace::Global; + ast::Variable { align, v_type, state_space, name, array_init } }, LinkingDirectives ".shared" => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, > > =>? { let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { + let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init) } } ast::ArrayOrPointer::Pointer => { @@ -611,38 +619,38 @@ ModuleVariable: ast::Variable = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, array_init, v_type, name }) + Ok(ast::Variable{ align, v_type, state_space, name, array_init }) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { +ParamVariable: (Option, Vec, ast::Type, &'input str) = { ".param" > => { let (align, t, name) = var; - let v_type = ast::VariableParamType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, Vec::new(), v_type, name) }, ".param" > => { let (align, t, name, arr_or_ptr) = var; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableParamType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new()) } }; (align, array_init, v_type, name) } } -ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { +ParamDeclaration: (Option, ast::Type, &'input str) = { =>? { let (align, array_init, v_type, name) = var; if array_init.len() > 0 { @@ -653,15 +661,15 @@ ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { } } -GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input str, Vec) = { +GlobalVariableDefinitionNoArray: (Option, ast::Type, &'input str, Vec) = { > => { let (align, t, name) = scalar; - let v_type = ast::VariableGlobalType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name, Vec::new()) }, > => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name, Vec::new()) }, } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1f647bd..4ba5729 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { - Directive::Variable(var) => { - if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = - var.v_type - { - extern_shared_decls.insert(var.name, p_type); - } + Directive::Variable(ast::Variable { + v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared), + state_space: ast::StateSpace::Shared, + name, + .. + }) => { + extern_shared_decls.insert(*name, p_type.clone()); } _ => {} } @@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>( let shared_id_param = new_id(); spirv_decl.input.push({ ast::Variable { + name: shared_id_param, align: None, v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), + ast::PointerType::Scalar(ast::ScalarType::B8), ast::LdStateSpace::Shared, ), + state_space: ast::StateSpace::Param, array_init: Vec::new(), - name: shared_id_param, } }); spirv_decl.uses_shared_mem = true; let shared_var_id = new_id(); let shared_var = ExpandedStatement::Variable(ast::Variable { - align: None, name: shared_var_id, - array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::ScalarType::B8, + align: None, + v_type: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::B8), ast::LdStateSpace::Shared, - )), + ), + state_space: ast::StateSpace::Reg, + array_init: Vec::new(), }); let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { arg: ast::Arg2St { @@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, shared_var_id: spirv::Word, @@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>( // because there's simply no way to pass shared ptr // without converting it to .b64 first if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { - call.param_list - .push((shared_id_param, ast::FnArgumentType::Shared)); + call.param_list.push(( + shared_id_param, + ast::Type::Scalar(ast::ScalarType::B8), + ast::StateSpace::Shared, + )); } result.push(Statement::Call(call)) } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some(typ) = extern_shared_decls.get(&id) { + if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) { if *typ == ast::ScalarType::B8 { return shared_var_id; } @@ -1067,7 +1073,7 @@ fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, - synthetic_globals: &[ast::Variable], + synthetic_globals: &[ast::Variable], func_decl: &SpirvMethodDecl<'a>, _denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'a str, HashSet>, @@ -1204,9 +1210,9 @@ fn translate_directive<'input>( fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>, - var: ast::Variable, -) -> Result, TranslateError> { - let (space, var_type) = var.v_type.to_type(); + var: ast::Variable<&'a str>, +) -> Result, TranslateError> { + let (space, var_type) = (var.state_space, var.v_type.clone()); let mut is_variable = false; let var_type = match space { ast::StateSpace::Reg => { @@ -1226,10 +1232,12 @@ fn translate_variable<'a>( } } ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + ast::StateSpace::Generic => todo!(), }; Ok(ast::Variable { align: var.align, v_type: var.v_type, + state_space: var.state_space, name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), array_init: var.array_init, }) @@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>( false, ), v_type: a.v_type.clone(), + state_space: a.state_space, align: a.align, array_init: Vec::new(), }) @@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { - let is_variable = match a.v_type { - ast::FnArgumentType::Reg(_) => true, - _ => false, - }; - let var_type = a.v_type.to_func_type(); + let is_variable = a.state_space == ast::StateSpace::Reg; Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(var_type), is_variable), + name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable), v_type: a.v_type.clone(), + state_space: a.state_space, align: a.align, array_init: Vec::new(), }) @@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, -) -> ( - Vec, - Vec>, -) { +) -> (Vec, Vec>) { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { @@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Shared(_), + state_space: ast::StateSpace::Shared, .. }, ) @@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>( var @ ast::Variable { - v_type: ast::VariableType::Global(_), + state_space: ast::StateSpace::Global, .. }, ) => global.push(var), @@ -1592,10 +1595,10 @@ fn convert_to_typed_statements( let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args .into_iter() - .partition(|(_, arg_type)| arg_type.is_param()); + .partition(|(_, _, space)| *space == ast::StateSpace::Param); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::Operand::Reg(id), typ)) + .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( - typ, ptr_space, - )), + v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + v_type: ast::Type::Scalar(scalar_typ), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)), + ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), + ast::Type::Scalar(scalar_typ), + ast::StateSpace::Reg, ), ], }) @@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) @@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call( let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }], @@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call( vec![ ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + v_type: ast::Type::Scalar(typ.into()), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, - v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( - ast::ScalarType::U32, - )), + v_type: ast::Type::Scalar(ast::ScalarType::U32), + state_space: ast::StateSpace::Reg, name: id_defs.new_non_variable(None), array_init: Vec::new(), }, @@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call( Statement::Call(ResolvedCall { uniform: false, func: fn_id, - ret_params: vec![( - arg.dst, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), - )], + ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)], param_list: vec![ ( arg.src1, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src2, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ast::Type::Scalar(typ.into()), + ast::StateSpace::Reg, ), ( arg.src3, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ( arg.src4, - ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, ), ], }) @@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call( fn to_resolved_fn_args( params: Vec, - params_decl: &[ast::FnArgumentType], -) -> Vec<(T, ast::FnArgumentType)> { + params_decl: &[(ast::Type, ast::StateSpace)], +) -> Vec<(T, ast::Type, ast::StateSpace)> { params .into_iter() .zip(params_decl.iter()) - .map(|(id, typ)| (id, typ.clone())) + .map(|(id, (typ, space))| (id, typ.clone(), *space)) .collect::>() } @@ -2096,50 +2101,38 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, - ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, + _: &'a ast::MethodDecl<'b, spirv::Word>, fn_decl: &mut SpirvMethodDecl, ) -> Result, TranslateError> { - let is_func = match ast_fn_decl { - ast::MethodDecl::Func(..) => true, - ast::MethodDecl::Kernel { .. } => false, - }; let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { - match type_to_variable_type(&arg.v_type, is_func)? { - Some(var_type) => { - result.push(Statement::Variable(ast::Variable { - align: arg.align, - v_type: var_type, - name: arg.name, - array_init: arg.array_init.clone(), - })); - } - None => return Err(error_unreachable()), - } + result.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); } for spirv_arg in fn_decl.input.iter_mut() { - match type_to_variable_type(&spirv_arg.v_type, is_func)? { - Some(var_type) => { - let typ = spirv_arg.v_type.clone(); - let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::Variable(ast::Variable { - align: spirv_arg.align, - v_type: var_type, - name: spirv_arg.name, - array_init: spirv_arg.array_init.clone(), - })); - result.push(Statement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: spirv_arg.name, - src2: new_id, - }, - typ, - member_index: None, - })); - spirv_arg.name = new_id; - } - None => {} - } + let typ = spirv_arg.v_type.clone(); + let new_id = id_def.new_non_variable(Some(typ.clone())); + result.push(Statement::Variable(ast::Variable { + align: spirv_arg.align, + v_type: spirv_arg.v_type.clone(), + state_space: spirv_arg.state_space, + name: spirv_arg.name, + array_init: spirv_arg.array_init.clone(), + })); + result.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { + src1: spirv_arg.name, + src2: new_id, + }, + typ, + member_index: None, + })); + spirv_arg.name = new_id; } for s in func { match s { @@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } -fn type_to_variable_type( - t: &ast::Type, - is_func: bool, -) -> Result, TranslateError> { - Ok(match t { - ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), - ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - *len, - ))), - ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array( - (*typ) - .try_into() - .map_err(|_| TranslateError::MismatchedType)?, - len.clone(), - ))), - ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { - if is_func { - return Ok(None); - } - Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( - scalar_type - .clone() - .try_into() - .map_err(|_| error_unreachable())?, - (*space).try_into().map_err(|_| error_unreachable())?, - ))) - } - ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(error_unreachable()), - }) -} - trait Visitable: Sized { fn visit( self, @@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, }) => result.push(Statement::Variable(ast::Variable { align, v_type, + state_space, name, array_init, })), @@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - spirv_input: &[ast::Variable], - spirv_output: &[ast::Variable], + spirv_input: &[ast::Variable], + spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, @@ -2822,8 +2782,8 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { - [(id, typ)] => ( - map.get_or_add(builder, SpirvType::from(typ.to_func_type())), + [(id, typ, _)] => ( + map.get_or_add(builder, SpirvType::from(typ.clone())), Some(*id), ), [] => (map.void(), None), @@ -2832,7 +2792,7 @@ fn emit_function_body_ops( let arg_list = call .param_list .iter() - .map(|(id, _)| *id) + .map(|(id, _, _)| *id) .collect::>(); builder.function_call(result_type, result_id, call.func, arg_list)?; } @@ -3602,14 +3562,16 @@ fn vec_repr(t: T) -> Vec { fn emit_variable( builder: &mut dr::Builder, map: &mut TypeWordMap, - var: &ast::Variable, + var: &ast::Variable, ) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.v_type { - ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { (false, spirv::StorageClass::Function) } - ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), - ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => todo!(), + ast::StateSpace::Generic => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( @@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>( ast::Statement::Variable(var) => { let mut var_type = ast::Type::from(var.var.v_type.clone()); let mut is_variable = false; - var_type = match var.var.v_type { - ast::VariableType::Reg(_) => { + var_type = match var.var.state_space { + ast::StateSpace::Reg => { is_variable = true; var_type } - ast::VariableType::Shared(_) => { + ast::StateSpace::Shared => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; @@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>( var_type.param_pointer_to(ast::LdStateSpace::Shared)? } } - ast::VariableType::Global(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Global)? - } - ast::VariableType::Param(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Param)? - } - ast::VariableType::Local(_) => { - var_type.param_pointer_to(ast::LdStateSpace::Local)? - } + ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, + ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, + ast::StateSpace::Const => todo!(), + ast::StateSpace::Generic => todo!(), }; match var.count { Some(count) => { @@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init.clone(), })) @@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), + state_space: var.var.state_space, name: new_id, array_init: var.var.array_init, })); @@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>( align: None, name: new_id, array_init: Vec::new(), - v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::ScalarType::U8, + v_type: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, - )), + ), + state_space: ast::StateSpace::Reg, })); remapped_ids.insert(reg, new_id); } @@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> { } pub struct FnDecl { - ret_vals: Vec, - params: Vec, + ret_vals: Vec<(ast::Type, ast::StateSpace)>, + params: Vec<(ast::Type, ast::StateSpace)>, } impl<'a> GlobalStringIdResolver<'a> { @@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert( name_id, FnDecl { - ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), - params: params_ids.iter().map(|p| p.v_type.clone()).collect(), + ret_vals: ret_params_ids + .iter() + .map(|p| (p.v_type.clone(), p.state_space)) + .collect(), + params: params_ids + .iter() + .map(|p| (p.v_type.clone(), p.state_space)) + .collect(), }, ); ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) @@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> { enum Statement { Label(u32), - Variable(ast::Variable), + Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), @@ -5352,16 +5319,17 @@ impl ExpandedStatement { Statement::StoreVar(details) } Statement::Call(mut call) => { - for (id, typ) in call.ret_params.iter_mut() { - let is_dst = match typ { - ast::FnArgumentType::Reg(_) => true, - ast::FnArgumentType::Param(_) => false, - ast::FnArgumentType::Shared => false, + for (id, _, space) in call.ret_params.iter_mut() { + let is_dst = match space { + ast::StateSpace::Reg => true, + ast::StateSpace::Param => false, + ast::StateSpace::Shared => false, + _ => todo!(), }; *id = f(*id, is_dst); } call.func = f(call.func, false); - for (id, _) in call.param_list.iter_mut() { + for (id, _, _) in call.param_list.iter_mut() { *id = f(*id, false); } Statement::Call(call) @@ -5502,9 +5470,9 @@ impl, U: ArgParamsEx> Visitab struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, + pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>, pub func: P::Id, - pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, + pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>, } impl ResolvedCall { @@ -5526,16 +5494,16 @@ impl> ResolvedCall { let ret_params = self .ret_params .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.id( ArgumentDescriptor { op: id, - is_dst: !typ.is_param(), - sema: typ.semantics(), + is_dst: space != ast::StateSpace::Param, + sema: space.semantics(), }, - Some(&typ.to_func_type()), + Some(&typ), )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; let func = visitor.id( @@ -5549,16 +5517,16 @@ impl> ResolvedCall { let param_list = self .param_list .into_iter() - .map::, _>(|(id, typ)| { + .map::, _>(|(id, typ, space)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, - sema: typ.semantics(), + sema: space.semantics(), }, - &typ.to_func_type(), + &typ, )?; - Ok((new_id, typ)) + Ok((new_id, typ, space)) }) .collect::, _>>()?; Ok(ResolvedCall { @@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams { } enum Directive<'input> { - Variable(ast::Variable), + Variable(ast::Variable), Method(Function<'input>), } struct Function<'input> { pub func_decl: ast::MethodDecl<'input, spirv::Word>, pub spirv_decl: SpirvMethodDecl<'input>, - pub globals: Vec>, + pub globals: Vec>, pub body: Option>, import_as: Option, tuning: Vec, @@ -7300,16 +7268,6 @@ impl ast::LdStateSpace { } } -impl From for ast::VariableType { - fn from(t: ast::FnArgumentType) -> Self { - match t { - ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), - ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), - ast::FnArgumentType::Shared => todo!(), - } - } -} - impl ast::Operand { fn underlying(&self) -> Option<&T> { match self { @@ -7362,12 +7320,13 @@ impl ast::AtomSemantics { } } -impl ast::FnArgumentType { - fn semantics(&self) -> ArgumentSemantics { +impl ast::StateSpace { + fn semantics(self) -> ArgumentSemantics { match self { - ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default, - ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer, - ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer, + ast::StateSpace::Reg => ArgumentSemantics::Default, + ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, + ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer, + _ => todo!(), } } } @@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> { } struct SpirvMethodDecl<'input> { - input: Vec>, - output: Vec>, + input: Vec>, + output: Vec>, name: MethodName<'input>, uses_shared_mem: bool, } @@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> { ast::MethodDecl::Kernel { in_args, .. } => { let spirv_input = in_args .iter() - .map(|var| { - let v_type = match &var.v_type { - ast::KernelArgumentType::Normal(t) => { - ast::FnArgumentType::Param(t.clone()) - } - ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared, - }; - ast::Variable { - name: var.name, - align: var.align, - v_type: v_type.to_kernel_type(), - array_init: var.array_init.clone(), - } + .map(|var| ast::Variable { + name: var.name, + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + array_init: var.array_init.clone(), }) .collect(); (spirv_input, Vec::new()) } ast::MethodDecl::Func(out_args, _, in_args) => { - let (param_output, non_param_output): (Vec<_>, Vec<_>) = - out_args.iter().partition(|var| var.v_type.is_param()); + let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args + .iter() + .partition(|var| var.state_space == ast::StateSpace::Param); let spirv_output = non_param_output .into_iter() .cloned() .map(|var| ast::Variable { name: var.name, align: var.align, - v_type: var.v_type.to_func_type(), + v_type: var.v_type.clone(), + state_space: var.state_space, array_init: var.array_init.clone(), }) .collect(); @@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> { .map(|var| ast::Variable { name: var.name, align: var.align, - v_type: var.v_type.to_func_type(), + v_type: var.v_type.clone(), + state_space: var.state_space, array_init: var.array_init.clone(), }) .collect();