diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 9214944..7ac9d18 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -354,6 +354,7 @@ pub struct CallInst { pub trait ArgParams { type ID; type Operand; + type MemoryOperand; type CallOperand; type VecOperand; } @@ -365,6 +366,7 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type ID = &'a str; type Operand = Operand<&'a str>; + type MemoryOperand = Operand<&'a str>; type CallOperand = CallOperand<&'a str>; type VecOperand = (&'a str, u8); } @@ -378,8 +380,13 @@ pub struct Arg2 { pub src: P::Operand, } +pub struct Arg2Ld { + pub dst: P::ID, + pub src: P::MemoryOperand, +} + pub struct Arg2St { - pub src1: P::Operand, + pub src1: P::MemoryOperand, pub src2: P::Operand, } @@ -416,13 +423,13 @@ pub struct Arg5 { pub enum Operand { Reg(ID), RegOffset(ID, i32), - Imm(i128), + Imm(u32), } #[derive(Copy, Clone)] pub enum CallOperand { Reg(ID), - Imm(i128), + Imm(u32), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 1ffbca2..44f29a5 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -446,7 +446,7 @@ Instruction: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction> = { - "ld" "," "[" "]" => { + "ld" "," => { ast::Instruction::Ld( ast::LdData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -899,7 +899,7 @@ ShlType: ast::ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { - "st" "[" "]" "," => { + "st" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -912,6 +912,11 @@ InstSt: ast::Instruction> = { } }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors +MemoryOperand: ast::Operand<&'input str> = { + "[" "]" => o +} + StStateSpace: ast::StStateSpace = { ".global" => ast::StStateSpace::Global, ".local" => ast::StStateSpace::Local, @@ -1006,7 +1011,7 @@ Operand: ast::Operand<&'input str> = { // TODO: start parsing whole constants sub-language: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants => { - let offset = o.parse::(); + let offset = o.parse::(); let offset = offset.unwrap_with(errors); ast::Operand::Imm(offset) } @@ -1015,7 +1020,7 @@ Operand: ast::Operand<&'input str> = { CallOperand: ast::CallOperand<&'input str> = { => ast::CallOperand::Reg(r), => { - let offset = o.parse::(); + let offset = o.parse::(); let offset = offset.unwrap_with(errors); ast::CallOperand::Imm(offset) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f1c3194..d251f77 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -59,6 +59,8 @@ test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); +test_ptx!(ntid, [3u32], [4u32]); +test_ptx!(reg_slm, [12u64], [12u64]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/ntid.ptx b/ptx/src/test/spirv_run/ntid.ptx new file mode 100644 index 0000000..2961197 --- /dev/null +++ b/ptx/src/test/spirv_run/ntid.ptx @@ -0,0 +1,23 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry ntid( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 in_val; + .reg .u32 global_count; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 in_val, [in_addr]; + mov.u32 global_count, %ntid.x; + add.u32 in_val, in_val, global_count; + st.u32 [out_addr], in_val; + ret; +} diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt new file mode 100644 index 0000000..ef308f0 --- /dev/null +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -0,0 +1,56 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %29 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" %GlobalSize + OpDecorate %GlobalSize BuiltIn GlobalSize + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 +%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint + %GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant + %ulong = OpTypeInt 64 0 + %35 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %35 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %27 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %12 = OpLoad %ulong %2 + %11 = OpCopyObject %ulong %12 + OpStore %4 %11 + %14 = OpLoad %ulong %3 + %13 = OpCopyObject %ulong %14 + OpStore %5 %13 + %16 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %16 + %15 = OpLoad %uint %25 + OpStore %6 %15 + %18 = OpLoad %v3uint %GlobalSize + %24 = OpCompositeExtract %uint %18 0 + %17 = OpCopyObject %uint %24 + OpStore %7 %17 + %20 = OpLoad %uint %6 + %21 = OpLoad %uint %7 + %19 = OpIAdd %uint %20 %21 + OpStore %6 %19 + %22 = OpLoad %ulong %5 + %23 = OpLoad %uint %6 + %26 = OpConvertUToPtr %_ptr_Generic_uint %22 + OpStore %26 %23 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/reg_slm.ptx b/ptx/src/test/spirv_run/reg_slm.ptx new file mode 100644 index 0000000..929d116 --- /dev/null +++ b/ptx/src/test/spirv_run/reg_slm.ptx @@ -0,0 +1,26 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry reg_slm( + .param .u64 input, + .param .u64 output +) +{ + .local .align 8 .b8 slm[8]; + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + .reg .s64 unused; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + mov.s64 unused, slm; + + ld.global.u64 temp, [in_addr]; + st.u64 [slm], temp; + ld.u64 temp, [slm]; + st.global.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/reg_slm.spvtxt b/ptx/src/test/spirv_run/reg_slm.spvtxt new file mode 100644 index 0000000..6810fec --- /dev/null +++ b/ptx/src/test/spirv_run/reg_slm.spvtxt @@ -0,0 +1,46 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f5c0ecb..45372f1 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -286,7 +286,7 @@ fn expand_kernel_params<'a, 'b>( args: impl Iterator>>, ) -> Vec> { args.map(|a| ast::KernelArgument { - name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))), + name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))), v_type: a.v_type, align: a.align, }) @@ -297,10 +297,16 @@ fn expand_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: impl Iterator>>, ) -> Vec> { - args.map(|a| ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))), - v_type: a.v_type, - align: a.align, + args.map(|a| { + let ss = match a.v_type { + ast::FnArgumentType::Reg(_) => StateSpace::Reg, + ast::FnArgumentType::Param(_) => StateSpace::Param, + }; + ast::FnArgument { + name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))), + v_type: a.v_type, + align: a.align, + } }) .collect() } @@ -325,6 +331,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); let unadorned_statements = add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs); + todo!() + /* let (f_args, ssa_statements) = insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); @@ -336,6 +344,7 @@ fn to_ssa<'input, 'b>( func_directive: f_args, body: Some(sorted_statements), } + */ } fn normalize_variable_decls(mut func: Vec) -> Vec { @@ -350,7 +359,7 @@ fn add_types_to_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, id_defs: &NumericIdResolver, -) -> Vec { +) -> Vec { func.into_iter() .map(|s| { match s { @@ -359,7 +368,7 @@ fn add_types_to_statements( let fn_def = fn_defs.get_fn_decl(call.func); let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params); - let resolved_call = ResolvedCall { + let resolved_call: ResolvedCall = ResolvedCall { uniform: call.uniform, ret_params, func: call.func, @@ -367,18 +376,13 @@ fn add_types_to_statements( }; Statement::Call(resolved_call) } - Statement::Instruction(ast::Instruction::MovVector(dets, args)) => { - // TODO fail on type mismatch - let new_dets = match id_defs.get_type(*args.dst()) { - Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails { - length: len, - ..dets - }, - _ => dets, - }; - Statement::Instruction(ast::Instruction::MovVector(new_dets, args)) + Statement::Instruction(ast::Instruction::Ld(d, arg)) => { + todo!() } - s => s, + Statement::Instruction(ast::Instruction::MovVector(dets, args)) => { + todo!() + } + s => todo!(), } }) .collect() @@ -485,7 +489,7 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::MethodDecl::Kernel(_, in_params) => { for p in in_params.iter_mut() { let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(Some(typ)); + let new_id = id_def.new_id(Some((StateSpace::Param, typ))); result.push(Statement::Variable(ast::Variable { align: p.align, v_type: ast::VariableType::Param(p.v_type), @@ -504,8 +508,12 @@ fn insert_mem_ssa_statements<'a, 'b>( } ast::MethodDecl::Func(out_params, _, in_params) => { for p in in_params.iter_mut() { + let ss = match p.v_type { + ast::FnArgumentType::Reg(_) => StateSpace::Reg, + ast::FnArgumentType::Param(_) => StateSpace::Param, + }; let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(Some(typ)); + let new_id = id_def.new_id(Some((ss, typ))); let var_typ = ast::VariableType::from(p.v_type); result.push(Statement::Variable(ast::Variable { align: p.align, @@ -548,7 +556,7 @@ fn insert_mem_ssa_statements<'a, 'b>( dst: new_id, src: out_param, }, - typ.unwrap(), + typ.unwrap().1, )); result.push(Statement::RetValue(d, new_id)); } else { @@ -558,7 +566,10 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst), }, Statement::Conditional(mut bra) => { - let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred))); + let generated_id = id_def.new_id(Some(( + StateSpace::Reg, + ast::Type::Scalar(ast::ScalarType::Pred), + ))); result.push(Statement::LoadVar( Arg2 { dst: generated_id, @@ -607,11 +618,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( let mut post_statements = Vec::new(); let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor, _| { let id_type = match (id_def.get_type(desc.op), desc.sema) { - (Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t, - (Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64), + (Some((_, t)), ArgumentSemantics::ParamPtr) + | (Some((_, t)), ArgumentSemantics::Default) => t, + (Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64), (None, _) => return desc.op, }; - let generated_id = id_def.new_id(Some(id_type)); + let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type))); if !desc.is_dst { result.push(Statement::LoadVar( Arg2 { @@ -716,11 +728,13 @@ impl<'a, 'b> ArgumentMapVisitor } else { todo!() }; - let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + let id = self + .id_def + .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, - value: x, + value: x as i64, })); id } @@ -732,13 +746,14 @@ impl<'a, 'b> ArgumentMapVisitor } else { todo!() }; - let id_constant_stmt = - self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - let result_id = self.id_def.new_id(Some(typ)); + let id_constant_stmt = self + .id_def + .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); + let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: scalar_t, - value: offset as i128, + value: offset as i64, })); let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); self.func.push(Statement::Instruction( @@ -758,13 +773,14 @@ impl<'a, 'b> ArgumentMapVisitor } ArgumentSemantics::Ptr => { let scalar_t = ast::ScalarType::U64; - let id_constant_stmt = - self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); - let result_id = self.id_def.new_id(Some(typ)); + let id_constant_stmt = self + .id_def + .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t)))); + let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ))); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: scalar_t, - value: offset as i128, + value: offset as i64, })); let int_type = ast::IntType::U64; self.func.push(Statement::Instruction( @@ -810,9 +826,10 @@ impl<'a, 'b> ArgumentMapVisitor desc: ArgumentDescriptor<(spirv::Word, u8)>, (scalar_type, vec_len): (ast::MovVectorType, u8), ) -> spirv::Word { - let new_id = self - .id_def - .new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len))); + let new_id = self.id_def.new_id(Some(( + StateSpace::Reg, + ast::Type::Vector(scalar_type.into(), vec_len), + ))); self.func.push(Statement::Composite(CompositeRead { typ: scalar_type, dst: new_id, @@ -821,6 +838,14 @@ impl<'a, 'b> ArgumentMapVisitor })); new_id } + + fn mov_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> spirv::Word { + self.operand(desc, typ) + } } /* @@ -911,7 +936,7 @@ fn insert_implicit_conversions( let mut did_vector_implicit = false; let mut post_conv = None; if inst_typ_is_bit { - let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()); + let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1; if let ast::Type::Vector(_, _) = src_type { arg.src = insert_conversion_src( &mut result, @@ -923,7 +948,7 @@ fn insert_implicit_conversions( ); did_vector_implicit = true; } - let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()); + let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1; if let ast::Type::Vector(_, _) = src_type { post_conv = Some(get_conversion_dst( id_def, @@ -1615,25 +1640,32 @@ fn expand_map_variables<'a, 'b>( p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))), i.map_variable(&mut |id| id_defs.get_id(id)), ))), - ast::Statement::Variable(var) => match var.count { - Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) { + ast::Statement::Variable(var) => { + let ss = match var.var.v_type { + ast::VariableType::Reg(_) => StateSpace::Reg, + ast::VariableType::Local(_) => StateSpace::Local, + ast::VariableType::Param(_) => StateSpace::ParamReg, + }; + match var.count { + Some(count) => { + for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) { + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type, + name: new_id, + })) + } + } + None => { + let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into()))); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type, name: new_id, - })) + })); } } - None => { - let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into())); - result.push(Statement::Variable(ast::Variable { - align: var.var.align, - v_type: var.var.v_type, - name: new_id, - })); - } - }, + } } } @@ -1766,7 +1798,7 @@ struct FnStringIdResolver<'input, 'b> { global_variables: &'b HashMap, spirv::Word>, special_registers: &'b mut HashMap, variables: Vec, spirv::Word>>, - type_check: HashMap, + type_check: HashMap, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -1809,7 +1841,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { } } - fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { + fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() @@ -1827,6 +1859,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { &mut self, base_id: &'a str, count: u32, + ss: StateSpace, typ: ast::Type, ) -> impl Iterator { let numeric_id = *self.current_id; @@ -1835,7 +1868,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check.insert(numeric_id + i, typ); + self.type_check.insert(numeric_id + i, (ss, typ)); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -1844,15 +1877,15 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - type_check: HashMap, + type_check: HashMap, } impl<'b> NumericIdResolver<'b> { - fn get_type(&self, id: spirv::Word) -> Option { + fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> { self.type_check.get(&id).map(|x| *x) } - fn new_id(&mut self, typ: Option) -> spirv::Word { + fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word { let new_id = *self.current_id; if let Some(typ) = typ { self.type_check.insert(new_id, typ); @@ -1982,16 +2015,48 @@ type UnadornedStatement = Statement, Norma impl ast::ArgParams for NormalizedArgParams { type ID = spirv::Word; type Operand = ast::Operand; + type MemoryOperand = ast::Operand; type CallOperand = ast::CallOperand; type VecOperand = (spirv::Word, u8); } +enum TypedArgParams {} +impl ast::ArgParams for TypedArgParams { + type ID = spirv::Word; + type Operand = ast::Operand; + type MemoryOperand = MemoryOperand; + type CallOperand = ast::CallOperand; + type VecOperand = (spirv::Word, u8); +} +type TypedStatement = Statement, TypedArgParams>; + impl ArgParamsEx for NormalizedArgParams { fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl { decl.get_fn_decl(*id) } } +#[derive(Copy, Clone)] +pub enum StateSpace { + Reg, + Sreg, + Const, + Global, + Local, + Shared, + Param, + ParamReg, +} + +#[derive(Copy, Clone)] +pub enum MemoryOperand { + Reg(spirv::Word), + Address(spirv::Word), + RegOffset(spirv::Word, i32), + AddressOffset(spirv::Word, i32), + Imm(u32), +} + enum ExpandedArgParams {} type ExpandedStatement = Statement, ExpandedArgParams>; type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>; @@ -1999,6 +2064,7 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word; type Operand = spirv::Word; + type MemoryOperand = spirv::Word; type CallOperand = spirv::Word; type VecOperand = spirv::Word; } @@ -2012,6 +2078,11 @@ impl ArgParamsEx for ExpandedArgParams { trait ArgumentMapVisitor { fn variable(&mut self, desc: ArgumentDescriptor, typ: Option) -> U::ID; fn operand(&mut self, desc: ArgumentDescriptor, typ: ast::Type) -> U::Operand; + fn mov_operand( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> U::MemoryOperand; fn src_call_operand( &mut self, desc: ArgumentDescriptor, @@ -2035,9 +2106,15 @@ where ) -> spirv::Word { self(desc, t) } + fn operand(&mut self, desc: ArgumentDescriptor, t: ast::Type) -> spirv::Word { self(desc, Some(t)) } + + fn mov_operand(&mut self, desc: ArgumentDescriptor, t: ast::Type) -> spirv::Word { + self(desc, Some(t)) + } + fn src_call_operand( &mut self, desc: ArgumentDescriptor, @@ -2045,6 +2122,7 @@ where ) -> spirv::Word { self(desc, Some(t)) } + fn src_vec_operand( &mut self, desc: ArgumentDescriptor, @@ -2095,6 +2173,14 @@ where ) -> (spirv::Word, u8) { (self(desc.op.0), desc.op.1) } + + fn mov_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> ast::Operand { + self.operand(desc, typ) + } } struct ArgumentDescriptor { @@ -2260,6 +2346,16 @@ where desc.op.1, ) } + + fn mov_operand( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> ast::Operand { + >::operand( + self, desc, typ, + ) + } } impl ast::Type { @@ -2365,7 +2461,7 @@ struct CompositeRead { struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, - pub value: i128, + pub value: i64, } struct BrachCondition { @@ -2534,7 +2630,7 @@ impl ast::Arg2St { is_param: bool, ) -> ast::Arg2St { ast::Arg2St { - src1: visitor.operand( + src1: visitor.mov_operand( ArgumentDescriptor { op: self.src1, is_dst: is_param, @@ -3012,6 +3108,16 @@ impl From for ast::VariableType { } } +impl ast::Operand { + fn underlying(&self) -> Option<&T> { + match self { + ast::Operand::Reg(r) => Some(r), + ast::Operand::RegOffset(r, _) => Some(r), + ast::Operand::Imm(_) => None, + } + } +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { @@ -3053,7 +3159,7 @@ fn insert_with_conversions ast::Instruction 0 { - let new_id = id_def.new_id(Some(post_conv[0].from)); + let new_id = id_def.new_id(Some((StateSpace::Reg, post_conv[0].from))); post_conv[0].src = new_id; post_conv.last_mut().unwrap().dst = *dst(&mut instr); *dst(&mut instr) = new_id; @@ -3078,7 +3184,7 @@ fn insert_with_conversions_pre_conv( conv.src = *original_src; } if i == pre_conv_len - 1 { - let new_id = id_def.new_id(Some(conv.to)); + let new_id = id_def.new_id(Some((StateSpace::Reg, conv.to))); conv.dst = new_id; *original_src = new_id; } @@ -3095,7 +3201,7 @@ fn get_implicit_conversions_ld_dst< should_convert: ShouldConvert, in_reverse: bool, ) -> Option { - let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()); + let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1; if let Some(conv) = should_convert(dst_type, instr_type) { Some(ImplicitConversion { src: u32::max_value(), @@ -3115,7 +3221,7 @@ fn get_implicit_conversions_ld_src( state_space: ast::LdStateSpace, src: spirv::Word, ) -> Vec { - let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()); + let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1; match state_space { ast::LdStateSpace::Param => { if src_type != instr_type { @@ -3162,7 +3268,7 @@ fn get_implicit_conversions_ld_src( kind: ConversionKind::Ptr(state_space), }); if result.len() == 2 { - let new_id = id_def.new_id(Some(new_src_type)); + let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type))); result[0].dst = new_id; result[1].src = new_id; result[1].from = new_src_type; @@ -3221,9 +3327,9 @@ fn insert_implicit_conversions_ld_src_impl< src: spirv::Word, should_convert: ShouldConvert, ) -> spirv::Word { - let src_type = id_def.get_type(src); - if let Some(conv) = should_convert(src_type.unwrap(), instr_type) { - insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv) + let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1; + if let Some(conv) = should_convert(src_type, instr_type) { + insert_conversion_src(func, id_def, src, src_type, instr_type, conv) } else { src } @@ -3263,7 +3369,7 @@ fn insert_conversion_src( instr_type: ast::Type, conv: ConversionKind, ) -> spirv::Word { - let temp_src = id_def.new_id(Some(instr_type)); + let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type))); func.push(Statement::Conversion(ImplicitConversion { src: src, dst: temp_src, @@ -3309,7 +3415,7 @@ fn get_conversion_dst( kind: ConversionKind, ) -> ExpandedStatement { let original_dst = *dst; - let temp_dst = id_def.new_id(Some(instr_type)); + let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type))); *dst = temp_dst; Statement::Conversion(ImplicitConversion { src: temp_dst, @@ -3428,8 +3534,8 @@ fn insert_implicit_bitcasts( Some(t) => t, None => return desc.op, }; - let id_actual_type = id_def.get_type(desc.op).unwrap(); - if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) { + let id_actual_type = id_def.get_type(desc.op).unwrap().1; + if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) { if desc.is_dst { dst_coercion = Some(get_conversion_dst( id_def,