diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl index a0d487b..4249f2b 100644 --- a/ptx/lib/notcuda_ptx_impl.cl +++ b/ptx/lib/notcuda_ptx_impl.cl @@ -1,5 +1,5 @@ // Every time this file changes it must te rebuilt: -// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0" -out_dir . -device kbl -output_no_suffix -spv_only +// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0 -Dcl_intel_bit_instructions" -out_dir . -device kbl -output_no_suffix -spv_only // Additionally you should strip names: // spirv-opt --strip-debug notcuda_ptx_impl.spv -o notcuda_ptx_impl.spv @@ -119,3 +119,23 @@ atomic_dec(atom_relaxed_sys_shared_dec, memory_order_relaxed, memory_order_relax atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local); atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local); atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local); + +uint FUNC(bfe_u32)(uint base, uint pos, uint len) +{ + return intel_ubfe(base, pos, len); +} + +ulong FUNC(bfe_u64)(ulong base, uint pos, uint len) +{ + return intel_ubfe(base, pos, len); +} + +int FUNC(bfe_s32)(int base, uint pos, uint len) +{ + return intel_sbfe(base, pos, len); +} + +long FUNC(bfe_s64)(long base, uint pos, uint len) +{ + return intel_sbfe(base, pos, len); +} \ No newline at end of file diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv index 36f37bb..1ef470f 100644 Binary files a/ptx/lib/notcuda_ptx_impl.spv and b/ptx/lib/notcuda_ptx_impl.spv differ diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index b6ac3db..5a5f6be 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -558,7 +558,7 @@ pub enum Instruction { Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), SetpBool(SetpBoolData, Arg5

), - Not(NotType, Arg2

), + Not(BooleanType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), Cvta(CvtaDetails, Arg2

), @@ -569,12 +569,12 @@ pub enum Instruction { Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), - Or(OrAndType, Arg3

), + Or(BooleanType, Arg3

), Sub(ArithDetails, Arg3

), Min(MinMaxDetails, Arg3

), Max(MinMaxDetails, Arg3

), Rcp(RcpDetails, Arg2

), - And(OrAndType, Arg3

), + And(BooleanType, Arg3

), Selp(SelpType, Arg4

), Bar(BarDetails, Arg1Bar

), Atom(AtomDetails, Arg3

), @@ -590,6 +590,9 @@ pub enum Instruction { Clz { typ: BitType, arg: Arg2

}, Brev { typ: BitType, arg: Arg2

}, Popc { typ: BitType, arg: Arg2

}, + Xor { typ: BooleanType, arg: Arg3

}, + Bfe { typ: IntType, arg: Arg4

}, + Rem { typ: IntType, arg: Arg3

}, } #[derive(Copy, Clone)] @@ -896,14 +899,6 @@ pub struct SetpBoolData { pub bool_op: SetpBoolPostOp, } -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum NotType { - Pred, - B16, - B32, - B64, -} - pub struct BraData { pub uniform: bool, } @@ -1058,7 +1053,7 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(OrAndType { +sub_enum!(BooleanType { Pred, B16, B32, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index cd1c642..6c231b2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -142,6 +142,7 @@ match { "atom", "bar", "barrier", + "bfe", "bra", "brev", "call", @@ -166,6 +167,7 @@ match { "or", "popc", "rcp", + "rem", "ret", "rsqrt", "selp", @@ -179,6 +181,7 @@ match { "sub", "texmode_independent", "texmode_unified", + "xor", } else { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID, @@ -192,6 +195,7 @@ ExtendedID : &'input str = { "atom", "bar", "barrier", + "bfe", "bra", "brev", "call", @@ -216,6 +220,7 @@ ExtendedID : &'input str = { "or", "popc", "rcp", + "rem", "ret", "rsqrt", "selp", @@ -229,6 +234,7 @@ ExtendedID : &'input str = { "sub", "texmode_independent", "texmode_unified", + "xor", ID } @@ -708,6 +714,9 @@ Instruction: ast::Instruction> = { InstClz, InstBrev, InstPopc, + InstXor, + InstRem, + InstBfe, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -874,6 +883,13 @@ IntType : ast::IntType = { ".s64" => ast::IntType::S64, }; +IntType3264: ast::IntType = { + ".u32" => ast::IntType::U32, + ".u64" => ast::IntType::U64, + ".s32" => ast::IntType::S32, + ".s64" => ast::IntType::S64, +} + UIntType: ast::UIntType = { ".u16" => ast::UIntType::U16, ".u32" => ast::UIntType::U32, @@ -979,14 +995,14 @@ SetpTypeNoF32: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not InstNot: ast::Instruction> = { - "not" => ast::Instruction::Not(t, a) + "not" => ast::Instruction::Not(t, a) }; -NotType: ast::NotType = { - ".pred" => ast::NotType::Pred, - ".b16" => ast::NotType::B16, - ".b32" => ast::NotType::B32, - ".b64" => ast::NotType::B64, +BooleanType: ast::BooleanType = { + ".pred" => ast::BooleanType::Pred, + ".b16" => ast::BooleanType::B16, + ".b32" => ast::BooleanType::B32, + ".b64" => ast::BooleanType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at @@ -1294,19 +1310,12 @@ SignedIntType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or InstOr: ast::Instruction> = { - "or" => ast::Instruction::Or(d, a), + "or" => ast::Instruction::Or(d, a), }; -OrAndType: ast::OrAndType = { - ".pred" => ast::OrAndType::Pred, - ".b16" => ast::OrAndType::B16, - ".b32" => ast::OrAndType::B32, - ".b64" => ast::OrAndType::B64, -} - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and InstAnd: ast::Instruction> = { - "and" => ast::Instruction::And(d, a), + "and" => ast::Instruction::And(d, a), }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp @@ -1447,7 +1456,7 @@ InstAtom: ast::Instruction> = { }; ast::Instruction::Atom(details,a) }, - "atom" => { + "atom" => { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1456,7 +1465,7 @@ InstAtom: ast::Instruction> = { }; ast::Instruction::Atom(details,a) }, - "atom" => { + "atom" => { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1515,12 +1524,12 @@ BitType: ast::BitType = { ".b64" => ast::BitType::B64, } -AtomUIntType: ast::UIntType = { +UIntType3264: ast::UIntType = { ".u32" => ast::UIntType::U32, ".u64" => ast::UIntType::U64, } -AtomSIntType: ast::SIntType = { +SIntType3264: ast::SIntType = { ".s32" => ast::SIntType::S32, ".s64" => ast::SIntType::S64, } @@ -1664,6 +1673,22 @@ InstPopc: ast::Instruction> = { "popc" => ast::Instruction::Popc{ <> } } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor +InstXor: ast::Instruction> = { + "xor" => ast::Instruction::Xor{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe +InstBfe: ast::Instruction> = { + "bfe" => ast::Instruction::Bfe{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem +InstRem: ast::Instruction> = { + "rem" => ast::Instruction::Rem{ <> } +} + + NegTypeFtz: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, diff --git a/ptx/src/test/spirv_run/bfe.ptx b/ptx/src/test/spirv_run/bfe.ptx new file mode 100644 index 0000000..60ee8a6 --- /dev/null +++ b/ptx/src/test/spirv_run/bfe.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry bfe( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp<3>; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp0, [in_addr]; + ld.u32 temp1, [in_addr+4]; + ld.u32 temp2, [in_addr+8]; + bfe.u32 temp0, temp0, temp1, temp2; + st.u32 [out_addr], temp0; + ret; +} diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt new file mode 100644 index 0000000..edcf138 --- /dev/null +++ b/ptx/src/test/spirv_run/bfe.spvtxt @@ -0,0 +1,70 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %40 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "bfe" + OpDecorate %34 LinkageAttributes "__notcuda_ptx_impl__bfe_u32" Import + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %43 = OpTypeFunction %uint %uint %uint %uint + %ulong = OpTypeInt 64 0 + %45 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %ulong_8 = OpConstant %ulong 8 + %34 = OpFunction %uint None %43 + %36 = OpFunctionParameter %uint + %37 = OpFunctionParameter %uint + %38 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %45 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %33 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 + OpStore %4 %11 + %12 = OpLoad %ulong %3 + OpStore %5 %12 + %14 = OpLoad %ulong %4 + %29 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %29 + OpStore %6 %13 + %16 = OpLoad %ulong %4 + %26 = OpIAdd %ulong %16 %ulong_4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %26 + %15 = OpLoad %uint %30 + OpStore %7 %15 + %18 = OpLoad %ulong %4 + %28 = OpIAdd %ulong %18 %ulong_8 + %31 = OpConvertUToPtr %_ptr_Generic_uint %28 + %17 = OpLoad %uint %31 + OpStore %8 %17 + %20 = OpLoad %uint %6 + %21 = OpLoad %uint %7 + %22 = OpLoad %uint %8 + %19 = OpFunctionCall %uint %34 %20 %21 %22 + OpStore %6 %19 + %23 = OpLoad %ulong %5 + %24 = OpLoad %uint %6 + %32 = OpConvertUToPtr %_ptr_Generic_uint %23 + OpStore %32 %24 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a7ef75b..5bbe45a 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -116,6 +116,20 @@ test_ptx!( [0b11000111_01011100_10101110_11111011u32], [0b11011111_01110101_00111010_11100011u32] ); +test_ptx!( + xor, + [ + 0b01010010_00011010_01000000_00001101u32, + 0b11100110_10011011_00001100_00100011u32 + ], + [0b10110100100000010100110000101110u32] +); +test_ptx!(rem, [21692i32, 13i32], [8i32]); +test_ptx!( + bfe, + [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32], + [0b11000001u32] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/rem.ptx b/ptx/src/test/spirv_run/rem.ptx new file mode 100644 index 0000000..2ac482d --- /dev/null +++ b/ptx/src/test/spirv_run/rem.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry rem( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .s32 temp1; + .reg .s32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.s32 temp1, [in_addr]; + ld.s32 temp2, [in_addr+4]; + rem.s32 temp1, temp1, temp2; + st.s32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt new file mode 100644 index 0000000..72d0965 --- /dev/null +++ b/ptx/src/test/spirv_run/rem.spvtxt @@ -0,0 +1,55 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %28 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "rem" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %31 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %31 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %26 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 + OpStore %4 %10 + %11 = OpLoad %ulong %3 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_Generic_uint %13 + %12 = OpLoad %uint %23 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %22 = OpIAdd %ulong %15 %ulong_4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %22 + %14 = OpLoad %uint %24 + OpStore %7 %14 + %17 = OpLoad %uint %6 + %18 = OpLoad %uint %7 + %16 = OpSMod %uint %17 %18 + OpStore %6 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %25 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %25 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/xor.ptx b/ptx/src/test/spirv_run/xor.ptx new file mode 100644 index 0000000..a28b321 --- /dev/null +++ b/ptx/src/test/spirv_run/xor.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry xor( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp1; + .reg .b32 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.b32 temp1, [in_addr]; + ld.b32 temp2, [in_addr+4]; + xor.b32 temp1, temp1, temp2; + st.b32 [out_addr], temp1; + ret; +} diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt new file mode 100644 index 0000000..ee09898 --- /dev/null +++ b/ptx/src/test/spirv_run/xor.spvtxt @@ -0,0 +1,55 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %28 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "xor" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %31 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %31 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %26 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 + OpStore %4 %10 + %11 = OpLoad %ulong %3 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_Generic_uint %13 + %12 = OpLoad %uint %23 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %22 = OpIAdd %ulong %15 %ulong_4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %22 + %14 = OpLoad %uint %24 + OpStore %7 %14 + %17 = OpLoad %uint %6 + %18 = OpLoad %uint %7 + %16 = OpBitwiseXor %uint %17 %18 + OpStore %6 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %25 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %25 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 23a63be..365d1e8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1289,6 +1289,9 @@ fn extract_globals<'input, 'b>( .. }, ) => global.push(var), + Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => { + local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg)); + } Statement::Instruction(ast::Instruction::Atom( d @ @@ -1591,6 +1594,24 @@ fn convert_to_typed_statements( arg: arg.cast(), })) } + ast::Instruction::Xor { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Xor { + typ, + arg: arg.cast(), + })) + } + ast::Instruction::Bfe { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Bfe { + typ, + arg: arg.cast(), + })) + } + ast::Instruction::Rem { typ, arg } => { + result.push(Statement::Instruction(ast::Instruction::Rem { + typ, + arg: arg.cast(), + })) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -1610,6 +1631,7 @@ fn convert_to_typed_statements( Ok(result) } +//TODO: share common code between this and to_ptx_impl_bfe_call fn to_ptx_impl_atomic_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, @@ -1705,6 +1727,100 @@ fn to_ptx_impl_atomic_call( }) } +fn to_ptx_impl_bfe_call( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + typ: ast::IntType, + arg: ast::Arg4, +) -> ExpandedStatement { + let prefix = "__notcuda_ptx_impl__"; + let suffix = match typ { + ast::IntType::U32 => "bfe_u32", + ast::IntType::U64 => "bfe_u64", + ast::IntType::S32 => "bfe_s32", + ast::IntType::S64 => "bfe_s64", + _ => unreachable!(), + }; + let fn_name = format!("{}{}", prefix, suffix); + let fn_id = match ptx_impl_imports.entry(fn_name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.new_id(None); + let func_decl = ast::MethodDecl::Func::( + vec![ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + name: id_defs.new_id(None), + array_init: Vec::new(), + }], + fn_id, + vec![ + ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + name: id_defs.new_id(None), + array_init: Vec::new(), + }, + ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( + ast::ScalarType::U32, + )), + name: id_defs.new_id(None), + array_init: Vec::new(), + }, + ast::FnArgument { + align: None, + v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( + ast::ScalarType::U32, + )), + name: id_defs.new_id(None), + array_init: Vec::new(), + }, + ], + ); + let spirv_decl = SpirvMethodDecl::new(&func_decl); + let func = Function { + func_decl, + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + spirv_decl, + }; + entry.insert(Directive::Method(func)); + fn_id + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { + func_decl: ast::MethodDecl::Func(_, name, _), + .. + }) => *name, + _ => unreachable!(), + }, + }; + Statement::Call(ResolvedCall { + uniform: false, + func: fn_id, + ret_params: vec![( + arg.dst, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + )], + param_list: vec![ + ( + arg.src1, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), + ), + ( + arg.src2, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ), + ( + arg.src3, + ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), + ), + ], + }) +} + fn to_resolved_fn_args( params: Vec, params_decl: &[ast::FnArgumentType], @@ -2803,7 +2919,7 @@ fn emit_function_body_ops( let result_id = Some(a.dst); let operand = a.src; match t { - ast::NotType::Pred => { + ast::BooleanType::Pred => { // HACK ALERT // Temporary workaround until IGC gets its shit together // Currently IGC carries two copies of SPIRV-LLVM translator @@ -2854,7 +2970,7 @@ fn emit_function_body_ops( }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::OrAndType::Pred { + if *t == ast::BooleanType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; @@ -2882,7 +2998,7 @@ fn emit_function_body_ops( } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); - if *t == ast::OrAndType::Pred { + if *t == ast::BooleanType::Pred { builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; @@ -3033,6 +3149,39 @@ fn emit_function_body_ops( let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder.bit_count(result_type, Some(arg.dst), arg.src)?; } + ast::Instruction::Xor { typ, arg } => { + let builder_fn = match typ { + ast::BooleanType::Pred => emit_logical_xor_spirv, + _ => dr::Builder::bitwise_xor, + }; + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; + } + ast::Instruction::Bfe { typ, arg } => { + let builder_fn = if typ.is_signed() { + dr::Builder::bit_field_s_extract + } else { + dr::Builder::bit_field_u_extract + }; + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder_fn( + builder, + result_type, + Some(arg.dst), + arg.src1, + arg.src2, + arg.src3, + )?; + } + ast::Instruction::Rem { typ, arg } => { + let builder_fn = if typ.is_signed() { + dr::Builder::s_mod + } else { + dr::Builder::u_mod + }; + let result_type = map.get_or_add_scalar(builder, (*typ).into()); + builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); @@ -3079,6 +3228,20 @@ fn emit_function_body_ops( Ok(()) } +// TODO: check what kind of assembly do we emit +fn emit_logical_xor_spirv( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + op1: spirv::Word, + op2: spirv::Word, +) -> Result { + let temp_or = builder.logical_or(result_type, None, op1, op2)?; + let temp_and = builder.logical_and(result_type, None, op1, op2)?; + let temp_neg = builder.logical_not(result_type, None, temp_and)?; + builder.logical_and(result_type, result_id, temp_or, temp_neg) +} + fn emit_sqrt( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -5039,6 +5202,27 @@ impl ast::Instruction { arg: arg.map_different_types(visitor, &dst_type, &src_type)?, } } + ast::Instruction::Xor { typ, arg } => { + let full_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Xor { + typ, + arg: arg.map_non_shift(visitor, &full_type, false)?, + } + } + ast::Instruction::Bfe { typ, arg } => { + let full_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Bfe { + typ, + arg: arg.map_bfe(visitor, &full_type)?, + } + } + ast::Instruction::Rem { typ, arg } => { + let full_type = ast::Type::Scalar(typ.into()); + ast::Instruction::Rem { + typ, + arg: arg.map_non_shift(visitor, &full_type, false)?, + } + } }) } } @@ -5351,6 +5535,9 @@ impl ast::Instruction { ast::Instruction::Clz { .. } => None, ast::Instruction::Brev { .. } => None, ast::Instruction::Popc { .. } => None, + ast::Instruction::Xor { .. } => None, + ast::Instruction::Bfe { .. } => None, + ast::Instruction::Rem { .. } => None, ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) @@ -6192,6 +6379,52 @@ impl ast::Arg4 { src3, }) } + + fn map_bfe>( + self, + visitor: &mut V, + typ: &ast::Type, + ) -> Result, TranslateError> { + let dst = visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(typ), + )?; + let src1 = visitor.operand( + ArgumentDescriptor { + op: self.src1, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + typ, + )?; + let u32_type = ast::Type::Scalar(ast::ScalarType::U32); + let src2 = visitor.operand( + ArgumentDescriptor { + op: self.src2, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &u32_type, + )?; + let src3 = visitor.operand( + ArgumentDescriptor { + op: self.src3, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + &u32_type, + )?; + Ok(ast::Arg4 { + dst, + src1, + src2, + src3, + }) + } } impl ast::Arg4Setp { @@ -6437,13 +6670,13 @@ impl ast::ScalarType { } } -impl ast::NotType { +impl ast::BooleanType { fn to_type(self) -> ast::Type { match self { - ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), - ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16), - ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32), - ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64), + ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), + ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16), + ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32), + ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64), } } }