diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..86df120 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +21ef5f60a3a5efa17855a30f6b5c7d1968cd46ba diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 01c5073..84a104f 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index c3643a5..e8e3206 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -562,4 +562,19 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return ballot(value, true); } + +#define REDUX_SYNC_TYPE_IMPL(reducer, ptx_type, amd_type, cpp_type) \ + cpp_type __ockl_wfred_##reducer##_##amd_type(cpp_type) __device__; \ + cpp_type FUNC(redux_sync_##reducer##_##ptx_type)(cpp_type src, uint32_t membermask __attribute__((unused))) \ + { \ + return __ockl_wfred_##reducer##_##amd_type(src); \ + } + +#define REDUX_SYNC_IMPL(reducer) \ + REDUX_SYNC_TYPE_IMPL(reducer, u32, u32, uint32_t) \ + REDUX_SYNC_TYPE_IMPL(reducer, s32, i32, int32_t) + + REDUX_SYNC_IMPL(add); + REDUX_SYNC_IMPL(min); + REDUX_SYNC_IMPL(max); } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 0fd15b8..c0f537b 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -195,7 +195,8 @@ fn run_instruction<'input>( | ast::Instruction::Tanh { .. } | ast::Instruction::Trap {} | ast::Instruction::Xor { .. } - | ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 4e82b6a..504d36b 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1853,7 +1853,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Mul24 { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::AtomCas { .. } - | ast::Instruction::Vote { .. } => InstructionModes::none(), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index d3b81cd..d0e826c 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -526,7 +526,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Activemask { .. } | ast::Instruction::ShflSync { .. } | ast::Instruction::Vote { .. } - | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), + | ast::Instruction::Nanosleep { .. } + | ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()), } } @@ -1645,9 +1646,39 @@ impl<'a> MethodEmitContext<'a> { } }; let src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst_type, dst) - }); + if let Some(src2) = arguments.src2 { + let packed_type = get_scalar_type( + self.context, + data.to + .packed_type() + .ok_or_else(|| error_mismatched_type())?, + ); + let src2 = self.resolver.value(src2)?; + self.resolver.with_result(arguments.dst, |dst| { + let vec = unsafe { + LLVMBuildInsertElement( + self.builder, + LLVMGetPoison(dst_type), + llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32), + LLVM_UNNAMED.as_ptr(), + ) + }; + unsafe { + LLVMBuildInsertElement( + self.builder, + vec, + llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32), + dst, + ) + } + }) + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }) + }; Ok(()) } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index b92ab95..2a939fd 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -377,6 +377,7 @@ fn run_instruction<'input>( let name = match data.pred_reduction { ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::Or => "bar_red_or_pred", + _ => return Err(error_unreachable()), }; to_call( resolver, @@ -400,6 +401,25 @@ fn run_instruction<'input>( ptx_parser::Instruction::Vote { data, arguments }, )? } + ptx_parser::Instruction::ReduxSync { data, arguments } => { + let op = match data.reduction { + ptx_parser::Reduction::Add => "add", + ptx_parser::Reduction::Min => "min", + ptx_parser::Reduction::Max => "max", + _ => return Err(error_unreachable()), + }; + let name = format!( + "redux_sync_{}_{}", + op, + data.type_.to_string().replace(".", "") + ); + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::ReduxSync { data, arguments }, + )? + } ptx_parser::Instruction::ShflSync { data, arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, diff --git a/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll new file mode 100644 index 0000000..1e19037 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll @@ -0,0 +1,41 @@ +define amdgpu_kernel void @cvt_rn_bf16x2_f32(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca float, align 4, addrspace(5) + %"42" = alloca float, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"44" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"44", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"45", ptr addrspace(5) %"40", align 8 + %"47" = load i64, ptr addrspace(5) %"39", align 8 + %"55" = inttoptr i64 %"47" to ptr + %"46" = load float, ptr %"55", align 4 + store float %"46", ptr addrspace(5) %"41", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"56" = inttoptr i64 %"48" to ptr + %"35" = getelementptr inbounds i8, ptr %"56", i64 4 + %"49" = load float, ptr %"35", align 4 + store float %"49", ptr addrspace(5) %"42", align 4 + %"51" = load float, ptr addrspace(5) %"41", align 4 + %"52" = load float, ptr addrspace(5) %"42", align 4 + %2 = fptrunc float %"51" to bfloat + %3 = insertelement <2 x bfloat> poison, bfloat %2, i32 1 + %4 = fptrunc float %"52" to bfloat + %"57" = insertelement <2 x bfloat> %3, bfloat %4, i32 0 + %"50" = bitcast <2 x bfloat> %"57" to i32 + store i32 %"50", ptr addrspace(5) %"43", align 4 + %"53" = load i64, ptr addrspace(5) %"40", align 8 + %"54" = load i32, ptr addrspace(5) %"43", align 4 + %"58" = inttoptr i64 %"53" to ptr + store i32 %"54", ptr %"58", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_add_u32_partial.ll b/ptx/src/test/ll/redux_sync_add_u32_partial.ll new file mode 100644 index 0000000..ab55c15 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_add_u32_partial.ll @@ -0,0 +1,58 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_add_u32_partial(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i1, align 1, addrspace(5) + %"62" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"52" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"52", ptr addrspace(5) %"49", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"55" = load i32, ptr addrspace(5) %"47", align 4 + %"54" = urem i32 %"55", 2 + store i32 %"54", ptr addrspace(5) %"50", align 4 + %"57" = load i32, ptr addrspace(5) %"50", align 4 + %2 = icmp eq i32 %"57", 0 + store i1 %2, ptr addrspace(5) %"51", align 1 + store i32 0, ptr addrspace(5) %"48", align 4 + %"59" = load i1, ptr addrspace(5) %"51", align 1 + br i1 %"59", label %"16", label %"17" + +"16": ; preds = %"44" + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"61", i32 1431655765) + store i32 %"60", ptr addrspace(5) %"48", align 4 + br label %"17" + +"17": ; preds = %"16", %"44" + %"64" = load i32, ptr addrspace(5) %"47", align 4 + %3 = zext i32 %"64" to i64 + %"63" = mul i64 %3, 4 + store i64 %"63", ptr addrspace(5) %"62", align 8 + %"66" = load i64, ptr addrspace(5) %"49", align 8 + %"67" = load i64, ptr addrspace(5) %"62", align 8 + %"65" = add i64 %"66", %"67" + store i64 %"65", ptr addrspace(5) %"49", align 8 + %"68" = load i64, ptr addrspace(5) %"49", align 8 + %"69" = load i32, ptr addrspace(5) %"48", align 4 + %"70" = inttoptr i64 %"68" to ptr + store i32 %"69", ptr %"70", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_s32.ll b/ptx/src/test/ll/redux_sync_op_s32.ll new file mode 100644 index 0000000..a84624b --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_s32.ll @@ -0,0 +1,67 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_s32(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i32, align 4, addrspace(5) + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i64, align 8, addrspace(5) + %"70" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"54" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"54", ptr addrspace(5) %"53", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"57" = load i32, ptr addrspace(5) %"47", align 4 + %"56" = sub i32 %"57", 5 + store i32 %"56", ptr addrspace(5) %"48", align 4 + %"59" = load i32, ptr addrspace(5) %"48", align 4 + %"58" = call i32 @__zluda_ptx_impl_redux_sync_add_s32(i32 %"59", i32 -1) + store i32 %"58", ptr addrspace(5) %"49", align 4 + %"61" = load i32, ptr addrspace(5) %"48", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_min_s32(i32 %"61", i32 -1) + store i32 %"60", ptr addrspace(5) %"50", align 4 + %"63" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = call i32 @__zluda_ptx_impl_redux_sync_max_s32(i32 %"63", i32 -1) + store i32 %"62", ptr addrspace(5) %"51", align 4 + %"65" = load i32, ptr addrspace(5) %"49", align 4 + %"66" = load i32, ptr addrspace(5) %"50", align 4 + %"64" = add i32 %"65", %"66" + store i32 %"64", ptr addrspace(5) %"52", align 4 + %"68" = load i32, ptr addrspace(5) %"52", align 4 + %"69" = load i32, ptr addrspace(5) %"51", align 4 + %"67" = add i32 %"68", %"69" + store i32 %"67", ptr addrspace(5) %"52", align 4 + %"72" = load i32, ptr addrspace(5) %"47", align 4 + %2 = zext i32 %"72" to i64 + %"71" = mul i64 %2, 4 + store i64 %"71", ptr addrspace(5) %"70", align 8 + %"74" = load i64, ptr addrspace(5) %"53", align 8 + %"75" = load i64, ptr addrspace(5) %"70", align 8 + %"73" = add i64 %"74", %"75" + store i64 %"73", ptr addrspace(5) %"53", align 8 + %"76" = load i64, ptr addrspace(5) %"53", align 8 + %"77" = load i32, ptr addrspace(5) %"52", align 4 + %"79" = inttoptr i64 %"76" to ptr + store i32 %"77", ptr %"79", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_u32.ll b/ptx/src/test/ll/redux_sync_op_u32.ll new file mode 100644 index 0000000..3629939 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_u32.ll @@ -0,0 +1,63 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_u32(ptr addrspace(4) byref(i64) %"44") #1 { + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"65" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"41" + +"41": ; preds = %1 + %"51" = load i64, ptr addrspace(4) %"44", align 8 + store i64 %"51", ptr addrspace(5) %"50", align 8 + %"36" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"42" + +"42": ; preds = %"41" + store i32 %"36", ptr addrspace(5) %"45", align 4 + %"54" = load i32, ptr addrspace(5) %"45", align 4 + %"53" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"54", i32 -1) + store i32 %"53", ptr addrspace(5) %"46", align 4 + %"56" = load i32, ptr addrspace(5) %"45", align 4 + %"55" = call i32 @__zluda_ptx_impl_redux_sync_min_u32(i32 %"56", i32 -1) + store i32 %"55", ptr addrspace(5) %"47", align 4 + %"58" = load i32, ptr addrspace(5) %"45", align 4 + %"57" = call i32 @__zluda_ptx_impl_redux_sync_max_u32(i32 %"58", i32 -1) + store i32 %"57", ptr addrspace(5) %"48", align 4 + %"60" = load i32, ptr addrspace(5) %"46", align 4 + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"59" = add i32 %"60", %"61" + store i32 %"59", ptr addrspace(5) %"49", align 4 + %"63" = load i32, ptr addrspace(5) %"49", align 4 + %"64" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = add i32 %"63", %"64" + store i32 %"62", ptr addrspace(5) %"49", align 4 + %"67" = load i32, ptr addrspace(5) %"45", align 4 + %2 = zext i32 %"67" to i64 + %"66" = mul i64 %2, 4 + store i64 %"66", ptr addrspace(5) %"65", align 8 + %"69" = load i64, ptr addrspace(5) %"50", align 8 + %"70" = load i64, ptr addrspace(5) %"65", align 8 + %"68" = add i64 %"69", %"70" + store i64 %"68", ptr addrspace(5) %"50", align 8 + %"71" = load i64, ptr addrspace(5) %"50", align 8 + %"72" = load i32, ptr addrspace(5) %"49", align 4 + %"73" = inttoptr i64 %"71" to ptr + store i32 %"72", ptr %"73", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx new file mode 100644 index 0000000..2bad276 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_bf16x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.bf16x2.f32 result, in_a, in_b; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f413a23..6e1b27e 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -200,6 +200,7 @@ test_ptx!( ); test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]); test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]); +test_ptx!(cvt_rn_bf16x2_f32, [0.40625, 12.9f32], [0x3ED0414Eu32]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( @@ -452,6 +453,40 @@ test_ptx_warp!( 4294967292, 4294967292, 4294967292, 4294967292, 4294967292 ] ); +test_ptx_warp!( + redux_sync_op_s32, + [ + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, + ] +); +test_ptx_warp!( + redux_sync_op_u32, + [ + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, + ] +); +test_ptx_warp!( + redux_sync_add_u32_partial, + [ + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, + 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, + 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, + 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32 + ] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx new file mode 100644 index 0000000..18485ec --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx @@ -0,0 +1,31 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_add_u32_partial( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 result; + .reg .u64 out_ptr; + + .reg .u32 tid_rem_2; + .reg .pred p; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + rem.u32 tid_rem_2, tid, 2; + setp.eq.u32 p, tid_rem_2, 0; + + mov.u32 result, 0; + @p redux.sync.add.u32 result, tid, 0x55555555; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_s32.ptx b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx new file mode 100644 index 0000000..7e3ec1d --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx @@ -0,0 +1,34 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_s32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .s32 in; + .reg .s32 add_out; + .reg .s32 min_out; + .reg .s32 max_out; + .reg .s32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + sub.s32 in, tid, 5; + + redux.sync.add.s32 add_out, in, 0xFFFFFFFF; + redux.sync.min.s32 min_out, in, 0xFFFFFFFF; + redux.sync.max.s32 max_out, in, 0xFFFFFFFF; + + add.s32 result, add_out, min_out; + add.s32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.s32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_u32.ptx b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx new file mode 100644 index 0000000..03292ee --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx @@ -0,0 +1,32 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_u32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 add_out; + .reg .u32 min_out; + .reg .u32 max_out; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + redux.sync.add.u32 add_out, tid, 0xFFFFFFFF; + redux.sync.min.u32 min_out, tid, 0xFFFFFFFF; + redux.sync.max.u32 max_out, tid, 0xFFFFFFFF; + + add.u32 result, add_out, min_out; + add.u32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 0570eb4..37b5f6b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -695,6 +695,18 @@ ptx_parser_macros::generate_instruction_type!( } } + }, + ReduxSync { + type: Type::Scalar(data.type_), + data: ReduxSyncData, + arguments: { + dst: T, + src: T, + src_membermask: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + } + } } } ); @@ -1162,6 +1174,35 @@ impl ScalarType { ScalarType::Pred => ScalarKind::Pred, } } + + pub fn packed_type(&self) -> Option { + match self { + ScalarType::E4m3x2 => Some(ScalarType::B8), + ScalarType::E5m2x2 => Some(ScalarType::B8), + ScalarType::F16x2 => Some(ScalarType::F16), + ScalarType::BF16x2 => Some(ScalarType::BF16), + ScalarType::U16x2 => Some(ScalarType::U16), + ScalarType::S16x2 => Some(ScalarType::S16), + ScalarType::S16 + | ScalarType::BF16 + | ScalarType::U32 + | ScalarType::S8 + | ScalarType::S32 + | ScalarType::Pred + | ScalarType::B8 + | ScalarType::U64 + | ScalarType::B16 + | ScalarType::S64 + | ScalarType::B32 + | ScalarType::U8 + | ScalarType::F32 + | ScalarType::B64 + | ScalarType::B128 + | ScalarType::U16 + | ScalarType::F64 + | ScalarType::F16 => None, + } + } } #[derive(Clone, Copy, PartialEq, Eq)] @@ -1933,8 +1974,13 @@ impl CvtDetails { (RoundingMode::NearestEven, false) } }; + let dst_size = if dst.packed_type().is_some() { + dst.size_of() / 2 + } else { + dst.size_of() + }; let mode = match (dst.kind(), src.kind()) { - (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + (ScalarKind::Float, ScalarKind::Float) => match dst_size.cmp(&src.size_of()) { Ordering::Less => { let (rounding, is_integer_rounding) = unwrap_rounding(); CvtMode::FPTruncate { @@ -2272,3 +2318,8 @@ impl VoteMode { } } } + +pub struct ReduxSyncData { + pub type_: ScalarType, + pub reduction: Reduction, +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 106ba69..3518965 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2442,7 +2442,16 @@ derive_parser!( // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; - // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b => { + if relu || satfinite { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(frnd2), false, false, ScalarType::BF16x2, ScalarType::F32); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) } + } + } // cvt.rna{.satfinite}.tf32.f32 d, a; // cvt.frnd2{.relu}.tf32.f32 d, a; cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => { @@ -3844,6 +3853,23 @@ derive_parser!( // .mode: VoteMode = { .all, .any, .uni }; .mode: VoteMode = { .all, .any }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync + + redux.sync.op.type dst, src, membermask => { + Instruction::ReduxSync { + data: ReduxSyncData { type_, reduction: op }, + arguments: ReduxSyncArgs { dst, src, src_membermask: membermask } + } + } + .op: Reduction = {.add, .min, .max}; + .type: ScalarType = {.u32, .s32}; + + // redux.sync.op.b32 dst, src, membermask; + // .op = {.and, .or, .xor} + + // redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask; + // .op = { .min, .max } + ); #[cfg(test)]