mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
[BROKEN] Start implementing better support for addressable arguments
This commit is contained in:
parent
42bad8fcc2
commit
952ed5d504
8 changed files with 351 additions and 80 deletions
|
@ -354,6 +354,7 @@ pub struct CallInst<P: ArgParams> {
|
|||
pub trait ArgParams {
|
||||
type ID;
|
||||
type Operand;
|
||||
type MemoryOperand;
|
||||
type CallOperand;
|
||||
type VecOperand;
|
||||
}
|
||||
|
@ -365,6 +366,7 @@ pub struct ParsedArgParams<'a> {
|
|||
impl<'a> ArgParams for ParsedArgParams<'a> {
|
||||
type ID = &'a str;
|
||||
type Operand = Operand<&'a str>;
|
||||
type MemoryOperand = Operand<&'a str>;
|
||||
type CallOperand = CallOperand<&'a str>;
|
||||
type VecOperand = (&'a str, u8);
|
||||
}
|
||||
|
@ -378,8 +380,13 @@ pub struct Arg2<P: ArgParams> {
|
|||
pub src: P::Operand,
|
||||
}
|
||||
|
||||
pub struct Arg2Ld<P: ArgParams> {
|
||||
pub dst: P::ID,
|
||||
pub src: P::MemoryOperand,
|
||||
}
|
||||
|
||||
pub struct Arg2St<P: ArgParams> {
|
||||
pub src1: P::Operand,
|
||||
pub src1: P::MemoryOperand,
|
||||
pub src2: P::Operand,
|
||||
}
|
||||
|
||||
|
@ -416,13 +423,13 @@ pub struct Arg5<P: ArgParams> {
|
|||
pub enum Operand<ID> {
|
||||
Reg(ID),
|
||||
RegOffset(ID, i32),
|
||||
Imm(i128),
|
||||
Imm(u32),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum CallOperand<ID> {
|
||||
Reg(ID),
|
||||
Imm(i128),
|
||||
Imm(u32),
|
||||
}
|
||||
|
||||
pub enum VectorPrefix {
|
||||
|
|
|
@ -446,7 +446,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
|
||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => {
|
||||
ast::Instruction::Ld(
|
||||
ast::LdData {
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
|
@ -899,7 +899,7 @@ ShlType: ast::ShlType = {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
||||
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
|
||||
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => {
|
||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => {
|
||||
ast::Instruction::St(
|
||||
ast::StData {
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
|
@ -912,6 +912,11 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
|||
}
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors
|
||||
MemoryOperand: ast::Operand<&'input str> = {
|
||||
"[" <o:Operand> "]" => o
|
||||
}
|
||||
|
||||
StStateSpace: ast::StStateSpace = {
|
||||
".global" => ast::StStateSpace::Global,
|
||||
".local" => ast::StStateSpace::Local,
|
||||
|
@ -1006,7 +1011,7 @@ Operand: ast::Operand<&'input str> = {
|
|||
// TODO: start parsing whole constants sub-language:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
|
||||
<o:Num> => {
|
||||
let offset = o.parse::<i128>();
|
||||
let offset = o.parse::<u32>();
|
||||
let offset = offset.unwrap_with(errors);
|
||||
ast::Operand::Imm(offset)
|
||||
}
|
||||
|
@ -1015,7 +1020,7 @@ Operand: ast::Operand<&'input str> = {
|
|||
CallOperand: ast::CallOperand<&'input str> = {
|
||||
<r:ExtendedID> => ast::CallOperand::Reg(r),
|
||||
<o:Num> => {
|
||||
let offset = o.parse::<i128>();
|
||||
let offset = o.parse::<u32>();
|
||||
let offset = offset.unwrap_with(errors);
|
||||
ast::CallOperand::Imm(offset)
|
||||
}
|
||||
|
|
|
@ -59,6 +59,8 @@ test_ptx!(local_align, [1u64], [1u64]);
|
|||
test_ptx!(call, [1u64], [2u64]);
|
||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
|
||||
test_ptx!(ntid, [3u32], [4u32]);
|
||||
test_ptx!(reg_slm, [12u64], [12u64]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
23
ptx/src/test/spirv_run/ntid.ptx
Normal file
23
ptx/src/test/spirv_run/ntid.ptx
Normal file
|
@ -0,0 +1,23 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry ntid(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u32 in_val;
|
||||
.reg .u32 global_count;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u32 in_val, [in_addr];
|
||||
mov.u32 global_count, %ntid.x;
|
||||
add.u32 in_val, in_val, global_count;
|
||||
st.u32 [out_addr], in_val;
|
||||
ret;
|
||||
}
|
56
ptx/src/test/spirv_run/ntid.spvtxt
Normal file
56
ptx/src/test/spirv_run/ntid.spvtxt
Normal file
|
@ -0,0 +1,56 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%29 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "add" %GlobalSize
|
||||
OpDecorate %GlobalSize BuiltIn GlobalSize
|
||||
%void = OpTypeVoid
|
||||
%uint = OpTypeInt 32 0
|
||||
%v3uint = OpTypeVector %uint 3
|
||||
%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint
|
||||
%GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant
|
||||
%ulong = OpTypeInt 64 0
|
||||
%35 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%1 = OpFunction %void None %35
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%10 = OpFunctionParameter %ulong
|
||||
%27 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_uint Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %2 %9
|
||||
OpStore %3 %10
|
||||
%12 = OpLoad %ulong %2
|
||||
%11 = OpCopyObject %ulong %12
|
||||
OpStore %4 %11
|
||||
%14 = OpLoad %ulong %3
|
||||
%13 = OpCopyObject %ulong %14
|
||||
OpStore %5 %13
|
||||
%16 = OpLoad %ulong %4
|
||||
%25 = OpConvertUToPtr %_ptr_Generic_uint %16
|
||||
%15 = OpLoad %uint %25
|
||||
OpStore %6 %15
|
||||
%18 = OpLoad %v3uint %GlobalSize
|
||||
%24 = OpCompositeExtract %uint %18 0
|
||||
%17 = OpCopyObject %uint %24
|
||||
OpStore %7 %17
|
||||
%20 = OpLoad %uint %6
|
||||
%21 = OpLoad %uint %7
|
||||
%19 = OpIAdd %uint %20 %21
|
||||
OpStore %6 %19
|
||||
%22 = OpLoad %ulong %5
|
||||
%23 = OpLoad %uint %6
|
||||
%26 = OpConvertUToPtr %_ptr_Generic_uint %22
|
||||
OpStore %26 %23
|
||||
OpReturn
|
||||
OpFunctionEnd
|
26
ptx/src/test/spirv_run/reg_slm.ptx
Normal file
26
ptx/src/test/spirv_run/reg_slm.ptx
Normal file
|
@ -0,0 +1,26 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry reg_slm(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.local .align 8 .b8 slm[8];
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .b64 temp;
|
||||
.reg .s64 unused;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
mov.s64 unused, slm;
|
||||
|
||||
ld.global.u64 temp, [in_addr];
|
||||
st.u64 [slm], temp;
|
||||
ld.u64 temp, [slm];
|
||||
st.global.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
46
ptx/src/test/spirv_run/reg_slm.spvtxt
Normal file
46
ptx/src/test/spirv_run/reg_slm.spvtxt
Normal file
|
@ -0,0 +1,46 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%25 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "add"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%28 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%ulong_1 = OpConstant %ulong 1
|
||||
%1 = OpFunction %void None %28
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%23 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %21
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %6
|
||||
%16 = OpIAdd %ulong %17 %ulong_1
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %ulong %5
|
||||
%19 = OpLoad %ulong %7
|
||||
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||
OpStore %22 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -286,7 +286,7 @@ fn expand_kernel_params<'a, 'b>(
|
|||
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
||||
args.map(|a| ast::KernelArgument {
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||
name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
|
||||
v_type: a.v_type,
|
||||
align: a.align,
|
||||
})
|
||||
|
@ -297,10 +297,16 @@ fn expand_fn_params<'a, 'b>(
|
|||
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
|
||||
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
||||
args.map(|a| ast::FnArgument {
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||
v_type: a.v_type,
|
||||
align: a.align,
|
||||
args.map(|a| {
|
||||
let ss = match a.v_type {
|
||||
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
||||
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
||||
};
|
||||
ast::FnArgument {
|
||||
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
|
||||
v_type: a.v_type,
|
||||
align: a.align,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -325,6 +331,8 @@ fn to_ssa<'input, 'b>(
|
|||
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
|
||||
let unadorned_statements =
|
||||
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
|
||||
todo!()
|
||||
/*
|
||||
let (f_args, ssa_statements) =
|
||||
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
|
||||
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
|
||||
|
@ -336,6 +344,7 @@ fn to_ssa<'input, 'b>(
|
|||
func_directive: f_args,
|
||||
body: Some(sorted_statements),
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
|
||||
|
@ -350,7 +359,7 @@ fn add_types_to_statements(
|
|||
func: Vec<UnadornedStatement>,
|
||||
fn_defs: &GlobalFnDeclResolver,
|
||||
id_defs: &NumericIdResolver,
|
||||
) -> Vec<UnadornedStatement> {
|
||||
) -> Vec<TypedStatement> {
|
||||
func.into_iter()
|
||||
.map(|s| {
|
||||
match s {
|
||||
|
@ -359,7 +368,7 @@ fn add_types_to_statements(
|
|||
let fn_def = fn_defs.get_fn_decl(call.func);
|
||||
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
|
||||
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
|
||||
let resolved_call = ResolvedCall {
|
||||
let resolved_call: ResolvedCall<TypedArgParams> = ResolvedCall {
|
||||
uniform: call.uniform,
|
||||
ret_params,
|
||||
func: call.func,
|
||||
|
@ -367,18 +376,13 @@ fn add_types_to_statements(
|
|||
};
|
||||
Statement::Call(resolved_call)
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
|
||||
// TODO fail on type mismatch
|
||||
let new_dets = match id_defs.get_type(*args.dst()) {
|
||||
Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
|
||||
length: len,
|
||||
..dets
|
||||
},
|
||||
_ => dets,
|
||||
};
|
||||
Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
|
||||
Statement::Instruction(ast::Instruction::Ld(d, arg)) => {
|
||||
todo!()
|
||||
}
|
||||
s => s,
|
||||
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
|
||||
todo!()
|
||||
}
|
||||
s => todo!(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
|
@ -485,7 +489,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
ast::MethodDecl::Kernel(_, in_params) => {
|
||||
for p in in_params.iter_mut() {
|
||||
let typ = ast::Type::from(p.v_type);
|
||||
let new_id = id_def.new_id(Some(typ));
|
||||
let new_id = id_def.new_id(Some((StateSpace::Param, typ)));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: p.align,
|
||||
v_type: ast::VariableType::Param(p.v_type),
|
||||
|
@ -504,8 +508,12 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
}
|
||||
ast::MethodDecl::Func(out_params, _, in_params) => {
|
||||
for p in in_params.iter_mut() {
|
||||
let ss = match p.v_type {
|
||||
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
|
||||
ast::FnArgumentType::Param(_) => StateSpace::Param,
|
||||
};
|
||||
let typ = ast::Type::from(p.v_type);
|
||||
let new_id = id_def.new_id(Some(typ));
|
||||
let new_id = id_def.new_id(Some((ss, typ)));
|
||||
let var_typ = ast::VariableType::from(p.v_type);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: p.align,
|
||||
|
@ -548,7 +556,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
dst: new_id,
|
||||
src: out_param,
|
||||
},
|
||||
typ.unwrap(),
|
||||
typ.unwrap().1,
|
||||
));
|
||||
result.push(Statement::RetValue(d, new_id));
|
||||
} else {
|
||||
|
@ -558,7 +566,10 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
||||
},
|
||||
Statement::Conditional(mut bra) => {
|
||||
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
|
||||
let generated_id = id_def.new_id(Some((
|
||||
StateSpace::Reg,
|
||||
ast::Type::Scalar(ast::ScalarType::Pred),
|
||||
)));
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
dst: generated_id,
|
||||
|
@ -607,11 +618,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
|
|||
let mut post_statements = Vec::new();
|
||||
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
|
||||
let id_type = match (id_def.get_type(desc.op), desc.sema) {
|
||||
(Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t,
|
||||
(Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
(Some((_, t)), ArgumentSemantics::ParamPtr)
|
||||
| (Some((_, t)), ArgumentSemantics::Default) => t,
|
||||
(Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
(None, _) => return desc.op,
|
||||
};
|
||||
let generated_id = id_def.new_id(Some(id_type));
|
||||
let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type)));
|
||||
if !desc.is_dst {
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
|
@ -716,11 +728,13 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
} else {
|
||||
todo!()
|
||||
};
|
||||
let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
||||
let id = self
|
||||
.id_def
|
||||
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value: x,
|
||||
value: x as i64,
|
||||
}));
|
||||
id
|
||||
}
|
||||
|
@ -732,13 +746,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
} else {
|
||||
todo!()
|
||||
};
|
||||
let id_constant_stmt =
|
||||
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
||||
let result_id = self.id_def.new_id(Some(typ));
|
||||
let id_constant_stmt = self
|
||||
.id_def
|
||||
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
value: offset as i128,
|
||||
value: offset as i64,
|
||||
}));
|
||||
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
||||
self.func.push(Statement::Instruction(
|
||||
|
@ -758,13 +773,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
}
|
||||
ArgumentSemantics::Ptr => {
|
||||
let scalar_t = ast::ScalarType::U64;
|
||||
let id_constant_stmt =
|
||||
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
||||
let result_id = self.id_def.new_id(Some(typ));
|
||||
let id_constant_stmt = self
|
||||
.id_def
|
||||
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
|
||||
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
value: offset as i128,
|
||||
value: offset as i64,
|
||||
}));
|
||||
let int_type = ast::IntType::U64;
|
||||
self.func.push(Statement::Instruction(
|
||||
|
@ -810,9 +826,10 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
desc: ArgumentDescriptor<(spirv::Word, u8)>,
|
||||
(scalar_type, vec_len): (ast::MovVectorType, u8),
|
||||
) -> spirv::Word {
|
||||
let new_id = self
|
||||
.id_def
|
||||
.new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
|
||||
let new_id = self.id_def.new_id(Some((
|
||||
StateSpace::Reg,
|
||||
ast::Type::Vector(scalar_type.into(), vec_len),
|
||||
)));
|
||||
self.func.push(Statement::Composite(CompositeRead {
|
||||
typ: scalar_type,
|
||||
dst: new_id,
|
||||
|
@ -821,6 +838,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
}));
|
||||
new_id
|
||||
}
|
||||
|
||||
fn mov_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
|
||||
typ: ast::Type,
|
||||
) -> spirv::Word {
|
||||
self.operand(desc, typ)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -911,7 +936,7 @@ fn insert_implicit_conversions(
|
|||
let mut did_vector_implicit = false;
|
||||
let mut post_conv = None;
|
||||
if inst_typ_is_bit {
|
||||
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
|
||||
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1;
|
||||
if let ast::Type::Vector(_, _) = src_type {
|
||||
arg.src = insert_conversion_src(
|
||||
&mut result,
|
||||
|
@ -923,7 +948,7 @@ fn insert_implicit_conversions(
|
|||
);
|
||||
did_vector_implicit = true;
|
||||
}
|
||||
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
|
||||
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1;
|
||||
if let ast::Type::Vector(_, _) = src_type {
|
||||
post_conv = Some(get_conversion_dst(
|
||||
id_def,
|
||||
|
@ -1615,25 +1640,32 @@ fn expand_map_variables<'a, 'b>(
|
|||
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
|
||||
i.map_variable(&mut |id| id_defs.get_id(id)),
|
||||
))),
|
||||
ast::Statement::Variable(var) => match var.count {
|
||||
Some(count) => {
|
||||
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
|
||||
ast::Statement::Variable(var) => {
|
||||
let ss = match var.var.v_type {
|
||||
ast::VariableType::Reg(_) => StateSpace::Reg,
|
||||
ast::VariableType::Local(_) => StateSpace::Local,
|
||||
ast::VariableType::Param(_) => StateSpace::ParamReg,
|
||||
};
|
||||
match var.count {
|
||||
Some(count) => {
|
||||
for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type,
|
||||
name: new_id,
|
||||
}))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type,
|
||||
name: new_id,
|
||||
}))
|
||||
}));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type,
|
||||
name: new_id,
|
||||
}));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1766,7 +1798,7 @@ struct FnStringIdResolver<'input, 'b> {
|
|||
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
||||
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
||||
type_check: HashMap<u32, ast::Type>,
|
||||
type_check: HashMap<u32, (StateSpace, ast::Type)>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
||||
|
@ -1809,7 +1841,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
||||
fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
|
||||
let numeric_id = *self.current_id;
|
||||
self.variables
|
||||
.last_mut()
|
||||
|
@ -1827,6 +1859,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
&mut self,
|
||||
base_id: &'a str,
|
||||
count: u32,
|
||||
ss: StateSpace,
|
||||
typ: ast::Type,
|
||||
) -> impl Iterator<Item = spirv::Word> {
|
||||
let numeric_id = *self.current_id;
|
||||
|
@ -1835,7 +1868,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
.last_mut()
|
||||
.unwrap()
|
||||
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
|
||||
self.type_check.insert(numeric_id + i, typ);
|
||||
self.type_check.insert(numeric_id + i, (ss, typ));
|
||||
}
|
||||
*self.current_id += count;
|
||||
(0..count).into_iter().map(move |i| i + numeric_id)
|
||||
|
@ -1844,15 +1877,15 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
|
||||
struct NumericIdResolver<'b> {
|
||||
current_id: &'b mut spirv::Word,
|
||||
type_check: HashMap<u32, ast::Type>,
|
||||
type_check: HashMap<u32, (StateSpace, ast::Type)>,
|
||||
}
|
||||
|
||||
impl<'b> NumericIdResolver<'b> {
|
||||
fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
|
||||
fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> {
|
||||
self.type_check.get(&id).map(|x| *x)
|
||||
}
|
||||
|
||||
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
|
||||
fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
|
||||
let new_id = *self.current_id;
|
||||
if let Some(typ) = typ {
|
||||
self.type_check.insert(new_id, typ);
|
||||
|
@ -1982,16 +2015,48 @@ type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, Norma
|
|||
impl ast::ArgParams for NormalizedArgParams {
|
||||
type ID = spirv::Word;
|
||||
type Operand = ast::Operand<spirv::Word>;
|
||||
type MemoryOperand = ast::Operand<spirv::Word>;
|
||||
type CallOperand = ast::CallOperand<spirv::Word>;
|
||||
type VecOperand = (spirv::Word, u8);
|
||||
}
|
||||
|
||||
enum TypedArgParams {}
|
||||
impl ast::ArgParams for TypedArgParams {
|
||||
type ID = spirv::Word;
|
||||
type Operand = ast::Operand<spirv::Word>;
|
||||
type MemoryOperand = MemoryOperand;
|
||||
type CallOperand = ast::CallOperand<spirv::Word>;
|
||||
type VecOperand = (spirv::Word, u8);
|
||||
}
|
||||
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
|
||||
|
||||
impl ArgParamsEx for NormalizedArgParams {
|
||||
fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
|
||||
decl.get_fn_decl(*id)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
Sreg,
|
||||
Const,
|
||||
Global,
|
||||
Local,
|
||||
Shared,
|
||||
Param,
|
||||
ParamReg,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum MemoryOperand {
|
||||
Reg(spirv::Word),
|
||||
Address(spirv::Word),
|
||||
RegOffset(spirv::Word, i32),
|
||||
AddressOffset(spirv::Word, i32),
|
||||
Imm(u32),
|
||||
}
|
||||
|
||||
enum ExpandedArgParams {}
|
||||
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
|
||||
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
|
||||
|
@ -1999,6 +2064,7 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme
|
|||
impl ast::ArgParams for ExpandedArgParams {
|
||||
type ID = spirv::Word;
|
||||
type Operand = spirv::Word;
|
||||
type MemoryOperand = spirv::Word;
|
||||
type CallOperand = spirv::Word;
|
||||
type VecOperand = spirv::Word;
|
||||
}
|
||||
|
@ -2012,6 +2078,11 @@ impl ArgParamsEx for ExpandedArgParams {
|
|||
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|
||||
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
|
||||
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
|
||||
fn mov_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<T::MemoryOperand>,
|
||||
typ: ast::Type,
|
||||
) -> U::MemoryOperand;
|
||||
fn src_call_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<T::CallOperand>,
|
||||
|
@ -2035,9 +2106,15 @@ where
|
|||
) -> spirv::Word {
|
||||
self(desc, t)
|
||||
}
|
||||
|
||||
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
|
||||
self(desc, Some(t))
|
||||
}
|
||||
|
||||
fn mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
|
||||
self(desc, Some(t))
|
||||
}
|
||||
|
||||
fn src_call_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<spirv::Word>,
|
||||
|
@ -2045,6 +2122,7 @@ where
|
|||
) -> spirv::Word {
|
||||
self(desc, Some(t))
|
||||
}
|
||||
|
||||
fn src_vec_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<spirv::Word>,
|
||||
|
@ -2095,6 +2173,14 @@ where
|
|||
) -> (spirv::Word, u8) {
|
||||
(self(desc.op.0), desc.op.1)
|
||||
}
|
||||
|
||||
fn mov_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::Operand<&str>>,
|
||||
typ: ast::Type,
|
||||
) -> ast::Operand<spirv::Word> {
|
||||
self.operand(desc, typ)
|
||||
}
|
||||
}
|
||||
|
||||
struct ArgumentDescriptor<Op> {
|
||||
|
@ -2260,6 +2346,16 @@ where
|
|||
desc.op.1,
|
||||
)
|
||||
}
|
||||
|
||||
fn mov_operand(
|
||||
&mut self,
|
||||
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
|
||||
typ: ast::Type,
|
||||
) -> ast::Operand<spirv::Word> {
|
||||
<Self as ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams>>::operand(
|
||||
self, desc, typ,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::Type {
|
||||
|
@ -2365,7 +2461,7 @@ struct CompositeRead {
|
|||
struct ConstantDefinition {
|
||||
pub dst: spirv::Word,
|
||||
pub typ: ast::ScalarType,
|
||||
pub value: i128,
|
||||
pub value: i64,
|
||||
}
|
||||
|
||||
struct BrachCondition {
|
||||
|
@ -2534,7 +2630,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
|||
is_param: bool,
|
||||
) -> ast::Arg2St<U> {
|
||||
ast::Arg2St {
|
||||
src1: visitor.operand(
|
||||
src1: visitor.mov_operand(
|
||||
ArgumentDescriptor {
|
||||
op: self.src1,
|
||||
is_dst: is_param,
|
||||
|
@ -3012,6 +3108,16 @@ impl From<ast::FnArgumentType> for ast::VariableType {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> ast::Operand<T> {
|
||||
fn underlying(&self) -> Option<&T> {
|
||||
match self {
|
||||
ast::Operand::Reg(r) => Some(r),
|
||||
ast::Operand::RegOffset(r, _) => Some(r),
|
||||
ast::Operand::Imm(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
|
@ -3053,7 +3159,7 @@ fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<Expan
|
|||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
|
||||
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
|
||||
if post_conv.len() > 0 {
|
||||
let new_id = id_def.new_id(Some(post_conv[0].from));
|
||||
let new_id = id_def.new_id(Some((StateSpace::Reg, post_conv[0].from)));
|
||||
post_conv[0].src = new_id;
|
||||
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
|
||||
*dst(&mut instr) = new_id;
|
||||
|
@ -3078,7 +3184,7 @@ fn insert_with_conversions_pre_conv<T>(
|
|||
conv.src = *original_src;
|
||||
}
|
||||
if i == pre_conv_len - 1 {
|
||||
let new_id = id_def.new_id(Some(conv.to));
|
||||
let new_id = id_def.new_id(Some((StateSpace::Reg, conv.to)));
|
||||
conv.dst = new_id;
|
||||
*original_src = new_id;
|
||||
}
|
||||
|
@ -3095,7 +3201,7 @@ fn get_implicit_conversions_ld_dst<
|
|||
should_convert: ShouldConvert,
|
||||
in_reverse: bool,
|
||||
) -> Option<ImplicitConversion> {
|
||||
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
|
||||
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1;
|
||||
if let Some(conv) = should_convert(dst_type, instr_type) {
|
||||
Some(ImplicitConversion {
|
||||
src: u32::max_value(),
|
||||
|
@ -3115,7 +3221,7 @@ fn get_implicit_conversions_ld_src(
|
|||
state_space: ast::LdStateSpace,
|
||||
src: spirv::Word,
|
||||
) -> Vec<ImplicitConversion> {
|
||||
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
|
||||
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
|
||||
match state_space {
|
||||
ast::LdStateSpace::Param => {
|
||||
if src_type != instr_type {
|
||||
|
@ -3162,7 +3268,7 @@ fn get_implicit_conversions_ld_src(
|
|||
kind: ConversionKind::Ptr(state_space),
|
||||
});
|
||||
if result.len() == 2 {
|
||||
let new_id = id_def.new_id(Some(new_src_type));
|
||||
let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type)));
|
||||
result[0].dst = new_id;
|
||||
result[1].src = new_id;
|
||||
result[1].from = new_src_type;
|
||||
|
@ -3221,9 +3327,9 @@ fn insert_implicit_conversions_ld_src_impl<
|
|||
src: spirv::Word,
|
||||
should_convert: ShouldConvert,
|
||||
) -> spirv::Word {
|
||||
let src_type = id_def.get_type(src);
|
||||
if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
|
||||
insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
|
||||
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
|
||||
if let Some(conv) = should_convert(src_type, instr_type) {
|
||||
insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
|
||||
} else {
|
||||
src
|
||||
}
|
||||
|
@ -3263,7 +3369,7 @@ fn insert_conversion_src(
|
|||
instr_type: ast::Type,
|
||||
conv: ConversionKind,
|
||||
) -> spirv::Word {
|
||||
let temp_src = id_def.new_id(Some(instr_type));
|
||||
let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type)));
|
||||
func.push(Statement::Conversion(ImplicitConversion {
|
||||
src: src,
|
||||
dst: temp_src,
|
||||
|
@ -3309,7 +3415,7 @@ fn get_conversion_dst(
|
|||
kind: ConversionKind,
|
||||
) -> ExpandedStatement {
|
||||
let original_dst = *dst;
|
||||
let temp_dst = id_def.new_id(Some(instr_type));
|
||||
let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type)));
|
||||
*dst = temp_dst;
|
||||
Statement::Conversion(ImplicitConversion {
|
||||
src: temp_dst,
|
||||
|
@ -3428,8 +3534,8 @@ fn insert_implicit_bitcasts(
|
|||
Some(t) => t,
|
||||
None => return desc.op,
|
||||
};
|
||||
let id_actual_type = id_def.get_type(desc.op).unwrap();
|
||||
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
|
||||
let id_actual_type = id_def.get_type(desc.op).unwrap().1;
|
||||
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) {
|
||||
if desc.is_dst {
|
||||
dst_coercion = Some(get_conversion_dst(
|
||||
id_def,
|
||||
|
|
Loading…
Add table
Reference in a new issue