From 81baecf2c821fcae29465ab9f0af85d810754182 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 25 Sep 2024 02:46:08 +0200 Subject: [PATCH] Add ptx_impl bitcode module --- comgr/src/lib.rs | 68 +++++++++++++++++-------- ptx/lib/zluda_ptx_impl.bc | Bin 34052 -> 2660 bytes ptx/lib/zluda_ptx_impl.cpp | 18 +++++++ ptx/src/pass/deparamize_functions.rs | 63 ++++++++++++++++++++++- ptx/src/pass/fix_special_registers2.rs | 8 +-- ptx/src/pass/hoist_globals.rs | 2 +- ptx/src/pass/mod.rs | 32 +++++++----- ptx/src/test/spirv_run/mod.rs | 1 + 8 files changed, 152 insertions(+), 40 deletions(-) create mode 100644 ptx/lib/zluda_ptx_impl.cpp diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index f27a127..bdec0fb 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -79,6 +79,10 @@ impl ActionInfo { unsafe { amd_comgr_action_info_set_isa_name(self.get(), full_isa.as_ptr().cast()) } } + fn set_language(&self, language: amd_comgr_language_t) -> Result<(), amd_comgr_status_s> { + unsafe { amd_comgr_action_info_set_language(self.get(), language) } + } + fn get(&self) -> amd_comgr_action_info_t { self.0 } @@ -90,36 +94,56 @@ impl Drop for ActionInfo { } } -pub fn compile_bitcode(gcn_arch: &CStr, buffer: &[u8]) -> Result, amd_comgr_status_s> { +pub fn compile_bitcode( + gcn_arch: &CStr, + main_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, amd_comgr_status_s> { use amd_comgr_sys::*; let bitcode_data_set = DataSet::new()?; - let bitcode_data = Data::new( + let main_bitcode_data = Data::new( amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, c"zluda.bc", - buffer, + main_buffer, + )?; + bitcode_data_set.add(&main_bitcode_data)?; + let stdlib_bitcode_data = Data::new( + amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_BC, + c"ptx_impl.bc", + ptx_impl, + )?; + bitcode_data_set.add(&stdlib_bitcode_data)?; + let lang_action_info = ActionInfo::new()?; + lang_action_info.set_isa_name(gcn_arch)?; + lang_action_info.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?; + let linked_data_set = do_action( + &bitcode_data_set, + &lang_action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, )?; - bitcode_data_set.add(&bitcode_data)?; - let reloc_data_set = DataSet::new()?; let action_info = ActionInfo::new()?; action_info.set_isa_name(gcn_arch)?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, - action_info.get(), - bitcode_data_set.get(), - reloc_data_set.get(), - ) - }?; - let exec_data_set = DataSet::new()?; - unsafe { - amd_comgr_do_action( - amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, - action_info.get(), - reloc_data_set.get(), - exec_data_set.get(), - ) - }?; + let reloc_data_set = do_action( + &linked_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, + )?; + let exec_data_set = do_action( + &reloc_data_set, + &action_info, + amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, + )?; let executable = exec_data_set.get_data(amd_comgr_data_kind_t::AMD_COMGR_DATA_KIND_EXECUTABLE, 0)?; executable.copy_content() } + +fn do_action( + data_set: &DataSet, + action: &ActionInfo, + kind: amd_comgr_action_kind_t, +) -> Result { + let result = DataSet::new()?; + unsafe { amd_comgr_do_action(kind, action.get(), data_set.get(), result.get()) }?; + Ok(result) +} diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 2d194c40c4406fc81c0a2b832e37afad1928fce4..cbbf2dc61f1365e90155feebfacc3088b43d0d05 100644 GIT binary patch delta 1512 zcmX|BeP~-%6u&RY-c8#qA6+}b(qu1ft?khGlD27LGnc$$bEZOqa zE)9dcHr?LTz*1ANh=08L0aFmWO~h~3rb^Mx#s9__qs47H5UnuAKyaR$E*?1N_wG5r z^KsA1%gh^!TTLCD|I*FgCWN*igoZV2`wOppH~Q^ApMBaeBo0#GvuY6PQOtHc-l*TF z;NEKBX7r9Gm)f|$-oefCrW`k8+|{|OXg|65>MXOe3K_362x$RZM^*X@+;pAcAeh<2 za$umou4C^HTR5&m-Skcy$~Uffy@ugC>SF8OD+ZA>*q((fgoca=y##1FV14iqggT_W z)P(k+e^F6#!fA?9IDa)GGd}F=kx@=yavQ#J;Po9&K=**-n#f{SVm@-q8PRLTyh_e|y zIi^gU!zoru7A5Drl&%1C16}j#DTX#ZkZa>9y{#pU;(+7t>KKc(sqky9bsa*_!+pqh zw@n8*W>8NLbAds1p>Bz6w=_n#D3GainhxeR;yA0v69{J+JRYoB7AGt8hDbumnvE5; z6JDGR;&g?+4=FNE1aT~=OMLM(wWsjJ6vrrM9eaWBpQ8DS_LThl_D%?1~%y%z#F7OziK=s%0vEL=&%JFJ?n za6>P8*nN1eCdUbbY-EFo^DO2b038Tc% zOYyjrmaFe^erm?&UMss6{kw~P_hs*jZ2ob3)@Dwy=_mz$sJuf?Lc3U^&=&`%I8K=`dGVF5o zxceO8KBv2E?|)a>Kd0O+=RI?NSFucfvI2t;{3oDoDxZfS;kG`vS{$NR6d0bvnQ(}Y z9!I8W3boO>eTM0vSF;2xfqC>DsC=}cjXnhdQRAa26C@=drUZ0o^{q4LeO0G^Kb)W# zbg%<9D(5n!IpG9BNz*}4;yVONgrExa3vvJ%1vTWroB(YEC5?#G2;`Y0)T9)~_+;#6(EpJH>Yxyj!*H7pK)6j8nVjdd(eV+1kEb(y zBo>eILMR=X2*-rf(P(rcW=n>J5k46{5=o`QNn503NN;OT!I?jnj*Rg9@zBvIeRUq8D~ZuW=26BV4w@Ap(zevI0&>ej24;I2|35qsgZiQ z=DM&{M}4k+PkiFSNuCC^@qSTJxBR(C)^v3}`x=w#sME|!n$?(cN9(ft*!|rQ_imPu zY|zw%?(uBVddUNcJL+I_w63Kwa?+Ab4HN2)Em602Y8j2XS>Ys3iCbDd{b$AIu=TUu@|(@lunjqXFiKV? ziEXC%8{*-Z&a5g)vstc|lVe>iVVAN$<){WTvQj(X*4B z@^q1|o=EqyB6FlE{o%Kkr#*SoQwrz%_uAd^GQ(!WHLs*XkWa%Uk<3tHe`LGXXpyux z8m=bcvWe}iW+8D~H<9Rtx#J4d3=1W1(hxF(AtdHk3Fp3t2Qw%pN^-Z{vxfClIX$(- z?pnKNom^IxTzgfwKcMJ#qQg4WsYW9HG%9MJ`#J#83#f7nYgz* zV&7~DKNPPWU}8yaqo)>vG|6h^p2|E=x>L5UMpm2TsRR#=o^-XRRtLstv~{_%JCkHY zvnn^L|)bCu@`$B#p0S8Fb{{5ts~&@Ftqg1Sg?1<8GGP}wv&rJ)E;-hsvS5I zd#ok=%lO!CF0Pkxik)`Ij~JZWowAx-*;>6Pz1U-?E#=cy5H}F7-Di!3bR9bqJJ6!t z2Of-ZhYZ?JJE$>;_Hlgd5F9Rxr0D8nFe}T5>Z?ZOmj=~wR^`mx?atgR zHbBWXQ%tKVW1C6eYKrMIWi*Jhce}H8b!6`lXY`5}5Ia$SIjAf%s36E7QGL>{Jkcxt zLaiKfDoJ*?DI0?Hh_mau^iP{&+TB^Z#o3#?sE2M-OuZ?)35?NbSv%ZUZ5BIa?b63_ zfc2}$_-5w<@{OhIIt^{!FUTs=&alknUfDwZE^)>-an?ao%t3c%k4fKT%Gkw0dt>Ur zo+)!jm%c{~mdvW-MEx&k31WsrrLJKUMOn9R*x8Fk{! zUEB}e4x1Je4;vmzu(Th+M{qT(Fg_41wLjaSCE&X6A*7xAdQSO&9C zB}_x!NzJ>JRX2I;ed?)n=rqgL7-Y2;Sv5RYd!Kj-oHH;#Rk1RdV;rzc=Y8TOu&a7x zE%T^dYev-i1i}tSv#2U9vRbulO|SF}OCdxqv#8F;KVy*9V99mfk`5;=P}xYQNr#!Q zxEgk`GT6ed)oQwnZXu>xX7`4zBr=DA?sX7ZwUZXML*!oRm(Ysg___)7(CV<&z0$|k zva){VC(W|8%@aEtZ+gir9{j~SrxSZ#IMj zeE^ev<^U#>!~9L=T-)M5epb)iDG1aKG%Ot zF8*@8|B3VOzq2xUVKeH)c~YZ2=JY*ruH5L>>U>X}Pb(Y6Qr{EjR;8oe_ry8yhEWOE z&CzG~6$`oSqB=w#Ei9`APtji)#Xyj8#<*`kKr+5wDE%HJOpU zr^DCe;uk&o8efxNoK?$IKAD`Ae%VL;zL|XYGm|#UH{Y~CGZK7z6;&G0w*?wbKj-mx`+$yE^!i?7MB${givUz1fgbaUf; zOg?#dd`7FU$q%yPWv~P=`ouYS`LTB2OuqPioi@)mlZRRx#S?rpId_V`$&IfCFuDEK z04A@R;BWGeB@<;)J|<@^oAVJX_UY~ipI@Xc@-;d1QoUQ@Yx0{Zm5sic{OCx65?f1Q-J|dJ&E%`A1DG64_BT2CXaJK7UkzY#_teUEUy}*nCMfm3CST0o z!^nJ1O77M7`$o5|{VA36Op`M}j$yKg3Mey3in^v&d5(f%e2ewpZR^3|0A zOxgmN{P#2Wt)s?A-sRfg>+U4fxX1u9eEWLKchP_FPBsinnfLZ}@%npb!-YT%149CA z__q1p=kB|gp*fQ{fsx_+;j7Z0ht4)YbJW0w@566-eZv|lA&OCl{_tN-Z-N^Utx6x> zJLlY(Uq=(7gZgm423zkPdr9G9G3y21g5F#4iZc5uD8LYgrO<0~<%{8Ogwxt1O+NVY zbAJVsQX+9eIWlZgJzqZ`jyTHX-lJEa0F#5Od7(%Tce#DqjbgTNPT=_`xv6_(qHq+K)|kdS_O z$*}Tl_>(Fpu^{ulRC)Cu#?y=%QprE)Ymg@=W+kcQPB1_GpqYF(JZ`g0Vt@k;Kl*9F za8+5jlN2NK27K_NmS-R!q;VOrvAJSJG_{B@b}leBsGwE$h@zTicWR^ zllUI`*Kz-fCQR@7f;2w~Z*+;qv9KtFkNIR;x;*mVPY}{Ag5$uok)*~ou~8G3G&YQ% zam_x7NwnDzzu7dgfvAfmPBKnzd~H&aQ7)QnjE^)KBNtgCb6X;p9Emg=C+BrA&%yn} zR)~=%cWU<>(Jqf1PQ>Ef#q8eTeN>k*-b<%Y^?Q-2XZ!vcatPH^t z5iBY7)HJ{<+U;`Lx?GPPmecIwH8mbP>xt@_sQkQ9c|@lw*Qt&UDhJKVH1(d?5976c z@v)uOxaGRL=}HwWy6h5X!MYTz;`9tguCd5&FP6=x>vOXCPLCay*(|a(NwPKP#zEM% zPES6p6ly$}uE>1}*ih1wV4SROx8qmK{|qKW|qaH!Hu^C3d`RgjL3=eMe$H1OvHInXT^ZUG5lI zv}-VB^mJvlb!5QG)J?Gr9jUo&BG+)=eQzyn zRwor@rKIL;YK-6Br<)~OJZVx_#ANAzsV9iZ+QMS1BfqMo)@Hue#4fR|v6U8T28E z($Y0$bL@_amA1la)=^P7XI?T?XH!#BGp=Npr)JDcTb{0;H!p2|QEFjI%Dgn|mFc#5 zDf*(+jOB&t1!-w1^K|o5XO@&#SCo~P*ej}S<<%@T_{)+qd#Q~z%{5(d1$%Rj(YU;{ z!dhKavAUqtR$OJxUuAQY+e-6`N~-ME>cZlDyR~qYtvbJ|s<^^YU0PAT!dkwLwO2r3 zn}e3GDp^r(EzPeY;|A|-z(3V)lD}&^v7Ph zIY>8sbaR+)`swBv-3-vp=X5hjH$!xDj&8oA8|OaQ)a_HaBeK#Ijr~3B!nYM2I+LYG z7g?-(L#Hl*%}y?QNt)ssx+!g!zM~M|4;yz(TgN=N;g&R#rdUjOT}L-d=;jUWlC(SS z)-G{aUh9}HYtj}unxE)+Ou8+)dS(9;9rI-tgUqUv6_{m(x-N@sW$PQdJ}%mxrdZP~ z{opx0u3fhZ|O}MXO+9)E+*fJK!zUo~7Th6B123{3Sai?a29p@9bGx%(LTl_|?2U&_dRXOef|YuhR=bQ2)mNZfb#Fhk3F={o~Ycky$@t`54y4sSU0 zba5w@U&8wfYUS5<^1;aXrWpD*1FYJ^yAf8F5O^6xA|&GR&Xnuaikpe3SJD;UVUS*K z*L|Qa>TEMQ)C{W;=^AE8I*MJSCsoerw(3ZXWDnG%SMl4VtQjx9R#s3vokkqec!m)B zGl$q88yVvKMTR(kA@*kuvA;|C`)TLfrStbE<}bwOhc3kDhatrNE}g$WZM=j$e<8lU zUAq4MwEn$x{$9!--1!pk{WGlM!TqyoHr(QSm$@W68CW{nH`}N)Gjc=3GF`jfWNz%J zc3UJzrQ>rXPK&%6Ondp{} zcX43`vXvF@{jJ#`b;i!mRU{CJiEMbQj@4YRcHSzDS6>|=h5hB{UxRI6E%db)y+@+2 z$>22@i@!5qWrzRX>Br+s3&%FTe2o2vjj{i_oUzT%hJvw;Uq8nFH;%FYOZort^EWj4 zuN&j}CwTnO%s;{7hi3i>9zQhmPw@Dm8UMlKhi3i>9zQhmPw@C-J^nA{pUyvHoqx|5 z&p*NAhi3i>9zQhmPw@DmnSX-E56$=w9zQhmPw@DmnSX-E56%1&JpQHp)A=WOenOM~ zy?-VPy|2G?YUhk@=krq|8=2z$;gj5Q(xD~=^Wx=KM@TGNjPkm->KNEhf1}L%9@qxf zF68xB@c7OS&6vg~on!33bBz7BjIsZgG4}uH82f*8jQwBAKg>Tra{h)S|NF;y{s|sG zH1kjJ_@SA9g2#7;Wc~>rpM+%m2aoRz$@~*MJ_*VE6Fk0itjGVQ{KNb+*7@%qszUC-0DYe@{{F<;EX=Pci8iWWHo8lk}!sH^ID3JpjLBNZ;QYUB6IB z{nU~5=^yMyNQRF?U(ZRB{zQLek<6EG%(086D(<+Z<#Nedu4C%#TGRH+C9lvwy|uE0 zOavWW?-!zXfz6kLN4xxnevwO+GEZzitUf$Wp5HqV{`*JNpG)IY`1k1SBj2OLHW-=p ztY({0D~TE1HZR}N5rdrfoJEy>?qR|;-HkU%B_bNfR%0uS5oMD^>7q!Ng@kDq7ZE2M z%*qmy2>T->i@Z6dJ`P4RNn8RUCw0TWzZoH6lhft*!-r`Q(lxEjpx!>Qaf3O!tDd2q z^K{4j>kJJV^^I~@NAw*pNcN1tip`|cAivrGpU%dlc>CnG4bEQooG!ya$Ceg$=!f?w zBQ_e(|82RICBz8F3gjtKS_l4HWh-<`-j`=-<>cw|dO8;1!=A`(8(lSGvGOHF zu0?j+#tp;I&31Ucark+jQwGJn@6pmenxK4CEoF80j^E2lyuXAC%*ql1pZO3l7XkW$ z)AUSICrbB3K#rsZx*u9s-!Yp>F^lBlcny5;VNBf~h)Mg}H}CLg_5qI-Fz-WK)cUB| z4~g1~=Wshy&B^0lOS#T4w8g%JLQ*ov_!foNdSS1HFUfp^C22-=de@XTk`R!)L;9uIV4Vk2-&Ma;Fl_;c(vUgPiutY`VPjPYsf( zYB;M~Q4x>$R%&*SzrZOQp742TD7^4&R%jf%RKG+n{B_zd#r5P zi90Q7X#wk!iHl2*)qTQ3)Dra)qrzZ8Cyn>5+XVSUKWRCrIlIE12OV}Qv zcVu$YUXRFgT*b|BN7kIWmP>jivdM8hH|y)jwo`fB?58I8IhJwBM<<^_@f~M{31}^I}OQ5iyZr&f8clrbFS7&2uP7iQ-kHNhe%QR=c=L#NT^=HCdMf=&jGvTh~XW5!F8dve-@l9u9Tr1y; zZ#xs`D)l7votfz>e>vgYnb|J;i3H8rxh}^;iTbk(T-Ey%^Uh|w)|^eOIh*UMd3;*a z+3Q@j?@envdxPr^&-A{tx4Q0ndHT7t%Un+8#B_~jFyG~39-5)o46bz5G5craX$H$& z8~In6}g?~ zsK_7TIV$oVo}(iFhUcirlOXZ*3`a$N3(rxJyLgU@{4YF5MShs)sL02QhmVVjJe}vL z$P0OnioB8MsK{UCIV$o`d5(&FDvU*%2UO$>d5(&_lIN(%AK^JF^4EEeiu@eUQISuF z3kvmvihL>0QIY?S=cve^<~b_zfAJg@xfJeo>2Xn!r|}#Wxdk{K^Z@f;QTZl0qe z|CHyb$Te{FrSVabU(It=w(jJ zq9T8h=cvd(;yEgE)%f8!sK~Q;j*7g5=cvfH@EjF+7tc|V5Aqxpc`V#K(RQICU(9n< zv@if{CS?EBL5%Y^xA}qd?H-O zC`UzZ;yEhv)jUT<{y5K3k^hV5sK_UT4<8p5c{a~ckyr2>6?qHKQIWsHb5!Kxl*7kG zMLr)m+*%S;! z=cvdhz=xrvz;RKLXYm{rc^S`9kw3z7ROJ5+JQd=jB4;9oIV$oro}(hS@f;QTLp(=C z{u<9wk$=l`ROE9ahy9=;&*wQR^80v>ioA>GsK`(A92NQH@P8#}{!x+V@f;QT-8@G{ z{x_bZBL9r%sK~YOOXKvosK~G9IV$qIc#ewvFFZ#@-p_MXc^1!6k>3Iwz6dJv3V?9x5md_M zvk7tWG98~05;PCR3pd!!U%cCN{ z1vs2y1eJ2yu9d(OfTLph2HI zxR<&S?(1RAp}ZD;&d;wcU0r0&w^!HX(+|7L&$q&d;N?4PrPdl-QGQ{yHGhSz+~z1L z%r7Y~3{u)!Sh>2yVG~#yV%w}$LQ6wz>PC2MukRK1)feGgSZVYvtTg%-RvPnNwXW(S zdsv@3yu84AQ2GxhcR-Y!Z@_O-FTtG=vh!< zOthfFm}hUVT*wA`7E~BL3o4AB1rvk*)Hb#61+6J799Kh z@W3nGkH$VfJ}z7ulkGjlFW5@-Evz)ATUcq#H|-qZzJ--W-@;0xZ(*g;x3@2c=e^*a zgXtDj7+ni0jHU$@M$g`!5$ah`Ve~AhFnShL7(IJ?MW|;%h0(L1!suC0Vf0Kpg`f9? zjg1QyM$bYcpPxX4L_R-%T(B^XXWD~83VUtP)LpRf$kgF0=B{1dl%CN!4WI|9~l*w)Z0^dB`P7nD@PkMMt~ycr4@y%yvD6%MJo!+ldF=;Y-QCYWwszD;GbR% zyDzQwF9GkyRaMy>)yu6V@Xnm?J9TU94*1Vf%SRp!{}WZRz1lI7mlYJ2S0|TN__N}Y zk*o>^1MQTNY;^#(w!*P$g`;A%J-MVPxn^YPQG3<}-Batkr=;3e7BCp@@QZ~zhC?}Q x)*_g!R+Lo1|Jq6}NuM|T55(Z#h6(-$V8cB&2ZVmZ|LYo^@Z`c$`VJ}~{|9JiUZVg2 diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp new file mode 100644 index 0000000..937bda1 --- /dev/null +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -0,0 +1,18 @@ +// Every time this file changes it must te rebuilt, you need llvm-17: +// /opt/rocm/llvm/bin/clang -Wall -Wextra -Wsign-compare -Wconversion -x hip zluda_ptx_impl.cpp -nogpulib -O3 -mno-wavefrontsize64 -o zluda_ptx_impl.bc -emit-llvm -c --offload-device-only --offload-arch=gfx1010 && llvm-dis-17 zluda_ptx_impl.bc -o - | sed '/@llvm.used/d' | sed '/wchar_size/d' | sed '/llvm.module.flags/d' | sed 's/define hidden/define linkonce_odr/g' | sed 's/\"target-cpu\"=\"gfx1010\"//g' | sed -E 's/\"target-features\"=\"[^\"]+\"//g' | llvm-as-17 - -o zluda_ptx_impl.bc && llvm-dis-17 zluda_ptx_impl.bc + +#include +#include + +#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_ ## NAME + +extern "C" { + uint32_t FUNC(activemask)() { + return __builtin_amdgcn_read_exec_lo(); + } + + size_t __ockl_get_local_size(uint32_t) __device__; + uint32_t FUNC(sreg_ntid)(uint8_t member) { + return (uint32_t)__ockl_get_local_size(member); + } +} diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..6e0beab 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -94,7 +94,7 @@ fn run_method<'input>( .body .map(|statements| { for statement in statements { - run_statement(&remap_returns, &mut body, statement)?; + run_statement(resolver, &remap_returns, &mut body, statement)?; } Ok::<_, TranslateError>(body) }) @@ -110,6 +110,7 @@ fn run_method<'input>( } fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, result: &mut Vec, SpirvWord>>, statement: Statement, SpirvWord>, @@ -133,6 +134,66 @@ fn run_statement<'input>( } result.push(statement); } + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + let mut post_st = Vec::new(); + for ((type_, space), ident) in data + .input_arguments + .iter_mut() + .zip(arguments.input_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: *ident, + src: old_name, + }, + })); + } + } + for ((type_, space), ident) in data + .return_arguments + .iter_mut() + .zip(arguments.return_arguments.iter_mut()) + { + if *space == ptx_parser::StateSpace::Param { + *space = ptx_parser::StateSpace::Reg; + let old_name = *ident; + *ident = resolver + .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg))); + post_st.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: *ident, + }, + })); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + result.extend(post_st.into_iter()); + } statement => { result.push(statement); } diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 97f6356..3553139 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>( sreg_to_function, result: Vec::new(), }; - directives - .into_iter() - .map(|directive| run_directive(&mut visitor, directive)) - .collect::, _>>() + for directive in directives.into_iter() { + result.push(run_directive(&mut visitor, directive)?); + } + Ok(result) } fn run_directive<'a, 'input>( diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 753172a..718c052 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -5,7 +5,7 @@ pub(super) fn run<'input>( ) -> Result, SpirvWord>>, TranslateError> { let mut result = Vec::with_capacity(directives.len()); for mut directive in directives.into_iter() { - run_directive(&mut result, &mut directive); + run_directive(&mut result, &mut directive)?; result.push(directive); } Ok(result) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0e233ed..7ba9ed0 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -39,9 +39,8 @@ mod normalize_predicates; mod normalize_predicates2; mod resolve_function_pointers; -static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); -static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); -const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result { let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); @@ -220,6 +219,12 @@ pub struct Module { pub kernel_info: HashMap, } +impl Module { + pub fn linked_bitcode(&self) -> &[u8] { + ZLUDA_PTX_IMPL + } +} + struct GlobalStringIdResolver<'input> { current_id: SpirvWord, variables: HashMap, SpirvWord>, @@ -1975,7 +1980,7 @@ impl SpecialRegistersMap2 { let name = ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); let return_type = sreg.get_function_return_type(); - let input_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_input_type(); ( sreg, ast::MethodDeclaration { @@ -1988,14 +1993,17 @@ impl SpecialRegistersMap2 { array_init: Vec::new(), }], name: name, - input_arguments: vec![ast::Variable { - align: None, - v_type: input_type.into(), - state_space: ast::StateSpace::Reg, - name: resolver - .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), - array_init: Vec::new(), - }], + input_arguments: input_type + .into_iter() + .map(|type_| ast::Variable { + align: None, + v_type: type_.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }) + .collect::>(), shared_mem: None, }, ) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e15d6ea..60f5052 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -326,6 +326,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, &*module.llvm_ir, + module.linked_bitcode(), ) .unwrap(); let mut module = ptr::null_mut();