mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Implement fn pointers in global initializers
This commit is contained in:
parent
a1c265b7c2
commit
4b4f33e29e
7 changed files with 240 additions and 8 deletions
|
@ -1383,12 +1383,19 @@ pub enum TextureGeometry {
|
|||
#[derive(Clone)]
|
||||
pub enum Initializer<ID> {
|
||||
Constant(ImmediateValue),
|
||||
Global(ID, Type),
|
||||
GenericGlobal(ID, Type),
|
||||
Global(ID, InitializerType),
|
||||
GenericGlobal(ID, InitializerType),
|
||||
Add(Box<(Initializer<ID>, Initializer<ID>)>),
|
||||
Array(Vec<Initializer<ID>>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum InitializerType {
|
||||
Unknown,
|
||||
Value(Type),
|
||||
Function(Vec<Type>, Vec<Type>),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -406,7 +406,7 @@ unsafe fn get_llvm_const(
|
|||
let name = ctx.names.value(id)?;
|
||||
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
|
||||
let mut zero = LLVMConstInt(b64, 0, 0);
|
||||
let src_type = get_llvm_type(ctx, &type_)?;
|
||||
let src_type = get_initializer_llvm_type(ctx, type_)?;
|
||||
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
|
||||
LLVMConstPtrToInt(global_ptr, b64)
|
||||
}
|
||||
|
@ -414,7 +414,7 @@ unsafe fn get_llvm_const(
|
|||
let name = ctx.names.value(id)?;
|
||||
let b64 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B64))?;
|
||||
let mut zero = LLVMConstInt(b64, 0, 0);
|
||||
let src_type = get_llvm_type(ctx, &type_)?;
|
||||
let src_type = get_initializer_llvm_type(ctx, type_)?;
|
||||
let global_ptr = LLVMConstInBoundsGEP2(src_type, name, &mut zero, 1);
|
||||
// void pointers are illegal in LLVM IR
|
||||
let b8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::B8))?;
|
||||
|
@ -430,6 +430,28 @@ unsafe fn get_llvm_const(
|
|||
Ok(const_value)
|
||||
}
|
||||
|
||||
fn get_initializer_llvm_type(
|
||||
ctx: &mut EmitContext,
|
||||
type_: ast::InitializerType,
|
||||
) -> Result<LLVMTypeRef, TranslateError> {
|
||||
Ok(match type_ {
|
||||
ast::InitializerType::Unknown => return Err(TranslateError::unreachable()),
|
||||
ast::InitializerType::Value(type_) => get_llvm_type(ctx, &type_)?,
|
||||
ast::InitializerType::Function(return_args, input_args) => {
|
||||
let return_type = match &*return_args {
|
||||
[] => llvm::void_type(&ctx.context),
|
||||
[type_] => get_llvm_type(ctx, type_)?,
|
||||
[..] => get_llvm_type_struct(ctx, return_args.into_iter().map(Cow::Owned))?,
|
||||
};
|
||||
get_llvm_function_type(
|
||||
ctx,
|
||||
return_type,
|
||||
input_args.iter().map(|type_| (type_, ast::StateSpace::Reg)),
|
||||
)?
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
unsafe fn get_llvm_const_scalar(
|
||||
ctx: &mut EmitContext,
|
||||
scalar_type: ast::ScalarType,
|
||||
|
|
|
@ -652,8 +652,8 @@ Initializer: ast::Initializer<&'input str> = {
|
|||
|
||||
InitializerNoAdd: ast::Initializer<&'input str> = {
|
||||
<val:ImmediateValue> => ast::Initializer::Constant(val),
|
||||
<id:ExtendedID> => ast::Initializer::Global(id, ast::Type::Struct(Vec::new())),
|
||||
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id, ast::Type::Struct(Vec::new())),
|
||||
<id:ExtendedID> => ast::Initializer::Global(id, ast::InitializerType::Unknown),
|
||||
"generic" "(" <id:ExtendedID> ")" => ast::Initializer::GenericGlobal(id, ast::InitializerType::Unknown),
|
||||
"{" <array_init:Comma<Initializer>> "}" => ast::Initializer::Array(array_init)
|
||||
}
|
||||
|
||||
|
|
71
ptx/src/test/spirv_run/call_global_ptr.ll
Normal file
71
ptx/src/test/spirv_run/call_global_ptr.ll
Normal file
|
@ -0,0 +1,71 @@
|
|||
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
|
||||
target triple = "amdgcn-amd-amdhsa"
|
||||
|
||||
@fn_ptrs = protected addrspace(1) externally_initialized global [2 x i64] [i64 0, i64 ptrtoint (ptr @incr to i64)], align 8
|
||||
|
||||
define private i64 @incr(i64 %"36") #0 {
|
||||
"60":
|
||||
%"21" = alloca i64, align 8, addrspace(5)
|
||||
%"20" = alloca i64, align 8, addrspace(5)
|
||||
%"24" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"24", align 1
|
||||
%"25" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"25", align 1
|
||||
%"51" = alloca i64, align 8, addrspace(5)
|
||||
%"52" = alloca i64, align 8, addrspace(5)
|
||||
%"17" = alloca i64, align 8, addrspace(5)
|
||||
store i64 %"36", ptr addrspace(5) %"21", align 8
|
||||
%"37" = load i64, ptr addrspace(5) %"21", align 8
|
||||
store i64 %"37", ptr addrspace(5) %"52", align 8
|
||||
%"38" = load i64, ptr addrspace(5) %"52", align 8
|
||||
store i64 %"38", ptr addrspace(5) %"17", align 8
|
||||
%"40" = load i64, ptr addrspace(5) %"17", align 8
|
||||
%"39" = add i64 %"40", 1
|
||||
store i64 %"39", ptr addrspace(5) %"17", align 8
|
||||
%"41" = load i64, ptr addrspace(5) %"17", align 8
|
||||
store i64 %"41", ptr addrspace(5) %"51", align 8
|
||||
%"42" = load i64, ptr addrspace(5) %"51", align 8
|
||||
store i64 %"42", ptr addrspace(5) %"20", align 8
|
||||
%"43" = load i64, ptr addrspace(5) %"20", align 8
|
||||
ret i64 %"43"
|
||||
}
|
||||
|
||||
define protected amdgpu_kernel void @call_global_ptr(ptr addrspace(4) byref(i64) %"47", ptr addrspace(4) byref(i64) %"48") #0 {
|
||||
"59":
|
||||
%"22" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"22", align 1
|
||||
%"23" = alloca i1, align 1, addrspace(5)
|
||||
store i1 false, ptr addrspace(5) %"23", align 1
|
||||
%"8" = alloca i64, align 8, addrspace(5)
|
||||
%"9" = alloca i64, align 8, addrspace(5)
|
||||
%"10" = alloca i64, align 8, addrspace(5)
|
||||
%"11" = alloca i64, align 8, addrspace(5)
|
||||
%"49" = alloca i64, align 8, addrspace(5)
|
||||
%"50" = alloca i64, align 8, addrspace(5)
|
||||
%"26" = load i64, ptr addrspace(4) %"47", align 8
|
||||
store i64 %"26", ptr addrspace(5) %"8", align 8
|
||||
%"27" = load i64, ptr addrspace(4) %"48", align 8
|
||||
store i64 %"27", ptr addrspace(5) %"9", align 8
|
||||
%"29" = load i64, ptr addrspace(5) %"8", align 8
|
||||
%"53" = inttoptr i64 %"29" to ptr addrspace(1)
|
||||
%"28" = load i64, ptr addrspace(1) %"53", align 8
|
||||
store i64 %"28", ptr addrspace(5) %"10", align 8
|
||||
%"30" = load i64, ptr addrspace(5) %"10", align 8
|
||||
store i64 %"30", ptr addrspace(5) %"49", align 8
|
||||
%"31" = load i64, ptr getelementptr inbounds (i8, ptr addrspacecast (ptr addrspace(1) @fn_ptrs to ptr), i64 8), align 8
|
||||
store i64 %"31", ptr addrspace(5) %"11", align 8
|
||||
%"18" = load i64, ptr addrspace(5) %"49", align 8
|
||||
%"32" = load i64, ptr addrspace(5) %"11", align 8
|
||||
%0 = inttoptr i64 %"32" to ptr
|
||||
%"19" = call i64 %0(i64 %"18")
|
||||
store i64 %"19", ptr addrspace(5) %"50", align 8
|
||||
%"33" = load i64, ptr addrspace(5) %"50", align 8
|
||||
store i64 %"33", ptr addrspace(5) %"10", align 8
|
||||
%"34" = load i64, ptr addrspace(5) %"9", align 8
|
||||
%"35" = load i64, ptr addrspace(5) %"10", align 8
|
||||
%"58" = inttoptr i64 %"34" to ptr addrspace(1)
|
||||
store i64 %"35", ptr addrspace(1) %"58", align 8
|
||||
ret void
|
||||
}
|
||||
|
||||
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
|
43
ptx/src/test/spirv_run/call_global_ptr.ptx
Normal file
43
ptx/src/test/spirv_run/call_global_ptr.ptx
Normal file
|
@ -0,0 +1,43 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.weak .func (.param.u64 output) incr (.param.u64 input);
|
||||
|
||||
.weak .global .align 8 .u64 fn_ptrs[2] = {0, incr};
|
||||
|
||||
.visible .entry call_global_ptr(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 fn_ptr;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.global.u64 temp, [in_addr];
|
||||
.param.u64 incr_in;
|
||||
.param.u64 incr_out;
|
||||
st.param.b64 [incr_in], temp;
|
||||
incr_fn_ptr: .callprototype (.param .u64 _) _ (.param .u64 _);
|
||||
ld.u64 fn_ptr, [fn_ptrs+8];
|
||||
call (incr_out), fn_ptr, (incr_in), incr_fn_ptr;
|
||||
ld.param.u64 temp, [incr_out];
|
||||
st.global.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
||||
|
||||
.weak .func (.param .u64 output) incr(
|
||||
.param .u64 input
|
||||
)
|
||||
{
|
||||
.reg .u64 temp;
|
||||
ld.param.u64 temp, [input];
|
||||
add.u64 temp, temp, 1;
|
||||
st.param.u64 [output], temp;
|
||||
ret;
|
||||
}
|
|
@ -127,6 +127,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
|
|||
test_ptx!(block, [1u64], [2u64]);
|
||||
test_ptx!(local_align, [1u64], [1u64]);
|
||||
test_ptx!(call, [1u64], [2u64]);
|
||||
test_ptx!(call_global_ptr, [12u64], [13u64]);
|
||||
// In certain situations LLVM will miscompile AMDGPU binaries.
|
||||
// This happens if the return type of a function is a .b8 array.
|
||||
// This test checks if our workaround for this bug works
|
||||
|
|
|
@ -1483,8 +1483,11 @@ fn resolve_initializers<'input>(
|
|||
ast::Initializer::Constant(_) => {}
|
||||
ast::Initializer::Global(name, type_)
|
||||
| ast::Initializer::GenericGlobal(name, type_) => {
|
||||
let (src_type, _, _, _) = id_defs.get_typed(*name)?;
|
||||
*type_ = src_type;
|
||||
*type_ = if let Some((src_type, _, _, _)) = id_defs.try_get_typed(*name)? {
|
||||
ast::InitializerType::Value(src_type)
|
||||
} else {
|
||||
ast::InitializerType::Unknown
|
||||
};
|
||||
}
|
||||
ast::Initializer::Add(subinit) => {
|
||||
resolve_initializer_impl(id_defs, &mut (*subinit).0)?;
|
||||
|
@ -3367,6 +3370,7 @@ fn to_llvm_module_impl2<'a, 'input>(
|
|||
// raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions
|
||||
translation_module = raytracing::postprocess(translation_module, raytracing_state)?;
|
||||
}
|
||||
let translation_module = resolve_type_of_global_fnptrs(translation_module)?;
|
||||
let translation_module = insert_implicit_conversions(translation_module)?;
|
||||
let translation_module = insert_compilation_mode_prologue(translation_module);
|
||||
let translation_module = normalize_labels(translation_module)?;
|
||||
|
@ -3398,6 +3402,76 @@ fn to_llvm_module_impl2<'a, 'input>(
|
|||
})
|
||||
}
|
||||
|
||||
fn resolve_type_of_global_fnptrs(
|
||||
mut translation_module: TranslationModule<ExpandedArgParams>,
|
||||
) -> Result<TranslationModule<ExpandedArgParams>, TranslateError> {
|
||||
let mut functions: FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)> = FxHashMap::default();
|
||||
for directive in translation_module.directives.iter_mut() {
|
||||
match directive {
|
||||
TranslationDirective::Variable(_, _, variable) => {
|
||||
if let Some(ref mut initializer) = variable.initializer {
|
||||
set_iniitalizer_type(&mut functions, initializer);
|
||||
}
|
||||
}
|
||||
TranslationDirective::Method(method) => {
|
||||
if method.is_kernel {
|
||||
continue;
|
||||
}
|
||||
match functions.entry(method.name) {
|
||||
hash_map::Entry::Occupied(_) => {}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert((
|
||||
extract_argument_types(&method.return_arguments)?,
|
||||
extract_argument_types(&method.input_arguments)?,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(translation_module)
|
||||
}
|
||||
|
||||
fn extract_argument_types(
|
||||
args: &[ast::VariableDeclaration<Id>],
|
||||
) -> Result<Vec<ast::Type>, TranslateError> {
|
||||
args.iter()
|
||||
.map(|var| {
|
||||
if var.state_space != ast::StateSpace::Reg {
|
||||
return Err(TranslateError::unreachable());
|
||||
}
|
||||
Ok(var.type_.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_iniitalizer_type(
|
||||
functions: &mut FxHashMap<Id, (Vec<ast::Type>, Vec<ast::Type>)>,
|
||||
initializer: &mut ast::Initializer<Id>,
|
||||
) {
|
||||
match initializer {
|
||||
ast::Initializer::Constant(_) => {}
|
||||
ast::Initializer::Global(name, type_) | ast::Initializer::GenericGlobal(name, type_) => {
|
||||
if let Some((return_arguments, input_arguments)) = functions.get(name) {
|
||||
*type_ = ast::InitializerType::Function(
|
||||
return_arguments.clone(),
|
||||
input_arguments.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
ast::Initializer::Add(add) => {
|
||||
let (add1, add2) = &mut **add;
|
||||
set_iniitalizer_type(functions, add1);
|
||||
set_iniitalizer_type(functions, add2);
|
||||
}
|
||||
ast::Initializer::Array(array) => {
|
||||
for initializer in array.iter_mut() {
|
||||
set_iniitalizer_type(functions, initializer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In PTX it's legal to have a function like this:
|
||||
// .func noreturn(.param .b64 noreturn_0)
|
||||
// .noreturn
|
||||
|
@ -5207,6 +5281,20 @@ impl<'input> IdNameMapBuilder<'input> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_get_typed(
|
||||
&self,
|
||||
id: Id,
|
||||
) -> Result<Option<(ast::Type, ast::StateSpace, Option<u32>, bool)>, TranslateError> {
|
||||
match self.type_check.get(&id) {
|
||||
Some(Some(x)) => Ok(Some(x.clone())),
|
||||
Some(None) => Ok(None),
|
||||
None => match self.globals.special_registers.get(id) {
|
||||
Some(x) => Ok(Some((x.get_type(), ast::StateSpace::Sreg, None, true))),
|
||||
None => Err(TranslateError::untyped_symbol()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_typed(
|
||||
&self,
|
||||
id: Id,
|
||||
|
|
Loading…
Add table
Reference in a new issue