From 5b2352723fb251b64317737167b609a0a11651a6 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 17 Sep 2021 16:24:25 +0000 Subject: [PATCH] Implement function pointers and activemask --- ptx/lib/zluda_ptx_impl.bc | Bin 30788 -> 31224 bytes ptx/lib/zluda_ptx_impl.cl | 5 + ptx/src/test/spirv_run/activemask.spvtxt | 20 ++-- ptx/src/test/spirv_run/func_ptr.ptx | 31 ++++++ ptx/src/test/spirv_run/func_ptr.spvtxt | 73 ++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 116 +++++++++++++++++++---- 7 files changed, 215 insertions(+), 31 deletions(-) create mode 100644 ptx/src/test/spirv_run/func_ptr.ptx create mode 100644 ptx/src/test/spirv_run/func_ptr.spvtxt diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6a2a51c8cbfd4bd5c99470835713d6285bfac043..175f4df8c13942df8c53e9608518b38a9e6eab69 100644 GIT binary patch delta 3142 zcmX@|f$_&@#tABnHz%rQGa5{ssKEGP;-&)gEi4Rcl^L2j5(0RX7+ejbuGhv+f6&;V z)L?Wrn4y@ZQIKJ_gL1<+1ttlWW{v~{1_p*e2?mCzn*$jund%RFOZ@-;|Gzy$Hp2l% z1xC>e>@F5tm@Vfp+gx#$zS3-c#@S+qv*ngX%LC5REzGuioUNA}wm8G=!2cwzzQE z`h>Ie9cQZv%(e@d8Q2->AM(Aw!1hIf{~-hahY5V|3ixs_@I7?kf5C8}+2)F~CBqS$ z3(gijhpo0S+MGCSeZkqXhuL;dv(1)9i!05RGn_3?I7^>#wtC@g$xy+Nqrk@)VJ2{b zK}Wiyg<%?ROAy0pw$z^>|1nNr;99}FfZt5w42a9$z!D}nIFBjFcr5ban8wLyteBz3 za<1h8<7O=u9p3taYz77gMg~R&1_lNO#s_Q%CS+~p<>zByFknz%U}0cjP++*jAQ7~5 zyJMvTgA7O!80RtYto(I*fsFzKAD9m!%o*B34D%w28CV#YKztB>z}9ett#MJfFatxR z0s{j(0|SGR+zbUtL7xddoNDgJHm?<|WEOAm>rr5^Hf&<}aD;>Dg0M12gMW`g!@{m$aC}o8BDI~U)3POAp5?e|YAx^4#E+l)0lY5ZEWR`v` zqrl`J`tghclOqg3YcCBwGO{Z-7Y)Q;^6EF!=*aMwo#_ zc7REo;uat^Kfq*!C5XKPOiEaR*cD*%2AH(428m!57ZczHB}WDZhC~Gh1~E_uVR7dA zF6t6u_UAefW&f=D$pIyIjhvBcuIK%MQWSn96YckF- z{52V;zkW@|8HT?m;|#-JlW~S2J?;99I}Cs04#VHL!|*rmF#L@>41ePe!{4~W@Hg%- zq>WwTD0vxFM+vhW>^L}a@;fih$=AKa*e13pNK~Ae{MoCGalzy&Z_oOKZq^nCmQ+S? z5dbOxV455kwkSxPd(v2#o1S68<8r7PFDjsPzYU~o4k)+$G^+{;QBa23zRl+xtY#Q|* z3IBLx-)JyPL?YR{(dd$h#bdullP=MSl zif2C>GnUoM&v^E`G3%QAj%WWGbJ`ScJmYN2+otg28BbHeGX;s~{7psE6fK@hHI*Dw zjCd~FRQ64=;<-js#WJNC&$XMXt|{$!uGdu4rhMbMWmDZYco5y>)5UmAb%(@*xTYS)FRC{r z9>h2GF}A4vka&>MG=XuMnuO$n?54?#2h<`Y6LOlSFe7?_{G<$LF)%QkfVR9Mr5P9m85kI{ptK%XGeRE3owa$UA2VM)sQtEHih%({ zgIbV`Dhv##r5G6G7!Wi}{G}8Fg9rly1B?cVbLcQI{Dq3cXqdR5G^AdI(J*n4^|3q* z43F3tSV3O-z4=4TQ7Nv(-po4*pnSmitf`MV}=19g$rUd3hjsu$-E!YxnIw-Jg$WfTY6k(=u z>i_@$|M?qO0vu!+YBw}S{QLj^e?7wzUarp|3FZxKASrnU1_nhu5VdkAcB~k%3Wxfq{X6@d4X`39Gj9o?>P? z0OEsj0ds4~;2#IZ^h7c!J zy%LfehLd}c!(@{|EF;5Y0mFDkhRFqB@&cGNFan880FwnKAa;N;h{P#w3X&Bt*PCo- zXvoMg`2(0AU@LF@uBc>zq~6t@Ja6tDu31zf_T-5_#=FQV~ngV-Cs3dSD|5>xnRH2!Lk+QYwt@o$676aE{G{~P3b z1b#4bH7aZokZ9s*RJtQz!NlLF(jpkaB-N<4N3fzvwo&7Y;0z|sMy)kMJDPMFb)E>_ zV6te`n27F$GTFvm7py%F8f zoY82rMC=B0PNUr&u^-L(jSf@9C0ME(oz93`wA3}aG)P3WbT+zekf>HB#AitySS7Xir`5SG28}mNM|7hcADp;T((azmebV0$Qov*2+ zK{29TvZ-u?Vnw@5Q^kXN#To7LO;r<=cC>3X)f`Z|(XQK6_d)4LyG2t2;{s&~i3aPY zCdLcO77`71O)ZQKDiIP5o=t6x8&oPJ8hn~M7$2z2kZAC4>SCOrx3&G7*;YzGcX7-PqDCNU|;|V zf0bfjFk)a}@B!6{3=9krV5Jax1_J}b3_gf}x-qS`ZM!I z$T2W5te0Y70MVeTno)&;;j9z`g9rnH28na%FfcrXioznUm3~7zyJVK { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e015062..39bd07e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; -use rspirv::binary::Assemble; +use rspirv::binary::{Assemble, Disassemble}; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc"); @@ -607,6 +607,7 @@ fn emit_directives<'input>( } } emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; + builder.select_block(None)?; builder.end_function()?; if let ( ast::MethodDeclaration { @@ -988,6 +989,7 @@ fn compute_denorm_information<'input>( Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} Statement::RepackVector(_) => {} + Statement::FunctionPointer(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1411,6 +1413,15 @@ fn extract_globals<'input, 'b>( fn_name, )?); } + Statement::Instruction(ast::Instruction::Activemask { arg }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Activemask { arg }, + fn_name, + )?); + } Statement::Instruction(ast::Instruction::Atom( details @ @@ -1596,6 +1607,21 @@ fn convert_to_typed_statements( for s in func { match s { Statement::Instruction(inst) => match inst { + ast::Instruction::Mov( + mov, + ast::Arg2Mov { + dst: ast::Operand::Reg(dst_reg), + src: ast::Operand::Reg(src_reg), + }, + ) if fn_defs.fns.contains_key(&src_reg) => { + if mov.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(TranslateError::MismatchedType); + } + result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + })); + } ast::Instruction::Call(call) => { let resolver = fn_defs.get_fn_sig_resolver(call.func)?; let resolved_call = resolver.resolve_in_spirv_repr(call)?; @@ -1724,7 +1750,7 @@ fn instruction_to_fn_call( let return_arguments_count = arguments .iter() .position(|(desc, _, _)| !desc.is_dst) - .unwrap_or(0); + .unwrap_or(arguments.len()); let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); let fn_id = register_external_fn_call( id_defs, @@ -1826,7 +1852,8 @@ fn normalize_labels( | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } - | Statement::RepackVector(..) => {} + | Statement::RepackVector(..) + | Statement::FunctionPointer(..) => {} } } iter::once(Statement::Label(id_def.register_intermediate(None))) @@ -1984,6 +2011,9 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::RepackVector(repack) => { insert_mem_ssa_statement_default(id_def, &mut result, repack)? } + Statement::FunctionPointer(func_ptr) => { + insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), _ => return Err(error_unreachable()), } @@ -2235,6 +2265,7 @@ fn expand_arguments<'a, 'b>( Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), } } Ok(result) @@ -2421,7 +2452,8 @@ fn insert_implicit_conversions( | s @ Statement::Variable(_) | s @ Statement::LoadVar(..) | s @ Statement::StoreVar(..) - | s @ Statement::RetValue(_, _) => result.push(s), + | s @ Statement::RetValue(..) + | s @ Statement::FunctionPointer(..) => result.push(s), } } Ok(result) @@ -2653,6 +2685,16 @@ fn emit_function_body_ops<'input>( iter::empty(), )?; } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + // TODO: implement properly + let zero = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U64), + &vec_repr(0u64), + )?; + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); + builder.copy_object(result_type, Some(*dst), zero)?; + } Statement::Instruction(inst) => match inst { ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?, ast::Instruction::Call(_) => unreachable!(), @@ -2975,14 +3017,13 @@ fn emit_function_body_ops<'input>( let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } - ast::Instruction::Bfe { .. } => { - // Should have beeen replaced with a funciton call earlier - return Err(error_unreachable()); - } - ast::Instruction::Bfi { .. } => { + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => { // Should have beeen replaced with a funciton call earlier return Err(error_unreachable()); } + ast::Instruction::Rem { typ, arg } => { let builder_fn = if typ.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod @@ -3017,18 +3058,6 @@ fn emit_function_body_ops<'input>( )?; builder.bitcast(b32_type, Some(arg.dst), dst_vector)?; } - ast::Instruction::Activemask { arg } => { - let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let vec4_b32_type = - map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4)); - let pred_true = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::Pred), - &[1], - )?; - let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?; - builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?; - } ast::Instruction::Membar { level } => { let (scope, semantics) = match level { ast::MemScope::Cta => ( @@ -5293,6 +5322,44 @@ impl<'b> MutableNumericIdResolver<'b> { } } +struct FunctionPointerDetails { + dst: spirv::Word, + src: spirv::Word, +} + +impl, U: ArgParamsEx> Visitable + for FunctionPointerDetails +{ + fn visit( + self, + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::FunctionPointer(FunctionPointerDetails { + dst: visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + )?, + src: visitor.id( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + None, + )?, + })) + } +} + enum Statement { Label(u32), Variable(ast::Variable), @@ -5307,6 +5374,7 @@ enum Statement { RetValue(ast::RetData, spirv::Word), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), } impl ExpandedStatement { @@ -5399,6 +5467,12 @@ impl ExpandedStatement { ..repack }) } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + Statement::FunctionPointer(FunctionPointerDetails { + dst: f(dst, true), + src: f(src, false), + }) + } } } }