From 00d7cd131b79e7e407b2a262f28c93f3242ec840 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 17 Sep 2025 01:51:29 +0000 Subject: [PATCH] Add mma --- ptx/lib/zluda_ptx_impl.bc | Bin 18824 -> 24456 bytes ptx/lib/zluda_ptx_impl.cpp | 158 +++++++++++++++++- ptx/src/pass/insert_post_saturation.rs | 3 +- .../instruction_mode_to_global_mode/mod.rs | 3 +- ptx/src/pass/llvm/emit.rs | 3 +- .../replace_instructions_with_functions.rs | 29 ++++ ptx_parser/src/ast.rs | 49 +++++- ptx_parser/src/lib.rs | 26 +++ 8 files changed, 265 insertions(+), 6 deletions(-) diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2b62aebd51921f051f18eeed20e95795f515aee4..bc375c3b14ff5032dff6fc592f683e1b081a93be 100644 GIT binary patch literal 24456 zcmZ>AK5)-egn@yTfq@~3$3Vp^a$nls-&@vm2R8UHGB7YG@-Q%z2{ABmGbu7KF)%QQ zF)%Q&H!4qbJi+40CSYVaiLr%8m6bzMj-@-PLB^HIX%e^769FR;3l0fyE~BCjkAfqL z$}U3fJpzuCm{Lw~Y!p#&cI!|$$!Jv6BB7zd~YA|qT z;F`?pD#_L1!YQ!SL4l*gF;bv|k*U!^L7>BtldYJ=QIM@)iHFJ6frrabf`_fuL4mC~ zh=;BD-~yiJg9}7BSdt7hoDu^%jz9RQaVWu{M}cSQ4IZ{;M?)@-BMgl*7&V%X2e-AS z7tA$ronr7)HfZfx^VIvf%N*eT>tT4uKweXE2k`) zz@Q*^n4`m-%bG-$Muf7Yq}e?RGd@Y;m^gVU|W>JM1`Yy@k>CgtNr~W}79=)^nVtw>Vo|V76lT zqQL((fRDkD|BC?M;|+Yd4E&!y@O=s3d+fmfUV%SXf$wVq-?IR|PZ!uePvH9%z{eoR zAl;$J;E~Ae$ibkQqtL~u(8}B5$gpI?L(6~v|NpON@MHYI{D7Z<^F{Lku{K^7A%>a_ zjUFrsHy?1H0(pY}fE2?(yMHOj{RbsL?w`dBa(^?!k`0X&LJcBj0uz`f9BM1+fcWJa zg8}OSeg+E$l?D7v(F`0#8tmm2?3Eeq1qtnD2J96X&1D?yRSxV8GM^609$F}SDN*JI z5}VY(gBN4Z;yGB+0To?(=O*{jfAQ_xC%onSBLXfN|@V^t_dpm*e@c}jjn?c~+2fkbf{!apY9~k&wKHz(Pf$gmT-ww%&zv*n=8l4iRp%nj@Z*wlW5T*Mf` z{6K`k9+dda1f&=`4zeW+GXyalNoYRE%b>}Sw4qUE7f78vLomZ1h6Yfs?^I|nDrhgC z&|b#DUY5aJ+yGBVCluxHD9YSYls(qSi(oS<+<7Q_t3mFGqRazExl0FS&%qNz)r9sc z4fci^>{T4?g%Ry$4eXT_>{S}=g$qFDn+dd68L$^?uvca@S9-A5YOpi#H}Lj+2f2s2 zfU7~Cp`O9RK|V*I{XIws646QyN&}kqoLSgFjmn zb~hY44l;Z=$d=5*(;&iNcq3r~BP<{BpyY9Aanv!u#nFs|AkXk0&}67%_ybC`0ymg8 zEVOI}iOM&Ck_0%vC*EjS;K-nIfa$_PHWgkL1xB$2w;Wb5GJZJ7mc$G3*AiYXX%+@e z7DkaJ2c>fYZaFY`C~_=N;Jug7{16&q*$f`i)69Q>%w~*Wc3}FYumRN0@swa-uwh_e z5N=9vQsi+-^pW6GnJJ(l%$2IZm3(Lsk4vJC&tsblEef*)JcL_1XDRp`W?8sNP2jqN zh%iU0LKmltsld^eM^YV!W*BHN8!lYIw#Yz@1djG;6%eL zg9Itl!~;56O$EHdEk+V*0+$^`gqe(IILu*`TGA-x!*MO4=>S`G{SkK( z6u6QTrJgscapW8LaIE2Mk>+tROe|<_oT0Geph02i!9m;3);;%+RGFm zBzwt>_JSJ@a!)U?eOSO==D=Q+z+Q5J?SlZA0T#(+;4fOhUaY`gWx!r^f$f6}4>rSb;rv0w3Ik-vmHXg$eAr3H(J4V0$mHeKP>*0;x>kFKGaK;R4&Y31F29>=hT- zz6r2bDnM-qX@r}9f$duWd!Yh*6`0IT;4frg2N^5CUZ4O9kZ%g?1rF>L3)qVo*h?DN zb06@50t6gZo58+V7{H!ez*TsG?cD^vssi?^2KK56>|XpFfJt35EoZr|g}KEDj7X zFoDr|Rd5C$iv$Y;Ljy>Vqk%^v=wj1S21WxA7mQO5Oz@K3#>&CK#J~jB%A~;D7NYvl z(1AgOK>*y07hsSySR%!}mXYZILj!{V1IV-!4T_=hIeZL0>|j$E7#JMbcvco1Zfalv zn*tU!O7M6p7s(T=W2$8Dwqh^N#1>|EDS6Rj0zx8#tDqh zTE&O-7-lGexF8I2($C;F#sey#eilds2w3s;>;NAr}DMyq|lEFs=tODe6HlCGH z-Ip2|Rv?FfXk5Ys#seU~!~n(v zDab~7-Mi4h+;9>ez$Y3Mr=II|WM!BPDo?>?fs(;KXJ$SIHEu8;6gA9kAy3WOMA2*v zFtET-mM_tPVOj)_4G+UI6w}^I_(%vaz+)BUeix$`Rtz(YV0MB+^=}FTn>>S$6if`H z?4Qw$0}LDuyf86Pz;2QZkQZQ(feC?>oiXNU=4VhtQKl~CgRZQ^5h1dpxj7{4+y9zYJ;SMkm7nHxNrVah;Zd!FqwI|~Cm zVS>VT2k#Dahpm-vILOWrjH2utuZm&zcjEyML&7JR^FV2P>19UA1&jbsWsJxD>;GjKnjXqR(Up} zXWOaM4q-YkA{fJYjmxn(Pf!%YdFEPE7^N8wl%UxAD0m^y9R`MF#&Ac2iuxT~*$>zl z4CbPQOhlLfs|7;?Cp-gyqWkGCCd~tk3_clf{h;WU6w~15Xn^Md22i?{cT47=amJCb{@U#PpWlQcz83xqMvT8o7J2P4Z;l(|Jk>voW zE&#g^6mc8Hl(^9>Ezu1rW?*T67Yd+ssI|5A7!#_cQ**f|vZA_e)n~DD5BSgu$)}H) zq@FMefa?;lwGM0^IRYAMR2Ug1ure`#t5hSo88VWBJ`;LCqgO1>LMx=FJlho%%dT~{ zR3Yi+Jh|!7pG|i1aGigiYH9notogvwpPVUt%WtM`@^}F z$U~Vl5k}h%XN#7@Jx<_mP4SFEo@Wf8;hZlGe8xOZa)&@YA(=D-RW=`+lrZd?zu z?U{<6k^=T>4)#iq_JRWTssiS+iuMA9cF8k>d~YW3Jv89YGvI&xf$u{ApRvFZ*?WpI zHxy+LCCa54bXxCew!YwOxrNzgPqR%6v-Oz|B4r8e)fw#O8SOUW7tLs|oYAQB?gQVu2Yg=-@Vzl$PC5bV^h3J-GItVXPASS9 zX_P&c$aZalv*ihA%NbyD$6<>p&K4ca)>97a?`gLA(`0>x*?Plap5vMA1rhB<1=g;^B@*o=3)&^m1@pgS;QtiB_hbTS{OO?r-b z%wqw*HwXBh2Jn40P(FCcQSQk@@f#0i?=i}yc}QAsXtv$bEZuV0qQ}|t%wd}y2Q5z= zw!Xq_bH-y!AqRU|0(&92^D24!7~h)&{s#x1OCqs z*j_&1dlJC+!$A4bn?#vIjj}f$ioa2meWoaTL{avJqU;Mr`80*O)_a_7cQ9K{X|_BQ zv8SK_G!|9G!CvLjUS7~HdHOludjtN55BOd>@V{KZ|J2|xBXjWz_Hv2#_!;d51)vZH zjU>%zY;m-FaM)slv+V?D%QFr~7Ra7jC<_`mIVh8+l5D{D5EQ@sFCF-w8aOL9mrAe~ zd$h+dXs>B#FPgD&3ZqTWVH?okfwR?~X6rMDdyevWD2h5I$~{w*J$Fzxs8J-%BE5m{ zsQ}*xaJaoU;QKrQG@!`9P^rMc;KsnfP$cJ};M~C@$ysFNp`hH@GjZZYmjwbEDK>^i zPJI&;7e^>8XVmy;a}AUi9erICl5-OC(iO^5i;6Sz^Arp%^$he3(lbjGG%`v`3W}}t z_0uy;GD?&5lJj%*gA!9x^Yi=*QuCbia|=pKQuT9k%5rrJit@8klS>qe@{@JV^i1@O z6^u-cj7$_#jLpo_5-m*9jMGw*lTu8Q3=J$y%u|z6OjD9f&5aUM4UH0wP12Gy^KMn%+uT;Va#%b$9bZJG0*9w6CMdoL5azmvx_~IwTl-^s~Rw@lwe@cWnf_7 z<|@)?FK=M4iYS=jY`X&#RhAbHH$9U%1{q5{2uc<&z-d~{VUc|pi0g}gVAY$7%#FZaMVtg#S%7>21yM&7&)34H}jtnwAsvZ z*nyRi{gApmV*tQoT^EXb@yR!s{TwGh=SR2ZaUAg3L?|js^w?nwXOq7{Kup-KN~I zji>WSj^rwaTShTKn^Ps48EzR(X?&6+c~7BX8&9EvEhs(}4H>|lOYSa54hB$!F)%QI z;tbTD6lQS*#~G-#@nr=ggEAYdKvM!ktMdj%uo_m-_ymh1TeF0r#sdYGgAPpX3eAi; z*yTR8;+A`aNA40r4%7#Pxe3&*F~^p^jsK0G zBpDgPd8$|!9_D3IWZTT+BEVzL(F#f$&IMeJn^}%@98_diV zmlu4YbI{;IBL@?^#z8hn*ouSQ=Z<97|3<-vPmECaF*34k=5Z6?G3RN4xX)|iW|m_G z%!wTbb(jNLj!Rg2C-g|j6bUpX82oiAme|a4%;97M2aBSJ2*SP6U^^gT0BUBLNHA*H zJ57=2<${bk9Oh*OSpy0N8K^a!6G6eyC~v}14>7rGz01{2(!4M`0LY3)Npr2gnI*PArsg}Ly(Qv z^aPITFd7CJ3N#tSG<;y~lbdo!u<@S28y!c(07DKINd}OYz;PjlE61cuq)wS+v7o0wW^q zAnE$xgvJC8Wi}Rpri6qoO(1(fegu_$urvrt93~Qs4CWkj4l20ua9MFU8eB+Fl*r@g zFt`EDM+DW!psELTQeftT++m{8@o8nFV1v?vvl)<5l7V5mOvS;4+ZC*m5{@`>G&Md| zJkHuIx9|ffv~}ZmtksscO3h|q>5cUE%J(C#^Z$t(7%~`N}0cptKz1vIQ=qH(~*)@e#Sr=5r^qtb;14jh3@CK?=0 zA_|PWT#W*(8V=1$3NDOn7Z$L9+Y_KXs@>woSi;|2c+gl^$0d7a6C#iaARmv=qP4@ z6b=k*bq+icJWNJ+1tbo2G%~g*L^K>^=VD}QZSXN?aoPx}>mhY2#5_nF4&A(NP*CGD z^Eo@3nV^0>to-3*U|=xOKsU2KnMIMK>4kicV9z`0eR2#8aPt==h+O7n73`6aC=_T) zU}U@B%Hz($%FK3|*I6)O1IG!NN5r`p`5GG_?E+By6J{T%PJAM%%-|?ce>ky+fyGha zAH!k6W`_nwp(c>OTMd#D7!*0g5(r?Hlz~cCaCnJDa)nT4s1DIR0*pX2k)br>lD8vr(h5&nPfJC^-gvNATh6#eqDtt}?ItLW^ z4k>s{dDy|Ebs)jQse{E}MxKX6;;w@nEIewAGxIoDGLIZ!a1!7-us}>H%gRZJ30$p# z(-)}!0t-tl?mfogAi(l49~`jvGXf5PlrS6?baQNAlxqdKw*%o`c?}i~7N#a9B?T1s zMj2=*vw6sXlB1(cj6i2{fh#b2}(+e34*W0m~fnt!~h`V;5*ja9}G=5IK_J#Ly!l z1IqQ1R~U;#7IQKh%u-;HWJv;tj?RGx+^vUr7$glH9x1RKWa4?0z|E3+grT!pfQ8$j zgIQ04&4!5)A$vfXDFrUutS`}mh>ZgWIHufeQ)ft%W>TJ!%fXSjO);U3*9p|72f34b zmaFH4!x|lq91L+01}qHnph5y#mxBiXV15E+L=g!_hQqv00&L1`oZt*naD+8uLHYv5xCXW=59tT2nG3|19K1J?&FYT83Pw=# z&!wPM>P>-j5qpBc-2e_pHpUsM8S65cW%D?0Iq;-Oaw(rsVc5)awt)F5s6@UXbFiTK zGg|^AlC+LnJ_FNb-h2ar%e+?u7BH|IOkh+p;A;}z;lTH!%i-e;W{DFF zf*$M#S$WI_7$u%Gbwo4G$=kuKe&_&WqYPWYK`jM7MvjCH>7`^B&b}I9SlUfiJ<}uE2^WgVP%f z9N#%Iq;bv?WM<}YG}!P=v9Q%R#xXFXRd~k(elunl8NnQIn6PCr&E)uyIHBE1hbKdD zhF~)jhoiv{k417u5^oBe1Nc)W{4lr?$@W}OZv*=o1qSgu&F6$NJQW@>)m$+Cz^H$K z-FAumho+hf+@BKPpJ12GM~stNfBy@WXj#q_=(+MM(TwX2|Et{WH*p7G-wcDVa`_j_t5gCyx8J;hQ^m|^%5N~ z+p;tQ7&t+pEg^J}Ge$wE#&tnO4+~F91BXiI1Ex0V4hEhg$0NFoe&%ypgcBb$IPfv= zNKhAOXJ9di;1t-AtXi@|(SX}vhw23$fhKPTChXs%>yB4XS7Jz39Ojd>Ujy&cZ;u2?!BpeG4EtOPw#NnfVpkT+t0|yeO9R!t1iOkKB zY{?vL(j7fKX_6gE5{4SDKtpxG&5xOQjU0}s9+Y)I!a3)hkO;GAheN_)hocDs*BTC$ zGau$X)Xt_})1Bzh;=tCiL-D}NwgVb`lFhplco?>t?r7vb;jmo#L_&a`0H`Jr%3)KR zp=!Y9b%sHVA)F(^QK0KIv+)5QrU{1>+N4`-8JZ*6o^klyU_Q$r=>DNG?f|>Roq`Xb zTH3(F;Eur`Fw4{6j>8v5$umY8-UdklS3qMM&$y1W)-gybJeu@?Po%?wQKV#Ag5(MZ z7FKUIEjH(kLLzPt+1jLA0v(!#-4Zv7bO?Y4r!4LYNE}sMULZb)X~8!x^8!Y#7bVPE zJf~O>r8RUoH=RAevYF+KL#O0}mPBE8Uxh{)++Q#q$ z7J;DnvjoSVH#q)y7#JA9@du5&AB@TjA}bh;L>2xxusTgq_~YQ%!w{mN;K(W@kZ;l; z09rA_$*Qmz6kmd%UZ!J#(~`ryE-Y-Dd0GWz1VR16#)CSC9as+Ywy>}r=IuyfyUZIT z;$zNo(ZD<4j@Vt$I9L)}5<`Qd0}oT--4;m}g;tPh@)Bo^G~5afEkDeA1k@EhB`EXp zWq~sTs7=$La{gbkN@51OA5(wNRR=bVrRvufG{c2;=_7A2Kpp1X(H zig_9@q#E4dg|v}DeQo6SBxsxqIv8f4%;pLjZ*_QVz+=wh1Zq`7^nk(yrU%rH`?7*j zst4R>0}rB^xO3Pf1%R6Aj3Imn3!DvF4zhvkHuSbDWULvi2gxpI2it}3P(iam18gjn zXm>Ip+zHC7!Yl_n7(xASiPIg9;33Zm797$LfAE9j8PeAURlvyZF=v7K!G-T2o1!78 zv%uZgB+=*$PPzG(BS4Q37FZa;!-%oXQ{*bnrP~Ah|f~Pp&l*{>MOwF z29%~vRx}DWcxf~p(one7!zQo@WDjf{U9u>qAVaTN2GSDq7HfIX@kGz#OB16}geGfq zAqO+0rlP+}J7zpO zg&ecE#ZDwF72!GNDbXV#;lR<cqjTu53pjEEzPLLwHfsHl9 zgH>QfI zW+Zc11kFh1SYS6JnWMw*hJpep--5=|quW#pw(XV$n0^&*%lM^ure}=2(}g~ zbZ#=}{X8L;&qMx;qx_eLa_=0$3^>_XlkDiAz!Q=x|A{xvKNP+TFW;0)6ue|_9 zJKC!u3<#-kKSDr&gE=rle3AzDWsCC=o`WEm$A9x#1IRv@HSZqE!MP{Q#aLJlIfxu< zOg(kT;Id8L6fZ7MK=0bb*4zTqX*gy5MGBz3rfIPL$u=jHhjK?Wrh~!0> zl_2{d-i3P6B@yJQnCCWNFG6?$Pi&CA2(=Py-;VZbm=`73L7u9BdKkiU5d7-KoDrb( zWHxgW`-B(lvsbiNp`itif>RiOK5Mwas8XP${_CLHtA%pNC?kLiqWeJSoPzebXl5*M zG@QZ=*L^2Zz3LE(85sdWD7rc4&tTsGH)DZg;1pIg-Om_xjwIsF=$QsfBDrpRj z8LrG`Udg^SfE$Vw&T}v*axE0;S}2ndU~okWRNaE+azS;tM@Au+B+mn;kN}>5hirNd ztg_A(tPv8<4Xq&(&K)3PLTiXbo<&e2j~h>q1!x$db>LMe zgCvysiIZ{XQ_->`2RfPsSeUtWno^-mZJm~UFwt!RD-2nR9oh|BiaFY$f@`)v zXt8XI&u7H4{{nY2kJklm(7J(ESq>(VJ5oE~M*MYEIBO(wM3Tpd=fMZ=%`C?_ zpEnfqJU94qSjJi5hBKSZi3i)8c^a)299C!+nsQjD(J&y9-HA~!o&CoKSfRg4uw*mO z2}Sm19%idKhjko_nrt4(G~4iK7#f^z`>XM+<9Sl$|VPv8l;(b(W{ zyaO_UE0){RBzeY2hfAKvh~xbOzRfIWIR7*h^ZYS5aYV*hVT}va5eMzR99D1^x^P&B z(KsNH{ScET%mMp0vK1Jdeqgf9;Do{rXWKJEHtI82&Kb@)azMD4r!hbn7806;Ld`s$ z8~8V~oOfuI<>0&00}aT9?NToe>Nv9;JfzTkf&Y(zx58WoVTKr>IB z^$CG9E9G>0B;D-0>e%=(*VN{N1upaaX8u_J~=LcS(G`}r_7Nzgr~^x zBCBbD;en$nAjt<}!Z9GpXrDYs-jExN59OSUT#%=c#`bF|z((ZDYGr)+{^ zMDG=e6bq)U7A6(EbCN51O*Hx~+Z;Z?wcl(2X)ow5X1;yElI?Z{^EZ(8PcZE#dT&dl zSTtRNt<W)o=pYfT8}B&doIR}VcyUYO=k zrZi|nCQNV!R7Ff{Bh-wlscce^h6#6=7<2|=O+?}nsKON2g;0gFrnW*Ac83XNLrshd zT=O}>xm~(Rn6a6;`R|2?Z47P)5_Ub@)MVLqlOsok`PMl@mMx9Smw6ite9Tj3Ol(Yl zF45ND_DrJb&WedI+Z+;>Nt@-KJg~W#$GJhdn5QxS$%ztY1((abvh1l3Zd$e}KA33P zmS9trvP)55%}X|eAFFt5tdAS8JTp8lu;6&5k2%XBv%3atNedvV1Ntr77T8`(-lZr2 z8cgIl_24wv><4niJdHJbj+ZzqOgUKMEVY19G6ih>gnrAm2X@au8m=T9VPFJJ)I8vG zXchgZxFd9iuyxo8VOK?lMwO~=r=B?nB$^DgpiSVVJpw(70<}hvS&wU{1P)3xJyeHE zZtW55QB2$nsC&V5pVHTIT%p3oKub9V5Xu|1|!>Njwc_SD2 znCBd{Ib$fmxWmUhK*;!R0qePvX0bbmWtatfPRlr7=<#M-e;3x?s~&CdA=+ZF#!$p) zdvCP8H`?ACZSRe?_eR@$qwT$cZSN^WGH&M-+97M`&;VLIa-umvu<2pcLccdLiT)z3 z3@npg=yDksF(q?2t%NqGwi-DZv^o4a32jK}ID4XlrOW1(h^xW|=4_3l@}-YRkD`x0O!B0%uINF>s{NRzGcbRv3tB-j~qtV?BtX~a|JFq-5y6fP1rsS}w z#bw^83H-%8E&+PQJdVM7l6%US85jx`7#KineT?`7lv-aY~WhRd!pf11K)aH zj>OvyLVI~F3hy+C-r>#oc&|aCgKtLQs|M-)d?yOuHpt!M<5={rLFpTxhtSUk)#dyd zN`D&EZ}4{*{cq6h6jLNf285kHqbUKL51tJ(2Y`7U1Kx^53xfvKhYXm`SM8Pspc@P_{o`HeE zj)8#zv~UYXgU)2&&|zS3XJB9e4I;s4kT`<@149V(tSJ}`5(lZz0*z}kFfhPqkT|Cb z14AnV0|RJHAB={HFJ@q10Idy#(J=A7P;+23NL)aLfdRDE8nh-HMuWsb_P%3aU;z0K zMuWsb>e(3?7(i?IU^GY^z^+y;P!279TG)(*s zBLjHTCX5D&gUopaRS%;<;vjRFm>}lBXqdPZ6T}=C4H5^LqXAV9qe0>za~z=JFd8Nv z0W}9kgTz7Rq(RlgXplI_oH`~*e8Fg#_;e;n_`_(BI4FMCFflMFGB7Z}XplI_oU=>} z450PMFd8HdviBhq=;#3k1{e(z2dQUdhNK4=4H5^b7hz^#Py@~LLTQjVNWB#^0|RIe z2aE=ZgWMCy3~>*P28n~z7cnz1Xn^*^L1~!yEM^7R~iY zoP&jd0kj?zM#IGQSr`~VM;E|on0O!yB;Ubkn0P)5BphHgNF3y^CKd(;Sq26M7!48! z`D;E4Bs^g>Ong7o92gA}2bpt;1rjeX8YcdX1rlyB8YB)fpPd!r9vBT02bm+q%D@1c z&xFw+agcgTsCpO;5(laGfvShmAaRiTY*vVSU^GY^q`n@i9!7)2LF(s0&4z^=DZj;RB;#;%`|Y;SZx>;@oVI^bey!;vn-?*dXd*G)Np|z7rb*1E~Cm z(J=8ms5p#lsE5%o z@gS&r7!4DzVP{|ft&4!sF!6=#5P!jFnD|+!`7jzL{(&75k1!e}4$8kQ91!=wXplH4 z+|)TB>R~iU9HibFsvbsz#6jxgpz2{XNF1cT2&x`NgTz7Vr$W`kXplHa{Yt2M7!48! zsXqf%52HciAoUNR>R~iU9HgF^6XJdt4H5^b7w3e8CyWM(gVftV#bGo|Jdu-uK^rtT z52az^J)8^-P7DkTFd8O)j+24Gg@J(qM#IE~xFGI<(I9b9x>e(Xv=d-7OxzeM4x?e> z4qOZjp!swd4HNg_g49zm8YZ5?1<5Zk8YW)L1*yMaG)Np2KHX4rU^GY^V6mv6Tb#E2S&rhpL0RVXBZ6==i`Rd3osfc4r+se+AS~|Bo38*V9$XpLv~_Hv0hR^YEf=! z36^w#t{FYiB1})oFVBOUkKtP^>TnyCl$KLTj8SlPxQ)t8sUXHAm@;&O@)Gm%i*r&_ z3qT10InavrAj0SxU}XuYB*F}Y!lIJ+#DapN{EGNAV5F2db&4xDdOZHzzs zC=V1G7~X`-l~iGsPsuC;ix(sol^BA<5*h|6nPqy#$%#3sSgl0XhFPX2mzBg9<;52# zmZW9oWtODIr<&v%R~UgULXT@qwWhg7_|>KvnpGIV%*0_HR1L%|bTwu9C8?0yoS0V$ zEd$VsCoD?xQqvQ$8P((X=5d1B5ng#Uzrp+{`=-+u&ljZ38Jmw+$pl$hO1^Oxs{$cx+3oh{v!k zu>xxqkdu;|SW=W(5ua<3XAxgv6rW^hhLPRS6qyjGsH7+{uNbS++}uQHMnH8nb0u$UoNzSPPMIYEGsIv0JvQ${L4;C*? zOi9rzP6W5&zyi6cxyc2UdIhF>1txl#W+q_4;@r$UP&*IIDJ?EdLeT;lx^oLCn)l%uNB|jN(Mpz}HJI0IdLKU;qH1gXU)d delta 9156 zcmeC!&)6}Uae@kyGw(z-bs-5?CZ|cbx^m9VK{hO)ya z)fvhR%u39pU3hYds)J4|47 zOk?PJP?=h>=b=?*mqy|z>m^aM^H1%&?bn#GD%@$&G)}>J&skGBY|nGt3+}l1d0tV) zyyum12JQ?~4}RIB#FHjDFJ0sC2cF5S&Vt)YB^o8VF05<)U~^u(^?4=#z2Gh-#tG>P z30w&hjt0z|4>KAu@v{}o`1k++|9XaN3`U2Yr1^)LA{7((|zJM%XdlbO`EP?N-0)MUn|H}pZ zPZ=ICZfJZ7@sS_n2j&O-44f~Tc^^Ia3vwgF8pZ@J2Nn+&%PGy)dzx*wI9qIRmY%_E zv&GRGVjbVl2Ye4N@VzwPtIu&ja2Ou)y}!WrMS=ez1OJB!eD4bQaxd^bbl`u%aG}}e zinArd5t|Fn7CncpwlLb9IBb2v*|LY(7Gy!A#g%5u8P1j`oTbk=TfK0$WT;@s>DUKy zjC=zqGJY^SFuh`U(R@IxftN*sgF!w=;ZOp@kA!A%MK(5YI9s$E{s5_Gj9@+>!eBo+ znz?oJ1$GTaM$5@RIdtkZ4zn!Wq$Y6LK}47%RiTU1#Z=&EOOjN_p&14m%!UhBuq`rB zW4YeKFx$ar<~Co2MU5J29Cr*HUrgX^H5X_#bX?KKd5E1&xGDcoK;+?;g;E`yhs4>0 zC!CV%I23Ut!O4J4xGC#U!9+&W#2rRihYBLEup~;UaikjfaGXhKGGMbg%<_omf&%ZA z1g8yb!YpPP&76(WJ{ort6u6QPJz`@Hl;FN>!EfSN(46R$;Hz*aL7|JYS(@ia%R}BZ zjGNzZPGe-W+}zA8k-D*kLcvDyejn}j38kU zPB}2a>)wS1=7yC_AT9`>Xi%JbuGf*3A($1!1>p&d&a3u0GxISl;{tI(7-ZyAb2d>1 zA5jn&gdNy;Rt6YYU?|I%=s-1%lhJuq1dk05gWBYUe8TnOAk82QGQh>?g%!gL0}vO4 z8+asw{-!Xn$usy!g7_c|Qufbi#sLNn23`;!gh9^PBpDztz#t9cf-p$g8DoxSeuiZ{ zAU+6#l&MSkpeUovtg0ZB*lBBB-elsMT0&_ux zQG&>f?X-Ze?@_pAwdZ4Oi&Q+tu$cEWMH@~0apr2Zr7MiJ(wA09DxbegT4MoZc3vR z!vPm~7%@m1EaCdyziC1ATjVeS@! z;lR5q0@!jG7$&1w+~wbDRKUzI;|y3QhyW#^hgxeIr5O&SpoC3QWaFj=1`Cv^X@p@U}J&W!;_83hl#e6<)96*p=&; zr#`b1*qGzre|l}D)T|h`*5{u$&6!uX+wsArKP*P=m){uOT&j@5vU#TPBxVBx1tU{K zV`GJ6^R!eG)8sUhBnv~!WK&B^!?aW*1B=AOL`$>Olr%#VBXeVeG{cn1(c(6YmYZjY zOE9WtF)%O)vv}%&QU@qfOr|g>K6O3M!N97_&>|whtIW_cg<_yB zG&Nr0b2mt8Si{KCHQ?9!LHb2Sr7W!@Q16Y@2zU z1eQ1!2(hF$r0{vn6F9#AAn60U`sGKogm`K=kZKZq0&h}e8R!R05-UJmxWGAv7{P! z7x07#>`UNr=WjW{1J<_zq_087kaQ>`D6_E&G$k;6b=ttlkf^}G0II5lSsdA%dn82i1)35RJ~~Y}$RLef?j;_% zhj`>}Lgbu57IJ_i))5@BpsLg4$^noNgEAYNOpAg!1u?1Joy9Bjh+ljEe-gQv(Vnldmj9Oi9NVB5^&EWl$f&@8c}@rWYhIU}A!6WE0^ zSkwy+9&~51@;2ZpQD{2A=mn~EKu!VGpTaEe$W|QSaQMkE`G&L}TcJQxg2CI#JTej) zQeY=ToB}EgOnMlm$SXc|cV=W`ILyn+$hMisMS#a#rbXh65r=aD=fuq{M>-lAL5>ug zqHvh^pdwp_j-kO_0i8pvyfH#9Y#uQpcMM%q8jkQfo)Z#amgP|6V3uZ(o;*QDk`3g| zD-4s@%P6udf=sxkJNdedW<98}fEpYq{s!4E&@6Grh{tIHtB?nadV%A?%`AraJ5jE(17i5oafol~>5H(n6jh zD`&_n3ZVItfuJ}4$SYbyyh)xvV-(1+^nik5J(8u);N0Q^P5|5t3=AeB35(0z4!b8}={6GuIp)`Y|b%?U>qFgOYD zY=~T7Fe4#hhoZ)TglUcvO@hr#9F7JDvKnWpYVgf0*`fI1K*F?#1~Bow##yQsejsrM zW`h~29ur{3#4JeToRq94o5yj4fhSE;OjW{Aqd|a$IY)d&WRoj1s38X`4x-rW40;>o(LW$qq_nUhdLS=TNEN14jyM`WNU5kF=uhw$T(SENlYI@dlEkF%O~$p zQVhk=b6>)dqeG)T2A+(*t~7aA#*^V+9%Lv=LHSX@Zj$q^#g%U|@Ixas*MRZbPSFp4G^$SArnN%)wn^cdVtU}0jL402qP zNfOAnYzYF~f~K;HoQx7Vlj~K)>QUVWj`t7dpwNOljp4AMn_~l`Tr0?J9SEn%%d=>( zFf}nLDTpzg=wM`G2ZtqFgn@=K8;48_C~-)qI2-gxC_rK=e}Sk|B4Y!Be}L7=NEYrV zM{xmWwtlcvjiFB6FxgO5p}vRV3N%6faDSKqZZj&gu?sXMII!g>h#bjqV(5{Oao}iT zlsv*%B(j*3(O{MWizJI7*buRV4__S<*^q3Vz}kce$VM?9X1PWr*#m-%2w6vQov^Lcm8x zR{8n>jg~76ys1teoEHM#Gn#_s*rxpdFVCwAW;==aEW8jP*D{6KR++(M3L~pBLy8y+ ztnu~XB%@@C!J7i-0Jf9~=>~6foE134XE3G-9AHgfu$jeI$RnnKt%^nI&vKvfb6!Ks{$jW0bz$o#g>4+|)pZT3e;lu~^4Gw(FF&nr# zoEexT-Y|09*~o5nN8kpN#2ujwwvGov4F_237f6^G7~E;t!X$CVNWpW$AqMd~OmPNW zR^AI@47sfC2(&Ou_?TPxH#lo198hIEC#1nF%G@k*(t$Ne(WQt{>p{Wc2Cg)NyA51U z26qD9FiG&3bG!!mwo!{?b3G5Y@Dn)}1Mdqx5)uv^O^u@bo!WCAa&`i!wUSZ5fT@a3?Pl)51lTyknrbEnefBljZX7t4)Ghzpz!~`U^7b; zDE#@W*z6v#{#eldg*k2lf0d)%8`+hk2*iHyqSyc4Ixvd(Ivn{)Qj<&wKvgAn})(!^iy3VQBb!fWsfua%=Pe zhrbXw{Dn}%e@+V`{Dr~czlMe5&PE}tI|4IUk-}epAte02ut=OS3JC*+{~Xph17R!g zf>=Xgt2+WmSfSzX3km-iW>Mx~$!iX*Nue#JjDDc-7Y2pDFev=ju!6$B8sytXKT!Bb zh=Id@K`%7?r8@oRJQVJ7@<$E-1-w^Sj2=EO@xh6?{GO7Of{DkJ2)W))SFOX_-Yq;ZG=! zLc*UmS+t&rJjz3jEb8dN!@vM))^c|_ax@e~1n@X9bR1@?ZvT}GBFgZNq zaA0K;nxgQ>!BM~#+~H%$WSwwb;4tqYR<_MNZUS=(GTA(28eg(X@R`d>1z&vr{K^^8m7G<`hpy7jzjDna2j%*$=LU#&X<{Wsb<7gPr$iZU96T-um z$g5FfZg5ATCZpicb_q5Y4p0Z2(a)U4t?-1z9oA%q6>AUj#yGH9-F0AVE@}i-BOV(h zXBc=V@JyM(sKc;MU_+CF$aMC!`UXjs35P&Vlmk@<#VqP?8q*q%2-YkR|G>y+E>pw6 zwwZ@Vh{v3#9>nqy;xQMfbzs}fb4sX3LZn!rDPh6JLxqj<4q$gSG6||TG@TRDVHWLJ zkZ{uBXo5yq(;HKkGij3JAYR*JzJHbq0wqy_^6 z19uz4i^)0qvh|>7f<#fPyaAJgBkzRBh7Zi59Dz+Am{|>(0va1wL_Hk7WH2bIg5sK! zb;3n~!@N!`Y@2yp1?C96EMOLFJgDQW$l~r|bw|J>!Qixkz`+7$#m0jL%t|v4mNbj7 zUgmW)IU^*p-PPPXfM?2#MhS)l0fAlVEUq zfy_a{`F!Rq382vP0EJ!x16w-VtT+ZX>$?Iqhjki@G7r4eX*3E*Vyg@I(CDb(WOV$3 zZH~cl0ghOaX9Rs+-abobcKOenE@`TaWQ2@z0>m}T-m19|LL7#hkZp6qaB1P$Iy(C3hb z3^{-@1#FlDRP!O5Zq5R80Sn(jHbp~Fr;@v`Nuq_}14sslcU3T&3zDgX9{og9dzw2}cfqI^k=gI3{rOoJa`eYBM~} zAbCT^S;0Y?D(p!bh;3;!GUdckU`I60|`R|P^UdxhUp~}hf&1Tro=O1h6NKO z&ls}sm>(BVxFh52Pylin!{j()CEI*(Ne1a-g63~P){d3a;WV<&H}-c`d>HCAKg;sT@ z61rCJlZAL^wQSR?CQhS>sZ7CV4Cgd5vKR9-7O+Y3RL`xKd;@Lhe-$VZ&zW#=mB@@o zM;v&|An~G8%W>oZlR*Pd*wOCh{GUac1~>P{!He%Rx3q!^yu) z9P6V&g$D>1DljmB3Lzst0i_n^7M>1<^9@`-c}_6gXy9w-<#>3bL1-$k#mAcsqFZ<~ z3hy*XJm8&i@P32zG`)1Y*Q&m-wwgX$l?jHI6p>TCErlKwPk-r--7 z^uIy7N8pAKZ=>E#0g20ejmC|F5tk(!>&@2-&M=l~w0bAF?IocX!g`JNCqy_( zvKt+jiED_|H9Ef&_c&7D=-Mrj(bC!IzCfZwq`T2;yTpkr-Hm>SB{-(`GzOiQ@CfZ~ z4E-jN(K-v1Rv8!=Qh68{vY8kdgqWu!XfrS{fbwo47o^toVStQgMc9D6&cMJB!N9z!SVM8YaGrlYs%0q+v8n{1_(#187JTM#IGKaYCGQkCUMu z>^7LhcTNTdJ_ZH`7!4BVRAFFX<$@Rlqe0>zhsr|z38P`+dRz<)pn+`|4HI|ff_MN% r|Nk%i|3Ao}0H{G=dh*p^XU=7iu-DSvtQgYIIoT}k>g1?+Jw^rqG?F!v diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index cc1d973..6174ec1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,17 +1,21 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include +#include #include #include #define SHARED_SPACE __attribute__((address_space(3))) #define CONSTANT_SPACE __attribute__((address_space(4))) +typedef _Float16 half16 __attribute__((ext_vector_type(16))); +typedef float float8 __attribute__((ext_vector_type(8))); + #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define DECLARE_ATTR(TYPE, NAME) \ @@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); uint32_t x3 = load_single_matrix_trans(address, 24); return uint4::Native_vec_{x0, x1, x2, x3}; } + + static inline __device__ _Float16 top16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast((value >> 16) & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast(value & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + + static inline __device__ float bpermute_lane(int lane, float x) { + return __hip_ds_bpermutef(4 * lane, x); + } + static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) { + return __hip_ds_bpermute(4 * lane, x); + } + + static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode) + half16 aFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 + int cudaTID = (vGPR % 4 + lane * 4) % 32; + uint32_t reg0, reg1; + // Select the two consecutive elements from a_reg: + if (cudaChunk == 0) { + reg0 = a_reg.x; + reg1 = a_reg.y; + } else { // cudaChunk==2 + reg0 = a_reg.z; + reg1 = a_reg.w; + } + uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0); + uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1); + uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1; + aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); + aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg); + } + return aFrag; + } + + static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; + half16 bFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = vGPR / 4; // will be 0 or 1 + int cudaTID = vGPR % 4 + (lane * 4) % 64; + uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y; + uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg); + if (lane < 8) { + bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); + bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg); + } else { + bFrag[2 * vGPR] = 0.0f; + bFrag[2 * vGPR + 1] = 0.0f; + } + } + return bFrag; + } + + static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) { + const int lIdx = (int)threadIdx.x; + float8 cFrag; + + // Loop over the eight vector GPRs. + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. + int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; + int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; + float ctmp0, ctmp1; + + if (cudaChunk == 0) { + ctmp0 = bpermute_lane(cudaTID, c_reg.x); + ctmp1 = bpermute_lane(cudaTID, c_reg.y); + } else { // cudaChunk == 2 + ctmp0 = bpermute_lane(cudaTID, c_reg.z); + ctmp1 = bpermute_lane(cudaTID, c_reg.w); + } + + // Select one of the two values based on the thread index's LSB. + cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0; + + // Zero out for specific thread indices. + if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32)) + cFrag[vGPR] = 0.0f; + } + return cFrag; + } + + static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) { + const int lIdx = (int)threadIdx.x; + float4::Native_vec_ d_out; + + for (int cChunk = 0; cChunk < 4; ++cChunk) { + int r_vGPR = (cChunk / 2) * 4; + int add8 = (lIdx & 0x4) ? 8 : 0; + int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8; + float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]); + float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]); + float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]); + float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]); + float val; + if (lIdx < 8) { + val = d_tmp0; + } else if (lIdx < 16) { + val = d_tmp1; + } else if (lIdx < 24) { + val = d_tmp2; + } else { + val = d_tmp3; + } + if (cChunk == 0) d_out.x = val; + else if (cChunk == 1) d_out.y = val; + else if (cChunk == 2) d_out.z = val; + else d_out.w = val; + } + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index d3e0b7b..525ae15 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -198,7 +198,8 @@ fn run_instruction<'input>( | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } | ast::Instruction::GridDepControl { .. } - | ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index 229e179..d365e29 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1856,7 +1856,8 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } | ast::Instruction::GridDepControl { .. } - | ast::Instruction::LdMatrix { .. } => InstructionModes::none(), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index 0677345..144f5e6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -533,7 +533,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Vote { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index 19e16e7..f7c976e 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -351,6 +351,35 @@ fn run_instruction<'input>( let name = "sqrt_rn_ftz_f32"; to_call(resolver, fn_declarations, name.into(), i)? } + i @ ptx_parser::Instruction::Mma { + data: + ast::MmaDetails { + alayout, + blayout, + dtype_scalar, + atype_scalar, + btype_scalar, + ctype_scalar, + }, + .. + } => { + let name = format!( + "mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}", + match alayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + match blayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + scalar_to_ptx_name(dtype_scalar), + scalar_to_ptx_name(atype_scalar), + scalar_to_ptx_name(btype_scalar), + scalar_to_ptx_name(ctype_scalar), + ); + to_call(resolver, fn_declarations, name.into(), i)? + } i @ ptx_parser::Instruction::Sqrt { data: ast::RcpData { diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 04e1b6a..1bc622c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -3,8 +3,8 @@ use super::{ StateSpace, VectorPrefix, }; use crate::{ - FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction, - ShiftDirection, ShuffleMode, VoteMode, + FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError, + PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode, }; use bitflags::bitflags; use derive_more::Display; @@ -724,6 +724,27 @@ ptx_parser_macros::generate_instruction_type!( }, GridDepControl { data: crate::GridDepControlAction, + }, + Mma { + data: MmaDetails, + arguments: { + dst: { + repr: T, + type: { data.dtype() }, + }, + src1: { + repr: T, + type: { data.atype() }, + }, + src2: { + repr: T, + type: { data.btype() }, + }, + src3: { + repr: T, + type: { data.ctype() }, + } + } } } ); @@ -2381,3 +2402,27 @@ pub struct ReduxSyncData { pub type_: ScalarType, pub reduction: Reduction, } + +pub struct MmaDetails { + pub alayout: MatrixLayout, + pub blayout: MatrixLayout, + pub dtype_scalar: ScalarType, + pub atype_scalar: ScalarType, + pub btype_scalar: ScalarType, + pub ctype_scalar: ScalarType, +} + +impl MmaDetails { + pub fn dtype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } + pub fn atype(&self) -> Type { + Type::Vector(4, ScalarType::U32) + } + pub fn btype(&self) -> Type { + Type::Vector(2, ScalarType::U32) + } + pub fn ctype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 5389118..a4f9080 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1862,6 +1862,9 @@ derive_parser!( #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum MatrixNumber { } + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] + pub enum MatrixLayout { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3905,6 +3908,29 @@ derive_parser!( } } .action: GridDepControlAction = { .launch_dependents, .wait }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma + mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => { + if dtype != ScalarType::F32 || ctype != ScalarType::F32 { + state.errors.push(PtxError::Todo); + } + Instruction::Mma { + data: MmaDetails { + alayout, + blayout, + dtype_scalar: dtype, + atype_scalar: ScalarType::BF16, + btype_scalar: ScalarType::BF16, + ctype_scalar: ctype, + }, + arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + + .alayout: MatrixLayout = {.row}; + .blayout: MatrixLayout = {.col}; + .ctype: ScalarType = {.f16, .f32}; + .dtype: ScalarType = {.f16, .f32}; ); #[cfg(test)]