From 4b4f33e29ee9e64852273f9ed7aa3ae9c2e06bd5 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 1 Mar 2024 00:41:23 +0100 Subject: [PATCH] Implement fn pointers in global initializers --- ptx/src/ast.rs | 11 ++- ptx/src/emit.rs | 26 +++++- ptx/src/ptx.lalrpop | 4 +- ptx/src/test/spirv_run/call_global_ptr.ll | 71 +++++++++++++++++ ptx/src/test/spirv_run/call_global_ptr.ptx | 43 ++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 92 +++++++++++++++++++++- 7 files changed, 240 insertions(+), 8 deletions(-) create mode 100644 ptx/src/test/spirv_run/call_global_ptr.ll create mode 100644 ptx/src/test/spirv_run/call_global_ptr.ptx diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index df6c75e..994cf7c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1383,12 +1383,19 @@ pub enum TextureGeometry { #[derive(Clone)] pub enum Initializer { Constant(ImmediateValue), - Global(ID, Type), - GenericGlobal(ID, Type), + Global(ID, InitializerType), + GenericGlobal(ID, InitializerType), Add(Box<(Initializer, Initializer)>), Array(Vec>), } +#[derive(Clone)] +pub enum InitializerType { + Unknown, + Value(Type), + Function(Vec, Vec), +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index ddfddae..dbbed6e 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -406,7 +406,7 @@ unsafe fn get_llvm_const( let name = ctx.names.value(id)?; let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?; let mut zero = LLVMConstInt(b64, 0, 0); - let src_type = get_llvm_type(ctx, &type_)?; + let src_type = get_initializer_llvm_type(ctx, type_)?; let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1); LLVMConstPtrToInt(global_ptr, b64) } @@ -414,7 +414,7 @@ unsafe fn get_llvm_const( let name = ctx.names.value(id)?; let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?; let mut zero = LLVMConstInt(b64, 0, 0); - let src_type = get_llvm_type(ctx, &type_)?; + let src_type = get_initializer_llvm_type(ctx, type_)?; let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1); // void pointers are illegal in LLVM IR let b8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?; @@ -430,6 +430,28 @@ unsafe fn get_llvm_const( Ok(const_value) } +fn get_initializer_llvm_type( + ctx: &mut EmitContext, + type_: ast::InitializerType, +) -> Result { + Ok(match type_ { + ast::InitializerType::Unknown => return Err(TranslateError::unreachable()), + ast::InitializerType::Value(type_) => get_llvm_type(ctx, &type_)?, + ast::InitializerType::Function(return_args, input_args) => { + let return_type = match &*return_args { + [] => llvm::void_type(&ctx.context), + [type_] => get_llvm_type(ctx, type_)?, + [..] => get_llvm_type_struct(ctx, return_args.into_iter().map(Cow::Owned))?, + }; + get_llvm_function_type( + ctx, + return_type, + input_args.iter().map(|type_| (type_, ast::StateSpace::Reg)), + )? + } + }) +} + unsafe fn get_llvm_const_scalar( ctx: &mut EmitContext, scalar_type: ast::ScalarType, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 08fe495..5345066 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -652,8 +652,8 @@ Initializer: ast::Initializer<&'input str> = { InitializerNoAdd: ast::Initializer<&'input str> = { => ast::Initializer::Constant(val), - => ast::Initializer::Global(id, ast::Type::Struct(Vec::new())), - "generic" "(" ")" => ast::Initializer::GenericGlobal(id, ast::Type::Struct(Vec::new())), + => ast::Initializer::Global(id, ast::InitializerType::Unknown), + "generic" "(" ")" => ast::Initializer::GenericGlobal(id, ast::InitializerType::Unknown), "{" > "}" => ast::Initializer::Array(array_init) } diff --git a/ptx/src/test/spirv_run/call_global_ptr.ll b/ptx/src/test/spirv_run/call_global_ptr.ll new file mode 100644 index 0000000..edd07eb --- /dev/null +++ b/ptx/src/test/spirv_run/call_global_ptr.ll @@ -0,0 +1,71 @@ +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" +target triple = "amdgcn-amd-amdhsa" + +@fn_ptrs = protected addrspace(1) externally_initialized global [2 x i64] [i64 0, i64 ptrtoint (ptr @incr to i64)], align 8 + +define private i64 @incr(i64 %"36") #0 { +"60": + %"21" = alloca i64, align 8, addrspace(5) + %"20" = alloca i64, align 8, addrspace(5) + %"24" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"24", align 1 + %"25" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"25", align 1 + %"51" = alloca i64, align 8, addrspace(5) + %"52" = alloca i64, align 8, addrspace(5) + %"17" = alloca i64, align 8, addrspace(5) + store i64 %"36", ptr addrspace(5) %"21", align 8 + %"37" = load i64, ptr addrspace(5) %"21", align 8 + store i64 %"37", ptr addrspace(5) %"52", align 8 + %"38" = load i64, ptr addrspace(5) %"52", align 8 + store i64 %"38", ptr addrspace(5) %"17", align 8 + %"40" = load i64, ptr addrspace(5) %"17", align 8 + %"39" = add i64 %"40", 1 + store i64 %"39", ptr addrspace(5) %"17", align 8 + %"41" = load i64, ptr addrspace(5) %"17", align 8 + store i64 %"41", ptr addrspace(5) %"51", align 8 + %"42" = load i64, ptr addrspace(5) %"51", align 8 + store i64 %"42", ptr addrspace(5) %"20", align 8 + %"43" = load i64, ptr addrspace(5) %"20", align 8 + ret i64 %"43" +} + +define protected amdgpu_kernel void @call_global_ptr(ptr addrspace(4) byref(i64) %"47", ptr addrspace(4) byref(i64) %"48") #0 { +"59": + %"22" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"22", align 1 + %"23" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"23", align 1 + %"8" = alloca i64, align 8, addrspace(5) + %"9" = alloca i64, align 8, addrspace(5) + %"10" = alloca i64, align 8, addrspace(5) + %"11" = alloca i64, align 8, addrspace(5) + %"49" = alloca i64, align 8, addrspace(5) + %"50" = alloca i64, align 8, addrspace(5) + %"26" = load i64, ptr addrspace(4) %"47", align 8 + store i64 %"26", ptr addrspace(5) %"8", align 8 + %"27" = load i64, ptr addrspace(4) %"48", align 8 + store i64 %"27", ptr addrspace(5) %"9", align 8 + %"29" = load i64, ptr addrspace(5) %"8", align 8 + %"53" = inttoptr i64 %"29" to ptr addrspace(1) + %"28" = load i64, ptr addrspace(1) %"53", align 8 + store i64 %"28", ptr addrspace(5) %"10", align 8 + %"30" = load i64, ptr addrspace(5) %"10", align 8 + store i64 %"30", ptr addrspace(5) %"49", align 8 + %"31" = load i64, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(1) @fn_ptrs to ptr), i64 8), align 8 + store i64 %"31", ptr addrspace(5) %"11", align 8 + %"18" = load i64, ptr addrspace(5) %"49", align 8 + %"32" = load i64, ptr addrspace(5) %"11", align 8 + %0 = inttoptr i64 %"32" to ptr + %"19" = call i64 %0(i64 %"18") + store i64 %"19", ptr addrspace(5) %"50", align 8 + %"33" = load i64, ptr addrspace(5) %"50", align 8 + store i64 %"33", ptr addrspace(5) %"10", align 8 + %"34" = load i64, ptr addrspace(5) %"9", align 8 + %"35" = load i64, ptr addrspace(5) %"10", align 8 + %"58" = inttoptr i64 %"34" to ptr addrspace(1) + store i64 %"35", ptr addrspace(1) %"58", align 8 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/call_global_ptr.ptx b/ptx/src/test/spirv_run/call_global_ptr.ptx new file mode 100644 index 0000000..59b1d26 --- /dev/null +++ b/ptx/src/test/spirv_run/call_global_ptr.ptx @@ -0,0 +1,43 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.weak .func (.param.u64 output) incr (.param.u64 input); + +.weak .global .align 8 .u64 fn_ptrs[2] = {0, incr}; + +.visible .entry call_global_ptr( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 fn_ptr; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp, [in_addr]; + .param.u64 incr_in; + .param.u64 incr_out; + st.param.b64 [incr_in], temp; +incr_fn_ptr: .callprototype (.param .u64 _) _ (.param .u64 _); + ld.u64 fn_ptr, [fn_ptrs+8]; + call (incr_out), fn_ptr, (incr_in), incr_fn_ptr; + ld.param.u64 temp, [incr_out]; + st.global.u64 [out_addr], temp; + ret; +} + +.weak .func (.param .u64 output) incr( + .param .u64 input +) +{ + .reg .u64 temp; + ld.param.u64 temp, [input]; + add.u64 temp, temp, 1; + st.param.u64 [output], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 36d82d2..1ad0cb2 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -127,6 +127,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]); test_ptx!(block, [1u64], [2u64]); test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); +test_ptx!(call_global_ptr, [12u64], [13u64]); // In certain situations LLVM will miscompile AMDGPU binaries. // This happens if the return type of a function is a .b8 array. // This test checks if our workaround for this bug works diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index b2e3e9a..a907e5e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1483,8 +1483,11 @@ fn resolve_initializers<'input>( ast::Initializer::Constant(_) => {} ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => { - let (src_type, _, _, _) = id_defs.get_typed(*name)?; - *type_ = src_type; + *type_ = if let Some((src_type, _, _, _)) = id_defs.try_get_typed(*name)? { + ast::InitializerType::Value(src_type) + } else { + ast::InitializerType::Unknown + }; } ast::Initializer::Add(subinit) => { resolve_initializer_impl(id_defs, &mut (*subinit).0)?; @@ -3367,6 +3370,7 @@ fn to_llvm_module_impl2<'a, 'input>( // raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions translation_module = raytracing::postprocess(translation_module, raytracing_state)?; } + let translation_module = resolve_type_of_global_fnptrs(translation_module)?; let translation_module = insert_implicit_conversions(translation_module)?; let translation_module = insert_compilation_mode_prologue(translation_module); let translation_module = normalize_labels(translation_module)?; @@ -3398,6 +3402,76 @@ fn to_llvm_module_impl2<'a, 'input>( }) } +fn resolve_type_of_global_fnptrs( + mut translation_module: TranslationModule, +) -> Result, TranslateError> { + let mut functions: FxHashMap, Vec)> = FxHashMap::default(); + for directive in translation_module.directives.iter_mut() { + match directive { + TranslationDirective::Variable(_, _, variable) => { + if let Some(ref mut initializer) = variable.initializer { + set_iniitalizer_type(&mut functions, initializer); + } + } + TranslationDirective::Method(method) => { + if method.is_kernel { + continue; + } + match functions.entry(method.name) { + hash_map::Entry::Occupied(_) => {} + hash_map::Entry::Vacant(entry) => { + entry.insert(( + extract_argument_types(&method.return_arguments)?, + extract_argument_types(&method.input_arguments)?, + )); + } + } + } + } + } + Ok(translation_module) +} + +fn extract_argument_types( + args: &[ast::VariableDeclaration], +) -> Result, TranslateError> { + args.iter() + .map(|var| { + if var.state_space != ast::StateSpace::Reg { + return Err(TranslateError::unreachable()); + } + Ok(var.type_.clone()) + }) + .collect() +} + +fn set_iniitalizer_type( + functions: &mut FxHashMap, Vec)>, + initializer: &mut ast::Initializer, +) { + match initializer { + ast::Initializer::Constant(_) => {} + ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => { + if let Some((return_arguments, input_arguments)) = functions.get(name) { + *type_ = ast::InitializerType::Function( + return_arguments.clone(), + input_arguments.clone(), + ); + } + } + ast::Initializer::Add(add) => { + let (add1, add2) = &mut **add; + set_iniitalizer_type(functions, add1); + set_iniitalizer_type(functions, add2); + } + ast::Initializer::Array(array) => { + for initializer in array.iter_mut() { + set_iniitalizer_type(functions, initializer); + } + } + } +} + // In PTX it's legal to have a function like this: // .func noreturn(.param .b64 noreturn_0) // .noreturn @@ -5207,6 +5281,20 @@ impl<'input> IdNameMapBuilder<'input> { } } + pub(crate) fn try_get_typed( + &self, + id: Id, + ) -> Result, bool)>, TranslateError> { + match self.type_check.get(&id) { + Some(Some(x)) => Ok(Some(x.clone())), + Some(None) => Ok(None), + None => match self.globals.special_registers.get(id) { + Some(x) => Ok(Some((x.get_type(), ast::StateSpace::Sreg, None, true))), + None => Err(TranslateError::untyped_symbol()), + }, + } + } + pub(crate) fn get_typed( &self, id: Id,