[BROKEN] Start implementing better support for addressable arguments

This commit is contained in:
Andrzej Janik 2020-09-18 02:25:20 +02:00
parent 42bad8fcc2
commit 952ed5d504
8 changed files with 351 additions and 80 deletions

View file

@ -354,6 +354,7 @@ pub struct CallInst<P: ArgParams> {
pub trait ArgParams {
type ID;
type Operand;
type MemoryOperand;
type CallOperand;
type VecOperand;
}
@ -365,6 +366,7 @@ pub struct ParsedArgParams<'a> {
impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
type MemoryOperand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>;
type VecOperand = (&'a str, u8);
}
@ -378,8 +380,13 @@ pub struct Arg2<P: ArgParams> {
pub src: P::Operand,
}
pub struct Arg2Ld<P: ArgParams> {
pub dst: P::ID,
pub src: P::MemoryOperand,
}
pub struct Arg2St<P: ArgParams> {
pub src1: P::Operand,
pub src1: P::MemoryOperand,
pub src2: P::Operand,
}
@ -416,13 +423,13 @@ pub struct Arg5<P: ArgParams> {
pub enum Operand<ID> {
Reg(ID),
RegOffset(ID, i32),
Imm(i128),
Imm(u32),
}
#[derive(Copy, Clone)]
pub enum CallOperand<ID> {
Reg(ID),
Imm(i128),
Imm(u32),
}
pub enum VectorPrefix {

View file

@ -446,7 +446,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => {
ast::Instruction::Ld(
ast::LdData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@ -899,7 +899,7 @@ ShlType: ast::ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@ -912,6 +912,11 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors
MemoryOperand: ast::Operand<&'input str> = {
"[" <o:Operand> "]" => o
}
StStateSpace: ast::StStateSpace = {
".global" => ast::StStateSpace::Global,
".local" => ast::StStateSpace::Local,
@ -1006,7 +1011,7 @@ Operand: ast::Operand<&'input str> = {
// TODO: start parsing whole constants sub-language:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
<o:Num> => {
let offset = o.parse::<i128>();
let offset = o.parse::<u32>();
let offset = offset.unwrap_with(errors);
ast::Operand::Imm(offset)
}
@ -1015,7 +1020,7 @@ Operand: ast::Operand<&'input str> = {
CallOperand: ast::CallOperand<&'input str> = {
<r:ExtendedID> => ast::CallOperand::Reg(r),
<o:Num> => {
let offset = o.parse::<i128>();
let offset = o.parse::<u32>();
let offset = offset.unwrap_with(errors);
ast::CallOperand::Imm(offset)
}

View file

@ -59,6 +59,8 @@ test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_slm, [12u64], [12u64]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,23 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry ntid(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 in_val;
.reg .u32 global_count;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 in_val, [in_addr];
mov.u32 global_count, %ntid.x;
add.u32 in_val, in_val, global_count;
st.u32 [out_addr], in_val;
ret;
}

View file

@ -0,0 +1,56 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%29 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add" %GlobalSize
OpDecorate %GlobalSize BuiltIn GlobalSize
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint
%GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant
%ulong = OpTypeInt 64 0
%35 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %35
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%27 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %2 %9
OpStore %3 %10
%12 = OpLoad %ulong %2
%11 = OpCopyObject %ulong %12
OpStore %4 %11
%14 = OpLoad %ulong %3
%13 = OpCopyObject %ulong %14
OpStore %5 %13
%16 = OpLoad %ulong %4
%25 = OpConvertUToPtr %_ptr_Generic_uint %16
%15 = OpLoad %uint %25
OpStore %6 %15
%18 = OpLoad %v3uint %GlobalSize
%24 = OpCompositeExtract %uint %18 0
%17 = OpCopyObject %uint %24
OpStore %7 %17
%20 = OpLoad %uint %6
%21 = OpLoad %uint %7
%19 = OpIAdd %uint %20 %21
OpStore %6 %19
%22 = OpLoad %ulong %5
%23 = OpLoad %uint %6
%26 = OpConvertUToPtr %_ptr_Generic_uint %22
OpStore %26 %23
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,26 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry reg_slm(
.param .u64 input,
.param .u64 output
)
{
.local .align 8 .b8 slm[8];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b64 temp;
.reg .s64 unused;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
mov.s64 unused, slm;
ld.global.u64 temp, [in_addr];
st.u64 [slm], temp;
ld.u64 temp, [slm];
st.global.u64 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,46 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpIAdd %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View file

@ -286,7 +286,7 @@ fn expand_kernel_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
args.map(|a| ast::KernelArgument {
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
v_type: a.v_type,
align: a.align,
})
@ -297,10 +297,16 @@ fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
args.map(|a| ast::FnArgument {
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
v_type: a.v_type,
align: a.align,
args.map(|a| {
let ss = match a.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
ast::FnArgumentType::Param(_) => StateSpace::Param,
};
ast::FnArgument {
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
v_type: a.v_type,
align: a.align,
}
})
.collect()
}
@ -325,6 +331,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let unadorned_statements =
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
todo!()
/*
let (f_args, ssa_statements) =
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
@ -336,6 +344,7 @@ fn to_ssa<'input, 'b>(
func_directive: f_args,
body: Some(sorted_statements),
}
*/
}
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
@ -350,7 +359,7 @@ fn add_types_to_statements(
func: Vec<UnadornedStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
) -> Vec<UnadornedStatement> {
) -> Vec<TypedStatement> {
func.into_iter()
.map(|s| {
match s {
@ -359,7 +368,7 @@ fn add_types_to_statements(
let fn_def = fn_defs.get_fn_decl(call.func);
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
let resolved_call = ResolvedCall {
let resolved_call: ResolvedCall<TypedArgParams> = ResolvedCall {
uniform: call.uniform,
ret_params,
func: call.func,
@ -367,18 +376,13 @@ fn add_types_to_statements(
};
Statement::Call(resolved_call)
}
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
// TODO fail on type mismatch
let new_dets = match id_defs.get_type(*args.dst()) {
Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
length: len,
..dets
},
_ => dets,
};
Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
Statement::Instruction(ast::Instruction::Ld(d, arg)) => {
todo!()
}
s => s,
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
todo!()
}
s => todo!(),
}
})
.collect()
@ -485,7 +489,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ));
let new_id = id_def.new_id(Some((StateSpace::Param, typ)));
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: ast::VariableType::Param(p.v_type),
@ -504,8 +508,12 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
let ss = match p.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
ast::FnArgumentType::Param(_) => StateSpace::Param,
};
let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ));
let new_id = id_def.new_id(Some((ss, typ)));
let var_typ = ast::VariableType::from(p.v_type);
result.push(Statement::Variable(ast::Variable {
align: p.align,
@ -548,7 +556,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
dst: new_id,
src: out_param,
},
typ.unwrap(),
typ.unwrap().1,
));
result.push(Statement::RetValue(d, new_id));
} else {
@ -558,7 +566,10 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
},
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
let generated_id = id_def.new_id(Some((
StateSpace::Reg,
ast::Type::Scalar(ast::ScalarType::Pred),
)));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
@ -607,11 +618,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
let mut post_statements = Vec::new();
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
let id_type = match (id_def.get_type(desc.op), desc.sema) {
(Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t,
(Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
(Some((_, t)), ArgumentSemantics::ParamPtr)
| (Some((_, t)), ArgumentSemantics::Default) => t,
(Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
};
let generated_id = id_def.new_id(Some(id_type));
let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type)));
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
@ -716,11 +728,13 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
} else {
todo!()
};
let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
let id = self
.id_def
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value: x,
value: x as i64,
}));
id
}
@ -732,13 +746,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
} else {
todo!()
};
let id_constant_stmt =
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
let result_id = self.id_def.new_id(Some(typ));
let id_constant_stmt = self
.id_def
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
value: offset as i64,
}));
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
@ -758,13 +773,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}
ArgumentSemantics::Ptr => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt =
self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
let result_id = self.id_def.new_id(Some(typ));
let id_constant_stmt = self
.id_def
.new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
value: offset as i64,
}));
let int_type = ast::IntType::U64;
self.func.push(Statement::Instruction(
@ -810,9 +826,10 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word {
let new_id = self
.id_def
.new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
let new_id = self.id_def.new_id(Some((
StateSpace::Reg,
ast::Type::Vector(scalar_type.into(), vec_len),
)));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
@ -821,6 +838,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}));
new_id
}
fn mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: ast::Type,
) -> spirv::Word {
self.operand(desc, typ)
}
}
/*
@ -911,7 +936,7 @@ fn insert_implicit_conversions(
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1;
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
@ -923,7 +948,7 @@ fn insert_implicit_conversions(
);
did_vector_implicit = true;
}
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1;
if let ast::Type::Vector(_, _) = src_type {
post_conv = Some(get_conversion_dst(
id_def,
@ -1615,25 +1640,32 @@ fn expand_map_variables<'a, 'b>(
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
i.map_variable(&mut |id| id_defs.get_id(id)),
))),
ast::Statement::Variable(var) => match var.count {
Some(count) => {
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
ast::Statement::Variable(var) => {
let ss = match var.var.v_type {
ast::VariableType::Reg(_) => StateSpace::Reg,
ast::VariableType::Local(_) => StateSpace::Local,
ast::VariableType::Param(_) => StateSpace::ParamReg,
};
match var.count {
Some(count) => {
for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
}))
}
}
None => {
let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
}))
}));
}
}
None => {
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
}));
}
},
}
}
}
@ -1766,7 +1798,7 @@ struct FnStringIdResolver<'input, 'b> {
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, ast::Type>,
type_check: HashMap<u32, (StateSpace, ast::Type)>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@ -1809,7 +1841,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
}
}
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
@ -1827,6 +1859,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
&mut self,
base_id: &'a str,
count: u32,
ss: StateSpace,
typ: ast::Type,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@ -1835,7 +1868,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check.insert(numeric_id + i, typ);
self.type_check.insert(numeric_id + i, (ss, typ));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@ -1844,15 +1877,15 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
type_check: HashMap<u32, ast::Type>,
type_check: HashMap<u32, (StateSpace, ast::Type)>,
}
impl<'b> NumericIdResolver<'b> {
fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> {
self.type_check.get(&id).map(|x| *x)
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let new_id = *self.current_id;
if let Some(typ) = typ {
self.type_check.insert(new_id, typ);
@ -1982,16 +2015,48 @@ type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, Norma
impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type MemoryOperand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
type VecOperand = (spirv::Word, u8);
}
enum TypedArgParams {}
impl ast::ArgParams for TypedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type MemoryOperand = MemoryOperand;
type CallOperand = ast::CallOperand<spirv::Word>;
type VecOperand = (spirv::Word, u8);
}
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
impl ArgParamsEx for NormalizedArgParams {
fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
decl.get_fn_decl(*id)
}
}
#[derive(Copy, Clone)]
pub enum StateSpace {
Reg,
Sreg,
Const,
Global,
Local,
Shared,
Param,
ParamReg,
}
#[derive(Copy, Clone)]
pub enum MemoryOperand {
Reg(spirv::Word),
Address(spirv::Word),
RegOffset(spirv::Word, i32),
AddressOffset(spirv::Word, i32),
Imm(u32),
}
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
@ -1999,6 +2064,7 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type MemoryOperand = spirv::Word;
type CallOperand = spirv::Word;
type VecOperand = spirv::Word;
}
@ -2012,6 +2078,11 @@ impl ArgParamsEx for ExpandedArgParams {
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
fn mov_operand(
&mut self,
desc: ArgumentDescriptor<T::MemoryOperand>,
typ: ast::Type,
) -> U::MemoryOperand;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
@ -2035,9 +2106,15 @@ where
) -> spirv::Word {
self(desc, t)
}
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
self(desc, Some(t))
}
fn mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
self(desc, Some(t))
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@ -2045,6 +2122,7 @@ where
) -> spirv::Word {
self(desc, Some(t))
}
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@ -2095,6 +2173,14 @@ where
) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
fn mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: ast::Type,
) -> ast::Operand<spirv::Word> {
self.operand(desc, typ)
}
}
struct ArgumentDescriptor<Op> {
@ -2260,6 +2346,16 @@ where
desc.op.1,
)
}
fn mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: ast::Type,
) -> ast::Operand<spirv::Word> {
<Self as ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams>>::operand(
self, desc, typ,
)
}
}
impl ast::Type {
@ -2365,7 +2461,7 @@ struct CompositeRead {
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
pub value: i128,
pub value: i64,
}
struct BrachCondition {
@ -2534,7 +2630,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
is_param: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
src1: visitor.operand(
src1: visitor.mov_operand(
ArgumentDescriptor {
op: self.src1,
is_dst: is_param,
@ -3012,6 +3108,16 @@ impl From<ast::FnArgumentType> for ast::VariableType {
}
}
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
ast::Operand::Reg(r) => Some(r),
ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@ -3053,7 +3159,7 @@ fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<Expan
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
if post_conv.len() > 0 {
let new_id = id_def.new_id(Some(post_conv[0].from));
let new_id = id_def.new_id(Some((StateSpace::Reg, post_conv[0].from)));
post_conv[0].src = new_id;
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
*dst(&mut instr) = new_id;
@ -3078,7 +3184,7 @@ fn insert_with_conversions_pre_conv<T>(
conv.src = *original_src;
}
if i == pre_conv_len - 1 {
let new_id = id_def.new_id(Some(conv.to));
let new_id = id_def.new_id(Some((StateSpace::Reg, conv.to)));
conv.dst = new_id;
*original_src = new_id;
}
@ -3095,7 +3201,7 @@ fn get_implicit_conversions_ld_dst<
should_convert: ShouldConvert,
in_reverse: bool,
) -> Option<ImplicitConversion> {
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1;
if let Some(conv) = should_convert(dst_type, instr_type) {
Some(ImplicitConversion {
src: u32::max_value(),
@ -3115,7 +3221,7 @@ fn get_implicit_conversions_ld_src(
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> Vec<ImplicitConversion> {
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
match state_space {
ast::LdStateSpace::Param => {
if src_type != instr_type {
@ -3162,7 +3268,7 @@ fn get_implicit_conversions_ld_src(
kind: ConversionKind::Ptr(state_space),
});
if result.len() == 2 {
let new_id = id_def.new_id(Some(new_src_type));
let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type)));
result[0].dst = new_id;
result[1].src = new_id;
result[1].from = new_src_type;
@ -3221,9 +3327,9 @@ fn insert_implicit_conversions_ld_src_impl<
src: spirv::Word,
should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = id_def.get_type(src);
if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
if let Some(conv) = should_convert(src_type, instr_type) {
insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
} else {
src
}
@ -3263,7 +3369,7 @@ fn insert_conversion_src(
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
let temp_src = id_def.new_id(Some(instr_type));
let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type)));
func.push(Statement::Conversion(ImplicitConversion {
src: src,
dst: temp_src,
@ -3309,7 +3415,7 @@ fn get_conversion_dst(
kind: ConversionKind,
) -> ExpandedStatement {
let original_dst = *dst;
let temp_dst = id_def.new_id(Some(instr_type));
let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type)));
*dst = temp_dst;
Statement::Conversion(ImplicitConversion {
src: temp_dst,
@ -3428,8 +3534,8 @@ fn insert_implicit_bitcasts(
Some(t) => t,
None => return desc.op,
};
let id_actual_type = id_def.get_type(desc.op).unwrap();
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
let id_actual_type = id_def.get_type(desc.op).unwrap().1;
if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,