From 9ca1c2da5a1fcbcaab059ee190b74d90e6575007 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 5 Dec 2024 05:43:20 +0100 Subject: [PATCH] Resolve crashes --- ptx/lib/zluda_ptx_impl.bc | Bin 5360 -> 7524 bytes ptx/lib/zluda_ptx_impl.cpp | 11 ++++++ ptx/src/pass/insert_explicit_load_store.rs | 42 +++++++++++++++++++++ ptx/src/pass/mod.rs | 4 +- ptx/src/pass/replace_known_functions.rs | 38 +++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 6 +-- zluda/src/impl/memory.rs | 4 ++ zluda/src/impl/mod.rs | 1 + zluda/src/lib.rs | 1 + 9 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 ptx/src/pass/replace_known_functions.rs diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 24c20d8b1bd94be0ef590b36498b2ddb325a1a31..6cefc813eb60ce7f302e265090ab069682291723 100644 GIT binary patch delta 4634 zcmbuBe^is#8OQJYgYXix_2*?{m21;q9l?IXeCP9lfZG&QuuH7L)g2jSHpgpeb z%o`FADRpts4Xy3KqGi^e;^6c!k2_67UFw*lt@h~J?VzHwTdmZh>yMpv&l7lU?b-g> z?n(08`@Nt0+~+>`=6&x(&iLYBCjC~-uB?qz`mSB+-XrfcId_sCjWCzY6B}5kL!A*w zY+$=h>l~fF?sB?Z9bog_mWUV-tdY*pNFuw;^d^_3r*>Nn>+MD%|A79aVf`s*drpis z(v=~s)-(kzPl#g0F+IUjW*{VwA|xd^_6?L+StE17f}no_TTH)KCFc-cppq>7xlNo1V8EEFG*VPUnlZ6o%|ICe`8ejd64&h1?U`-(rP57 zF2=%3%D7YIQ*c&w$db!ssm5K~M@OUxv*TT9cA zk&dPGw#) z}2|JcejxcB~m|;g0@1~vxu!Iz!VMo3PZOyY&@v_RN$`#`AZ?y zw{;?XI>FaT+8Y8tL9`!_@qZmv!Qq&cc>*~r?qsA9rZ_NcK4eP?`Ij7JiaUqRqv$`S z)oUy6@)vgcmkfsCXx!h)$gv^Sm9H=)YArLSCrfed1+C1jZ=*~U6UQAqOnC#&U5Kx9 z@)Nh`D;)P11#CsV!{&bK+d2sp9kr#5`j-TlB_Z45q2a(9oox8_*~Ln zF!6sMRZWQ6>yq}eiJ!y?R2LoED^C70oJN!O<8C0wsD9(A4uPHN75;=VY30H2IVJbO z0wI+IX|e+TO6l;I(lt4)Moo#!Zs0R_3)wogTU4(&CyK9mNU*w-m9qamA0Yg>Kh-i`iVLW>a7ROx4&rS8~2 z_=q3c0&0{qHWb{b&yC?^cnD0D5HB0oYiYUWj>v8noMu_V5K07UDzbx{O=@yLg>5X+ zoLzJ3h&>+WpDLKNgSWP1iIAc1Chco~88I!IhDbDN)!ggES$|u?$llnso}wo1i0std zPiQ#~Qyndsvl$O2tH|ox3Bui6>3#Vi&5<|-izaR9Q^~}xQ>5s2zS5Q}LV}c1L^2CF z{kFRZiO^Iu9r<&*49SW1V0&EF11DLVdM?Rqr$~ZEvd)bTIt=IiZK&PD$*IRLq*m*x z;#QkEFg%dTy-ype%^PI?GCUmMg%RSfn|OGd!&9CNTvvxFZzkkd$F9i5LzoBOs+F?i z$7R|3^R?n?R^PVMo~);vS3yI4Lh6wtoi&cGjXZ6q- z_s~)wru{P$A<1Af9*6|XcyWG>>xI2dvWG8PSUGnjZC=x{mcz ziHEH4$`uYEdcXwevK&o{SMWvIGGraddVB+_6b&M3 zdKJ-E(d`S_h-M?ILsUR?LPoDLqLLTU*D%<3)4jlk(~%vIMYr)7?U!L-yO4EkLv|aw zU6YOri$-u3Q3s-)n}|+Y0Chfx=-U`}Sd8{Evg63g6awM9^^(?V7(;ms<<@W-$iBs#q&;T%2<1K-$ib)#KR_p+Lpg|Y zmoG}EA)R~%tU^BMF+W_Z*`c;`50NO)RGDaz9f^;N3vTV=Y^_uDfi7r6+%G zoI_Xt-lTo+u=|Vc^7;HgAW)m^eKf=QmgX+wJYrl(2tHBc+u(A^;J=p;mqRwExt1?2bNoO zFMQCc=`^%Tg4IH$Osxu^bUJofl*eRMEhjMhm%DO8IDZVDm=c&jHGs#?=84+{VYd1X zL(FR*X@WDg-Ozpbc&O3V&&|5mvqDJj9?BkC-xeV!5-g{a1SXBk;+PZzNr|R#R&=uE z-w7_v{~#lb&SsX406s%y6sZV9ct}t@AtB@Y$_`lRZykQ zY|?ej6K_bqfZP^UDHeRMAdwu(i8Bh1BTQKbj#hXlUWyca+5924RLKgb%}U} zoP&T|B3wG6frW!UCpTBc!jW)@YSi&I0NXm#Iv{xW1V-Utw^>Z)QK!%jm#ZxR@BrlV zLPmgz#I}yTBCG!q7a9Ad%r~pKaG#vip)TEWP@Zu_z18@JY`CWGa}CKUJ@MzZ9D#9I z+!{j8gH2*Ohj237f;V8_ux^Jh9@Mq)#XC^6g4a2!6J%P_%}hv zbEH^>|LrtCWnur70cWQ>6szdhrX49(arE^xuULiG!MVYJ6szb9K*xX-8};Lb6&Fqq zP^>bb40PNY#VYzn&@t7DMLoKgpr?WEowf?U7utBB6|3+dkq&#(y~YBpLHPhcJgr!M zq#@)eSojNFis|S{GEH5@8tYx2iV9algS)=5qS{qcJ9R;;sdhKk)VcE-8eR2`d7ecT z6`R*TSzA% c6&26aK2_za@HB3%z^h_~5)8)D(W|%r8x@x*H2?qr delta 2442 zcmcIlZ){Ul6hE);^}W}oov&-Rbz2LsY>aM>uq|OMe89IqY;`m`D$%JR9c%$%Xb}Q# zZuGSjS{4Ew9cDoihFJuG=$82b^n)!MWFutiGzJMK1Ex-m2_PS)G3vQp?*KK%FP`S+ z_ulXR&bhyP?`b=={7zvsU-JXoU9g48?Czew?#x)wdk9V%q#c~MK~q*ISpvNaq9t{p zU+Lm?qG70AxpI)60>7NRL#gH3sdMM?99DdCg-W?6^5`|rx*&S-Qd~?{> zC+vh#$NU=(+jWQia=`fDnUQbYv|m@}O*w3n%)E(&ef*^Hy2Ccj*r7cdZg&c5uh1@r zw>O5%)pDWjb)nrE-rh$RD3`+RCBhD;(6)JRn@i|)3h{yo+5h@Qrnzi70s{mfL8r~{8uG6@E`n#s-S-$VSAH! zo3AbfH*<%DG|)}6v7gvImXfPv;<0!RoPrrfW;zEPTVf})nQN`%bYg*E z&6%Sb7^~;&;eKEgm{ls!rzz?8=!}f|@|;T1dR;3sk7_tzfRPE!sQP%cXh}Kehj$o4 zBqKw0WNdTmCE$Ql3By4VRG;~IQi5KdyeES=%kvZ}AmN7?mXgkV+>78XN@)Oyv&4ZR zV7!kwyJtSgsNc`26hB&PdHk9No|&^jtX)}mTq^<#8ihzTK%9N-HGo?n@5w-CMaL33 z&?`{}h_lw)6$<1(;_N>E26*x6QbC-JJ(bDCKR{yrEJ+USg%kA}jXGz89GbSOo8=Ui zW~zF1gb!T&28rioAXdue^;)t`?&v{XLEUApTAv5`2N^IL%5ZKFJLSSqi~d{9&_-i| z@e{q7t1Kpd$tulT0P=XP6!iL<6L`E$GJqC&)bRPW2K^!~p(KI^u@$+0nj8F=xs!jN zE0`12JQ{*mdVojC;|J{&mE&j?Q4RIriP$^!VgxpyV$}j8_81QW z3>z=W{E(qqGAy%%tW({$j=-jDzR8uZotMz9RqvPr2yDt$))Qas^#}` zeI4+eA;xDE`=1triQv+3=_;?GM~Wy?RV9-$LE7KU&3fjD=UPA;znBW+P@jkI!^T<$ z&D5Y7d7t%anON(P9{d;JFDYjR)w!#%I>cyJioPTRK8@a z98Z^6yUtc1?iJ8#lMIlBMa^981E;aEj75rx6tk?pit-H3Ar5QB^^?<;V|LTDF)^NQ zVKYKXlo?=>6|(hpu<-i7i#SD@CX$+8_^)nyL8wNU1Z(l_`*fv%9% zc2l!|N#uUPtwZAd{oYNwu0U6;Bh&7ev-R82n*l4ArAZ4ynzOu1*{bp%@e#RI)sOL` zuuV;y;%9uXs;0w&bFf{tj0lZ)J23&H1ry&#OlNz1r0ZB5ChP$6J$K_E9omS41VL=( zDB5(m^g9BnC(erGg}oSd0TNZVt&0s+2#QLt2335+uA>4Db>@ljEXD z^}CTX<4B=DnV`c$T}YCaphy$?k;}kAk?=%+0XbrbBGrF^99e)O)qjVa4F|3B*h}bq?Ak%aJt_7GoW@ Y;Za0v!&6+WEvID2nfo~tv)bK%0y#uXXaE2J diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 329a810..7af9729 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -4,6 +4,7 @@ #include #include +#include #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME @@ -155,4 +156,14 @@ extern "C" __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup"); __builtin_amdgcn_s_barrier(); } + + void FUNC(__assertfail)(uint64_t message, + uint64_t file, + uint32_t line, + uint64_t function, + uint64_t char_size) + { + (void)char_size; + __assert_fail((const char *)message, (const char *)file, line, (const char *)function); + } } diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 60c4a14..702f733 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -122,6 +122,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction)); } + Statement::PtrAccess(ptr_access) => { + let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); + let statement = statement.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } s => { let new_statement = s.visit_map(visitor)?; result.extend(visitor.pre.drain(..).map(Statement::Instruction)); @@ -259,6 +266,41 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { Ok(ast::Instruction::Ld { data, arguments }) } + fn visit_ptr_access( + &mut self, + ptr_access: PtrAccess, + ) -> Result, TranslateError> { + let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) { + Some(RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + }) => (*old_space, *new_space, *name), + Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access), + }; + if ptr_access.state_space != old_space { + return Err(error_mismatched_type()); + } + // Propagate space changes in dst + let new_dst = self + .resolver + .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space))); + self.variables.insert( + ptr_access.dst, + RemapAction::LDStSpaceChange { + old_space, + new_space, + name: new_dst, + }, + ); + Ok(PtrAccess { + ptr_src: name, + dst: new_dst, + state_space: new_space, + ..ptr_access + }) + } + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { let old_space = match var.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ef131b4..c32cc39 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -22,6 +22,7 @@ mod normalize_identifiers2; mod normalize_predicates2; mod replace_instructions_with_function_calls; mod resolve_function_pointers; +mod replace_known_functions; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; @@ -42,9 +43,10 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result>, ptx_parser::ParsedOperand>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs new file mode 100644 index 0000000..56bb7e6 --- /dev/null +++ b/ptx/src/pass/replace_known_functions.rs @@ -0,0 +1,38 @@ +use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord}; + +pub(crate) fn run<'input>( + resolver: &GlobalStringIdentResolver2<'input>, + mut directives: Vec>, +) -> Vec> { + for directive in directives.iter_mut() { + match directive { + NormalizedDirective2::Method(func) => { + func.import_as = + replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take()); + } + _ => {} + } + } + directives +} + +fn replace_with_ptx_impl<'input>( + resolver: &GlobalStringIdentResolver2<'input>, + fn_name: &ptx_parser::MethodName<'input, SpirvWord>, + name: Option, +) -> Option { + let known_names = ["__assertfail"]; + match name { + Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)), + Some(name) => Some(name), + None => match fn_name { + ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) { + Some(super::IdentEntry { + name: Some(name), .. + }) => Some(format!("__zluda_ptx_impl_{}", name)), + _ => None, + }, + ptx_parser::MethodName::Kernel(..) => None, + }, + } +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f4b7921..e4171cd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -298,7 +298,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def let mut result = vec![0u8.into(); output.len()]; { let dev = 0; - let mut stream = ptr::null_mut(); + let mut stream = unsafe { mem::zeroed() }; unsafe { hipStreamCreate(&mut stream) }.unwrap(); let mut dev_props = unsafe { mem::zeroed() }; unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); @@ -308,9 +308,9 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def module.linked_bitcode(), ) .unwrap(); - let mut module = ptr::null_mut(); + let mut module = unsafe { mem::zeroed() }; unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap(); - let mut kernel = ptr::null_mut(); + let mut kernel = unsafe { mem::zeroed() }; unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap(); let mut inp_b = ptr::null_mut(); unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::()) }.unwrap(); diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 33d5a4e..18e58e7 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -38,3 +38,7 @@ pub(crate) fn get_address_range_v2( pub(crate) fn set_d32_v2(dst: hipDeviceptr_t, ui: ::core::ffi::c_uint, n: usize) -> hipError_t { unsafe { hipMemsetD32(dst, mem::transmute(ui), n) } } + +pub(crate) fn set_d8_v2(dst: hipDeviceptr_t, value: ::core::ffi::c_uchar, n: usize) -> hipError_t { + unsafe { hipMemsetD8(dst, value, n) } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 766b4a5..282f8d5 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -107,6 +107,7 @@ from_cuda_nop!( *const ::core::ffi::c_char, *mut ::core::ffi::c_void, *mut *mut ::core::ffi::c_void, + u8, i32, u32, usize, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 1f6a7ff..8efbd26 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -73,6 +73,7 @@ cuda_base::cuda_function_declarations!( cuPointerGetAttribute, cuMemGetAddressRange_v2, cuMemsetD32_v2, + cuMemsetD8_v2 ], implemented_in_function <= [ cuLaunchKernel,