diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 979bedf..6bb099a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -68,6 +68,17 @@ pub enum Type { ExtendedScalar(ExtendedScalarType), } +impl From for Type { + fn from(t: FloatType) -> Self { + match t { + FloatType::F16 => Type::Scalar(ScalarType::F16), + FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2), + FloatType::F32 => Type::Scalar(ScalarType::F32), + FloatType::F64 => Type::Scalar(ScalarType::F64), + } + } +} + #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum ScalarType { B8, @@ -87,6 +98,37 @@ pub enum ScalarType { F64, } +impl From for ScalarType { + fn from(t: IntType) -> Self { + match t { + IntType::S16 => ScalarType::S16, + IntType::S32 => ScalarType::S32, + IntType::S64 => ScalarType::S64, + IntType::U16 => ScalarType::U16, + IntType::U32 => ScalarType::U32, + IntType::U64 => ScalarType::U64, + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub enum IntType { + U16, + U32, + U64, + S16, + S32, + S64, +} + +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub enum FloatType { + F16, + F16x2, + F32, + F64, +} + #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum ExtendedScalarType { F16x2, @@ -130,8 +172,8 @@ pub struct PredAt { pub enum Instruction { Ld(LdData, Arg2), Mov(MovData, Arg2Mov), - Mul(MulData, Arg3), - Add(AddData, Arg3), + Mul(MulDetails, Arg3), + Add(AddDetails, Arg3), Setp(SetpData, Arg4), SetpBool(SetpBoolData, Arg5), Not(NotData, Arg2), @@ -244,23 +286,24 @@ pub struct MovData { pub typ: Type, } -pub struct MulData { - pub typ: Type, - pub desc: MulDescriptor, +pub enum MulDetails { + Int(MulIntDesc), + Float(MulFloatDesc), } -pub enum MulDescriptor { - Int(MulIntControl), - Float(MulFloatDesc), +pub struct MulIntDesc { + pub typ: IntType, + pub control: MulIntControl, } pub enum MulIntControl { Low, High, - Wide + Wide, } pub struct MulFloatDesc { + pub typ: FloatType, pub rounding: Option, pub flush_to_zero: bool, pub saturate: bool, @@ -270,11 +313,24 @@ pub enum RoundingMode { NearestEven, Zero, NegativeInf, - PositiveInf + PositiveInf, } -pub struct AddData { - pub typ: ScalarType, +pub enum AddDetails { + Int(AddIntDesc), + Float(AddFloatDesc), +} + +pub struct AddIntDesc { + pub typ: IntType, + pub saturate: bool, +} + +pub struct AddFloatDesc { + pub typ: FloatType, + pub rounding: Option, + pub flush_to_zero: bool, + pub saturate: bool, } pub struct SetpData { @@ -310,7 +366,7 @@ pub struct SetpBoolData { pub typ: ScalarType, pub flush_to_zero: bool, pub cmp_op: SetpCompareOp, - pub bool_op: SetpBoolPostOp + pub bool_op: SetpBoolPostOp, } pub struct NotData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index b44702d..cc58cf2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -398,43 +398,35 @@ InstMul: ast::Instruction<&'input str> = { "mul" => ast::Instruction::Mul(d, a) }; -InstMulMode: ast::MulData = { - => ast::MulData{ - typ: ast::Type::Scalar(t), - desc: ast::MulDescriptor::Int(ctr) - }, - ".f32" => ast::MulData{ - typ: ast::Type::Scalar(ast::ScalarType::F32), - desc: ast::MulDescriptor::Float(ast::MulFloatDesc { - rounding: r, - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }) - }, - ".f64" => ast::MulData{ - typ: ast::Type::Scalar(ast::ScalarType::F64), - desc: ast::MulDescriptor::Float(ast::MulFloatDesc { - rounding: r, - flush_to_zero: false, - saturate: false - }) - }, - ".f16" => ast::MulData{ - typ: ast::Type::Scalar(ast::ScalarType::F16), - desc: ast::MulDescriptor::Float(ast::MulFloatDesc { - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }) - }, - ".f16x2" => ast::MulData{ - typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2), - desc: ast::MulDescriptor::Float(ast::MulFloatDesc { - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }) - } +InstMulMode: ast::MulDetails = { + => ast::MulDetails::Int(ast::MulIntDesc { + typ: t, + control: ctr + }), + ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { + typ: ast::FloatType::F32, + rounding: r, + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }), + ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { + typ: ast::FloatType::F64, + rounding: r, + flush_to_zero: false, + saturate: false + }), + ".f16" => ast::MulDetails::Float(ast::MulFloatDesc { + typ: ast::FloatType::F16, + rounding: r.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }), + ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc { + typ: ast::FloatType::F16x2, + rounding: r.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }) }; MulIntControl: ast::MulIntControl = { @@ -451,13 +443,13 @@ RoundingMode : ast::RoundingMode = { ".rp" => ast::RoundingMode::PositiveInf, }; -IntType : ast::ScalarType = { - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, +IntType : ast::IntType = { + ".u16" => ast::IntType::U16, + ".u32" => ast::IntType::U32, + ".u64" => ast::IntType::U64, + ".s16" => ast::IntType::S16, + ".s32" => ast::IntType::S32, + ".s64" => ast::IntType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -467,12 +459,33 @@ InstAdd: ast::Instruction<&'input str> = { "add" => ast::Instruction::Add(d, a) }; -InstAddMode: ast::AddData = { - => ast::AddData{ typ: t }, - ".sat" ".s32" => ast::AddData{ typ: ast::ScalarType::S32 }, - RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{ typ: ast::ScalarType::F32 }, - RoundingMode? ".f64" => ast::AddData{ typ: ast::ScalarType::F64 }, - ".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{ typ: ast::ScalarType::F16 }, +InstAddMode: ast::AddDetails = { + => ast::AddDetails::Int(ast::AddIntDesc { + typ: t, + saturate: false, + }), + ".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc { + typ: ast::IntType::S32, + saturate: true, + }), + ".f32" => ast::AddDetails::Float(ast::AddFloatDesc { + typ: ast::FloatType::F32, + rounding: rn, + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }), + ".f64" => ast::AddDetails::Float(ast::AddFloatDesc { + typ: ast::FloatType::F64, + rounding: rn, + flush_to_zero: false, + saturate: false, + }), + ".f16" => ast::AddDetails::Float(ast::AddFloatDesc { + typ: ast::FloatType::F16, + rounding: rn.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }), ".rn"? ".ftz"? ".sat"? ".f16x2" => todo!() }; diff --git a/ptx/src/test/spirv_run/add.ptx b/ptx/src/test/spirv_run/add.ptx new file mode 100644 index 0000000..6762eae --- /dev/null +++ b/ptx/src/test/spirv_run/add.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry add( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/add.spvtxt b/ptx/src/test/spirv_run/add.spvtxt new file mode 100644 index 0000000..465a74e --- /dev/null +++ b/ptx/src/test/spirv_run/add.spvtxt @@ -0,0 +1,38 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %21 = OpLabel + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + OpStore %8 %6 + OpStore %9 %7 + %12 = OpLoad %ulong %8 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %12 + %13 = OpLoad %ulong %19 + OpStore %10 %13 + %14 = OpLoad %ulong %10 + %15 = OpIAdd %ulong %14 %ulong_1 + OpStore %11 %15 + %16 = OpLoad %ulong %9 + %17 = OpLoad %ulong %11 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %20 %17 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c1ef574..32e46ce 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -43,6 +43,7 @@ test_ptx!(ld_st, [1u64], [1u64]); test_ptx!(mov, [1u64], [1u64]); test_ptx!(mul_lo, [1u64], [2u64]); test_ptx!(mul_hi, [u64::max_value()], [1u64]); +test_ptx!(add, [1u64], [2u64]); struct DisplayError { err: T, @@ -233,6 +234,9 @@ fn is_instr_equal( instr2: &Instruction, map: &mut HashMap, ) -> bool { + if instr1.class.opcode != instr2.class.opcode { + return false; + } if !is_option_equal(&instr1.result_type, &instr2.result_type, map, is_word_equal) { return false; } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7512545..b2831a0 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -355,11 +355,11 @@ fn normalize_insert_instruction( Instruction::Mov(d, arg) } Instruction::Mul(d, a) => { - let arg = normalize_expand_arg3(func, id_def, &|| d.typ.try_as_scalar(), a); + let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); Instruction::Mul(d, arg) } Instruction::Add(d, a) => { - let arg = normalize_expand_arg3(func, id_def, &|| Some(d.typ), a); + let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a); Instruction::Add(d, arg) } Instruction::Setp(d, a) => { @@ -731,11 +731,17 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } - Instruction::Mul(mul, arg) => match mul.desc { - ast::MulDescriptor::Int(ref ctr) => { - emit_mul_int(builder, map, opencl, mul.typ, ctr, arg)?; + Instruction::Mul(mul, arg) => match mul { + ast::MulDetails::Int(ref ctr) => { + emit_mul_int(builder, map, opencl, ctr, arg)?; } - ast::MulDescriptor::Float(_) => todo!(), + ast::MulDetails::Float(_) => todo!(), + }, + Instruction::Add(add, arg) => match add { + ast::AddDetails::Int(ref desc) => { + emit_add_int(builder, map, desc, arg)?; + } + ast::AddDetails::Float(_) => todo!(), }, _ => todo!(), }, @@ -755,26 +761,24 @@ fn emit_mul_int( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, - typ: ast::Type, - ctr: &ast::MulIntControl, + desc: &ast::MulIntDesc, arg: &Arg3, ) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(typ)); - match ctr { + let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into())); + match desc.control { ast::MulIntControl::Low => { builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::MulIntControl::High => { - let ocl_mul_hi = match typ.try_as_scalar().unwrap().kind() { - ScalarKind::Signed => spirv::CLOp::s_mul_hi, - ScalarKind::Unsigned => spirv::CLOp::u_mul_hi, - ScalarKind::Float => unreachable!(), - ScalarKind::Byte => unreachable!(), + let ocl_mul_hi = if desc.typ.is_signed() { + spirv::CLOp::s_mul_hi + } else { + spirv::CLOp::u_mul_hi }; builder.ext_inst( inst_type, Some(arg.dst), - 1, + opencl, ocl_mul_hi as spirv::Word, [arg.src1, arg.src2], )?; @@ -784,6 +788,17 @@ fn emit_mul_int( Ok(()) } +fn emit_add_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ctr: &ast::AddIntDesc, + arg: &Arg3, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into())); + builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; + Ok(()) +} + fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1059,8 +1074,8 @@ type ExpandedStatement = Statement; enum Instruction { Ld(ast::LdData, A::Arg2), Mov(ast::MovData, A::Arg2Mov), - Mul(ast::MulData, A::Arg3), - Add(ast::AddData, A::Arg3), + Mul(ast::MulDetails, A::Arg3), + Add(ast::AddDetails, A::Arg3), Setp(ast::SetpData, A::Arg4), SetpBool(ast::SetpBoolData, A::Arg5), Not(ast::NotData, A::Arg2), @@ -1091,12 +1106,22 @@ impl Instruction { fn get_type(&self) -> Option { match self { - Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)), + Instruction::Add(add, _) => match add { + ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => { + Some(ast::Type::Scalar((*typ).into())) + } + ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => Some((*typ).into()), + }, Instruction::Ret(_) => None, Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), Instruction::Mov(mov, _) => Some(mov.typ), - Instruction::Mul(mul, _) => Some(mul.typ), + Instruction::Mul(mul, _) => match mul { + ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => { + Some(ast::Type::Scalar((*typ).into())) + } + ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => Some((*typ).into()), + }, _ => todo!(), } } @@ -1437,12 +1462,12 @@ impl ast::Instruction { fn get_type(&self) -> Option { match self { - ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)), + ast::Instruction::Add(add, _) => Some(add.get_type()), ast::Instruction::Ret(_) => None, ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), ast::Instruction::Mov(mov, _) => Some(mov.typ), - ast::Instruction::Mul(mul, _) => Some(mul.typ), + ast::Instruction::Mul(mul, _) => Some(mul.get_type()), _ => todo!(), } } @@ -1800,6 +1825,33 @@ impl ast::ScalarType { } } +impl ast::AddDetails { + fn get_type(&self) -> ast::Type { + match self { + ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), + ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(), + } + } +} + +impl ast::MulDetails { + fn get_type(&self) -> ast::Type { + match self { + ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()), + ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(), + } + } +} + +impl ast::IntType { + fn is_signed(self) -> bool { + match self { + ast::IntType::S16 | ast::IntType::S32 | ast::IntType::S64 => true, + ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false, + } + } +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {