mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add more tests
This commit is contained in:
parent
17f2d09cc7
commit
dcaea507ba
7 changed files with 175 additions and 139 deletions
|
@ -13,8 +13,10 @@
|
|||
%25 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uchar = OpTypeInt 8 0
|
||||
%_arr_uchar_8 = OpTypeArray %uchar %8
|
||||
%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_8 = OpConstant %uint 8
|
||||
%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8
|
||||
%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%1 = OpFunction %void None %25
|
||||
%8 = OpFunctionParameter %ulong
|
||||
|
@ -22,7 +24,7 @@
|
|||
%20 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
|
||||
%4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
|
|
|
@ -8,7 +8,7 @@ use spirv_headers::Word;
|
|||
use spirv_tools_sys::{
|
||||
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
|
||||
};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::{collections::hash_map::Entry, cmp};
|
||||
use std::error;
|
||||
use std::ffi::{c_void, CStr, CString};
|
||||
use std::fmt;
|
||||
|
@ -59,8 +59,9 @@ test_ptx!(local_align, [1u64], [1u64]);
|
|||
test_ptx!(call, [1u64], [2u64]);
|
||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
||||
//test_ptx!(ntid, [3u32], [4u32]);
|
||||
//test_ptx!(reg_slm, [12u64], [12u64]);
|
||||
test_ptx!(ntid, [3u32], [4u32]);
|
||||
test_ptx!(reg_local, [12u64], [12u64]);
|
||||
test_ptx!(mov_address, [0xDEADu64], [0u64]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
@ -123,8 +124,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
|||
kernel.set_indirect_access(
|
||||
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
|
||||
)?;
|
||||
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, input.len())?;
|
||||
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, output.len())?;
|
||||
let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
|
||||
let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
|
||||
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
|
||||
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
|
||||
let ev0 = ze::Event::new(&event_pool, 0)?;
|
||||
|
|
15
ptx/src/test/spirv_run/mov_address.ptx
Normal file
15
ptx/src/test/spirv_run/mov_address.ptx
Normal file
|
@ -0,0 +1,15 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mov_address(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.local .b8 __local_depot0[8];
|
||||
.reg .u64 temp;
|
||||
|
||||
mov.u64 temp, __local_depot0;
|
||||
ret;
|
||||
}
|
|
@ -2,12 +2,12 @@
|
|||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry reg_slm(
|
||||
.visible .entry reg_local(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.local .align 8 .b8 slm[8];
|
||||
.local .align 8 .b8 local_x[8];
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b64 temp;
|
||||
|
@ -16,11 +16,9 @@
|
|||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
mov.s64 unused, slm;
|
||||
|
||||
ld.global.u64 temp, [in_addr];
|
||||
st.u64 [slm], temp;
|
||||
ld.u64 temp, [slm];
|
||||
st.u64 [local_x], temp;
|
||||
ld.u64 temp, [local_x];
|
||||
st.global.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
46
ptx/src/test/spirv_run/reg_local.spvtxt
Normal file
46
ptx/src/test/spirv_run/reg_local.spvtxt
Normal file
|
@ -0,0 +1,46 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%25 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "add"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%28 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%ulong_1 = OpConstant %ulong 1
|
||||
%1 = OpFunction %void None %28
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%23 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %21
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %6
|
||||
%16 = OpIAdd %ulong %17 %ulong_1
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %ulong %5
|
||||
%19 = OpLoad %ulong %7
|
||||
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||
OpStore %22 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -217,11 +217,13 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
|
|||
let opencl_id = emit_opencl_import(&mut builder);
|
||||
emit_memory_model(&mut builder);
|
||||
let mut map = TypeWordMap::new(&mut builder);
|
||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
||||
for f in ssa_functions {
|
||||
let f_body = match f.body {
|
||||
Some(f) => f,
|
||||
None => continue,
|
||||
};
|
||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
||||
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
|
||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
||||
builder.end_function()?;
|
||||
|
@ -229,6 +231,33 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
|
|||
Ok(builder.module())
|
||||
}
|
||||
|
||||
fn emit_builtins(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
id_defs: &GlobalStringIdResolver,
|
||||
) {
|
||||
for (reg, id) in id_defs.special_registers.iter() {
|
||||
let result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(
|
||||
Box::new(SpirvType::from(reg.get_type())),
|
||||
spirv::StorageClass::UniformConstant,
|
||||
),
|
||||
);
|
||||
builder.variable(
|
||||
result_type,
|
||||
Some(*id),
|
||||
spirv::StorageClass::UniformConstant,
|
||||
None,
|
||||
);
|
||||
builder.decorate(
|
||||
*id,
|
||||
spirv::Decoration::BuiltIn,
|
||||
&[dr::Operand::BuiltIn(reg.get_builtin())],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_function_header<'a>(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -239,7 +268,12 @@ fn emit_function_header<'a>(
|
|||
let fn_id = match func_directive {
|
||||
ast::MethodDecl::Kernel(name, _) => {
|
||||
let fn_id = global.get_id(name)?;
|
||||
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
|
||||
let interface = global
|
||||
.special_registers
|
||||
.iter()
|
||||
.map(|(_, id)| *id)
|
||||
.collect::<Vec<_>>();
|
||||
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
|
||||
fn_id
|
||||
}
|
||||
ast::MethodDecl::Func(_, name, _) => name,
|
||||
|
@ -293,7 +327,7 @@ fn emit_memory_model(builder: &mut dr::Builder) {
|
|||
fn to_ssa_function<'a>(
|
||||
id_defs: &mut GlobalStringIdResolver<'a>,
|
||||
f: ast::ParsedFunction<'a>,
|
||||
) -> Result<ExpandedFunction<'a>, TranslateError> {
|
||||
) -> Result<Function<'a>, TranslateError> {
|
||||
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
|
||||
to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
|
||||
}
|
||||
|
@ -333,13 +367,14 @@ fn to_ssa<'input, 'b>(
|
|||
fn_defs: GlobalFnDeclResolver<'input, 'b>,
|
||||
f_args: ast::MethodDecl<'input, ExpandedArgParams>,
|
||||
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
|
||||
) -> Result<ExpandedFunction<'input>, TranslateError> {
|
||||
) -> Result<Function<'input>, TranslateError> {
|
||||
let f_body = match f_body {
|
||||
Some(vec) => vec,
|
||||
None => {
|
||||
return Ok(ExpandedFunction {
|
||||
return Ok(Function {
|
||||
func_directive: f_args,
|
||||
body: None,
|
||||
globals: Vec::new(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
@ -357,12 +392,21 @@ fn to_ssa<'input, 'b>(
|
|||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||
let sorted_statements = normalize_variable_decls(labeled_statements);
|
||||
Ok(ExpandedFunction {
|
||||
let (f_body, globals) = extract_globals(sorted_statements);
|
||||
Ok(Function {
|
||||
func_directive: f_args,
|
||||
body: Some(sorted_statements),
|
||||
globals: globals,
|
||||
body: Some(f_body),
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_globals(
|
||||
sorted_statements: Vec<ExpandedStatement>,
|
||||
) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
|
||||
// This fn will be used for SLM
|
||||
(sorted_statements, Vec::new())
|
||||
}
|
||||
|
||||
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
|
||||
func[1..].sort_by_key(|s| match s {
|
||||
Statement::Variable(_) => 0,
|
||||
|
@ -477,7 +521,9 @@ fn add_types_to_statements(
|
|||
},
|
||||
_ => dets,
|
||||
};
|
||||
Ok(Statement::Instruction(ast::Instruction::MovVector(new_dets, args)))
|
||||
Ok(Statement::Instruction(ast::Instruction::MovVector(
|
||||
new_dets, args,
|
||||
)))
|
||||
}
|
||||
s => Ok(s),
|
||||
}
|
||||
|
@ -724,7 +770,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
|
|||
(_, ArgumentSemantics::Address) => return Ok(desc.op),
|
||||
(t, ArgumentSemantics::RegisterPointer)
|
||||
| (t, ArgumentSemantics::Default)
|
||||
| (t, ArgumentSemantics::Ptr) => t,
|
||||
| (t, ArgumentSemantics::PhysicalPointer) => t,
|
||||
};
|
||||
let generated_id = id_def.new_id(id_type);
|
||||
if !desc.is_dst {
|
||||
|
@ -873,7 +919,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
));
|
||||
Ok(result_id)
|
||||
}
|
||||
ArgumentSemantics::Ptr => {
|
||||
ArgumentSemantics::PhysicalPointer => {
|
||||
let scalar_t = ast::ScalarType::U64;
|
||||
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
|
||||
let result_id = self.id_def.new_id(typ);
|
||||
|
@ -1137,7 +1183,7 @@ fn emit_function_body_ops(
|
|||
builder.begin_block(Some(*id))?;
|
||||
}
|
||||
_ => {
|
||||
if builder.block.is_none() {
|
||||
if builder.block.is_none() && builder.function.is_some() {
|
||||
builder.begin_block(None)?;
|
||||
}
|
||||
}
|
||||
|
@ -1166,10 +1212,9 @@ fn emit_function_body_ops(
|
|||
name,
|
||||
}) => {
|
||||
let st_class = match v_type {
|
||||
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
|
||||
spirv::StorageClass::Function
|
||||
}
|
||||
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
|
||||
ast::VariableType::Reg(_)
|
||||
| ast::VariableType::Param(_)
|
||||
| ast::VariableType::Local(_) => spirv::StorageClass::Function,
|
||||
};
|
||||
let type_id = map.get_or_add(
|
||||
builder,
|
||||
|
@ -1234,7 +1279,7 @@ fn emit_function_body_ops(
|
|||
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
|
||||
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
|
||||
}
|
||||
ast::LdStateSpace::Param => {
|
||||
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
|
||||
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
||||
}
|
||||
|
@ -1242,18 +1287,20 @@ fn emit_function_body_ops(
|
|||
}
|
||||
}
|
||||
ast::Instruction::St(data, arg) => {
|
||||
if data.qualifier != ast::LdStQualifier::Weak
|
||||
|| (data.state_space != ast::StStateSpace::Generic
|
||||
&& data.state_space != ast::StStateSpace::Param
|
||||
&& data.state_space != ast::StStateSpace::Global)
|
||||
{
|
||||
if data.qualifier != ast::LdStQualifier::Weak {
|
||||
todo!()
|
||||
}
|
||||
if data.state_space == ast::StStateSpace::Param {
|
||||
if data.state_space == ast::StStateSpace::Param
|
||||
|| data.state_space == ast::StStateSpace::Local
|
||||
{
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
|
||||
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
|
||||
} else {
|
||||
} else if data.state_space == ast::StStateSpace::Generic
|
||||
|| data.state_space == ast::StStateSpace::Global
|
||||
{
|
||||
builder.store(arg.src1, arg.src2, None, &[])?;
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
|
@ -1643,7 +1690,7 @@ fn emit_implicit_conversion(
|
|||
let from_parts = cv.from.to_parts();
|
||||
let to_parts = cv.to.to_parts();
|
||||
match (from_parts.kind, to_parts.kind, cv.kind) {
|
||||
(_, _, ConversionKind::Ptr(space)) => {
|
||||
(_, _, ConversionKind::BitToPtr(space)) => {
|
||||
let dst_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
|
||||
|
@ -1699,14 +1746,11 @@ fn emit_implicit_conversion(
|
|||
}
|
||||
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
|
||||
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
|
||||
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => {
|
||||
| (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default)
|
||||
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
|
||||
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
|
||||
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
(TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
|
||||
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
|
||||
builder.convert_ptr_to_u(into_type, Some(cv.dst), cv.src)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
|
@ -2181,7 +2225,7 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
|
|||
decl.get_fn_decl_str(id)
|
||||
}
|
||||
|
||||
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
|
||||
fn get_src_semantics(_: &Self::MovOperand) -> ArgumentSemantics {
|
||||
ArgumentSemantics::Default
|
||||
}
|
||||
}
|
||||
|
@ -2230,7 +2274,12 @@ pub enum StateSpace {
|
|||
|
||||
enum ExpandedArgParams {}
|
||||
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
|
||||
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
|
||||
|
||||
struct Function<'input> {
|
||||
pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
|
||||
pub globals: Vec<ExpandedStatement>,
|
||||
pub body: Option<Vec<ExpandedStatement>>,
|
||||
}
|
||||
|
||||
impl ast::ArgParams for ExpandedArgParams {
|
||||
type ID = spirv::Word;
|
||||
|
@ -2248,7 +2297,7 @@ impl ArgParamsEx for ExpandedArgParams {
|
|||
decl.get_fn_decl(*id)
|
||||
}
|
||||
|
||||
fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
|
||||
fn get_src_semantics(_: &spirv::Word) -> ArgumentSemantics {
|
||||
ArgumentSemantics::Default
|
||||
}
|
||||
}
|
||||
|
@ -2398,12 +2447,12 @@ struct ArgumentDescriptor<Op> {
|
|||
sema: ArgumentSemantics,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
pub enum ArgumentSemantics {
|
||||
// normal register access
|
||||
Default,
|
||||
// st/ld global
|
||||
Ptr,
|
||||
PhysicalPointer,
|
||||
// st/ld .param, .local
|
||||
RegisterPointer,
|
||||
// mov of .local/.global variables
|
||||
|
@ -2720,7 +2769,8 @@ enum ConversionKind {
|
|||
Default,
|
||||
// zero-extend/chop/bitcast depending on types
|
||||
SignExtend,
|
||||
Ptr(ast::LdStateSpace),
|
||||
BitToPtr(ast::LdStateSpace),
|
||||
PtrToBit,
|
||||
}
|
||||
|
||||
impl<T> ast::PredAt<T> {
|
||||
|
@ -2831,7 +2881,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
|
|||
sema: if is_param {
|
||||
ArgumentSemantics::RegisterPointer
|
||||
} else {
|
||||
ArgumentSemantics::Ptr
|
||||
ArgumentSemantics::PhysicalPointer
|
||||
},
|
||||
},
|
||||
t,
|
||||
|
@ -2919,7 +2969,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
|||
sema: if is_param {
|
||||
ArgumentSemantics::RegisterPointer
|
||||
} else {
|
||||
ArgumentSemantics::Ptr
|
||||
ArgumentSemantics::PhysicalPointer
|
||||
},
|
||||
},
|
||||
t,
|
||||
|
@ -3518,7 +3568,7 @@ fn get_implicit_conversions_ld_src(
|
|||
) -> Result<Vec<ImplicitConversion>, TranslateError> {
|
||||
let src_type = id_def.get_typed(src)?;
|
||||
match state_space {
|
||||
ast::LdStateSpace::Param => {
|
||||
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
|
||||
if src_type != instr_type {
|
||||
Ok(vec![
|
||||
ImplicitConversion {
|
||||
|
@ -3560,7 +3610,7 @@ fn get_implicit_conversions_ld_src(
|
|||
dst: u32::max_value(),
|
||||
from: src_type,
|
||||
to: instr_type,
|
||||
kind: ConversionKind::Ptr(state_space),
|
||||
kind: ConversionKind::BitToPtr(state_space),
|
||||
});
|
||||
if result.len() == 2 {
|
||||
let new_id = id_def.new_id(new_src_type);
|
||||
|
@ -3570,92 +3620,9 @@ fn get_implicit_conversions_ld_src(
|
|||
}
|
||||
Ok(result)
|
||||
}
|
||||
_ => todo!(),
|
||||
_ => Err(TranslateError::Todo),
|
||||
}
|
||||
}
|
||||
fn insert_implicit_conversions_ld_src(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
instr_type: ast::Type,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
state_space: ast::LdStateSpace,
|
||||
src: spirv::Word,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
match state_space {
|
||||
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
|
||||
func,
|
||||
id_def,
|
||||
instr_type,
|
||||
src,
|
||||
should_convert_ld_param_src,
|
||||
),
|
||||
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
|
||||
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
|
||||
mem::size_of::<usize>() as u8,
|
||||
ScalarKind::Bit,
|
||||
));
|
||||
let new_src = insert_implicit_conversions_ld_src_impl(
|
||||
func,
|
||||
id_def,
|
||||
new_src_type,
|
||||
src,
|
||||
should_convert_ld_generic_src_to_bitcast,
|
||||
)?;
|
||||
Ok(insert_conversion_src(
|
||||
func,
|
||||
id_def,
|
||||
new_src,
|
||||
new_src_type,
|
||||
instr_type,
|
||||
ConversionKind::Ptr(state_space),
|
||||
))
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_ld_src_impl<
|
||||
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
|
||||
>(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
instr_type: ast::Type,
|
||||
src: spirv::Word,
|
||||
should_convert: ShouldConvert,
|
||||
) -> Result<spirv::Word, TranslateError> {
|
||||
let src_type = id_def.get_typed(src)?;
|
||||
if let Some(conv) = should_convert(src_type, instr_type) {
|
||||
Ok(insert_conversion_src(
|
||||
func, id_def, src, src_type, instr_type, conv,
|
||||
))
|
||||
} else {
|
||||
Ok(src)
|
||||
}
|
||||
}
|
||||
|
||||
fn should_convert_ld_param_src(
|
||||
src_type: ast::Type,
|
||||
instr_type: ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type != instr_type {
|
||||
return Some(ConversionKind::Default);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// HACK ALERT
|
||||
// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
|
||||
// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
|
||||
fn should_convert_ld_generic_src_to_bitcast(
|
||||
src_type: ast::Type,
|
||||
_instr_type: ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if let ast::Type::Scalar(src_type) = src_type {
|
||||
if src_type.kind() == ScalarKind::Signed {
|
||||
return Some(ConversionKind::Default);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn insert_conversion_src(
|
||||
|
@ -3832,14 +3799,21 @@ fn insert_implicit_bitcasts(
|
|||
None => return Ok(desc.op),
|
||||
};
|
||||
let id_actual_type = id_def.get_typed(desc.op)?;
|
||||
if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
|
||||
let conv_kind = if desc.sema == ArgumentSemantics::Address {
|
||||
Some(ConversionKind::PtrToBit)
|
||||
} else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(conv_kind) = conv_kind {
|
||||
if desc.is_dst {
|
||||
dst_coercion = Some(get_conversion_dst(
|
||||
id_def,
|
||||
&mut desc.op,
|
||||
id_type_from_instr,
|
||||
id_actual_type,
|
||||
ConversionKind::Default,
|
||||
conv_kind,
|
||||
));
|
||||
Ok(desc.op)
|
||||
} else {
|
||||
|
@ -3849,7 +3823,7 @@ fn insert_implicit_bitcasts(
|
|||
desc.op,
|
||||
id_actual_type,
|
||||
id_type_from_instr,
|
||||
ConversionKind::Default,
|
||||
conv_kind,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue