Implement fn pointers in global initializers

This commit is contained in:
Andrzej Janik 2024-03-01 00:41:23 +01:00
parent a1c265b7c2
commit 4b4f33e29e
7 changed files with 240 additions and 8 deletions

View file

@ -1383,12 +1383,19 @@ pub enum TextureGeometry {
#[derive(Clone)]
pub enum Initializer<ID> {
Constant(ImmediateValue),
Global(ID, Type),
GenericGlobal(ID, Type),
Global(ID, InitializerType),
GenericGlobal(ID, InitializerType),
Add(Box<(Initializer<ID>, Initializer<ID>)>),
Array(Vec<Initializer<ID>>),
}
#[derive(Clone)]
pub enum InitializerType {
Unknown,
Value(Type),
Function(Vec<Type>, Vec<Type>),
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -406,7 +406,7 @@ unsafe fn get_llvm_const(
let name = ctx.names.value(id)?;
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
let mut zero = LLVMConstInt(b64, 0, 0);
let src_type = get_llvm_type(ctx, &type_)?;
let src_type = get_initializer_llvm_type(ctx, type_)?;
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
LLVMConstPtrToInt(global_ptr, b64)
}
@ -414,7 +414,7 @@ unsafe fn get_llvm_const(
let name = ctx.names.value(id)?;
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
let mut zero = LLVMConstInt(b64, 0, 0);
let src_type = get_llvm_type(ctx, &type_)?;
let src_type = get_initializer_llvm_type(ctx, type_)?;
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
// void pointers are illegal in LLVM IR
let b8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?;
@ -430,6 +430,28 @@ unsafe fn get_llvm_const(
Ok(const_value)
}
fn get_initializer_llvm_type(
ctx: &mut EmitContext,
type_: ast::InitializerType,
) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::InitializerType::Unknown => return Err(TranslateError::unreachable()),
ast::InitializerType::Value(type_) => get_llvm_type(ctx, &type_)?,
ast::InitializerType::Function(return_args, input_args) => {
let return_type = match &*return_args {
[] => llvm::void_type(&ctx.context),
[type_] => get_llvm_type(ctx, type_)?,
[..] => get_llvm_type_struct(ctx, return_args.into_iter().map(Cow::Owned))?,
};
get_llvm_function_type(
ctx,
return_type,
input_args.iter().map(|type_| (type_, ast::StateSpace::Reg)),
)?
}
})
}
unsafe fn get_llvm_const_scalar(
ctx: &mut EmitContext,
scalar_type: ast::ScalarType,

View file

@ -652,8 +652,8 @@ Initializer: ast::Initializer<&'input str> = {
InitializerNoAdd: ast::Initializer<&'input str> = {
<val:ImmediateValue> => ast::Initializer::Constant(val),
<id:ExtendedID> => ast::Initializer::Global(id, ast::Type::Struct(Vec::new())),
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id, ast::Type::Struct(Vec::new())),
<id:ExtendedID> => ast::Initializer::Global(id, ast::InitializerType::Unknown),
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id, ast::InitializerType::Unknown),
"{" <array_init:Comma<Initializer>> "}" => ast::Initializer::Array(array_init)
}

View file

@ -0,0 +1,71 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"
@fn_ptrs = protected addrspace(1) externally_initialized global [2 x i64] [i64 0, i64 ptrtoint (ptr @incr to i64)], align 8
define private i64 @incr(i64 %"36") #0 {
"60":
%"21" = alloca i64, align 8, addrspace(5)
%"20" = alloca i64, align 8, addrspace(5)
%"24" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"24", align 1
%"25" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"25", align 1
%"51" = alloca i64, align 8, addrspace(5)
%"52" = alloca i64, align 8, addrspace(5)
%"17" = alloca i64, align 8, addrspace(5)
store i64 %"36", ptr addrspace(5) %"21", align 8
%"37" = load i64, ptr addrspace(5) %"21", align 8
store i64 %"37", ptr addrspace(5) %"52", align 8
%"38" = load i64, ptr addrspace(5) %"52", align 8
store i64 %"38", ptr addrspace(5) %"17", align 8
%"40" = load i64, ptr addrspace(5) %"17", align 8
%"39" = add i64 %"40", 1
store i64 %"39", ptr addrspace(5) %"17", align 8
%"41" = load i64, ptr addrspace(5) %"17", align 8
store i64 %"41", ptr addrspace(5) %"51", align 8
%"42" = load i64, ptr addrspace(5) %"51", align 8
store i64 %"42", ptr addrspace(5) %"20", align 8
%"43" = load i64, ptr addrspace(5) %"20", align 8
ret i64 %"43"
}
define protected amdgpu_kernel void @call_global_ptr(ptr addrspace(4) byref(i64) %"47", ptr addrspace(4) byref(i64) %"48") #0 {
"59":
%"22" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"22", align 1
%"23" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"23", align 1
%"8" = alloca i64, align 8, addrspace(5)
%"9" = alloca i64, align 8, addrspace(5)
%"10" = alloca i64, align 8, addrspace(5)
%"11" = alloca i64, align 8, addrspace(5)
%"49" = alloca i64, align 8, addrspace(5)
%"50" = alloca i64, align 8, addrspace(5)
%"26" = load i64, ptr addrspace(4) %"47", align 8
store i64 %"26", ptr addrspace(5) %"8", align 8
%"27" = load i64, ptr addrspace(4) %"48", align 8
store i64 %"27", ptr addrspace(5) %"9", align 8
%"29" = load i64, ptr addrspace(5) %"8", align 8
%"53" = inttoptr i64 %"29" to ptr addrspace(1)
%"28" = load i64, ptr addrspace(1) %"53", align 8
store i64 %"28", ptr addrspace(5) %"10", align 8
%"30" = load i64, ptr addrspace(5) %"10", align 8
store i64 %"30", ptr addrspace(5) %"49", align 8
%"31" = load i64, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(1) @fn_ptrs to ptr), i64 8), align 8
store i64 %"31", ptr addrspace(5) %"11", align 8
%"18" = load i64, ptr addrspace(5) %"49", align 8
%"32" = load i64, ptr addrspace(5) %"11", align 8
%0 = inttoptr i64 %"32" to ptr
%"19" = call i64 %0(i64 %"18")
store i64 %"19", ptr addrspace(5) %"50", align 8
%"33" = load i64, ptr addrspace(5) %"50", align 8
store i64 %"33", ptr addrspace(5) %"10", align 8
%"34" = load i64, ptr addrspace(5) %"9", align 8
%"35" = load i64, ptr addrspace(5) %"10", align 8
%"58" = inttoptr i64 %"34" to ptr addrspace(1)
store i64 %"35", ptr addrspace(1) %"58", align 8
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -0,0 +1,43 @@
.version 6.5
.target sm_30
.address_size 64
.weak .func (.param.u64 output) incr (.param.u64 input);
.weak .global .align 8 .u64 fn_ptrs[2] = {0, incr};
.visible .entry call_global_ptr(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 fn_ptr;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.global.u64 temp, [in_addr];
.param.u64 incr_in;
.param.u64 incr_out;
st.param.b64 [incr_in], temp;
incr_fn_ptr: .callprototype (.param .u64 _) _ (.param .u64 _);
ld.u64 fn_ptr, [fn_ptrs+8];
call (incr_out), fn_ptr, (incr_in), incr_fn_ptr;
ld.param.u64 temp, [incr_out];
st.global.u64 [out_addr], temp;
ret;
}
.weak .func (.param .u64 output) incr(
.param .u64 input
)
{
.reg .u64 temp;
ld.param.u64 temp, [input];
add.u64 temp, temp, 1;
st.param.u64 [output], temp;
ret;
}

View file

@ -127,6 +127,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(call_global_ptr, [12u64], [13u64]);
// In certain situations LLVM will miscompile AMDGPU binaries.
// This happens if the return type of a function is a .b8 array.
// This test checks if our workaround for this bug works

View file

@ -1483,8 +1483,11 @@ fn resolve_initializers<'input>(
ast::Initializer::Constant(_) => {}
ast::Initializer::Global(name, type_)
| ast::Initializer::GenericGlobal(name, type_) => {
let (src_type, _, _, _) = id_defs.get_typed(*name)?;
*type_ = src_type;
*type_ = if let Some((src_type, _, _, _)) = id_defs.try_get_typed(*name)? {
ast::InitializerType::Value(src_type)
} else {
ast::InitializerType::Unknown
};
}
ast::Initializer::Add(subinit) => {
resolve_initializer_impl(id_defs, &mut (*subinit).0)?;
@ -3367,6 +3370,7 @@ fn to_llvm_module_impl2<'a, 'input>(
// raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions
translation_module = raytracing::postprocess(translation_module, raytracing_state)?;
}
let translation_module = resolve_type_of_global_fnptrs(translation_module)?;
let translation_module = insert_implicit_conversions(translation_module)?;
let translation_module = insert_compilation_mode_prologue(translation_module);
let translation_module = normalize_labels(translation_module)?;
@ -3398,6 +3402,76 @@ fn to_llvm_module_impl2<'a, 'input>(
})
}
fn resolve_type_of_global_fnptrs(
mut translation_module: TranslationModule<ExpandedArgParams>,
) -> Result<TranslationModule<ExpandedArgParams>, TranslateError> {
let mut functions: FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)> = FxHashMap::default();
for directive in translation_module.directives.iter_mut() {
match directive {
TranslationDirective::Variable(_, _, variable) => {
if let Some(ref mut initializer) = variable.initializer {
set_iniitalizer_type(&mut functions, initializer);
}
}
TranslationDirective::Method(method) => {
if method.is_kernel {
continue;
}
match functions.entry(method.name) {
hash_map::Entry::Occupied(_) => {}
hash_map::Entry::Vacant(entry) => {
entry.insert((
extract_argument_types(&method.return_arguments)?,
extract_argument_types(&method.input_arguments)?,
));
}
}
}
}
}
Ok(translation_module)
}
fn extract_argument_types(
args: &[ast::VariableDeclaration<Id>],
) -> Result<Vec<ast::Type>, TranslateError> {
args.iter()
.map(|var| {
if var.state_space != ast::StateSpace::Reg {
return Err(TranslateError::unreachable());
}
Ok(var.type_.clone())
})
.collect()
}
fn set_iniitalizer_type(
functions: &mut FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)>,
initializer: &mut ast::Initializer<Id>,
) {
match initializer {
ast::Initializer::Constant(_) => {}
ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => {
if let Some((return_arguments, input_arguments)) = functions.get(name) {
*type_ = ast::InitializerType::Function(
return_arguments.clone(),
input_arguments.clone(),
);
}
}
ast::Initializer::Add(add) => {
let (add1, add2) = &mut **add;
set_iniitalizer_type(functions, add1);
set_iniitalizer_type(functions, add2);
}
ast::Initializer::Array(array) => {
for initializer in array.iter_mut() {
set_iniitalizer_type(functions, initializer);
}
}
}
}
// In PTX it's legal to have a function like this:
// .func noreturn(.param .b64 noreturn_0)
// .noreturn
@ -5207,6 +5281,20 @@ impl<'input> IdNameMapBuilder<'input> {
}
}
pub(crate) fn try_get_typed(
&self,
id: Id,
) -> Result<Option<(ast::Type, ast::StateSpace, Option<u32>, bool)>, TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(Some(x.clone())),
Some(None) => Ok(None),
None => match self.globals.special_registers.get(id) {
Some(x) => Ok(Some((x.get_type(), ast::StateSpace::Sreg, None, true))),
None => Err(TranslateError::untyped_symbol()),
},
}
}
pub(crate) fn get_typed(
&self,
id: Id,