From d342e1a06e4bb45123932050699299c600557006 Mon Sep 17 00:00:00 2001 From: Violet Date: Mon, 8 Sep 2025 16:13:28 -0700 Subject: [PATCH 1/3] Implement redux.sync for u32 and s32 (#500) --- ptx/lib/zluda_ptx_impl.bc | Bin 16612 -> 17636 bytes ptx/lib/zluda_ptx_impl.cpp | 15 ++++ ptx/src/pass/insert_post_saturation.rs | 3 +- .../instruction_mode_to_global_mode/mod.rs | 3 +- ptx/src/pass/llvm/emit.rs | 3 +- .../replace_instructions_with_functions.rs | 20 ++++++ ptx/src/test/ll/redux_sync_add_u32_partial.ll | 58 +++++++++++++++ ptx/src/test/ll/redux_sync_op_s32.ll | 67 ++++++++++++++++++ ptx/src/test/ll/redux_sync_op_u32.ll | 63 ++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 34 +++++++++ .../spirv_run/redux_sync_add_u32_partial.ptx | 31 ++++++++ ptx/src/test/spirv_run/redux_sync_op_s32.ptx | 34 +++++++++ ptx/src/test/spirv_run/redux_sync_op_u32.ptx | 32 +++++++++ ptx_parser/src/ast.rs | 17 +++++ ptx_parser/src/lib.rs | 17 +++++ 15 files changed, 394 insertions(+), 3 deletions(-) create mode 100644 ptx/src/test/ll/redux_sync_add_u32_partial.ll create mode 100644 ptx/src/test/ll/redux_sync_op_s32.ll create mode 100644 ptx/src/test/ll/redux_sync_op_u32.ll create mode 100644 ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx create mode 100644 ptx/src/test/spirv_run/redux_sync_op_s32.ptx create mode 100644 ptx/src/test/spirv_run/redux_sync_op_u32.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 01c5073887bef7e39abe7bd5840fd4e744c39e04..84a104fad401d174f65c6c39300948dfd31107ac 100644 GIT binary patch delta 2900 zcmaFT$oQm_ae@kyC(lN;2^_90iBf7TsRoV`&52BjZ0QYbnFfjh&76m%HXLkOz^f*b zYaq~Sn3&nzsL|7!Ifs3g!j^*uiW}4fat#EIwk+gjOFHy`?b_zw9Mc&YEjCZ)if64) zPGCF`!o~mtk_Jnf{x@7|US6+ps_pir41YQW6VzzOp|14vzl zwu=~=ai79g++yHhFahZU;}Z>vQ(trSva%e2N7@8N=T&z_mAFxCRWaib5lduXIKjlo z01i+ixfwE&f<6;^IMv*bu{aA|;q@wxUAg(6fGi86#bz1Na7NVxzNH392}>9`ni?1J zxf>)ktYGA5a+G#yw3a#6!MKr4k%P_KYVr&*UG_t43LI=awv*3@>4Z}JrJWa*bq z8yFc96&M&;85kIZSsdA%dn82i1ey{Qo=(=25Mu)~?;@C25zMob(g{rEW{%g7|cjW*rBL#AYq!LM3Z1M6NjV0ft1OD^5VuBjk8oO{AQNy zP-I{>n33u+A>qgZ1}6a?gXjfmoRgB(Wb-(VFz}>Fim6H%YBUJ2F!P8{j+S@g&Jk!% zD426tL4uKalA`M5v+~@c(7j7pMY(O_X}Vp38NV>r>l z_=!!CLo8tfTZDmzGTRAhu#=@yoDF&;6d=yYULfj}$k>43A7FI?nFz7QQCxtTt-ry? zoTbS)DS<(egDs(fYw|iJtF)X1ks}#S3_TJu4jfI4k~6;08q88)kz_FhD-}!l z@X|4n4aw*UtW5}`8^w5-4+t_MWF5tM4s|eon(U`+$QV7jLD`ZqfAcP7e|En3 zhyWfZhChdxnk8BoR!kPx?q_VByj*)3W6@+komR#Ro3H4IF|r>8rB#PLlRxU3uz^C1 z@yTRuJ#7J~utAT62*}k}Ca3A?z}3#wQ(*rH%Ci^rCLhq#tOx76!hMiU(GZ;Q`kEwK z7;b>_9JpWsWk67YFU)eV!|@?V-z_yz35Y2kg+u(7Dh~Bv@p?>yZmHnV02VJ)U|`T= zU|=xf6HsblKE~alaJqqOH_r)$GYxzvc{mu)HV9qku@F4hAo`6bW8#eliAB6K9^P({ zzQKFqz_39iD^)MPF4pMIo zRS%;<;vn^LQ1x*f4E11(KoTGgMNkbe8YB+VFcqo+MuWsb>Q_S5!)TB=Nc|b8dKe87 z2dRGmRS%>8|F4Jmn3)se0~ig{APyCW(I9aS9R>z%P6h^0r3|BC;ucVG7!4D5zhpvF8 z1Q-ny-^t0qAPUM3Fa}KG9MlIe8YX@lYS3-Y$%o9{I6J|qaEhtUW7_xD>nRyt- zfyF?^DIpmL5=1vHH#2YYM=Mbj^AamC%>#*n%u_=)FR>yX!@R_b_)KG?%@6IF8JQRu E02@sT6#xJL delta 1965 zcmaFT$@rv^ae@ld6Yh;_6F6L37VxNXBpdiBoO!@=L4j9|CCk7ug4t|_z?Oq8i+I@< z8K~*x8VIy*H0b^8Alzamp~iBJ1tdPhV9P-VABj5$q!<_&Hvi_B&d8{{c`{c#YyA{f z76(ZN7)UuVLF-TRGiIg)W|+{42F0nr+0U}G9AJY9O<;6h^_PE!DFdpqtDIaC(hQ6n z7-8yC4orB(-kQVPz;FX5#9)-*(c0d$n2|AojX{8csh)v>L4ZNhV2P0{qj3Yn3}kgr zyPZ-P(TtOAUB%nLU;%es1CK<|#3dIZ92j!oS|tsZba^UnLb25$so5d%AuB_JjsSxI z0|SGR+zc5>L7xem9|&r*FzRmR5esKjUBJE6ASq!1BS%x?BW`ztq=qGo98HeWE{zs4 z$2u4nvMF+~dFxN^5Z7hj#HPT(#$z>kkGPKb8b*#LMrD3436{tT3KIsMp36sy0 z|BG9)gXJFcPqvpZ6allFg}CHcN(_=57!)|z+`K2ZNF+<|blSkkkf^}Gz{aWY+_oIt)-_ky=a_#MOvBNRk}q%BBOTlKWR-UyLz&oj0NM4 z$>}ocj4LPi%UCn+n0#7Bon@g@K<;ES83V>0lhtJrQr)sBQqN`87*|f_le1ymG1*&A z4Wi3b#*lHx8b|`8bNSNj*(InW+#NlXgz-Ka}g1B)&<1AGRznLXF6d9NeW~6#dNI0^9!AXF} zAap?*=cHsc**uOT3_NL)VyY5`8Vv$0%srx$y%n6evjv(H3f>%2kYHq<{6SH5@?HgQ z(QF5v2p%S*y8;r2IvN>U6e1c9UT2$pSHT*a%yzcPx{BUfxd|edd07Q}BqZ_#ni3e< z;#+y#Sy-9bF7rAICT!q10W;`1+vG)x4vaG=KUTEGW?VSiWK*R~u^o;aO{@}Y4vL8! z6Xi{Lu$GmKqf-0aTFI|X6tY8F=uHqPD)@<N*#%-g(5#h;xoG9rM-iQ&v4re=v229L?? zI{l2Ulc(z}V=SC(r`yU{u=$9t7$bWtD6u+hnEX)Bgbfs8j4G3*^|b|{!UjDOA|O|b zOb*l6fvfG*S73h#$`lv0CvVW#tOx58;XcTwXb4VueN7TA3>BbE2QE4a6&M&aK}7?f zfKm%{7k5X({syl9+$R!FH1PHEa5$c95Sq z|2G&<5{P)r)o6Y|V1_eqqt$n8%fh zM%Ovw8B?kn-8YJZ-OG^5!@wZO#K0iLJf%W|fq?j}FfcG!a6m*>a4;|kGB7ag1VKhwqH#-9Z4+8@OjE0G`b1*PSFfcH{XqdP>2gFJ8 z91Qhfx4|Sp?&E{{14M(wIaL@K+@J=*XplI_p>a@u!f2RyJ_iGXC<6lnjE0G~a6mi; mqyPUGfcSVK)S!tRlgq5!IL#mdzglziQL78=o9{a_GXek$v=op4 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index c3643a5..e8e3206 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -562,4 +562,19 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return ballot(value, true); } + +#define REDUX_SYNC_TYPE_IMPL(reducer, ptx_type, amd_type, cpp_type) \ + cpp_type __ockl_wfred_##reducer##_##amd_type(cpp_type) __device__; \ + cpp_type FUNC(redux_sync_##reducer##_##ptx_type)(cpp_type src, uint32_t membermask __attribute__((unused))) \ + { \ + return __ockl_wfred_##reducer##_##amd_type(src); \ + } + +#define REDUX_SYNC_IMPL(reducer) \ + REDUX_SYNC_TYPE_IMPL(reducer, u32, u32, uint32_t) \ + REDUX_SYNC_TYPE_IMPL(reducer, s32, i32, int32_t) + + REDUX_SYNC_IMPL(add); + REDUX_SYNC_IMPL(min); + REDUX_SYNC_IMPL(max); } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 0fd15b8..c0f537b 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -195,7 +195,8 @@ fn run_instruction<'input>( | ast::Instruction::Tanh { .. } | ast::Instruction::Trap {} | ast::Instruction::Xor { .. } - | ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 4e82b6a..504d36b 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1853,7 +1853,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Mul24 { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::AtomCas { .. } - | ast::Instruction::Vote { .. } => InstructionModes::none(), + | ast::Instruction::Vote { .. } + | ast::Instruction::ReduxSync { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index d3b81cd..08532e3 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -526,7 +526,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Activemask { .. } | ast::Instruction::ShflSync { .. } | ast::Instruction::Vote { .. } - | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), + | ast::Instruction::Nanosleep { .. } + | ast::Instruction::ReduxSync { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index b92ab95..2a939fd 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -377,6 +377,7 @@ fn run_instruction<'input>( let name = match data.pred_reduction { ptx_parser::Reduction::And => "bar_red_and_pred", ptx_parser::Reduction::Or => "bar_red_or_pred", + _ => return Err(error_unreachable()), }; to_call( resolver, @@ -400,6 +401,25 @@ fn run_instruction<'input>( ptx_parser::Instruction::Vote { data, arguments }, )? } + ptx_parser::Instruction::ReduxSync { data, arguments } => { + let op = match data.reduction { + ptx_parser::Reduction::Add => "add", + ptx_parser::Reduction::Min => "min", + ptx_parser::Reduction::Max => "max", + _ => return Err(error_unreachable()), + }; + let name = format!( + "redux_sync_{}_{}", + op, + data.type_.to_string().replace(".", "") + ); + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::ReduxSync { data, arguments }, + )? + } ptx_parser::Instruction::ShflSync { data, arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, diff --git a/ptx/src/test/ll/redux_sync_add_u32_partial.ll b/ptx/src/test/ll/redux_sync_add_u32_partial.ll new file mode 100644 index 0000000..ab55c15 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_add_u32_partial.ll @@ -0,0 +1,58 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_add_u32_partial(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i1, align 1, addrspace(5) + %"62" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"52" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"52", ptr addrspace(5) %"49", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"55" = load i32, ptr addrspace(5) %"47", align 4 + %"54" = urem i32 %"55", 2 + store i32 %"54", ptr addrspace(5) %"50", align 4 + %"57" = load i32, ptr addrspace(5) %"50", align 4 + %2 = icmp eq i32 %"57", 0 + store i1 %2, ptr addrspace(5) %"51", align 1 + store i32 0, ptr addrspace(5) %"48", align 4 + %"59" = load i1, ptr addrspace(5) %"51", align 1 + br i1 %"59", label %"16", label %"17" + +"16": ; preds = %"44" + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"61", i32 1431655765) + store i32 %"60", ptr addrspace(5) %"48", align 4 + br label %"17" + +"17": ; preds = %"16", %"44" + %"64" = load i32, ptr addrspace(5) %"47", align 4 + %3 = zext i32 %"64" to i64 + %"63" = mul i64 %3, 4 + store i64 %"63", ptr addrspace(5) %"62", align 8 + %"66" = load i64, ptr addrspace(5) %"49", align 8 + %"67" = load i64, ptr addrspace(5) %"62", align 8 + %"65" = add i64 %"66", %"67" + store i64 %"65", ptr addrspace(5) %"49", align 8 + %"68" = load i64, ptr addrspace(5) %"49", align 8 + %"69" = load i32, ptr addrspace(5) %"48", align 4 + %"70" = inttoptr i64 %"68" to ptr + store i32 %"69", ptr %"70", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_s32.ll b/ptx/src/test/ll/redux_sync_op_s32.ll new file mode 100644 index 0000000..a84624b --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_s32.ll @@ -0,0 +1,67 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_s32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_s32(ptr addrspace(4) byref(i64) %"46") #1 { + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i32, align 4, addrspace(5) + %"51" = alloca i32, align 4, addrspace(5) + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i64, align 8, addrspace(5) + %"70" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"43" + +"43": ; preds = %1 + %"54" = load i64, ptr addrspace(4) %"46", align 8 + store i64 %"54", ptr addrspace(5) %"53", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"44" + +"44": ; preds = %"43" + store i32 %"37", ptr addrspace(5) %"47", align 4 + %"57" = load i32, ptr addrspace(5) %"47", align 4 + %"56" = sub i32 %"57", 5 + store i32 %"56", ptr addrspace(5) %"48", align 4 + %"59" = load i32, ptr addrspace(5) %"48", align 4 + %"58" = call i32 @__zluda_ptx_impl_redux_sync_add_s32(i32 %"59", i32 -1) + store i32 %"58", ptr addrspace(5) %"49", align 4 + %"61" = load i32, ptr addrspace(5) %"48", align 4 + %"60" = call i32 @__zluda_ptx_impl_redux_sync_min_s32(i32 %"61", i32 -1) + store i32 %"60", ptr addrspace(5) %"50", align 4 + %"63" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = call i32 @__zluda_ptx_impl_redux_sync_max_s32(i32 %"63", i32 -1) + store i32 %"62", ptr addrspace(5) %"51", align 4 + %"65" = load i32, ptr addrspace(5) %"49", align 4 + %"66" = load i32, ptr addrspace(5) %"50", align 4 + %"64" = add i32 %"65", %"66" + store i32 %"64", ptr addrspace(5) %"52", align 4 + %"68" = load i32, ptr addrspace(5) %"52", align 4 + %"69" = load i32, ptr addrspace(5) %"51", align 4 + %"67" = add i32 %"68", %"69" + store i32 %"67", ptr addrspace(5) %"52", align 4 + %"72" = load i32, ptr addrspace(5) %"47", align 4 + %2 = zext i32 %"72" to i64 + %"71" = mul i64 %2, 4 + store i64 %"71", ptr addrspace(5) %"70", align 8 + %"74" = load i64, ptr addrspace(5) %"53", align 8 + %"75" = load i64, ptr addrspace(5) %"70", align 8 + %"73" = add i64 %"74", %"75" + store i64 %"73", ptr addrspace(5) %"53", align 8 + %"76" = load i64, ptr addrspace(5) %"53", align 8 + %"77" = load i32, ptr addrspace(5) %"52", align 4 + %"79" = inttoptr i64 %"76" to ptr + store i32 %"77", ptr %"79", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/ll/redux_sync_op_u32.ll b/ptx/src/test/ll/redux_sync_op_u32.ll new file mode 100644 index 0000000..3629939 --- /dev/null +++ b/ptx/src/test/ll/redux_sync_op_u32.ll @@ -0,0 +1,63 @@ +declare hidden i32 @__zluda_ptx_impl_redux_sync_max_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_add_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_redux_sync_min_u32(i32, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @redux_sync_op_u32(ptr addrspace(4) byref(i64) %"44") #1 { + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i32, align 4, addrspace(5) + %"47" = alloca i32, align 4, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i32, align 4, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"65" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"41" + +"41": ; preds = %1 + %"51" = load i64, ptr addrspace(4) %"44", align 8 + store i64 %"51", ptr addrspace(5) %"50", align 8 + %"36" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"42" + +"42": ; preds = %"41" + store i32 %"36", ptr addrspace(5) %"45", align 4 + %"54" = load i32, ptr addrspace(5) %"45", align 4 + %"53" = call i32 @__zluda_ptx_impl_redux_sync_add_u32(i32 %"54", i32 -1) + store i32 %"53", ptr addrspace(5) %"46", align 4 + %"56" = load i32, ptr addrspace(5) %"45", align 4 + %"55" = call i32 @__zluda_ptx_impl_redux_sync_min_u32(i32 %"56", i32 -1) + store i32 %"55", ptr addrspace(5) %"47", align 4 + %"58" = load i32, ptr addrspace(5) %"45", align 4 + %"57" = call i32 @__zluda_ptx_impl_redux_sync_max_u32(i32 %"58", i32 -1) + store i32 %"57", ptr addrspace(5) %"48", align 4 + %"60" = load i32, ptr addrspace(5) %"46", align 4 + %"61" = load i32, ptr addrspace(5) %"47", align 4 + %"59" = add i32 %"60", %"61" + store i32 %"59", ptr addrspace(5) %"49", align 4 + %"63" = load i32, ptr addrspace(5) %"49", align 4 + %"64" = load i32, ptr addrspace(5) %"48", align 4 + %"62" = add i32 %"63", %"64" + store i32 %"62", ptr addrspace(5) %"49", align 4 + %"67" = load i32, ptr addrspace(5) %"45", align 4 + %2 = zext i32 %"67" to i64 + %"66" = mul i64 %2, 4 + store i64 %"66", ptr addrspace(5) %"65", align 8 + %"69" = load i64, ptr addrspace(5) %"50", align 8 + %"70" = load i64, ptr addrspace(5) %"65", align 8 + %"68" = add i64 %"69", %"70" + store i64 %"68", ptr addrspace(5) %"50", align 8 + %"71" = load i64, ptr addrspace(5) %"50", align 8 + %"72" = load i32, ptr addrspace(5) %"49", align 4 + %"73" = inttoptr i64 %"71" to ptr + store i32 %"72", ptr %"73", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f413a23..e6d9a58 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -452,6 +452,40 @@ test_ptx_warp!( 4294967292, 4294967292, 4294967292, 4294967292, 4294967292 ] ); +test_ptx_warp!( + redux_sync_op_s32, + [ + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, + 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 357i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, 1445i32, + 1445i32, + ] +); +test_ptx_warp!( + redux_sync_op_u32, + [ + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, + 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 527u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, 1615u32, + 1615u32, + ] +); +test_ptx_warp!( + redux_sync_add_u32_partial, + [ + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, + 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, + 240u32, 0u32, 240u32, 0u32, 240u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, + 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, + 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32, 752u32, 0u32 + ] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx new file mode 100644 index 0000000..18485ec --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_add_u32_partial.ptx @@ -0,0 +1,31 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_add_u32_partial( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 result; + .reg .u64 out_ptr; + + .reg .u32 tid_rem_2; + .reg .pred p; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + rem.u32 tid_rem_2, tid, 2; + setp.eq.u32 p, tid_rem_2, 0; + + mov.u32 result, 0; + @p redux.sync.add.u32 result, tid, 0x55555555; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_s32.ptx b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx new file mode 100644 index 0000000..7e3ec1d --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_s32.ptx @@ -0,0 +1,34 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_s32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .s32 in; + .reg .s32 add_out; + .reg .s32 min_out; + .reg .s32 max_out; + .reg .s32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + sub.s32 in, tid, 5; + + redux.sync.add.s32 add_out, in, 0xFFFFFFFF; + redux.sync.min.s32 min_out, in, 0xFFFFFFFF; + redux.sync.max.s32 max_out, in, 0xFFFFFFFF; + + add.s32 result, add_out, min_out; + add.s32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.s32 [out_ptr], result; + + ret; +} diff --git a/ptx/src/test/spirv_run/redux_sync_op_u32.ptx b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx new file mode 100644 index 0000000..03292ee --- /dev/null +++ b/ptx/src/test/spirv_run/redux_sync_op_u32.ptx @@ -0,0 +1,32 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry redux_sync_op_u32( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .u32 add_out; + .reg .u32 min_out; + .reg .u32 max_out; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + mov.u32 tid, %tid.x; + + redux.sync.add.u32 add_out, tid, 0xFFFFFFFF; + redux.sync.min.u32 min_out, tid, 0xFFFFFFFF; + redux.sync.max.u32 max_out, tid, 0xFFFFFFFF; + + add.u32 result, add_out, min_out; + add.u32 result, result, max_out; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 0570eb4..86719ef 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -695,6 +695,18 @@ ptx_parser_macros::generate_instruction_type!( } } + }, + ReduxSync { + type: Type::Scalar(data.type_), + data: ReduxSyncData, + arguments: { + dst: T, + src: T, + src_membermask: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + } + } } } ); @@ -2272,3 +2284,8 @@ impl VoteMode { } } } + +pub struct ReduxSyncData { + pub type_: ScalarType, + pub reduction: Reduction, +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 29b348a..2701127 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3844,6 +3844,23 @@ derive_parser!( // .mode: VoteMode = { .all, .any, .uni }; .mode: VoteMode = { .all, .any }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync + + redux.sync.op.type dst, src, membermask => { + Instruction::ReduxSync { + data: ReduxSyncData { type_, reduction: op }, + arguments: ReduxSyncArgs { dst, src, src_membermask: membermask } + } + } + .op: Reduction = {.add, .min, .max}; + .type: ScalarType = {.u32, .s32}; + + // redux.sync.op.b32 dst, src, membermask; + // .op = {.and, .or, .xor} + + // redux.sync.op{.abs.}{.NaN}.f32 dst, src, membermask; + // .op = { .min, .max } + ); #[cfg(test)] From d81456a549394be05d40378f5ea576abbe26fdaf Mon Sep 17 00:00:00 2001 From: Violet Date: Mon, 8 Sep 2025 17:41:24 -0700 Subject: [PATCH 2/3] Add support for cvt_rn_bf16x2_f32 (#501) --- ptx/src/pass/llvm/emit.rs | 36 +++++++++++++++-- ptx/src/test/ll/cvt_rn_bf16x2_f32.ll | 41 ++++++++++++++++++++ ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx | 25 ++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx_parser/src/ast.rs | 36 ++++++++++++++++- ptx_parser/src/lib.rs | 11 +++++- 6 files changed, 145 insertions(+), 5 deletions(-) create mode 100644 ptx/src/test/ll/cvt_rn_bf16x2_f32.ll create mode 100644 ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 08532e3..d0e826c 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -1646,9 +1646,39 @@ impl<'a> MethodEmitContext<'a> { } }; let src = self.resolver.value(arguments.src)?; - self.resolver.with_result(arguments.dst, |dst| unsafe { - llvm_fn(self.builder, src, dst_type, dst) - }); + if let Some(src2) = arguments.src2 { + let packed_type = get_scalar_type( + self.context, + data.to + .packed_type() + .ok_or_else(|| error_mismatched_type())?, + ); + let src2 = self.resolver.value(src2)?; + self.resolver.with_result(arguments.dst, |dst| { + let vec = unsafe { + LLVMBuildInsertElement( + self.builder, + LLVMGetPoison(dst_type), + llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32), + LLVM_UNNAMED.as_ptr(), + ) + }; + unsafe { + LLVMBuildInsertElement( + self.builder, + vec, + llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()), + LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32), + dst, + ) + } + }) + } else { + self.resolver.with_result(arguments.dst, |dst| unsafe { + llvm_fn(self.builder, src, dst_type, dst) + }) + }; Ok(()) } diff --git a/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll new file mode 100644 index 0000000..1e19037 --- /dev/null +++ b/ptx/src/test/ll/cvt_rn_bf16x2_f32.ll @@ -0,0 +1,41 @@ +define amdgpu_kernel void @cvt_rn_bf16x2_f32(ptr addrspace(4) byref(i64) %"37", ptr addrspace(4) byref(i64) %"38") #0 { + %"39" = alloca i64, align 8, addrspace(5) + %"40" = alloca i64, align 8, addrspace(5) + %"41" = alloca float, align 4, addrspace(5) + %"42" = alloca float, align 4, addrspace(5) + %"43" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"36" + +"36": ; preds = %1 + %"44" = load i64, ptr addrspace(4) %"37", align 8 + store i64 %"44", ptr addrspace(5) %"39", align 8 + %"45" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"45", ptr addrspace(5) %"40", align 8 + %"47" = load i64, ptr addrspace(5) %"39", align 8 + %"55" = inttoptr i64 %"47" to ptr + %"46" = load float, ptr %"55", align 4 + store float %"46", ptr addrspace(5) %"41", align 4 + %"48" = load i64, ptr addrspace(5) %"39", align 8 + %"56" = inttoptr i64 %"48" to ptr + %"35" = getelementptr inbounds i8, ptr %"56", i64 4 + %"49" = load float, ptr %"35", align 4 + store float %"49", ptr addrspace(5) %"42", align 4 + %"51" = load float, ptr addrspace(5) %"41", align 4 + %"52" = load float, ptr addrspace(5) %"42", align 4 + %2 = fptrunc float %"51" to bfloat + %3 = insertelement <2 x bfloat> poison, bfloat %2, i32 1 + %4 = fptrunc float %"52" to bfloat + %"57" = insertelement <2 x bfloat> %3, bfloat %4, i32 0 + %"50" = bitcast <2 x bfloat> %"57" to i32 + store i32 %"50", ptr addrspace(5) %"43", align 4 + %"53" = load i64, ptr addrspace(5) %"40", align 8 + %"54" = load i32, ptr addrspace(5) %"43", align 4 + %"58" = inttoptr i64 %"53" to ptr + store i32 %"54", ptr %"58", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx new file mode 100644 index 0000000..2bad276 --- /dev/null +++ b/ptx/src/test/spirv_run/cvt_rn_bf16x2_f32.ptx @@ -0,0 +1,25 @@ +.version 7.8 +.target sm_90 +.address_size 64 + +.visible .entry cvt_rn_bf16x2_f32( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .f32 in_a; + .reg .f32 in_b; + .reg .b32 result; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.f32 in_a, [in_addr]; + ld.f32 in_b, [in_addr + 4]; + + cvt.rn.bf16x2.f32 result, in_a, in_b; + st.b32 [out_addr], result; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e6d9a58..6e1b27e 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -200,6 +200,7 @@ test_ptx!( ); test_ptx!(cvt_rn_f16x2_e4m3x2, [0x2D55u16], [0x36804a80u32]); test_ptx!(cvt_rn_f16x2_e5m2x2, [0x36EDu16], [0x3600ED00u32]); +test_ptx!(cvt_rn_bf16x2_f32, [0.40625, 12.9f32], [0x3ED0414Eu32]); test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]); test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]); test_ptx!( diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 86719ef..37b5f6b 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1174,6 +1174,35 @@ impl ScalarType { ScalarType::Pred => ScalarKind::Pred, } } + + pub fn packed_type(&self) -> Option { + match self { + ScalarType::E4m3x2 => Some(ScalarType::B8), + ScalarType::E5m2x2 => Some(ScalarType::B8), + ScalarType::F16x2 => Some(ScalarType::F16), + ScalarType::BF16x2 => Some(ScalarType::BF16), + ScalarType::U16x2 => Some(ScalarType::U16), + ScalarType::S16x2 => Some(ScalarType::S16), + ScalarType::S16 + | ScalarType::BF16 + | ScalarType::U32 + | ScalarType::S8 + | ScalarType::S32 + | ScalarType::Pred + | ScalarType::B8 + | ScalarType::U64 + | ScalarType::B16 + | ScalarType::S64 + | ScalarType::B32 + | ScalarType::U8 + | ScalarType::F32 + | ScalarType::B64 + | ScalarType::B128 + | ScalarType::U16 + | ScalarType::F64 + | ScalarType::F16 => None, + } + } } #[derive(Clone, Copy, PartialEq, Eq)] @@ -1945,8 +1974,13 @@ impl CvtDetails { (RoundingMode::NearestEven, false) } }; + let dst_size = if dst.packed_type().is_some() { + dst.size_of() / 2 + } else { + dst.size_of() + }; let mode = match (dst.kind(), src.kind()) { - (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + (ScalarKind::Float, ScalarKind::Float) => match dst_size.cmp(&src.size_of()) { Ordering::Less => { let (rounding, is_integer_rounding) = unwrap_rounding(); CvtMode::FPTruncate { diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 2701127..3dec840 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -2442,7 +2442,16 @@ derive_parser!( // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; - // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b => { + if relu || satfinite { + state.errors.push(PtxError::Todo); + } + let data = ast::CvtDetails::new(&mut state.errors, Some(frnd2), false, false, ScalarType::BF16x2, ScalarType::F32); + ast::Instruction::Cvt { + data, + arguments: ast::CvtArgs { dst: d, src: a, src2: Some(b) } + } + } // cvt.rna{.satfinite}.tf32.f32 d, a; // cvt.frnd2{.relu}.tf32.f32 d, a; cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b => { From 3da39364e031cabf6c4406fb301e3bca1b425af0 Mon Sep 17 00:00:00 2001 From: Violet Date: Tue, 9 Sep 2025 13:12:31 -0700 Subject: [PATCH 3/3] Make blame ignore formatting commit (#502) --- .git-blame-ignore-revs | 1 + 1 file changed, 1 insertion(+) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..86df120 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +21ef5f60a3a5efa17855a30f6b5c7d1968cd46ba