mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-22 10:19:31 +00:00
Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build
This commit is contained in:
parent
4a71fefb8a
commit
b4de21fbc5
7 changed files with 278 additions and 81 deletions
|
@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]);
|
||||||
test_ptx!(local_align, [1u64], [1u64]);
|
test_ptx!(local_align, [1u64], [1u64]);
|
||||||
test_ptx!(call, [1u64], [2u64]);
|
test_ptx!(call, [1u64], [2u64]);
|
||||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||||
|
test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
|
||||||
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
||||||
test_ptx!(ntid, [3u32], [4u32]);
|
test_ptx!(ntid, [3u32], [4u32]);
|
||||||
test_ptx!(reg_local, [12u64], [13u64]);
|
test_ptx!(reg_local, [12u64], [13u64]);
|
||||||
|
|
|
@ -7,24 +7,27 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
%28 = OpExtInstImport "OpenCL.std"
|
%31 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize
|
OpEntryPoint Kernel %1 "ntid"
|
||||||
OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
|
OpExecutionMode %1 ContractionOff
|
||||||
|
OpDecorate %24 LinkageAttributes "get_local_size" Import
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%v3ulong = OpTypeVector %ulong 3
|
|
||||||
%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
|
|
||||||
%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input
|
|
||||||
%33 = OpTypeFunction %void %ulong %ulong
|
|
||||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
|
||||||
%uint = OpTypeInt 32 0
|
%uint = OpTypeInt 32 0
|
||||||
|
%35 = OpTypeFunction %ulong %uint
|
||||||
|
%36 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||||
%1 = OpFunction %void None %33
|
%uint_0 = OpConstant %uint 0
|
||||||
|
%24 = OpFunction %ulong None %35
|
||||||
|
%26 = OpFunctionParameter %uint
|
||||||
|
OpFunctionEnd
|
||||||
|
%1 = OpFunction %void None %36
|
||||||
%9 = OpFunctionParameter %ulong
|
%9 = OpFunctionParameter %ulong
|
||||||
%10 = OpFunctionParameter %ulong
|
%10 = OpFunctionParameter %ulong
|
||||||
%26 = OpLabel
|
%29 = OpLabel
|
||||||
%2 = OpVariable %_ptr_Function_ulong Function
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
%3 = OpVariable %_ptr_Function_ulong Function
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
%4 = OpVariable %_ptr_Function_ulong Function
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
@ -38,13 +41,12 @@
|
||||||
%12 = OpLoad %ulong %3 Aligned 8
|
%12 = OpLoad %ulong %3 Aligned 8
|
||||||
OpStore %5 %12
|
OpStore %5 %12
|
||||||
%14 = OpLoad %ulong %4
|
%14 = OpLoad %ulong %4
|
||||||
%24 = OpConvertUToPtr %_ptr_Generic_uint %14
|
%27 = OpConvertUToPtr %_ptr_Generic_uint %14
|
||||||
%13 = OpLoad %uint %24 Aligned 4
|
%13 = OpLoad %uint %27 Aligned 4
|
||||||
OpStore %6 %13
|
OpStore %6 %13
|
||||||
%38 = OpLoad %v3ulong %gl_WorkGroupSize
|
%23 = OpFunctionCall %ulong %24 %uint_0
|
||||||
%23 = OpCompositeExtract %ulong %38 0
|
%40 = OpBitcast %ulong %23
|
||||||
%39 = OpBitcast %ulong %23
|
%16 = OpUConvert %uint %40
|
||||||
%16 = OpUConvert %uint %39
|
|
||||||
%15 = OpCopyObject %uint %16
|
%15 = OpCopyObject %uint %16
|
||||||
OpStore %7 %15
|
OpStore %7 %15
|
||||||
%18 = OpLoad %uint %6
|
%18 = OpLoad %uint %6
|
||||||
|
@ -53,7 +55,7 @@
|
||||||
OpStore %6 %17
|
OpStore %6 %17
|
||||||
%20 = OpLoad %ulong %5
|
%20 = OpLoad %ulong %5
|
||||||
%21 = OpLoad %uint %6
|
%21 = OpLoad %uint %6
|
||||||
%25 = OpConvertUToPtr %_ptr_Generic_uint %20
|
%28 = OpConvertUToPtr %_ptr_Generic_uint %20
|
||||||
OpStore %25 %21 Aligned 4
|
OpStore %28 %21 Aligned 4
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
22
ptx/src/test/spirv_run/vector4.ptx
Normal file
22
ptx/src/test/spirv_run/vector4.ptx
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
.version 6.5
|
||||||
|
.target sm_60
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
.visible .entry vector4(
|
||||||
|
.param .u64 input_p,
|
||||||
|
.param .u64 output_p
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .u64 in_addr;
|
||||||
|
.reg .u64 out_addr;
|
||||||
|
.reg .v4 .u32 temp;
|
||||||
|
.reg .u32 temp_scalar;
|
||||||
|
|
||||||
|
ld.param.u64 in_addr, [input_p];
|
||||||
|
ld.param.u64 out_addr, [output_p];
|
||||||
|
|
||||||
|
ld.v4.u32 temp, [in_addr];
|
||||||
|
mov.b32 temp_scalar, temp.w;
|
||||||
|
st.u32 [out_addr], temp_scalar;
|
||||||
|
ret;
|
||||||
|
}
|
99
ptx/src/test/spirv_run/vector4.spvtxt
Normal file
99
ptx/src/test/spirv_run/vector4.spvtxt
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
OpCapability GenericPointer
|
||||||
|
OpCapability Linkage
|
||||||
|
OpCapability Addresses
|
||||||
|
OpCapability Kernel
|
||||||
|
OpCapability Int8
|
||||||
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
%51 = OpExtInstImport "OpenCL.std"
|
||||||
|
OpMemoryModel Physical64 OpenCL
|
||||||
|
OpEntryPoint Kernel %25 "vector"
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%uint = OpTypeInt 32 0
|
||||||
|
%v2uint = OpTypeVector %uint 2
|
||||||
|
%55 = OpTypeFunction %v2uint %v2uint
|
||||||
|
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
|
||||||
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
|
%uint_0 = OpConstant %uint 0
|
||||||
|
%uint_1 = OpConstant %uint 1
|
||||||
|
%ulong = OpTypeInt 64 0
|
||||||
|
%67 = OpTypeFunction %void %ulong %ulong
|
||||||
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
|
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
|
||||||
|
%1 = OpFunction %v2uint None %55
|
||||||
|
%7 = OpFunctionParameter %v2uint
|
||||||
|
%24 = OpLabel
|
||||||
|
%3 = OpVariable %_ptr_Function_v2uint Function
|
||||||
|
%2 = OpVariable %_ptr_Function_v2uint Function
|
||||||
|
%4 = OpVariable %_ptr_Function_v2uint Function
|
||||||
|
%5 = OpVariable %_ptr_Function_uint Function
|
||||||
|
%6 = OpVariable %_ptr_Function_uint Function
|
||||||
|
OpStore %3 %7
|
||||||
|
%59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
|
||||||
|
%9 = OpLoad %uint %59
|
||||||
|
%8 = OpCopyObject %uint %9
|
||||||
|
OpStore %5 %8
|
||||||
|
%61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
|
||||||
|
%11 = OpLoad %uint %61
|
||||||
|
%10 = OpCopyObject %uint %11
|
||||||
|
OpStore %6 %10
|
||||||
|
%13 = OpLoad %uint %5
|
||||||
|
%14 = OpLoad %uint %6
|
||||||
|
%12 = OpIAdd %uint %13 %14
|
||||||
|
OpStore %6 %12
|
||||||
|
%16 = OpLoad %uint %6
|
||||||
|
%15 = OpCopyObject %uint %16
|
||||||
|
%62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
|
||||||
|
OpStore %62 %15
|
||||||
|
%18 = OpLoad %uint %6
|
||||||
|
%17 = OpCopyObject %uint %18
|
||||||
|
%63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
|
||||||
|
OpStore %63 %17
|
||||||
|
%64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
|
||||||
|
%20 = OpLoad %uint %64
|
||||||
|
%19 = OpCopyObject %uint %20
|
||||||
|
%65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
|
||||||
|
OpStore %65 %19
|
||||||
|
%22 = OpLoad %v2uint %4
|
||||||
|
%21 = OpCopyObject %v2uint %22
|
||||||
|
OpStore %2 %21
|
||||||
|
%23 = OpLoad %v2uint %2
|
||||||
|
OpReturnValue %23
|
||||||
|
OpFunctionEnd
|
||||||
|
%25 = OpFunction %void None %67
|
||||||
|
%34 = OpFunctionParameter %ulong
|
||||||
|
%35 = OpFunctionParameter %ulong
|
||||||
|
%49 = OpLabel
|
||||||
|
%26 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%27 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%28 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%29 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
%30 = OpVariable %_ptr_Function_v2uint Function
|
||||||
|
%31 = OpVariable %_ptr_Function_uint Function
|
||||||
|
%32 = OpVariable %_ptr_Function_uint Function
|
||||||
|
%33 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %26 %34
|
||||||
|
OpStore %27 %35
|
||||||
|
%36 = OpLoad %ulong %26 Aligned 8
|
||||||
|
OpStore %28 %36
|
||||||
|
%37 = OpLoad %ulong %27 Aligned 8
|
||||||
|
OpStore %29 %37
|
||||||
|
%39 = OpLoad %ulong %28
|
||||||
|
%46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
|
||||||
|
%38 = OpLoad %v2uint %46 Aligned 8
|
||||||
|
OpStore %30 %38
|
||||||
|
%41 = OpLoad %v2uint %30
|
||||||
|
%40 = OpFunctionCall %v2uint %1 %41
|
||||||
|
OpStore %30 %40
|
||||||
|
%43 = OpLoad %v2uint %30
|
||||||
|
%47 = OpBitcast %ulong %43
|
||||||
|
%42 = OpCopyObject %ulong %47
|
||||||
|
OpStore %33 %42
|
||||||
|
%44 = OpLoad %ulong %29
|
||||||
|
%45 = OpLoad %v2uint %30
|
||||||
|
%48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
|
||||||
|
OpStore %48 %45 Aligned 8
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
|
@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
|
||||||
let opencl_id = emit_opencl_import(&mut builder);
|
let opencl_id = emit_opencl_import(&mut builder);
|
||||||
emit_memory_model(&mut builder);
|
emit_memory_model(&mut builder);
|
||||||
let mut map = TypeWordMap::new(&mut builder);
|
let mut map = TypeWordMap::new(&mut builder);
|
||||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
//emit_builtins(&mut builder, &mut map, &id_defs);
|
||||||
let mut kernel_info = HashMap::new();
|
let mut kernel_info = HashMap::new();
|
||||||
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
|
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
|
||||||
emit_directives(
|
emit_directives(
|
||||||
|
@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>(
|
||||||
&mut numeric_id_defs,
|
&mut numeric_id_defs,
|
||||||
&mut (*func_decl).borrow_mut(),
|
&mut (*func_decl).borrow_mut(),
|
||||||
)?;
|
)?;
|
||||||
let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
|
let ssa_statements =
|
||||||
|
fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
|
||||||
let mut numeric_id_defs = numeric_id_defs.finish();
|
let mut numeric_id_defs = numeric_id_defs.finish();
|
||||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
|
||||||
let expanded_statements =
|
let expanded_statements =
|
||||||
|
@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fix_special_registers(
|
fn fix_special_registers(
|
||||||
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
typed_statements: Vec<TypedStatement>,
|
typed_statements: Vec<TypedStatement>,
|
||||||
numeric_id_defs: &mut NumericIdResolver,
|
numeric_id_defs: &mut NumericIdResolver,
|
||||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
@ -1276,7 +1278,6 @@ fn fix_special_registers(
|
||||||
for s in typed_statements {
|
for s in typed_statements {
|
||||||
match s {
|
match s {
|
||||||
Statement::LoadVar(
|
Statement::LoadVar(
|
||||||
mut
|
|
||||||
details
|
details
|
||||||
@
|
@
|
||||||
LoadVarDetails {
|
LoadVarDetails {
|
||||||
|
@ -1285,48 +1286,53 @@ fn fix_special_registers(
|
||||||
},
|
},
|
||||||
) => {
|
) => {
|
||||||
let index = details.member_index.unwrap().0;
|
let index = details.member_index.unwrap().0;
|
||||||
if index == 3 {
|
let sreg = numeric_id_defs
|
||||||
result.push(Statement::Constant(ConstantDefinition {
|
.special_registers
|
||||||
dst: details.arg.dst,
|
.get(details.arg.src)
|
||||||
typ: ast::ScalarType::U32,
|
.ok_or_else(|| error_unreachable())?;
|
||||||
value: ast::ImmediateValue::U64(0),
|
let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
|
||||||
}));
|
let index_constant = numeric_id_defs.register_intermediate(Some((
|
||||||
} else {
|
ast::Type::Scalar(ast::ScalarType::U32),
|
||||||
let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
|
ast::StateSpace::Reg,
|
||||||
{
|
)));
|
||||||
Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
|
result.push(Statement::Constant(ConstantDefinition {
|
||||||
None => None,
|
dst: index_constant,
|
||||||
};
|
typ: ast::ScalarType::U32,
|
||||||
let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
|
value: ast::ImmediateValue::U64(index as u64),
|
||||||
Some(sreg_and_type) => sreg_and_type,
|
}));
|
||||||
None => {
|
let fn_result = numeric_id_defs.register_intermediate(Some((
|
||||||
result.push(Statement::LoadVar(details));
|
ast::Type::Scalar(ocl_type),
|
||||||
continue;
|
ast::StateSpace::Reg,
|
||||||
}
|
)));
|
||||||
};
|
let return_arguments =
|
||||||
let temp_id = numeric_id_defs
|
vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
|
||||||
.register_intermediate(Some((details.typ.clone(), details.state_space)));
|
let input_arguments = vec![(
|
||||||
let real_dst = details.arg.dst;
|
TypedOperand::Reg(index_constant),
|
||||||
details.arg.dst = temp_id;
|
ast::Type::Scalar(ast::ScalarType::U32),
|
||||||
result.push(Statement::LoadVar(LoadVarDetails {
|
ast::StateSpace::Reg,
|
||||||
arg: Arg2 {
|
)];
|
||||||
src: sreg_src,
|
let fn_call = register_external_fn_call(
|
||||||
dst: temp_id,
|
numeric_id_defs,
|
||||||
},
|
ptx_impl_imports,
|
||||||
state_space: ast::StateSpace::Sreg,
|
ocl_name.to_string(),
|
||||||
typ: ast::Type::Scalar(scalar_typ),
|
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
member_index: Some((index, Some(vector_width))),
|
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
}));
|
)?;
|
||||||
result.push(Statement::Conversion(ImplicitConversion {
|
result.push(Statement::Call(ResolvedCall {
|
||||||
src: temp_id,
|
uniform: false,
|
||||||
dst: real_dst,
|
return_arguments,
|
||||||
from_type: ast::Type::Scalar(scalar_typ),
|
name: fn_call,
|
||||||
from_space: ast::StateSpace::Sreg,
|
input_arguments,
|
||||||
to_type: ast::Type::Scalar(ast::ScalarType::U32),
|
}));
|
||||||
to_space: ast::StateSpace::Sreg,
|
result.push(Statement::Conversion(ImplicitConversion {
|
||||||
kind: ConversionKind::Default,
|
src: fn_result,
|
||||||
}));
|
dst: details.arg.dst,
|
||||||
}
|
from_type: ast::Type::Scalar(ocl_type),
|
||||||
|
from_space: ast::StateSpace::Reg,
|
||||||
|
to_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||||
|
to_space: ast::StateSpace::Reg,
|
||||||
|
kind: ConversionKind::Default,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
s => result.push(s),
|
s => result.push(s),
|
||||||
}
|
}
|
||||||
|
@ -1721,8 +1727,8 @@ fn instruction_to_fn_call(
|
||||||
id_defs,
|
id_defs,
|
||||||
ptx_impl_imports,
|
ptx_impl_imports,
|
||||||
fn_name,
|
fn_name,
|
||||||
return_arguments,
|
return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
|
||||||
input_arguments,
|
input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
|
||||||
)?;
|
)?;
|
||||||
Ok(Statement::Call(ResolvedCall {
|
Ok(Statement::Call(ResolvedCall {
|
||||||
uniform: false,
|
uniform: false,
|
||||||
|
@ -1732,12 +1738,12 @@ fn instruction_to_fn_call(
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn register_external_fn_call(
|
fn register_external_fn_call<'a>(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
name: String,
|
name: String,
|
||||||
return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
) -> Result<spirv::Word, TranslateError> {
|
) -> Result<spirv::Word, TranslateError> {
|
||||||
match ptx_impl_imports.entry(name) {
|
match ptx_impl_imports.entry(name) {
|
||||||
hash_map::Entry::Vacant(entry) => {
|
hash_map::Entry::Vacant(entry) => {
|
||||||
|
@ -1770,19 +1776,18 @@ fn register_external_fn_call(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fn_arguments_to_variables(
|
fn fn_arguments_to_variables<'a>(
|
||||||
id_defs: &mut NumericIdResolver,
|
id_defs: &mut NumericIdResolver,
|
||||||
args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
|
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
|
||||||
) -> Vec<ast::Variable<spirv::Word>> {
|
) -> Vec<ast::Variable<spirv::Word>> {
|
||||||
args.iter()
|
args.map(|(typ, space)| ast::Variable {
|
||||||
.map(|(_, typ, space)| ast::Variable {
|
align: None,
|
||||||
align: None,
|
v_type: typ.clone(),
|
||||||
v_type: typ.clone(),
|
state_space: space,
|
||||||
state_space: *space,
|
name: id_defs.register_intermediate(None),
|
||||||
name: id_defs.register_intermediate(None),
|
array_init: Vec::new(),
|
||||||
array_init: Vec::new(),
|
})
|
||||||
})
|
.collect::<Vec<_>>()
|
||||||
.collect::<Vec<_>>()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn arguments_to_resolved_arguments(
|
fn arguments_to_resolved_arguments(
|
||||||
|
@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>(
|
||||||
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||||
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||||
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||||
Statement::Constant(_) => return Err(error_unreachable()),
|
Statement::Constant(c) => result.push(Statement::Constant(c)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
@ -4686,6 +4691,19 @@ impl PtxSpecialRegister {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_scalar_type(self) -> ast::ScalarType {
|
||||||
|
match self {
|
||||||
|
PtxSpecialRegister::Tid
|
||||||
|
| PtxSpecialRegister::Ntid
|
||||||
|
| PtxSpecialRegister::Ctaid
|
||||||
|
| PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
|
||||||
|
PtxSpecialRegister::Tid64
|
||||||
|
| PtxSpecialRegister::Ntid64
|
||||||
|
| PtxSpecialRegister::Ctaid64
|
||||||
|
| PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_builtin(self) -> spirv::BuiltIn {
|
fn get_builtin(self) -> spirv::BuiltIn {
|
||||||
match self {
|
match self {
|
||||||
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
||||||
|
@ -4701,6 +4719,23 @@ impl PtxSpecialRegister {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
|
||||||
|
match self {
|
||||||
|
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
|
||||||
|
("_Z12get_local_idj", ast::ScalarType::U64)
|
||||||
|
}
|
||||||
|
PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
|
||||||
|
("_Z14get_local_sizej", ast::ScalarType::U64)
|
||||||
|
}
|
||||||
|
PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
|
||||||
|
("_Z12get_group_idj", ast::ScalarType::U64)
|
||||||
|
}
|
||||||
|
PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
|
||||||
|
("_Z14get_num_groupsj", ast::ScalarType::U64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
|
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
|
||||||
match self {
|
match self {
|
||||||
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
|
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
|
||||||
|
@ -4743,6 +4778,8 @@ impl SpecialRegistersMap {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn interface(&self) -> Vec<spirv::Word> {
|
fn interface(&self) -> Vec<spirv::Word> {
|
||||||
|
return Vec::new();
|
||||||
|
/*
|
||||||
self.reg_to_id
|
self.reg_to_id
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|(sreg, id)| {
|
.filter_map(|(sreg, id)| {
|
||||||
|
@ -4753,6 +4790,7 @@ impl SpecialRegistersMap {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
|
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
|
||||||
|
|
|
@ -14,6 +14,7 @@ level_zero-sys = { path = "../level_zero-sys" }
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4"
|
||||||
num_enum = "0.4"
|
num_enum = "0.4"
|
||||||
lz4-sys = "1.9"
|
lz4-sys = "1.9"
|
||||||
|
tempfile = "3"
|
||||||
|
|
||||||
[dependencies.ocl-core]
|
[dependencies.ocl-core]
|
||||||
version = "0.11"
|
version = "0.11"
|
||||||
|
|
|
@ -4,8 +4,10 @@ use std::{
|
||||||
ffi::c_void,
|
ffi::c_void,
|
||||||
ffi::CStr,
|
ffi::CStr,
|
||||||
ffi::CString,
|
ffi::CString,
|
||||||
|
io::{self, Write},
|
||||||
mem,
|
mem,
|
||||||
os::raw::{c_char, c_int, c_uint},
|
os::raw::{c_char, c_int, c_uint},
|
||||||
|
process::{Command, Stdio},
|
||||||
ptr, slice,
|
ptr, slice,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,6 +22,7 @@ use super::{
|
||||||
CUresult, GlobalState, HasLivenessCookie, LiveCheck,
|
CUresult, GlobalState, HasLivenessCookie, LiveCheck,
|
||||||
};
|
};
|
||||||
use ptx;
|
use ptx;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
pub type Module = LiveCheck<ModuleData>;
|
pub type Module = LiveCheck<ModuleData>;
|
||||||
|
|
||||||
|
@ -88,6 +91,36 @@ impl SpirvModule {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const LLVM_SPIRV: &'static str = "/home/vosen/amd/llvm-project/build/bin/llvm-spirv";
|
||||||
|
const AMDGPU: &'static str = "/opt/amdgpu-pro/";
|
||||||
|
const AMDGPU_BITCODE: [&'static str; 8] = [
|
||||||
|
"opencl",
|
||||||
|
"ocml",
|
||||||
|
"ockl",
|
||||||
|
"oclc_correctly_rounded_sqrt_off",
|
||||||
|
"oclc_daz_opt_on",
|
||||||
|
"oclc_finite_only_off",
|
||||||
|
"oclc_unsafe_math_off",
|
||||||
|
"oclc_wavefrontsize64_off",
|
||||||
|
];
|
||||||
|
const AMDGPU_BITCODE_DEVICE_PREFIX: &'static str = "oclc_isa_version_";
|
||||||
|
const AMDGPU_DEVICE: &'static str = "gfx1010";
|
||||||
|
|
||||||
|
fn compile_amd(spirv_il: &[u8]) -> io::Result<()> {
|
||||||
|
let dir = tempfile::tempdir()?;
|
||||||
|
let mut spirv = NamedTempFile::new_in(&dir)?;
|
||||||
|
let llvm = NamedTempFile::new_in(&dir)?;
|
||||||
|
spirv.write_all(spirv_il)?;
|
||||||
|
let mut cmd = Command::new(Self::LLVM_SPIRV)
|
||||||
|
.arg("-r")
|
||||||
|
.arg("-o")
|
||||||
|
.arg(llvm.path())
|
||||||
|
.arg(spirv.path())
|
||||||
|
.status()?;
|
||||||
|
assert!(cmd.success());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn compile<'a>(
|
pub fn compile<'a>(
|
||||||
&self,
|
&self,
|
||||||
ctx: &ocl_core::Context,
|
ctx: &ocl_core::Context,
|
||||||
|
@ -99,6 +132,7 @@ impl SpirvModule {
|
||||||
self.binaries.len() * mem::size_of::<u32>(),
|
self.binaries.len() * mem::size_of::<u32>(),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
Self::compile_amd(byte_il).unwrap();
|
||||||
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
|
let main_module = ocl_core::create_program_with_il(ctx, byte_il, None)?;
|
||||||
let main_module = match self.should_link_ptx_impl {
|
let main_module = match self.should_link_ptx_impl {
|
||||||
None => {
|
None => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue