diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index bc2fa4c..3a7cf98 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -34,43 +34,6 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_enum { - ($name:ident { $($variant:ident),+ $(,)? }) => { - sub_enum!{ $name : ScalarType { $($variant),+ } } - }; - ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub enum $name { - $( - $variant, - )+ - } - - impl From<$name> for $base_type { - fn from(t: $name) -> $base_type { - match t { - $( - $name::$variant => $base_type::$variant, - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $name { - type Error = (); - - fn try_from(t: $base_type) -> Result { - match t { - $( - $base_type::$variant => Ok($name::$variant), - )+ - _ => Err(()), - } - } - } - }; -} - macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { sub_type! { $type_name : Type { @@ -118,12 +81,12 @@ macro_rules! sub_type { sub_type! { VariableRegType { Scalar(ScalarType), - Vector(SizedScalarType, u8), + Vector(ScalarType, u8), // Array type is used when emiting SSA statements at the start of a method Array(ScalarType, VecU32), // Pointer variant is used when passing around SLM pointer between // function calls for dynamic SLM - Pointer(SizedScalarType, PointerStateSpace) + Pointer(ScalarType, LdStateSpace) } } @@ -131,9 +94,9 @@ type VecU32 = Vec; sub_type! { VariableLocalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), + Scalar(ScalarType), + Vector(ScalarType, u8), + Array(ScalarType, VecU32), } } @@ -152,10 +115,10 @@ impl TryFrom for VariableLocalType { sub_type! { VariableGlobalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), + Scalar(ScalarType), + Vector(ScalarType, u8), + Array(ScalarType, VecU32), + Pointer(ScalarType, LdStateSpace), } } @@ -167,49 +130,12 @@ sub_type! { // .param .b32 foobar[] sub_type! { VariableParamType { - Scalar(LdStScalarType), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), + Scalar(ScalarType), + Array(ScalarType, VecU32), + Pointer(ScalarType, LdStateSpace), } } -sub_enum!(SizedScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F16x2, - F32, - F64, -}); - -sub_enum!(LdStScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, -}); - #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { SyncAligned, @@ -345,16 +271,6 @@ impl FnArgumentType { } } -sub_enum!( - PointerStateSpace : LdStateSpace { - Generic, - Global, - Const, - Shared, - Param, - } -); - #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), @@ -371,18 +287,18 @@ pub enum PointerType { Pointer(ScalarType, LdStateSpace), } -impl From for PointerType { - fn from(t: SizedScalarType) -> Self { +impl From for PointerType { + fn from(t: ScalarType) -> Self { PointerType::Scalar(t.into()) } } -impl TryFrom for SizedScalarType { +impl TryFrom for ScalarType { type Error = (); fn try_from(value: PointerType) -> Result { match value { - PointerType::Scalar(t) => Ok(t.try_into()?), + PointerType::Scalar(t) => Ok(t), PointerType::Vector(_, _) => Err(()), PointerType::Array(_, _) => Err(()), PointerType::Pointer(_, _) => Err(()), @@ -685,8 +601,8 @@ pub struct LdDetails { sub_type! { LdStType { - Scalar(LdStScalarType), - Vector(LdStScalarType, u8), + Scalar(ScalarType), + Vector(ScalarType, u8), // Used in generated code Pointer(PointerType, LdStateSpace), } @@ -1135,7 +1051,7 @@ pub struct NegDetails { } impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result, PtxError> { + pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result, PtxError> { self.normalize_dimensions(dimensions)?; let sizeof_t = ScalarType::from(typ).size_of() as usize; let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); @@ -1166,7 +1082,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy( &self, - t: SizedScalarType, + t: ScalarType, size_of_t: usize, dimensions: &[u32], result: &mut [u8], @@ -1206,47 +1122,48 @@ impl<'a> NumsOrArrays<'a> { } fn parse_and_copy_single( - t: SizedScalarType, + t: ScalarType, idx: usize, str_val: &str, radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { - SizedScalarType::B8 | SizedScalarType::U8 => { + ScalarType::B8 | ScalarType::U8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B16 | SizedScalarType::U16 => { + ScalarType::B16 | ScalarType::U16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B32 | SizedScalarType::U32 => { + ScalarType::B32 | ScalarType::U32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::B64 | SizedScalarType::U64 => { + ScalarType::B64 | ScalarType::U64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S8 => { + ScalarType::S8 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S16 => { + ScalarType::S16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S32 => { + ScalarType::S32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::S64 => { + ScalarType::S64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16 => { + ScalarType::F16 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F16x2 => todo!(), - SizedScalarType::F32 => { + ScalarType::F16x2 => todo!(), + ScalarType::F32 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } - SizedScalarType::F64 => { + ScalarType::F64 => { Self::parse_and_copy_single_t::(idx, str_val, radix, output)?; } + ScalarType::Pred => todo!() } Ok(()) } @@ -1334,13 +1251,13 @@ mod tests { #[test] fn array_fails_multiple_0_dmiensions() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); } #[test] fn array_fails_on_empty() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); } #[test] @@ -1352,7 +1269,7 @@ mod tests { let mut dimensions = vec![0u32, 2]; assert_eq!( vec![1u8, 2, 3, 4], - inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() ); assert_eq!(dimensions, vec![2u32, 2]); } @@ -1364,7 +1281,7 @@ mod tests { NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } #[test] @@ -1374,6 +1291,6 @@ mod tests { NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 41c1d73..7bd9c4f 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -611,9 +611,9 @@ ModuleVariable: ast::Variable = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new()) + (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new()) + (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) } } }; @@ -635,7 +635,7 @@ ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { (ast::VariableParamType::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new()) + (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) } }; (align, array_init, v_type, name) @@ -667,42 +667,42 @@ GlobalVariableDefinitionNoArray: (Option, ast::VariableGlobalType, &'input } #[inline] -SizedScalarType: ast::SizedScalarType = { - ".b8" => ast::SizedScalarType::B8, - ".b16" => ast::SizedScalarType::B16, - ".b32" => ast::SizedScalarType::B32, - ".b64" => ast::SizedScalarType::B64, - ".u8" => ast::SizedScalarType::U8, - ".u16" => ast::SizedScalarType::U16, - ".u32" => ast::SizedScalarType::U32, - ".u64" => ast::SizedScalarType::U64, - ".s8" => ast::SizedScalarType::S8, - ".s16" => ast::SizedScalarType::S16, - ".s32" => ast::SizedScalarType::S32, - ".s64" => ast::SizedScalarType::S64, - ".f16" => ast::SizedScalarType::F16, - ".f16x2" => ast::SizedScalarType::F16x2, - ".f32" => ast::SizedScalarType::F32, - ".f64" => ast::SizedScalarType::F64, +SizedScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } #[inline] -LdStScalarType: ast::LdStScalarType = { - ".b8" => ast::LdStScalarType::B8, - ".b16" => ast::LdStScalarType::B16, - ".b32" => ast::LdStScalarType::B32, - ".b64" => ast::LdStScalarType::B64, - ".u8" => ast::LdStScalarType::U8, - ".u16" => ast::LdStScalarType::U16, - ".u32" => ast::LdStScalarType::U32, - ".u64" => ast::LdStScalarType::U64, - ".s8" => ast::LdStScalarType::S8, - ".s16" => ast::LdStScalarType::S16, - ".s32" => ast::LdStScalarType::S32, - ".s64" => ast::LdStScalarType::S64, - ".f16" => ast::LdStScalarType::F16, - ".f32" => ast::LdStScalarType::F32, - ".f64" => ast::LdStScalarType::F64, +LdStScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } Instruction: ast::Instruction> = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 51b1dc6..7eec085 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -97,18 +97,6 @@ impl ast::Type { } } -impl Into for ast::PointerStateSpace { - fn into(self) -> spirv::StorageClass { - match self { - ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::PointerStateSpace::Param => spirv::StorageClass::Function, - ast::PointerStateSpace::Generic => spirv::StorageClass::Generic, - } - } -} - impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) @@ -824,8 +812,8 @@ fn convert_dynamic_shared_memory_usage<'input>( name: shared_var_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::B8, - ast::PointerStateSpace::Shared, + ast::ScalarType::B8, + ast::LdStateSpace::Shared, )), }); let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { @@ -863,7 +851,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>( result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, shared_var_id: spirv::Word, @@ -884,7 +872,7 @@ fn replace_uses_of_shared_memory<'a>( statement => { let new_statement = statement.map_id(&mut |id, _| { if let Some(typ) = extern_shared_decls.get(&id) { - if *typ == ast::SizedScalarType::B8 { + if *typ == ast::ScalarType::B8 { return shared_var_id; } let replacement_id = new_id(); @@ -1505,7 +1493,7 @@ fn extract_globals<'input, 'b>( d, a, "inc", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1527,7 +1515,7 @@ fn extract_globals<'input, 'b>( d, a, "dec", - ast::SizedScalarType::U32, + ast::ScalarType::U32, )); } Statement::Instruction(ast::Instruction::Atom( @@ -1553,8 +1541,8 @@ fn extract_globals<'input, 'b>( space, }; let (op, typ) = match typ { - ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32), - ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64), + ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32), + ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64), _ => unreachable!(), }; local.push(to_ptx_impl_atomic_call( @@ -1734,7 +1722,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails, arg: ast::Arg3, op: &'static str, - typ: ast::SizedScalarType, + typ: ast::ScalarType, ) -> ExpandedStatement { let semantics = ptx_semantics_name(details.semantics); let scope = ptx_scope_name(details.scope); @@ -1745,9 +1733,9 @@ fn to_ptx_impl_atomic_call( ); // TODO: extract to a function let ptr_space = match details.space { - ast::AtomSpace::Generic => ast::PointerStateSpace::Generic, - ast::AtomSpace::Global => ast::PointerStateSpace::Global, - ast::AtomSpace::Shared => ast::PointerStateSpace::Shared, + ast::AtomSpace::Generic => ast::LdStateSpace::Generic, + ast::AtomSpace::Global => ast::LdStateSpace::Global, + ast::AtomSpace::Shared => ast::LdStateSpace::Shared, }; let scalar_typ = ast::ScalarType::from(typ); let fn_id = match ptx_impl_imports.entry(fn_name) { @@ -4565,7 +4553,7 @@ fn convert_to_stateful_memory_access<'a>( Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::U64), + typ: ast::LdStType::Scalar(ast::ScalarType::U64), .. }, arg, @@ -4573,7 +4561,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::S64), + typ: ast::LdStType::Scalar(ast::ScalarType::S64), .. }, arg, @@ -4581,7 +4569,7 @@ fn convert_to_stateful_memory_access<'a>( | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, - typ: ast::LdStType::Scalar(ast::LdStScalarType::B64), + typ: ast::LdStType::Scalar(ast::ScalarType::B64), .. }, arg, @@ -4672,8 +4660,8 @@ fn convert_to_stateful_memory_access<'a>( name: new_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( - ast::SizedScalarType::U8, - ast::PointerStateSpace::Global, + ast::ScalarType::U8, + ast::LdStateSpace::Global, )), })); remapped_ids.insert(reg, new_id);