From 9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 2 Oct 2020 00:11:28 +0200 Subject: [PATCH] Add sub, min, max --- ptx/src/ast.rs | 88 ++++-- ptx/src/ptx.lalrpop | 169 ++++++++---- ptx/src/test/spirv_run/max.ptx | 23 ++ ptx/src/test/spirv_run/max.spvtxt | 57 ++++ ptx/src/test/spirv_run/min.ptx | 23 ++ ptx/src/test/spirv_run/min.spvtxt | 57 ++++ ptx/src/test/spirv_run/mod.rs | 3 + ptx/src/test/spirv_run/or.ptx | 23 ++ ptx/src/test/spirv_run/or.spvtxt | 58 ++++ ptx/src/test/spirv_run/sub.ptx | 22 ++ ptx/src/test/spirv_run/sub.spvtxt | 49 ++++ ptx/src/translate.rs | 429 +++++++++++++++++++++++------- 12 files changed, 820 insertions(+), 181 deletions(-) create mode 100644 ptx/src/test/spirv_run/max.ptx create mode 100644 ptx/src/test/spirv_run/max.spvtxt create mode 100644 ptx/src/test/spirv_run/min.ptx create mode 100644 ptx/src/test/spirv_run/min.spvtxt create mode 100644 ptx/src/test/spirv_run/or.ptx create mode 100644 ptx/src/test/spirv_run/or.spvtxt create mode 100644 ptx/src/test/spirv_run/sub.ptx create mode 100644 ptx/src/test/spirv_run/sub.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 8c64ebf..048d43a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -241,6 +241,10 @@ sub_scalar_type!(IntType { S64 }); +sub_scalar_type!(UIntType { U8, U16, U32, U64 }); + +sub_scalar_type!(SIntType { S8, S16, S32, S64 }); + impl IntType { pub fn is_signed(self) -> bool { match self { @@ -331,7 +335,7 @@ pub enum Instruction { Ld(LdDetails, Arg2Ld

), Mov(MovDetails, Arg2Mov

), Mul(MulDetails, Arg3

), - Add(AddDetails, Arg3

), + Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), SetpBool(SetpBoolData, Arg5

), Not(NotType, Arg2

), @@ -346,6 +350,9 @@ pub enum Instruction { Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), Or(OrType, Arg3

), + Sub(ArithDetails, Arg3

), + Min(MinMaxDetails, Arg3

), + Max(MinMaxDetails, Arg3

), } #[derive(Copy, Clone)] @@ -554,11 +561,6 @@ impl MovDetails { } } -pub enum MulDetails { - Int(MulIntDesc), - Float(MulFloatDesc), -} - #[derive(Copy, Clone)] pub struct MulIntDesc { pub typ: IntType, @@ -572,14 +574,6 @@ pub enum MulIntControl { Wide, } -#[derive(Copy, Clone)] -pub struct MulFloatDesc { - pub typ: FloatType, - pub rounding: Option, - pub flush_to_zero: bool, - pub saturate: bool, -} - #[derive(PartialEq, Eq, Copy, Clone)] pub enum RoundingMode { NearestEven, @@ -588,23 +582,11 @@ pub enum RoundingMode { PositiveInf, } -pub enum AddDetails { - Int(AddIntDesc), - Float(AddFloatDesc), -} - pub struct AddIntDesc { pub typ: IntType, pub saturate: bool, } -pub struct AddFloatDesc { - pub typ: FloatType, - pub rounding: Option, - pub flush_to_zero: bool, - pub saturate: bool, -} - pub struct SetpData { pub typ: ScalarType, pub flush_to_zero: bool, @@ -810,3 +792,57 @@ sub_scalar_type!(OrType { B32, B64, }); + +#[derive(Copy, Clone)] +pub enum MulDetails { + Unsigned(MulUInt), + Signed(MulSInt), + Float(ArithFloat), +} + +#[derive(Copy, Clone)] +pub struct MulUInt { + pub typ: UIntType, + pub control: MulIntControl, +} + +#[derive(Copy, Clone)] +pub struct MulSInt { + pub typ: SIntType, + pub control: MulIntControl, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Unsigned(UIntType), + Signed(ArithSInt), + Float(ArithFloat), +} + +#[derive(Copy, Clone)] +pub struct ArithSInt { + pub typ: SIntType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub typ: FloatType, + pub rounding: Option, + pub flush_to_zero: bool, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub enum MinMaxDetails { + Signed(SIntType), + Unsigned(UIntType), + Float(MinMaxFloat), +} + +#[derive(Copy, Clone)] +pub struct MinMaxFloat { + pub ftz: bool, + pub nan: bool, + pub typ: FloatType, +} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2d5be8..2c0e365 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -70,6 +70,7 @@ match { ".ltu", ".lu", ".nan", + ".NaN", ".ne", ".neu", ".num", @@ -124,6 +125,8 @@ match { "ld", "mad", "map_f64_to_f32", + "max", + "min", "mov", "mul", "not", @@ -134,6 +137,7 @@ match { "shr", r"sm_[0-9]+" => ShaderModel, "st", + "sub", "texmode_independent", "texmode_unified", } else { @@ -153,6 +157,8 @@ ExtendedID : &'input str = { "ld", "mad", "map_f64_to_f32", + "max", + "min", "mov", "mul", "not", @@ -163,6 +169,7 @@ ExtendedID : &'input str = { "shr", ShaderModel, "st", + "sub", "texmode_independent", "texmode_unified", ID @@ -448,7 +455,10 @@ Instruction: ast::Instruction> = { InstCall, InstAbs, InstMad, - InstOr + InstOr, + InstSub, + InstMin, + InstMax, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul InstMul: ast::Instruction> = { - "mul" => ast::Instruction::Mul(d, a) + "mul" => ast::Instruction::Mul(d, a) }; -InstMulMode: ast::MulDetails = { - => ast::MulDetails::Int(ast::MulIntDesc { +MulDetails: ast::MulDetails = { + => ast::MulDetails::Unsigned(ast::MulUInt{ typ: t, control: ctr }), - ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F32, - rounding: r, - flush_to_zero: ftz.is_some(), - saturate: s.is_some() + => ast::MulDetails::Signed(ast::MulSInt{ + typ: t, + control: ctr }), - ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F64, - rounding: r, - flush_to_zero: false, - saturate: false - }), - ".f16" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F16, - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }), - ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F16x2, - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }) + => ast::MulDetails::Float(f) }; MulIntControl: ast::MulIntControl = { @@ -634,41 +625,23 @@ IntType : ast::IntType = { ".s64" => ast::IntType::S64, }; +UIntType: ast::UIntType = { + ".u16" => ast::UIntType::U16, + ".u32" => ast::UIntType::U32, + ".u64" => ast::UIntType::U64, +}; + +SIntType: ast::SIntType = { + ".s16" => ast::SIntType::S16, + ".s32" => ast::SIntType::S32, + ".s64" => ast::SIntType::S64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add InstAdd: ast::Instruction> = { - "add" => ast::Instruction::Add(d, a) -}; - -InstAddMode: ast::AddDetails = { - => ast::AddDetails::Int(ast::AddIntDesc { - typ: t, - saturate: false, - }), - ".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc { - typ: ast::IntType::S32, - saturate: true, - }), - ".f32" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F32, - rounding: rn, - flush_to_zero: ftz.is_some(), - saturate: sat.is_some(), - }), - ".f64" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F64, - rounding: rn, - flush_to_zero: false, - saturate: false, - }), - ".f16" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F16, - rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: sat.is_some(), - }), - ".rn"? ".ftz"? ".sat"? ".f16x2" => todo!() + "add" => ast::Instruction::Add(d, a) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp @@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad InstMad: ast::Instruction> = { - "mad" => ast::Instruction::Mad(d, a), + "mad" => ast::Instruction::Mad(d, a), "mad" ".hi" ".sat" ".s32" => todo!() }; @@ -1063,6 +1036,84 @@ OrType: ast::OrType = { ".b64" => ast::OrType::B64, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub +InstSub: ast::Instruction> = { + "sub" => ast::Instruction::Sub(d, a), +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min +InstMin: ast::Instruction> = { + "min" => ast::Instruction::Min(d, a), +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max +InstMax: ast::Instruction> = { + "max" => ast::Instruction::Max(d, a), +}; + +MinMaxDetails: ast::MinMaxDetails = { + => ast::MinMaxDetails::Unsigned(t), + => ast::MinMaxDetails::Signed(t), + ".f32" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 } + ), + ".f64" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 } + ), + ".f16" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 } + ), + ".f16x2" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ) +} + +ArithDetails: ast::ArithDetails = { + => ast::ArithDetails::Unsigned(t), + => ast::ArithDetails::Signed(ast::ArithSInt { + typ: t, + saturate: false, + }), + ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S32, + saturate: true, + }), + => ast::ArithDetails::Float(f) +} + +ArithFloat: ast::ArithFloat = { + ".f32" => ast::ArithFloat { + typ: ast::FloatType::F32, + rounding: rn, + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, + ".f64" => ast::ArithFloat { + typ: ast::FloatType::F64, + rounding: rn, + flush_to_zero: false, + saturate: false, + }, + ".f16" => ast::ArithFloat { + typ: ast::FloatType::F16, + rounding: rn.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, + ".f16x2" => ast::ArithFloat { + typ: ast::FloatType::F16x2, + rounding: rn.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, +} + Operand: ast::Operand<&'input str> = { => ast::Operand::Reg(r), "+" => { diff --git a/ptx/src/test/spirv_run/max.ptx b/ptx/src/test/spirv_run/max.ptx new file mode 100644 index 0000000..8c72fe2 --- /dev/null +++ b/ptx/src/test/spirv_run/max.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry max( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp1; + .reg .s32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp1, [in_addr]; + ld.s32 temp2, [in_addr+4]; + max.s32 temp1, temp1, temp2; + st.s32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt new file mode 100644 index 0000000..cab9a9a --- /dev/null +++ b/ptx/src/test/spirv_run/max.spvtxt @@ -0,0 +1,57 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "max" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %14 = OpLoad %uint %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %24 + %16 = OpLoad %uint %26 + OpStore %7 %16 + %19 = OpLoad %uint %6 + %20 = OpLoad %uint %7 + %18 = OpExtInst %uint %30 s_max %19 %20 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %6 + %27 = OpConvertUToPtr %_ptr_Generic_uint %21 + OpStore %27 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/min.ptx b/ptx/src/test/spirv_run/min.ptx new file mode 100644 index 0000000..0311cdb --- /dev/null +++ b/ptx/src/test/spirv_run/min.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry min( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp1; + .reg .s32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp1, [in_addr]; + ld.s32 temp2, [in_addr+4]; + min.s32 temp1, temp1, temp2; + st.s32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt new file mode 100644 index 0000000..119cd15 --- /dev/null +++ b/ptx/src/test/spirv_run/min.spvtxt @@ -0,0 +1,57 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "min" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %14 = OpLoad %uint %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %24 + %16 = OpLoad %uint %26 + OpStore %7 %16 + %19 = OpLoad %uint %6 + %20 = OpLoad %uint %7 + %18 = OpExtInst %uint %30 s_min %19 %20 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %6 + %27 = OpConvertUToPtr %_ptr_Generic_uint %21 + OpStore %27 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 99785a6..8caf540 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]) test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); test_ptx!(shr, [-2i32], [-1i32]); test_ptx!(or, [1u64, 2u64], [3u64]); +test_ptx!(sub, [2u64], [1u64]); +test_ptx!(min, [555i32, 444i32], [444i32]); +test_ptx!(max, [555i32, 444i32], [555i32]); struct DisplayError { diff --git a/ptx/src/test/spirv_run/or.ptx b/ptx/src/test/spirv_run/or.ptx new file mode 100644 index 0000000..1deb3c8 --- /dev/null +++ b/ptx/src/test/spirv_run/or.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry or( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp1, [in_addr]; + ld.u64 temp2, [in_addr+8]; + or.b64 temp1, temp1, temp2; + st.u64 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt new file mode 100644 index 0000000..fbf80c5 --- /dev/null +++ b/ptx/src/test/spirv_run/or.spvtxt @@ -0,0 +1,58 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %33 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "or" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %36 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_8 = OpConstant %ulong 8 + %1 = OpFunction %void None %36 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %31 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_8 + %26 = OpConvertUToPtr %_ptr_Generic_ulong %24 + %16 = OpLoad %ulong %26 + OpStore %7 %16 + %19 = OpLoad %ulong %6 + %20 = OpLoad %ulong %7 + %28 = OpCopyObject %ulong %19 + %29 = OpCopyObject %ulong %20 + %27 = OpBitwiseOr %ulong %28 %29 + %18 = OpCopyObject %ulong %27 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %ulong %6 + %30 = OpConvertUToPtr %_ptr_Generic_ulong %21 + OpStore %30 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sub.ptx b/ptx/src/test/spirv_run/sub.ptx new file mode 100644 index 0000000..6cce9dc --- /dev/null +++ b/ptx/src/test/spirv_run/sub.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry sub( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + sub.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/sub.spvtxt b/ptx/src/test/spirv_run/sub.spvtxt new file mode 100644 index 0000000..8520168 --- /dev/null +++ b/ptx/src/test/spirv_run/sub.spvtxt @@ -0,0 +1,49 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "sub" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpISub %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index fb1b843..7c15744 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -595,6 +595,15 @@ fn convert_to_typed_statements( ast::Instruction::Or(d, a) => { result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) } + ast::Instruction::Sub(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast()))) + } + ast::Instruction::Min(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast()))) + } + ast::Instruction::Max(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg_offset( &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, - typ: ast::Type, + mut typ: ast::Type, ) -> Result { let (reg, offset) = desc.op; match desc.sema { - ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() + ArgumentSemantics::Default + | ArgumentSemantics::DefaultRelaxed + | ArgumentSemantics::PhysicalPointer => { + if desc.sema == ArgumentSemantics::PhysicalPointer { + typ = ast::Type::Scalar(ast::ScalarType::U64); + } + let (width, kind) = match typ { + ast::Type::Scalar(scalar_t) => { + let kind = match scalar_t.kind() { + kind @ ScalarKind::Bit + | kind @ ScalarKind::Unsigned + | kind @ ScalarKind::Signed => kind, + ScalarKind::Float => return Err(TranslateError::MismatchedType), + ScalarKind::Float2 => return Err(TranslateError::MismatchedType), + ScalarKind::Pred => return Err(TranslateError::MismatchedType), + }; + (scalar_t.width(), kind) + } + _ => return Err(TranslateError::MismatchedType), }; - let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + let arith_detail = if kind == ScalarKind::Signed { + ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::from_size(width), + saturate: false, + }) + } else { + ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) + }; + let id_constant_stmt = self.id_def.new_id(typ); let result_id = self.id_def.new_id(typ); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - Ok(result_id) - } - ArgumentSemantics::PhysicalPointer => { - let scalar_t = ast::ScalarType::U64; - let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::U64; - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); + // TODO: check for edge cases around min value/max value/wrapping + if offset < 0 && kind != ScalarKind::Signed { + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::from_parts(width, kind), + value: -(offset as i64), + })); + self.func.push(Statement::Instruction( + ast::Instruction::::Sub( + arith_detail, + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + } else { + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::from_parts(width, kind), + value: offset as i64, + })); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + arith_detail, + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + } Ok(result_id) } ArgumentSemantics::RegisterPointer => { @@ -1522,14 +1543,22 @@ fn emit_function_body_ops( } }, ast::Instruction::Mul(mul, arg) => match mul { - ast::MulDetails::Int(ref ctr) => { - emit_mul_int(builder, map, opencl, ctr, arg)?; + ast::MulDetails::Signed(ref ctr) => { + emit_mul_sint(builder, map, opencl, ctr, arg)? + } + ast::MulDetails::Unsigned(ref ctr) => { + emit_mul_uint(builder, map, opencl, ctr, arg)? } ast::MulDetails::Float(_) => todo!(), }, ast::Instruction::Add(add, arg) => match add { - ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?, - ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, + ast::ArithDetails::Signed(ref desc) => { + emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)? + } + ast::ArithDetails::Unsigned(ref desc) => { + emit_add_int(builder, map, (*desc).into(), false, arg)? + } + ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, }, ast::Instruction::Setp(setp, arg) => { if arg.dst2.is_some() { @@ -1581,8 +1610,11 @@ fn emit_function_body_ops( } ast::Instruction::SetpBool(_, _) => todo!(), ast::Instruction::Mad(mad, arg) => match mad { - ast::MulDetails::Int(ref desc) => { - emit_mad_int(builder, map, opencl, desc, arg)? + ast::MulDetails::Signed(ref desc) => { + emit_mad_sint(builder, map, opencl, desc, arg)? + } + ast::MulDetails::Unsigned(ref desc) => { + emit_mad_uint(builder, map, opencl, desc, arg)? } ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?, }, @@ -1594,6 +1626,23 @@ fn emit_function_body_ops( builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; } } + ast::Instruction::Sub(d, arg) => match d { + ast::ArithDetails::Signed(desc) => { + emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?; + } + ast::ArithDetails::Unsigned(desc) => { + emit_sub_int(builder, map, (*desc).into(), false, arg)?; + } + ast::ArithDetails::Float(desc) => { + emit_sub_float(builder, map, desc, arg)?; + } + }, + ast::Instruction::Min(d, a) => { + emit_min(builder, map, opencl, d, a)?; + } + ast::Instruction::Max(d, a) => { + emit_max(builder, map, opencl, d, a)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(*typ)); @@ -1624,11 +1673,11 @@ fn emit_function_body_ops( Ok(()) } -fn emit_mad_int( +fn emit_mad_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, - desc: &ast::MulIntDesc, + desc: &ast::MulUInt, arg: &ast::Arg4, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); @@ -1638,16 +1687,38 @@ fn emit_mad_int( builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; } ast::MulIntControl::High => { - let cl_op = if desc.typ.is_signed() { - spirv::CLOp::s_mad_hi - } else { - spirv::CLOp::u_mad_hi - }; builder.ext_inst( inst_type, Some(arg.dst), opencl, - cl_op as spirv::Word, + spirv::CLOp::u_mad_hi as spirv::Word, + [arg.src1, arg.src2, arg.src3], + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_sint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MulSInt, + arg: &ast::Arg4, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + match desc.control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?; + builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst), + opencl, + spirv::CLOp::s_mad_hi as spirv::Word, [arg.src1, arg.src2, arg.src3], )?; } @@ -1659,7 +1730,7 @@ fn emit_mad_int( fn emit_mad_float( builder: &mut dr::Builder, map: &mut TypeWordMap, - desc: &ast::MulFloatDesc, + desc: &ast::ArithFloat, arg: &ast::Arg4, ) -> Result<(), dr::Error> { todo!() @@ -1668,7 +1739,7 @@ fn emit_mad_float( fn emit_add_float( builder: &mut dr::Builder, map: &mut TypeWordMap, - desc: &ast::AddFloatDesc, + desc: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { if desc.flush_to_zero { @@ -1680,6 +1751,67 @@ fn emit_add_float( Ok(()) } +fn emit_sub_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + if desc.flush_to_zero { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_min( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + builder.ext_inst( + inst_type, + Some(arg.dst), + opencl, + cl_op as spirv::Word, + [arg.src1, arg.src2], + )?; + Ok(()) +} + +fn emit_max( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + builder.ext_inst( + inst_type, + Some(arg.dst), + opencl, + cl_op as spirv::Word, + [arg.src1, arg.src2], + )?; + Ok(()) +} + fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1880,11 +2012,11 @@ fn emit_setp( Ok(()) } -fn emit_mul_int( +fn emit_mul_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, - desc: &ast::MulIntDesc, + desc: &ast::MulSInt, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let instruction_type = ast::ScalarType::from(desc.typ); @@ -1894,16 +2026,11 @@ fn emit_mul_int( builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::MulIntControl::High => { - let ocl_mul_hi = if desc.typ.is_signed() { - spirv::CLOp::s_mul_hi - } else { - spirv::CLOp::u_mul_hi - }; builder.ext_inst( inst_type, Some(arg.dst), opencl, - ocl_mul_hi as spirv::Word, + spirv::CLOp::s_mul_hi as spirv::Word, [arg.src1, arg.src2], )?; } @@ -1913,11 +2040,54 @@ fn emit_mul_int( SpirvScalarKey::from(instruction_type), ]); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); - let mul = if desc.typ.is_signed() { - builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)? - } else { - builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)? - }; + let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; + let instr_width = instruction_type.width(); + let instr_kind = instruction_type.kind(); + let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); + let dst_type_id = map.get_or_add_scalar(builder, dst_type); + struct2_bitcast_to_wide( + builder, + map, + SpirvScalarKey::from(instruction_type), + inst_type, + arg.dst, + dst_type_id, + mul, + )?; + } + } + Ok(()) +} + +fn emit_mul_uint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MulUInt, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + let instruction_type = ast::ScalarType::from(desc.typ); + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); + match desc.control { + ast::MulIntControl::Low => { + builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst), + opencl, + spirv::CLOp::u_mul_hi as spirv::Word, + [arg.src1, arg.src2], + )?; + } + ast::MulIntControl::Wide => { + let mul_ext_type = SpirvType::Struct(vec![ + SpirvScalarKey::from(instruction_type), + SpirvScalarKey::from(instruction_type), + ]); + let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); + let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; let instr_width = instruction_type.width(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); @@ -1981,14 +2151,33 @@ fn emit_abs( fn emit_add_int( builder: &mut dr::Builder, map: &mut TypeWordMap, - ctr: &ast::AddIntDesc, + typ: ast::ScalarType, + saturate: bool, arg: &ast::Arg3, ) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ))); + if saturate { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; Ok(()) } +fn emit_sub_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::Arg3, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); + builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; + Ok(()) +} + fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -2920,6 +3109,18 @@ impl ast::Instruction { t, a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?, ), + ast::Instruction::Sub(d, a) => { + let typ = d.get_type(); + ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?) + } + ast::Instruction::Min(d, a) => { + let typ = d.get_type(); + ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?) + } + ast::Instruction::Max(d, a) => { + let typ = d.get_type(); + ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?) + } }) } } @@ -3129,6 +3330,9 @@ impl ast::Instruction { | ast::Instruction::Abs(_, _) | ast::Instruction::Call(_) | ast::Instruction::Or(_, _) + | ast::Instruction::Sub(_, _) + | ast::Instruction::Min(_, _) + | ast::Instruction::Max(_, _) | ast::Instruction::Mad(_, _) => None, } } @@ -4049,25 +4253,33 @@ impl ast::ShrType { } } -impl ast::AddDetails { +impl ast::ArithDetails { fn get_type(&self) -> ast::Type { - match self { - ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), - ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => { - ast::Type::Scalar((*typ).into()) - } - } + ast::Type::Scalar(match self { + ast::ArithDetails::Unsigned(t) => (*t).into(), + ast::ArithDetails::Signed(d) => d.typ.into(), + ast::ArithDetails::Float(d) => d.typ.into(), + }) } } impl ast::MulDetails { fn get_type(&self) -> ast::Type { - match self { - ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), - ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => { - ast::Type::Scalar((*typ).into()) - } - } + ast::Type::Scalar(match self { + ast::MulDetails::Unsigned(d) => d.typ.into(), + ast::MulDetails::Signed(d) => d.typ.into(), + ast::MulDetails::Float(d) => d.typ.into(), + }) + } +} + +impl ast::MinMaxDetails { + fn get_type(&self) -> ast::Type { + ast::Type::Scalar(match self { + ast::MinMaxDetails::Signed(t) => (*t).into(), + ast::MinMaxDetails::Unsigned(t) => (*t).into(), + ast::MinMaxDetails::Float(d) => d.typ.into(), + }) } } @@ -4085,6 +4297,30 @@ impl ast::IntType { } } +impl ast::SIntType { + fn from_size(width: u8) -> Self { + match width { + 1 => ast::SIntType::S8, + 2 => ast::SIntType::S16, + 4 => ast::SIntType::S32, + 8 => ast::SIntType::S64, + _ => unreachable!(), + } + } +} + +impl ast::UIntType { + fn from_size(width: u8) -> Self { + match width { + 1 => ast::UIntType::U8, + 2 => ast::UIntType::U16, + 4 => ast::UIntType::U32, + 8 => ast::UIntType::U64, + _ => unreachable!(), + } + } +} + impl ast::LdStateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { @@ -4128,7 +4364,8 @@ impl ast::OperandOrVector { impl ast::MulDetails { fn is_wide(&self) -> bool { match self { - ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide, + ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide, + ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide, ast::MulDetails::Float(_) => false, } }