From 9d92a6e284dce00b0b785a50f623d3715f8aeac4 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 5 May 2021 22:56:58 +0200 Subject: [PATCH] Start converting the translation to one type type --- ptx/src/ast.rs | 85 +-- ptx/src/ptx.lalrpop | 74 +-- ptx/src/translate.rs | 1189 ++++++++++++++++++++++-------------------- 3 files changed, 666 insertions(+), 682 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index c7b9563..364ec01 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -110,35 +110,7 @@ pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec), - Pointer(PointerType, LdStateSpace), -} - -#[derive(PartialEq, Eq, Clone)] -pub enum PointerType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, Vec), - // Instances of this variant are generated during stateful conversion - Pointer(ScalarType, LdStateSpace), -} - -impl From for PointerType { - fn from(t: ScalarType) -> Self { - PointerType::Scalar(t.into()) - } -} - -impl TryFrom for ScalarType { - type Error = (); - - fn try_from(value: PointerType) -> Result { - match value { - PointerType::Scalar(t) => Ok(t), - PointerType::Vector(_, _) => Err(()), - PointerType::Array(_, _) => Err(()), - PointerType::Pointer(_, _) => Err(()), - } - } + Pointer(ScalarType), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -222,6 +194,7 @@ pub enum StateSpace { Shared, Param, Generic, + Sreg, } pub struct PredAt { @@ -397,9 +370,9 @@ pub enum VectorPrefix { pub struct LdDetails { pub qualifier: LdStQualifier, - pub state_space: LdStateSpace, + pub state_space: StateSpace, pub caching: LdCacheOperator, - pub typ: PointerType, + pub typ: Type, pub non_coherent: bool, } @@ -418,17 +391,6 @@ pub enum MemScope { Sys, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -#[repr(u8)] -pub enum LdStateSpace { - Generic, - Const, - Global, - Local, - Param, - Shared, -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, @@ -612,20 +574,11 @@ impl CvtDetails { } pub struct CvtaDetails { - pub to: CvtaStateSpace, - pub from: CvtaStateSpace, + pub to: StateSpace, + pub from: StateSpace, pub size: CvtaSize, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum CvtaStateSpace { - Generic, - Const, - Global, - Local, - Shared, -} - pub enum CvtaSize { U32, U64, @@ -633,18 +586,9 @@ pub enum CvtaSize { pub struct StData { pub qualifier: LdStQualifier, - pub state_space: StStateSpace, + pub state_space: StateSpace, pub caching: StCacheOperator, - pub typ: PointerType, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum StStateSpace { - Generic, - Global, - Local, - Param, - Shared, + pub typ: Type, } #[derive(PartialEq, Eq)] @@ -717,7 +661,7 @@ pub struct MinMaxFloat { pub struct AtomDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub inner: AtomInnerDetails, } @@ -729,13 +673,6 @@ pub enum AtomSemantics { AcquireRelease, } -#[derive(Copy, Clone)] -pub enum AtomSpace { - Generic, - Global, - Shared, -} - #[derive(Copy, Clone)] pub enum AtomInnerDetails { Bit { op: AtomBitOp, typ: ScalarType }, @@ -777,7 +714,7 @@ pub enum AtomFloatOp { pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub typ: ScalarType, } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index dc439b7..8fee7c2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -619,9 +619,9 @@ ModuleVariable: ast::Variable<&'input str> = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new()) + (ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new()) } } }; @@ -643,7 +643,7 @@ ParamVariable: (Option, Vec, ast::Type, &'input str) = { (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new()) + (ast::Type::Pointer(t), Vec::new()) } }; (align, array_init, v_type, name) @@ -763,7 +763,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -775,7 +775,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -787,7 +787,7 @@ InstLd: ast::Instruction> = { ast::Instruction::Ld( ast::LdDetails { qualifier: ast::LdStQualifier::Weak, - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: true @@ -797,9 +797,9 @@ InstLd: ast::Instruction> = { } }; -LdStType: ast::PointerType = { - => ast::PointerType::Vector(t, v), - => ast::PointerType::Scalar(t), +LdStType: ast::Type = { + => ast::Type::Vector(t, v), + => ast::Type::Scalar(t), } LdStQualifier: ast::LdStQualifier = { @@ -815,11 +815,11 @@ MemScope: ast::MemScope = { ".sys" => ast::MemScope::Sys }; -LdNonGlobalStateSpace: ast::LdStateSpace = { - ".const" => ast::LdStateSpace::Const, - ".local" => ast::LdStateSpace::Local, - ".param" => ast::LdStateSpace::Param, - ".shared" => ast::LdStateSpace::Shared, +LdNonGlobalStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; LdCacheOperator: ast::LdCacheOperator = { @@ -1235,7 +1235,7 @@ InstSt: ast::Instruction> = { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), typ: t }, @@ -1249,11 +1249,11 @@ MemoryOperand: ast::Operand<&'input str> = { "[" "]" => o } -StStateSpace: ast::StStateSpace = { - ".global" => ast::StStateSpace::Global, - ".local" => ast::StStateSpace::Local, - ".param" => ast::StStateSpace::Param, - ".shared" => ast::StStateSpace::Shared, +StStateSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; StCacheOperator: ast::StCacheOperator = { @@ -1272,7 +1272,7 @@ InstRet: ast::Instruction> = { InstCvta: ast::Instruction> = { "cvta" => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, + to: ast::StateSpace::Generic, from, size: s }, @@ -1281,18 +1281,18 @@ InstCvta: ast::Instruction> = { "cvta" ".to" => { ast::Instruction::Cvta(ast::CvtaDetails { to, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, size: s }, a) } } -CvtaStateSpace: ast::CvtaStateSpace = { - ".const" => ast::CvtaStateSpace::Const, - ".global" => ast::CvtaStateSpace::Global, - ".local" => ast::CvtaStateSpace::Local, - ".shared" => ast::CvtaStateSpace::Shared, +CvtaStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".shared" => ast::StateSpace::Shared, } CvtaSize: ast::CvtaSize = { @@ -1450,7 +1450,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Bit { op, typ } }; ast::Instruction::Atom(details,a) @@ -1459,7 +1459,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, typ: ast::ScalarType::U32 @@ -1471,7 +1471,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, typ: ast::ScalarType::U32 @@ -1484,7 +1484,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Float { op, typ } }; ast::Instruction::Atom(details,a) @@ -1493,7 +1493,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op, typ } }; ast::Instruction::Atom(details,a) @@ -1502,7 +1502,7 @@ InstAtom: ast::Instruction> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Signed { op, typ } }; ast::Instruction::Atom(details,a) @@ -1514,7 +1514,7 @@ InstAtomCas: ast::Instruction> = { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), typ, }; ast::Instruction::AtomCas(details,a) @@ -1528,9 +1528,9 @@ AtomSemantics: ast::AtomSemantics = { ".acq_rel" => ast::AtomSemantics::AcquireRelease } -AtomSpace: ast::AtomSpace = { - ".global" => ast::AtomSpace::Global, - ".shared" => ast::AtomSpace::Shared +AtomSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".shared" => ast::StateSpace::Shared } AtomBitOp: ast::AtomBitOp = { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 4ba5729..a743496 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -37,6 +37,12 @@ fn error_unreachable() -> TranslateError { TranslateError::Unreachable } +macro_rules! new_todo { + () => { + todo!() + }; +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -48,52 +54,40 @@ enum SpirvType { } impl SpirvType { - fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { - let key = t.into(); - SpirvType::Pointer(Box::new(key), sc) - } -} - -impl From for SpirvType { - fn from(t: ast::Type) -> Self { + fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer( - Box::new(SpirvType::from(ast::Type::from(pointer_t))), - state_space.to_spirv(), - ), - } - } -} - -impl From for ast::Type { - fn from(t: ast::PointerType) -> Self { - match t { - ast::PointerType::Scalar(t) => ast::Type::Scalar(t), - ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len), - ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims), - ast::PointerType::Pointer(t, space) => { - ast::Type::Pointer(ast::PointerType::Scalar(t), space) + ast::Type::Pointer(pointer_t) => { + let spirv_space = match decl_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { + spirv::StorageClass::Private + } + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + }; + SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space) } } } + + fn pointer_to( + t: ast::Type, + inner_space: ast::StateSpace, + outer_space: spirv::StorageClass, + ) -> Self { + let key = Self::new(t, inner_space); + SpirvType::Pointer(Box::new(key), outer_space) + } } impl ast::Type { - fn param_pointer_to(self, space: ast::LdStateSpace) -> Result { - Ok(match self { - ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Vector(t, len) => { - ast::Type::Pointer(ast::PointerType::Vector(t, len), space) - } - ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), - ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { - ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) - } - ast::Type::Pointer(_, _) => return Err(error_unreachable()), - }) + fn param_pointer_to(self, space: ast::StateSpace) -> Result { + Ok(self) } } @@ -398,18 +392,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter()) } }, - ast::Type::Pointer(typ, state_space) => { - let base_t = typ.clone().into(); - let base = self.get_or_add_constant(b, &base_t, &[])?; - let result_type = self.get_or_add( - b, - SpirvType::Pointer( - Box::new(SpirvType::from(base_t)), - (*state_space).to_spirv(), - ), - ); - b.variable(result_type, None, (*state_space).to_spirv(), Some(base)) - } + ast::Type::Pointer(typ) => return Err(error_unreachable()), }) } @@ -702,11 +685,29 @@ fn multi_hash_map_append(m: &mut MultiHashMap, } } -// PTX represents dynamically allocated shared local memory as -// .extern .shared .align 4 .b8 shared_mem[]; -// In SPIRV/OpenCL world this is expressed as an additional argument -// This pass looks for all uses of .extern .shared and converts them to -// an additional method argument +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, @@ -715,7 +716,7 @@ fn convert_dynamic_shared_memory_usage<'input>( for dir in module.iter() { match dir { Directive::Variable(ast::Variable { - v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared), + v_type: ast::Type::Pointer(p_type), state_space: ast::StateSpace::Shared, name, .. @@ -799,48 +800,23 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::Variable { name: shared_id_param, align: None, - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - state_space: ast::StateSpace::Param, + v_type: ast::Type::Pointer(ast::ScalarType::B8), + state_space: ast::StateSpace::Shared, array_init: Vec::new(), } }); spirv_decl.uses_shared_mem = true; - let shared_var_id = new_id(); - let shared_var = ExpandedStatement::Variable(ast::Variable { - name: shared_var_id, - align: None, - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - state_space: ast::StateSpace::Reg, - array_init: Vec::new(), - }); - let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { - arg: ast::Arg2St { - src1: shared_var_id, - src2: shared_id_param, - }, - typ: ast::Type::Scalar(ast::ScalarType::B8), - member_index: None, - }); - let mut new_statements = vec![shared_var, shared_var_st]; - replace_uses_of_shared_memory( - &mut new_statements, + let statements = replace_uses_of_shared_memory( new_id, &extern_shared_decls, &mut methods_using_extern_shared, shared_id_param, - shared_var_id, statements, ); Directive::Method(Function { func_decl, globals, - body: Some(new_statements), + body: Some(statements), import_as, spirv_decl, tuning, @@ -852,14 +828,13 @@ fn convert_dynamic_shared_memory_usage<'input>( } fn replace_uses_of_shared_memory<'a>( - result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, + extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, - shared_var_id: spirv::Word, statements: Vec, -) { +) -> Vec { + let mut result = Vec::with_capacity(statements.len()); for statement in statements { match statement { Statement::Call(mut call) => { @@ -877,22 +852,18 @@ fn replace_uses_of_shared_memory<'a>( } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) { - if *typ == ast::ScalarType::B8 { - return shared_var_id; + if let Some(scalar_type) = extern_shared_decls.get(&id) { + if *scalar_type == ast::ScalarType::B8 { + return shared_id_param; } let replacement_id = new_id(); result.push(Statement::Conversion(ImplicitConversion { - src: shared_var_id, + src: shared_id_param, dst: replacement_id, - from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - to: ast::Type::Pointer( - ast::PointerType::Scalar((*typ).into()), - ast::LdStateSpace::Shared, - ), + from_type: ast::Type::Pointer(ast::ScalarType::B8), + from_space: ast::StateSpace::Shared, + to_type: ast::Type::Pointer((*scalar_type).into()), + to_space: ast::StateSpace::Shared, kind: ConversionKind::PtrToPtr { spirv_ptr: true }, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -906,6 +877,7 @@ fn replace_uses_of_shared_memory<'a>( } } } + result } fn get_callers_of_extern_shared<'a>( @@ -1055,8 +1027,9 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, - SpirvType::Pointer( - Box::new(SpirvType::from(reg.get_type())), + SpirvType::pointer_to( + reg.get_type(), + ast::StateSpace::Reg, spirv::StorageClass::Input, ), ); @@ -1158,7 +1131,10 @@ fn emit_function_header<'a>( } */ for input in &func_decl.input { - let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::new(input.v_type.clone(), input.state_space), + ); builder.function_parameter(Some(input.name), result_type)?; } Ok(fn_id) @@ -1219,26 +1195,26 @@ fn translate_variable<'a>( is_variable = true; var_type } - ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?, - ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, + ast::StateSpace::Const => var_type.param_pointer_to(ast::StateSpace::Const)?, + ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, ast::StateSpace::Shared => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; var_type } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? + var_type.param_pointer_to(ast::StateSpace::Shared)? } } - ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, - ast::StateSpace::Generic => todo!(), + ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, + ast::StateSpace::Generic | ast::StateSpace::Sreg => return Err(error_unreachable()), }; Ok(ast::Variable { align: var.align, v_type: var.v_type, state_space: var.state_space, - name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), + name: id_defs.get_or_add_def_typed(var.name, var_type, var.state_space, is_variable), array_init: var.array_init, }) } @@ -1283,7 +1259,10 @@ fn expand_kernel_params<'a, 'b>( Ok(ast::KernelArgument { name: fn_resolver.add_def( a.name, - Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?), + Some(( + ast::Type::from(a.v_type.clone()).param_pointer_to(ast::StateSpace::Param)?, + a.state_space, + )), false, ), v_type: a.v_type.clone(), @@ -1302,7 +1281,7 @@ fn expand_fn_params<'a, 'b>( args.map(|a| { let is_variable = a.state_space == ast::StateSpace::Reg; Ok(ast::FnArgument { - name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable), + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable), v_type: a.v_type.clone(), state_space: a.state_space, align: a.align, @@ -1339,15 +1318,15 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + //let typed_statements = + // convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, &f_args, &mut spirv_decl, )?; - let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?; + let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1366,7 +1345,7 @@ fn to_ssa<'input, 'b>( }) } -fn fix_builtins( +fn fix_special_registers( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { @@ -1402,7 +1381,8 @@ fn fix_builtins( continue; } }; - let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone())); + let temp_id = numeric_id_defs + .register_intermediate(Some((details.typ.clone(), details.state_space))); let real_dst = details.arg.dst; details.arg.dst = temp_id; result.push(Statement::LoadVar(LoadVarDetails { @@ -1410,14 +1390,17 @@ fn fix_builtins( src: sreg_src, dst: temp_id, }, + state_space: ast::StateSpace::Sreg, typ: ast::Type::Scalar(scalar_typ), member_index: Some((index, Some(vector_width))), })); result.push(Statement::Conversion(ImplicitConversion { src: temp_id, dst: real_dst, - from: ast::Type::Scalar(scalar_typ), - to: ast::Type::Scalar(ast::ScalarType::U32), + from_type: ast::Type::Scalar(scalar_typ), + from_space: ast::StateSpace::Sreg, + to_type: ast::Type::Scalar(ast::ScalarType::U32), + to_space: ast::StateSpace::Sreg, kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -1614,12 +1597,12 @@ fn convert_to_typed_statements( } ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { if let Some(src_id) = src.underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; + let (typ, _, _) = id_defs.get_typed(*src_id)?; let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, + ast::Type::Scalar(..) => false, + ast::Type::Vector(..) => false, + ast::Type::Array(..) => true, + ast::Type::Pointer(..) => true, }; d.src_is_address = take_address; } @@ -1666,6 +1649,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { is_dst: bool, vector_sema: ArgumentSemantics, typ: &ast::Type, + state_space: ast::StateSpace, idx: Vec, ) -> Result { // mov.u32 foobar, {a,b}; @@ -1673,7 +1657,9 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ast::Type::Vector(scalar_t, _) => *scalar_t, _ => return Err(TranslateError::MismatchedType), }; - let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, @@ -1696,7 +1682,7 @@ impl<'a, 'b> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -1705,15 +1691,20 @@ impl<'a, 'b> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { ast::Operand::Reg(reg) => TypedOperand::Reg(reg), ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), - ast::Operand::VecPack(vec) => { - TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) - } + ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector( + desc.is_dst, + desc.sema, + typ, + state_space, + vec, + )?), }) } } @@ -1735,37 +1726,33 @@ fn to_ptx_impl_atomic_call( semantics, scope, space, op ); // TODO: extract to a function - let ptr_space = match details.space { - ast::AtomSpace::Generic => ast::LdStateSpace::Generic, - ast::AtomSpace::Global => ast::LdStateSpace::Global, - ast::AtomSpace::Shared => ast::LdStateSpace::Shared, - }; + let ptr_space = details.space; let scalar_typ = ast::ScalarType::from(typ); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, vec![ ast::FnArgument { align: None, - v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), - state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + v_type: ast::Type::Pointer(typ), + state_space: ptr_space, + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(scalar_typ), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -1795,11 +1782,7 @@ fn to_ptx_impl_atomic_call( func: fn_id, ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)], param_list: vec![ - ( - arg.src1, - ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space), - ast::StateSpace::Reg, - ), + (arg.src1, ast::Type::Pointer(typ), ptr_space), ( arg.src2, ast::Type::Scalar(scalar_typ), @@ -1826,13 +1809,13 @@ fn to_ptx_impl_bfe_call( let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, @@ -1841,21 +1824,21 @@ fn to_ptx_impl_bfe_call( align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -1919,13 +1902,13 @@ fn to_ptx_impl_bfi_call( let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { - let fn_id = id_defs.new_non_variable(None); + let fn_id = id_defs.register_intermediate(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }], fn_id, @@ -1934,28 +1917,28 @@ fn to_ptx_impl_bfi_call( align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(typ.into()), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::Type::Scalar(ast::ScalarType::U32), state_space: ast::StateSpace::Reg, - name: id_defs.new_non_variable(None), + name: id_defs.register_intermediate(None), array_init: Vec::new(), }, ], @@ -2048,7 +2031,7 @@ fn normalize_labels( | Statement::RepackVector(..) => {} } } - iter::once(Statement::Label(id_def.new_non_variable(None))) + iter::once(Statement::Label(id_def.register_intermediate(None))) .chain(func.into_iter().filter(|s| match s { Statement::Label(i) => labels_in_use.contains(i), _ => true, @@ -2066,8 +2049,8 @@ fn normalize_predicates( Statement::Label(id) => result.push(Statement::Label(id)), Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { - let if_true = id_def.new_non_variable(None); - let if_false = id_def.new_non_variable(None); + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, @@ -2116,7 +2099,8 @@ fn insert_mem_ssa_statements<'a, 'b>( } for spirv_arg in fn_decl.input.iter_mut() { let typ = spirv_arg.v_type.clone(); - let new_id = id_def.new_non_variable(Some(typ.clone())); + let state_space = spirv_arg.state_space; + let new_id = id_def.register_intermediate(Some((typ.clone(), state_space))); result.push(Statement::Variable(ast::Variable { align: spirv_arg.align, v_type: spirv_arg.v_type.clone(), @@ -2129,6 +2113,7 @@ fn insert_mem_ssa_statements<'a, 'b>( src1: spirv_arg.name, src2: new_id, }, + state_space, typ, member_index: None, })); @@ -2143,13 +2128,15 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => { // TODO: handle multiple output args if let &[out_param] = &fn_decl.output.as_slice() { - let (typ, _) = id_def.get_typed(out_param.name)?; - let new_id = id_def.new_non_variable(Some(typ.clone())); + let (typ, space, _) = id_def.get_typed(out_param.name)?; + let new_id = id_def.register_intermediate(Some((typ.clone(), space))); result.push(Statement::LoadVar(LoadVarDetails { arg: ast::Arg2 { dst: new_id, src: out_param.name, }, + // TODO: ret with stateful conversion + state_space: new_todo!(), typ: typ.clone(), member_index: None, })); @@ -2161,13 +2148,16 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { - let generated_id = - id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); + let generated_id = id_def.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + ))); result.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: bra.predicate, }, + state_space: ast::StateSpace::Reg, typ: ast::Type::Scalar(ast::ScalarType::Pred), member_index: None, })); @@ -2204,6 +2194,7 @@ struct VisitArgumentDescriptor< > { desc: ArgumentDescriptor, typ: &'a ast::Type, + state_space: ast::StateSpace, stmt_ctor: Ctor, } @@ -2218,7 +2209,9 @@ impl< self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { - Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) + Ok((self.stmt_ctor)( + visitor.id(self.desc, Some((self.typ, self.state_space)))?, + )) } } @@ -2232,13 +2225,13 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, - expected_type: Option<&ast::Type>, + expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { let symbol = desc.op.0; if expected_type.is_none() { return Ok(symbol); }; - let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; + let (mut var_type, _, is_variable) = self.id_def.get_typed(symbol)?; if !is_variable { return Ok(symbol); }; @@ -2262,13 +2255,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { } None => None, }; - let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + let generated_id = self + .id_def + .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); if !desc.is_dst { self.func.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: symbol, }, + state_space: ast::StateSpace::Reg, typ: var_type, member_index, })); @@ -2279,6 +2275,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { src1: symbol, src2: generated_id, }, + state_space: ast::StateSpace::Reg, typ: var_type, member_index: member_index.map(|(idx, _)| idx), })); @@ -2293,7 +2290,7 @@ impl<'a, 'input> ArgumentMapVisitor fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.symbol(desc.new_op((desc.op, None)), typ) } @@ -2302,18 +2299,20 @@ impl<'a, 'input> ArgumentMapVisitor &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { TypedOperand::Reg(reg) => { - TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) - } - TypedOperand::RegOffset(reg, offset) => { - TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?) } + TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset( + self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?, + offset, + ), op @ TypedOperand::Imm(..) => op, - TypedOperand::VecMember(symbol, index) => { - TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) - } + TypedOperand::VecMember(symbol, index) => TypedOperand::Reg( + self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?, + ), }) } } @@ -2411,7 +2410,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg( &mut self, desc: ArgumentDescriptor, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(desc.op) } @@ -2420,30 +2419,31 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let (reg, offset) = desc.op; let add_type; match typ { - ast::Type::Pointer(underlying_type, state_space) => { - let reg_typ = self.id_def.get_typed(reg)?; - if let ast::Type::Pointer(_, _) = reg_typ { - let id_constant_stmt = self.id_def.new_non_variable(typ.clone()); + ast::Type::Pointer(underlying_type) => { + let (reg_typ, space) = self.id_def.get_typed(reg)?; + if let ast::Type::Pointer(..) = reg_typ { + let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::S64, value: ast::ImmediateValue::S64(offset as i64), })); - let dst = self.id_def.new_non_variable(typ.clone()); + let dst = self.id_def.register_intermediate(typ.clone(), space); self.func.push(Statement::PtrAccess(PtrAccess { - underlying_type: underlying_type.clone(), - state_space: *state_space, + underlying_type: *underlying_type, + state_space: state_space, dst, ptr_src: reg, offset_src: id_constant_stmt, })); return Ok(dst); } else { - add_type = self.id_def.get_typed(reg)?; + add_type = self.id_def.get_typed(reg)?.0; } } _ => { @@ -2475,8 +2475,12 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ast::ScalarKind::Unsigned, )) }; - let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); - let result_id = self.id_def.new_non_variable(add_type); + let id_constant_stmt = self + .id_def + .register_intermediate(add_type.clone(), ast::StateSpace::Reg); + let result_id = self + .id_def + .register_intermediate(add_type, ast::StateSpace::Reg); // TODO: check for edge cases around min value/max value/wrapping if offset < 0 && kind != ast::ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { @@ -2518,13 +2522,16 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { *scalar } else { todo!() }; - let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t)); + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, @@ -2538,7 +2545,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self.reg(desc, t) } @@ -2547,12 +2554,13 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { match desc.op { - TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space), TypedOperand::RegOffset(reg, offset) => { - self.reg_offset(desc.new_op((reg, offset)), typ) + self.reg_offset(desc.new_op((reg, offset)), typ, state_space) } TypedOperand::VecMember(..) => Err(error_unreachable()), } @@ -2580,39 +2588,29 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { - Statement::Call(call) => insert_implicit_conversions_impl( - &mut result, - id_def, - call, - should_bitcast_wrapper, - None, - )?, + Statement::Call(call) => { + insert_implicit_conversions_impl(&mut result, id_def, call, should_bitcast_wrapper)? + } Statement::Instruction(inst) => { let mut default_conversion_fn = - should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _; + should_bitcast_wrapper as for<'a> fn(&'a _, _, &'a _, _) -> _; let mut state_space = None; if let ast::Instruction::Ld(d, _) = &inst { state_space = Some(d.state_space); } if let ast::Instruction::St(d, _) = &inst { - state_space = Some(d.state_space.to_ld_ss()); + state_space = Some(d.state_space); } if let ast::Instruction::Atom(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); + state_space = Some(d.space); } if let ast::Instruction::AtomCas(d, _) = &inst { - state_space = Some(d.space.to_ld_ss()); + state_space = Some(d.space); } if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } - insert_implicit_conversions_impl( - &mut result, - id_def, - inst, - default_conversion_fn, - state_space, - )?; + insert_implicit_conversions_impl(&mut result, id_def, inst, default_conversion_fn)?; } Statement::PtrAccess(PtrAccess { underlying_type, @@ -2627,7 +2625,8 @@ fn insert_implicit_conversions( is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - typ: &ast::Type::Pointer(underlying_type.clone(), state_space), + typ: &ast::Type::Pointer(underlying_type), + state_space, stmt_ctor: |new_ptr_src| { Statement::PtrAccess(PtrAccess { underlying_type, @@ -2643,7 +2642,6 @@ fn insert_implicit_conversions( id_def, visit_desc, bitcast_physical_pointer, - Some(state_space), )?; } Statement::RepackVector(repack) => insert_implicit_conversions_impl( @@ -2651,7 +2649,6 @@ fn insert_implicit_conversions( id_def, repack, should_bitcast_wrapper, - None, )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) @@ -2672,19 +2669,20 @@ fn insert_implicit_conversions_impl( stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, + ast::StateSpace, &'a ast::Type, - Option, + ast::StateSpace, ) -> Result, TranslateError>, - state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit( - &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { - let instr_type = match typ { + let statement = + stmt.visit(&mut |desc: ArgumentDescriptor, + typ: Option<(&ast::Type, ast::StateSpace)>| { + let (instr_type, instruction_space) = match typ { None => return Ok(desc.op), Some(t) => t, }; - let operand_type = id_def.get_typed(desc.op)?; + let (operand_type, operand_space) = id_def.get_typed(desc.op)?; let mut conversion_fn = default_conversion_fn; match desc.sema { ArgumentSemantics::Default => {} @@ -2705,27 +2703,33 @@ fn insert_implicit_conversions_impl( conversion_fn = force_bitcast_ptr_to_bit; } }; - match conversion_fn(&operand_type, instr_type, state_space)? { + match conversion_fn(&operand_type, operand_space, instr_type, instruction_space)? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv } else { &mut *func }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); } conv_output.push(Statement::Conversion(ImplicitConversion { src, dst, - from, - to, + from_type, + from_space, + to_type, + to_space, kind: conv_kind, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -2734,8 +2738,7 @@ fn insert_implicit_conversions_impl( } None => Ok(desc.op), } - }, - )?; + })?; func.push(statement); func.append(&mut post_conv); Ok(()) @@ -2751,10 +2754,10 @@ fn get_function_type( builder, spirv_input .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), spirv_output .iter() - .map(|var| SpirvType::from(var.v_type.clone())), + .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)), ) } @@ -2782,8 +2785,8 @@ fn emit_function_body_ops( Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { - [(id, typ, _)] => ( - map.get_or_add(builder, SpirvType::from(typ.clone())), + [(id, typ, space)] => ( + map.get_or_add(builder, SpirvType::new(typ.clone(), *space)), Some(*id), ), [] => (map.void(), None), @@ -2915,8 +2918,10 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone()))); + let result_type = map.get_or_add( + builder, + SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space), + ); builder.load( result_type, Some(arg.dst), @@ -2947,8 +2952,10 @@ fn emit_function_body_ops( // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + let result_type = map.get_or_add( + builder, + SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg), + ); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { @@ -2989,7 +2996,8 @@ fn emit_function_body_ops( ast::Instruction::Shl(t, a) => { let full_type = ast::Type::Scalar(*t); let size_of = full_type.size_of(); - let result_type = map.get_or_add(builder, SpirvType::from(full_type)); + let result_type = + map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } @@ -3251,8 +3259,9 @@ fn emit_function_body_ops( Some(index) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer( + SpirvType::pointer_to( details.typ.clone(), + details.state_space, spirv::StorageClass::Function, ), ); @@ -3284,14 +3293,11 @@ fn emit_function_body_ops( }) => { let u8_pointer = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - *state_space, - )), + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space), ); let result_type = map.get_or_add( builder, - SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)), + SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( @@ -3503,11 +3509,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str { } } -fn ptx_space_name(space: ast::AtomSpace) -> &'static str { +fn ptx_space_name(space: ast::StateSpace) -> &'static str { match space { - ast::AtomSpace::Generic => "generic", - ast::AtomSpace::Global => "global", - ast::AtomSpace::Shared => "shared", + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", } } @@ -3572,6 +3583,7 @@ fn emit_variable( ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Const => todo!(), ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( @@ -3580,17 +3592,14 @@ fn emit_variable( &*var.array_init, )?) } else if must_init { - let type_id = map.get_or_add( - builder, - SpirvType::from(ast::Type::from(var.v_type.clone())), - ); + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space)); Some(builder.constant_null(type_id, None)) } else { None }; let ptr_type_id = map.get_or_add( builder, - SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), + SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class), ); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { @@ -3729,7 +3738,10 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add( + builder, + SpirvType::new(desc.get_type(), ast::StateSpace::Reg), + ); builder.ext_inst( inst_type, Some(arg.dst), @@ -3754,7 +3766,10 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); + let inst_type = map.get_or_add( + builder, + SpirvType::new(desc.get_type(), ast::StateSpace::Reg), + ); builder.ext_inst( inst_type, Some(arg.dst), @@ -3865,11 +3880,13 @@ fn emit_cvt( let cv = ImplicitConversion { src: arg.src, dst: new_dst, - from: ast::Type::Scalar(src_t), - to: ast::Type::Scalar(ast::ScalarType::from_parts( + from_type: ast::Type::Scalar(src_t), + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Scalar(ast::ScalarType::from_parts( dest_t.size_of(), src_t.kind(), )), + to_space: ast::StateSpace::Reg, kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, @@ -4224,20 +4241,24 @@ fn emit_implicit_conversion( map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), TranslateError> { - let from_parts = cv.from.to_parts(); - let to_parts = cv.to.to_parts(); + let from_parts = cv.from_type.to_parts(); + let to_parts = cv.to_type.to_parts(); match (from_parts.kind, to_parts.kind, cv.kind) { (_, _, ConversionKind::PtrToBit(typ)) => { let dst_type = map.get_or_add_scalar(builder, typ.into()); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } - (_, _, ConversionKind::BitToPtr(_)) => { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + (_, _, ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()), + ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let dst_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); if from_parts.scalar_kind != ast::ScalarKind::Float && to_parts.scalar_kind != ast::ScalarKind::Float { @@ -4247,13 +4268,16 @@ fn emit_implicit_conversion( builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } } else { - // This block is safe because it's illegal to implictly convert between floating point instructions + // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = map.get_or_add( builder, - SpirvType::from(ast::Type::from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - })), + SpirvType::new( + ast::Type::from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + }), + cv.from_space, + ), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { @@ -4261,7 +4285,7 @@ fn emit_implicit_conversion( ..to_parts }); let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space)); if to_parts.scalar_kind == ast::ScalarKind::Unsigned || to_parts.scalar_kind == ast::ScalarKind::Bit { @@ -4282,8 +4306,10 @@ fn emit_implicit_conversion( &ImplicitConversion { src: wide_bit_value, dst: cv.dst, - from: wide_bit_type, - to: cv.to.clone(), + from_type: wide_bit_type, + from_space: new_todo!(), + to_type: cv.to_type.clone(), + to_space: new_todo!(), kind: ConversionKind::Default, src_sema: cv.src_sema, dst_sema: cv.dst_sema, @@ -4293,13 +4319,15 @@ fn emit_implicit_conversion( } } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let result_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); builder.s_convert(result_type, Some(cv.dst), cv.src)?; } (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); + let into_type = + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { @@ -4307,12 +4335,12 @@ fn emit_implicit_conversion( map.get_or_add( builder, SpirvType::Pointer( - Box::new(SpirvType::from(cv.to.clone())), + Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)), spirv::StorageClass::Function, ), ) } else { - map.get_or_add(builder, SpirvType::from(cv.to.clone())) + map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space)) }; builder.bitcast(result_type, Some(cv.dst), cv.src)?; } @@ -4326,14 +4354,18 @@ fn emit_load_var( map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::new(details.typ.clone(), details.state_space), + ); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_type_spirv = + map.get_or_add(builder, SpirvType::new(vector_type, details.state_space)); let vector_temp = builder.load( vector_type_spirv, None, @@ -4351,7 +4383,11 @@ fn emit_load_var( Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, - SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + SpirvType::pointer_to( + details.typ.clone(), + details.state_space, + spirv::StorageClass::Function, + ), ); let index_spirv = map.get_or_add_constant( builder, @@ -4433,18 +4469,25 @@ fn expand_map_variables<'a, 'b>( is_variable = true; var_type } else { - var_type.param_pointer_to(ast::LdStateSpace::Shared)? + var_type.param_pointer_to(ast::StateSpace::Shared)? } } - ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, - ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, - ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, - ast::StateSpace::Const => todo!(), - ast::StateSpace::Generic => todo!(), + ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?, + ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?, + ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?, + ast::StateSpace::Const => new_todo!(), + ast::StateSpace::Generic => new_todo!(), + ast::StateSpace::Sreg => new_todo!(), }; match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) { + for new_id in id_defs.add_defs( + var.var.name, + count, + var_type, + var.var.state_space, + is_variable, + ) { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4455,7 +4498,11 @@ fn expand_map_variables<'a, 'b>( } } None => { - let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable); + let new_id = id_defs.add_def( + var.var.name, + Some((var_type, var.var.state_space)), + is_variable, + ); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), @@ -4470,11 +4517,42 @@ fn expand_map_variables<'a, 'b>( Ok(()) } +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ // TODO: detect more patterns (mov, call via reg, call via param) // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion // TODO: propagate through calls? +/* fn convert_to_stateful_memory_access<'a>( func_args: &mut SpirvMethodDecl, func_body: Vec, @@ -4496,9 +4574,9 @@ fn convert_to_stateful_memory_access<'a>( match statement { Statement::Instruction(ast::Instruction::Cvta( ast::CvtaDetails { - to: ast::CvtaStateSpace::Global, + to: ast::StateSpace::Global, size: ast::CvtaSize::U64, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, }, arg, )) => { @@ -4512,24 +4590,24 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::U64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::S64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { - state_space: ast::LdStateSpace::Param, - typ: ast::PointerType::Scalar(ast::ScalarType::B64), + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), .. }, arg, @@ -4611,19 +4689,16 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new(); let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { - let new_id = id_defs.new_variable(ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - )); + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8), + ast::StateSpace::Global, + ); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), - v_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ), - state_space: ast::StateSpace::Reg, + v_type: ast::Type::Pointer(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, })); remapped_ids.insert(reg, new_id); } @@ -4658,8 +4733,8 @@ fn convert_to_stateful_memory_access<'a>( }; let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, @@ -4686,7 +4761,7 @@ fn convert_to_stateful_memory_access<'a>( _ => return Err(error_unreachable()), }; let offset_neg = - id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); + id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, @@ -4699,8 +4774,8 @@ fn convert_to_stateful_memory_access<'a>( ))); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { - underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), - state_space: ast::LdStateSpace::Global, + underlying_type: ast::ScalarType::U8, + state_space: ast::StateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: TypedOperand::Reg(offset_neg), @@ -4768,10 +4843,8 @@ fn convert_to_stateful_memory_access<'a>( } for arg in func_args.input.iter_mut() { if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, - ); + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8); + arg.state_space = ast::StateSpace::Global; } } Ok(result) @@ -4790,21 +4863,21 @@ fn convert_to_stateful_memory_access_postprocess( Some(new_id) => { // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type_clone)); + let converting_id = id_defs.register_intermediate(Some(old_type_clone)); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, - from: old_type, - to: ast::Type::Pointer( + from_type: old_type, + to_type: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, + ast::StateSpace::Global, ), - kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global), + kind: ConversionKind::BitToPtr(ast::StateSpace::Global), src_sema: ArgumentSemantics::Default, dst_sema: arg_desc.sema, })); @@ -4813,11 +4886,11 @@ fn convert_to_stateful_memory_access_postprocess( result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from: ast::Type::Pointer( + from_type: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, + ast::StateSpace::Global, ), - to: old_type, + to_type: old_type, kind: ConversionKind::PtrToBit(ast::ScalarType::U64), src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, @@ -4832,19 +4905,19 @@ fn convert_to_stateful_memory_access_postprocess( } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), + Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_non_variable(Some(old_type)); + let converting_id = id_defs.register_intermediate(Some(old_type)); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from: ast::Type::Pointer( - ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global), - ast::LdStateSpace::Param, + from_type: ast::Type::Pointer( + ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, ), - to: old_type_clone, + to_type: old_type_clone, kind: ConversionKind::PtrToPtr { spirv_ptr: false }, src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, @@ -4855,6 +4928,7 @@ fn convert_to_stateful_memory_access_postprocess( }, }) } +*/ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { @@ -4876,9 +4950,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3 bool { match id_defs.get_typed(id) { - Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) - | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, _ => false, } } @@ -5007,7 +5081,7 @@ impl SpecialRegistersMap { struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, - variables_type_check: HashMap>, + variables_type_check: HashMap>, special_registers: SpecialRegistersMap, fns: HashMap, } @@ -5036,12 +5110,17 @@ impl<'a> GlobalStringIdResolver<'a> { &mut self, id: &'a str, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> spirv::Word { - self.get_or_add_impl(id, Some((typ, is_variable))) + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) } - fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word { + fn get_or_add_impl( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { @@ -5143,10 +5222,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, - global_type_check: &'b HashMap>, + global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, - type_check: HashMap>, + type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { @@ -5184,14 +5263,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { } } - fn add_def(&mut self, id: &'a str, typ: Option, is_variable: bool) -> spirv::Word { + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); - self.type_check - .insert(numeric_id, typ.map(|t| (t, is_variable))); + self.type_check.insert( + numeric_id, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); *self.current_id += 1; numeric_id } @@ -5202,6 +5288,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { base_id: &'a str, count: u32, typ: ast::Type, + state_space: ast::StateSpace, is_variable: bool, ) -> impl Iterator { let numeric_id = *self.current_id; @@ -5210,8 +5297,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check - .insert(numeric_id + i, Some((typ.clone(), is_variable))); + self.type_check.insert( + numeric_id + i, + Some((typ.clone(), state_space, is_variable)), + ); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -5220,8 +5309,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, - global_type_check: &'b HashMap>, - type_check: HashMap>, + global_type_check: &'b HashMap>, + type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } @@ -5230,12 +5319,15 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self } } - fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> { + fn get_typed( + &self, + id: spirv::Word, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), true)), + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), @@ -5246,16 +5338,18 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable // They are candidates for insertion of LoadVar/StoreVar - fn new_variable(&mut self, typ: ast::Type) -> spirv::Word { + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, Some((typ, true))); + self.type_check + .insert(new_id, Some((typ, state_space, true))); *self.current_id += 1; new_id } - fn new_non_variable(&mut self, typ: Option) -> spirv::Word { + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word { let new_id = *self.current_id; - self.type_check.insert(new_id, typ.map(|t| (t, false))); + self.type_check + .insert(new_id, typ.map(|(t, space)| (t, space, false))); *self.current_id += 1; new_id } @@ -5270,12 +5364,16 @@ impl<'b> MutableNumericIdResolver<'b> { self.base } - fn get_typed(&self, id: spirv::Word) -> Result { - self.base.get_typed(id).map(|(t, _)| t) + fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) } - fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word { - self.base.new_non_variable(Some(typ)) + fn register_intermediate( + &mut self, + typ: ast::Type, + state_space: ast::StateSpace, + ) -> spirv::Word { + self.base.register_intermediate(Some((typ, state_space))) } } @@ -5304,7 +5402,8 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { + .visit(&mut |arg: ArgumentDescriptor<_>, + _: Option<(&ast::Type, ast::StateSpace)>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), @@ -5391,6 +5490,7 @@ impl ExpandedStatement { struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, + state_space: ast::StateSpace, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors @@ -5402,6 +5502,7 @@ struct LoadVarDetails { struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, + state_space: ast::StateSpace, member_index: Option, } @@ -5428,7 +5529,10 @@ impl RepackVectorDetails { is_dst: !self.is_extract, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + Some(( + &ast::Type::Vector(self.typ, self.unpacked.len() as u8), + ast::StateSpace::Reg, + )), )?; let scalar_type = self.typ; let is_extract = self.is_extract; @@ -5443,7 +5547,7 @@ impl RepackVectorDetails { is_dst: is_extract, sema: vector_sema, }, - Some(&ast::Type::Scalar(scalar_type)), + Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)), ) }) .collect::>()?; @@ -5501,7 +5605,7 @@ impl> ResolvedCall { is_dst: space != ast::StateSpace::Param, sema: space.semantics(), }, - Some(&typ), + Some((&typ, space)), )?; Ok((new_id, typ, space)) }) @@ -5525,6 +5629,7 @@ impl> ResolvedCall { sema: space.semantics(), }, &typ, + space, )?; Ok((new_id, typ, space)) }) @@ -5555,22 +5660,22 @@ impl> PtrAccess

{ visitor: &mut V, ) -> Result, TranslateError> { let sema = match self.state_space { - ast::LdStateSpace::Const - | ast::LdStateSpace::Global - | ast::LdStateSpace::Shared - | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, - ast::LdStateSpace::Local | ast::LdStateSpace::Param => { - ArgumentSemantics::RegisterPointer - } + ast::StateSpace::Const + | ast::StateSpace::Global + | ast::StateSpace::Shared + | ast::StateSpace::Generic => ArgumentSemantics::PhysicalPointer, + ast::StateSpace::Local | ast::StateSpace::Param => ArgumentSemantics::RegisterPointer, + ast::StateSpace::Reg => new_todo!(), + ast::StateSpace::Sreg => new_todo!(), }; - let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space); + let ptr_type = ast::Type::Pointer(self.underlying_type.clone()); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_ptr_src = visitor.id( ArgumentDescriptor { @@ -5578,7 +5683,7 @@ impl> PtrAccess

{ is_dst: false, sema, }, - Some(&ptr_type), + Some((&ptr_type, self.state_space)), )?; let new_constant_src = visitor.operand( ArgumentDescriptor { @@ -5587,6 +5692,7 @@ impl> PtrAccess

{ sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::S64), + self.state_space, )?; Ok(PtrAccess { underlying_type: self.underlying_type, @@ -5723,12 +5829,13 @@ pub trait ArgumentMapVisitor { fn id( &mut self, desc: ArgumentDescriptor, - typ: Option<&ast::Type>, + typ: Option<(&ast::Type, ast::StateSpace)>, ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result; } @@ -5736,13 +5843,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -5751,8 +5858,9 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { - self(desc, Some(typ)) + self(desc, Some((typ, state_space))) } } @@ -5763,7 +5871,7 @@ where fn id( &mut self, desc: ArgumentDescriptor<&str>, - _: Option<&ast::Type>, + _: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc.op) } @@ -5772,6 +5880,7 @@ where &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result, TranslateError> { Ok(match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), @@ -5780,7 +5889,7 @@ where ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( ids.into_iter() - .map(|id| self.id(desc.new_op(id), Some(typ))) + .map(|id| self.id(desc.new_op(id), Some((typ, state_space)))) .collect::, _>>()?, ), }) @@ -5794,8 +5903,8 @@ pub struct ArgumentDescriptor { } pub struct PtrAccess { - underlying_type: ast::PointerType, - state_space: ast::LdStateSpace, + underlying_type: ast::ScalarType, + state_space: ast::StateSpace, dst: spirv::Word, ptr_src: spirv::Word, offset_src: P::Operand, @@ -6061,7 +6170,7 @@ impl ImplicitConversion { is_dst: true, sema: self.dst_sema, }, - Some(&self.to), + Some((&self.to_type, self.to_space)), )?; let new_src = visitor.id( ArgumentDescriptor { @@ -6069,7 +6178,7 @@ impl ImplicitConversion { is_dst: false, sema: self.src_sema, }, - Some(&self.from), + Some((&self.from_type, self.from_space)), )?; Ok(Statement::Conversion({ ImplicitConversion { @@ -6096,13 +6205,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option<&ast::Type>, + Option<(&ast::Type, ast::StateSpace)>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { self(desc, t) } @@ -6111,12 +6220,15 @@ where &mut self, desc: ArgumentDescriptor, typ: &ast::Type, + state_space: ast::StateSpace, ) -> Result { Ok(match desc.op { - TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Reg(id) => { + TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?) + } TypedOperand::Imm(imm) => TypedOperand::Imm(imm), TypedOperand::RegOffset(id, imm) => { - TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm) } TypedOperand::VecMember(reg, index) => { let scalar_type = match typ { @@ -6124,7 +6236,10 @@ where _ => return Err(error_unreachable()), }; let vec_type = ast::Type::Vector(scalar_type, index + 1); - TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) + TypedOperand::VecMember( + self(desc.new_op(reg), Some((&vec_type, state_space)))?, + index, + ) } }) } @@ -6159,54 +6274,25 @@ impl ast::Type { scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: ast::LdStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], - state_space: ast::LdStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), - state_space: ast::LdStateSpace::Global, }, - ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts { + ast::Type::Pointer(scalar) => TypeParts { kind: TypeKind::PointerScalar, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: *state_space, }, - ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts { - kind: TypeKind::PointerVector, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*len as u32], - state_space: *state_space, - }, - ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => { - TypeParts { - kind: TypeKind::PointerArray, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - state_space: *state_space, - } - } - ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => { - TypeParts { - kind: TypeKind::PointerPointer, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*inner_space as u32], - state_space: *state_space, - } - } } } @@ -6223,31 +6309,9 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::PointerScalar => ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), - t.state_space, - ), - TypeKind::PointerVector => ast::Type::Pointer( - ast::PointerType::Vector( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components[0] as u8, - ), - t.state_space, - ), - TypeKind::PointerArray => ast::Type::Pointer( - ast::PointerType::Array( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components, - ), - t.state_space, - ), - TypeKind::PointerPointer => ast::Type::Pointer( - ast::PointerType::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), - unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) }, - ), - t.state_space, - ), + TypeKind::PointerScalar => { + ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind)) + } } } @@ -6258,7 +6322,7 @@ impl ast::Type { ast::Type::Array(typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(_, _) => mem::size_of::(), + ast::Type::Pointer(..) => mem::size_of::(), } } } @@ -6269,7 +6333,6 @@ struct TypeParts { scalar_kind: ast::ScalarKind, width: u8, components: Vec, - state_space: ast::LdStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -6278,9 +6341,6 @@ enum TypeKind { Vector, Array, PointerScalar, - PointerVector, - PointerArray, - PointerPointer, } impl ast::Instruction { @@ -6408,8 +6468,10 @@ struct BrachCondition { struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, - from: ast::Type, - to: ast::Type, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, kind: ConversionKind, src_sema: ArgumentSemantics, dst_sema: ArgumentSemantics, @@ -6420,7 +6482,7 @@ enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, - BitToPtr(ast::LdStateSpace), + BitToPtr, PtrToBit(ast::ScalarType), PtrToPtr { spirv_ptr: bool }, } @@ -6470,7 +6532,7 @@ impl ast::Arg1 { fn map>( self, visitor: &mut V, - t: Option<&ast::Type>, + t: Option<(&ast::Type, ast::StateSpace)>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { @@ -6496,6 +6558,7 @@ impl ast::Arg1Bar { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg1Bar { src: new_src }) } @@ -6514,6 +6577,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let new_src = visitor.operand( ArgumentDescriptor { @@ -6522,6 +6586,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst: new_dst, @@ -6542,6 +6607,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, dst_t, + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6550,6 +6616,7 @@ impl ast::Arg2 { sema: ArgumentSemantics::Default, }, src_t, + ast::StateSpace::Reg, )?; Ok(ast::Arg2 { dst, src }) } @@ -6568,9 +6635,10 @@ impl ast::Arg2Ld { sema: ArgumentSemantics::DefaultRelaxed, }, &ast::Type::from(details.typ.clone()), + ast::StateSpace::Reg, )?; - let is_logical_ptr = details.state_space == ast::LdStateSpace::Param - || details.state_space == ast::LdStateSpace::Local; + let is_logical_ptr = details.state_space == ast::StateSpace::Param + || details.state_space == ast::StateSpace::Local; let src = visitor.operand( ArgumentDescriptor { op: self.src, @@ -6581,10 +6649,8 @@ impl ast::Arg2Ld { ArgumentSemantics::PhysicalPointer }, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space, - ), + &details.typ, + details.state_space, )?; Ok(ast::Arg2Ld { dst, src }) } @@ -6596,8 +6662,8 @@ impl ast::Arg2St { visitor: &mut V, details: &ast::StData, ) -> Result, TranslateError> { - let is_logical_ptr = details.state_space == ast::StStateSpace::Param - || details.state_space == ast::StStateSpace::Local; + let is_logical_ptr = details.state_space == ast::StateSpace::Param + || details.state_space == ast::StateSpace::Local; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, @@ -6608,10 +6674,8 @@ impl ast::Arg2St { ArgumentSemantics::PhysicalPointer }, }, - &ast::Type::Pointer( - ast::PointerType::from(details.typ.clone()), - details.state_space.to_ld_ss(), - ), + &details.typ, + details.state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6620,6 +6684,7 @@ impl ast::Arg2St { sema: ArgumentSemantics::DefaultRelaxed, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2St { src1, src2 }) } @@ -6638,6 +6703,7 @@ impl ast::Arg2Mov { sema: ArgumentSemantics::Default, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6650,6 +6716,7 @@ impl ast::Arg2Mov { }, }, &details.typ.clone().into(), + ast::StateSpace::Reg, )?; Ok(ast::Arg2Mov { dst, src }) } @@ -6674,6 +6741,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(typ), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6682,6 +6750,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6690,6 +6759,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6706,6 +6776,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6714,6 +6785,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6722,6 +6794,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6730,7 +6803,7 @@ impl ast::Arg3 { self, visitor: &mut V, t: ast::ScalarType, - state_space: ast::AtomSpace, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( @@ -6740,6 +6813,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6747,10 +6821,8 @@ impl ast::Arg3 { is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + state_space, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6759,6 +6831,7 @@ impl ast::Arg3 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -6783,6 +6856,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(t), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6791,6 +6865,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6799,6 +6874,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6807,6 +6883,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6828,6 +6905,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6836,6 +6914,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6844,6 +6923,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6852,6 +6932,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6865,7 +6946,7 @@ impl ast::Arg4 { self, visitor: &mut V, t: ast::ScalarType, - state_space: ast::AtomSpace, + state_space: ast::StateSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( @@ -6875,6 +6956,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6882,10 +6964,8 @@ impl ast::Arg4 { is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, - &ast::Type::Pointer( - ast::PointerType::Scalar(scalar_type), - state_space.to_ld_ss(), - ), + &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -6894,6 +6974,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6902,6 +6983,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6923,6 +7005,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6931,6 +7014,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, typ, + ast::StateSpace::Reg, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); let src2 = visitor.operand( @@ -6940,6 +7024,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &u32_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -6948,6 +7033,7 @@ impl ast::Arg4 { sema: ArgumentSemantics::Default, }, &u32_type, + ast::StateSpace::Reg, )?; Ok(ast::Arg4 { dst, @@ -6970,7 +7056,10 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -6981,7 +7070,10 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -6992,6 +7084,7 @@ impl ast::Arg4Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7000,6 +7093,7 @@ impl ast::Arg4Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; Ok(ast::Arg4Setp { dst1, @@ -7023,6 +7117,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -7031,6 +7126,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7039,6 +7135,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, base_type, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -7047,6 +7144,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; let src4 = visitor.operand( ArgumentDescriptor { @@ -7055,6 +7153,7 @@ impl ast::Arg5 { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, )?; Ok(ast::Arg5 { dst, @@ -7078,7 +7177,10 @@ impl ast::Arg5Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), )?; let dst2 = self .dst2 @@ -7089,7 +7191,10 @@ impl ast::Arg5Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(ast::ScalarType::Pred)), + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), ) }) .transpose()?; @@ -7100,6 +7205,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src2 = visitor.operand( ArgumentDescriptor { @@ -7108,6 +7214,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, t, + ast::StateSpace::Reg, )?; let src3 = visitor.operand( ArgumentDescriptor { @@ -7116,6 +7223,7 @@ impl ast::Arg5Setp { sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, )?; Ok(ast::Arg5Setp { dst1, @@ -7153,18 +7261,6 @@ impl ast::Operand { } } -impl ast::StStateSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::StStateSpace::Generic => ast::LdStateSpace::Generic, - ast::StStateSpace::Global => ast::LdStateSpace::Global, - ast::StStateSpace::Local => ast::LdStateSpace::Local, - ast::StStateSpace::Param => ast::LdStateSpace::Param, - ast::StStateSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - impl ast::ScalarType { fn from_parts(width: u8, kind: ast::ScalarKind) -> Self { match kind { @@ -7255,15 +7351,17 @@ impl ast::AtomInnerDetails { } } -impl ast::LdStateSpace { +impl ast::StateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { - ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant, - ast::LdStateSpace::Generic => spirv::StorageClass::Generic, - ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::LdStateSpace::Local => spirv::StorageClass::Function, - ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup, - ast::LdStateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, } } } @@ -7289,16 +7387,6 @@ impl ast::MulDetails { } } -impl ast::AtomSpace { - fn to_ld_ss(self) -> ast::LdStateSpace { - match self { - ast::AtomSpace::Generic => ast::LdStateSpace::Generic, - ast::AtomSpace::Global => ast::LdStateSpace::Global, - ast::AtomSpace::Shared => ast::LdStateSpace::Shared, - } - } -} - impl ast::MemScope { fn to_spirv(self) -> spirv::Scope { match self { @@ -7333,89 +7421,44 @@ impl ast::StateSpace { fn bitcast_register_pointer( operand_type: &ast::Type, + operand_space: ast::StateSpace, instr_type: &ast::Type, - ss: Option, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { - bitcast_physical_pointer(operand_type, instr_type, ss) + bitcast_physical_pointer(operand_type, operand_space, instr_type, instruction_space) } fn bitcast_physical_pointer( operand_type: &ast::Type, - instr_type: &ast::Type, - ss: Option, + operand_space: ast::StateSpace, + instruction_type: &ast::Type, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { - match operand_type { - // array decays to a pointer - ast::Type::Array(op_scalar_t, _) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if ss == Some(*instr_space) { - if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } - } else { - if ss == Some(ast::LdStateSpace::Generic) - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } - } - } else { - Err(TranslateError::MismatchedType) - } + if operand_space == instruction_space { + if operand_type != instruction_type { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } else { + Ok(None) } - ast::Type::Scalar(ast::ScalarType::B64) - | ast::Type::Scalar(ast::ScalarType::U64) - | ast::Type::Scalar(ast::ScalarType::S64) => { - if let Some(space) = ss { - Ok(Some(ConversionKind::BitToPtr(space))) - } else { - Err(error_unreachable()) - } - } - ast::Type::Scalar(ast::ScalarType::B32) - | ast::Type::Scalar(ast::ScalarType::U32) - | ast::Type::Scalar(ast::ScalarType::S32) => match ss { - Some(ast::LdStateSpace::Shared) - | Some(ast::LdStateSpace::Generic) - | Some(ast::LdStateSpace::Param) - | Some(ast::LdStateSpace::Local) => { - Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared))) - } + } else { + match operand_space { + ast::StateSpace::Reg | ast::StateSpace::Sreg => match instruction_space { + ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Shared + | ast::StateSpace::Local => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(TranslateError::MismatchedType), + }, _ => Err(TranslateError::MismatchedType), - }, - ast::Type::Pointer(op_scalar_t, op_space) => { - if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { - if op_space == instr_space { - if op_scalar_t == instr_scalar_t { - Ok(None) - } else { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } - } else { - if *op_space == ast::LdStateSpace::Generic - || *instr_space == ast::LdStateSpace::Generic - { - Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) - } else { - Err(TranslateError::MismatchedType) - } - } - } else { - Err(TranslateError::MismatchedType) - } } - _ => Err(TranslateError::MismatchedType), } } fn force_bitcast_ptr_to_bit( _: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { // TODO: verify this on f32, u16 and the like if let ast::Type::Scalar(scalar_t) = instr_type { @@ -7457,11 +7500,12 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { fn should_bitcast_packed( operand: &ast::Type, - instr: &ast::Type, - ss: Option, + operand_space: ast::StateSpace, + instruction: &ast::Type, + instruction_space: ast::StateSpace, ) -> Result, TranslateError> { if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = - (operand, instr) + (operand, instruction) { if scalar.kind() == ast::ScalarKind::Bit && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) @@ -7469,13 +7513,14 @@ fn should_bitcast_packed( return Ok(Some(ConversionKind::Default)); } } - should_bitcast_wrapper(operand, instr, ss) + should_bitcast_wrapper(operand, operand_space, instruction, instruction_space) } fn should_bitcast_wrapper( operand: &ast::Type, + _: ast::StateSpace, instr: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if instr == operand { return Ok(None); @@ -7489,8 +7534,9 @@ fn should_bitcast_wrapper( fn should_convert_relaxed_src_wrapper( src_type: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if src_type == instr_type { return Ok(None); @@ -7552,8 +7598,9 @@ fn should_convert_relaxed_src( fn should_convert_relaxed_dst_wrapper( dst_type: &ast::Type, + _: ast::StateSpace, instr_type: &ast::Type, - _: Option, + _: ast::StateSpace, ) -> Result, TranslateError> { if dst_type == instr_type { return Ok(None);