From b5f41c7cd07d3cd18376d3024c2bce1450f9f3f6 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 18 Sep 2025 20:15:22 +0200 Subject: [PATCH] More runtime fixes, add mma instruction (#509) --- .github/ISSUE_TEMPLATE/zluda_dump.yml | 2 +- Cargo.lock | 1 - comgr/src/lib.rs | 7 +- compiler/src/main.rs | 22 ++- docs/src/troubleshooting.md | 2 +- ptx/Cargo.toml | 1 - ptx/lib/zluda_ptx_impl.bc | Bin 18824 -> 24456 bytes ptx/lib/zluda_ptx_impl.cpp | 158 +++++++++++++++++- ptx/src/pass/insert_post_saturation.rs | 4 +- .../instruction_mode_to_global_mode/mod.rs | 4 +- ptx/src/pass/llvm/emit.rs | 7 +- ptx/src/pass/mod.rs | 2 +- .../replace_instructions_with_functions.rs | 29 ++++ ptx_parser/src/ast.rs | 52 +++++- ptx_parser/src/lib.rs | 34 ++++ ptxas/src/main.rs | 58 +++---- zluda/src/impl/device.rs | 13 +- zluda/src/impl/driver.rs | 4 +- zluda/src/impl/hipfix.rs | 12 ++ zluda/src/impl/mod.rs | 1 + zluda/src/impl/module.rs | 33 +++- zluda/src/impl/pointer.rs | 54 +++++- zluda_ml/src/impl_unix.rs | 113 +++++++++++++ zluda_ml/src/impl_win.rs | 11 +- zluda_ml/src/lib.rs | 1 + zluda_trace/src/log.rs | 144 ++++++++-------- zluda_trace/src/trace.rs | 24 +-- 27 files changed, 639 insertions(+), 154 deletions(-) create mode 100644 zluda/src/impl/hipfix.rs diff --git a/.github/ISSUE_TEMPLATE/zluda_dump.yml b/.github/ISSUE_TEMPLATE/zluda_dump.yml index ee2738a..a199cf4 100644 --- a/.github/ISSUE_TEMPLATE/zluda_dump.yml +++ b/.github/ISSUE_TEMPLATE/zluda_dump.yml @@ -45,7 +45,7 @@ body: ./train_gpt2fp32cu 4. Build and run the tests: make test_gpt2fp32cu - LD_LIBRARY_PATH= ./test_gpt2fp32cu + LD_LIBRARY_PATH= ./test_gpt2fp32cu validations: required: true - type: input diff --git a/Cargo.lock b/Cargo.lock index baddd3a..cfe4cff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2573,7 +2573,6 @@ dependencies = [ "ptx_parser", "quick-error", "rustc-hash 2.0.0", - "serde", "smallvec", "strum 0.26.3", "strum_macros 0.26.4", diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 9c5671b..8546203 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -219,6 +219,10 @@ pub fn compile_bitcode( compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_language(Language::LlvmIr)?; let common_options = [ + c"-mllvm", + c"-ignore-tti-inline-compatible", + // c"-mllvm", + // c"-amdgpu-early-inline-all=true", // This makes no sense, but it makes ockl linking work c"-Xclang", c"-mno-link-builtin-bitcode-postopt", @@ -237,8 +241,7 @@ pub fn compile_bitcode( ] .into_iter(); let opt_options = if cfg!(debug_assertions) { - //[c"-g", c"-mllvm", c"-print-before-all", c"", c""] - [c"-g", c"", c"", c"", c""] + [c"-g", c"-mamdgpu-precise-memory-op", c"", c"", c""] } else { [ c"-g0", diff --git a/compiler/src/main.rs b/compiler/src/main.rs index 9d1a5d1..a58ad98 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -21,9 +21,14 @@ pub struct Options { output_dir: Option, #[bpaf(long("arch"))] - /// Target architecture + /// Target GPU architecture arch: Option, + #[bpaf(long("ignore-errors"))] + /// Try to ignore errors. This will try and produce output even if there are + /// parsing errors (e.g. an unimplemented instruction) + ignore_errors: bool, + #[bpaf(positional("filename"))] /// PTX file ptx_path: String, @@ -48,7 +53,10 @@ fn main_core() -> Result<(), CompilerError> { .unwrap_or("output"); let mut output_path = match opts.output_dir { - Some(value) => value, + Some(value) => { + std::fs::create_dir_all(&value)?; + value + } None => match ptx_path.parent() { Some(dir) => dir.to_path_buf(), None => env::current_dir()?, @@ -68,7 +76,7 @@ fn main_core() -> Result<(), CompilerError> { let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?; - let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?; + let llvm = ptx_to_llvm(opts.ignore_errors, ptx).map_err(CompilerError::from)?; write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?; @@ -92,8 +100,12 @@ fn main_core() -> Result<(), CompilerError> { Ok(()) } -fn ptx_to_llvm(ptx: &str) -> Result { - let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; +fn ptx_to_llvm(ignore_errors: bool, ptx: &str) -> Result { + let ast = if ignore_errors { + ptx_parser::parse_module_unchecked(ptx) + } else { + ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)? + }; let mut start = Instant::now(); let module = ptx::to_llvm_module( ast, diff --git a/docs/src/troubleshooting.md b/docs/src/troubleshooting.md index ce1189b..cc75399 100644 --- a/docs/src/troubleshooting.md +++ b/docs/src/troubleshooting.md @@ -116,7 +116,7 @@ in order to demonstrate all of zluda_trace's features. ```bash nvcc add.cu -o add -arch sm_80 -LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_TRACE_DIR=/tmp/zluda ./add +LD_LIBRARY_PATH=~/ZLUDA/target/release/trace/ ZLUDA_LOG_DIR=/tmp/zluda ./add ``` The last few lines should look something like: diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index c9a5a6b..7ee6e43 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -22,7 +22,6 @@ microlp = "0.2.11" int-enum = "1.1" unwrap_or = "1.0.1" smallvec = "1.15.1" -serde = { version = "1.0.219", features = ["derive"] } [dev-dependencies] hip_runtime-sys = { path = "../ext/hip_runtime-sys" } diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2b62aebd51921f051f18eeed20e95795f515aee4..bc375c3b14ff5032dff6fc592f683e1b081a93be 100644 GIT binary patch literal 24456 zcmZ>AK5)-egn@yTfq@~3$3Vp^a$nls-&@vm2R8UHGB7YG@-Q%z2{ABmGbu7KF)%QQ zF)%Q&H!4qbJi+40CSYVaiLr%8m6bzMj-@-PLB^HIX%e^769FR;3l0fyE~BCjkAfqL z$}U3fJpzuCm{Lw~Y!p#&cI!|$$!Jv6BB7zd~YA|qT z;F`?pD#_L1!YQ!SL4l*gF;bv|k*U!^L7>BtldYJ=QIM@)iHFJ6frrabf`_fuL4mC~ zh=;BD-~yiJg9}7BSdt7hoDu^%jz9RQaVWu{M}cSQ4IZ{;M?)@-BMgl*7&V%X2e-AS z7tA$ronr7)HfZfx^VIvf%N*eT>tT4uKweXE2k`) zz@Q*^n4`m-%bG-$Muf7Yq}e?RGd@Y;m^gVU|W>JM1`Yy@k>CgtNr~W}79=)^nVtw>Vo|V76lT zqQL((fRDkD|BC?M;|+Yd4E&!y@O=s3d+fmfUV%SXf$wVq-?IR|PZ!uePvH9%z{eoR zAl;$J;E~Ae$ibkQqtL~u(8}B5$gpI?L(6~v|NpON@MHYI{D7Z<^F{Lku{K^7A%>a_ zjUFrsHy?1H0(pY}fE2?(yMHOj{RbsL?w`dBa(^?!k`0X&LJcBj0uz`f9BM1+fcWJa zg8}OSeg+E$l?D7v(F`0#8tmm2?3Eeq1qtnD2J96X&1D?yRSxV8GM^609$F}SDN*JI z5}VY(gBN4Z;yGB+0To?(=O*{jfAQ_xC%onSBLXfN|@V^t_dpm*e@c}jjn?c~+2fkbf{!apY9~k&wKHz(Pf$gmT-ww%&zv*n=8l4iRp%nj@Z*wlW5T*Mf` z{6K`k9+dda1f&=`4zeW+GXyalNoYRE%b>}Sw4qUE7f78vLomZ1h6Yfs?^I|nDrhgC z&|b#DUY5aJ+yGBVCluxHD9YSYls(qSi(oS<+<7Q_t3mFGqRazExl0FS&%qNz)r9sc z4fci^>{T4?g%Ry$4eXT_>{S}=g$qFDn+dd68L$^?uvca@S9-A5YOpi#H}Lj+2f2s2 zfU7~Cp`O9RK|V*I{XIws646QyN&}kqoLSgFjmn zb~hY44l;Z=$d=5*(;&iNcq3r~BP<{BpyY9Aanv!u#nFs|AkXk0&}67%_ybC`0ymg8 zEVOI}iOM&Ck_0%vC*EjS;K-nIfa$_PHWgkL1xB$2w;Wb5GJZJ7mc$G3*AiYXX%+@e z7DkaJ2c>fYZaFY`C~_=N;Jug7{16&q*$f`i)69Q>%w~*Wc3}FYumRN0@swa-uwh_e z5N=9vQsi+-^pW6GnJJ(l%$2IZm3(Lsk4vJC&tsblEef*)JcL_1XDRp`W?8sNP2jqN zh%iU0LKmltsld^eM^YV!W*BHN8!lYIw#Yz@1djG;6%eL zg9Itl!~;56O$EHdEk+V*0+$^`gqe(IILu*`TGA-x!*MO4=>S`G{SkK( z6u6QTrJgscapW8LaIE2Mk>+tROe|<_oT0Geph02i!9m;3);;%+RGFm zBzwt>_JSJ@a!)U?eOSO==D=Q+z+Q5J?SlZA0T#(+;4fOhUaY`gWx!r^f$f6}4>rSb;rv0w3Ik-vmHXg$eAr3H(J4V0$mHeKP>*0;x>kFKGaK;R4&Y31F29>=hT- zz6r2bDnM-qX@r}9f$duWd!Yh*6`0IT;4frg2N^5CUZ4O9kZ%g?1rF>L3)qVo*h?DN zb06@50t6gZo58+V7{H!ez*TsG?cD^vssi?^2KK56>|XpFfJt35EoZr|g}KEDj7X zFoDr|Rd5C$iv$Y;Ljy>Vqk%^v=wj1S21WxA7mQO5Oz@K3#>&CK#J~jB%A~;D7NYvl z(1AgOK>*y07hsSySR%!}mXYZILj!{V1IV-!4T_=hIeZL0>|j$E7#JMbcvco1Zfalv zn*tU!O7M6p7s(T=W2$8Dwqh^N#1>|EDS6Rj0zx8#tDqh zTE&O-7-lGexF8I2($C;F#sey#eilds2w3s;>;NAr}DMyq|lEFs=tODe6HlCGH z-Ip2|Rv?FfXk5Ys#seU~!~n(v zDab~7-Mi4h+;9>ez$Y3Mr=II|WM!BPDo?>?fs(;KXJ$SIHEu8;6gA9kAy3WOMA2*v zFtET-mM_tPVOj)_4G+UI6w}^I_(%vaz+)BUeix$`Rtz(YV0MB+^=}FTn>>S$6if`H z?4Qw$0}LDuyf86Pz;2QZkQZQ(feC?>oiXNU=4VhtQKl~CgRZQ^5h1dpxj7{4+y9zYJ;SMkm7nHxNrVah;Zd!FqwI|~Cm zVS>VT2k#Dahpm-vILOWrjH2utuZm&zcjEyML&7JR^FV2P>19UA1&jbsWsJxD>;GjKnjXqR(Up} zXWOaM4q-YkA{fJYjmxn(Pf!%YdFEPE7^N8wl%UxAD0m^y9R`MF#&Ac2iuxT~*$>zl z4CbPQOhlLfs|7;?Cp-gyqWkGCCd~tk3_clf{h;WU6w~15Xn^Md22i?{cT47=amJCb{@U#PpWlQcz83xqMvT8o7J2P4Z;l(|Jk>voW zE&#g^6mc8Hl(^9>Ezu1rW?*T67Yd+ssI|5A7!#_cQ**f|vZA_e)n~DD5BSgu$)}H) zq@FMefa?;lwGM0^IRYAMR2Ug1ure`#t5hSo88VWBJ`;LCqgO1>LMx=FJlho%%dT~{ zR3Yi+Jh|!7pG|i1aGigiYH9notogvwpPVUt%WtM`@^}F z$U~Vl5k}h%XN#7@Jx<_mP4SFEo@Wf8;hZlGe8xOZa)&@YA(=D-RW=`+lrZd?zu z?U{<6k^=T>4)#iq_JRWTssiS+iuMA9cF8k>d~YW3Jv89YGvI&xf$u{ApRvFZ*?WpI zHxy+LCCa54bXxCew!YwOxrNzgPqR%6v-Oz|B4r8e)fw#O8SOUW7tLs|oYAQB?gQVu2Yg=-@Vzl$PC5bV^h3J-GItVXPASS9 zX_P&c$aZalv*ihA%NbyD$6<>p&K4ca)>97a?`gLA(`0>x*?Plap5vMA1rhB<1=g;^B@*o=3)&^m1@pgS;QtiB_hbTS{OO?r-b z%wqw*HwXBh2Jn40P(FCcQSQk@@f#0i?=i}yc}QAsXtv$bEZuV0qQ}|t%wd}y2Q5z= zw!Xq_bH-y!AqRU|0(&92^D24!7~h)&{s#x1OCqs z*j_&1dlJC+!$A4bn?#vIjj}f$ioa2meWoaTL{avJqU;Mr`80*O)_a_7cQ9K{X|_BQ zv8SK_G!|9G!CvLjUS7~HdHOludjtN55BOd>@V{KZ|J2|xBXjWz_Hv2#_!;d51)vZH zjU>%zY;m-FaM)slv+V?D%QFr~7Ra7jC<_`mIVh8+l5D{D5EQ@sFCF-w8aOL9mrAe~ zd$h+dXs>B#FPgD&3ZqTWVH?okfwR?~X6rMDdyevWD2h5I$~{w*J$Fzxs8J-%BE5m{ zsQ}*xaJaoU;QKrQG@!`9P^rMc;KsnfP$cJ};M~C@$ysFNp`hH@GjZZYmjwbEDK>^i zPJI&;7e^>8XVmy;a}AUi9erICl5-OC(iO^5i;6Sz^Arp%^$he3(lbjGG%`v`3W}}t z_0uy;GD?&5lJj%*gA!9x^Yi=*QuCbia|=pKQuT9k%5rrJit@8klS>qe@{@JV^i1@O z6^u-cj7$_#jLpo_5-m*9jMGw*lTu8Q3=J$y%u|z6OjD9f&5aUM4UH0wP12Gy^KMn%+uT;Va#%b$9bZJG0*9w6CMdoL5azmvx_~IwTl-^s~Rw@lwe@cWnf_7 z<|@)?FK=M4iYS=jY`X&#RhAbHH$9U%1{q5{2uc<&z-d~{VUc|pi0g}gVAY$7%#FZaMVtg#S%7>21yM&7&)34H}jtnwAsvZ z*nyRi{gApmV*tQoT^EXb@yR!s{TwGh=SR2ZaUAg3L?|js^w?nwXOq7{Kup-KN~I zji>WSj^rwaTShTKn^Ps48EzR(X?&6+c~7BX8&9EvEhs(}4H>|lOYSa54hB$!F)%QI z;tbTD6lQS*#~G-#@nr=ggEAYdKvM!ktMdj%uo_m-_ymh1TeF0r#sdYGgAPpX3eAi; z*yTR8;+A`aNA40r4%7#Pxe3&*F~^p^jsK0G zBpDgPd8$|!9_D3IWZTT+BEVzL(F#f$&IMeJn^}%@98_diV zmlu4YbI{;IBL@?^#z8hn*ouSQ=Z<97|3<-vPmECaF*34k=5Z6?G3RN4xX)|iW|m_G z%!wTbb(jNLj!Rg2C-g|j6bUpX82oiAme|a4%;97M2aBSJ2*SP6U^^gT0BUBLNHA*H zJ57=2<${bk9Oh*OSpy0N8K^a!6G6eyC~v}14>7rGz01{2(!4M`0LY3)Npr2gnI*PArsg}Ly(Qv z^aPITFd7CJ3N#tSG<;y~lbdo!u<@S28y!c(07DKINd}OYz;PjlE61cuq)wS+v7o0wW^q zAnE$xgvJC8Wi}Rpri6qoO(1(fegu_$urvrt93~Qs4CWkj4l20ua9MFU8eB+Fl*r@g zFt`EDM+DW!psELTQeftT++m{8@o8nFV1v?vvl)<5l7V5mOvS;4+ZC*m5{@`>G&Md| zJkHuIx9|ffv~}ZmtksscO3h|q>5cUE%J(C#^Z$t(7%~`N}0cptKz1vIQ=qH(~*)@e#Sr=5r^qtb;14jh3@CK?=0 zA_|PWT#W*(8V=1$3NDOn7Z$L9+Y_KXs@>woSi;|2c+gl^$0d7a6C#iaARmv=qP4@ z6b=k*bq+icJWNJ+1tbo2G%~g*L^K>^=VD}QZSXN?aoPx}>mhY2#5_nF4&A(NP*CGD z^Eo@3nV^0>to-3*U|=xOKsU2KnMIMK>4kicV9z`0eR2#8aPt==h+O7n73`6aC=_T) zU}U@B%Hz($%FK3|*I6)O1IG!NN5r`p`5GG_?E+By6J{T%PJAM%%-|?ce>ky+fyGha zAH!k6W`_nwp(c>OTMd#D7!*0g5(r?Hlz~cCaCnJDa)nT4s1DIR0*pX2k)br>lD8vr(h5&nPfJC^-gvNATh6#eqDtt}?ItLW^ z4k>s{dDy|Ebs)jQse{E}MxKX6;;w@nEIewAGxIoDGLIZ!a1!7-us}>H%gRZJ30$p# z(-)}!0t-tl?mfogAi(l49~`jvGXf5PlrS6?baQNAlxqdKw*%o`c?}i~7N#a9B?T1s zMj2=*vw6sXlB1(cj6i2{fh#b2}(+e34*W0m~fnt!~h`V;5*ja9}G=5IK_J#Ly!l z1IqQ1R~U;#7IQKh%u-;HWJv;tj?RGx+^vUr7$glH9x1RKWa4?0z|E3+grT!pfQ8$j zgIQ04&4!5)A$vfXDFrUutS`}mh>ZgWIHufeQ)ft%W>TJ!%fXSjO);U3*9p|72f34b zmaFH4!x|lq91L+01}qHnph5y#mxBiXV15E+L=g!_hQqv00&L1`oZt*naD+8uLHYv5xCXW=59tT2nG3|19K1J?&FYT83Pw=# z&!wPM>P>-j5qpBc-2e_pHpUsM8S65cW%D?0Iq;-Oaw(rsVc5)awt)F5s6@UXbFiTK zGg|^AlC+LnJ_FNb-h2ar%e+?u7BH|IOkh+p;A;}z;lTH!%i-e;W{DFF zf*$M#S$WI_7$u%Gbwo4G$=kuKe&_&WqYPWYK`jM7MvjCH>7`^B&b}I9SlUfiJ<}uE2^WgVP%f z9N#%Iq;bv?WM<}YG}!P=v9Q%R#xXFXRd~k(elunl8NnQIn6PCr&E)uyIHBE1hbKdD zhF~)jhoiv{k417u5^oBe1Nc)W{4lr?$@W}OZv*=o1qSgu&F6$NJQW@>)m$+Cz^H$K z-FAumho+hf+@BKPpJ12GM~stNfBy@WXj#q_=(+MM(TwX2|Et{WH*p7G-wcDVa`_j_t5gCyx8J;hQ^m|^%5N~ z+p;tQ7&t+pEg^J}Ge$wE#&tnO4+~F91BXiI1Ex0V4hEhg$0NFoe&%ypgcBb$IPfv= zNKhAOXJ9di;1t-AtXi@|(SX}vhw23$fhKPTChXs%>yB4XS7Jz39Ojd>Ujy&cZ;u2?!BpeG4EtOPw#NnfVpkT+t0|yeO9R!t1iOkKB zY{?vL(j7fKX_6gE5{4SDKtpxG&5xOQjU0}s9+Y)I!a3)hkO;GAheN_)hocDs*BTC$ zGau$X)Xt_})1Bzh;=tCiL-D}NwgVb`lFhplco?>t?r7vb;jmo#L_&a`0H`Jr%3)KR zp=!Y9b%sHVA)F(^QK0KIv+)5QrU{1>+N4`-8JZ*6o^klyU_Q$r=>DNG?f|>Roq`Xb zTH3(F;Eur`Fw4{6j>8v5$umY8-UdklS3qMM&$y1W)-gybJeu@?Po%?wQKV#Ag5(MZ z7FKUIEjH(kLLzPt+1jLA0v(!#-4Zv7bO?Y4r!4LYNE}sMULZb)X~8!x^8!Y#7bVPE zJf~O>r8RUoH=RAevYF+KL#O0}mPBE8Uxh{)++Q#q$ z7J;DnvjoSVH#q)y7#JA9@du5&AB@TjA}bh;L>2xxusTgq_~YQ%!w{mN;K(W@kZ;l; z09rA_$*Qmz6kmd%UZ!J#(~`ryE-Y-Dd0GWz1VR16#)CSC9as+Ywy>}r=IuyfyUZIT z;$zNo(ZD<4j@Vt$I9L)}5<`Qd0}oT--4;m}g;tPh@)Bo^G~5afEkDeA1k@EhB`EXp zWq~sTs7=$La{gbkN@51OA5(wNRR=bVrRvufG{c2;=_7A2Kpp1X(H zig_9@q#E4dg|v}DeQo6SBxsxqIv8f4%;pLjZ*_QVz+=wh1Zq`7^nk(yrU%rH`?7*j zst4R>0}rB^xO3Pf1%R6Aj3Imn3!DvF4zhvkHuSbDWULvi2gxpI2it}3P(iam18gjn zXm>Ip+zHC7!Yl_n7(xASiPIg9;33Zm797$LfAE9j8PeAURlvyZF=v7K!G-T2o1!78 zv%uZgB+=*$PPzG(BS4Q37FZa;!-%oXQ{*bnrP~Ah|f~Pp&l*{>MOwF z29%~vRx}DWcxf~p(one7!zQo@WDjf{U9u>qAVaTN2GSDq7HfIX@kGz#OB16}geGfq zAqO+0rlP+}J7zpO zg&ecE#ZDwF72!GNDbXV#;lR<cqjTu53pjEEzPLLwHfsHl9 zgH>QfI zW+Zc11kFh1SYS6JnWMw*hJpep--5=|quW#pw(XV$n0^&*%lM^ure}=2(}g~ zbZ#=}{X8L;&qMx;qx_eLa_=0$3^>_XlkDiAz!Q=x|A{xvKNP+TFW;0)6ue|_9 zJKC!u3<#-kKSDr&gE=rle3AzDWsCC=o`WEm$A9x#1IRv@HSZqE!MP{Q#aLJlIfxu< zOg(kT;Id8L6fZ7MK=0bb*4zTqX*gy5MGBz3rfIPL$u=jHhjK?Wrh~!0> zl_2{d-i3P6B@yJQnCCWNFG6?$Pi&CA2(=Py-;VZbm=`73L7u9BdKkiU5d7-KoDrb( zWHxgW`-B(lvsbiNp`itif>RiOK5Mwas8XP${_CLHtA%pNC?kLiqWeJSoPzebXl5*M zG@QZ=*L^2Zz3LE(85sdWD7rc4&tTsGH)DZg;1pIg-Om_xjwIsF=$QsfBDrpRj z8LrG`Udg^SfE$Vw&T}v*axE0;S}2ndU~okWRNaE+azS;tM@Au+B+mn;kN}>5hirNd ztg_A(tPv8<4Xq&(&K)3PLTiXbo<&e2j~h>q1!x$db>LMe zgCvysiIZ{XQ_->`2RfPsSeUtWno^-mZJm~UFwt!RD-2nR9oh|BiaFY$f@`)v zXt8XI&u7H4{{nY2kJklm(7J(ESq>(VJ5oE~M*MYEIBO(wM3Tpd=fMZ=%`C?_ zpEnfqJU94qSjJi5hBKSZi3i)8c^a)299C!+nsQjD(J&y9-HA~!o&CoKSfRg4uw*mO z2}Sm19%idKhjko_nrt4(G~4iK7#f^z`>XM+<9Sl$|VPv8l;(b(W{ zyaO_UE0){RBzeY2hfAKvh~xbOzRfIWIR7*h^ZYS5aYV*hVT}va5eMzR99D1^x^P&B z(KsNH{ScET%mMp0vK1Jdeqgf9;Do{rXWKJEHtI82&Kb@)azMD4r!hbn7806;Ld`s$ z8~8V~oOfuI<>0&00}aT9?NToe>Nv9;JfzTkf&Y(zx58WoVTKr>IB z^$CG9E9G>0B;D-0>e%=(*VN{N1upaaX8u_J~=LcS(G`}r_7Nzgr~^x zBCBbD;en$nAjt<}!Z9GpXrDYs-jExN59OSUT#%=c#`bF|z((ZDYGr)+{^ zMDG=e6bq)U7A6(EbCN51O*Hx~+Z;Z?wcl(2X)ow5X1;yElI?Z{^EZ(8PcZE#dT&dl zSTtRNt<W)o=pYfT8}B&doIR}VcyUYO=k zrZi|nCQNV!R7Ff{Bh-wlscce^h6#6=7<2|=O+?}nsKON2g;0gFrnW*Ac83XNLrshd zT=O}>xm~(Rn6a6;`R|2?Z47P)5_Ub@)MVLqlOsok`PMl@mMx9Smw6ite9Tj3Ol(Yl zF45ND_DrJb&WedI+Z+;>Nt@-KJg~W#$GJhdn5QxS$%ztY1((abvh1l3Zd$e}KA33P zmS9trvP)55%}X|eAFFt5tdAS8JTp8lu;6&5k2%XBv%3atNedvV1Ntr77T8`(-lZr2 z8cgIl_24wv><4niJdHJbj+ZzqOgUKMEVY19G6ih>gnrAm2X@au8m=T9VPFJJ)I8vG zXchgZxFd9iuyxo8VOK?lMwO~=r=B?nB$^DgpiSVVJpw(70<}hvS&wU{1P)3xJyeHE zZtW55QB2$nsC&V5pVHTIT%p3oKub9V5Xu|1|!>Njwc_SD2 znCBd{Ib$fmxWmUhK*;!R0qePvX0bbmWtatfPRlr7=<#M-e;3x?s~&CdA=+ZF#!$p) zdvCP8H`?ACZSRe?_eR@$qwT$cZSN^WGH&M-+97M`&;VLIa-umvu<2pcLccdLiT)z3 z3@npg=yDksF(q?2t%NqGwi-DZv^o4a32jK}ID4XlrOW1(h^xW|=4_3l@}-YRkD`x0O!B0%uINF>s{NRzGcbRv3tB-j~qtV?BtX~a|JFq-5y6fP1rsS}w z#bw^83H-%8E&+PQJdVM7l6%US85jx`7#KineT?`7lv-aY~WhRd!pf11K)aH zj>OvyLVI~F3hy+C-r>#oc&|aCgKtLQs|M-)d?yOuHpt!M<5={rLFpTxhtSUk)#dyd zN`D&EZ}4{*{cq6h6jLNf285kHqbUKL51tJ(2Y`7U1Kx^53xfvKhYXm`SM8Pspc@P_{o`HeE zj)8#zv~UYXgU)2&&|zS3XJB9e4I;s4kT`<@149V(tSJ}`5(lZz0*z}kFfhPqkT|Cb z14AnV0|RJHAB={HFJ@q10Idy#(J=A7P;+23NL)aLfdRDE8nh-HMuWsb_P%3aU;z0K zMuWsb>e(3?7(i?IU^GY^z^+y;P!279TG)(*s zBLjHTCX5D&gUopaRS%;<;vjRFm>}lBXqdPZ6T}=C4H5^LqXAV9qe0>za~z=JFd8Nv z0W}9kgTz7Rq(RlgXplI_oH`~*e8Fg#_;e;n_`_(BI4FMCFflMFGB7Z}XplI_oU=>} z450PMFd8HdviBhq=;#3k1{e(z2dQUdhNK4=4H5^b7hz^#Py@~LLTQjVNWB#^0|RIe z2aE=ZgWMCy3~>*P28n~z7cnz1Xn^*^L1~!yEM^7R~iY zoP&jd0kj?zM#IGQSr`~VM;E|on0O!yB;Ubkn0P)5BphHgNF3y^CKd(;Sq26M7!48! z`D;E4Bs^g>Ong7o92gA}2bpt;1rjeX8YcdX1rlyB8YB)fpPd!r9vBT02bm+q%D@1c z&xFw+agcgTsCpO;5(laGfvShmAaRiTY*vVSU^GY^q`n@i9!7)2LF(s0&4z^=DZj;RB;#;%`|Y;SZx>;@oVI^bey!;vn-?*dXd*G)Np|z7rb*1E~Cm z(J=8ms5p#lsE5%o z@gS&r7!4DzVP{|ft&4!sF!6=#5P!jFnD|+!`7jzL{(&75k1!e}4$8kQ91!=wXplH4 z+|)TB>R~iU9HibFsvbsz#6jxgpz2{XNF1cT2&x`NgTz7Vr$W`kXplHa{Yt2M7!48! zsXqf%52HciAoUNR>R~iU9HgF^6XJdt4H5^b7w3e8CyWM(gVftV#bGo|Jdu-uK^rtT z52az^J)8^-P7DkTFd8O)j+24Gg@J(qM#IE~xFGI<(I9b9x>e(Xv=d-7OxzeM4x?e> z4qOZjp!swd4HNg_g49zm8YZ5?1<5Zk8YW)L1*yMaG)Np2KHX4rU^GY^V6mv6Tb#E2S&rhpL0RVXBZ6==i`Rd3osfc4r+se+AS~|Bo38*V9$XpLv~_Hv0hR^YEf=! z36^w#t{FYiB1})oFVBOUkKtP^>TnyCl$KLTj8SlPxQ)t8sUXHAm@;&O@)Gm%i*r&_ z3qT10InavrAj0SxU}XuYB*F}Y!lIJ+#DapN{EGNAV5F2db&4xDdOZHzzs zC=V1G7~X`-l~iGsPsuC;ix(sol^BA<5*h|6nPqy#$%#3sSgl0XhFPX2mzBg9<;52# zmZW9oWtODIr<&v%R~UgULXT@qwWhg7_|>KvnpGIV%*0_HR1L%|bTwu9C8?0yoS0V$ zEd$VsCoD?xQqvQ$8P((X=5d1B5ng#Uzrp+{`=-+u&ljZ38Jmw+$pl$hO1^Oxs{$cx+3oh{v!k zu>xxqkdu;|SW=W(5ua<3XAxgv6rW^hhLPRS6qyjGsH7+{uNbS++}uQHMnH8nb0u$UoNzSPPMIYEGsIv0JvQ${L4;C*? zOi9rzP6W5&zyi6cxyc2UdIhF>1txl#W+q_4;@r$UP&*IIDJ?EdLeT;lx^oLCn)l%uNB|jN(Mpz}HJI0IdLKU;qH1gXU)d delta 9156 zcmeC!&)6}Uae@kyGw(z-bs-5?CZ|cbx^m9VK{hO)ya z)fvhR%u39pU3hYds)J4|47 zOk?PJP?=h>=b=?*mqy|z>m^aM^H1%&?bn#GD%@$&G)}>J&skGBY|nGt3+}l1d0tV) zyyum12JQ?~4}RIB#FHjDFJ0sC2cF5S&Vt)YB^o8VF05<)U~^u(^?4=#z2Gh-#tG>P z30w&hjt0z|4>KAu@v{}o`1k++|9XaN3`U2Yr1^)LA{7((|zJM%XdlbO`EP?N-0)MUn|H}pZ zPZ=ICZfJZ7@sS_n2j&O-44f~Tc^^Ia3vwgF8pZ@J2Nn+&%PGy)dzx*wI9qIRmY%_E zv&GRGVjbVl2Ye4N@VzwPtIu&ja2Ou)y}!WrMS=ez1OJB!eD4bQaxd^bbl`u%aG}}e zinArd5t|Fn7CncpwlLb9IBb2v*|LY(7Gy!A#g%5u8P1j`oTbk=TfK0$WT;@s>DUKy zjC=zqGJY^SFuh`U(R@IxftN*sgF!w=;ZOp@kA!A%MK(5YI9s$E{s5_Gj9@+>!eBo+ znz?oJ1$GTaM$5@RIdtkZ4zn!Wq$Y6LK}47%RiTU1#Z=&EOOjN_p&14m%!UhBuq`rB zW4YeKFx$ar<~Co2MU5J29Cr*HUrgX^H5X_#bX?KKd5E1&xGDcoK;+?;g;E`yhs4>0 zC!CV%I23Ut!O4J4xGC#U!9+&W#2rRihYBLEup~;UaikjfaGXhKGGMbg%<_omf&%ZA z1g8yb!YpPP&76(WJ{ort6u6QPJz`@Hl;FN>!EfSN(46R$;Hz*aL7|JYS(@ia%R}BZ zjGNzZPGe-W+}zA8k-D*kLcvDyejn}j38kU zPB}2a>)wS1=7yC_AT9`>Xi%JbuGf*3A($1!1>p&d&a3u0GxISl;{tI(7-ZyAb2d>1 zA5jn&gdNy;Rt6YYU?|I%=s-1%lhJuq1dk05gWBYUe8TnOAk82QGQh>?g%!gL0}vO4 z8+asw{-!Xn$usy!g7_c|Qufbi#sLNn23`;!gh9^PBpDztz#t9cf-p$g8DoxSeuiZ{ zAU+6#l&MSkpeUovtg0ZB*lBBB-elsMT0&_ux zQG&>f?X-Ze?@_pAwdZ4Oi&Q+tu$cEWMH@~0apr2Zr7MiJ(wA09DxbegT4MoZc3vR z!vPm~7%@m1EaCdyziC1ATjVeS@! z;lR5q0@!jG7$&1w+~wbDRKUzI;|y3QhyW#^hgxeIr5O&SpoC3QWaFj=1`Cv^X@p@U}J&W!;_83hl#e6<)96*p=&; zr#`b1*qGzre|l}D)T|h`*5{u$&6!uX+wsArKP*P=m){uOT&j@5vU#TPBxVBx1tU{K zV`GJ6^R!eG)8sUhBnv~!WK&B^!?aW*1B=AOL`$>Olr%#VBXeVeG{cn1(c(6YmYZjY zOE9WtF)%O)vv}%&QU@qfOr|g>K6O3M!N97_&>|whtIW_cg<_yB zG&Nr0b2mt8Si{KCHQ?9!LHb2Sr7W!@Q16Y@2zU z1eQ1!2(hF$r0{vn6F9#AAn60U`sGKogm`K=kZKZq0&h}e8R!R05-UJmxWGAv7{P! z7x07#>`UNr=WjW{1J<_zq_087kaQ>`D6_E&G$k;6b=ttlkf^}G0II5lSsdA%dn82i1)35RJ~~Y}$RLef?j;_% zhj`>}Lgbu57IJ_i))5@BpsLg4$^noNgEAYNOpAg!1u?1Joy9Bjh+ljEe-gQv(Vnldmj9Oi9NVB5^&EWl$f&@8c}@rWYhIU}A!6WE0^ zSkwy+9&~51@;2ZpQD{2A=mn~EKu!VGpTaEe$W|QSaQMkE`G&L}TcJQxg2CI#JTej) zQeY=ToB}EgOnMlm$SXc|cV=W`ILyn+$hMisMS#a#rbXh65r=aD=fuq{M>-lAL5>ug zqHvh^pdwp_j-kO_0i8pvyfH#9Y#uQpcMM%q8jkQfo)Z#amgP|6V3uZ(o;*QDk`3g| zD-4s@%P6udf=sxkJNdedW<98}fEpYq{s!4E&@6Grh{tIHtB?nadV%A?%`AraJ5jE(17i5oafol~>5H(n6jh zD`&_n3ZVItfuJ}4$SYbyyh)xvV-(1+^nik5J(8u);N0Q^P5|5t3=AeB35(0z4!b8}={6GuIp)`Y|b%?U>qFgOYD zY=~T7Fe4#hhoZ)TglUcvO@hr#9F7JDvKnWpYVgf0*`fI1K*F?#1~Bow##yQsejsrM zW`h~29ur{3#4JeToRq94o5yj4fhSE;OjW{Aqd|a$IY)d&WRoj1s38X`4x-rW40;>o(LW$qq_nUhdLS=TNEN14jyM`WNU5kF=uhw$T(SENlYI@dlEkF%O~$p zQVhk=b6>)dqeG)T2A+(*t~7aA#*^V+9%Lv=LHSX@Zj$q^#g%U|@Ixas*MRZbPSFp4G^$SArnN%)wn^cdVtU}0jL402qP zNfOAnYzYF~f~K;HoQx7Vlj~K)>QUVWj`t7dpwNOljp4AMn_~l`Tr0?J9SEn%%d=>( zFf}nLDTpzg=wM`G2ZtqFgn@=K8;48_C~-)qI2-gxC_rK=e}Sk|B4Y!Be}L7=NEYrV zM{xmWwtlcvjiFB6FxgO5p}vRV3N%6faDSKqZZj&gu?sXMII!g>h#bjqV(5{Oao}iT zlsv*%B(j*3(O{MWizJI7*buRV4__S<*^q3Vz}kce$VM?9X1PWr*#m-%2w6vQov^Lcm8x zR{8n>jg~76ys1teoEHM#Gn#_s*rxpdFVCwAW;==aEW8jP*D{6KR++(M3L~pBLy8y+ ztnu~XB%@@C!J7i-0Jf9~=>~6foE134XE3G-9AHgfu$jeI$RnnKt%^nI&vKvfb6!Ks{$jW0bz$o#g>4+|)pZT3e;lu~^4Gw(FF&nr# zoEexT-Y|09*~o5nN8kpN#2ujwwvGov4F_237f6^G7~E;t!X$CVNWpW$AqMd~OmPNW zR^AI@47sfC2(&Ou_?TPxH#lo198hIEC#1nF%G@k*(t$Ne(WQt{>p{Wc2Cg)NyA51U z26qD9FiG&3bG!!mwo!{?b3G5Y@Dn)}1Mdqx5)uv^O^u@bo!WCAa&`i!wUSZ5fT@a3?Pl)51lTyknrbEnefBljZX7t4)Ghzpz!~`U^7b; zDE#@W*z6v#{#eldg*k2lf0d)%8`+hk2*iHyqSyc4Ixvd(Ivn{)Qj<&wKvgAn})(!^iy3VQBb!fWsfua%=Pe zhrbXw{Dn}%e@+V`{Dr~czlMe5&PE}tI|4IUk-}epAte02ut=OS3JC*+{~Xph17R!g zf>=Xgt2+WmSfSzX3km-iW>Mx~$!iX*Nue#JjDDc-7Y2pDFev=ju!6$B8sytXKT!Bb zh=Id@K`%7?r8@oRJQVJ7@<$E-1-w^Sj2=EO@xh6?{GO7Of{DkJ2)W))SFOX_-Yq;ZG=! zLc*UmS+t&rJjz3jEb8dN!@vM))^c|_ax@e~1n@X9bR1@?ZvT}GBFgZNq zaA0K;nxgQ>!BM~#+~H%$WSwwb;4tqYR<_MNZUS=(GTA(28eg(X@R`d>1z&vr{K^^8m7G<`hpy7jzjDna2j%*$=LU#&X<{Wsb<7gPr$iZU96T-um z$g5FfZg5ATCZpicb_q5Y4p0Z2(a)U4t?-1z9oA%q6>AUj#yGH9-F0AVE@}i-BOV(h zXBc=V@JyM(sKc;MU_+CF$aMC!`UXjs35P&Vlmk@<#VqP?8q*q%2-YkR|G>y+E>pw6 zwwZ@Vh{v3#9>nqy;xQMfbzs}fb4sX3LZn!rDPh6JLxqj<4q$gSG6||TG@TRDVHWLJ zkZ{uBXo5yq(;HKkGij3JAYR*JzJHbq0wqy_^6 z19uz4i^)0qvh|>7f<#fPyaAJgBkzRBh7Zi59Dz+Am{|>(0va1wL_Hk7WH2bIg5sK! zb;3n~!@N!`Y@2yp1?C96EMOLFJgDQW$l~r|bw|J>!Qixkz`+7$#m0jL%t|v4mNbj7 zUgmW)IU^*p-PPPXfM?2#MhS)l0fAlVEUq zfy_a{`F!Rq382vP0EJ!x16w-VtT+ZX>$?Iqhjki@G7r4eX*3E*Vyg@I(CDb(WOV$3 zZH~cl0ghOaX9Rs+-abobcKOenE@`TaWQ2@z0>m}T-m19|LL7#hkZp6qaB1P$Iy(C3hb z3^{-@1#FlDRP!O5Zq5R80Sn(jHbp~Fr;@v`Nuq_}14sslcU3T&3zDgX9{og9dzw2}cfqI^k=gI3{rOoJa`eYBM~} zAbCT^S;0Y?D(p!bh;3;!GUdckU`I60|`R|P^UdxhUp~}hf&1Tro=O1h6NKO z&ls}sm>(BVxFh52Pylin!{j()CEI*(Ne1a-g63~P){d3a;WV<&H}-c`d>HCAKg;sT@ z61rCJlZAL^wQSR?CQhS>sZ7CV4Cgd5vKR9-7O+Y3RL`xKd;@Lhe-$VZ&zW#=mB@@o zM;v&|An~G8%W>oZlR*Pd*wOCh{GUac1~>P{!He%Rx3q!^yu) z9P6V&g$D>1DljmB3Lzst0i_n^7M>1<^9@`-c}_6gXy9w-<#>3bL1-$k#mAcsqFZ<~ z3hy*XJm8&i@P32zG`)1Y*Q&m-wwgX$l?jHI6p>TCErlKwPk-r--7 z^uIy7N8pAKZ=>E#0g20ejmC|F5tk(!>&@2-&M=l~w0bAF?IocX!g`JNCqy_( zvKt+jiED_|H9Ef&_c&7D=-Mrj(bC!IzCfZwq`T2;yTpkr-Hm>SB{-(`GzOiQ@CfZ~ z4E-jN(K-v1Rv8!=Qh68{vY8kdgqWu!XfrS{fbwo47o^toVStQgMc9D6&cMJB!N9z!SVM8YaGrlYs%0q+v8n{1_(#187JTM#IGKaYCGQkCUMu z>^7LhcTNTdJ_ZH`7!4BVRAFFX<$@Rlqe0>zhsr|z38P`+dRz<)pn+`|4HI|ff_MN% r|Nk%i|3Ao}0H{G=dh*p^XU=7iu-DSvtQgYIIoT}k>g1?+Jw^rqG?F!v diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index cc1d973..6174ec1 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -1,17 +1,21 @@ // Every time this file changes it must te rebuilt, you need `rocm-llvm-dev` and `llvm-17` // `fdenormal-fp-math=dynamic` is required to make functions eligible for inlining -// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc +// /opt/rocm/llvm/bin/clang -std=c++20 -Xclang -fdenormal-fp-math=dynamic -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1100 && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1100\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | sed 's/ nneg / /g' | sed 's/ disjoint / /g' | sed '/__hip_cuid/d' | sed 's/external protected/external hidden/g' | sed 's/trunc nuw/trunc/' | sed 's/trunc nsw/trunc/' | llvm-as-17 - -o zluda_ptx_impl.bc && /opt/rocm/llvm/bin/llvm-dis zluda_ptx_impl.bc #include #include #include #include +#include #include #include #define SHARED_SPACE __attribute__((address_space(3))) #define CONSTANT_SPACE __attribute__((address_space(4))) +typedef _Float16 half16 __attribute__((ext_vector_type(16))); +typedef float float8 __attribute__((ext_vector_type(8))); + #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define DECLARE_ATTR(TYPE, NAME) \ @@ -624,4 +628,156 @@ typedef uint32_t ShflSyncResult __attribute__((ext_vector_type(2))); uint32_t x3 = load_single_matrix_trans(address, 24); return uint4::Native_vec_{x0, x1, x2, x3}; } + + static inline __device__ _Float16 top16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast((value >> 16) & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + static inline __device__ _Float16 bottom16_as_fp16(uint32_t value) { + uint16_t half_bits = static_cast(value & 0xFFFF); + return *reinterpret_cast<_Float16*>(&half_bits); + } + + static inline __device__ float bpermute_lane(int lane, float x) { + return __hip_ds_bpermutef(4 * lane, x); + } + static inline __device__ uint32_t bpermute_lane(int lane, uint32_t x) { + return __hip_ds_bpermute(4 * lane, x); + } + + static __device__ half16 shuffle_a(uint4::Native_vec_ a_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; // Lanes 0-15 (the other 16 lanes are a duplicate in w32 mode) + half16 aFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2 + int cudaTID = (vGPR % 4 + lane * 4) % 32; + uint32_t reg0, reg1; + // Select the two consecutive elements from a_reg: + if (cudaChunk == 0) { + reg0 = a_reg.x; + reg1 = a_reg.y; + } else { // cudaChunk==2 + reg0 = a_reg.z; + reg1 = a_reg.w; + } + uint32_t a_tmp0 = bpermute_lane(cudaTID, reg0); + uint32_t a_tmp1 = bpermute_lane(cudaTID, reg1); + uint32_t a_Frag_reg = (lane < 8) ? a_tmp0 : a_tmp1; + aFrag[2 * vGPR] = bottom16_as_fp16(a_Frag_reg); + aFrag[2 * vGPR + 1] = top16_as_fp16(a_Frag_reg); + } + return aFrag; + } + + static __device__ half16 shuffle_b(uint2::Native_vec_ b_reg) { + const unsigned lIdx = threadIdx.x; + const int lane = lIdx % 16; + half16 bFrag; + + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = vGPR / 4; // will be 0 or 1 + int cudaTID = vGPR % 4 + (lane * 4) % 64; + uint32_t reg = (cudaChunk == 0) ? b_reg.x : b_reg.y; + uint32_t b_Frag_reg = bpermute_lane(cudaTID, reg); + if (lane < 8) { + bFrag[2 * vGPR] = bottom16_as_fp16(b_Frag_reg); + bFrag[2 * vGPR + 1] = top16_as_fp16(b_Frag_reg); + } else { + bFrag[2 * vGPR] = 0.0f; + bFrag[2 * vGPR + 1] = 0.0f; + } + } + return bFrag; + } + + static __device__ float8 shuffle_c(float4::Native_vec_ c_reg) { + const int lIdx = (int)threadIdx.x; + float8 cFrag; + + // Loop over the eight vector GPRs. + for (int vGPR = 0; vGPR < 8; ++vGPR) { + int cudaChunk = (vGPR / 4) * 2; // will be 0 or 2: selects which pair of components to use. + int lIdx8 = (lIdx < 8) ? lIdx : lIdx - 8; + int cudaTID = (vGPR % 4) * 8 + lIdx8 / 2; + float ctmp0, ctmp1; + + if (cudaChunk == 0) { + ctmp0 = bpermute_lane(cudaTID, c_reg.x); + ctmp1 = bpermute_lane(cudaTID, c_reg.y); + } else { // cudaChunk == 2 + ctmp0 = bpermute_lane(cudaTID, c_reg.z); + ctmp1 = bpermute_lane(cudaTID, c_reg.w); + } + + // Select one of the two values based on the thread index's LSB. + cFrag[vGPR] = (lIdx & 1) ? ctmp1 : ctmp0; + + // Zero out for specific thread indices. + if ((lIdx > 7 && lIdx < 16) || (lIdx > 23 && lIdx < 32)) + cFrag[vGPR] = 0.0f; + } + return cFrag; + } + + static inline __device__ float4::Native_vec_ shuffle_d(float8 dFrag) { + const int lIdx = (int)threadIdx.x; + float4::Native_vec_ d_out; + + for (int cChunk = 0; cChunk < 4; ++cChunk) { + int r_vGPR = (cChunk / 2) * 4; + int add8 = (lIdx & 0x4) ? 8 : 0; + int r_lIdx = (cChunk % 2) + (lIdx % 8) * 2 + add8; + float d_tmp0 = bpermute_lane(r_lIdx, dFrag[r_vGPR]); + float d_tmp1 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 1]); + float d_tmp2 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 2]); + float d_tmp3 = bpermute_lane(r_lIdx, dFrag[r_vGPR + 3]); + float val; + if (lIdx < 8) { + val = d_tmp0; + } else if (lIdx < 16) { + val = d_tmp1; + } else if (lIdx < 24) { + val = d_tmp2; + } else { + val = d_tmp3; + } + if (cChunk == 0) d_out.x = val; + else if (cChunk == 1) d_out.y = val; + else if (cChunk == 2) d_out.z = val; + else d_out.w = val; + } + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_f16_f16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } + + float4::Native_vec_ FUNC(mma_sync_aligned_m16n8k16_row_col_f32_bf16_bf16_f32)(uint4::Native_vec_ a_reg, uint2::Native_vec_ b_reg, float4::Native_vec_ c_reg) { + // Reshuffle from Nvidia-like register layout to AMD layout: + half16 aFrag = shuffle_a(a_reg); + half16 bFrag = shuffle_b(b_reg); + float8 cFrag = shuffle_c(c_reg); + + // Call the (built‐in) 16x16 MMA instruction. It returns a float8. + float8 dFrag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aFrag, bFrag, cFrag); + + // Unshuffle back into Nvidia expected float4 result + float4::Native_vec_ d_out = shuffle_d(dFrag); + + return d_out; + } } diff --git a/ptx/src/pass/insert_post_saturation.rs b/ptx/src/pass/insert_post_saturation.rs index 904bf37..525ae15 100644 --- a/ptx/src/pass/insert_post_saturation.rs +++ b/ptx/src/pass/insert_post_saturation.rs @@ -197,7 +197,9 @@ fn run_instruction<'input>( | ast::Instruction::Xor { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => result.push(Statement::Instruction(instruction)), + | ast::Instruction::GridDepControl { .. } + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => result.push(Statement::Instruction(instruction)), ast::Instruction::Add { data: ast::ArithDetails::Float(ast::ArithFloat { diff --git a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs index a4c2dc4..d365e29 100644 --- a/ptx/src/pass/instruction_mode_to_global_mode/mod.rs +++ b/ptx/src/pass/instruction_mode_to_global_mode/mod.rs @@ -1855,7 +1855,9 @@ fn get_modes(inst: &ast::Instruction) -> InstructionModes { | ast::Instruction::AtomCas { .. } | ast::Instruction::Vote { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => InstructionModes::none(), + | ast::Instruction::GridDepControl { .. } + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => InstructionModes::none(), ast::Instruction::Add { data: ast::ArithDetails::Integer(_), .. diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index c811a53..144f5e6 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -153,6 +153,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { for (i, param) in method.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); + if let Some(align) = param.align { + unsafe { LLVMSetParamAlignment(value, align) }; + } unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); if method.is_kernel { @@ -519,6 +522,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::CpAsyncCommitGroup {} => Ok(()), // nop ast::Instruction::CpAsyncWaitGroup { .. } => Ok(()), // nop ast::Instruction::CpAsyncWaitAll { .. } => Ok(()), // nop + ast::Instruction::GridDepControl { .. } => Ok(()), // nop // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bar { .. } @@ -529,7 +533,8 @@ impl<'a> MethodEmitContext<'a> { | ast::Instruction::Vote { .. } | ast::Instruction::Nanosleep { .. } | ast::Instruction::ReduxSync { .. } - | ast::Instruction::LdMatrix { .. } => return Err(error_unreachable()), + | ast::Instruction::LdMatrix { .. } + | ast::Instruction::Mma { .. } => return Err(error_unreachable()), } } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0b9ef79..4f87dc3 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -51,7 +51,7 @@ quick_error! { } /// GPU attributes needed at compile time. -#[derive(serde::Serialize)] +#[derive(Copy, Clone)] pub struct Attributes { /// Clock frequency in kHz. pub clock_rate: u32, diff --git a/ptx/src/pass/replace_instructions_with_functions.rs b/ptx/src/pass/replace_instructions_with_functions.rs index 19e16e7..f7c976e 100644 --- a/ptx/src/pass/replace_instructions_with_functions.rs +++ b/ptx/src/pass/replace_instructions_with_functions.rs @@ -351,6 +351,35 @@ fn run_instruction<'input>( let name = "sqrt_rn_ftz_f32"; to_call(resolver, fn_declarations, name.into(), i)? } + i @ ptx_parser::Instruction::Mma { + data: + ast::MmaDetails { + alayout, + blayout, + dtype_scalar, + atype_scalar, + btype_scalar, + ctype_scalar, + }, + .. + } => { + let name = format!( + "mma_sync_aligned_m16n8k16_{}_{}_{}_{}_{}_{}", + match alayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + match blayout { + ast::MatrixLayout::Row => "row", + ast::MatrixLayout::Col => "col", + }, + scalar_to_ptx_name(dtype_scalar), + scalar_to_ptx_name(atype_scalar), + scalar_to_ptx_name(btype_scalar), + scalar_to_ptx_name(ctype_scalar), + ); + to_call(resolver, fn_declarations, name.into(), i)? + } i @ ptx_parser::Instruction::Sqrt { data: ast::RcpData { diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 9fecba3..1bc622c 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -3,8 +3,8 @@ use super::{ StateSpace, VectorPrefix, }; use crate::{ - FunnelShiftMode, MatrixNumber, MatrixShape, Mul24Control, PtxError, PtxParserState, Reduction, - ShiftDirection, ShuffleMode, VoteMode, + FunnelShiftMode, MatrixLayout, MatrixNumber, MatrixShape, Mul24Control, PtxError, + PtxParserState, Reduction, ShiftDirection, ShuffleMode, VoteMode, }; use bitflags::bitflags; use derive_more::Display; @@ -721,6 +721,30 @@ ptx_parser_macros::generate_instruction_type!( space: { data.state_space }, } } + }, + GridDepControl { + data: crate::GridDepControlAction, + }, + Mma { + data: MmaDetails, + arguments: { + dst: { + repr: T, + type: { data.dtype() }, + }, + src1: { + repr: T, + type: { data.atype() }, + }, + src2: { + repr: T, + type: { data.btype() }, + }, + src3: { + repr: T, + type: { data.ctype() }, + } + } } } ); @@ -2378,3 +2402,27 @@ pub struct ReduxSyncData { pub type_: ScalarType, pub reduction: Reduction, } + +pub struct MmaDetails { + pub alayout: MatrixLayout, + pub blayout: MatrixLayout, + pub dtype_scalar: ScalarType, + pub atype_scalar: ScalarType, + pub btype_scalar: ScalarType, + pub ctype_scalar: ScalarType, +} + +impl MmaDetails { + pub fn dtype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } + pub fn atype(&self) -> Type { + Type::Vector(4, ScalarType::U32) + } + pub fn btype(&self) -> Type { + Type::Vector(2, ScalarType::U32) + } + pub fn ctype(&self) -> Type { + Type::Vector(4, ScalarType::F32) + } +} diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 4253ae6..a4f9080 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1862,6 +1862,9 @@ derive_parser!( #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum MatrixNumber { } + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] + pub enum MatrixLayout { } + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov mov{.vec}.type d, a => { Instruction::Mov { @@ -3897,6 +3900,37 @@ derive_parser!( .type: ScalarType = {.b16, .b8}; // .dst_fmt = { .b8x16 }; // .src_fmt = { .b6x16_p32, .b4x16_p64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol + griddepcontrol.action => { + Instruction::GridDepControl { + data: action + } + } + .action: GridDepControlAction = { .launch_dependents, .wait }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma + mma.sync.aligned.m16n8k16.alayout.blayout.dtype.bf16.bf16.ctype d, a, b, c => { + if dtype != ScalarType::F32 || ctype != ScalarType::F32 { + state.errors.push(PtxError::Todo); + } + Instruction::Mma { + data: MmaDetails { + alayout, + blayout, + dtype_scalar: dtype, + atype_scalar: ScalarType::BF16, + btype_scalar: ScalarType::BF16, + ctype_scalar: ctype, + }, + arguments: MmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + + .alayout: MatrixLayout = {.row}; + .blayout: MatrixLayout = {.col}; + .ctype: ScalarType = {.f16, .f32}; + .dtype: ScalarType = {.f16, .f32}; ); #[cfg(test)] diff --git a/ptxas/src/main.rs b/ptxas/src/main.rs index 0ffe841..f8a52aa 100644 --- a/ptxas/src/main.rs +++ b/ptxas/src/main.rs @@ -1,6 +1,4 @@ -use bpaf::{any, doc::Style, Bpaf, Parser}; -use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600}; -use std::{ffi::CStr, mem}; +use bpaf::{any, choice, doc::Style, literal, Bpaf, Parser}; #[derive(Debug, Clone, Bpaf)] #[allow(dead_code)] @@ -12,6 +10,8 @@ pub struct Options { #[bpaf(short, long)] verbose: bool, #[bpaf(external)] + lineinfo: bool, + #[bpaf(external)] gpu_name: String, #[bpaf(long, short('O'), fallback(3))] opt_level: usize, @@ -19,48 +19,32 @@ pub struct Options { input: String, } +fn lineinfo() -> impl Parser { + choice(["-lineinfo", "--lineinfo"].into_iter().map(|s| { + literal(s) + .anywhere() + .optional() + .map(|_| true) + .fallback(false) + .boxed() + })) +} + // #[bpaf(long, long("gpu_name"), fallback_with(default_arch))] fn gpu_name() -> impl Parser { any("", move |s: String| { - Some(s.strip_prefix("-arch=")?.to_owned()) + Some( + s.strip_prefix("-arch=") + .or_else(|| s.strip_prefix("--gpu-name="))? + .to_owned(), + ) }) - .metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)]) + .metavar(&[("--gpu-name=", Style::Literal), ("SM", Style::Metavar)]) .anywhere() .fallback_with(|| Ok::("sm_52".to_string())) } fn main() { let options = options().run(); - let comgr = comgr::Comgr::new().unwrap(); - unsafe { hip_runtime_sys::hipInit(0) }.unwrap(); - let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; - let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props); - let input = std::fs::read_to_string(options.input).unwrap(); - let ast = ptx_parser::parse_module_checked(&input).unwrap(); - let llvm = ptx::to_llvm_module( - ast, - ptx::Attributes { - clock_rate: clock_rate as u32, - }, - |_| {}, - ) - .unwrap(); - let elf_binary = comgr::compile_bitcode( - &comgr, - gpu_arch, - &*llvm.llvm_ir.write_bitcode_to_memory(), - &*llvm.linked_bitcode(), - &*llvm.attributes_ir.write_bitcode_to_memory(), - None, - ) - .unwrap(); - std::fs::write(options.output, elf_binary).unwrap(); -} - -fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) { - unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap(); - let gcn_arch_name = &dev_props.gcnArchName; - let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; - let gcn_arch_name = gcn_arch_name.to_str(); - (gcn_arch_name.unwrap(), dev_props.clockRate) + std::fs::copy(&options.input, &options.output).unwrap(); } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 6816994..ed8bb8c 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -89,7 +89,15 @@ pub(crate) fn get_attribute( *pi = 32; return Ok(()); } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER => { + // TODO: maintain a table, certain RDNAs are 1/16, some are 1/32 + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => { + *pi = 32; + return Ok(()); + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_TCC_DRIVER + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED => { *pi = 0; return Ok(()); } @@ -211,9 +219,6 @@ pub(crate) fn get_attribute( CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE => { return get_device_prop(pi, dev_idx, |props| props.persistingL2CacheMaxSize) } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO => { - return get_device_prop(pi, dev_idx, |props| props.singleToDoublePrecisionPerfRatio) - } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE => { return get_device_prop(pi, dev_idx, |props| props.accessPolicyMaxWindowSize) } diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 737f5c3..ad8310e 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -487,9 +487,9 @@ pub(crate) unsafe fn occupancy_max_active_blocks_per_multiprocessor_with_flags( dynamic_smem_size: usize, flags: ::core::ffi::c_uint, ) -> hipError_t { - hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( num_blocks, - func.0.cast(), + func, block_size, dynamic_smem_size, flags, diff --git a/zluda/src/impl/hipfix.rs b/zluda/src/impl/hipfix.rs new file mode 100644 index 0000000..f957849 --- /dev/null +++ b/zluda/src/impl/hipfix.rs @@ -0,0 +1,12 @@ +// There's a bug in hipDrvPointerGetAttributes where it returns +// HIP_ERROR_INVALID_VALUE if the pointer is null. It works correctly for any +// other invalid pointer +pub(crate) fn get_attributes( + ptr: hip_runtime_sys::hipDeviceptr_t, +) -> hip_runtime_sys::hipDeviceptr_t { + if ptr.0.is_null() { + hip_runtime_sys::hipDeviceptr_t(usize::MAX as _) + } else { + ptr + } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index f73a972..60ecb80 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -7,6 +7,7 @@ pub(super) mod driver; pub(super) mod event; pub(super) mod function; pub(super) mod graph; +pub(super) mod hipfix; pub(super) mod kernel; pub(super) mod library; pub(super) mod memory; diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 506f824..da7c145 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -20,6 +20,10 @@ impl ZludaObject for Module { } } +static EMPTY_PTX: &str = ".version 6.5 +.target sm_30 +.address_size 64"; + // get_ptx takes an `image` that can be anything we support and returns a // String containing a ptx extracted from `image`. fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result, CUerror> { @@ -58,11 +62,17 @@ fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option> { pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result { let global_state = driver::global_state()?; - let text = get_ptx(library)?; + let maybe_ptx = get_ptx(library); + let text = if cfg!(debug_assertions) { + maybe_ptx? + } else { + maybe_ptx.unwrap_or_else(|_| Cow::Borrowed(EMPTY_PTX)) + }; let hip_properties = get_hip_properties()?; let gcn_arch = get_gcn_arch(&hip_properties)?; - let attributes = ptx::Attributes { + let attributes = ExtraCacheAttributes { clock_rate: hip_properties.clockRate as u32, + is_debug: cfg!(debug_assertions), }; let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| { let cache = zluda_cache::ModuleCache::open(p)?; @@ -84,6 +94,12 @@ pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result() -> Result { let hip_dev = super::context::get_current_device()?; let mut props = unsafe { mem::zeroed() }; @@ -100,7 +116,7 @@ fn get_cache_key<'a, 'b>( global_state: &'static driver::GlobalState, isa: &'a str, text: &str, - attributes: &ptx::Attributes, + attributes: &impl serde::Serialize, ) -> Option> { // Serialization here is deterministic. When marking a type with // #[derive(serde::Serialize)] the derived implementation will just write @@ -129,7 +145,7 @@ fn load_cached_binary( fn compile_from_ptx_and_cache( comgr: &comgr::Comgr, gcn_arch: &str, - attributes: ptx::Attributes, + attributes: ExtraCacheAttributes, text: &str, cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>, ) -> Result, CUerror> { @@ -138,7 +154,14 @@ fn compile_from_ptx_and_cache( } else { ptx_parser::parse_module_unchecked(text) }; - let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?; + let llvm_module = ptx::to_llvm_module( + ast, + ptx::Attributes { + clock_rate: attributes.clock_rate, + }, + |_| {}, + ) + .map_err(|_| CUerror::UNKNOWN)?; let elf_module = comgr::compile_bitcode( comgr, gcn_arch, diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 8eda15e..6541fce 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -2,7 +2,7 @@ use cuda_types::cuda::*; use hip_runtime_sys::*; use std::{ffi::c_void, ptr}; -use crate::r#impl::driver; +use crate::r#impl::{driver, hipfix}; // TODO: handlehipMemoryTypeUnregistered fn to_cu_memory_type(cu: hipMemoryType) -> Result { @@ -59,7 +59,12 @@ pub(crate) unsafe fn get_attributes( data: &mut *mut ::core::ffi::c_void, ptr: hipDeviceptr_t, ) -> CUresult { - hipDrvPointerGetAttributes(num_attributes, attributes, data, ptr)?; + hipDrvPointerGetAttributes( + num_attributes, + attributes, + data, + hipfix::get_attributes(ptr), + )?; let attributes = std::slice::from_raw_parts_mut(attributes, num_attributes as usize); let data = std::slice::from_raw_parts_mut(data, num_attributes as usize); for (attr, data_ptr) in attributes.iter().copied().zip(data.iter().copied()) { @@ -88,7 +93,7 @@ mod tests { use crate::tests::CudaApi; use cuda_macros::test_cuda; use cuda_types::cuda::*; - use std::{ffi::c_void, mem, ptr}; + use std::{ffi::c_void, i32, mem, ptr, usize}; #[test_cuda] pub unsafe fn unknown_ptr_attribute(api: impl CudaApi) { @@ -162,4 +167,47 @@ mod tests { ); assert_eq!(context, CUcontext(ptr::null_mut())); } + + #[test_cuda] + pub unsafe fn null_ptr_attributes_success(api: impl CudaApi) { + api.cuInit(0); + api.cuCtxCreate_v2(&mut mem::zeroed(), 0, 0); + let mut context = CUcontext(1 as _); + let mut mem_type = mem::transmute::<_, CUmemorytype>(u32::MAX); + let mut dev_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut host_ptr = mem::transmute::<_, *mut c_void>(usize::MAX); + let mut is_managed = true; + let mut ordinal = i32::MAX; + let mut attrs = [ + CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED, + CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + ]; + let mut values = [ + std::ptr::from_mut(&mut context).cast::(), + std::ptr::from_mut(&mut mem_type).cast::(), + std::ptr::from_mut(&mut dev_ptr).cast::(), + std::ptr::from_mut(&mut host_ptr).cast::(), + std::ptr::from_mut(&mut is_managed).cast::(), + std::ptr::from_mut(&mut ordinal).cast::(), + ]; + assert_eq!( + CUresult::SUCCESS, + api.cuPointerGetAttributes_unchecked( + attrs.len() as u32, + attrs.as_mut_ptr(), + values.as_mut_ptr(), + CUdeviceptr_v2(ptr::null_mut()) + ) + ); + assert_eq!(context, CUcontext(ptr::null_mut())); + assert_eq!(mem_type, CUmemorytype(0)); + assert_eq!(dev_ptr, ptr::null_mut()); + assert_eq!(host_ptr, ptr::null_mut()); + assert_eq!(is_managed, false); + assert_eq!(ordinal, -2); + } } diff --git a/zluda_ml/src/impl_unix.rs b/zluda_ml/src/impl_unix.rs index 93d04e3..55437a6 100644 --- a/zluda_ml/src/impl_unix.rs +++ b/zluda_ml/src/impl_unix.rs @@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint) rsmi_num_monitor_devices(device_count) } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?; + let bdfid = pci.to_bdfid(); + let mut device_count = 0; + rsmi_num_monitor_devices(&mut device_count)?; + for dv_ind in 0..device_count { + let mut curr_bdfid = 0; + rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?; + if curr_bdfid == bdfid { + *device = Device { _index: dv_ind }.wrap(); + return nvmlReturn_t::SUCCESS; + } + } + nvmlReturn_t::ERROR_NOT_FOUND +} + +#[derive(Clone, Copy)] +struct PciBusId { + domain: u16, + bus: u8, + device: u8, + function: u8, +} +impl PciBusId { + fn to_bdfid(self) -> u64 { + ((self.domain as u64) << 32) + | ((self.bus as u64) << 8) + | ((self.device as u64) << 3) + | (self.function as u64) + } +} + +fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option { + let s = id.to_str().ok()?.trim(); + let mut domain: u16 = 0; + let mut rest = s; + if let Some(colon1) = s.find(':') { + if colon1 == 4 { + domain = hex_u16(&s[..4])?; + rest = &s[5..]; + } + } + let mut parts = rest.split(':'); + let bus_part = parts.next()?; + let tail = parts.next()?; + if parts.next().is_some() { + return None; + } + let mut dev_func = tail.split('.'); + let dev_part = dev_func.next()?; + let func_part = dev_func.next(); + let function = match func_part { + Some(f) => hex_u8(f)?, + None => 0, + }; + Some(PciBusId { + domain, + bus: hex_u8(bus_part)?, + device: hex_u8(dev_part)?, + function, + }) +} + +fn hex_u16(s: &str) -> Option { + if s.len() > 4 { + return None; + } + u16::from_str_radix(s, 16).ok() +} + +fn hex_u8(s: &str) -> Option { + if s.len() > 2 { + return None; + } + u8::from_str_radix(s, 16).ok() +} + pub(crate) unsafe fn device_get_field_values( _device: &Device, values_count: ::core::ffi::c_int, @@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2( *device = Device { _index: index }.wrap(); nvmlReturn_t::SUCCESS } + +#[cfg(test)] +mod tests { + #[test] + fn parse_pci_bus_id_full() { + let id = std::ffi::CString::new("0100:65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } + + #[test] + fn parse_pci_bus_id_no_func() { + let id = std::ffi::CString::new("0100:65:a0").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0); + } + + #[test] + fn parse_pci_bus_id_no_domain() { + let id = std::ffi::CString::new("65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } +} diff --git a/zluda_ml/src/impl_win.rs b/zluda_ml/src/impl_win.rs index 35f0dfc..205e792 100644 --- a/zluda_ml/src/impl_win.rs +++ b/zluda_ml/src/impl_win.rs @@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint crate::impl_common::unimplemented() } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + crate::impl_common::unimplemented() +} + pub(crate) unsafe fn device_get_field_values( _device: cuda_types::nvml::nvmlDevice_t, _values_count: ::core::ffi::c_int, @@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values( crate::impl_common::unimplemented() } -unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> { - crate::impl_common::unimplemented() -} - pub(crate) unsafe fn device_get_gpu_fabric_info( _device: cuda_types::nvml::nvmlDevice_t, _gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t, diff --git a/zluda_ml/src/lib.rs b/zluda_ml/src/lib.rs index fe8271c..40a7e30 100644 --- a/zluda_ml/src/lib.rs +++ b/zluda_ml/src/lib.rs @@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!( nvmlDeviceGetFieldValues, nvmlDeviceGetGpuFabricInfo, nvmlDeviceGetHandleByIndex_v2, + nvmlDeviceGetHandleByPciBusId_v2, nvmlInit, nvmlInitWithFlags, nvmlInit_v2, diff --git a/zluda_trace/src/log.rs b/zluda_trace/src/log.rs index b3f9716..9cbb9cc 100644 --- a/zluda_trace/src/log.rs +++ b/zluda_trace/src/log.rs @@ -303,6 +303,7 @@ pub(crate) enum ErrorEntry { }, NullPointer(&'static str), UnknownLibrary(CUlibrary), + SavedModule(String), } unsafe impl Send for ErrorEntry {} @@ -344,93 +345,94 @@ impl Display for ErrorEntry { match self { ErrorEntry::IoError(e) => e.fmt(f), ErrorEntry::CreatedDumpDirectory(dir) => { - write!( - f, - "Created trace directory {} ", - dir.as_os_str().to_string_lossy() - ) - } + write!( + f, + "Created trace directory {} ", + dir.as_os_str().to_string_lossy() + ) + } ErrorEntry::ErrorBox(e) => e.fmt(f), ErrorEntry::UnsupportedModule { - module, - raw_image, - kind, - } => { - write!( - f, - "Unsupported {} module {:?} loaded from module image {:?}", - kind, module, raw_image - ) - } + module, + raw_image, + kind, + } => { + write!( + f, + "Unsupported {} module {:?} loaded from module image {:?}", + kind, module, raw_image + ) + } ErrorEntry::MalformedModulePath(e) => e.fmt(f), ErrorEntry::NonUtf8ModuleText(e) => e.fmt(f), ErrorEntry::ModuleParsingError(file_name) => { - write!( - f, - "Error parsing module, log has been written to {}", - file_name - ) - } + write!( + f, + "Error parsing module, log has been written to {}", + file_name + ) + } ErrorEntry::NulInsideModuleText(e) => e.fmt(f), ErrorEntry::Lz4DecompressionFailure => write!(f, "LZ4 decompression failure"), ErrorEntry::ZstdDecompressionFailure(err_code) => write!(f, "Zstd decompression failure: {}", zstd_safe::get_error_name(*err_code)), ErrorEntry::UnexpectedBinaryField { - field_name, - expected, - observed, - } => write!( - f, - "Unexpected field {}. Expected one of: [{}], observed: {}", - field_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + field_name, + expected, + observed, + } => write!( + f, + "Unexpected field {}. Expected one of: [{}], observed: {}", + field_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::UnexpectedArgument { - arg_name, - expected, - observed, - } => write!( - f, - "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", - arg_name, - expected - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - observed - ), + arg_name, + expected, + observed, + } => write!( + f, + "Unexpected argument {}. Expected one of: {{{}}}, observed: {}", + arg_name, + expected + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + observed + ), ErrorEntry::InvalidEnvVar { - var, - pattern, - value, - } => write!( - f, - "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" - ), + var, + pattern, + value, + } => write!( + f, + "Unexpected value of environment variable {var}. Expected pattern: {pattern}, got value: {value}" + ), ErrorEntry::FunctionNotFound(cuda_function_name) => write!( - f, - "No function {cuda_function_name} in the underlying library" - ), + f, + "No function {cuda_function_name} in the underlying library" + ), ErrorEntry::UnexpectedExportTableSize { expected, computed } => { - write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") - } + write!(f, "Table length mismatch. Expected: {expected}, got: {computed}") + } ErrorEntry::IntegrityCheck { original, overriden } => { - write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") - } + write!(f, "Overriding integrity check hash. Original: {original:?}, overriden: {overriden:?}") + } ErrorEntry::NullPointer(type_) => { - write!(f, "Null pointer of type {type_} encountered") - } + write!(f, "Null pointer of type {type_} encountered") + } ErrorEntry::UnknownLibrary(culibrary) => { - write!(f, "Unknown library: ")?; - let mut temp_buffer = Vec::new(); - CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); - f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) - } + write!(f, "Unknown library: ")?; + let mut temp_buffer = Vec::new(); + CudaDisplay::write(culibrary, "", 0, &mut temp_buffer).ok(); + f.write_str(&unsafe { String::from_utf8_unchecked(temp_buffer) }) + } + ErrorEntry::SavedModule(file) => write!(f, "Saved module to {file}"), } } } diff --git a/zluda_trace/src/trace.rs b/zluda_trace/src/trace.rs index e71aacd..f397d34 100644 --- a/zluda_trace/src/trace.rs +++ b/zluda_trace/src/trace.rs @@ -128,12 +128,11 @@ impl StateTracker { fn_logger: &mut FnCallLog, type_: &'static str, ) { - fn_logger.log_io_error(self.writer.save_module( - self.library_counter, - index, - submodule, - type_, - )); + fn_logger.try_(|fn_logger| { + self.writer + .save_module(fn_logger, self.library_counter, index, submodule, type_) + .map_err(ErrorEntry::IoError) + }); if type_ == "ptx" { match CString::new(submodule) { Err(e) => fn_logger.log(log::ErrorEntry::NulInsideModuleText(e)), @@ -323,6 +322,7 @@ impl DumpWriter { fn save_module( &self, + fn_logger: &mut FnCallLog, module_index: usize, submodule_index: Option<(usize, Option)>, buffer: &[u8], @@ -332,9 +332,13 @@ impl DumpWriter { None => return Ok(()), Some(d) => d.clone(), }; - dump_file.push(Self::get_file_name(module_index, submodule_index, kind)); - let mut file = File::create_new(dump_file)?; - file.write_all(buffer)?; + let file_name = Self::get_file_name(module_index, submodule_index, kind); + dump_file.push(&file_name); + { + let mut file = File::create_new(dump_file)?; + file.write_all(buffer)?; + } + fn_logger.log(ErrorEntry::SavedModule(file_name)); Ok(()) } @@ -349,7 +353,7 @@ impl DumpWriter { Some(d) => d.clone(), }; log_file.push(Self::get_file_name(module_index, submodule_index, "log")); - let mut file = File::create(log_file)?; + let mut file = File::create_new(log_file)?; for error in errors { writeln!(file, "{}", error)?; }