From 5cb0a9b8e83298ecf17b7aa18360866003c842e1 Mon Sep 17 00:00:00 2001 From: Violet Date: Thu, 3 Jul 2025 11:56:20 -0700 Subject: [PATCH] Add support for `bar.red.and.pred` (#402) Implements bar.red.and.pred and bar.red.or.pred, using the undocument __ockl_wgred functions. Doesn't yet add support for numbered barriers and threadcount, as these are not needed for llm.c. --- ptx/lib/zluda_ptx_impl.bc | Bin 7524 -> 7496 bytes ptx/lib/zluda_ptx_impl.cpp | 13 ++ ptx/src/pass/emit_llvm.rs | 1 + ptx/src/pass/insert_post_saturation.rs | 1 + .../instruction_mode_to_global_mode/mod.rs | 1 + ...eplace_instructions_with_function_calls.rs | 10 ++ ptx/src/test/ll/bar_red_and_pred.ll | 121 ++++++++++++++++++ ptx/src/test/spirv_run/bar_red_and_pred.ptx | 60 +++++++++ ptx/src/test/spirv_run/mod.rs | 7 + ptx_parser/src/ast.rs | 28 +++- ptx_parser/src/lib.rs | 31 ++++- ptx_parser_macros/src/lib.rs | 2 +- 12 files changed, 269 insertions(+), 6 deletions(-) create mode 100644 ptx/src/test/ll/bar_red_and_pred.ll create mode 100644 ptx/src/test/spirv_run/bar_red_and_pred.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6cefc813eb60ce7f302e265090ab069682291723..84d2f2d734e8c76faeaa25d9778b95755e71d944 100644 GIT binary patch delta 3784 zcmaE2b;4?b3S-(tRSh8tS0<-P+)hUXj6^IrB)qwdk|xGW*C(||XlZCoFs#`SGg0B# z2?hp+P6Y-ARR#tIr8E|SO-@anfIFzudVTQ89C)F7` z92(3Wf(=3v#ViQ{T+IpsNqr0=YTR8$5~qTgmWp&aX?R46um*~_28yt1G8(xHh%k0` zHL$n_OmI@*Xc94T3KUo(!qKGQAt}J3*}6!BBZ!4r)1}2FP*7lrMB5^Tt{|>t#wi&J z30w>y5GcXG;JR6q@i%k5fCNLwK{jP!h9HI`3C#<688jJ^HaN;``~Uy{fB6Pe1|P-_ zhE9d{qJsAF3GHPZ>}46u#SQId8SMoN*h?5rD9Ybal)0rSd#sTc!Ddvr^HBCygWMBE znFow=mk!FFTgZD$QRW4sEJM|V_9_kbh8gTt9PNb>?Pm22?3ESlRT}Mu3qTf_3A9%k zuor8vS7tOhtpfdR?Ymj3H=jkn+N6iFU z4HX}#3C#W8c(6r*M~x%NfK8amXhy(@R;+KC}79q0et-H2BOoqoB#az`?-4 z(8R+qW1-+<9)<$9MQ&0YGC@J>%3EBFK$SQ%IvK#Eu#*mzc6ajI@$Fknz% zU||3mXO!UKDkaUraG4RL0)!_pID4YU|3=|mTkll1tXd2@I7KqvP3=AMgY0Wj_WmpD^X^=JsP~<6v z$~rK3fVm(-(qPH1l?FE$5@vyTU~H7&(Hbto&eFidzzB|XMo<`BV`}DQSOyO}kWp_; z3mq6V7z7wVfhE8oX|Uwi`gP0=olFc2Ap02@7)~@OiuSuo2r!6(O=f_E`K>jd8W=Xf zL>P=5lA0YJu47_o&=C-3U|=wkn;{`7=rf^*Q_cMti?h%bUa#WVmFt+NKC=?onB(7n zdTr&b7`E2upEu2!SGU{o!KFVeM(vm17~NdDAc3WU1t^4ULTzlFidnO-z&1Op+`NEt5?xEe+FBjSMUj6B8}XQd80lO^nQq4blu# zHcRtyGAdRxFfa(Sc32FD4i5*mg{4{{VZ*!(p);vLvhAZWu`ArWTL z77YRCi_D4-JU9*&IB#G}FgTqc;>hQ*`JRyIdA zjvNN8a!*+$6*e$pmph6_ZYe}g8f>T#I6b(6(*r2gw@mr(@Uv%fnuuck#a0E0ki-ol zM|2nsFEk1?8CV=(1M3k4>v2KSBd-Z66b|#YD6nnjaTee)7igAPvhau^<2fUqLlfGC zGFa3L4jy!8vGO+HDbZ*;z~}{YxER#nnN3uuK9-}&k(KRi;u#?c zW>pR)ftG*(s3Ro78X*n`l_M=Z3{&JapSpW6vN0UyWo2aB%;O@!V=mJoamI+lxqx%x zW|kuzjSpKDBoYFfc+6SaB^GQva-hU{1BbXtfv13^MqyyVagm9QLgp-%-VQt^8Z#7l za%35Dz&;iSTjqxBN{63UCZ7{ks{aTwQ6}(=kN~qNha$&eR*v`wFC0(G3clbycn*{_ zI~*o3LSq+HXb7`F+zbh9lz;=dSD;zqj1iC1gjOLB7WD$hgPU0nJG3fs_)O>$QfDc5 zJS1*W#kxdXq}+MKW|kuktxEMgDNr|p>L6j3gA*DPIF#8~1ey{GEVw{P3Y18NSscM` z2bHBQJq!%y9CH|jl-XFtn-UJV9xQNnXgSCRD!zzOZ;PV7o`C_RQ<%k3tR>AcL#U-s zLLrK$$?*|ea{>zwL!kl#xb8FJ6Hsbl>E@WxFr$HcBIk{UnGO8&IVBQjH3+Zej1ZjD zAa;ebxVk~%B$r0eT9Cs)IF*Ni;V%ONgAntS7wilS44_zd zs%M5cz=weWT$x3HR5CCyfM`&@yTA+)_h4pV0HxbVW(EdOng=BxkkS_*Q$P#`NUIaX zbY);*0M!dH8q~Vs&|zTkV_;z5WkAp%@p=XW28L(`1_n?z4RwqR3^EK13@{ob4pKi0Y5|M}iG$RyXJlZIU|?W?(I9b<`s0iY z415d>3^4jX%m4rXIVbB%%4>p@zlCapDQ9JZD2LG?an8x5lJfN+)!IxDOF*I^!=0h3 zVKhh_)a(-TMiC%nsVsUY5QAvDSVrGtBT5d^v zd`4zLd~#`KN_n+a delta 3783 zcmX?M^~7p|3gfnksv1Hfu1rpoxSf^=7>QVLNO*G@1x<{Xt`BOF(9+PFU|6#uW}?E; z6ATOtoeB&L@(c_NN@*+to1B_F1%sTLJSP?&3h?0KRC7JX>J+5LvFl`0LLb9I9sy+r zRfS@fga9t37Qtec1qP7;QyLe9Fokfnuo$H}o=|8Ja5^#Nz!^anSA+E&Y9fp?og|Wd z6a|cyi5!ziK9$fTqsi)YVuq4HFi+D2#|#C9_(LKLAP^|Qz@WETm+?1qeFyWm|NsC0 zXWqcppvfSpAowJJ@7)K!PZ#(e9^lJO;QMfa?X3Xcmj*rt%O}m&EzZ_U4qI$sM6elJ ztoJxeUud>zVYWHoY_;RC`U2Yr1^)LA{7((| zzJSbUdlbO`EP>DMsRDnl0sqSd{7)GkFm6yh_3!`x|Md)hj31aE@H23(U|zt_6wP2J za0}!}{)R0FZ$K1WW0+tjAj2?$iNRBXfkA$@xP~*Anz`(Hi4qG2%J%>@s5|#-I1P^Tp+7`h4!T$1B zcAkuAi%%`HnEI>XXTBgVul1OkTx)W!ypkfw~dvN5o{7z=mw*~l2^StObwbKlfXh? zm&s3NVQFAuU<5lD>@qFaEgTHXSinjc7#Kh9>r-kH2*H$rUAEcC z(}6()#bsw&yBQL^;Hp3_>-14&WL&`Dz`()4z~I0Da@j7)9Hxd$6qkK9JIBJpz>Ms& z`c?S`Tnx+DV5Whb7Wy_pn8Ak|CIkw!Ro@EL6c{9+fd&b)-_3yx2Bt7o44^Pe<5^Q|Ga0+Xy6OTf$gou}AgHL9F#y19| zqmB!gFrS#h-fG0B5Uil;v~bG_5zm<-T&+%g3c*@OgY-naEE;?cPVu=og|}6SUm@5+ zgv++U$L*yEf9oecgd?AyPF_XgzmdSwvf%PstVI1w7ogC(fIz3t=Ad>FvkR#D{ z^oD^!&qhTJ24lh5g=|X=k`h7`IM}?bmpB>ZIIye;V6f$A)Z&pAVANu9;BPA6I?U!$ zz{PZ-*w%i*fYBQ4MfRvHLW8URvSFae}=Pd!Mj!V-|$ z^$rZS9Gqaa+8|?{Kx!v|)NTN)od8m5uo$FtG050n4Un-KAf>_}r2-(O9~wZ$Hh`2m z)Ps~R2RZhlJdd=%L3xnUC+vsWTqbaX9Q*U&2F7~986dR*bs)8MAY(&8Y8k;UxC&Am z08(4<_233Z!7q*sUlbB*KuY(694iS{Y5_9#7)YrANa==;Af+EbN(-t%N~=Mp&XofN zmkvnja*)ym>`WJ$KRi8HzkyNk2}o@NBS^(TeGvzKCy56jBOZW+I}XZ%d}+nuz<)^M z4oG-~Gs72!31H!avOEs_PLel3!V^HkGaA9FWkISXuYiQNfPA;W5hQ$2Pr!lSNeV1{ z1!VRLMp>SEX#qzIkaDRDAmvv;$~Q2Alrw@2mp%g$J_Ay|qfrLr2M>+|9gK;s3KBY2 zhTZ}^Au10TcpUki7VxXI7~Bz#(7{;L zs32jGdQiqu{)`Y0vnW#!NB!Xo%>q9dITd-#SqvE2HuD@2c+y}L%>>exz}0?G6Rs_P z;SrYtE=gexHy#-tmV*l#xf#`(X07OippY4=8f;oDMv85V*cAgu<%51CxO$iLS2iX`H5)~L2SQ$V?j3b+x8&8)G%Rz@rO*}j~4A|uk z;gQ>fM{W^BP8w_>s1g!paRrxYpprV}RQ-Y)x5?*(6yq1RDoFU;4iebRa>SwCk>k(< z25E>x1i_AULDJJ7z{|kEaG1A6fo(I7vjC5|K(oY>#zzwv(k0mvPcXAe9psED_#oJL zPKbwDwnL$yK`4h|vY)URd&;T$1eOIgle2~OMWZ;H4lXfmH}p>6nIh5QAmHT4D9tc= zy|5^EiW|eJ{(zqoH4iXMJ})fIwy{+~BEo<22VoJT1v@l$BzK))ILhmIPDp@RmP3t$ zS(*XjH*v5#-H_Z_zn~^-vYCid{XvikVu5Fb1eitH6gdvFa>PG);doM3@CEO|b3!7_ zqHPWn7&kD2gGds?7!>ciLcDv%h{JgT8<$Fx@C;RjPY#W41x(^HbGW;Nn`WyjuxSY- z6|^*RwMfW3WNOpzsQ;q&sE^}ITS7W#S0k5;!WD)a&`1W=jKVAjCp0E-D6_E$G$jNm zaDfsAsNxo8aTIgfEHJC(f*|vgE=L243kd=}4U7zh3JeT<3=9lLd;&@>%qQ489?Wat z`p$Xc!TbikRxXZ%D;k6*b9of5Y7pJURiU`HLE>#a*NVg)4bls_Zxrr=)~KmG3=B6J z7#M_@r|bY(#>~LLu$GB|0mSxUU|;|>Fe0F8K{Tjk?Z6BX-^|3o07?``m>3v93Cawn z4rBM}4eaDrL`Ubi(xdV+5I1+L79<(L6m`k0Y<|# zSTI6-2BSgZAoanZ_6Gw41B?cVgDgm6WMGhCU|@jJAaRiTCaC!^8YB)Bt?N}V8f zEDr<29!OSs3@WhW<1;b~;*(1=Q{pWQOcG5Jlaoyo6D`tAEmJrDmn>zPd`CuhGM}s! GBLe_Y)7!BC diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 7af9729..638ef1e 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -157,6 +157,19 @@ extern "C" __builtin_amdgcn_s_barrier(); } + int32_t __ockl_wgred_and_i32(int32_t) __device__; + int32_t __ockl_wgred_or_i32(int32_t) __device__; + + #define BAR_RED_IMPL(reducer) \ + bool FUNC(bar_red_##reducer##_pred)(uint32_t barrier __attribute__((unused)), bool predicate, bool invert_predicate) \ + { \ + /* TODO: handle barrier */ \ + return __ockl_wgred_##reducer##_i32(predicate ^ invert_predicate); \ + } + + BAR_RED_IMPL(and); + BAR_RED_IMPL(or); + void FUNC(__assertfail)(uint64_t message, uint64_t file, uint32_t line, diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 58341e4..1c2b52c 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -639,6 +639,7 @@ impl<'a> MethodEmitContext<'a> { // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index cc9afa7..92be749 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -75,6 +75,7 @@ fn run_instruction<'input>( | ast::Instruction::Atom { .. } | ast::Instruction::AtomCas { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Bra { .. } diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index fdaafd1..3d56dd0 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1804,6 +1804,7 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Selp { .. } | ast::Instruction::Ret { .. } | ast::Instruction::Bar { .. } + | ast::Instruction::BarRed { .. } | ast::Instruction::Cvta { .. } | ast::Instruction::Atom { .. } | ast::Instruction::Mul24 { .. } diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs index 0f9311a..db6b473 100644 --- a/ptx/src/pass/replace_instructions_with_function_calls.rs +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -108,6 +108,16 @@ fn run_instruction<'input>( i @ ptx_parser::Instruction::Bar { .. } => { to_call(resolver, fn_declarations, "bar_sync".into(), i)? } + ptx_parser::Instruction::BarRed { data, arguments } => { + if arguments.src_threadcount.is_some() { + return Err(error_todo()); + } + let name = match data.pred_reduction { + ptx_parser::Reduction::And => "bar_red_and_pred", + ptx_parser::Reduction::Or => "bar_red_or_pred", + }; + to_call(resolver, fn_declarations, name.into(), ptx_parser::Instruction::BarRed { data, arguments })? + } i => i, }) } diff --git a/ptx/src/test/ll/bar_red_and_pred.ll b/ptx/src/test/ll/bar_red_and_pred.ll new file mode 100644 index 0000000..649efc0 --- /dev/null +++ b/ptx/src/test/ll/bar_red_and_pred.ll @@ -0,0 +1,121 @@ +declare i1 @__zluda_ptx_impl_bar_red_and_pred(i32, i1, i1) #0 + +declare i1 @__zluda_ptx_impl_bar_red_or_pred(i32, i1, i1) #0 + +declare i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @bar_red_and_pred(ptr addrspace(4) byref(i64) %"73", ptr addrspace(4) byref(i64) %"74") #1 { + %"75" = alloca i64, align 8, addrspace(5) + %"76" = alloca i64, align 8, addrspace(5) + %"77" = alloca i32, align 4, addrspace(5) + %"78" = alloca i32, align 4, addrspace(5) + %"79" = alloca i1, align 1, addrspace(5) + %"80" = alloca i1, align 1, addrspace(5) + %"81" = alloca i32, align 4, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"70" + +"70": ; preds = %1 + %"82" = load i64, ptr addrspace(4) %"74", align 4 + store i64 %"82", ptr addrspace(5) %"75", align 4 + %"44" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"71" + +"71": ; preds = %"70" + store i32 %"44", ptr addrspace(5) %"77", align 4 + %"85" = load i32, ptr addrspace(5) %"77", align 4 + %"84" = urem i32 %"85", 2 + store i32 %"84", ptr addrspace(5) %"78", align 4 + %"87" = load i32, ptr addrspace(5) %"78", align 4 + %"86" = icmp eq i32 %"87", 0 + store i1 %"86", ptr addrspace(5) %"80", align 1 + store i32 0, ptr addrspace(5) %"81", align 4 + %"90" = load i1, ptr addrspace(5) %"80", align 1 + %"89" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"90", i1 false) + store i1 %"89", ptr addrspace(5) %"79", align 1 + %"91" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"91", label %"17", label %"18" + +"17": ; preds = %"71" + %"93" = load i32, ptr addrspace(5) %"81", align 4 + %"92" = add i32 %"93", 1 + store i32 %"92", ptr addrspace(5) %"81", align 4 + br label %"18" + +"18": ; preds = %"17", %"71" + %"95" = load i1, ptr addrspace(5) %"80", align 1 + %"94" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"95", i1 false) + store i1 %"94", ptr addrspace(5) %"79", align 1 + %"96" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"96", label %"19", label %"20" + +"19": ; preds = %"18" + %"98" = load i32, ptr addrspace(5) %"81", align 4 + %"97" = add i32 %"98", 1 + store i32 %"97", ptr addrspace(5) %"81", align 4 + br label %"20" + +"20": ; preds = %"19", %"18" + store i1 true, ptr addrspace(5) %"80", align 1 + %"101" = load i1, ptr addrspace(5) %"80", align 1 + %"100" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"101", i1 false) + store i1 %"100", ptr addrspace(5) %"79", align 1 + %"102" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"102", label %"21", label %"22" + +"21": ; preds = %"20" + %"104" = load i32, ptr addrspace(5) %"81", align 4 + %"103" = add i32 %"104", 1 + store i32 %"103", ptr addrspace(5) %"81", align 4 + br label %"22" + +"22": ; preds = %"21", %"20" + store i1 false, ptr addrspace(5) %"80", align 1 + %"107" = load i1, ptr addrspace(5) %"80", align 1 + %"106" = call i1 @__zluda_ptx_impl_bar_red_or_pred(i32 1, i1 %"107", i1 false) + store i1 %"106", ptr addrspace(5) %"79", align 1 + %"108" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"108", label %"23", label %"24" + +"23": ; preds = %"22" + %"110" = load i32, ptr addrspace(5) %"81", align 4 + %"109" = add i32 %"110", 1 + store i32 %"109", ptr addrspace(5) %"81", align 4 + br label %"24" + +"24": ; preds = %"23", %"22" + store i1 true, ptr addrspace(5) %"80", align 1 + %"113" = load i1, ptr addrspace(5) %"80", align 1 + %"112" = call i1 @__zluda_ptx_impl_bar_red_and_pred(i32 1, i1 %"113", i1 true) + store i1 %"112", ptr addrspace(5) %"79", align 1 + %"114" = load i1, ptr addrspace(5) %"79", align 1 + br i1 %"114", label %"25", label %"26" + +"25": ; preds = %"24" + %"116" = load i32, ptr addrspace(5) %"81", align 4 + %"115" = add i32 %"116", 1 + store i32 %"115", ptr addrspace(5) %"81", align 4 + br label %"26" + +"26": ; preds = %"25", %"24" + %"118" = load i32, ptr addrspace(5) %"77", align 4 + %"117" = zext i32 %"118" to i64 + store i64 %"117", ptr addrspace(5) %"76", align 4 + %"120" = load i64, ptr addrspace(5) %"76", align 4 + %"119" = mul i64 %"120", 4 + store i64 %"119", ptr addrspace(5) %"76", align 4 + %"122" = load i64, ptr addrspace(5) %"75", align 4 + %"123" = load i64, ptr addrspace(5) %"76", align 4 + %"121" = add i64 %"122", %"123" + store i64 %"121", ptr addrspace(5) %"75", align 4 + %"124" = load i64, ptr addrspace(5) %"75", align 4 + %"125" = load i32, ptr addrspace(5) %"81", align 4 + %"126" = inttoptr i64 %"124" to ptr + store i32 %"125", ptr %"126", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/bar_red_and_pred.ptx b/ptx/src/test/spirv_run/bar_red_and_pred.ptx new file mode 100644 index 0000000..777b771 --- /dev/null +++ b/ptx/src/test/spirv_run/bar_red_and_pred.ptx @@ -0,0 +1,60 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry bar_red_and_pred( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 out_addr; + .reg .u64 out_index; + .reg .u32 thread_id; + .reg .u32 thread_mod_2; + .reg .pred pred; + .reg .pred cond; + .reg .u32 result; + + ld.param.u64 out_addr, [output]; + + mov.u32 thread_id, %tid.x; + rem.u32 thread_mod_2, thread_id, 2; + setp.eq.u32 cond, thread_mod_2, 0; + + mov.u32 result, 0; + + // Basic functionality + + // result += AND(tid.x % 2 == 0) forall threads + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(tid.x % 2 == 0) forall threads + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // result += AND(true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, cond; + @pred add.u32 result, result, 1; + // result += OR(false) forall threads + setp.eq.u32 cond, 1, 0; + bar.red.or.pred pred, 1, cond; + @pred add.u32 result, result, 1; + + // Negated condition + // result += AND(!true) forall threads + setp.eq.u32 cond, 1, 1; + bar.red.and.pred pred, 1, !cond; + @pred add.u32 result, result, 1; + + // Return result + + cvt.u64.u32 out_index, thread_id; + mul.lo.u64 out_index, out_index, 4; + add.u64 out_addr, out_addr, out_index; + st.u32 [out_addr], result; + + // result should be 2 + + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 84e0731..c594ebb 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -307,6 +307,13 @@ test_ptx_warp!(tid, [ 32u8, 33u8, 34u8, 35u8, 36u8, 37u8, 38u8, 39u8, 40u8, 41u8, 42u8, 43u8, 44u8, 45u8, 46u8, 47u8, 48u8, 49u8, 50u8, 51u8, 52u8, 53u8, 54u8, 55u8, 56u8, 57u8, 58u8, 59u8, 60u8, 61u8, 62u8, 63u8, ]); +test_ptx_warp!(bar_red_and_pred, [ + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, + 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, 2u32, +]); + struct DisplayError { err: T, } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index ca7b9df..6e42871 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -2,7 +2,7 @@ use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; -use crate::{Mul24Control, PtxError, PtxParserState}; +use crate::{Mul24Control, Reduction, PtxError, PtxParserState}; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8}; @@ -95,6 +95,26 @@ ptx_parser_macros::generate_instruction_type!( src2: Option, } }, + BarRed { + type: Type::Scalar(ScalarType::U32), + data: BarRedData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_barrier: T, + src_threadcount: Option, + src_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + src_negate_predicate: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + } + }, Bfe { type: Type::Scalar(data.clone()), data: ScalarType, @@ -1745,6 +1765,12 @@ pub struct BarData { pub aligned: bool, } +#[derive(Copy, Clone)] +pub struct BarRedData { + pub aligned: bool, + pub pred_reduction: Reduction, +} + pub struct AtomDetails { pub type_: Type, pub semantics: AtomSemantics, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 6dedbbb..da14406 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1705,6 +1705,9 @@ derive_parser!( #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Mul24Control { } + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum Reduction { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -2987,7 +2990,9 @@ derive_parser!( barrier{.cta}.sync{.aligned} a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned }, + data: ast::BarData { + aligned, + }, arguments: BarArgs { src1: a, src2: b } } } @@ -2997,14 +3002,32 @@ derive_parser!( bar{.cta}.sync a{, b} => { let _ = cta; ast::Instruction::Bar { - data: ast::BarData { aligned: true }, + data: ast::BarData { + aligned: true, + }, arguments: BarArgs { src1: a, src2: b } } } //bar{.cta}.arrive a, b; //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; - //bar{.cta}.red.op.pred p, a{, b}, {!}c; - //.op = { .and, .or }; + bar{.cta}.red.op.pred p, a{, b}, {!}c => { + let _ = cta; + let (negate_src3, c) = c; + ast::Instruction::BarRed { + data: ast::BarRedData { + aligned: true, + pred_reduction: op, + }, + arguments: BarRedArgs { + dst1: p, + src_barrier: a, + src_threadcount: b, + src_predicate: c, + src_negate_predicate: ParsedOperand::Imm(ImmediateValue::U64(negate_src3 as u64)) + } + } + } + .op: Reduction = { .and, .or }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index f88395d..0e916b4 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -784,7 +784,7 @@ fn emit_definition_parser( }; let can_be_negated = if arg.can_be_negated { quote! { - opt(any.verify(|(t, _)| *t == #token_type::Not)).map(|o| o.is_some()) + opt(any.verify(|(t, _)| *t == #token_type::Exclamation)).map(|o| o.is_some()) } } else { quote! {