From 432f0bb2ece993e7e7efdc219677cb78d9e51c3a Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 28 Aug 2025 00:18:26 +0000 Subject: [PATCH] Add initial code for vote --- ptx/lib/zluda_ptx_impl.bc | Bin 12264 -> 13332 bytes ptx/lib/zluda_ptx_impl.cpp | 43 ++++++++++++ ...registers2.rs => fix_special_registers.rs} | 10 +-- ptx/src/pass/insert_post_saturation.rs | 3 +- .../instruction_mode_to_global_mode/mod.rs | 3 +- ptx/src/pass/llvm/emit.rs | 1 + ptx/src/pass/mod.rs | 20 ++++-- .../replace_instructions_with_functions.rs | 15 ++++ ptx/src/test/ll/vote_all.ll | 66 ++++++++++++++++++ ptx/src/test/ll/vote_any.ll | 50 +++++++++++++ ptx/src/test/ll/vote_ballot.ll | 46 ++++++++++++ ptx/src/test/spirv_run/mod.rs | 27 +++++++ ptx/src/test/spirv_run/vote_all.ptx | 40 +++++++++++ ptx/src/test/spirv_run/vote_any.ptx | 29 ++++++++ ptx/src/test/spirv_run/vote_ballot.ptx | 27 +++++++ ptx_parser/src/ast.rs | 33 ++++++++- ptx_parser/src/lib.rs | 32 +++++++++ 17 files changed, 430 insertions(+), 15 deletions(-) rename ptx/src/pass/{fix_special_registers2.rs => fix_special_registers.rs} (95%) create mode 100644 ptx/src/test/ll/vote_all.ll create mode 100644 ptx/src/test/ll/vote_any.ll create mode 100644 ptx/src/test/ll/vote_ballot.ll create mode 100644 ptx/src/test/spirv_run/vote_all.ptx create mode 100644 ptx/src/test/spirv_run/vote_any.ptx create mode 100644 ptx/src/test/spirv_run/vote_ballot.ptx diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 853b9e152ba2034b0e8e646bb2c4fc6075e64387..8c77c54274fab097c796caaf180aafe5f803f030 100644 GIT binary patch delta 5356 zcmaD6KP6*=3KK8qM71nN^NADX8E;M8xLC36x#TGW8)rqE^PjyQR?hoAWswKN6G@I4 zObiSRff5W1)thq}!nYry|<0sm8m2aFpUUqXE3$M}Ky0Y3xhi)P+O5B`GO$gqYnfy;r#gT-=6 zv-O^4n=Q^38=R$QFxzZ#w1!y6_wxbY!wY;b4fyJF91t9ahkWlZuzgYBf5^c9VFKT~ z0>0b}d=DM?Uoc!~wz=YL$#BHxg0n@>VXG~SHYW~SUvReUVYUTX&}ea`*>Z-n5X|t0;Rmw=(<_D-%?HF9cv&Pk802#l4ka-BNNBECWMhMb z^Dzbn4@QAzUVh0xAbsHII3Uj;bb-B5mqB`#(GQR)qXj7TI6(%r@peeEI56oj$mTel zb6^lkU|hw^A|(Ja`qBfCT6Tw93syKXOmGr#aFp6}kS&#$r9q^Lf$P!(UNecD4QB^Ve?CP%XC*Du_pCUDt7M3^I0p^MYSRN!b!l2pf`83r26h6`7)EizDJ zx!%Gs+rek%HeZEBjT&kkcMKd~OyF!a7icwfT+zmPh@DNiDgRJF6!LD}gb8*`w9 zu$Y-d$DxQB$0`?aGD-6YGi5V@#C;@gBq%U!mg1Pf$f&(}Gp8$K{Tyb7%b*0$00I*j zowY6*v$3=Rw_2PSyc8gMbFv4a$VumW>ih*FC$ zgMlH42gZ^HOD-vhdqk$8| z17VPxazxoA8GJ-QTo86(<5?NieW`&VfI)$Qg@J)Vfx)Og!9z4IVFKd;7i6=vj;Qf+ zFt9Q(Ie_#rfy^>VW0PU<5e4x<7-ZI$H=7z5z$!o@Ag7CFCk8Mc2thW=>)wS1=7yE< zAU@HcIQ3kwBP&BND@a#8h?v0WylS5_Gath;E)W-lK~_FBXA?!UGQhwBLs`B=2i!Q2 zR*-2CJT^QGYTO_$2!l*}FX1C0z#tCdf-uMfE=Divtr%t)fcPK`3fR9X3~cfYK9V3l z2!qu9Gn#RLfrEh;#0O!JahoIq%fo$5o2IrkTh75boIe+hJ;gaA)^G3t?L-SGBR#J4&GPs z&F`5TG?`(_K*4*S?J_&d0XCQrC_HxX?m&0hTIq&^>u1I^21Wzq;Ng1Le1Zwpw9psqQ>7Rf9gw`Cu+eeB26k&E1_v)L6vrZjjV41H_cFIjARdtb?(TO_77GTX%Ac zh%S2*n*s-$kHzF2B0Ay=7&)34H}iQ(utc^o3MsSsNi`@)n7o<%Tf~wbELY4w*;>?4 z1k4r|;*w)2F-US?P~c!|^PJounk+NdX#*oeq5=a0C}#?@II=nSNQh(!G$km^bejBM zRFoaeo`}Tm1hb97;*Cxd4l+oC^@GYuVHQUnP{9Yv_$F5lfP@&7*^WuIC`d%~3gk=< z7qekM2=d8f&B?RGR);| zo-|1@RS83l1_2i4H6oMcW!01#AO#dC20_V8n8i`7CCxEIsHIOrAw!@!q2R*gbXimR z6bGIN9wwu^0uqNh8W~#@A{q{+voW%@Hu#vcIBjH{yj#`~n;LgyHQaKJ`e_Lwmw8zQ zdn6>X1)35V+2*(KxU;Y_vt8zO7EIW{aRO#|IU6HiV?)m5YB_7h+R5AHY~`@H;5n;= z2!qjG0gj_=lZE61b>&4<4BoJEr94>6%3ze#0E%2zi3N;eBF98|(FMgPPmot-+%S2Y zd}RH6Nfr$jrY0sO1u=#b9gLgU6gk8a8rVV&G?dv+NVO43PhfRI*w-i~z|7VU3YaG2qyz><4z`2@i>hM;WS7>KiI*@*rFgIalFHEBdF96(BhDWn8MG%z<|pM<}5I0 zFX1}Krf3Mxo_$RcEervm5)f3t*E19HHF#e>E8I=a1O@x54}^{|sS{MyvWY z0y{SIHrne}|oc0c6lyb_NC^P}2jd4lEBYwiy^0 zd^Z;=bu-q>FfcHzU}Rtb(V*tPfC>Y{5k>|EP@M~-Vd8ff85lrKT^J1#=Tu=}cm-7t zqe0>zbC{SQ=D=u}xD*q_94RJ-da!{Y36Mb=Pz^8|Bn~pj0V)ooVd4=`gJ3jB9Ar)! zR6UFaiG$3kV`5+cH3MPv|Nr&>|AQo^GchoL8p1Fdq=7?+fng031A`(1f(D6$)SqQy zU;s73VKhh_Wbs2LP!pVi0Y-zwLFyTq85o2?u?D3<;vn@R^~{j8C<3KH5+Ds$%naa$ zEsO?;Q28n~z7cnz1XfQA^z-XBGEM^7RXMqGMjE0HthZ+QrG)%)rR!BC5(I9b<`m?N%5Q5P#@wcpyD1gy0ac(w9E&$Q>pzHvW0J#X% zA_v78j0TB=EO26jBqA6M6HjGhV9;V#{?F6h?!@K^|~qX8?EUU^GnJ4=N6$VdBy3^$ZN4sxX=z zLc=7|*&!ttjE0HVvqORoM#IEsK-I%&nD_>0NerW5;s@Cw<;_7x_Q{(x+&Gz;z$N{E z`OS=)dl)C{Yq7HBB<7`Nrc4&l7TtVME0d)@KE5iaG$k>~Fa`oUL1!t+zB<9&nc4;%sq& z*^1%G1papdd~YZ4JwCvOU^583`@om$!2d~r?*jw>%LjbVFR;B8;QP|R$6$NH*>(!E z!x?7VmctfnoTU#m+w5VsopRXX0?2&nj(TU?8O@ecnr*fmv{}+@H-))@{Q%puUmzzj zMle4RVX$XVS-@{5Ai>aakWE>bA&B8fLi0ji22F;f4UIC}KE5o|_<`a2J0Z#Bq0QIvVW zD0k_g?74-!w-jYwFv>DiO=z#uU~ibgUd7Q~7}0Llz+PFwUZv4qxBz5KQy7V$zPBnhB-_R*bm4v2nDb^$g^x?$WgFke9+9xBEi6*nWJzh zf#F9&vmzUt16PAYPQ|eYAT8|R5camtAjjd90)tQj z^3g;Raek3qAvVqjvGhEVTQ{`o95NTrIy7Yk8Oe0HyVTk}szy!AZ z4F@BdL2Be1Oc_)d7#KVy7#IvDXR_GAHwrECx4|AY|Fk7mEn#ffL zkz*|fW(mXycW|EG!s%inaFiucT8%@-XyFUCMFwgdR}CCL?Bi@SPh4>zVTM5nv!UaM zLkUbljRze>jw!JTHyO@wnBXWmu<0yTxMc`f(eYyS{Dk{6c{8J zI2afi7#KJjcqD>$^0_e{xB^wez>soa!l{}KoD9oYVL}SbZ6W_2`#3NdFbFV!3=?3G zG+2_P*u&JY5@ZTk-H8T8QOQJ61|L2!mw|!7fsJQn{f`qH8yG4W6c|_-7#I{7j1oMS zdfi}U5C(C<3PElvJY~QzLkY}fU|;~b>1}W-1ET`m@RS1+xRxDKU~b3+Sr1ltqCqh< zCQwR%fs28q0VKi#vd@c~$B1Eu8i)_V4fQ+{K~t|UVPssu;K0BEaumpaT>H;7FgJ7} zn>DpRR!jh_0;CV@_E$PQS`0HZz65^~;=l$I1BL%?o)!EY z47@NQkg`>hEeF{dgprl4`pUBmUD-d$8hHjE37AfhtyLy7>=YPGAVT$!=scMDp_aiQ z5hBLG01CINs|~*z7%sqtj1oLd*B$)Mc%TIoZeZb*0~5SbnqM+AIWU9y3=9mQSU<~l z2i;{mcn>^aZpcJY_J(bq1VjCSBgn?Rs%`G!X<&E(GY*sOloGc?DE>azDpG9Z-2f|J`E<~KZOer-MPn8xP7kO8+s z(qKu~ONC8n%2cm8wz4@eG{CKOVDrck(9mgPV3@$lw0Ro8B@3hG<_p5%jH(;BmKr1_ zv@misHGbl9H%MyeVB~0Gl(IRfBXg{S(U47%gRNV8vX7`PyAzuN2b+)Hni3eYgrkLk{TdZ%s412a!iyLT~KUt zgMuof!{lWOk@fe*7YJ~0aCJ09FeoUqaY(f&L^L!?1{_?_BcYH1_H|kS??D5&iH)Hg zt;jYoigBFmU~FPj&InD_zn|T=(&zP~CnJld&%+3%H zVZd@gZnCM8PJQ|V306Ubo(C)*n?S`(p#lSgJOcxR5ubok3v(-1hrs#s= zb#QZh+|eL3joV}6t_IP4+!cj;8zeq+uef-kL3#<#jg6-o))dILbeR`G14e ze1Q|jydZZlr1CH@Y-eO(5MrLPL6(7m0hC|1voSD$QnL>O1GoVb0WCj4v?~Jx!v&Be z0|UceHU#qgTz7R@GwE_<6&Z`N3mE1DgmQG z8bB7zgS?>XVKhh_q&|s>fkBdifdNLt#H*PY7(^Ku7-01O|Mma>bLcQI zOkiSQP-H;RFbx}-7#Ktt7#LtQNE~GGQ6^A(j)4J2gTz4=KV@QI5N2RtfYBgvkoq4? z3=E)l8H@&rgVf8_Gcz!Nnk{lr8YBU7kSR07K` zgV7*ykom8n=D=u>IEPL>1H(TSNS24uAPJBLSyqS-U^GY^WU&D&L_LfKiG$SpvqBsM zqe0>z^~q56Fd8HdQr`qJAJjw!GZ+}aApz1bm6d@(hJk?rM#D60WQBwfjE0F{WQ9Zl zjE0H7XNBYf7!48!na{xnQ4gZ)L2(C?0J%tw4dMeB4H5@wuwr8Xw~t{oOxztR4x?e> zA#4l`pk_CWhKVPzLGnF}{{R2KJvjSSut8i3qhT5*LJKk&4HI9)2Fb^Z*e37OcH`U) bDJLJuZsyh5!?^jWZZ6AYTjPSsM~(deeK~;2 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 9de6f61..93b1ec9 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -47,6 +47,11 @@ extern "C" return (uint32_t)__ockl_get_num_groups(member); } + uint32_t FUNC(sreg_laneid)() + { + return __lane_id(); + } + uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __device__; uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) { @@ -476,4 +481,42 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); { return div_f32_part2(x, y, {fma_4, fma_1, fma_3, numerator_scaled_flag}); } + + __device__ static inline uint32_t ballot(bool value, bool negate) + { + __builtin_amdgcn_wave_barrier(); + return __builtin_amdgcn_ballot_w32(negate ? !value : value); + } + + bool FUNC(vote_sync_any_pred)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, false) != 0; + } + + bool FUNC(vote_sync_any_pred_negate)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, true) != 0; + } + + // IMPORTANT: exec mask must be a subset of membermask, the behavior is undefined otherwise + bool FUNC(vote_sync_all_pred)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, false) == __builtin_amdgcn_read_exec_lo(); + } + + // also known as "none" + bool FUNC(vote_sync_all_pred_negate)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, false) == 0; + } + + uint32_t FUNC(vote_sync_ballot_b32)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, false); + } + + uint32_t FUNC(vote_sync_ballot_b32_negate)(bool value, uint32_t membermask __attribute__((unused))) + { + return ballot(value, true); + } } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers.rs similarity index 95% rename from ptx/src/pass/fix_special_registers2.rs rename to ptx/src/pass/fix_special_registers.rs index fc9e028..70b468d 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -2,13 +2,13 @@ use super::*; pub(super) fn run<'a, 'input>( resolver: &'a mut GlobalStringIdentResolver2<'input>, - special_registers: &'a SpecialRegistersMap2, + special_registers: &'a SpecialRegistersMap, directives: Vec, ) -> Result, TranslateError> { - let mut result = Vec::with_capacity(SpecialRegistersMap2::len() + directives.len()); + let mut result = Vec::with_capacity(SpecialRegistersMap::len() + directives.len()); let mut sreg_to_function = - FxHashMap::with_capacity_and_hasher(SpecialRegistersMap2::len(), Default::default()); - SpecialRegistersMap2::foreach_declaration( + FxHashMap::with_capacity_and_hasher(SpecialRegistersMap::len(), Default::default()); + SpecialRegistersMap::foreach_declaration( resolver, |sreg, (return_arguments, name, input_arguments)| { result.push(UnconditionalDirective::Method(UnconditionalFunction { @@ -80,7 +80,7 @@ fn run_statement<'a, 'input>( struct SpecialRegisterResolver<'a, 'input> { resolver: &'a mut GlobalStringIdentResolver2<'input>, - special_registers: &'a SpecialRegistersMap2, + special_registers: &'a SpecialRegistersMap, sreg_to_function: FxHashMap, result: Vec, } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index c46149f..0fd15b8 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -194,7 +194,8 @@ fn run_instruction<'input>( } | ast::Instruction::Tanh { .. } | ast::Instruction::Trap {} - | ast::Instruction::Xor { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::Xor { .. } + | ast::Instruction::Vote { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 1981473..4e82b6a 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1852,7 +1852,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Atom { .. } | ast::Instruction::Mul24 { .. } | ast::Instruction::Nanosleep { .. } - | ast::Instruction::AtomCas { .. } => InstructionModes::none(), + | ast::Instruction::AtomCas { .. } + | ast::Instruction::Vote { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index f6b8ca0..d716a77 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -540,6 +540,7 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } | ast::Instruction::ShflSync { .. } + | ast::Instruction::Vote { .. } | ast::Instruction::Nanosleep { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 79f5e99..bbec9b0 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,7 +13,7 @@ use strum_macros::EnumIter; mod deparamize_functions; mod expand_operands; -mod fix_special_registers2; +mod fix_special_registers; mod hoist_globals; mod insert_explicit_load_store; mod insert_implicit_conversions2; @@ -63,12 +63,12 @@ pub fn to_llvm_module<'input>( ) -> Result { let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); - let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; + let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = replace_known_functions::run(&mut flat_resolver, directives); let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = resolve_function_pointers::run(directives)?; - let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; + let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; @@ -119,6 +119,7 @@ enum PtxSpecialRegister { Nctaid, Clock, LanemaskLt, + Laneid, } impl PtxSpecialRegister { @@ -130,6 +131,7 @@ impl PtxSpecialRegister { Self::Nctaid => "%nctaid", Self::Clock => "%clock", Self::LanemaskLt => "%lanemask_lt", + Self::Laneid => "%laneid", } } @@ -151,6 +153,7 @@ impl PtxSpecialRegister { PtxSpecialRegister::Nctaid => ast::ScalarType::U32, PtxSpecialRegister::Clock => ast::ScalarType::U32, PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, + PtxSpecialRegister::Laneid => ast::ScalarType::U32, } } @@ -160,7 +163,9 @@ impl PtxSpecialRegister { | PtxSpecialRegister::Ntid | PtxSpecialRegister::Ctaid | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), - PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, + PtxSpecialRegister::Clock + | PtxSpecialRegister::LanemaskLt + | PtxSpecialRegister::Laneid => None, } } @@ -172,6 +177,7 @@ impl PtxSpecialRegister { PtxSpecialRegister::Nctaid => "sreg_nctaid", PtxSpecialRegister::Clock => "sreg_clock", PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", + PtxSpecialRegister::Laneid => "sreg_laneid", } } } @@ -885,14 +891,14 @@ impl<'input> ScopeMarker<'input> { } } -struct SpecialRegistersMap2 { +struct SpecialRegistersMap { reg_to_id: FxHashMap, id_to_reg: FxHashMap, } -impl SpecialRegistersMap2 { +impl SpecialRegistersMap { fn new(resolver: &mut ScopedResolver) -> Result { - let mut result = SpecialRegistersMap2 { + let mut result = SpecialRegistersMap { reg_to_id: FxHashMap::default(), id_to_reg: FxHashMap::default(), }; diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index edcaaa1..34c7b15 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -312,6 +312,21 @@ fn run_instruction<'input>( ptx_parser::Instruction::BarRed { data, arguments }, )? } + ptx_parser::Instruction::Vote { data, arguments } => { + let mode = match data.mode { + ptx_parser::VoteMode::Any => "any_pred", + ptx_parser::VoteMode::All => "all_pred", + ptx_parser::VoteMode::Ballot => "ballot_b32", + }; + let negate = if data.negate { "_negate" } else { "" }; + let name = format!("vote_sync_{mode}{negate}"); + to_call( + resolver, + fn_declarations, + name.into(), + ptx_parser::Instruction::Vote { data, arguments }, + )? + } ptx_parser::Instruction::ShflSync { data, arguments: orig_arguments @ ast::ShflSyncArgs { dst_pred: None, .. }, diff --git a/ptx/src/test/ll/vote_all.ll b/ptx/src/test/ll/vote_all.ll new file mode 100644 index 0000000..175edb8 --- /dev/null +++ b/ptx/src/test/ll/vote_all.ll @@ -0,0 +1,66 @@ +declare hidden i1 @__zluda_ptx_impl_vote_sync_all_pred(i1, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_laneid() #0 + +define amdgpu_kernel void @vote_all(ptr addrspace(4) byref(i64) %"51") #1 { + %"52" = alloca i32, align 4, addrspace(5) + %"53" = alloca i32, align 4, addrspace(5) + %"54" = alloca i1, align 1, addrspace(5) + %"55" = alloca i1, align 1, addrspace(5) + %"56" = alloca i32, align 4, addrspace(5) + %"57" = alloca i64, align 8, addrspace(5) + %"69" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"46" + +"46": ; preds = %1 + %"58" = load i64, ptr addrspace(4) %"51", align 8 + store i64 %"58", ptr addrspace(5) %"57", align 8 + %"37" = call i32 @__zluda_ptx_impl_sreg_laneid() + br label %"47" + +"47": ; preds = %"46" + store i32 %"37", ptr addrspace(5) %"52", align 4 + %"39" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"48" + +"48": ; preds = %"47" + store i32 %"39", ptr addrspace(5) %"53", align 4 + %"62" = load i32, ptr addrspace(5) %"52", align 4 + %2 = icmp ne i32 %"62", 0 + store i1 %2, ptr addrspace(5) %"54", align 1 + store i1 false, ptr addrspace(5) %"55", align 1 + %"64" = load i1, ptr addrspace(5) %"54", align 1 + br i1 %"64", label %"17", label %"18" + +"17": ; preds = %"48" + %"66" = load i1, ptr addrspace(5) %"54", align 1 + %"65" = call i1 @__zluda_ptx_impl_vote_sync_all_pred(i1 %"66", i32 -2) + store i1 %"65", ptr addrspace(5) %"55", align 1 + br label %"18" + +"18": ; preds = %"17", %"48" + %"68" = load i1, ptr addrspace(5) %"55", align 1 + %"67" = select i1 %"68", i32 1, i32 0 + store i32 %"67", ptr addrspace(5) %"56", align 4 + %"71" = load i32, ptr addrspace(5) %"53", align 4 + %3 = zext i32 %"71" to i64 + %"70" = mul i64 %3, 4 + store i64 %"70", ptr addrspace(5) %"69", align 8 + %"73" = load i64, ptr addrspace(5) %"57", align 8 + %"74" = load i64, ptr addrspace(5) %"69", align 8 + %"72" = add i64 %"73", %"74" + store i64 %"72", ptr addrspace(5) %"57", align 8 + %"75" = load i64, ptr addrspace(5) %"57", align 8 + %"76" = load i32, ptr addrspace(5) %"56", align 4 + %"77" = inttoptr i64 %"75" to ptr + store i32 %"76", ptr %"77", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/vote_any.ll b/ptx/src/test/ll/vote_any.ll new file mode 100644 index 0000000..bc24522 --- /dev/null +++ b/ptx/src/test/ll/vote_any.ll @@ -0,0 +1,50 @@ +declare hidden i1 @__zluda_ptx_impl_vote_sync_any_pred_negate(i1, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @vote_any(ptr addrspace(4) byref(i64) %"44") #1 { + %"45" = alloca i32, align 4, addrspace(5) + %"46" = alloca i1, align 1, addrspace(5) + %"47" = alloca i1, align 1, addrspace(5) + %"48" = alloca i32, align 4, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"58" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"41" + +"41": ; preds = %1 + %"50" = load i64, ptr addrspace(4) %"44", align 8 + store i64 %"50", ptr addrspace(5) %"49", align 8 + %"35" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"42" + +"42": ; preds = %"41" + store i32 %"35", ptr addrspace(5) %"45", align 4 + %"53" = load i32, ptr addrspace(5) %"45", align 4 + %2 = icmp uge i32 %"53", 32 + store i1 %2, ptr addrspace(5) %"46", align 1 + %"55" = load i1, ptr addrspace(5) %"46", align 1 + %"54" = call i1 @__zluda_ptx_impl_vote_sync_any_pred_negate(i1 %"55", i32 -1) + store i1 %"54", ptr addrspace(5) %"47", align 1 + %"57" = load i1, ptr addrspace(5) %"47", align 1 + %"56" = select i1 %"57", i32 1, i32 0 + store i32 %"56", ptr addrspace(5) %"48", align 4 + %"60" = load i32, ptr addrspace(5) %"45", align 4 + %3 = zext i32 %"60" to i64 + %"59" = mul i64 %3, 4 + store i64 %"59", ptr addrspace(5) %"58", align 8 + %"62" = load i64, ptr addrspace(5) %"49", align 8 + %"63" = load i64, ptr addrspace(5) %"58", align 8 + %"61" = add i64 %"62", %"63" + store i64 %"61", ptr addrspace(5) %"49", align 8 + %"64" = load i64, ptr addrspace(5) %"49", align 8 + %"65" = load i32, ptr addrspace(5) %"48", align 4 + %"66" = inttoptr i64 %"64" to ptr + store i32 %"65", ptr %"66", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/ll/vote_ballot.ll b/ptx/src/test/ll/vote_ballot.ll new file mode 100644 index 0000000..350d837 --- /dev/null +++ b/ptx/src/test/ll/vote_ballot.ll @@ -0,0 +1,46 @@ +declare hidden i32 @__zluda_ptx_impl_vote_sync_ballot_b32(i1, i32) #0 + +declare hidden i32 @__zluda_ptx_impl_sreg_tid(i8) #0 + +define amdgpu_kernel void @vote_ballot(ptr addrspace(4) byref(i64) %"41") #1 { + %"42" = alloca i32, align 4, addrspace(5) + %"43" = alloca i1, align 1, addrspace(5) + %"44" = alloca i32, align 4, addrspace(5) + %"45" = alloca i64, align 8, addrspace(5) + %"52" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"38" + +"38": ; preds = %1 + %"46" = load i64, ptr addrspace(4) %"41", align 8 + store i64 %"46", ptr addrspace(5) %"45", align 8 + %"34" = call i32 @__zluda_ptx_impl_sreg_tid(i8 0) + br label %"39" + +"39": ; preds = %"38" + store i32 %"34", ptr addrspace(5) %"42", align 4 + %"49" = load i32, ptr addrspace(5) %"42", align 4 + %2 = icmp uge i32 %"49", 34 + store i1 %2, ptr addrspace(5) %"43", align 1 + %"51" = load i1, ptr addrspace(5) %"43", align 1 + %"60" = call i32 @__zluda_ptx_impl_vote_sync_ballot_b32(i1 %"51", i32 -1) + store i32 %"60", ptr addrspace(5) %"44", align 4 + %"54" = load i32, ptr addrspace(5) %"42", align 4 + %3 = zext i32 %"54" to i64 + %"53" = mul i64 %3, 4 + store i64 %"53", ptr addrspace(5) %"52", align 8 + %"56" = load i64, ptr addrspace(5) %"45", align 8 + %"57" = load i64, ptr addrspace(5) %"52", align 8 + %"55" = add i64 %"56", %"57" + store i64 %"55", ptr addrspace(5) %"45", align 8 + %"58" = load i64, ptr addrspace(5) %"45", align 8 + %"59" = load i32, ptr addrspace(5) %"44", align 4 + %"61" = inttoptr i64 %"58" to ptr + store i32 %"59", ptr %"61", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="dynamic" "denormal-fp-math-f32"="dynamic" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="preserve-sign" "denormal-fp-math-f32"="preserve-sign" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index ca412be..106a792 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -401,6 +401,33 @@ test_ptx_warp!( 225u32, 237u32, 235u32, 236u32, 237u32, ] ); +test_ptx_warp!( + vote_all, + [ + 0u32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1 + ] +); +test_ptx_warp!( + vote_any, + [ + 1u32, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0 + ] +); +test_ptx_warp!( + vote_ballot, + [ + 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, + 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, + 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, + 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, 4294967292, + 4294967292, 4294967292, 4294967292, 4294967292, 4294967292 + ] +); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/vote_all.ptx b/ptx/src/test/spirv_run/vote_all.ptx new file mode 100644 index 0000000..a6cf25f --- /dev/null +++ b/ptx/src/test/spirv_run/vote_all.ptx @@ -0,0 +1,40 @@ +.version 7.0 +.target sm_70 +.address_size 64 + +.visible .entry vote_all( + .param .u64 output +) +{ + .reg .u32 laneid; + .reg .u32 tid; + .reg .pred not_first_lane; + .reg .pred result_pred; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + + mov.u32 laneid, %laneid; + mov.u32 tid, %tid.x; + setp.ne.u32 not_first_lane, laneid, 0; + + mov.pred result_pred, 0; + // IMPORTANT: + // PTX documentation states: + // "The behavior of vote.sync is undefined if the executing thread is not in the membermask." + // You might think that means: + // "The value produced by vote.sync is undefined if the if the executing thread is not in the membermask." + // But it actually means: + // "The instruction `vote.sync` is _undefined behavior_ (in C/C++ sense) for _all threads in the warp_ if the executing thread is not in the membermask." + // Compiler _can_ and _does_ skip vote.sync entirely if it can prove that the membermask does not match execution mask + @not_first_lane vote.sync.all.pred result_pred, not_first_lane, 0xFFFFFFFE; + selp.u32 result, 1, 0, result_pred; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/vote_any.ptx b/ptx/src/test/spirv_run/vote_any.ptx new file mode 100644 index 0000000..c43aa15 --- /dev/null +++ b/ptx/src/test/spirv_run/vote_any.ptx @@ -0,0 +1,29 @@ +.version 7.0 +.target sm_70 +.address_size 64 + +.visible .entry vote_any( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .pred tid_is_greater_equal_32; + .reg .pred result_pred; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + + mov.u32 tid, %tid.x; + setp.ge.u32 tid_is_greater_equal_32, tid, 32; + + vote.sync.any.pred result_pred, !tid_is_greater_equal_32, 0xFFFFFFFF; + selp.u32 result, 1, 0, result_pred; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/vote_ballot.ptx b/ptx/src/test/spirv_run/vote_ballot.ptx new file mode 100644 index 0000000..52b31c3 --- /dev/null +++ b/ptx/src/test/spirv_run/vote_ballot.ptx @@ -0,0 +1,27 @@ +.version 7.0 +.target sm_70 +.address_size 64 + +.visible .entry vote_ballot( + .param .u64 output +) +{ + .reg .u32 tid; + .reg .pred tid_is_greater_equal_34; + .reg .u32 result; + .reg .u64 out_ptr; + + ld.param.u64 out_ptr, [output]; + + mov.u32 tid, %tid.x; + setp.ge.u32 tid_is_greater_equal_34, tid, 34; + + vote.sync.ballot.b32 result, tid_is_greater_equal_34, 0xFFFFFFFF; + + .reg .u64 out_offset; + mul.wide.u32 out_offset, tid, 4; + add.u64 out_ptr, out_ptr, out_offset; + st.u32 [out_ptr], result; + + ret; +} \ No newline at end of file diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index f198795..d5122bf 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -3,7 +3,8 @@ use super::{ StateSpace, VectorPrefix, }; use crate::{ - FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection, ShuffleMode, + FunnelShiftMode, Mul24Control, PtxError, PtxParserState, Reduction, ShiftDirection, + ShuffleMode, VoteMode, }; use bitflags::bitflags; use std::{alloc::Layout, cmp::Ordering, fmt::Write, num::NonZeroU8}; @@ -673,6 +674,22 @@ ptx_parser_macros::generate_instruction_type!( src: T } }, + Vote { + type: Type::Scalar(data.mode.type_()), + data: VoteDetails, + arguments: { + dst: T, + src1: { + repr: T, + type: { Type::Scalar(ScalarType::Pred) }, + }, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + } + } + + } } ); @@ -2190,3 +2207,17 @@ pub enum DivFloatKind { pub struct FlushToZero { pub flush_to_zero: bool, } + +pub struct VoteDetails { + pub mode: VoteMode, + pub negate: bool, +} + +impl VoteMode { + fn type_(self) -> ScalarType { + match self { + VoteMode::All | VoteMode::Any => ScalarType::Pred, + VoteMode::Ballot => ScalarType::B32, + } + } +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 9c08f95..db0edd1 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1792,6 +1792,11 @@ derive_parser!( #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum FunnelShiftMode { } + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] + pub enum VoteMode { + Ballot + } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3717,6 +3722,33 @@ derive_parser!( .atype: ScalarType = { .u32, .s32 }; .btype: ScalarType = { .u32, .s32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync + + vote.sync.mode.pred d, {!}a, membermask => { + let (negate, a) = a; + Instruction::Vote { + data: VoteDetails { + mode, + negate + }, + arguments: VoteArgs { dst: d, src1: a, src2: membermask } + } + } + vote.sync.ballot.b32 d, {!}a, membermask => { + let (negate, a) = a; + Instruction::Vote { + data: VoteDetails { + mode: VoteMode::Ballot, + negate + }, + arguments: VoteArgs { dst: d, src1: a, src2: membermask } + } + } + + // .mode: VoteMode = { .all, .any, .uni }; + .mode: VoteMode = { .all, .any }; + ); #[cfg(test)]