diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 6a2a51c..175f4df 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl index 9171ef9..aca9327 100644 --- a/ptx/lib/zluda_ptx_impl.cl +++ b/ptx/lib/zluda_ptx_impl.cl @@ -291,6 +291,11 @@ atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_a ulong FUNC(brev_b64)(ulong base) { return __llvm_bitreverse_i64(base); } + + // Taken from __ballot definition in hipamd/include/hip/amd_detail/amd_device_functions.h + uint FUNC(activemask)() { + return (uint)__builtin_amdgcn_uicmp(1, 0, 33); + } #endif void FUNC(__assertfail)( diff --git a/ptx/src/test/spirv_run/activemask.spvtxt b/ptx/src/test/spirv_run/activemask.spvtxt index c4ad55d..0753c95 100644 --- a/ptx/src/test/spirv_run/activemask.spvtxt +++ b/ptx/src/test/spirv_run/activemask.spvtxt @@ -7,21 +7,22 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %16 = OpExtInstImport "OpenCL.std" + %18 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "activemask" OpExecutionMode %1 ContractionOff + OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %19 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 + %21 = OpTypeFunction %uint + %ulong = OpTypeInt 64 0 + %23 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint - %v4uint = OpTypeVector %uint 4 - %bool = OpTypeBool - %true = OpConstantTrue %bool %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %19 + %15 = OpFunction %uint None %21 + OpFunctionEnd + %1 = OpFunction %void None %23 %6 = OpFunctionParameter %ulong %7 = OpFunctionParameter %ulong %14 = OpLabel @@ -33,8 +34,7 @@ OpStore %3 %7 %8 = OpLoad %ulong %3 Aligned 8 OpStore %4 %8 - %26 = OpSubgroupBallotKHR %v4uint %true - %9 = OpCompositeExtract %uint %26 0 + %9 = OpFunctionCall %uint %15 OpStore %5 %9 %10 = OpLoad %ulong %4 %11 = OpLoad %uint %5 diff --git a/ptx/src/test/spirv_run/func_ptr.ptx b/ptx/src/test/spirv_run/func_ptr.ptx new file mode 100644 index 0000000..aa94f2b --- /dev/null +++ b/ptx/src/test/spirv_run/func_ptr.ptx @@ -0,0 +1,31 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .f32 out) foobar(.reg .f32 x, .reg .f32 y) +{ + add.f32 out, x, y; + ret; +} + +.visible .entry func_ptr( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + .reg .u64 f_addr; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + add.u64 temp2, temp, 1; + mov.u64 f_addr, foobar; + add.u64 temp2, temp2, f_addr; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/func_ptr.spvtxt b/ptx/src/test/spirv_run/func_ptr.spvtxt new file mode 100644 index 0000000..adc71eb --- /dev/null +++ b/ptx/src/test/spirv_run/func_ptr.spvtxt @@ -0,0 +1,73 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %38 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %11 "func_ptr" + OpExecutionMode %11 ContractionOff + %void = OpTypeVoid + %float = OpTypeFloat 32 + %41 = OpTypeFunction %float %float %float +%_ptr_Function_float = OpTypePointer Function %float + %ulong = OpTypeInt 64 0 + %44 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %ulong_0 = OpConstant %ulong 0 + %1 = OpFunction %float None %41 + %5 = OpFunctionParameter %float + %6 = OpFunctionParameter %float + %10 = OpLabel + %3 = OpVariable %_ptr_Function_float Function + %4 = OpVariable %_ptr_Function_float Function + %2 = OpVariable %_ptr_Function_float Function + OpStore %3 %5 + OpStore %4 %6 + %8 = OpLoad %float %3 + %9 = OpLoad %float %4 + %7 = OpFAdd %float %8 %9 + OpStore %2 %7 + OpFunctionEnd + %11 = OpFunction %void None %44 + %19 = OpFunctionParameter %ulong + %20 = OpFunctionParameter %ulong + %36 = OpLabel + %12 = OpVariable %_ptr_Function_ulong Function + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function + %15 = OpVariable %_ptr_Function_ulong Function + %16 = OpVariable %_ptr_Function_ulong Function + %17 = OpVariable %_ptr_Function_ulong Function + %18 = OpVariable %_ptr_Function_ulong Function + OpStore %12 %19 + OpStore %13 %20 + %21 = OpLoad %ulong %12 Aligned 8 + OpStore %14 %21 + %22 = OpLoad %ulong %13 Aligned 8 + OpStore %15 %22 + %24 = OpLoad %ulong %14 + %34 = OpConvertUToPtr %_ptr_Generic_ulong %24 + %23 = OpLoad %ulong %34 Aligned 8 + OpStore %16 %23 + %26 = OpLoad %ulong %16 + %25 = OpIAdd %ulong %26 %ulong_1 + OpStore %17 %25 + %27 = OpCopyObject %ulong %ulong_0 + OpStore %18 %27 + %29 = OpLoad %ulong %17 + %30 = OpLoad %ulong %18 + %28 = OpIAdd %ulong %29 %30 + OpStore %17 %28 + %31 = OpLoad %ulong %15 + %32 = OpLoad %ulong %17 + %35 = OpConvertUToPtr %_ptr_Generic_ulong %31 + OpStore %35 %32 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f6b556e..0dcd0bb 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -209,6 +209,7 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]); test_ptx!(membar, [152731u32], [152731u32]); +test_ptx!(func_ptr, [152731u64], [152732u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e015062..39bd07e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; -use rspirv::binary::Assemble; +use rspirv::binary::{Assemble, Disassemble}; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc"); @@ -607,6 +607,7 @@ fn emit_directives<'input>( } } emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; + builder.select_block(None)?; builder.end_function()?; if let ( ast::MethodDeclaration { @@ -988,6 +989,7 @@ fn compute_denorm_information<'input>( Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} Statement::RepackVector(_) => {} + Statement::FunctionPointer(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1411,6 +1413,15 @@ fn extract_globals<'input, 'b>( fn_name, )?); } + Statement::Instruction(ast::Instruction::Activemask { arg }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Activemask { arg }, + fn_name, + )?); + } Statement::Instruction(ast::Instruction::Atom( details @ @@ -1596,6 +1607,21 @@ fn convert_to_typed_statements( for s in func { match s { Statement::Instruction(inst) => match inst { + ast::Instruction::Mov( + mov, + ast::Arg2Mov { + dst: ast::Operand::Reg(dst_reg), + src: ast::Operand::Reg(src_reg), + }, + ) if fn_defs.fns.contains_key(&src_reg) => { + if mov.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(TranslateError::MismatchedType); + } + result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + })); + } ast::Instruction::Call(call) => { let resolver = fn_defs.get_fn_sig_resolver(call.func)?; let resolved_call = resolver.resolve_in_spirv_repr(call)?; @@ -1724,7 +1750,7 @@ fn instruction_to_fn_call( let return_arguments_count = arguments .iter() .position(|(desc, _, _)| !desc.is_dst) - .unwrap_or(0); + .unwrap_or(arguments.len()); let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); let fn_id = register_external_fn_call( id_defs, @@ -1826,7 +1852,8 @@ fn normalize_labels( | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } - | Statement::RepackVector(..) => {} + | Statement::RepackVector(..) + | Statement::FunctionPointer(..) => {} } } iter::once(Statement::Label(id_def.register_intermediate(None))) @@ -1984,6 +2011,9 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::RepackVector(repack) => { insert_mem_ssa_statement_default(id_def, &mut result, repack)? } + Statement::FunctionPointer(func_ptr) => { + insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), _ => return Err(error_unreachable()), } @@ -2235,6 +2265,7 @@ fn expand_arguments<'a, 'b>( Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), } } Ok(result) @@ -2421,7 +2452,8 @@ fn insert_implicit_conversions( | s @ Statement::Variable(_) | s @ Statement::LoadVar(..) | s @ Statement::StoreVar(..) - | s @ Statement::RetValue(_, _) => result.push(s), + | s @ Statement::RetValue(..) + | s @ Statement::FunctionPointer(..) => result.push(s), } } Ok(result) @@ -2653,6 +2685,16 @@ fn emit_function_body_ops<'input>( iter::empty(), )?; } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + // TODO: implement properly + let zero = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U64), + &vec_repr(0u64), + )?; + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); + builder.copy_object(result_type, Some(*dst), zero)?; + } Statement::Instruction(inst) => match inst { ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?, ast::Instruction::Call(_) => unreachable!(), @@ -2975,14 +3017,13 @@ fn emit_function_body_ops<'input>( let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } - ast::Instruction::Bfe { .. } => { - // Should have beeen replaced with a funciton call earlier - return Err(error_unreachable()); - } - ast::Instruction::Bfi { .. } => { + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => { // Should have beeen replaced with a funciton call earlier return Err(error_unreachable()); } + ast::Instruction::Rem { typ, arg } => { let builder_fn = if typ.kind() == ast::ScalarKind::Signed { dr::Builder::s_mod @@ -3017,18 +3058,6 @@ fn emit_function_body_ops<'input>( )?; builder.bitcast(b32_type, Some(arg.dst), dst_vector)?; } - ast::Instruction::Activemask { arg } => { - let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let vec4_b32_type = - map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4)); - let pred_true = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::Pred), - &[1], - )?; - let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?; - builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?; - } ast::Instruction::Membar { level } => { let (scope, semantics) = match level { ast::MemScope::Cta => ( @@ -5293,6 +5322,44 @@ impl<'b> MutableNumericIdResolver<'b> { } } +struct FunctionPointerDetails { + dst: spirv::Word, + src: spirv::Word, +} + +impl, U: ArgParamsEx> Visitable + for FunctionPointerDetails +{ + fn visit( + self, + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::FunctionPointer(FunctionPointerDetails { + dst: visitor.id( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + )?, + src: visitor.id( + ArgumentDescriptor { + op: self.src, + is_dst: false, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + None, + )?, + })) + } +} + enum Statement { Label(u32), Variable(ast::Variable), @@ -5307,6 +5374,7 @@ enum Statement { RetValue(ast::RetData, spirv::Word), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), } impl ExpandedStatement { @@ -5399,6 +5467,12 @@ impl ExpandedStatement { ..repack }) } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + Statement::FunctionPointer(FunctionPointerDetails { + dst: f(dst, true), + src: f(src, false), + }) + } } } }