From 1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 27 Sep 2020 23:51:34 +0200 Subject: [PATCH] Implement vector-destructuring mov/ld/st --- ptx/src/ast.rs | 181 ++-- ptx/src/ptx.lalrpop | 136 +-- ptx/src/test/spirv_run/ntid.spvtxt | 13 +- ptx/src/translate.rs | 1229 ++++++++++++++++++++-------- 4 files changed, 1064 insertions(+), 495 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index acefdc1..7edfa70 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -35,6 +35,19 @@ macro_rules! sub_scalar_type { } } } + + impl std::convert::TryFrom for $name { + type Error = (); + + fn try_from(t: ScalarType) -> Result { + match t { + $( + ScalarType::$variant => Ok($name::$variant), + )+ + _ => Err(()), + } + } + } }; } @@ -159,20 +172,20 @@ pub struct Module<'a> { pub functions: Vec>, } -pub enum MethodDecl<'a, P: ArgParams> { - Func(Vec>, P::ID, Vec>), - Kernel(&'a str, Vec>), +pub enum MethodDecl<'a, ID> { + Func(Vec>, ID, Vec>), + Kernel(&'a str, Vec>), } -pub type FnArgument

= Variable; -pub type KernelArgument

= Variable; +pub type FnArgument = Variable; +pub type KernelArgument = Variable; -pub struct Function<'a, P: ArgParams, S> { - pub func_directive: MethodDecl<'a, P>, +pub struct Function<'a, ID, S> { + pub func_directive: MethodDecl<'a, ID>, pub body: Option>, } -pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement>>; +pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; #[derive(PartialEq, Eq, Clone, Copy)] pub enum FnArgumentType { @@ -264,21 +277,21 @@ impl Default for ScalarType { } pub enum Statement { - Label(P::ID), - Variable(MultiVariable

), - Instruction(Option>, Instruction

), + Label(P::Id), + Variable(MultiVariable), + Instruction(Option>, Instruction

), Block(Vec>), } -pub struct MultiVariable { - pub var: Variable, +pub struct MultiVariable { + pub var: Variable, pub count: Option, } -pub struct Variable { +pub struct Variable { pub align: Option, pub v_type: T, - pub name: P::ID, + pub name: ID, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -315,9 +328,8 @@ pub struct PredAt { } pub enum Instruction { - Ld(LdData, Arg2

), - Mov(MovDetails, Arg2

), - MovVector(MovVectorDetails, Arg2Vec

), + Ld(LdDetails, Arg2Ld

), + Mov(MovDetails, Arg2Mov

), Mul(MulDetails, Arg3

), Add(AddDetails, Arg3

), Setp(SetpData, Arg4Setp

), @@ -337,11 +349,6 @@ pub enum Instruction { #[derive(Copy, Clone)] pub struct MadFloatDesc {} -#[derive(Copy, Clone)] -pub struct MovVectorDetails { - pub typ: MovVectorType, - pub length: u8, -} #[derive(Copy, Clone)] pub struct AbsDetails { pub flush_to_zero: bool, @@ -350,16 +357,18 @@ pub struct AbsDetails { pub struct CallInst { pub uniform: bool, - pub ret_params: Vec, - pub func: P::ID, + pub ret_params: Vec, + pub func: P::Id, pub param_list: Vec, } pub trait ArgParams { - type ID; + type Id; type Operand; + type IdOrVector; + type OperandOrVector; type CallOperand; - type VecOperand; + type SrcMemberOperand; } pub struct ParsedArgParams<'a> { @@ -367,57 +376,73 @@ pub struct ParsedArgParams<'a> { } impl<'a> ArgParams for ParsedArgParams<'a> { - type ID = &'a str; + type Id = &'a str; type Operand = Operand<&'a str>; type CallOperand = CallOperand<&'a str>; - type VecOperand = (&'a str, u8); + type IdOrVector = IdOrVector<&'a str>; + type OperandOrVector = OperandOrVector<&'a str>; + type SrcMemberOperand = (&'a str, u8); } pub struct Arg1 { - pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand + pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand } pub struct Arg2 { - pub dst: P::ID, + pub dst: P::Id, + pub src: P::Operand, +} +pub struct Arg2Ld { + pub dst: P::IdOrVector, pub src: P::Operand, } pub struct Arg2St { pub src1: P::Operand, - pub src2: P::Operand, + pub src2: P::OperandOrVector, +} + +pub enum Arg2Mov { + Normal(Arg2MovNormal

), + Member(Arg2MovMember

), +} + +pub struct Arg2MovNormal { + pub dst: P::IdOrVector, + pub src: P::OperandOrVector, } // We duplicate dst here because during further compilation // composite dst and composite src will receive different ids -pub enum Arg2Vec { - Dst((P::ID, u8), P::ID, P::ID), - Src(P::ID, P::VecOperand), - Both((P::ID, u8), P::ID, P::VecOperand), +pub enum Arg2MovMember { + Dst((P::Id, u8), P::Id, P::Id), + Src(P::Id, P::SrcMemberOperand), + Both((P::Id, u8), P::Id, P::SrcMemberOperand), } pub struct Arg3 { - pub dst: P::ID, + pub dst: P::Id, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg4 { - pub dst: P::ID, + pub dst: P::Id, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, } pub struct Arg4Setp { - pub dst1: P::ID, - pub dst2: Option, + pub dst1: P::Id, + pub dst2: Option, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg5 { - pub dst1: P::ID, - pub dst2: Option, + pub dst1: P::Id, + pub dst2: Option, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, @@ -436,12 +461,34 @@ pub enum CallOperand { Imm(u32), } +pub enum IdOrVector { + Reg(ID), + Vec(Vec) +} + +pub enum OperandOrVector { + Reg(ID), + RegOffset(ID, i32), + Imm(u32), + Vec(Vec) +} + +impl From> for OperandOrVector { + fn from(this: Operand) -> Self { + match this { + Operand::Reg(r) => OperandOrVector::Reg(r), + Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), + Operand::Imm(imm) => OperandOrVector::Imm(imm), + } + } +} + pub enum VectorPrefix { V2, V4, } -pub struct LdData { +pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, @@ -482,45 +529,23 @@ pub enum LdCacheOperator { Uncached, } -sub_scalar_type!(MovScalarType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, - Pred, -}); - -// pred vectors are illegal -sub_scalar_type!(MovVectorType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); - +#[derive(Copy, Clone)] pub struct MovDetails { - pub typ: MovType, + pub typ: Type, pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, } -sub_type! { - MovType { - Scalar(MovScalarType), - Vector(MovVectorType, u8), +impl MovDetails { + pub fn new(typ: Type) -> Self { + MovDetails { + typ, + src_is_address: false, + dst_width: 0, + src_width: 0 + } } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 50a6aeb..ba3fc2b 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -194,7 +194,7 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option, ast::Statement>>> = { +Directive: Option>>> = { AddressSize => None, => Some(f), File => None, @@ -205,7 +205,7 @@ AddressSize = { ".address_size" Num }; -Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement>> = { +Function: ast::Function<'input, &'input str, ast::Statement>> = { LinkingDirective* => ast::Function{<>} @@ -217,29 +217,29 @@ LinkingDirective = { ".weak" }; -MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = { +MethodDecl: ast::MethodDecl<'input, &'input str> = { ".entry" => ast::MethodDecl::Kernel(name, params), ".func" => { ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) } }; -KernelArguments: Vec>> = { +KernelArguments: Vec> = { "(" > ")" => args }; -FnArguments: Vec>> = { +FnArguments: Vec> = { "(" > ")" => args }; -KernelInput: ast::Variable> = { +KernelInput: ast::Variable = { => { let (align, v_type, name) = v; ast::Variable{ align, v_type, name } } } -FnInput: ast::Variable> = { +FnInput: ast::Variable = { => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Reg(v_type); @@ -320,7 +320,7 @@ Align: u32 = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names -MultiVariable: ast::MultiVariable> = { +MultiVariable: ast::MultiVariable<&'input str> = { => ast::MultiVariable{<>} } @@ -331,7 +331,7 @@ VariableParam: u32 = { } } -Variable: ast::Variable> = { +Variable: ast::Variable = { => { let (align, v_type, name) = v; let v_type = ast::VariableType::Reg(v_type); @@ -356,7 +356,7 @@ RegVariable: (Option, ast::VariableRegType, &'input str) = { } } -LocalVariable: ast::Variable> = { +LocalVariable: ast::Variable = { ".local" => { let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); ast::Variable {align, v_type, name} @@ -449,19 +449,29 @@ Instruction: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction> = { - "ld" "," => { + "ld" "," => { ast::Instruction::Ld( - ast::LdData { + ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::LdStateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t }, - ast::Arg2 { dst:dst, src:src } + ast::Arg2Ld { dst:dst, src:src } ) } }; +IdOrVector: ast::IdOrVector<&'input str> = { + => ast::IdOrVector::Reg(dst), + => ast::IdOrVector::Vec(dst) +} + +OperandOrVector: ast::OperandOrVector<&'input str> = { + => ast::OperandOrVector::from(op), + => ast::OperandOrVector::Vec(dst) +} + LdStType: ast::Type = { => ast::Type::Vector(t, v), => ast::Type::Scalar(t), @@ -498,49 +508,58 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction> = { - "mov" => { - ast::Instruction::Mov(ast::MovDetails{ src_is_address: false, typ: t }, a) - }, - "mov" => { - ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a) - } + => ast::Instruction::Mov(m.0, m.1), + => ast::Instruction::Mov(m.0, m.1), }; -#[inline] -MovType: ast::MovType = { - => ast::MovType::Scalar(t), - => ast::MovType::Vector(t, pref) + +MovNormal: (ast::MovDetails, ast::Arg2Mov>) = { + "mov" "," => {( + ast::MovDetails::new(ast::Type::Scalar(t)), + ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() }) + )}, + "mov" "," => {( + ast::MovDetails::new(ast::Type::Vector(t, pref)), + ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src }) + )} +} + +MovVector: (ast::MovDetails, ast::Arg2Mov>) = { + "mov" => {( + ast::MovDetails::new(ast::Type::Scalar(t.into())), + ast::Arg2Mov::Member(a) + )}, } #[inline] -MovScalarType: ast::MovScalarType = { - ".b16" => ast::MovScalarType::B16, - ".b32" => ast::MovScalarType::B32, - ".b64" => ast::MovScalarType::B64, - ".u16" => ast::MovScalarType::U16, - ".u32" => ast::MovScalarType::U32, - ".u64" => ast::MovScalarType::U64, - ".s16" => ast::MovScalarType::S16, - ".s32" => ast::MovScalarType::S32, - ".s64" => ast::MovScalarType::S64, - ".f32" => ast::MovScalarType::F32, - ".f64" => ast::MovScalarType::F64, - ".pred" => ast::MovScalarType::Pred +MovScalarType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, + ".pred" => ast::ScalarType::Pred }; #[inline] -MovVectorType: ast::MovVectorType = { - ".b16" => ast::MovVectorType::B16, - ".b32" => ast::MovVectorType::B32, - ".b64" => ast::MovVectorType::B64, - ".u16" => ast::MovVectorType::U16, - ".u32" => ast::MovVectorType::U32, - ".u64" => ast::MovVectorType::U64, - ".s16" => ast::MovVectorType::S16, - ".s32" => ast::MovVectorType::S32, - ".s64" => ast::MovVectorType::S64, - ".f32" => ast::MovVectorType::F32, - ".f64" => ast::MovVectorType::F64, +MovVectorType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -902,7 +921,7 @@ ShlType: ast::ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { - "st" "," => { + "st" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -1044,13 +1063,13 @@ Arg2: ast::Arg2> = { "," => ast::Arg2{<>} }; -Arg2Vec: ast::Arg2Vec> = { - "," => ast::Arg2Vec::Dst(dst, dst.0, src), - "," => ast::Arg2Vec::Src(dst, src), - "," => ast::Arg2Vec::Both(dst, dst.0, src), +Arg2MovMember: ast::Arg2MovMember> = { + "," => ast::Arg2MovMember::Dst(dst, dst.0, src), + "," => ast::Arg2MovMember::Src(dst, src), + "," => ast::Arg2MovMember::Both(dst, dst.0, src), }; -VectorOperand: (&'input str, u8) = { +MemberOperand: (&'input str, u8) = { "." =>? { let suf_idx = vector_index(suf)?; Ok((pref, suf_idx)) @@ -1061,6 +1080,15 @@ VectorOperand: (&'input str, u8) = { } }; +VectorExtract: Vec<&'input str> = { + "{" "," "}" => { + vec![r1, r2] + }, + "{" "," "," "," "}" => { + vec![r1, r2, r3, r4] + }, +}; + Arg3: ast::Arg3> = { "," "," => ast::Arg3{<>} }; diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index ef308f0..be16d2e 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -4,15 +4,16 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 + OpCapability Float64 %29 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" %GlobalSize - OpDecorate %GlobalSize BuiltIn GlobalSize + OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize %void = OpTypeVoid %uint = OpTypeInt 32 0 - %v3uint = OpTypeVector %uint 3 -%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint - %GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant + %v4uint = OpTypeVector %uint 4 +%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint +%gl_WorkGroupSize = OpVariable %_ptr_UniformConstant_v4uint UniformConstant %ulong = OpTypeInt 64 0 %35 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong @@ -40,7 +41,7 @@ %25 = OpConvertUToPtr %_ptr_Generic_uint %16 %15 = OpLoad %uint %25 OpStore %6 %15 - %18 = OpLoad %v3uint %GlobalSize + %18 = OpLoad %v4uint %gl_WorkGroupSize %24 = OpCompositeExtract %uint %18 0 %17 = OpCopyObject %uint %24 OpStore %7 %17 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a1d4b6a..981da86 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,6 +1,7 @@ use crate::ast; use rspirv::{binary::Disassemble, dr}; use std::collections::{hash_map, HashMap, HashSet}; +use std::convert::TryInto; use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; @@ -282,7 +283,7 @@ fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, global: &GlobalStringIdResolver<'a>, - func_directive: ast::MethodDecl, + func_directive: ast::MethodDecl, all_args_lens: &mut HashMap>, ) -> Result<(), TranslateError> { if let ast::MethodDecl::Kernel(name, args) = &func_directive { @@ -334,8 +335,10 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Kernel); - builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Int8); + builder.capability(spirv::Capability::Int16); + builder.capability(spirv::Capability::Int64); + builder.capability(spirv::Capability::Float16); builder.capability(spirv::Capability::Float64); } @@ -362,8 +365,8 @@ fn to_ssa_function<'a>( fn expand_kernel_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>>, -) -> Vec> { + args: impl Iterator>, +) -> Vec> { args.map(|a| ast::KernelArgument { name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))), v_type: a.v_type, @@ -374,8 +377,8 @@ fn expand_kernel_params<'a, 'b>( fn expand_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: impl Iterator>>, -) -> Vec> { + args: impl Iterator>, +) -> Vec> { args.map(|a| { let ss = match a.v_type { ast::FnArgumentType::Reg(_) => StateSpace::Reg, @@ -393,7 +396,7 @@ fn expand_fn_params<'a, 'b>( fn to_ssa<'input, 'b>( mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, - f_args: ast::MethodDecl<'input, ExpandedArgParams>, + f_args: ast::MethodDecl<'input, spirv::Word>, f_body: Option>>>, ) -> Result, TranslateError> { let f_body = match f_body { @@ -409,11 +412,11 @@ fn to_ssa<'input, 'b>( let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); - let unadorned_statements = - add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + let typed_statements = + convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let (f_args, ssa_statements) = - insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args)?; + insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, f_args)?; let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; @@ -443,15 +446,16 @@ fn normalize_variable_decls(mut func: Vec) -> Vec, +fn convert_to_typed_statements( + func: Vec, fn_defs: &GlobalFnDeclResolver, id_defs: &NumericIdResolver, -) -> Result, TranslateError> { - func.into_iter() - .map(|s| { - match s { - Statement::Instruction(ast::Instruction::Call(call)) => { +) -> Result, TranslateError> { + let mut result = Vec::::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Call(call) => { // TODO: error out if lengths don't match let fn_def = fn_defs.get_fn_decl(call.func)?; let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); @@ -462,7 +466,7 @@ fn add_types_to_statements( func: call.func, param_list, }; - Ok(Statement::Call(resolved_call)) + result.push(Statement::Call(resolved_call)); } // Supported ld/st: // global: only compatible with reg b64/u64/s64 source/dest @@ -477,25 +481,24 @@ fn add_types_to_statements( // One complication: immediate address is only allowed in local, // It is not supported in generic ld // ld.local foo, [1]; - Statement::Instruction(ast::Instruction::Ld(mut d, arg)) => { + ast::Instruction::Ld(mut d, arg) => { match arg.src.underlying() { - None => return Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))), + None => {} Some(u) => { let (ss, _) = id_defs.get_typed(*u)?; match (d.state_space, ss) { (ast::LdStateSpace::Generic, StateSpace::Local) => { d.state_space = ast::LdStateSpace::Local; } - _ => (), + _ => {} }; } }; - - Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))) + result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast()))); } - Statement::Instruction(ast::Instruction::St(mut d, arg)) => { + ast::Instruction::St(mut d, arg) => { match arg.src1.underlying() { - None => return Ok(Statement::Instruction(ast::Instruction::St(d, arg))), + None => {} Some(u) => { let (ss, _) = id_defs.get_typed(*u)?; match (d.state_space, ss) { @@ -506,39 +509,101 @@ fn add_types_to_statements( }; } }; - Ok(Statement::Instruction(ast::Instruction::St(d, arg))) + result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast()))); } - Statement::Instruction(ast::Instruction::Mov(mut d, arg)) => { - if let Some(src_id) = arg.src.underlying() { - let (scope, _) = id_defs.get_typed(*src_id)?; - d.src_is_address = match scope { - StateSpace::Reg => false, - StateSpace::Const - | StateSpace::Global - | StateSpace::Local - | StateSpace::Shared - | StateSpace::Param - | StateSpace::ParamReg => true, - }; + ast::Instruction::Mov(mut d, args) => match args { + ast::Arg2Mov::Normal(arg) => { + if let Some(src_id) = arg.src.single_underlying() { + let (scope, _) = id_defs.get_typed(*src_id)?; + d.src_is_address = match scope { + StateSpace::Reg => false, + StateSpace::Const + | StateSpace::Global + | StateSpace::Local + | StateSpace::Shared + | StateSpace::Param + | StateSpace::ParamReg => true, + }; + } + result.push(Statement::Instruction(ast::Instruction::Mov( + d, + ast::Arg2Mov::Normal(arg.cast()), + ))); } - Ok(Statement::Instruction(ast::Instruction::Mov(d, arg))) + ast::Arg2Mov::Member(args) => { + if let Some(dst_typ) = args.vector_dst() { + match id_defs.get_typed(*dst_typ)? { + (_, ast::Type::Vector(_, len)) => { + d.dst_width = len; + } + _ => return Err(TranslateError::MismatchedType), + } + }; + if let Some((src_typ, _)) = args.vector_src() { + match id_defs.get_typed(*src_typ)? { + (_, ast::Type::Vector(_, len)) => { + d.src_width = len; + } + _ => return Err(TranslateError::MismatchedType), + } + }; + result.push(Statement::Instruction(ast::Instruction::Mov( + d, + ast::Arg2Mov::Member(args.cast()), + ))); + } + }, + ast::Instruction::Mul(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast()))) } - Statement::Instruction(ast::Instruction::MovVector(dets, args)) => { - let new_dets = match id_defs.get_typed(*args.dst())? { - (_, ast::Type::Vector(_, len)) => ast::MovVectorDetails { - length: len, - ..dets - }, - _ => dets, - }; - Ok(Statement::Instruction(ast::Instruction::MovVector( - new_dets, args, - ))) + ast::Instruction::Add(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast()))) } - s => Ok(s), - } - }) - .collect::, _>>() + ast::Instruction::Setp(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast()))) + } + ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction( + ast::Instruction::SetpBool(d, a.cast()), + )), + ast::Instruction::Not(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast()))) + } + ast::Instruction::Bra(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast()))) + } + ast::Instruction::Cvt(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast()))) + } + ast::Instruction::Cvta(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast()))) + } + ast::Instruction::Shl(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast()))) + } + ast::Instruction::Ret(d) => { + result.push(Statement::Instruction(ast::Instruction::Ret(d))) + } + ast::Instruction::Abs(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast()))) + } + ast::Instruction::Mad(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast()))) + } + }, + Statement::Label(i) => result.push(Statement::Label(i)), + Statement::Variable(v) => result.push(Statement::Variable(v)), + Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)), + Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)), + Statement::Call(c) => result.push(Statement::Call(c.cast())), + Statement::Composite(c) => result.push(Statement::Composite(c)), + Statement::Conditional(c) => result.push(Statement::Conditional(c)), + Statement::Conversion(c) => result.push(Statement::Conversion(c)), + Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), + Statement::Undef(_, _) => return Err(TranslateError::Unreachable), + } + } + Ok(result) } fn to_resolved_fn_args( @@ -576,7 +641,8 @@ fn normalize_labels( | Statement::RetValue(_, _) | Statement::Conversion(_) | Statement::Constant(_) - | Statement::Label(_) => (), + | Statement::Label(_) + | Statement::Undef(_, _) => (), } } iter::once(Statement::Label(id_def.new_id(None))) @@ -590,7 +656,7 @@ fn normalize_labels( fn normalize_predicates( func: Vec, id_def: &mut NumericIdResolver, -) -> Vec { +) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { match s { @@ -630,16 +696,10 @@ fn normalize_predicates( } fn insert_mem_ssa_statements<'a, 'b>( - func: Vec, + func: Vec, id_def: &mut MutableNumericIdResolver, - mut f_args: ast::MethodDecl<'a, ExpandedArgParams>, -) -> Result< - ( - ast::MethodDecl<'a, ExpandedArgParams>, - Vec, - ), - TranslateError, -> { + mut f_args: ast::MethodDecl<'a, spirv::Word>, +) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec), TranslateError> { let mut result = Vec::with_capacity(func.len()); let out_param = match &mut f_args { ast::MethodDecl::Kernel(_, in_params) => { @@ -697,7 +757,9 @@ fn insert_mem_ssa_statements<'a, 'b>( }; for s in func { match s { - Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call)?, + Statement::Call(call) => { + insert_mem_ssa_statement_default(id_def, &mut result, call.cast())? + } Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { if let Some(out_param) = out_param { @@ -734,7 +796,8 @@ fn insert_mem_ssa_statements<'a, 'b>( | Statement::StoreVar(_, _) | Statement::Conversion(_) | Statement::RetValue(_, _) - | Statement::Constant(_) => unreachable!(), + | Statement::Constant(_) + | Statement::Undef(_, _) => {} Statement::Composite(_) => todo!(), } } @@ -751,7 +814,7 @@ trait VisitVariable: Sized { >( self, f: &mut F, - ) -> Result; + ) -> Result; } trait VisitVariableExpanded { fn visit_variable_extended< @@ -767,7 +830,7 @@ trait VisitVariableExpanded { fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( id_def: &mut MutableNumericIdResolver, - result: &mut Vec, + result: &mut Vec, stmt: F, ) -> Result<(), TranslateError> { let mut post_statements = Vec::new(); @@ -808,7 +871,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( } fn expand_arguments<'a, 'b>( - func: Vec, + func: Vec, id_def: &'b mut MutableNumericIdResolver<'a>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); @@ -840,9 +903,10 @@ fn expand_arguments<'a, 'b>( Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), - Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => { - unreachable!() - } + Statement::Composite(_) + | Statement::Conversion(_) + | Statement::Constant(_) + | Statement::Undef(_, _) => unreachable!(), } } Ok(result) @@ -865,12 +929,26 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { post_stmts: Vec::new(), } } -} -impl<'a, 'b> ArgumentMapVisitor - for FlattenArguments<'a, 'b> -{ - fn variable( + fn insert_composite_read( + func: &mut Vec, + id_def: &mut MutableNumericIdResolver<'a>, + (scalar_type, vec_len): (ast::ScalarType, u8), + scalar_dst: Option, + composite_src: (spirv::Word, u8), + ) -> spirv::Word { + let new_id = + scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len))); + func.push(Statement::Composite(CompositeRead { + typ: scalar_type, + dst: new_id, + src_composite: composite_src.0, + src_index: composite_src.1 as u32, + })); + new_id + } + + fn reg( &mut self, desc: ArgumentDescriptor, _: Option, @@ -878,90 +956,177 @@ impl<'a, 'b> ArgumentMapVisitor Ok(desc.op) } + fn reg_offset( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, i32)>, + typ: ast::Type, + ) -> Result { + let (reg, offset) = desc.op; + match desc.sema { + ArgumentSemantics::Default => { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + let result_id = self.id_def.new_id(typ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i64, + })); + let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + Ok(result_id) + } + ArgumentSemantics::PhysicalPointer => { + let scalar_t = ast::ScalarType::U64; + let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i64, + })); + let int_type = ast::IntType::U64; + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + Ok(result_id) + } + ArgumentSemantics::RegisterPointer => { + if offset == 0 { + return Ok(reg); + } + todo!() + } + ArgumentSemantics::Address => todo!(), + } + } + + fn immediate( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value: desc.op as i64, + })); + Ok(id) + } + + fn member_src( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, u8)>, + (scalar_type, vec_len): (ast::ScalarType, u8), + ) -> Result { + if desc.is_dst { + return Err(TranslateError::Unreachable); + } + let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len)); + self.func.push(Statement::Composite(CompositeRead { + typ: scalar_type, + dst: new_id, + src_composite: desc.op.0, + src_index: desc.op.1 as u32, + })); + Ok(new_id) + } + + fn vector( + &mut self, + desc: ArgumentDescriptor<&Vec>, + typ: ast::Type, + ) -> Result { + let (scalar_type, vec_len) = typ.get_vector()?; + if !desc.is_dst { + let mut new_id = self.id_def.new_id(typ); + self.func.push(Statement::Undef(typ, new_id)); + for (idx, id) in desc.op.iter().enumerate() { + let newer_id = self.id_def.new_id(typ); + self.func.push(Statement::Instruction(ast::Instruction::Mov( + ast::MovDetails { + typ: typ, + src_is_address: false, + dst_width: 0, + src_width: vec_len, + }, + ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( + (newer_id, idx as u8), + new_id, + *id, + )), + ))); + new_id = newer_id; + } + Ok(new_id) + } else { + let new_id = self.id_def.new_id(typ); + for (idx, id) in desc.op.iter().enumerate() { + Self::insert_composite_read( + &mut self.post_stmts, + self.id_def, + (scalar_type, vec_len), + Some(*id), + (new_id, idx as u8), + ); + } + Ok(new_id) + } + } +} + +impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { + fn id( + &mut self, + desc: ArgumentDescriptor, + t: Option, + ) -> Result { + self.reg(desc, t) + } + fn operand( &mut self, desc: ArgumentDescriptor>, typ: ast::Type, ) -> Result { match desc.op { - ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)), - ast::Operand::Imm(x) => { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id, - typ: scalar_t, - value: x as i64, - })); - Ok(id) + ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ), + ast::Operand::RegOffset(reg, offset) => { + self.reg_offset(desc.new_op((reg, offset)), typ) } - ast::Operand::RegOffset(reg, offset) => match desc.sema { - ArgumentSemantics::Default => { - let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar - } else { - todo!() - }; - let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - let result_id = self.id_def.new_id(typ); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - Ok(result_id) - } - ArgumentSemantics::PhysicalPointer => { - let scalar_t = ast::ScalarType::U64; - let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t)); - self.func.push(Statement::Constant(ConstantDefinition { - dst: id_constant_stmt, - typ: scalar_t, - value: offset as i64, - })); - let int_type = ast::IntType::U64; - self.func.push(Statement::Instruction( - ast::Instruction::::Add( - ast::AddDetails::Int(ast::AddIntDesc { - typ: int_type, - saturate: false, - }), - ast::Arg3 { - dst: result_id, - src1: reg, - src2: id_constant_stmt, - }, - ), - )); - Ok(result_id) - } - ArgumentSemantics::RegisterPointer => { - if offset == 0 { - return Ok(reg); - } - todo!() - } - ArgumentSemantics::Address => todo!(), - }, } } @@ -971,26 +1136,41 @@ impl<'a, 'b> ArgumentMapVisitor typ: ast::Type, ) -> Result { match desc.op { - ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)), - ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ), + ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), + ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ), } } - fn src_vec_operand( + fn src_member_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vec_len): (ast::MovVectorType, u8), + (scalar_type, vec_len): (ast::ScalarType, u8), ) -> Result { - let new_id = self - .id_def - .new_id(ast::Type::Vector(scalar_type.into(), vec_len)); - self.func.push(Statement::Composite(CompositeRead { - typ: scalar_type, - dst: new_id, - src_composite: desc.op.0, - src_index: desc.op.1 as u32, - })); - Ok(new_id) + self.member_src(desc, (scalar_type, vec_len)) + } + + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> Result { + match desc.op { + ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), + } + } + + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> Result { + match desc.op { + ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ), + ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ), + ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), + } } } @@ -1078,14 +1258,13 @@ fn insert_implicit_conversions( |arg| ast::Instruction::St(st, arg), ) } - ast::Instruction::Mov(d, mut arg) => { + ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2 // TODO: handle the case of mixed vector/scalar implicit conversions let inst_typ_is_bit = match d.typ { - ast::MovType::Scalar(t) => { - ast::ScalarType::from(t).kind() == ScalarKind::Bit - } - ast::MovType::Vector(_, _) => false, + ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit, + ast::Type::Vector(_, _) => false, + ast::Type::Array(_, _) => false, }; let mut did_vector_implicit = false; let mut post_conv = None; @@ -1115,12 +1294,15 @@ fn insert_implicit_conversions( } } if did_vector_implicit { - result.push(Statement::Instruction(ast::Instruction::Mov(d, arg))); + result.push(Statement::Instruction(ast::Instruction::Mov( + d, + ast::Arg2Mov::Normal(arg), + ))); } else { insert_implicit_bitcasts( &mut result, id_def, - ast::Instruction::Mov(d, arg), + ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)), )?; } if let Some(post_conv) = post_conv { @@ -1129,13 +1311,14 @@ fn insert_implicit_conversions( } inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?, }, - s @ Statement::Composite(_) - | s @ Statement::Conditional(_) + Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?, + s @ Statement::Conditional(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) | s @ Statement::LoadVar(_, _) | s @ Statement::StoreVar(_, _) + | s @ Statement::Undef(_, _) | s @ Statement::RetValue(_, _) => result.push(s), Statement::Conversion(_) => unreachable!(), } @@ -1146,7 +1329,7 @@ fn insert_implicit_conversions( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - method_decl: &ast::MethodDecl, + method_decl: &ast::MethodDecl, ) -> (spirv::Word, spirv::Word) { match method_decl { ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn( @@ -1173,7 +1356,7 @@ fn emit_function_body_ops( map: &mut TypeWordMap, opencl: spirv::Word, func: &[ExpandedStatement], -) -> Result<(), dr::Error> { +) -> Result<(), TranslateError> { for s in func { match s { Statement::Label(id) => { @@ -1305,11 +1488,34 @@ fn emit_function_body_ops( } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(d, arg) => { - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ))); - builder.copy_object(result_type, Some(arg.dst), arg.src)?; - } + ast::Instruction::Mov(d, arg) => match arg { + ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src }) + | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => { + let result_type = + map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ))); + builder.copy_object(result_type, Some(*dst), *src)?; + } + ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( + dst, + composite_src, + scalar_src, + )) + | ast::Arg2Mov::Member(ast::Arg2MovMember::Both( + dst, + composite_src, + scalar_src, + )) => { + let result_type = map.get_or_add(builder, SpirvType::from(d.typ)); + let result_id = Some(dst.0); + builder.composite_insert( + result_type, + result_id, + *scalar_src, + *composite_src, + [dst.1 as u32], + )?; + } + }, ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Int(ref ctr) => { emit_mul_int(builder, map, opencl, ctr, arg)?; @@ -1361,31 +1567,6 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::SetpBool(_, _) => todo!(), - ast::Instruction::MovVector(typ, arg) => match arg { - ast::Arg2Vec::Dst((dst, dst_index), composite_src, src) - | ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => { - let result_type = map.get_or_add( - builder, - SpirvType::Vector( - SpirvScalarKey::from(ast::ScalarType::from(typ.typ)), - typ.length, - ), - ); - let result_id = Some(*dst); - builder.composite_insert( - result_type, - result_id, - *src, - *composite_src, - [*dst_index as u32], - )?; - } - ast::Arg2Vec::Src(dst, src) => { - let result_type = - map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ)); - builder.copy_object(result_type, Some(*dst), *src)?; - } - }, ast::Instruction::Mad(mad, arg) => match mad { ast::MulDetails::Int(ref desc) => { emit_mad_int(builder, map, opencl, desc, arg)? @@ -1413,6 +1594,10 @@ fn emit_function_body_ops( [c.src_index], )?; } + Statement::Undef(t, id) => { + let result_type = map.get_or_add(builder, SpirvType::from(*t)); + builder.undef(result_type, Some(*id)); + } } } Ok(()) @@ -2016,11 +2201,11 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>( &'b mut self, - header: &'b ast::MethodDecl<'a, ast::ParsedArgParams<'a>>, + header: &'b ast::MethodDecl<'a, &'a str>, ) -> ( FnStringIdResolver<'a, 'b>, GlobalFnDeclResolver<'a, 'b>, - ast::MethodDecl<'a, ExpandedArgParams>, + ast::MethodDecl<'a, spirv::Word>, ) { // In case a function decl was inserted earlier we want to use its id let name_id = self.get_or_add_def(header.name()); @@ -2213,7 +2398,7 @@ impl<'b> MutableNumericIdResolver<'b> { enum Statement { Label(u32), - Variable(ast::Variable), + Variable(ast::Variable), Instruction(I), LoadVar(ast::Arg2, ast::Type), StoreVar(ast::Arg2St, ast::Type), @@ -2224,6 +2409,7 @@ enum Statement { Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), + Undef(ast::Type, spirv::Word), } struct ResolvedCall { @@ -2233,8 +2419,19 @@ struct ResolvedCall { pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>, } -impl> ResolvedCall { - fn map, V: ArgumentMapVisitor>( +impl ResolvedCall { + fn cast>(self) -> ResolvedCall { + ResolvedCall { + uniform: self.uniform, + ret_params: self.ret_params, + func: self.func, + param_list: self.param_list, + } + } +} + +impl> ResolvedCall { + fn map, V: ArgumentMapVisitor>( self, visitor: &mut V, ) -> Result, TranslateError> { @@ -2242,7 +2439,7 @@ impl> ResolvedCall { .ret_params .into_iter() .map::, _>(|(id, typ)| { - let new_id = visitor.variable( + let new_id = visitor.id( ArgumentDescriptor { op: id, is_dst: true, @@ -2253,7 +2450,7 @@ impl> ResolvedCall { Ok((new_id, typ)) }) .collect::, _>>()?; - let func = visitor.variable( + let func = visitor.id( ArgumentDescriptor { op: self.func, is_dst: false, @@ -2285,7 +2482,7 @@ impl> ResolvedCall { } } -impl VisitVariable for ResolvedCall { +impl VisitVariable for ResolvedCall { fn visit_variable< 'a, F: FnMut( @@ -2295,7 +2492,7 @@ impl VisitVariable for ResolvedCall { >( self, f: &mut F, - ) -> Result { + ) -> Result { Ok(Statement::Call(self.map(f)?)) } } @@ -2314,16 +2511,16 @@ impl VisitVariableExpanded for ResolvedCall { } } -pub trait ArgParamsEx: ast::ArgParams { +pub trait ArgParamsEx: ast::ArgParams + Sized { fn get_fn_decl<'x, 'b>( - id: &Self::ID, + id: &Self::Id, decl: &'b GlobalFnDeclResolver<'x, 'b>, ) -> Result<&'b FnDecl, TranslateError>; } impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { fn get_fn_decl<'x, 'b>( - id: &Self::ID, + id: &Self::Id, decl: &'b GlobalFnDeclResolver<'x, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl_str(id) @@ -2331,6 +2528,25 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { } enum NormalizedArgParams {} + +impl ast::ArgParams for NormalizedArgParams { + type Id = spirv::Word; + type Operand = ast::Operand; + type CallOperand = ast::CallOperand; + type IdOrVector = ast::IdOrVector; + type OperandOrVector = ast::OperandOrVector; + type SrcMemberOperand = (spirv::Word, u8); +} + +impl ArgParamsEx for NormalizedArgParams { + fn get_fn_decl<'a, 'b>( + id: &Self::Id, + decl: &'b GlobalFnDeclResolver<'a, 'b>, + ) -> Result<&'b FnDecl, TranslateError> { + decl.get_fn_decl(*id) + } +} + type NormalizedStatement = Statement< ( Option>, @@ -2338,18 +2554,46 @@ type NormalizedStatement = Statement< ), NormalizedArgParams, >; -type UnadornedStatement = Statement, NormalizedArgParams>; -impl ast::ArgParams for NormalizedArgParams { - type ID = spirv::Word; +type UnconditionalStatement = Statement, NormalizedArgParams>; + +enum TypedArgParams {} + +impl ast::ArgParams for TypedArgParams { + type Id = spirv::Word; type Operand = ast::Operand; type CallOperand = ast::CallOperand; - type VecOperand = (spirv::Word, u8); + type IdOrVector = ast::IdOrVector; + type OperandOrVector = ast::OperandOrVector; + type SrcMemberOperand = (spirv::Word, u8); } -impl ArgParamsEx for NormalizedArgParams { +impl ArgParamsEx for TypedArgParams { fn get_fn_decl<'a, 'b>( - id: &Self::ID, + id: &Self::Id, + decl: &'b GlobalFnDeclResolver<'a, 'b>, + ) -> Result<&'b FnDecl, TranslateError> { + decl.get_fn_decl(*id) + } +} + +type TypedStatement = Statement, TypedArgParams>; + +enum ExpandedArgParams {} +type ExpandedStatement = Statement, ExpandedArgParams>; + +impl ast::ArgParams for ExpandedArgParams { + type Id = spirv::Word; + type Operand = spirv::Word; + type CallOperand = spirv::Word; + type IdOrVector = spirv::Word; + type OperandOrVector = spirv::Word; + type SrcMemberOperand = spirv::Word; +} + +impl ArgParamsEx for ExpandedArgParams { + fn get_fn_decl<'a, 'b>( + id: &Self::Id, decl: &'b GlobalFnDeclResolver<'a, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) @@ -2367,52 +2611,43 @@ pub enum StateSpace { ParamReg, } -enum ExpandedArgParams {} -type ExpandedStatement = Statement, ExpandedArgParams>; - struct Function<'input> { - pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>, + pub func_directive: ast::MethodDecl<'input, spirv::Word>, pub globals: Vec, pub body: Option>, } -impl ast::ArgParams for ExpandedArgParams { - type ID = spirv::Word; - type Operand = spirv::Word; - type CallOperand = spirv::Word; - type VecOperand = spirv::Word; -} - -impl ArgParamsEx for ExpandedArgParams { - fn get_fn_decl<'a, 'b>( - id: &Self::ID, - decl: &'b GlobalFnDeclResolver<'a, 'b>, - ) -> Result<&'b FnDecl, TranslateError> { - decl.get_fn_decl(*id) - } -} - -trait ArgumentMapVisitor { - fn variable( +pub trait ArgumentMapVisitor { + fn id( &mut self, - desc: ArgumentDescriptor, + desc: ArgumentDescriptor, typ: Option, - ) -> Result; + ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, typ: ast::Type, ) -> Result; + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result; + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result; fn src_call_operand( &mut self, desc: ArgumentDescriptor, typ: ast::Type, ) -> Result; - fn src_vec_operand( + fn src_member_operand( &mut self, - desc: ArgumentDescriptor, - typ: (ast::MovVectorType, u8), - ) -> Result; + desc: ArgumentDescriptor, + typ: (ast::ScalarType, u8), + ) -> Result; } impl ArgumentMapVisitor for T @@ -2422,7 +2657,7 @@ where Option, ) -> Result, { - fn variable( + fn id( &mut self, desc: ArgumentDescriptor, t: Option, @@ -2438,6 +2673,22 @@ where self(desc, Some(t)) } + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result { + self(desc, Some(typ)) + } + + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor, + typ: ast::Type, + ) -> Result { + self(desc, Some(typ)) + } + fn src_call_operand( &mut self, desc: ArgumentDescriptor, @@ -2446,10 +2697,10 @@ where self(desc, Some(t)) } - fn src_vec_operand( + fn src_member_operand( &mut self, desc: ArgumentDescriptor, - (scalar_type, vec_len): (ast::MovVectorType, u8), + (scalar_type, vec_len): (ast::ScalarType, u8), ) -> Result { self( desc.new_op(desc.op), @@ -2462,7 +2713,7 @@ impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> fo where T: FnMut(&str) -> Result, { - fn variable( + fn id( &mut self, desc: ArgumentDescriptor<&str>, _: Option, @@ -2477,8 +2728,38 @@ where ) -> Result, TranslateError> { match desc.op { ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)), + ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), + } + } + + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor>, + _: ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)), + ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec( + ids.into_iter().map(self).collect::>()?, + )), + } + } + + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor>, + _: ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)), + ast::OperandOrVector::RegOffset(id, imm) => { + Ok(ast::OperandOrVector::RegOffset(self(id)?, imm)) + } + ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), + ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec( + ids.into_iter().map(self).collect::>()?, + )), } } @@ -2493,16 +2774,16 @@ where } } - fn src_vec_operand( + fn src_member_operand( &mut self, desc: ArgumentDescriptor<(&str, u8)>, - _: (ast::MovVectorType, u8), + _: (ast::ScalarType, u8), ) -> Result<(spirv::Word, u8), TranslateError> { Ok((self(desc.op.0)?, desc.op.1)) } } -struct ArgumentDescriptor { +pub struct ArgumentDescriptor { op: Op, is_dst: bool, sema: ArgumentSemantics, @@ -2536,22 +2817,19 @@ impl ast::Instruction { visitor: &mut V, ) -> Result, TranslateError> { Ok(match self { - ast::Instruction::MovVector(t, a) => { - ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))?) - } ast::Instruction::Abs(d, arg) => { ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => unreachable!(), + ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; let is_param = d.state_space == ast::LdStateSpace::Param || d.state_space == ast::LdStateSpace::Local; - ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?) + ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?) } ast::Instruction::Mov(d, a) => { - let mapped = a.map(visitor, d.src_is_address, d.typ.into())?; + let mapped = a.map(visitor, d)?; ast::Instruction::Mov(d, mapped) } ast::Instruction::Mul(d, a) => { @@ -2617,7 +2895,7 @@ impl ast::Instruction { } } -impl VisitVariable for ast::Instruction { +impl VisitVariable for ast::Instruction { fn visit_variable< 'a, F: FnMut( @@ -2627,19 +2905,19 @@ impl VisitVariable for ast::Instruction { >( self, f: &mut F, - ) -> Result { + ) -> Result { Ok(Statement::Instruction(self.map(f)?)) } } -impl ArgumentMapVisitor for T +impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, Option, ) -> Result, { - fn variable( + fn id( &mut self, desc: ArgumentDescriptor, t: Option, @@ -2673,10 +2951,47 @@ where } } - fn src_vec_operand( + fn id_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)), + ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec( + ids.iter() + .map(|id| self(desc.new_op(*id), Some(typ))) + .collect::>()?, + )), + } + } + + fn operand_or_vector( + &mut self, + desc: ArgumentDescriptor>, + typ: ast::Type, + ) -> Result, TranslateError> { + match desc.op { + ast::OperandOrVector::Reg(id) => { + Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?)) + } + ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset( + self(desc.new_op(id), Some(typ))?, + imm, + )), + ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), + ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec( + ids.iter() + .map(|id| self(desc.new_op(*id), Some(typ))) + .collect::>()?, + )), + } + } + + fn src_member_operand( &mut self, desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vector_len): (ast::MovVectorType, u8), + (scalar_type, vector_len): (ast::ScalarType, u8), ) -> Result<(spirv::Word, u8), TranslateError> { Ok(( self( @@ -2750,7 +3065,6 @@ impl ast::Instruction { ast::Instruction::Bra(_, a) => Some(a.src), ast::Instruction::Ld(_, _) | ast::Instruction::Mov(_, _) - | ast::Instruction::MovVector(_, _) | ast::Instruction::Mul(_, _) | ast::Instruction::Add(_, _) | ast::Instruction::Setp(_, _) @@ -2786,12 +3100,44 @@ type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; struct CompositeRead { - pub typ: ast::MovVectorType, + pub typ: ast::ScalarType, pub dst: spirv::Word, pub src_composite: spirv::Word, pub src_index: u32, } +impl VisitVariableExpanded for CompositeRead { + fn visit_variable_extended< + F: FnMut( + ArgumentDescriptor, + Option, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + Ok(Statement::Composite(CompositeRead { + dst: f( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(self.typ)), + )?, + src_composite: f( + ArgumentDescriptor { + op: self.src_composite, + is_dst: false, + sema: ArgumentSemantics::Default, + }, + Some(ast::Type::Scalar(self.typ)), + )?, + ..self + })) + } +} + struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -2875,12 +3221,16 @@ impl ast::VariableParamType { } impl ast::Arg1 { + fn cast>(self) -> ast::Arg1 { + ast::Arg1 { src: self.src } + } + fn map>( self, visitor: &mut V, t: Option, ) -> Result, TranslateError> { - let new_src = visitor.variable( + let new_src = visitor.id( ArgumentDescriptor { op: self.src, is_dst: false, @@ -2893,13 +3243,20 @@ impl ast::Arg1 { } impl ast::Arg2 { + fn cast>(self) -> ast::Arg2 { + ast::Arg2 { + src: self.src, + dst: self.dst, + } + } + fn map>( self, visitor: &mut V, src_is_addr: bool, t: ast::Type, ) -> Result, TranslateError> { - let new_dst = visitor.variable( + let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -2925,42 +3282,13 @@ impl ast::Arg2 { }) } - fn map_ld>( - self, - visitor: &mut V, - t: ast::Type, - is_param: bool, - ) -> Result, TranslateError> { - let dst = visitor.variable( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(t), - )?; - let src = visitor.operand( - ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: if is_param { - ArgumentSemantics::RegisterPointer - } else { - ArgumentSemantics::PhysicalPointer - }, - }, - t, - )?; - Ok(ast::Arg2 { dst, src }) - } - fn map_cvt>( self, visitor: &mut V, dst_t: ast::Type, src_t: ast::Type, ) -> Result, TranslateError> { - let dst = visitor.variable( + let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -2980,7 +3308,56 @@ impl ast::Arg2 { } } +impl ast::Arg2Ld { + fn cast>( + self, + ) -> ast::Arg2Ld { + ast::Arg2Ld { + dst: self.dst, + src: self.src, + } + } + + fn map>( + self, + visitor: &mut V, + t: ast::Type, + is_param: bool, + ) -> Result, TranslateError> { + let dst = visitor.id_or_vector( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + t.into(), + )?; + let src = visitor.operand( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: if is_param { + ArgumentSemantics::RegisterPointer + } else { + ArgumentSemantics::PhysicalPointer + }, + }, + t, + )?; + Ok(ast::Arg2Ld { dst, src }) + } +} + impl ast::Arg2St { + fn cast>( + self, + ) -> ast::Arg2St { + ast::Arg2St { + src1: self.src1, + src2: self.src2, + } + } + fn map>( self, visitor: &mut V, @@ -2999,7 +3376,7 @@ impl ast::Arg2St { }, t, )?; - let src2 = visitor.operand( + let src2 = visitor.operand_or_vector( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -3011,105 +3388,191 @@ impl ast::Arg2St { } } -impl ast::Arg2Vec { - fn dst(&self) -> &T::ID { - match self { - ast::Arg2Vec::Dst((d, _), _, _) - | ast::Arg2Vec::Src(d, _) - | ast::Arg2Vec::Both((d, _), _, _) => d, - } - } - +impl ast::Arg2Mov { fn map>( self, visitor: &mut V, - (scalar_type, vec_len): (ast::MovVectorType, u8), - ) -> Result, TranslateError> { + details: ast::MovDetails, + ) -> Result, TranslateError> { + Ok(match self { + ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?), + ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?), + }) + } +} + +impl ast::Arg2MovNormal

{ + fn cast>( + self, + ) -> ast::Arg2MovNormal { + ast::Arg2MovNormal { + dst: self.dst, + src: self.src, + } + } + + fn map>( + self, + visitor: &mut V, + details: ast::MovDetails, + ) -> Result, TranslateError> { + let dst = visitor.id_or_vector( + ArgumentDescriptor { + op: self.dst, + is_dst: true, + sema: ArgumentSemantics::Default, + }, + details.typ.into(), + )?; + let src = visitor.operand_or_vector( + ArgumentDescriptor { + op: self.src, + is_dst: false, + sema: if details.src_is_address { + ArgumentSemantics::RegisterPointer + } else { + ArgumentSemantics::PhysicalPointer + }, + }, + details.typ.into(), + )?; + Ok(ast::Arg2MovNormal { dst, src }) + } +} + +impl ast::Arg2MovMember { + fn cast>( + self, + ) -> ast::Arg2MovMember { match self { - ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => { - let dst = visitor.variable( + ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2), + ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src), + ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2), + } + } + + fn vector_dst(&self) -> Option<&T::Id> { + match self { + ast::Arg2MovMember::Src(_, _) => None, + ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => { + Some(d) + } + } + } + + fn vector_src(&self) -> Option<&T::SrcMemberOperand> { + match self { + ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d), + ast::Arg2MovMember::Dst(_, _, _) => None, + } + } +} + +impl ast::Arg2MovMember { + fn map>( + self, + visitor: &mut V, + details: ast::MovDetails, + ) -> Result, TranslateError> { + match self { + ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { + let dst = visitor.id( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - let src1 = visitor.variable( + let src1 = visitor.id( ArgumentDescriptor { op: composite_src, is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - let src2 = visitor.variable( + let src2 = visitor.id( ArgumentDescriptor { op: scalar_src, is_dst: false, - sema: ArgumentSemantics::Default, + sema: if details.src_is_address { + ArgumentSemantics::Address + } else { + ArgumentSemantics::Default + }, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - Ok(ast::Arg2Vec::Dst((dst, len), src1, src2)) + Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2)) } - ast::Arg2Vec::Src(dst, src) => { - let dst = visitor.variable( + ast::Arg2MovMember::Src(dst, src) => { + let dst = visitor.id( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - let src = visitor.src_vec_operand( + let scalar_typ = details.typ.get_scalar()?; + let src = visitor.src_member_operand( ArgumentDescriptor { op: src, is_dst: false, sema: ArgumentSemantics::Default, }, - (scalar_type, vec_len), + (scalar_typ.into(), details.src_width), )?; - Ok(ast::Arg2Vec::Src(dst, src)) + Ok(ast::Arg2MovMember::Src(dst, src)) } - ast::Arg2Vec::Both((dst, len), composite_src, src) => { - let dst = visitor.variable( + ast::Arg2MovMember::Both((dst, len), composite_src, src) => { + let dst = visitor.id( ArgumentDescriptor { op: dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - let composite_src = visitor.variable( + let composite_src = visitor.id( ArgumentDescriptor { op: composite_src, is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(scalar_type.into())), + Some(details.typ.into()), )?; - let src = visitor.src_vec_operand( + let scalar_typ = details.typ.get_scalar()?; + let src = visitor.src_member_operand( ArgumentDescriptor { op: src, is_dst: false, sema: ArgumentSemantics::Default, }, - (scalar_type, vec_len), + (scalar_typ.into(), details.src_width), )?; - Ok(ast::Arg2Vec::Both((dst, len), composite_src, src)) + Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src)) } } } } impl ast::Arg3 { + fn cast>(self) -> ast::Arg3 { + ast::Arg3 { + dst: self.dst, + src1: self.src1, + src2: self.src2, + } + } + fn map_non_shift>( self, visitor: &mut V, t: ast::Type, ) -> Result, TranslateError> { - let dst = visitor.variable( + let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -3141,7 +3604,7 @@ impl ast::Arg3 { visitor: &mut V, t: ast::Type, ) -> Result, TranslateError> { - let dst = visitor.variable( + let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -3170,12 +3633,21 @@ impl ast::Arg3 { } impl ast::Arg4 { + fn cast>(self) -> ast::Arg4 { + ast::Arg4 { + dst: self.dst, + src1: self.src1, + src2: self.src2, + src3: self.src3, + } + } + fn map>( self, visitor: &mut V, t: ast::Type, ) -> Result, TranslateError> { - let dst = visitor.variable( + let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -3217,12 +3689,21 @@ impl ast::Arg4 { } impl ast::Arg4Setp { + fn cast>(self) -> ast::Arg4Setp { + ast::Arg4Setp { + dst1: self.dst1, + dst2: self.dst2, + src1: self.src1, + src2: self.src2, + } + } + fn map>( self, visitor: &mut V, t: ast::Type, ) -> Result, TranslateError> { - let dst1 = visitor.variable( + let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, is_dst: true, @@ -3233,7 +3714,7 @@ impl ast::Arg4Setp { let dst2 = self .dst2 .map(|dst2| { - visitor.variable( + visitor.id( ArgumentDescriptor { op: dst2, is_dst: true, @@ -3269,12 +3750,22 @@ impl ast::Arg4Setp { } impl ast::Arg5 { + fn cast>(self) -> ast::Arg5 { + ast::Arg5 { + dst1: self.dst1, + dst2: self.dst2, + src1: self.src1, + src2: self.src2, + src3: self.src3, + } + } + fn map>( self, visitor: &mut V, t: ast::Type, ) -> Result, TranslateError> { - let dst1 = visitor.variable( + let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, is_dst: true, @@ -3285,7 +3776,7 @@ impl ast::Arg5 { let dst2 = self .dst2 .map(|dst2| { - visitor.variable( + visitor.id( ArgumentDescriptor { op: dst2, is_dst: true, @@ -3329,6 +3820,22 @@ impl ast::Arg5 { } } +impl ast::Type { + fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> { + match self { + ast::Type::Vector(t, len) => Ok((t, len)), + _ => Err(TranslateError::MismatchedType), + } + } + + fn get_scalar(self) -> Result { + match self { + ast::Type::Scalar(t) => Ok(t), + _ => Err(TranslateError::MismatchedType), + } + } +} + impl ast::CallOperand { fn map_variable Result>( self, @@ -3528,13 +4035,21 @@ impl From for ast::VariableType { impl ast::Operand { fn underlying(&self) -> Option<&T> { match self { - ast::Operand::Reg(r) => Some(r), - ast::Operand::RegOffset(r, _) => Some(r), + ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), ast::Operand::Imm(_) => None, } } } +impl ast::OperandOrVector { + fn single_underlying(&self) -> Option<&T> { + match self { + ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r), + ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None, + } + } +} + fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { @@ -3891,7 +4406,7 @@ fn insert_implicit_bitcasts( } Ok(()) } -impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { +impl<'a> ast::MethodDecl<'a, &'a str> { fn name(&self) -> &'a str { match self { ast::MethodDecl::Kernel(name, _) => name, @@ -3900,8 +4415,8 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { } } -impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> { - fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument

)) { +impl<'a> ast::MethodDecl<'a, spirv::Word> { + fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument)) { match self { ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f), ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {