From 1238796dfd58adad813cc8411d299ef7457fd22a Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 11 Sep 2020 00:40:13 +0200 Subject: [PATCH] Be more precise about types admitted in register definitions and method arguments --- ptx/src/ast.rs | 256 +++++++++++----- ptx/src/ptx.lalrpop | 205 ++++++++----- ptx/src/test/mod.rs | 26 +- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/vector.ptx | 44 +++ ptx/src/test/spirv_run/vector.spvtxt | 46 +++ ptx/src/translate.rs | 420 ++++++++++++++------------- 7 files changed, 647 insertions(+), 351 deletions(-) create mode 100644 ptx/src/test/spirv_run/vector.ptx create mode 100644 ptx/src/test/spirv_run/vector.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index cfbdad5..2fed9ff 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -12,9 +12,117 @@ quick_error! { SyntaxError {} NonF32Ftz {} WrongArrayType {} + WrongVectorElement {} + MultiArrayVariable {} } } +macro_rules! sub_scalar_type { + ($name:ident { $($variant:ident),+ $(,)? }) => { + #[derive(PartialEq, Eq, Clone, Copy)] + pub enum $name { + $( + $variant, + )+ + } + + impl From<$name> for ScalarType { + fn from(t: $name) -> ScalarType { + match t { + $( + $name::$variant => ScalarType::$variant, + )+ + } + } + } + }; +} + +macro_rules! sub_type { + ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { + #[derive(PartialEq, Eq, Clone, Copy)] + pub enum $type_name { + $( + $variant ($($field_type),+), + )+ + } + + impl From<$type_name> for Type { + #[allow(non_snake_case)] + fn from(t: $type_name) -> Type { + match t { + $( + $type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+), + )+ + } + } + } + }; +} + +sub_type! { + VariableRegType { + Scalar(ScalarType), + Vector(SizedScalarType, u8), + } +} + +sub_type! { + VariableLocalType { + Scalar(SizedScalarType), + Vector(SizedScalarType, u8), + Array(SizedScalarType, u32), + } +} + +// For some weird reson this is illegal: +// .param .f16x2 foobar; +// but this is legal: +// .param .f16x2 foobar[1]; +sub_type! { + VariableParamType { + Scalar(ParamScalarType), + Array(SizedScalarType, u32), + } +} + +sub_scalar_type!(SizedScalarType { + B8, + B16, + B32, + B64, + U8, + U16, + U32, + U64, + S8, + S16, + S32, + S64, + F16, + F16x2, + F32, + F64, +}); + +sub_scalar_type!(ParamScalarType { + B8, + B16, + B32, + B64, + U8, + U16, + U32, + U64, + S8, + S16, + S32, + S64, + F16, + F32, + F64, +}); + pub trait UnwrapWithVec { fn unwrap_with(self, errs: &mut Vec) -> To; } @@ -56,6 +164,9 @@ pub enum MethodDecl<'a, P: ArgParams> { Kernel(&'a str, Vec>), } +pub type FnArgument = Variable; +pub type KernelArgument = Variable; + pub struct Function<'a, P: ArgParams, S> { pub func_directive: MethodDecl<'a, P>, pub body: Option>, @@ -63,43 +174,28 @@ pub struct Function<'a, P: ArgParams, S> { pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement>>; -pub struct FnArgument { - pub base: KernelArgument

, - pub state_space: FnArgStateSpace, +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum FnArgumentType { + Reg(VariableRegType), + Param(VariableParamType), } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum FnArgStateSpace { - Reg, - Param, -} - -#[derive(Default, Copy, Clone)] -pub struct KernelArgument { - pub name: P::ID, - pub a_type: ScalarType, - // TODO: turn length into part of type definition - pub length: u32, +impl From for Type { + fn from(t: FnArgumentType) -> Self { + match t { + FnArgumentType::Reg(x) => x.into(), + FnArgumentType::Param(x) => x.into(), + } + } } #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum Type { Scalar(ScalarType), - ExtendedScalar(ExtendedScalarType), + Vector(ScalarType, u8), Array(ScalarType, u32), } -impl From for Type { - fn from(t: FloatType) -> Self { - match t { - FloatType::F16 => Type::Scalar(ScalarType::F16), - FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2), - FloatType::F32 => Type::Scalar(ScalarType::F32), - FloatType::F64 => Type::Scalar(ScalarType::F64), - } - } -} - #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum ScalarType { B8, @@ -117,25 +213,11 @@ pub enum ScalarType { F16, F32, F64, + F16x2, + Pred, } -impl From for ScalarType { - fn from(t: IntType) -> Self { - match t { - IntType::S8 => ScalarType::S8, - IntType::S16 => ScalarType::S16, - IntType::S32 => ScalarType::S32, - IntType::S64 => ScalarType::S64, - IntType::U8 => ScalarType::U8, - IntType::U16 => ScalarType::U16, - IntType::U32 => ScalarType::U32, - IntType::U64 => ScalarType::U64, - } - } -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum IntType { +sub_scalar_type!(IntType { U8, U16, U32, @@ -143,8 +225,8 @@ pub enum IntType { S8, S16, S32, - S64, -} + S64 +}); impl IntType { pub fn is_signed(self) -> bool { @@ -168,19 +250,12 @@ impl IntType { } } -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum FloatType { +sub_scalar_type!(FloatType { F16, F16x2, F32, - F64, -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum ExtendedScalarType { - F16x2, - Pred, -} + F64 +}); impl Default for ScalarType { fn default() -> Self { @@ -190,19 +265,39 @@ impl Default for ScalarType { pub enum Statement { Label(P::ID), - Variable(Variable

), + Variable(MultiVariable

), Instruction(Option>, Instruction

), Block(Vec>), } -pub struct Variable { - pub space: StateSpace, - pub align: Option, - pub v_type: Type, - pub name: P::ID, +pub struct MultiVariable { + pub var: Variable, pub count: Option, } +pub struct Variable { + pub align: Option, + pub v_type: T, + pub name: P::ID, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +pub enum VariableType { + Reg(VariableRegType), + Local(VariableLocalType), + Param(VariableParamType), +} + +impl From for Type { + fn from(t: VariableType) -> Self { + match t { + VariableType::Reg(t) => t.into(), + VariableType::Local(t) => t.into(), + VariableType::Param(t) => t.into(), + } + } +} + #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -322,7 +417,7 @@ pub enum CallOperand { pub enum MovOperand { Op(Operand), - Vec(String, String), + Vec(ID, u8), } pub enum VectorPrefix { @@ -334,7 +429,7 @@ pub struct LdData { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub vector: Option, + pub vector: Option, pub typ: ScalarType, } @@ -376,6 +471,37 @@ pub struct MovData { pub typ: Type, } +sub_scalar_type!(MovScalarType { + B16, + B32, + B64, + U16, + U32, + U64, + S16, + S32, + S64, + F32, + F64, + Pred, +}); + +enum MovType { + Scalar(MovScalarType), + Vector(MovScalarType, u8), + Array(MovScalarType, u32), +} + +impl From for Type { + fn from(t: MovType) -> Self { + match t { + MovType::Scalar(t) => Type::Scalar(t.into()), + MovType::Vector(t, len) => Type::Vector(t.into(), len), + MovType::Array(t, len) => Type::Array(t.into(), len), + } + } +} + pub enum MulDetails { Int(MulIntDesc), Float(MulFloatDesc), @@ -587,7 +713,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub vector: Option, + pub vector: Option, pub typ: ScalarType, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 53bb296..bb77b62 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -21,6 +21,7 @@ match { "@", "[", "]", "{", "}", + "<", ">", "|", ".acquire", ".address_size", @@ -133,8 +134,6 @@ match { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID, r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID, -} else { - r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID, } ExtendedID : &'input str = { @@ -214,7 +213,9 @@ LinkingDirective = { MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = { ".entry" => ast::MethodDecl::Kernel(name, params), - ".func" => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) + ".func" => { + ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) + } }; KernelArguments: Vec>> = { @@ -225,32 +226,25 @@ FnArguments: Vec>> = { "(" > ")" => args }; -FnInput: ast::FnArgument> = { - ".reg" <_type:ScalarType> => { - ast::FnArgument { - base: ast::KernelArgument {a_type: _type, name: name, length: 1 }, - state_space: ast::FnArgStateSpace::Reg, - } - }, - => { - ast::FnArgument { - base: p, - state_space: ast::FnArgStateSpace::Param, - } +KernelInput: ast::Variable> = { + => { + let (align, v_type, name) = v; + ast::Variable{ align, v_type, name } } -}; +} -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -KernelInput: ast::KernelArgument> = { - ".param" <_type:ScalarType> => { - ast::KernelArgument {a_type: _type, name: name, length: 1 } +FnInput: ast::Variable> = { + => { + let (align, v_type, name) = v; + let v_type = ast::FnArgumentType::Reg(v_type); + ast::Variable{ align, v_type, name } }, - ".param" "[" "]" => { - let length = length.parse::(); - let length = length.unwrap_with(errors); - ast::KernelArgument { a_type: a_type, name: name, length: length } + => { + let (align, v_type, name) = v; + let v_type = ast::FnArgumentType::Param(v_type); + ast::Variable{ align, v_type, name } } -}; +} pub(crate) FunctionBody: Option>>> = { "{" "}" => { Some(without_none(s)) }, @@ -267,22 +261,13 @@ StateSpaceSpecifier: ast::StateSpace = { ".param" => ast::StateSpace::Param, // used to prepare function call }; - -Type: ast::Type = { - => ast::Type::Scalar(t), - => ast::Type::ExtendedScalar(t), -}; - ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".pred" => ast::ScalarType::Pred, MemoryType }; -ExtendedScalarType: ast::ExtendedScalarType = { - ".f16x2" => ast::ExtendedScalarType::F16x2, - ".pred" => ast::ExtendedScalarType::Pred, -}; - MemoryType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, @@ -303,7 +288,7 @@ MemoryType: ast::ScalarType = { Statement: Option>> = { => Some(ast::Statement::Label(l)), DebugDirective => None, - ";" => Some(ast::Statement::Variable(v)), + ";" => Some(ast::Statement::Variable(v)), ";" => Some(ast::Statement::Instruction(p, i)), "{" "}" => Some(ast::Statement::Block(without_none(s))) }; @@ -328,21 +313,109 @@ Align: u32 = { } }; -Variable: ast::Variable> = { - => { - let (name, count) = v; - let t = match (t, arr) { - (ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size), - (t, Some(_)) => { - errors.push(ast::PtxError::WrongArrayType); - t - }, - (t, None) => t, - }; - ast::Variable { space: s, align: a, v_type: t, name: name, count: count } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names +MultiVariable: ast::MultiVariable> = { + => ast::MultiVariable{<>} +} + +VariableParam: u32 = { + "<" ">" => { + let size = n.parse::(); + size.unwrap_with(errors) } +} + +Variable: ast::Variable> = { + => { + let (align, v_type, name) = v; + let v_type = ast::VariableType::Reg(v_type); + ast::Variable {align, v_type, name} + }, + LocalVariable, + => { + let (align, v_type, name) = v; + let v_type = ast::VariableType::Param(v_type); + ast::Variable {align, v_type, name} + }, }; +RegVariable: (Option, ast::VariableRegType, &'input str) = { + ".reg" => { + let v_type = ast::VariableRegType::Scalar(t); + (align, v_type, name) + }, + ".reg" => { + let v_type = ast::VariableRegType::Vector(t, v_len); + (align, v_type, name) + } +} + +LocalVariable: ast::Variable> = { + ".local" => { + let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); + ast::Variable {align, v_type, name} + }, + ".local" => { + let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); + ast::Variable {align, v_type, name} + }, + ".local" => { + let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr)); + ast::Variable {align, v_type, name} + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space +ParamVariable: (Option, ast::VariableParamType, &'input str) = { + ".param" => { + let v_type = ast::VariableParamType::Scalar(t); + (align, v_type, name) + }, + ".param" => { + let v_type = ast::VariableParamType::Array(t, arr); + (align, v_type, name) + } +} + +#[inline] +SizedScalarType: ast::SizedScalarType = { + ".b8" => ast::SizedScalarType::B8, + ".b16" => ast::SizedScalarType::B16, + ".b32" => ast::SizedScalarType::B32, + ".b64" => ast::SizedScalarType::B64, + ".u8" => ast::SizedScalarType::U8, + ".u16" => ast::SizedScalarType::U16, + ".u32" => ast::SizedScalarType::U32, + ".u64" => ast::SizedScalarType::U64, + ".s8" => ast::SizedScalarType::S8, + ".s16" => ast::SizedScalarType::S16, + ".s32" => ast::SizedScalarType::S32, + ".s64" => ast::SizedScalarType::S64, + ".f16" => ast::SizedScalarType::F16, + ".f16x2" => ast::SizedScalarType::F16x2, + ".f32" => ast::SizedScalarType::F32, + ".f64" => ast::SizedScalarType::F64, +} + +#[inline] +ParamScalarType: ast::ParamScalarType = { + ".b8" => ast::ParamScalarType::B8, + ".b16" => ast::ParamScalarType::B16, + ".b32" => ast::ParamScalarType::B32, + ".b64" => ast::ParamScalarType::B64, + ".u8" => ast::ParamScalarType::U8, + ".u16" => ast::ParamScalarType::U16, + ".u32" => ast::ParamScalarType::U32, + ".u64" => ast::ParamScalarType::U64, + ".s8" => ast::ParamScalarType::S8, + ".s16" => ast::ParamScalarType::S16, + ".s32" => ast::ParamScalarType::S32, + ".s64" => ast::ParamScalarType::S64, + ".f16" => ast::ParamScalarType::F16, + ".f32" => ast::ParamScalarType::F32, + ".f64" => ast::ParamScalarType::F64, +} + ArraySpecifier: u32 = { "[" "]" => { let size = n.parse::(); @@ -350,20 +423,6 @@ ArraySpecifier: u32 = { } }; -VariableName: (&'input str, Option) = { - => (id, None), - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names - => { - let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap(); - let count = id[left_angle+1..id.len()-1].parse::(); - let count = match count { - Ok(c) => Some(c), - Err(e) => { errors.push(e.into()); None }, - }; - (&id[0..left_angle], count) - } -}; - Instruction: ast::Instruction> = { InstLd, InstMov, @@ -445,7 +504,7 @@ MovType: ast::Type = { ".s64" => ast::Type::Scalar(ast::ScalarType::S64), ".f32" => ast::Type::Scalar(ast::ScalarType::F32), ".f64" => ast::Type::Scalar(ast::ScalarType::F64), - ".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred) + ".pred" => ast::Type::Scalar(ast::ScalarType::Pred) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -934,7 +993,17 @@ MovOperand: ast::MovOperand<&'input str> = { => ast::MovOperand::Op(o), => { let (pref, suf) = o; - ast::MovOperand::Vec(pref.to_string(), suf.to_string()) + let suf_idx = match suf { + "x" | "r" => 0, + "y" | "g" => 1, + "z" | "b" => 2, + "w" | "a" => 3, + _ => { + errors.push(ast::PtxError::WrongVectorElement); + 0 + } + }; + ast::MovOperand::Vec(pref, suf_idx) } }; @@ -980,9 +1049,9 @@ OptionalDst: &'input str = { "|" => dst2 } -VectorPrefix: ast::VectorPrefix = { - ".v2" => ast::VectorPrefix::V2, - ".v4" => ast::VectorPrefix::V4 +VectorPrefix: u8 = { + ".v2" => 2, + ".v4" => 4 }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 3252b50..f40fc02 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -8,16 +8,16 @@ fn parse_and_assert(s: &str) { assert!(errors.len() == 0); } -#[test] -fn empty() { - parse_and_assert(".version 6.5 .target sm_30, debug"); +fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> { + let mut errors = Vec::new(); + let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); + crate::to_spirv(ast)?; + Ok(()) } #[test] -#[allow(non_snake_case)] -fn vectorAdd_kernel64_ptx() { - let vector_add = include_str!("vectorAdd_kernel64.ptx"); - parse_and_assert(vector_add); +fn empty() { + parse_and_assert(".version 6.5 .target sm_30, debug"); } #[test] @@ -28,8 +28,14 @@ fn operands_ptx() { #[test] #[allow(non_snake_case)] -fn _Z9vectorAddPKfS0_Pfi_ptx() { - let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx"); - parse_and_assert(vector_add); +fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> { + let vector_add = include_str!("vectorAdd_kernel64.ptx"); + compile_and_assert(vector_add) } +#[test] +#[allow(non_snake_case)] +fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> { + let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx"); + compile_and_assert(vector_add) +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a72c453..a04f0eb 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -54,6 +54,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); +test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/vector.ptx b/ptx/src/test/spirv_run/vector.ptx new file mode 100644 index 0000000..dea9543 --- /dev/null +++ b/ptx/src/test/spirv_run/vector.ptx @@ -0,0 +1,44 @@ +// Excersise as many features of vector types as possible + +.version 6.5 +.target sm_53 +.address_size 64 + +.func (.reg .v2 .u32 output) impl( + .reg .v2 .u32 input +) +{ + .reg .v2 .u32 temp_v; + .reg .u32 temp1; + .reg .u32 temp2; + + mov.u32 temp1, input.x; + mov.u32 temp2, input.y; + add.u32 temp2, temp1, temp2; + mov.u32 temp_v.x, temp2; + mov.u32 temp_v.y, temp2; + mov.v2.u32 output, temp_v; + ret; +} + +.visible .entry vector( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .v2 .u32 temp; + .reg .u32 temp1; + .reg .u32 temp2; + .reg .b64 packed; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.v2.u32 temp, [in_addr]; + call (temp), impl, (temp); + mov.b64 packed, temp; + st.v2.u32 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt new file mode 100644 index 0000000..6810fec --- /dev/null +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -0,0 +1,46 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index bd37b14..1f58db8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -8,6 +8,7 @@ use rspirv::binary::Assemble; #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), + Vector(SpirvScalarKey, u8), Array(SpirvScalarKey, u32), Pointer(Box, spirv::StorageClass), Func(Option>, Vec), @@ -17,7 +18,7 @@ impl SpirvType { fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { let key = match t { ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), - ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)), + ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len), ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len), }; SpirvType::Pointer(Box::new(key), sc) @@ -28,7 +29,7 @@ impl From for SpirvType { fn from(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()), + ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), } } @@ -77,15 +78,8 @@ impl From for SpirvScalarKey { ast::ScalarType::F16 => SpirvScalarKey::F16, ast::ScalarType::F32 => SpirvScalarKey::F32, ast::ScalarType::F64 => SpirvScalarKey::F64, - } - } -} - -impl From for SpirvScalarKey { - fn from(t: ast::ExtendedScalarType) -> Self { - match t { - ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred, - ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2, + ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, + ast::ScalarType::Pred => SpirvScalarKey::Pred, } } } @@ -135,6 +129,13 @@ impl TypeWordMap { .entry(t) .or_insert_with(|| b.type_pointer(None, storage, base)) } + SpirvType::Vector(typ, len) => { + let base = self.get_or_add_spirv_scalar(b, typ); + *self + .complex + .entry(t) + .or_insert_with(|| b.type_vector(base, len as u32)) + } SpirvType::Array(typ, len) => { let base = self.get_or_add_spirv_scalar(b, typ); *self @@ -232,8 +233,8 @@ fn emit_function_header<'a>( spirv::FunctionControl::NONE, func_type, )?; - func_directive.visit_args(|arg| { - let result_type = map.get_or_add_scalar(builder, arg.a_type); + func_directive.visit_args(&mut |arg| { + let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into()); let inst = dr::Instruction::new( spirv::Op::FunctionParameter, Some(result_type), @@ -285,9 +286,9 @@ fn expand_kernel_params<'a, 'b>( args: impl Iterator>>, ) -> Vec> { args.map(|a| ast::KernelArgument { - name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))), - a_type: a.a_type, - length: a.length, + name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))), + v_type: a.v_type, + align: a.align, }) .collect() } @@ -297,12 +298,9 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator>>, ) -> Vec> { args.map(|a| ast::FnArgument { - state_space: a.state_space, - base: ast::KernelArgument { - name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))), - a_type: a.base.a_type, - length: a.base.length, - }, + name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))), + v_type: a.v_type, + align: a.align, }) .collect() } @@ -375,16 +373,12 @@ fn resolve_fn_calls( fn to_resolved_fn_args( params: Vec, - params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)], -) -> Vec> { + params_decl: &[ast::FnArgumentType], +) -> Vec<(T, ast::FnArgumentType)> { params .into_iter() .zip(params_decl.iter()) - .map(|(id, &(space, typ))| ArgCall { - id, - typ: ast::Type::Scalar(typ), - space: space, - }) + .map(|(id, typ)| (id, *typ)) .collect::>() } @@ -476,12 +470,11 @@ fn insert_mem_ssa_statements<'a, 'b>( let out_param = match &mut f_args { ast::MethodDecl::Kernel(_, in_params) => { for p in in_params.iter_mut() { - let typ = ast::Type::Scalar(p.a_type); + let typ = ast::Type::from(p.v_type); let new_id = id_def.new_id(Some(typ)); - result.push(Statement::Variable(VariableDecl { - space: ast::StateSpace::Reg, - align: None, - v_type: typ, + result.push(Statement::Variable(ast::Variable { + align: p.align, + v_type: ast::VariableType::Param(p.v_type), name: p.name, })); result.push(Statement::StoreVar( @@ -497,32 +490,31 @@ fn insert_mem_ssa_statements<'a, 'b>( } ast::MethodDecl::Func(out_params, _, in_params) => { for p in in_params.iter_mut() { - let typ = ast::Type::Scalar(p.base.a_type); + let typ = ast::Type::from(p.v_type); let new_id = id_def.new_id(Some(typ)); - result.push(Statement::Variable(VariableDecl { - space: ast::StateSpace::Reg, - align: None, - v_type: typ, - name: p.base.name, + let var_typ = ast::VariableType::from(p.v_type); + result.push(Statement::Variable(ast::Variable { + align: p.align, + v_type: var_typ, + name: p.name, })); result.push(Statement::StoreVar( ast::Arg2St { - src1: p.base.name, + src1: p.name, src2: new_id, }, typ, )); - p.base.name = new_id; + p.name = new_id; } match &mut **out_params { [p] => { - result.push(Statement::Variable(VariableDecl { - space: ast::StateSpace::Reg, - align: None, - v_type: ast::Type::Scalar(p.base.a_type), - name: p.base.name, + result.push(Statement::Variable(ast::Variable { + align: p.align, + v_type: ast::VariableType::from(p.v_type), + name: p.name, })); - Some(p.base.name) + Some(p.name) } [] => None, _ => todo!(), @@ -552,15 +544,13 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst), }, Statement::Conditional(mut bra) => { - let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar( - ast::ExtendedScalarType::Pred, - ))); + let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred))); result.push(Statement::LoadVar( Arg2 { dst: generated_id, src: bra.predicate, }, - ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), + ast::Type::Scalar(ast::ScalarType::Pred), )); bra.predicate = generated_id; result.push(Statement::Conditional(bra)); @@ -642,7 +632,15 @@ fn expand_arguments<'a, 'b>( let new_inst = inst.map(&mut visitor); result.push(Statement::Instruction(new_inst)); } - Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)), + Statement::Variable(ast::Variable { + align, + v_type, + name, + }) => result.push(Statement::Variable(ast::Variable { + align, + v_type, + name, + })), Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), @@ -745,7 +743,7 @@ impl<'a, 'b> ArgumentMapVisitor ) -> spirv::Word { match &desc.op { ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)), - ast::MovOperand::Vec(_, _) => todo!(), + ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)), } } } @@ -835,13 +833,19 @@ fn get_function_type( match method_decl { ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn( builder, - out_params.iter().map(|p| SpirvType::from(p.base.a_type)), - in_params.iter().map(|p| SpirvType::from(p.base.a_type)), + out_params + .iter() + .map(|p| SpirvType::from(ast::Type::from(p.v_type))), + in_params + .iter() + .map(|p| SpirvType::from(ast::Type::from(p.v_type))), ), ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( builder, iter::empty(), - params.iter().map(|p| SpirvType::from(p.a_type)), + params + .iter() + .map(|p| SpirvType::from(ast::Type::from(p.v_type))), ), } } @@ -870,31 +874,38 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { - [p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id), + [(id, typ)] => ( + map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))), + *id, + ), _ => todo!(), }; - let arg_list = call.param_list.iter().map(|p| p.id).collect::>(); + let arg_list = call + .param_list + .iter() + .map(|(id, _)| *id) + .collect::>(); builder.function_call(result_type, Some(result_id), call.func, arg_list)?; } - Statement::Variable(VariableDecl { - name: id, - v_type: typ, - space: ss, + Statement::Variable(ast::Variable { align, + v_type, + name, }) => { let type_id = map.get_or_add( builder, - SpirvType::new_pointer(*typ, spirv::StorageClass::Function), + SpirvType::new_pointer(ast::Type::from(*v_type), spirv::StorageClass::Function), ); - let st_class = match ss { - ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function, - ast::StateSpace::Local => spirv::StorageClass::Workgroup, - _ => todo!(), + let st_class = match v_type { + ast::VariableType::Reg(_) | ast::VariableType::Param(_) => { + spirv::StorageClass::Function + } + ast::VariableType::Local(_) => spirv::StorageClass::Workgroup, }; - builder.variable(type_id, Some(*id), st_class, None); + builder.variable(type_id, Some(*name), st_class, None); if let Some(align) = align { builder.decorate( - *id, + *name, spirv::Decoration::Alignment, &[dr::Operand::LiteralInt32(*align)], ); @@ -1051,7 +1062,7 @@ fn emit_cvt( if desc.saturate || desc.flush_to_zero { todo!() } - let dest_t: ast::Type = desc.dst.into(); + let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); builder.f_convert(result_type, Some(arg.dst), arg.src)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); @@ -1060,7 +1071,7 @@ fn emit_cvt( if desc.saturate || desc.flush_to_zero { todo!() } - let dest_t: ast::Type = desc.dst.into(); + let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.src.is_signed() { builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; @@ -1367,7 +1378,7 @@ fn normalize_identifiers<'a, 'b>( fn expand_map_variables<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, - fn_defs: &GlobalFnDeclResolver, + fn_defs: &GlobalFnDeclResolver<'a, 'b>, result: &mut Vec, s: ast::Statement>, ) { @@ -1386,21 +1397,19 @@ fn expand_map_variables<'a, 'b>( ))), ast::Statement::Variable(var) => match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.name, count, var.v_type) { - result.push(Statement::Variable(VariableDecl { - space: var.space, - align: var.align, - v_type: var.v_type, + for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) { + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type, name: new_id, })) } } None => { - let new_id = id_defs.add_def(var.name, Some(var.v_type)); - result.push(Statement::Variable(VariableDecl { - space: var.space, - align: var.align, - v_type: var.v_type, + let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into())); + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type, name: new_id, })); } @@ -1408,15 +1417,38 @@ fn expand_map_variables<'a, 'b>( } } +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)] +enum PtxSpecialRegister { + Tid, + Ntid, + Ctaid, + Nctaid, + Gridid, +} + +impl PtxSpecialRegister { + fn try_parse(s: &str) -> Option { + match s { + "%tid" => Some(Self::Tid), + "%ntid" => Some(Self::Ntid), + "%ctaid" => Some(Self::Ctaid), + "%nctaid" => Some(Self::Nctaid), + "%gridid" => Some(Self::Gridid), + _ => None, + } + } +} + struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, + special_registers: HashMap, fns: HashMap, } pub struct FnDecl { - ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, - params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, + ret_vals: Vec, + params: Vec, } impl<'a> GlobalStringIdResolver<'a> { @@ -1424,6 +1456,7 @@ impl<'a> GlobalStringIdResolver<'a> { Self { current_id: start_id, variables: HashMap::new(), + special_registers: HashMap::new(), fns: HashMap::new(), } } @@ -1461,6 +1494,7 @@ impl<'a> GlobalStringIdResolver<'a> { let mut fn_resolver = FnStringIdResolver { current_id: &mut self.current_id, global_variables: &self.variables, + special_registers: &mut self.special_registers, variables: vec![HashMap::new(); 1], type_check: HashMap::new(), }; @@ -1474,14 +1508,8 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert( name_id, FnDecl { - ret_vals: ret_params_ids - .iter() - .map(|p| (p.state_space, p.base.a_type)) - .collect(), - params: params_ids - .iter() - .map(|p| (p.state_space, p.base.a_type)) - .collect(), + ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(), + params: params_ids.iter().map(|p| p.v_type).collect(), }, ); ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) @@ -1516,7 +1544,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, - //global: &'b mut GlobalStringIdResolver<'a>, + special_registers: &'b mut HashMap, variables: Vec, spirv::Word>>, type_check: HashMap, } @@ -1537,14 +1565,28 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { self.variables.pop(); } - fn get_id(&self, id: &str) -> spirv::Word { + fn get_id(&mut self, id: &str) -> spirv::Word { for scope in self.variables.iter().rev() { match scope.get(id) { Some(id) => return *id, None => continue, } } - self.global_variables[id] + match self.global_variables.get(id) { + Some(id) => *id, + None => { + let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!()); + match self.special_registers.entry(sreg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = *self.current_id; + *self.current_id += 1; + e.insert(numeric_id); + numeric_id + } + } + } + } } fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { @@ -1602,7 +1644,7 @@ impl<'b> NumericIdResolver<'b> { enum Statement { Label(u32), - Variable(VariableDecl), + Variable(ast::Variable), Instruction(I), LoadVar(ast::Arg2, ast::Type), StoreVar(ast::Arg2St, ast::Type), @@ -1614,18 +1656,11 @@ enum Statement { RetValue(ast::RetData, spirv::Word), } -struct VariableDecl { - pub space: ast::StateSpace, - pub align: Option, - pub v_type: ast::Type, - pub name: spirv::Word, -} - struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec>, + pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>, pub func: spirv::Word, - pub param_list: Vec>, + pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>, } impl> ResolvedCall { @@ -1636,18 +1671,14 @@ impl> ResolvedCall { let ret_params = self .ret_params .into_iter() - .map(|p| { + .map(|(id, typ)| { let new_id = visitor.variable(ArgumentDescriptor { - op: p.id, - typ: Some(p.typ), + op: id, + typ: Some(typ.into()), is_dst: true, is_pointer: false, }); - ArgCall { - id: new_id, - typ: p.typ, - space: p.space, - } + (new_id, typ) }) .collect(); let func = visitor.variable(ArgumentDescriptor { @@ -1659,18 +1690,14 @@ impl> ResolvedCall { let param_list = self .param_list .into_iter() - .map(|p| { + .map(|(id, typ)| { let new_id = visitor.src_call_operand(ArgumentDescriptor { - op: p.id, - typ: Some(p.typ), + op: id, + typ: Some(typ.into()), is_dst: false, is_pointer: false, }); - ArgCall { - id: new_id, - typ: p.typ, - space: p.space, - } + (new_id, typ) }) .collect(); ResolvedCall { @@ -1700,12 +1727,6 @@ impl VisitVariableExpanded for ResolvedCall { } } -struct ArgCall { - id: ID, - typ: ast::Type, - space: ast::FnArgStateSpace, -} - pub trait ArgParamsEx: ast::ArgParams { fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl; } @@ -1817,7 +1838,9 @@ where ) -> ast::MovOperand { match desc.op { ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))), - ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), + ast::MovOperand::Vec(reg, x2) => { + ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2) + } } } } @@ -1881,13 +1904,18 @@ impl ast::Instruction { } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { - ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()), - ast::CvtDetails::FloatFromInt(desc) => { - (desc.dst.into(), ast::Type::Scalar(desc.src.into())) - } - ast::CvtDetails::IntFromFloat(desc) => { - (ast::Type::Scalar(desc.dst.into()), desc.src.into()) - } + ast::CvtDetails::FloatFromFloat(desc) => ( + ast::Type::Scalar(desc.dst.into()), + ast::Type::Scalar(desc.src.into()), + ), + ast::CvtDetails::FloatFromInt(desc) => ( + ast::Type::Scalar(desc.dst.into()), + ast::Type::Scalar(desc.src.into()), + ), + ast::CvtDetails::IntFromFloat(desc) => ( + ast::Type::Scalar(desc.dst.into()), + ast::Type::Scalar(desc.src.into()), + ), ast::CvtDetails::IntFromInt(desc) => ( ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.src.into()), @@ -2261,14 +2289,14 @@ impl ast::Arg4 { ast::Arg4 { dst1: visitor.variable(ArgumentDescriptor { op: self.dst1, - typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), is_dst: true, is_pointer: false, }), dst2: self.dst2.map(|dst2| { visitor.variable(ArgumentDescriptor { op: dst2, - typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), is_dst: true, is_pointer: false, }) @@ -2298,14 +2326,14 @@ impl ast::Arg5 { ast::Arg5 { dst1: visitor.variable(ArgumentDescriptor { op: self.dst1, - typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), is_dst: true, is_pointer: false, }), dst2: self.dst2.map(|dst2| { visitor.variable(ArgumentDescriptor { op: dst2, - typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), is_dst: true, is_pointer: false, }) @@ -2324,7 +2352,7 @@ impl ast::Arg5 { }), src3: visitor.operand(ArgumentDescriptor { op: self.src3, - typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), + typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)), is_dst: false, is_pointer: false, }), @@ -2332,65 +2360,6 @@ impl ast::Arg5 { } } -/* -impl ast::ArgCall { - fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor>( - self, - visitor: &mut V, - fn_resolve: &GlobalFnDeclResolver<'a>, - ) -> ast::ArgCall { - // TODO: error out if lengths don't match - let fn_decl = T::get_fn_decl(&self.func, fn_resolve); - let ret_params = self - .ret_params - .into_iter() - .zip(fn_decl.ret_vals.iter().copied()) - .map(|(a, (space, typ))| { - visitor.variable(ArgumentDescriptor { - op: a, - typ: Some(ast::Type::Scalar(typ)), - is_dst: true, - is_pointer: if space == ast::FnArgStateSpace::Reg { - false - } else { - true - }, - }) - }) - .collect(); - let func = visitor.variable(ArgumentDescriptor { - op: self.func, - typ: None, - is_dst: false, - is_pointer: false, - }); - let param_list = self - .param_list - .into_iter() - .zip(fn_decl.params.iter().copied()) - .map(|(a, (space, typ))| { - visitor.src_call_operand(ArgumentDescriptor { - op: a, - typ: Some(ast::Type::Scalar(typ)), - is_dst: false, - is_pointer: if space == ast::FnArgStateSpace::Reg { - false - } else { - true - }, - }) - }) - .collect(); - ast::ArgCall { - uniform: false, - ret_params, - func: func, - param_list: param_list, - } - } -} -*/ - impl ast::CallOperand { fn map_variable U>(self, f: &mut F) -> ast::CallOperand { match self { @@ -2418,6 +2387,8 @@ enum ScalarKind { Unsigned, Signed, Float, + Float2, + Pred, } impl ast::ScalarType { @@ -2438,6 +2409,8 @@ impl ast::ScalarType { ast::ScalarType::S64 => 8, ast::ScalarType::B64 => 8, ast::ScalarType::F64 => 8, + ast::ScalarType::F16x2 => 4, + ast::ScalarType::Pred => 1, } } @@ -2458,6 +2431,8 @@ impl ast::ScalarType { ast::ScalarType::F16 => ScalarKind::Float, ast::ScalarType::F32 => ScalarKind::Float, ast::ScalarType::F64 => ScalarKind::Float, + ast::ScalarType::F16x2 => ScalarKind::Float, + ast::ScalarType::Pred => ScalarKind::Pred, } } @@ -2490,6 +2465,11 @@ impl ast::ScalarType { 8 => ast::ScalarType::U64, _ => unreachable!(), }, + ScalarKind::Float2 => match width { + 4 => ast::ScalarType::F16x2, + _ => unreachable!(), + }, + ScalarKind::Pred => ast::ScalarType::Pred, } } } @@ -2497,7 +2477,7 @@ impl ast::ScalarType { impl ast::NotType { fn to_type(self) -> ast::Type { match self { - ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), + ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16), ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32), ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64), @@ -2519,7 +2499,9 @@ impl ast::AddDetails { fn get_type(&self) -> ast::Type { match self { ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), - ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(), + ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => { + ast::Type::Scalar((*typ).into()) + } } } } @@ -2528,7 +2510,9 @@ impl ast::MulDetails { fn get_type(&self) -> ast::Type { match self { ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), - ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(), + ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => { + ast::Type::Scalar((*typ).into()) + } } } } @@ -2560,6 +2544,15 @@ impl ast::LdStateSpace { } } +impl From for ast::VariableType { + fn from(t: ast::FnArgumentType) -> Self { + match t { + ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), + ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), + } + } +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { @@ -2575,6 +2568,8 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { ScalarKind::Unsigned => { operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed } + ScalarKind::Float2 => todo!(), + ScalarKind::Pred => false, } } _ => false, @@ -2758,6 +2753,8 @@ fn should_convert_relaxed_src( None } } + ScalarKind::Float2 => todo!(), + ScalarKind::Pred => None, }, _ => None, } @@ -2807,6 +2804,8 @@ fn should_convert_relaxed_dst( None } } + ScalarKind::Float2 => todo!(), + ScalarKind::Pred => None, }, _ => None, } @@ -2862,16 +2861,21 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { } } -impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> { - fn visit_args(&self, f: impl FnMut(&ast::KernelArgument

)) { +impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> { + fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument

)) { match self { - ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f), - ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f), + ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f), + ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| { + f(&ast::FnArgument { + align: arg.align, + name: arg.name, + v_type: ast::FnArgumentType::Param(arg.v_type), + }) + }), } } } -// CFGs below taken from "Modern Compiler Implementation in Java" #[cfg(test)] mod tests { use super::*;