From d342e1a06e4bb45123932050699299c600557006 Mon Sep 17 00:00:00 2001 From: Violet Date: Mon, 8 Sep 2025 16:13:28 -0700 Subject: [PATCH] Implement redux.sync for u32 and s32 (#500) --- ptx/lib/zluda_ptx_impl.bc | Bin 16612 -> 17636 bytes ptx/lib/zluda_ptx_impl.cpp | 15 ++++ ptx/src/pass/insert_post_saturation.rs | 3 +- .../instruction_mode_to_global_mode/mod.rs | 3 +- ptx/src/pass/llvm/emit.rs | 3 +- .../replace_instructions_with_functions.rs | 20 ++++++ ptx/src/test/ll/redux_sync_add_u32_partial.ll | 58 +++++++++++++++ ptx/src/test/ll/redux_sync_op_s32.ll | 67 ++++++++++++++++++ ptx/src/test/ll/redux_sync_op_u32.ll | 63 ++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 34 +++++++++ .../spirv_run/redux_sync_add_u32_partial.ptx | 31 ++++++++ ptx/src/test/spirv_run/redux_sync_op_s32.ptx | 34 +++++++++ ptx/src/test/spirv_run/redux_sync_op_u32.ptx | 32 +++++++++ ptx_parser/src/ast.rs | 17 +++++ ptx_parser/src/lib.rs | 17 +++++ 15 files changed, 394 insertions(+), 3 deletions(-) create mode 100644 ptx/src/test/ll/redux_sync_add_u32_partial.ll create mode 100644 ptx/src/test/ll/redux_sync_op_s32.ll create mode 100644 ptx/src/test/ll/redux_sync_op_u32.ll create mode 100644 ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx create mode 100644 ptx/src/test/spirv_run/redux_sync_op_s32.ptx create mode 100644 ptx/src/test/spirv_run/redux_sync_op_u32.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 01c5073887bef7e39abe7bd5840fd4e744c39e04..84a104fad401d174f65c6c39300948dfd31107ac 100644 GIT binary patch delta 2900 zcmaFT$oQm_ae@kyC(lN;2^_90iBf7TsRoV`&52BjZ0QYbnFfjh&76m%HXLkOz^f*b zYaq~Sn3&nzsL|7!Ifs3g!j^*uiW}4fat#EIwk+gjOFHy`?b_zw9Mc&YEjCZ)if64) zPGCF`!o~mtk_Jnf{x@7|US6+ps_pir41YQW6VzzOp|14vzl zwu=~=ai79g++yHhFahZU;}Z>vQ(trSva%e2N7@8N=T&z_mAFxCRWaib5lduXIKjlo z01i+ixfwE&f<6;^IMv*bu{aA|;q@wxUAg(6fGi86#bz1Na7NVxzNH392}>9`ni?1J zxf>)ktYGA5a+G#yw3a#6!MKr4k%P_KYVr&*UG_t43LI=awv*3@>4Z}JrJWa*bq z8yFc96&M&;85kIZSsdA%dn82i1ey{Qo=(=25Mu)~?;@C25zMob(g{rEW{%g7|cjW*rBL#AYq!LM3Z1M6NjV0ft1OD^5VuBjk8oO{AQNy zP-I{>n33u+A>qgZ1}6a?gXjfmoRgB(Wb-(VFz}>Fim6H%YBUJ2F!P8{j+S@g&Jk!% zD426tL4uKalA`M5v+~@c(7j7pMY(O_X}Vp38NV>r>l z_=!!CLo8tfTZDmzGTRAhu#=@yoDF&;6d=yYULfj}$k>43A7FI?nFz7QQCxtTt-ry? zoTbS)DS<(egDs(fYw|iJtF)X1ks}#S3_TJu4jfI4k~6;08q88)kz_FhD-}!l z@X|4n4aw*UtW5}`8^w5-4+t_MWF5tM4s|eon(U`+$QV7jLD`ZqfAcP7e|En3 zhyWfZhChdxnk8BoR!kPx?q_VByj*)3W6@+komR#Ro3H4IF|r>8rB#PLlRxU3uz^C1 z@yTRuJ#7J~utAT62*}k}Ca3A?z}3#wQ(*rH%Ci^rCLhq#tOx76!hMiU(GZ;Q`kEwK z7;b>_9JpWsWk67YFU)eV!|@?V-z_yz35Y2kg+u(7Dh~Bv@p?>yZmHnV02VJ)U|`T= zU|=xf6HsblKE~alaJqqOH_r)$GYxzvc{mu)HV9qku@F4hAo`6bW8#eliAB6K9^P({ zzQKFqz_39iD^)MPF4pMIo zRS%;<;vn^LQ1x*f4E11(KoTGgMNkbe8YB+VFcqo+MuWsb>Q_S5!)TB=Nc|b8dKe87 z2dRGmRS%>8|F4Jmn3)se0~ig{APyCW(I9aS9R>z%P6h^0r3|BC;ucVG7!4D5zhpvF8 z1Q-ny-^t0qAPUM3Fa}KG9MlIe8YX@lYS3-Y$%o9{I6J|qaEhtUW7_xD>nRyt- zfyF?^DIpmL5=1vHH#2YYM=Mbj^AamC%>#*n%u_=)FR>yX!@R_b_)KG?%@6IF8JQRu E02@sT6#xJL delta 1965 zcmaFT$@rv^ae@ld6Yh;_6F6L37VxNXBpdiBoO!@=L4j9|CCk7ug4t|_z?Oq8i+I@< z8K~*x8VIy*H0b^8Alzamp~iBJ1tdPhV9P-VABj5$q!<_&Hvi_B&d8{{c`{c#YyA{f z76(ZN7)UuVLF-TRGiIg)W|+{42F0nr+0U}G9AJY9O<;6h^_PE!DFdpqtDIaC(hQ6n z7-8yC4orB(-kQVPz;FX5#9)-*(c0d$n2|AojX{8csh)v>L4ZNhV2P0{qj3Yn3}kgr zyPZ-P(TtOAUB%nLU;%es1CK<|#3dIZ92j!oS|tsZba^UnLb25$so5d%AuB_JjsSxI z0|SGR+zc5>L7xem9|&r*FzRmR5esKjUBJE6ASq!1BS%x?BW`ztq=qGo98HeWE{zs4 z$2u4nvMF+~dFxN^5Z7hj#HPT(#$z>kkGPKb8b*#LMrD3436{tT3KIsMp36sy0 z|BG9)gXJFcPqvpZ6allFg}CHcN(_=57!)|z+`K2ZNF+<|blSkkkf^}Gz{aWY+_oIt)-_ky=a_#MOvBNRk}q%BBOTlKWR-UyLz&oj0NM4 z$>}ocj4LPi%UCn+n0#7Bon@g@K<;ES83V>0lhtJrQr)sBQqN`87*|f_le1ymG1*&A z4Wi3b#*lHx8b|`8bNSNj*(InW+#NlXgz-Ka}g1B)&<1AGRznLXF6d9NeW~6#dNI0^9!AXF} zAap?*=cHsc**uOT3_NL)VyY5`8Vv$0%srx$y%n6evjv(H3f>%2kYHq<{6SH5@?HgQ z(QF5v2p%S*y8;r2IvN>U6e1c9UT2$pSHT*a%yzcPx{BUfxd|edd07Q}BqZ_#ni3e< z;#+y#Sy-9bF7rAICT!q10W;`1+vG)x4vaG=KUTEGW?VSiWK*R~u^o;aO{@}Y4vL8! z6Xi{Lu$GmKqf-0aTFI|X6tY8F=uHqPD)@<N*#%-g(5#h;xoG9rM-iQ&v4re=v229L?? zI{l2Ulc(z}V=SC(r`yU{u=$9t7$bWtD6u+hnEX)Bgbfs8j4G3*^|b|{!UjDOA|O|b zOb*l6fvfG*S73h#$`lv0CvVW#tOx58;XcTwXb4VueN7TA3>BbE2QE4a6&M&aK}7?f zfKm%{7k5X({syl9+$R!FH1PHEa5$c95Sq z|2G&<5{P)r)o6Y|V1_eqqt$n8%fh zM%Ovw8B?kn-8YJZ-OG^5!@wZO#K0iLJf%W|fq?j}FfcG!a6m*>a4;|kGB7ag1VKhwqH#-9Z4+8@OjE0G`b1*PSFfcH{XqdP>2gFJ8 z91Qhfx4|Sp?&E{{14M(wIaL@K+@J=*XplI_p>a@u!f2RyJ_iGXC<6lnjE0G~a6mi; mqyPUGfcSVK)S!tRlgq5!IL#mdzglziQL78=o9{a_GXek$v=op4 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index c3643a5..e8e3206 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -562,4 +562,19 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return ballot(value, true); } + +#define REDUX_SYNC_TYPE_IMPL(reducer, ptx_type, amd_type, cpp_type) \ + cpp_type __ockl_wfred_##reducer##_##amd_type(cpp_type) __device__; \ + cpp_type FUNC(redux_sync_##reducer##_##ptx_type)(cpp_type src, uint32_t membermask __attribute__((unused))) \ + { \ + return __ockl_wfred_##reducer##_##amd_type(src); \ + } + +#define REDUX_SYNC_IMPL(reducer) \ + REDUX_SYNC_TYPE_IMPL(reducer, u32, u32, uint32_t) \ + REDUX_SYNC_TYPE_IMPL(reducer, s32, i32, int32_t) + + REDUX_SYNC_IMPL(add); + REDUX_SYNC_IMPL(min); + REDUX_SYNC_IMPL(max); } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 0fd15b8..c0f537b 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -195,7 +195,8 @@ fn run_instruction<'input>( | ast::Instruction::Tanh { .. } | ast::Instruction::Trap {} | ast::Instruction::Xor { .. } - | ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 4e82b6a..504d36b 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1853,7 +1853,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Mul24 { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::AtomCas { .. } - | ast::Instruction::Vote { .. } => InstructionModes::none(), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index d3b81cd..08532e3 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -526,7 +526,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Activemask { .. } | ast::Instruction::ShflSync { .. } | ast::Instruction::Vote { .. } - | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), + | ast::Instruction::Nanosleep { .. } + | ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index b92ab95..2a939fd 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -377,6 +377,7 @@ fn run_instruction<'input>( let name = match data.pred_reduction { ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::Or => "bar_red_or_pred", + _ => return Err(error_unreachable()), }; to_call( resolver, @@ -400,6 +401,25 @@ fn run_instruction<'input>( ptx_parser::Instruction::Vote { data, arguments }, )? } + ptx_parser::Instruction::ReduxSync { data, arguments } => { + let op = match data.reduction { + ptx_parser::Reduction::Add => "add", + ptx_parser::Reduction::Min => "min", + ptx_parser::Reduction::Max => "max", + _ => return Err(error_unreachable()), + }; + let name = format!( + "redux_sync_{}_{}", + op, + data.type_.to_string().replace(".", "") + ); + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::ReduxSync { data, arguments }, + )? + } ptx_parser::Instruction::ShflSync { data, arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, diff --git a/ptx/src/test/ll/redux_sync_add_u32_partial.ll b/ptx/src/test/ll/redux_sync_add_u32_partial.ll new file mode 100644 index 0000000..ab55c15 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_add_u32_partial.ll @@ -0,0 +1,58 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_add_u32_partial(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i1, align 1, addrspace(5) + %"62" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"52" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"52", ptr addrspace(5) %"49", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"55" = load i32, ptr addrspace(5) %"47", align 4 + %"54" = urem i32 %"55", 2 + store i32 %"54", ptr addrspace(5) %"50", align 4 + %"57" = load i32, ptr addrspace(5) %"50", align 4 + %2 = icmp eq i32 %"57", 0 + store i1 %2, ptr addrspace(5) %"51", align 1 + store i32 0, ptr addrspace(5) %"48", align 4 + %"59" = load i1, ptr addrspace(5) %"51", align 1 + br i1 %"59", label %"16", label %"17" + +"16": ; preds = %"44" + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"61", i32 1431655765) + store i32 %"60", ptr addrspace(5) %"48", align 4 + br label %"17" + +"17": ; preds = %"16", %"44" + %"64" = load i32, ptr addrspace(5) %"47", align 4 + %3 = zext i32 %"64" to i64 + %"63" = mul i64 %3, 4 + store i64 %"63", ptr addrspace(5) %"62", align 8 + %"66" = load i64, ptr addrspace(5) %"49", align 8 + %"67" = load i64, ptr addrspace(5) %"62", align 8 + %"65" = add i64 %"66", %"67" + store i64 %"65", ptr addrspace(5) %"49", align 8 + %"68" = load i64, ptr addrspace(5) %"49", align 8 + %"69" = load i32, ptr addrspace(5) %"48", align 4 + %"70" = inttoptr i64 %"68" to ptr + store i32 %"69", ptr %"70", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_s32.ll b/ptx/src/test/ll/redux_sync_op_s32.ll new file mode 100644 index 0000000..a84624b --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_s32.ll @@ -0,0 +1,67 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_s32(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i32, align 4, addrspace(5) + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i64, align 8, addrspace(5) + %"70" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"54" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"54", ptr addrspace(5) %"53", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"57" = load i32, ptr addrspace(5) %"47", align 4 + %"56" = sub i32 %"57", 5 + store i32 %"56", ptr addrspace(5) %"48", align 4 + %"59" = load i32, ptr addrspace(5) %"48", align 4 + %"58" = call i32 @__zluda_ptx_impl_redux_sync_add_s32(i32 %"59", i32 -1) + store i32 %"58", ptr addrspace(5) %"49", align 4 + %"61" = load i32, ptr addrspace(5) %"48", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_min_s32(i32 %"61", i32 -1) + store i32 %"60", ptr addrspace(5) %"50", align 4 + %"63" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = call i32 @__zluda_ptx_impl_redux_sync_max_s32(i32 %"63", i32 -1) + store i32 %"62", ptr addrspace(5) %"51", align 4 + %"65" = load i32, ptr addrspace(5) %"49", align 4 + %"66" = load i32, ptr addrspace(5) %"50", align 4 + %"64" = add i32 %"65", %"66" + store i32 %"64", ptr addrspace(5) %"52", align 4 + %"68" = load i32, ptr addrspace(5) %"52", align 4 + %"69" = load i32, ptr addrspace(5) %"51", align 4 + %"67" = add i32 %"68", %"69" + store i32 %"67", ptr addrspace(5) %"52", align 4 + %"72" = load i32, ptr addrspace(5) %"47", align 4 + %2 = zext i32 %"72" to i64 + %"71" = mul i64 %2, 4 + store i64 %"71", ptr addrspace(5) %"70", align 8 + %"74" = load i64, ptr addrspace(5) %"53", align 8 + %"75" = load i64, ptr addrspace(5) %"70", align 8 + %"73" = add i64 %"74", %"75" + store i64 %"73", ptr addrspace(5) %"53", align 8 + %"76" = load i64, ptr addrspace(5) %"53", align 8 + %"77" = load i32, ptr addrspace(5) %"52", align 4 + %"79" = inttoptr i64 %"76" to ptr + store i32 %"77", ptr %"79", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_u32.ll b/ptx/src/test/ll/redux_sync_op_u32.ll new file mode 100644 index 0000000..3629939 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_u32.ll @@ -0,0 +1,63 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_u32(ptr addrspace(4) byref(i64) %"44") #1 { + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"65" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"41" + +"41": ; preds = %1 + %"51" = load i64, ptr addrspace(4) %"44", align 8 + store i64 %"51", ptr addrspace(5) %"50", align 8 + %"36" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"42" + +"42": ; preds = %"41" + store i32 %"36", ptr addrspace(5) %"45", align 4 + %"54" = load i32, ptr addrspace(5) %"45", align 4 + %"53" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"54", i32 -1) + store i32 %"53", ptr addrspace(5) %"46", align 4 + %"56" = load i32, ptr addrspace(5) %"45", align 4 + %"55" = call i32 @__zluda_ptx_impl_redux_sync_min_u32(i32 %"56", i32 -1) + store i32 %"55", ptr addrspace(5) %"47", align 4 + %"58" = load i32, ptr addrspace(5) %"45", align 4 + %"57" = call i32 @__zluda_ptx_impl_redux_sync_max_u32(i32 %"58", i32 -1) + store i32 %"57", ptr addrspace(5) %"48", align 4 + %"60" = load i32, ptr addrspace(5) %"46", align 4 + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"59" = add i32 %"60", %"61" + store i32 %"59", ptr addrspace(5) %"49", align 4 + %"63" = load i32, ptr addrspace(5) %"49", align 4 + %"64" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = add i32 %"63", %"64" + store i32 %"62", ptr addrspace(5) %"49", align 4 + %"67" = load i32, ptr addrspace(5) %"45", align 4 + %2 = zext i32 %"67" to i64 + %"66" = mul i64 %2, 4 + store i64 %"66", ptr addrspace(5) %"65", align 8 + %"69" = load i64, ptr addrspace(5) %"50", align 8 + %"70" = load i64, ptr addrspace(5) %"65", align 8 + %"68" = add i64 %"69", %"70" + store i64 %"68", ptr addrspace(5) %"50", align 8 + %"71" = load i64, ptr addrspace(5) %"50", align 8 + %"72" = load i32, ptr addrspace(5) %"49", align 4 + %"73" = inttoptr i64 %"71" to ptr + store i32 %"72", ptr %"73", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f413a23..e6d9a58 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -452,6 +452,40 @@ test_ptx_warp!( 4294967292, 4294967292, 4294967292, 4294967292, 4294967292 ] ); +test_ptx_warp!( + redux_sync_op_s32, + [ + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, + ] +); +test_ptx_warp!( + redux_sync_op_u32, + [ + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, + ] +); +test_ptx_warp!( + redux_sync_add_u32_partial, + [ + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, + 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, + 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, + 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32 + ] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx new file mode 100644 index 0000000..18485ec --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx @@ -0,0 +1,31 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_add_u32_partial( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 result; + .reg .u64 out_ptr; + + .reg .u32 tid_rem_2; + .reg .pred p; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + rem.u32 tid_rem_2, tid, 2; + setp.eq.u32 p, tid_rem_2, 0; + + mov.u32 result, 0; + @p redux.sync.add.u32 result, tid, 0x55555555; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_s32.ptx b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx new file mode 100644 index 0000000..7e3ec1d --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx @@ -0,0 +1,34 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_s32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .s32 in; + .reg .s32 add_out; + .reg .s32 min_out; + .reg .s32 max_out; + .reg .s32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + sub.s32 in, tid, 5; + + redux.sync.add.s32 add_out, in, 0xFFFFFFFF; + redux.sync.min.s32 min_out, in, 0xFFFFFFFF; + redux.sync.max.s32 max_out, in, 0xFFFFFFFF; + + add.s32 result, add_out, min_out; + add.s32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.s32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_u32.ptx b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx new file mode 100644 index 0000000..03292ee --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx @@ -0,0 +1,32 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_u32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 add_out; + .reg .u32 min_out; + .reg .u32 max_out; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + redux.sync.add.u32 add_out, tid, 0xFFFFFFFF; + redux.sync.min.u32 min_out, tid, 0xFFFFFFFF; + redux.sync.max.u32 max_out, tid, 0xFFFFFFFF; + + add.u32 result, add_out, min_out; + add.u32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 0570eb4..86719ef 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -695,6 +695,18 @@ ptx_parser_macros::generate_instruction_type!( } } + }, + ReduxSync { + type: Type::Scalar(data.type_), + data: ReduxSyncData, + arguments: { + dst: T, + src: T, + src_membermask: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + } + } } } ); @@ -2272,3 +2284,8 @@ impl VoteMode { } } } + +pub struct ReduxSyncData { + pub type_: ScalarType, + pub reduction: Reduction, +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 29b348a..2701127 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3844,6 +3844,23 @@ derive_parser!( // .mode: VoteMode = { .all, .any, .uni }; .mode: VoteMode = { .all, .any }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync + + redux.sync.op.type dst, src, membermask => { + Instruction::ReduxSync { + data: ReduxSyncData { type_, reduction: op }, + arguments: ReduxSyncArgs { dst, src, src_membermask: membermask } + } + } + .op: Reduction = {.add, .min, .max}; + .type: ScalarType = {.u32, .s32}; + + // redux.sync.op.b32 dst, src, membermask; + // .op = {.and, .or, .xor} + + // redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask; + // .op = { .min, .max } + ); #[cfg(test)]