diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 559805e..5ced5d0 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -1,6 +1,6 @@ use crate::sys; use std::{ - ffi::{c_void, CStr}, + ffi::{c_void, CStr, CString}, fmt::Debug, marker::PhantomData, mem, ptr, @@ -238,23 +238,16 @@ impl Drop for CommandQueue { pub struct Module(sys::ze_module_handle_t); impl Module { - pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t { - self.0 - } - pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self { - Self(x) - } - pub fn new_spirv( ctx: &mut Context, d: &Device, bin: &[u8], opts: Option<&CStr>, - ) -> Result { + ) -> (Result, BuildLog) { Module::new(ctx, true, d, bin, opts) } - pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> Result { + pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result, BuildLog) { Module::new(ctx, false, d, bin, None) } @@ -264,7 +257,7 @@ impl Module { d: &Device, bin: &[u8], opts: Option<&CStr>, - ) -> Result { + ) -> (Result, BuildLog) { let desc = sys::ze_module_desc_t { stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_MODULE_DESC, pNext: ptr::null(), @@ -279,14 +272,14 @@ impl Module { pConstants: ptr::null(), }; let mut result: sys::ze_module_handle_t = ptr::null_mut(); - check!(sys::zeModuleCreate( - ctx.0, - d.0, - &desc, - &mut result, - ptr::null_mut() - )); - Ok(Module(result)) + let mut log_handle = ptr::null_mut(); + let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) }; + let log = BuildLog(log_handle); + if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS { + (Result::Err(err), log) + } else { + (Ok(Module(result)), log) + } } } @@ -297,6 +290,32 @@ impl Drop for Module { } } +pub struct BuildLog(sys::ze_module_build_log_handle_t); + +impl BuildLog { + pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t { + self.0 + } + pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self { + Self(x) + } + + pub fn get_cstring(&self) -> Result { + let mut size = 0; + check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) }; + let mut str_vec = vec![0u8; size]; + check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) }; + str_vec.pop(); + Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?) + } +} + +impl Drop for BuildLog { + fn drop(&mut self) { + check_panic!(sys::zeModuleBuildLogDestroy(self.0)); + } +} + pub trait SafeRepr {} impl SafeRepr for u8 {} impl SafeRepr for i8 {} diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs index 6f8773e..0ab3bea 100644 --- a/notcuda/src/impl/function.rs +++ b/notcuda/src/impl/function.rs @@ -1,7 +1,7 @@ use ::std::os::raw::{c_uint, c_void}; use std::ptr; -use super::{context, device, stream::Stream, CUresult}; +use super::{device, stream::Stream, CUresult}; pub struct Function { pub base: l0::Kernel<'static>, diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs index 3f92b5e..439b26f 100644 --- a/notcuda/src/impl/memory.rs +++ b/notcuda/src/impl/memory.rs @@ -46,7 +46,7 @@ unsafe fn memcpy_impl( Ok(()) } -pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> { +pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> { Ok(()) } diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 3d31da2..5a72ce4 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,4 +1,4 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st}; +use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; #[cfg(test)] diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index fc55f33..eea862b 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -1,13 +1,10 @@ use std::{ - collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, - sync::Mutex, + collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex, }; use super::{function::Function, transmute_lifetime, CUresult}; use ptx; -use super::context; - pub type Module = Mutex; pub struct ModuleData { @@ -67,14 +64,14 @@ impl ModuleData { l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None) }); match module { - Ok(Ok(module)) => Ok(Mutex::new(Self { + Ok((Ok(module), _)) => Ok(Mutex::new(Self { base: module, arg_lens: all_arg_lens .into_iter() .map(|(k, v)| (CString::new(k).unwrap(), v)) .collect(), })), - Ok(Err(err)) => Err(ModuleCompileError::from(err)), + Ok((Err(err), _)) => Err(ModuleCompileError::from(err)), Err(err) => Err(ModuleCompileError::from(err)), } } @@ -116,6 +113,6 @@ pub fn get_function( Ok(()) } -pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> { +pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> { Ok(()) } diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs index 7410100..1844677 100644 --- a/notcuda/src/impl/stream.rs +++ b/notcuda/src/impl/stream.rs @@ -30,7 +30,7 @@ mod tests { use super::super::test::CudaDriverFns; use super::super::CUresult; - use std::{ffi::c_void, ptr}; + use std::ptr; const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; @@ -41,7 +41,7 @@ mod tests { fn default_stream_uses_current_ctx_legacy() { default_stream_uses_current_ctx_impl::(CU_STREAM_LEGACY); } - + fn default_stream_uses_current_ctx_ptsd() { default_stream_uses_current_ctx_impl::(CU_STREAM_PER_THREAD); } diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs index d4366b7..dbd2eff 100644 --- a/notcuda/src/impl/test.rs +++ b/notcuda/src/impl/test.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] -use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda}; +use crate::{cuda::CUstream, r#impl as notcuda}; use crate::r#impl::CUresult; use crate::{cuda::CUuuid, r#impl::Encuda}; use ::std::{ diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 048d43a..c6510da 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,8 @@ -use std::convert::From; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; +use half::f16; + quick_error! { #[derive(Debug)] pub enum PtxError { @@ -9,11 +11,17 @@ quick_error! { display("{}", err) cause(err) } + ParseFloat (err: ParseFloatError) { + from() + display("{}", err) + cause(err) + } SyntaxError {} NonF32Ftz {} WrongArrayType {} WrongVectorElement {} MultiArrayVariable {} + ZeroDimensionArray {} } } @@ -53,7 +61,7 @@ macro_rules! sub_scalar_type { macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone, Copy)] + #[derive(PartialEq, Eq, Clone)] pub enum $type_name { $( $variant ($($field_type),+), @@ -80,11 +88,13 @@ sub_type! { } } +type VecU32 = Vec; + sub_type! { VariableLocalType { Scalar(SizedScalarType), Vector(SizedScalarType, u8), - Array(SizedScalarType, u32), + Array(SizedScalarType, VecU32), } } @@ -95,7 +105,7 @@ sub_type! { sub_type! { VariableParamType { Scalar(ParamScalarType), - Array(SizedScalarType, u32), + Array(SizedScalarType, VecU32), } } @@ -169,7 +179,12 @@ impl< pub struct Module<'a> { pub version: (u8, u8), - pub functions: Vec>, + pub directives: Vec>>, +} + +pub enum Directive<'a, P: ArgParams> { + Variable(Variable), + Method(Function<'a, &'a str, Statement

>), } pub enum MethodDecl<'a, ID> { @@ -187,7 +202,7 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement>>; -#[derive(PartialEq, Eq, Clone, Copy)] +#[derive(PartialEq, Eq, Clone)] pub enum FnArgumentType { Reg(VariableRegType), Param(VariableParamType), @@ -202,11 +217,11 @@ impl From for Type { } } -#[derive(PartialEq, Eq, Hash, Clone, Copy)] +#[derive(PartialEq, Eq, Hash, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), - Array(ScalarType, u32), + Array(ScalarType, Vec), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -274,6 +289,30 @@ sub_scalar_type!(FloatType { F64 }); +impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 => 1, + ScalarType::S8 => 1, + ScalarType::B8 => 1, + ScalarType::U16 => 2, + ScalarType::S16 => 2, + ScalarType::B16 => 2, + ScalarType::F16 => 2, + ScalarType::U32 => 4, + ScalarType::S32 => 4, + ScalarType::B32 => 4, + ScalarType::F32 => 4, + ScalarType::U64 => 8, + ScalarType::S64 => 8, + ScalarType::B64 => 8, + ScalarType::F64 => 8, + ScalarType::F16x2 => 4, + ScalarType::Pred => 1, + } + } +} + impl Default for ScalarType { fn default() -> Self { ScalarType::B8 @@ -296,13 +335,26 @@ pub struct Variable { pub align: Option, pub v_type: T, pub name: ID, + pub array_init: Vec, } -#[derive(Eq, PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Clone)] pub enum VariableType { Reg(VariableRegType), Local(VariableLocalType), Param(VariableParamType), + Global(VariableLocalType), +} + +impl VariableType { + pub fn to_type(&self) -> (StateSpace, Type) { + match self { + VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), + VariableType::Local(t) => (StateSpace::Local, t.clone().into()), + VariableType::Param(t) => (StateSpace::Param, t.clone().into()), + VariableType::Global(t) => (StateSpace::Global, t.clone().into()), + } + } } impl From for Type { @@ -311,6 +363,7 @@ impl From for Type { VariableType::Reg(t) => t.into(), VariableType::Local(t) => t.into(), VariableType::Param(t) => t.into(), + VariableType::Global(t) => t.into(), } } } @@ -318,7 +371,6 @@ impl From for Type { #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, - Sreg, Const, Global, Local, @@ -538,7 +590,7 @@ pub enum LdCacheOperator { Uncached, } -#[derive(Copy, Clone)] +#[derive(Clone)] pub struct MovDetails { pub typ: Type, pub src_is_address: bool, @@ -846,3 +898,194 @@ pub struct MinMaxFloat { pub nan: bool, pub typ: FloatType, } + +pub enum NumsOrArrays<'a> { + Nums(Vec<&'a str>), + Arrays(Vec>), +} + +impl<'a> NumsOrArrays<'a> { + pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result, PtxError> { + self.normalize_dimensions(dimensions)?; + let sizeof_t = ScalarType::from(typ).size_of() as usize; + let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); + let mut result = vec![0; result_size]; + self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; + Ok(result) + } + + fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { + match dimensions.first_mut() { + Some(first) => { + if *first == 0 { + *first = match self { + NumsOrArrays::Nums(v) => v.len() as u32, + NumsOrArrays::Arrays(v) => v.len() as u32, + }; + } + } + None => return Err(PtxError::ZeroDimensionArray), + } + for dim in dimensions { + if *dim == 0 { + return Err(PtxError::ZeroDimensionArray); + } + } + Ok(()) + } + + fn parse_and_copy( + &self, + t: SizedScalarType, + size_of_t: usize, + dimensions: &[u32], + result: &mut [u8], + ) -> Result<(), PtxError> { + match dimensions { + [] => unreachable!(), + [dim] => match self { + NumsOrArrays::Nums(vec) => { + if vec.len() > *dim as usize { + return Err(PtxError::ZeroDimensionArray); + } + for (idx, val) in vec.iter().enumerate() { + Self::parse_and_copy_single(t, idx, val, result)?; + } + } + NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), + }, + [first_dim, rest @ ..] => match self { + NumsOrArrays::Arrays(vec) => { + if vec.len() > *first_dim as usize { + return Err(PtxError::ZeroDimensionArray); + } + let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); + for (idx, this) in vec.iter().enumerate() { + this.parse_and_copy( + t, + size_of_t, + rest, + &mut result[(size_of_element * idx)..], + )?; + } + } + NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), + }, + } + Ok(()) + } + + fn parse_and_copy_single( + t: SizedScalarType, + idx: usize, + str_val: &str, + output: &mut [u8], + ) -> Result<(), PtxError> { + match t { + SizedScalarType::B8 | SizedScalarType::U8 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::B16 | SizedScalarType::U16 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::B32 | SizedScalarType::U32 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::B64 | SizedScalarType::U64 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::S8 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::S16 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::S32 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::S64 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::F16 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::F16x2 => todo!(), + SizedScalarType::F32 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + SizedScalarType::F64 => { + Self::parse_and_copy_single_t::(idx, str_val, output)?; + } + } + Ok(()) + } + + fn parse_and_copy_single_t( + idx: usize, + str_val: &str, + output: &mut [u8], + ) -> Result<(), PtxError> + where + T::Err: Into, + { + let typed_output = unsafe { + std::slice::from_raw_parts_mut::( + output.as_mut_ptr() as *mut _, + output.len() / mem::size_of::(), + ) + }; + typed_output[idx] = str_val.parse::().map_err(|e| e.into())?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn array_fails_multiple_0_dmiensions() { + let inp = NumsOrArrays::Nums(Vec::new()); + assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + } + + #[test] + fn array_fails_on_empty() { + let inp = NumsOrArrays::Nums(Vec::new()); + assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + } + + #[test] + fn array_auto_sizes_0_dimension() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2"]), + NumsOrArrays::Nums(vec!["3", "4"]), + ]); + let mut dimensions = vec![0u32, 2]; + assert_eq!( + vec![1u8, 2, 3, 4], + inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + ); + assert_eq!(dimensions, vec![2u32, 2]); + } + + #[test] + fn array_fails_wrong_structure() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2"]), + NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]), + ]); + let mut dimensions = vec![0u32, 2]; + assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + } + + #[test] + fn array_fails_too_long_component() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2", "3"]), + NumsOrArrays::Nums(vec!["4", "5"]), + ]); + let mut dimensions = vec![0u32, 2]; + assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + } +} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 2c0e365..0b6fa0f 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -2,6 +2,8 @@ use crate::ast; use crate::ast::UnwrapWithVec; use crate::{without_none, vector_index}; +use lalrpop_util::ParseError; + grammar<'a>(errors: &mut Vec); extern { @@ -27,6 +29,7 @@ match { "{", "}", "<", ">", "|", + "=", ".acquire", ".address_size", ".align", @@ -94,7 +97,6 @@ match { ".sat", ".section", ".shared", - ".sreg", ".sys", ".target", ".to", @@ -176,8 +178,8 @@ ExtendedID : &'input str = { } pub Module: ast::Module<'input> = { - Target => { - ast::Module { version: v, functions: without_none(f) } + Target => { + ast::Module { version: v, directives: without_none(d) } } }; @@ -203,11 +205,12 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option>>> = { +Directive: Option>> = { AddressSize => None, - => Some(f), + => Some(ast::Directive::Method(f)), File => None, - Section => None + Section => None, + ";" => Some(ast::Directive::Variable(v)), }; AddressSize = { @@ -242,9 +245,9 @@ FnArguments: Vec> = { }; KernelInput: ast::Variable = { - => { + => { let (align, v_type, name) = v; - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } } } @@ -252,12 +255,12 @@ FnInput: ast::Variable = { => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } }, - => { + => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } } } @@ -268,7 +271,6 @@ pub(crate) FunctionBody: Option> StateSpaceSpecifier: ast::StateSpace = { ".reg" => ast::StateSpace::Reg, - ".sreg" => ast::StateSpace::Sreg, ".const" => ast::StateSpace::Const, ".global" => ast::StateSpace::Global, ".local" => ast::StateSpace::Local, @@ -344,13 +346,13 @@ Variable: ast::Variable = { => { let (align, v_type, name) = v; let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name} + ast::Variable {align, v_type, name, array_init: Vec::new()} }, LocalVariable, => { - let (align, v_type, name) = v; + let (align, array_init, v_type, name) = v; let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name} + ast::Variable {align, v_type, name, array_init} }, }; @@ -366,32 +368,60 @@ RegVariable: (Option, ast::VariableRegType, &'input str) = { } LocalVariable: ast::Variable = { - ".local" => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable {align, v_type, name} - }, - ".local" => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable {align, v_type, name} - }, - ".local" => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr)); - ast::Variable {align, v_type, name} + ".local" => { + let (align, array_init, v_type, name) = def; + ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init } + } +} + +GlobalVariable: ast::Variable = { + ".global" => { + let (align, array_init, v_type, name) = def; + ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option, ast::VariableParamType, &'input str) = { +ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { + ".param" => { + let v_type = ast::VariableParamType::Scalar(t); + (align, Vec::new(), v_type, name) + }, + ".param" => { + let (array_init, name, (t, dimensions)) = arr; + let v_type = ast::VariableParamType::Array(t, dimensions); + (align, array_init, v_type, name) + } +} + +ParamDeclaration: (Option, ast::VariableParamType, &'input str) = { ".param" => { let v_type = ast::VariableParamType::Scalar(t); (align, v_type, name) }, - ".param" => { - let v_type = ast::VariableParamType::Array(t, arr); + ".param" => { + let (name, (t, dimensions)) = arr; + let v_type = ast::VariableParamType::Array(t, dimensions); (align, v_type, name) } } +LocalVariableDefinition: (Option, Vec, ast::VariableLocalType, &'input str) = { + => { + let v_type = ast::VariableLocalType::Scalar(t); + (align, Vec::new(), v_type, name) + }, + => { + let v_type = ast::VariableLocalType::Vector(t, v_len); + (align, Vec::new(), v_type, name) + }, + => { + let (array_init, name, (t, dimensions)) = arr; + let v_type = ast::VariableLocalType::Array(t, dimensions); + (align, array_init, v_type, name) + } +} + #[inline] SizedScalarType: ast::SizedScalarType = { ".b8" => ast::SizedScalarType::B8, @@ -431,12 +461,59 @@ ParamScalarType: ast::ParamScalarType = { ".f64" => ast::ParamScalarType::F64, } -ArraySpecifier: u32 = { - "[" "]" => { - let size = n.parse::(); - size.unwrap_with(errors) +ArrayDefinition: (Vec, &'input str, (ast::SizedScalarType, Vec)) = { + =>? { + let mut dims = dims; + let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?; + Ok(( + array_init, + name, + (typ, dims) + )) } -}; +} + +ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec)) = { + =>? { + let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::>()?; + Ok((name, (typ, dims))) + } +} + +// [0] and [] are treated the same +ArrayDimensions: Vec = { + ArrayEmptyDimension => vec![0u32], + ArrayEmptyDimension => { + let mut dims = dims; + let mut result = vec![0u32]; + result.append(&mut dims); + result + }, + => dims +} + +ArrayEmptyDimension = { + "[" "]" +} + +ArrayDimension: u32 = { + "[" "]" =>? { + str::parse::(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) + } +} + +ArrayInitializer: ast::NumsOrArrays<'input> = { + "=" => nums +} + +NumsOrArraysBracket: ast::NumsOrArrays<'input> = { + "{" "}" => nums +} + +NumsOrArrays: ast::NumsOrArrays<'input> = { + > => ast::NumsOrArrays::Arrays(n), + > => ast::NumsOrArrays::Nums(n), +} Instruction: ast::Instruction> = { InstLd, @@ -1244,3 +1321,11 @@ Comma: Vec = { } } }; + +CommaNonEmpty: Vec = { + ",")*> => { + let mut v = v; + v.push(e); + v + } +}; diff --git a/ptx/src/test/spirv_run/global_array.ptx b/ptx/src/test/spirv_run/global_array.ptx new file mode 100644 index 0000000..7ac8bce --- /dev/null +++ b/ptx/src/test/spirv_run/global_array.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.global .s32 foobar[4] = {1}; + +.visible .entry global_array( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 temp; + + mov.u64 in_addr, foobar; + ld.param.u64 out_addr, [output]; + + ld.global.u32 temp, [in_addr]; + st.global.u32 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/global_array.spvtxt b/ptx/src/test/spirv_run/global_array.spvtxt new file mode 100644 index 0000000..a4ed91d --- /dev/null +++ b/ptx/src/test/spirv_run/global_array.spvtxt @@ -0,0 +1,54 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %22 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %2 "global_array" %1 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4 + %uint_4_0 = OpConstant %uint 4 + %uint_1 = OpConstant %uint 1 + %uint_0 = OpConstant %uint 0 + %31 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0 + %1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %31 + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %2 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %20 = OpLabel + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %3 %8 + OpStore %4 %9 + %17 = OpConvertPtrToU %ulong %1 + %10 = OpCopyObject %ulong %17 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %11 = OpCopyObject %ulong %12 + OpStore %6 %11 + %14 = OpLoad %ulong %5 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14 + %13 = OpLoad %uint %18 + OpStore %7 %13 + %15 = OpLoad %ulong %6 + %16 = OpLoad %uint %7 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %15 + OpStore %19 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 8caf540..0c881d9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -66,14 +66,18 @@ test_ptx!(b64tof64, [111u64], [111u64]); test_ptx!(implicit_param, [34u32], [34u32]); test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]); -test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]); +test_ptx!( + mul_wide, + [0x01_00_00_00__01_00_00_00i64], + [0x1_00_00_00_00_00_00i64] +); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); test_ptx!(shr, [-2i32], [-1i32]); test_ptx!(or, [1u64, 2u64], [3u64]); test_ptx!(sub, [2u64], [1u64]); test_ptx!(min, [555i32, 444i32], [444i32]); test_ptx!(max, [555i32, 444i32], [555i32]); - +test_ptx!(global_array, [0xDEADu32], [1u32]); struct DisplayError { err: T, @@ -131,7 +135,15 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let mut devices = drv.devices()?; let dev = devices.drain(0..1).next().unwrap(); let queue = ze::CommandQueue::new(&mut ctx, &dev)?; - let module = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None)?; + let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None); + let module = match module { + Ok(m) => m, + Err(err) => { + let raw_err_string = log.get_cstring()?; + let err_string = raw_err_string.to_string_lossy(); + panic!("{:?}\n{}", err, err_string); + } + }; let mut kernel = ze::Kernel::new_resident(&module, name)?; kernel.set_indirect_access( ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7c15744..a86ab3c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; +use half::f16; use rspirv::{binary::Disassemble, dr}; use std::collections::{hash_map, HashMap, HashSet}; -use std::convert::TryInto; use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; @@ -26,7 +26,7 @@ quick_error! { enum SpirvType { Base(SpirvScalarKey), Vector(SpirvScalarKey, u8), - Array(SpirvScalarKey, u32), + Array(SpirvScalarKey, Vec), Pointer(Box, spirv::StorageClass), Func(Option>, Vec), Struct(Vec), @@ -62,6 +62,7 @@ impl From for SpirvType { struct TypeWordMap { void: spirv::Word, complex: HashMap, + constants: HashMap<(SpirvType, u64), spirv::Word>, } // SPIR-V integer type definitions are signless, more below: @@ -108,6 +109,7 @@ impl TypeWordMap { TypeWordMap { void: void, complex: HashMap::::new(), + constants: HashMap::new(), } } @@ -154,13 +156,25 @@ impl TypeWordMap { .entry(t) .or_insert_with(|| b.type_vector(base, len as u32)) } - SpirvType::Array(typ, len) => { - let base = self.get_or_add_spirv_scalar(b, typ); + SpirvType::Array(typ, array_dimensions) => { let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - *self.complex.entry(t).or_insert_with(|| { - let len_word = b.constant_u32(u32_type, None, len); - b.type_array(base, len_word) - }) + let (base_type, length) = match &*array_dimensions { + &[len] => { + let base = self.get_or_add_spirv_scalar(b, typ); + let len_const = b.constant_u32(u32_type, None, len); + (base, len_const) + } + array_dimensions => { + let base = self + .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); + let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); + (base, len_const) + } + }; + *self + .complex + .entry(SpirvType::Array(typ, array_dimensions)) + .or_insert_with(|| b.type_array(base_type, length)) } SpirvType::Func(ref out_params, ref in_params) => { let out_t = match out_params { @@ -211,16 +225,173 @@ impl TypeWordMap { self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), ) } + + fn get_or_add_constant( + &mut self, + b: &mut dr::Builder, + typ: &ast::Type, + init: &[u8], + ) -> Result { + Ok(match typ { + ast::Type::Scalar(t) => match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v), + ), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v, + |b, result_type, v| b.constant_u64(result_type, None, v), + ), + ast::ScalarType::F16 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u16>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), + ), + ast::ScalarType::F32 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u32>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v), + ), + ast::ScalarType::F64 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u64>(v) }, + |b, result_type, v| b.constant_f64(result_type, None, v), + ), + ast::ScalarType::F16x2 => return Err(TranslateError::Todo), + ast::ScalarType::Pred => self.get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| { + if v == 0 { + b.constant_false(result_type, None) + } else { + b.constant_true(result_type, None) + } + }, + ), + }, + ast::Type::Vector(typ, len) => { + let result_type = + self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); + let size_of_t = typ.size_of(); + let components = (0..*len) + .map(|x| { + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + ) + }) + .collect::, _>>()?; + b.constant_composite(result_type, None, &components) + } + ast::Type::Array(typ, dims) => match dims.as_slice() { + [] => return Err(TranslateError::Unreachable), + [dim] => { + let result_type = self + .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); + let size_of_t = typ.size_of(); + let components = (0..*dim) + .map(|x| { + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + ) + }) + .collect::, _>>()?; + b.constant_composite(result_type, None, &components) + } + [first_dim, rest @ ..] => { + let result_type = self.get_or_add( + b, + SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), + ); + let size_of_t = rest + .iter() + .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); + let components = (0..*first_dim) + .map(|x| { + self.get_or_add_constant( + b, + &ast::Type::Array(*typ, rest.to_vec()), + &init[((size_of_t as usize) * (x as usize))..], + ) + }) + .collect::, _>>()?; + b.constant_composite(result_type, None, &components) + } + }, + }) + } + + fn get_or_add_constant_single< + T: Copy, + CastAsU64: FnOnce(T) -> u64, + InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, + >( + &mut self, + b: &mut dr::Builder, + key: ast::ScalarType, + init: &[u8], + cast: CastAsU64, + f: InsertConstant, + ) -> spirv::Word { + let value = unsafe { *(init.as_ptr() as *const T) }; + let value_64 = cast(value); + let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); + match self.constants.get(&ht_key) { + Some(value) => *value, + None => { + let spirv_type = self.get_or_add_scalar(b, key); + let result = f(b, spirv_type, value); + self.constants.insert(ht_key, result); + result + } + } + } } pub fn to_spirv_module<'a>( ast: ast::Module<'a>, ) -> Result<(dr::Module, HashMap>), TranslateError> { let mut id_defs = GlobalStringIdResolver::new(1); - let ssa_functions = ast - .functions + let directives = ast + .directives .into_iter() - .map(|f| to_ssa_function(&mut id_defs, f)) + .map(|f| translate_directive(&mut id_defs, f)) .collect::, _>>()?; let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); @@ -233,21 +404,28 @@ pub fn to_spirv_module<'a>( let mut map = TypeWordMap::new(&mut builder); emit_builtins(&mut builder, &mut map, &id_defs); let mut args_len = HashMap::new(); - for f in ssa_functions { - let f_body = match f.body { - Some(f) => f, - None => continue, - }; - emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; - emit_function_header( - &mut builder, - &mut map, - &id_defs, - f.func_directive, - &mut args_len, - )?; - emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; - builder.end_function()?; + for d in directives { + match d { + Directive::Variable(var) => { + emit_variable(&mut builder, &mut map, &var)?; + } + Directive::Method(f) => { + let f_body = match f.body { + Some(f) => f, + None => continue, + }; + emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; + emit_function_header( + &mut builder, + &mut map, + &id_defs, + f.func_directive, + &mut args_len, + )?; + emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; + builder.end_function()?; + } + } } Ok((builder.module(), args_len)) } @@ -294,12 +472,18 @@ fn emit_function_header<'a>( let fn_id = match func_directive { ast::MethodDecl::Kernel(name, _) => { let fn_id = global.get_id(name)?; - let interface = global + let mut global_variables = global + .variables_type_check + .iter() + .filter_map(|(k, t)| t.as_ref().map(|_| *k)) + .collect::>(); + let mut interface = global .special_registers .iter() .map(|(_, id)| *id) .collect::>(); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface); + global_variables.append(&mut interface); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); fn_id } ast::MethodDecl::Func(_, name, _) => name, @@ -311,7 +495,7 @@ fn emit_function_header<'a>( func_type, )?; func_directive.visit_args(&mut |arg| { - let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into()); + let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type.clone()).into()); let inst = dr::Instruction::new( spirv::Op::FunctionParameter, Some(result_type), @@ -355,7 +539,30 @@ fn emit_memory_model(builder: &mut dr::Builder) { ); } -fn to_ssa_function<'a>( +fn translate_directive<'input>( + id_defs: &mut GlobalStringIdResolver<'input>, + d: ast::Directive<'input, ast::ParsedArgParams<'input>>, +) -> Result, TranslateError> { + Ok(match d { + ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?), + ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?), + }) +} + +fn translate_variable<'a>( + id_defs: &mut GlobalStringIdResolver<'a>, + var: ast::Variable, +) -> Result, TranslateError> { + let (state_space, typ) = var.v_type.to_type(); + Ok(ast::Variable { + align: var.align, + v_type: var.v_type, + name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)), + array_init: var.array_init, + }) +} + +fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, f: ast::ParsedFunction<'a>, ) -> Result, TranslateError> { @@ -368,9 +575,13 @@ fn expand_kernel_params<'a, 'b>( args: impl Iterator>, ) -> Vec> { args.map(|a| ast::KernelArgument { - name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))), - v_type: a.v_type, + name: fn_resolver.add_def( + a.name, + Some((StateSpace::Param, ast::Type::from(a.v_type.clone()))), + ), + v_type: a.v_type.clone(), align: a.align, + array_init: Vec::new(), }) .collect() } @@ -385,9 +596,10 @@ fn expand_fn_params<'a, 'b>( ast::FnArgumentType::Param(_) => StateSpace::Param, }; ast::FnArgument { - name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))), - v_type: a.v_type, + name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))), + v_type: a.v_type.clone(), align: a.align, + array_init: Vec::new(), } }) .collect() @@ -628,7 +840,7 @@ fn to_resolved_fn_args( params .into_iter() .zip(params_decl.iter()) - .map(|(id, typ)| (id, *typ)) + .map(|(id, typ)| (id, typ.clone())) .collect::>() } @@ -719,12 +931,13 @@ fn insert_mem_ssa_statements<'a, 'b>( let out_param = match &mut f_args { ast::MethodDecl::Kernel(_, in_params) => { for p in in_params.iter_mut() { - let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(typ); + let typ = ast::Type::from(p.v_type.clone()); + let new_id = id_def.new_id(typ.clone()); result.push(Statement::Variable(ast::Variable { align: p.align, - v_type: ast::VariableType::Param(p.v_type), + v_type: ast::VariableType::Param(p.v_type.clone()), name: p.name, + array_init: p.array_init.clone(), })); result.push(Statement::StoreVar( ast::Arg2St { @@ -739,20 +952,21 @@ fn insert_mem_ssa_statements<'a, 'b>( } ast::MethodDecl::Func(out_params, _, in_params) => { for p in in_params.iter_mut() { - let typ = ast::Type::from(p.v_type); - let new_id = id_def.new_id(typ); - let var_typ = ast::VariableType::from(p.v_type); + let typ = ast::Type::from(p.v_type.clone()); + let new_id = id_def.new_id(typ.clone()); + let var_typ = ast::VariableType::from(p.v_type.clone()); result.push(Statement::Variable(ast::Variable { align: p.align, v_type: var_typ, name: p.name, + array_init: p.array_init.clone(), })); result.push(Statement::StoreVar( ast::Arg2St { src1: p.name, src2: new_id, }, - typ, + typ.clone(), )); p.name = new_id; } @@ -760,8 +974,9 @@ fn insert_mem_ssa_statements<'a, 'b>( [p] => { result.push(Statement::Variable(ast::Variable { align: p.align, - v_type: ast::VariableType::from(p.v_type), + v_type: ast::VariableType::from(p.v_type.clone()), name: p.name, + array_init: p.array_init.clone(), })); Some(p.name) } @@ -779,13 +994,13 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => { if let Some(out_param) = out_param { let typ = id_def.get_typed(out_param)?; - let new_id = id_def.new_id(typ); + let new_id = id_def.new_id(typ.clone()); result.push(Statement::LoadVar( ast::Arg2 { dst: new_id, src: out_param, }, - typ, + typ.clone(), )); result.push(Statement::RetValue(d, new_id)); } else { @@ -824,7 +1039,7 @@ trait VisitVariable: Sized { 'a, F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -835,7 +1050,7 @@ trait VisitVariableExpanded { fn visit_variable_extended< F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -861,7 +1076,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( | (t, ArgumentSemantics::DefaultRelaxed) | (t, ArgumentSemantics::PhysicalPointer) => t, }; - let generated_id = id_def.new_id(id_type); + let generated_id = id_def.new_id(id_type.clone()); if !desc.is_dst { result.push(Statement::LoadVar( Arg2 { @@ -909,10 +1124,12 @@ fn expand_arguments<'a, 'b>( align, v_type, name, + array_init, }) => result.push(Statement::Variable(ast::Variable { align, v_type, name, + array_init, })), Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), @@ -969,7 +1186,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg( &mut self, desc: ArgumentDescriptor, - _: Option, + _: Option<&ast::Type>, ) -> Result { Ok(desc.op) } @@ -977,8 +1194,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg_offset( &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, - mut typ: ast::Type, + typ: &ast::Type, ) -> Result { + let mut typ = typ.clone(); let (reg, offset) = desc.op; match desc.sema { ArgumentSemantics::Default @@ -997,7 +1215,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ScalarKind::Float2 => return Err(TranslateError::MismatchedType), ScalarKind::Pred => return Err(TranslateError::MismatchedType), }; - (scalar_t.width(), kind) + (scalar_t.size_of(), kind) } _ => return Err(TranslateError::MismatchedType), }; @@ -1009,7 +1227,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else { ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) }; - let id_constant_stmt = self.id_def.new_id(typ); + let id_constant_stmt = self.id_def.new_id(typ.clone()); let result_id = self.id_def.new_id(typ); // TODO: check for edge cases around min value/max value/wrapping if offset < 0 && kind != ScalarKind::Signed { @@ -1060,10 +1278,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn immediate( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { - scalar + *scalar } else { todo!() }; @@ -1098,14 +1316,14 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn vector( &mut self, desc: ArgumentDescriptor<&Vec>, - typ: ast::Type, + typ: &ast::Type, ) -> Result { let (scalar_type, vec_len) = typ.get_vector()?; if !desc.is_dst { - let mut new_id = self.id_def.new_id(typ); - self.func.push(Statement::Undef(typ, new_id)); + let mut new_id = self.id_def.new_id(typ.clone()); + self.func.push(Statement::Undef(typ.clone(), new_id)); for (idx, id) in desc.op.iter().enumerate() { - let newer_id = self.id_def.new_id(typ); + let newer_id = self.id_def.new_id(typ.clone()); self.func.push(Statement::Instruction(ast::Instruction::Mov( ast::MovDetails { typ: ast::Type::Scalar(scalar_type), @@ -1124,7 +1342,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } Ok(new_id) } else { - let new_id = self.id_def.new_id(typ); + let new_id = self.id_def.new_id(typ.clone()); for (idx, id) in desc.op.iter().enumerate() { Self::insert_composite_read( &mut self.post_stmts, @@ -1144,7 +1362,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn id( &mut self, desc: ArgumentDescriptor, - t: Option, + t: Option<&ast::Type>, ) -> Result { self.reg(desc, t) } @@ -1152,7 +1370,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn operand( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result { match desc.op { ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), @@ -1166,7 +1384,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn src_call_operand( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result { match desc.op { ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), @@ -1185,7 +1403,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn id_or_vector( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result { match desc.op { ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), @@ -1196,7 +1414,7 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn operand_or_vector( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result { match desc.op { ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), @@ -1236,8 +1454,8 @@ fn insert_implicit_conversions( None, )?, Statement::Instruction(inst) => { - let mut default_conversion_fn = should_bitcast_wrapper - as fn(_, _, _) -> Result, TranslateError>; + let mut default_conversion_fn = + should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _; let mut state_space = None; if let ast::Instruction::Ld(d, _) = &inst { state_space = Some(d.state_space); @@ -1281,9 +1499,9 @@ fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, stmt: impl VisitVariableExpanded, - default_conversion_fn: fn( - ast::Type, - ast::Type, + default_conversion_fn: for<'a> fn( + &'a ast::Type, + &'a ast::Type, Option, ) -> Result, TranslateError>, state_space: Option, @@ -1315,16 +1533,16 @@ fn insert_implicit_conversions_impl( conversion_fn = force_bitcast_ptr_to_bit; } }; - match conversion_fn(operand_type, instr_type, state_space)? { + match conversion_fn(&operand_type, instr_type, state_space)? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv } else { &mut *func }; - let mut from = instr_type; + let mut from = instr_type.clone(); let mut to = operand_type; - let mut src = id_def.new_id(instr_type); + let mut src = id_def.new_id(instr_type.clone()); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { @@ -1358,17 +1576,17 @@ fn get_function_type( builder, out_params .iter() - .map(|p| SpirvType::from(ast::Type::from(p.v_type))), + .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), in_params .iter() - .map(|p| SpirvType::from(ast::Type::from(p.v_type))), + .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), ), ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( builder, iter::empty(), params .iter() - .map(|p| SpirvType::from(ast::Type::from(p.v_type))), + .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), ), } } @@ -1398,7 +1616,7 @@ fn emit_function_body_ops( Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { [(id, typ)] => ( - map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))), + map.get_or_add(builder, SpirvType::from(ast::Type::from(typ.clone()))), Some(*id), ), [] => (map.void(), None), @@ -1411,28 +1629,8 @@ fn emit_function_body_ops( .collect::>(); builder.function_call(result_type, result_id, call.func, arg_list)?; } - Statement::Variable(ast::Variable { - align, - v_type, - name, - }) => { - let st_class = match v_type { - ast::VariableType::Reg(_) - | ast::VariableType::Param(_) - | ast::VariableType::Local(_) => spirv::StorageClass::Function, - }; - let type_id = map.get_or_add( - builder, - SpirvType::new_pointer(ast::Type::from(*v_type), st_class), - ); - builder.variable(type_id, Some(*name), st_class, None); - if let Some(align) = align { - builder.decorate( - *name, - spirv::Decoration::Alignment, - &[dr::Operand::LiteralInt32(*align)], - ); - } + Statement::Variable(var) => { + emit_variable(builder, map, var)?; } Statement::Constant(cnst) => { let typ_id = map.get_or_add_scalar(builder, cnst.typ); @@ -1479,13 +1677,14 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); + let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone())); match data.state_space { ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param | ast::LdStateSpace::Local => { - let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); + let result_type = + map.get_or_add(builder, SpirvType::from(data.typ.clone())); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } _ => todo!(), @@ -1498,7 +1697,8 @@ fn emit_function_body_ops( if data.state_space == ast::StStateSpace::Param || data.state_space == ast::StStateSpace::Local { - let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); + let result_type = + map.get_or_add(builder, SpirvType::from(data.typ.clone())); builder.copy_object(result_type, Some(arg.src1), arg.src2)?; } else if data.state_space == ast::StStateSpace::Generic || data.state_space == ast::StStateSpace::Global @@ -1513,8 +1713,8 @@ fn emit_function_body_ops( ast::Instruction::Mov(d, arg) => match arg { ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src }) | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => { - let result_type = - map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ))); + let result_type = map + .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); builder.copy_object(result_type, Some(*dst), *src)?; } ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( @@ -1645,7 +1845,7 @@ fn emit_function_body_ops( } }, Statement::LoadVar(arg, typ) => { - let type_id = map.get_or_add(builder, SpirvType::from(*typ)); + let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); builder.load(type_id, Some(arg.dst), arg.src, None, [])?; } Statement::StoreVar(arg, _) => { @@ -1665,7 +1865,7 @@ fn emit_function_body_ops( )?; } Statement::Undef(t, id) => { - let result_type = map.get_or_add(builder, SpirvType::from(*t)); + let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); builder.undef(result_type, Some(*id)); } } @@ -1673,6 +1873,41 @@ fn emit_function_body_ops( Ok(()) } +fn emit_variable( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + var: &ast::Variable, +) -> Result<(), TranslateError> { + let (should_init, st_class) = match var.v_type { + ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { + (false, spirv::StorageClass::Function) + } + ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), + }; + let type_id = map.get_or_add( + builder, + SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), + ); + let initalizer = if should_init { + Some(map.get_or_add_constant( + builder, + &ast::Type::from(var.v_type.clone()), + &*var.array_init, + )?) + } else { + None + }; + builder.variable(type_id, Some(var.name), st_class, initalizer); + if let Some(align) = var.align { + builder.decorate( + var.name, + spirv::Decoration::Alignment, + &[dr::Operand::LiteralInt32(align)], + ); + } + Ok(()) +} + fn emit_mad_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1876,7 +2111,7 @@ fn emit_cvt( dst: new_dst, from: ast::Type::Scalar(src_t), to: ast::Type::Scalar(ast::ScalarType::from_parts( - dest_t.width(), + dest_t.size_of(), src_t.kind(), )), kind: ConversionKind::Default, @@ -2041,7 +2276,7 @@ fn emit_mul_sint( ]); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; - let instr_width = instruction_type.width(); + let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); @@ -2088,7 +2323,7 @@ fn emit_mul_uint( ]); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; - let instr_width = instruction_type.width(); + let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); @@ -2193,13 +2428,13 @@ fn emit_implicit_conversion( (_, _, ConversionKind::BitToPtr(space)) => { let dst_type = map.get_or_add( builder, - SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), + SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()), ); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::from(cv.to)); + let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); if from_parts.scalar_kind != ScalarKind::Float && to_parts.scalar_kind != ScalarKind::Float { @@ -2222,7 +2457,8 @@ fn emit_implicit_conversion( scalar_kind: ScalarKind::Bit, ..to_parts }); - let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type)); + let wide_bit_type_spirv = + map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); if to_parts.scalar_kind == ScalarKind::Unsigned || to_parts.scalar_kind == ScalarKind::Bit { @@ -2237,7 +2473,7 @@ fn emit_implicit_conversion( src: wide_bit_value, dst: cv.dst, from: wide_bit_type, - to: cv.to, + to: cv.to.clone(), kind: ConversionKind::Default, }, )?; @@ -2248,7 +2484,7 @@ fn emit_implicit_conversion( (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); + let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } _ => unreachable!(), @@ -2301,23 +2537,29 @@ fn expand_map_variables<'a, 'b>( ast::VariableType::Reg(_) => StateSpace::Reg, ast::VariableType::Local(_) => StateSpace::Local, ast::VariableType::Param(_) => StateSpace::ParamReg, + ast::VariableType::Global(_) => todo!(), }; match var.count { Some(count) => { - for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) { + for new_id in + id_defs.add_defs(var.var.name, count, ss, var.var.v_type.clone().into()) + { result.push(Statement::Variable(ast::Variable { align: var.var.align, - v_type: var.var.v_type, + v_type: var.var.v_type.clone(), name: new_id, + array_init: var.var.array_init.clone(), })) } } None => { - let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into()))); + let new_id = + id_defs.add_def(var.var.name, Some((ss, var.var.v_type.clone().into()))); result.push(Statement::Variable(ast::Variable { align: var.var.align, - v_type: var.var.v_type, + v_type: var.var.v_type.clone(), name: new_id, + array_init: var.var.array_init, })); } } @@ -2367,6 +2609,7 @@ impl PtxSpecialRegister { struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, + variables_type_check: HashMap>, special_registers: HashMap, fns: HashMap, } @@ -2381,13 +2624,26 @@ impl<'a> GlobalStringIdResolver<'a> { Self { current_id: start_id, variables: HashMap::new(), + variables_type_check: HashMap::new(), special_registers: HashMap::new(), fns: HashMap::new(), } } fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { - match self.variables.entry(Cow::Borrowed(id)) { + self.get_or_add_impl(id, None) + } + + fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word { + self.get_or_add_impl(id, Some(typ)) + } + + fn get_or_add_impl( + &mut self, + id: &'a str, + typ: Option<(StateSpace, ast::Type)>, + ) -> spirv::Word { + let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { let numeric_id = self.current_id; @@ -2395,7 +2651,9 @@ impl<'a> GlobalStringIdResolver<'a> { self.current_id += 1; numeric_id } - } + }; + self.variables_type_check.insert(id, typ); + id } fn get_id(&self, id: &str) -> Result { @@ -2422,6 +2680,7 @@ impl<'a> GlobalStringIdResolver<'a> { let mut fn_resolver = FnStringIdResolver { current_id: &mut self.current_id, global_variables: &self.variables, + global_type_check: &self.variables_type_check, special_registers: &mut self.special_registers, variables: vec![HashMap::new(); 1], type_check: HashMap::new(), @@ -2436,8 +2695,8 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert( name_id, FnDecl { - ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(), - params: params_ids.iter().map(|p| p.v_type).collect(), + ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), + params: params_ids.iter().map(|p| p.v_type.clone()).collect(), }, ); ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) @@ -2475,6 +2734,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, + global_type_check: &'b HashMap>, special_registers: &'b mut HashMap, variables: Vec, spirv::Word>>, type_check: HashMap>, @@ -2484,6 +2744,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { fn finish(self) -> NumericIdResolver<'b> { NumericIdResolver { current_id: self.current_id, + global_type_check: self.global_type_check, type_check: self.type_check, special_registers: self .special_registers @@ -2551,7 +2812,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); - self.type_check.insert(numeric_id + i, Some((ss, typ))); + self.type_check + .insert(numeric_id + i, Some((ss, typ.clone()))); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) @@ -2560,6 +2822,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, + global_type_check: &'b HashMap>, type_check: HashMap>, special_registers: HashMap, } @@ -2571,11 +2834,14 @@ impl<'b> NumericIdResolver<'b> { fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> { match self.type_check.get(&id) { - Some(Some(x)) => Ok(*x), + Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(&id) { Some(x) => Ok((StateSpace::Reg, x.get_type())), - None => Err(TranslateError::UntypedSymbol), + None => match self.global_type_check.get(&id) { + Some(Some(x)) => Ok(x.clone()), + Some(None) | None => Err(TranslateError::UntypedSymbol), + }, }, } } @@ -2655,7 +2921,7 @@ impl> ResolvedCall { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(typ.into()), + Some(&typ.clone().into()), )?; Ok((new_id, typ)) }) @@ -2678,7 +2944,7 @@ impl> ResolvedCall { is_dst: false, sema: ArgumentSemantics::Default, }, - typ.into(), + &typ.clone().into(), )?; Ok((new_id, typ)) }) @@ -2697,7 +2963,7 @@ impl VisitVariable for ResolvedCall { 'a, F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -2711,7 +2977,7 @@ impl VisitVariableExpanded for ResolvedCall { fn visit_variable_extended< F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -2821,6 +3087,24 @@ pub enum StateSpace { ParamReg, } +impl From for StateSpace { + fn from(ss: ast::StateSpace) -> Self { + match ss { + ast::StateSpace::Reg => StateSpace::Reg, + ast::StateSpace::Const => StateSpace::Const, + ast::StateSpace::Global => StateSpace::Global, + ast::StateSpace::Local => StateSpace::Local, + ast::StateSpace::Shared => StateSpace::Shared, + ast::StateSpace::Param => StateSpace::Param, + } + } +} + +enum Directive<'input> { + Variable(ast::Variable), + Method(Function<'input>), +} + struct Function<'input> { pub func_directive: ast::MethodDecl<'input, spirv::Word>, pub globals: Vec, @@ -2831,27 +3115,27 @@ pub trait ArgumentMapVisitor { fn id( &mut self, desc: ArgumentDescriptor, - typ: Option, + typ: Option<&ast::Type>, ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result; fn id_or_vector( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result; fn operand_or_vector( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result; fn src_call_operand( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result; fn src_member_operand( &mut self, @@ -2864,13 +3148,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option, + t: Option<&ast::Type>, ) -> Result { self(desc, t) } @@ -2878,7 +3162,7 @@ where fn operand( &mut self, desc: ArgumentDescriptor, - t: ast::Type, + t: &ast::Type, ) -> Result { self(desc, Some(t)) } @@ -2886,7 +3170,7 @@ where fn id_or_vector( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result { self(desc, Some(typ)) } @@ -2894,7 +3178,7 @@ where fn operand_or_vector( &mut self, desc: ArgumentDescriptor, - typ: ast::Type, + typ: &ast::Type, ) -> Result { self(desc, Some(typ)) } @@ -2902,7 +3186,7 @@ where fn src_call_operand( &mut self, desc: ArgumentDescriptor, - t: ast::Type, + t: &ast::Type, ) -> Result { self(desc, Some(t)) } @@ -2912,7 +3196,7 @@ where desc: ArgumentDescriptor, (scalar_type, _): (ast::ScalarType, u8), ) -> Result { - self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type))) + self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type))) } } @@ -2923,7 +3207,7 @@ where fn id( &mut self, desc: ArgumentDescriptor<&str>, - _: Option, + _: Option<&ast::Type>, ) -> Result { self(desc.op) } @@ -2931,7 +3215,7 @@ where fn operand( &mut self, desc: ArgumentDescriptor>, - _: ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), @@ -2943,7 +3227,7 @@ where fn id_or_vector( &mut self, desc: ArgumentDescriptor>, - _: ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)), @@ -2956,7 +3240,7 @@ where fn operand_or_vector( &mut self, desc: ArgumentDescriptor>, - _: ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)), @@ -2973,7 +3257,7 @@ where fn src_call_operand( &mut self, desc: ArgumentDescriptor>, - _: ast::Type, + _: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)), @@ -3027,39 +3311,39 @@ impl ast::Instruction { ) -> Result, TranslateError> { Ok(match self { ast::Instruction::Abs(d, arg) => { - ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?) + ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), ast::Instruction::Ld(d, a) => { - let inst_type = d.typ; let is_param = d.state_space == ast::LdStateSpace::Param || d.state_space == ast::LdStateSpace::Local; - ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?) + let new_args = a.map(visitor, &d.typ, is_param)?; + ast::Instruction::Ld(d, new_args) } ast::Instruction::Mov(d, a) => { - let mapped = a.map(visitor, d)?; + let mapped = a.map(visitor, &d)?; ast::Instruction::Mov(d, mapped) } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); let is_wide = d.is_wide(); - ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?) + ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?) + ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; - ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))?) + ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } ast::Instruction::SetpBool(d, a) => { let inst_type = d.typ; - ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?) + ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } ast::Instruction::Not(t, a) => { - ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?) + ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?) } ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { @@ -3080,46 +3364,46 @@ impl ast::Instruction { ast::Type::Scalar(desc.src.into()), ), }; - ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)?) + ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { - ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?) + ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) } ast::Instruction::Shr(t, a) => { - ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?) + ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) } ast::Instruction::St(d, a) => { - let inst_type = d.typ; let is_param = d.state_space == ast::StStateSpace::Param || d.state_space == ast::StStateSpace::Local; - ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?) + let new_args = a.map(visitor, &d.typ, is_param)?; + ast::Instruction::St(d, new_args) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); - ast::Instruction::Cvta(d, a.map(visitor, false, inst_type)?) + ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?) } ast::Instruction::Mad(d, a) => { let inst_type = d.get_type(); let is_wide = d.is_wide(); - ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?) + ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?) } ast::Instruction::Or(t, a) => ast::Instruction::Or( t, - a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?, + a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, ), ast::Instruction::Sub(d, a) => { let typ = d.get_type(); - ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?) + ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?) } ast::Instruction::Min(d, a) => { let typ = d.get_type(); - ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?) + ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?) } ast::Instruction::Max(d, a) => { let typ = d.get_type(); - ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?) + ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?) } }) } @@ -3130,7 +3414,7 @@ impl VisitVariable for ast::Instruction { 'a, F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -3144,13 +3428,13 @@ impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, - t: Option, + t: Option<&ast::Type>, ) -> Result { self(desc, t) } @@ -3158,7 +3442,7 @@ where fn operand( &mut self, desc: ArgumentDescriptor>, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)), @@ -3173,7 +3457,7 @@ where fn src_call_operand( &mut self, desc: ArgumentDescriptor>, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)), @@ -3184,7 +3468,7 @@ where fn id_or_vector( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)), @@ -3199,7 +3483,7 @@ where fn operand_or_vector( &mut self, desc: ArgumentDescriptor>, - typ: ast::Type, + typ: &ast::Type, ) -> Result, TranslateError> { match desc.op { ast::OperandOrVector::Reg(id) => { @@ -3226,7 +3510,7 @@ where Ok(( self( desc.new_op(desc.op.0), - Some(ast::Type::Vector(scalar_type.into(), vector_len)), + Some(&ast::Type::Vector(scalar_type.into(), vector_len)), )?, desc.op.1, )) @@ -3238,7 +3522,7 @@ impl ast::Type { match self { ast::Type::Scalar(scalar) => { let kind = scalar.kind(); - let width = scalar.width(); + let width = scalar.size_of(); if (kind != ScalarKind::Signed && kind != ScalarKind::Unsigned && kind != ScalarKind::Bit) @@ -3255,25 +3539,25 @@ impl ast::Type { } } - fn to_parts(self) -> TypeParts { + fn to_parts(&self) -> TypeParts { match self { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, scalar_kind: scalar.kind(), - width: scalar.width(), - components: 0, + width: scalar.size_of(), + components: Vec::new(), }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), - width: scalar.width(), - components: components as u32, + width: scalar.size_of(), + components: vec![*components as u32], }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), - width: scalar.width(), - components: components, + width: scalar.size_of(), + components: components.clone(), }, } } @@ -3285,7 +3569,7 @@ impl ast::Type { } TypeKind::Vector => ast::Type::Vector( ast::ScalarType::from_parts(t.width, t.scalar_kind), - t.components as u8, + t.components[0] as u8, ), TypeKind::Array => ast::Type::Array( ast::ScalarType::from_parts(t.width, t.scalar_kind), @@ -3295,12 +3579,12 @@ impl ast::Type { } } -#[derive(Eq, PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, scalar_kind: ScalarKind, width: u8, - components: u32, + components: Vec, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -3342,7 +3626,7 @@ impl VisitVariableExpanded for ast::Instruction { fn visit_variable_extended< F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -3368,7 +3652,7 @@ impl VisitVariableExpanded for CompositeRead { fn visit_variable_extended< F: FnMut( ArgumentDescriptor, - Option, + Option<&ast::Type>, ) -> Result, >( self, @@ -3384,7 +3668,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: true, sema: dst_sema, }, - Some(ast::Type::Scalar(self.typ)), + Some(&ast::Type::Scalar(self.typ)), )?, src_composite: f( ArgumentDescriptor { @@ -3392,7 +3676,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Vector(self.typ, self.src_len as u8)), + Some(&ast::Type::Vector(self.typ, self.src_len as u8)), )?, ..self })) @@ -3411,7 +3695,7 @@ struct BrachCondition { if_false: spirv::Word, } -#[derive(Copy, Clone)] +#[derive(Clone)] struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, @@ -3471,11 +3755,12 @@ impl<'a> ast::Instruction> { } impl ast::VariableParamType { - fn width(self) -> usize { + fn width(&self) -> usize { match self { - ast::VariableParamType::Scalar(t) => ast::ScalarType::from(t).width() as usize, + ast::VariableParamType::Scalar(t) => ast::ScalarType::from(*t).size_of() as usize, ast::VariableParamType::Array(t, len) => { - (ast::ScalarType::from(t).width() as usize) * (len as usize) + (ast::ScalarType::from(*t).size_of() as usize) + * (len.iter().fold(1, |x, y| x * (*y)) as usize) } } } @@ -3489,7 +3774,7 @@ impl ast::Arg1 { fn map>( self, visitor: &mut V, - t: Option, + t: Option<&ast::Type>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { @@ -3515,7 +3800,7 @@ impl ast::Arg2 { self, visitor: &mut V, src_is_addr: bool, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { let new_dst = visitor.id( ArgumentDescriptor { @@ -3546,8 +3831,8 @@ impl ast::Arg2 { fn map_cvt>( self, visitor: &mut V, - dst_t: ast::Type, - src_t: ast::Type, + dst_t: &ast::Type, + src_t: &ast::Type, ) -> Result, TranslateError> { let dst = visitor.id( ArgumentDescriptor { @@ -3582,7 +3867,7 @@ impl ast::Arg2Ld { fn map>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, is_param: bool, ) -> Result, TranslateError> { let dst = visitor.id_or_vector( @@ -3591,7 +3876,7 @@ impl ast::Arg2Ld { is_dst: true, sema: ArgumentSemantics::DefaultRelaxed, }, - t.into(), + &ast::Type::from(t.clone()), )?; let src = visitor.operand( ArgumentDescriptor { @@ -3622,7 +3907,7 @@ impl ast::Arg2St { fn map>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, is_param: bool, ) -> Result, TranslateError> { let src1 = visitor.operand( @@ -3653,7 +3938,7 @@ impl ast::Arg2Mov { fn map>( self, visitor: &mut V, - details: ast::MovDetails, + details: &ast::MovDetails, ) -> Result, TranslateError> { Ok(match self { ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?), @@ -3675,7 +3960,7 @@ impl ast::Arg2MovNormal

{ fn map>( self, visitor: &mut V, - details: ast::MovDetails, + details: &ast::MovDetails, ) -> Result, TranslateError> { let dst = visitor.id_or_vector( ArgumentDescriptor { @@ -3683,7 +3968,7 @@ impl ast::Arg2MovNormal

{ is_dst: true, sema: ArgumentSemantics::Default, }, - details.typ.into(), + &details.typ.clone().into(), )?; let src = visitor.operand_or_vector( ArgumentDescriptor { @@ -3695,7 +3980,7 @@ impl ast::Arg2MovNormal

{ ArgumentSemantics::Default }, }, - details.typ.into(), + &details.typ.clone().into(), )?; Ok(ast::Arg2MovNormal { dst, src }) } @@ -3733,7 +4018,7 @@ impl ast::Arg2MovMember { fn map>( self, visitor: &mut V, - details: ast::MovDetails, + details: &ast::MovDetails, ) -> Result, TranslateError> { match self { ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { @@ -3744,7 +4029,7 @@ impl ast::Arg2MovMember { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Vector(scalar_type, details.dst_width)), + Some(&ast::Type::Vector(scalar_type, details.dst_width)), )?; let src1 = visitor.id( ArgumentDescriptor { @@ -3752,7 +4037,7 @@ impl ast::Arg2MovMember { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Vector(scalar_type, details.dst_width)), + Some(&ast::Type::Vector(scalar_type, details.dst_width)), )?; let src2 = visitor.id( ArgumentDescriptor { @@ -3766,7 +4051,7 @@ impl ast::Arg2MovMember { ArgumentSemantics::Default }, }, - Some(details.typ.into()), + Some(&details.typ.clone().into()), )?; Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2)) } @@ -3777,7 +4062,7 @@ impl ast::Arg2MovMember { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(details.typ.into()), + Some(&details.typ.clone().into()), )?; let scalar_typ = details.typ.get_scalar()?; let src = visitor.src_member_operand( @@ -3798,7 +4083,7 @@ impl ast::Arg2MovMember { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Vector(scalar_type, details.dst_width)), + Some(&ast::Type::Vector(scalar_type, details.dst_width)), )?; let composite_src = visitor.id( ArgumentDescriptor { @@ -3806,7 +4091,7 @@ impl ast::Arg2MovMember { is_dst: false, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Vector(scalar_type, details.dst_width)), + Some(&ast::Type::Vector(scalar_type, details.dst_width)), )?; let src = visitor.src_member_operand( ArgumentDescriptor { @@ -3838,16 +4123,21 @@ impl ast::Arg3 { fn map_non_shift>( self, visitor: &mut V, - typ: ast::Type, + typ: &ast::Type, is_wide: bool, ) -> Result, TranslateError> { + let wide_type = if is_wide { + Some(typ.clone().widen()?) + } else { + None + }; let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(if is_wide { typ.widen()? } else { typ }), + Some(wide_type.as_ref().unwrap_or(typ)), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -3871,7 +4161,7 @@ impl ast::Arg3 { fn map_shift>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { let dst = visitor.id( ArgumentDescriptor { @@ -3895,7 +4185,7 @@ impl ast::Arg3 { is_dst: false, sema: ArgumentSemantics::Default, }, - ast::Type::Scalar(ast::ScalarType::U32), + &ast::Type::Scalar(ast::ScalarType::U32), )?; Ok(ast::Arg3 { dst, src1, src2 }) } @@ -3914,16 +4204,21 @@ impl ast::Arg4 { fn map>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, is_wide: bool, ) -> Result, TranslateError> { + let wide_type = if is_wide { + Some(t.clone().widen()?) + } else { + None + }; let dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(if is_wide { t.widen()? } else { t }), + Some(wide_type.as_ref().unwrap_or(t)), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -3971,7 +4266,7 @@ impl ast::Arg4Setp { fn map>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { @@ -3979,7 +4274,7 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), + Some(&ast::Type::Scalar(ast::ScalarType::Pred)), )?; let dst2 = self .dst2 @@ -3990,7 +4285,7 @@ impl ast::Arg4Setp { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), + Some(&ast::Type::Scalar(ast::ScalarType::Pred)), ) }) .transpose()?; @@ -4033,7 +4328,7 @@ impl ast::Arg5 { fn map>( self, visitor: &mut V, - t: ast::Type, + t: &ast::Type, ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { @@ -4041,7 +4336,7 @@ impl ast::Arg5 { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), + Some(&ast::Type::Scalar(ast::ScalarType::Pred)), )?; let dst2 = self .dst2 @@ -4052,7 +4347,7 @@ impl ast::Arg5 { is_dst: true, sema: ArgumentSemantics::Default, }, - Some(ast::Type::Scalar(ast::ScalarType::Pred)), + Some(&ast::Type::Scalar(ast::ScalarType::Pred)), ) }) .transpose()?; @@ -4078,7 +4373,7 @@ impl ast::Arg5 { is_dst: false, sema: ArgumentSemantics::Default, }, - ast::Type::Scalar(ast::ScalarType::Pred), + &ast::Type::Scalar(ast::ScalarType::Pred), )?; Ok(ast::Arg5 { dst1, @@ -4091,16 +4386,16 @@ impl ast::Arg5 { } impl ast::Type { - fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> { + fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> { match self { - ast::Type::Vector(t, len) => Ok((t, len)), + ast::Type::Vector(t, len) => Ok((*t, *len)), _ => Err(TranslateError::MismatchedType), } } - fn get_scalar(self) -> Result { + fn get_scalar(&self) -> Result { match self { - ast::Type::Scalar(t) => Ok(t), + ast::Type::Scalar(t) => Ok(*t), _ => Err(TranslateError::MismatchedType), } } @@ -4141,28 +4436,6 @@ enum ScalarKind { } impl ast::ScalarType { - fn width(self) -> u8 { - match self { - ast::ScalarType::U8 => 1, - ast::ScalarType::S8 => 1, - ast::ScalarType::B8 => 1, - ast::ScalarType::U16 => 2, - ast::ScalarType::S16 => 2, - ast::ScalarType::B16 => 2, - ast::ScalarType::F16 => 2, - ast::ScalarType::U32 => 4, - ast::ScalarType::S32 => 4, - ast::ScalarType::B32 => 4, - ast::ScalarType::F32 => 4, - ast::ScalarType::U64 => 8, - ast::ScalarType::S64 => 8, - ast::ScalarType::B64 => 8, - ast::ScalarType::F64 => 8, - ast::ScalarType::F16x2 => 4, - ast::ScalarType::Pred => 1, - } - } - fn kind(self) -> ScalarKind { match self { ast::ScalarType::U8 => ScalarKind::Unsigned, @@ -4283,20 +4556,6 @@ impl ast::MinMaxDetails { } } -impl ast::IntType { - fn try_new(t: ast::ScalarType) -> Option { - match t { - ast::ScalarType::U16 => Some(ast::IntType::U16), - ast::ScalarType::U32 => Some(ast::IntType::U32), - ast::ScalarType::U64 => Some(ast::IntType::U64), - ast::ScalarType::S16 => Some(ast::IntType::S16), - ast::ScalarType::S32 => Some(ast::IntType::S32), - ast::ScalarType::S64 => Some(ast::IntType::S64), - _ => None, - } - } -} - impl ast::SIntType { fn from_size(width: u8) -> Self { match width { @@ -4372,8 +4631,8 @@ impl ast::MulDetails { } fn force_bitcast( - operand: ast::Type, - instr: ast::Type, + operand: &ast::Type, + instr: &ast::Type, _: Option, ) -> Result, TranslateError> { if instr != operand { @@ -4384,8 +4643,8 @@ fn force_bitcast( } fn bitcast_physical_pointer( - operand_type: ast::Type, - _: ast::Type, + operand_type: &ast::Type, + _: &ast::Type, ss: Option, ) -> Result, TranslateError> { match operand_type { @@ -4403,17 +4662,17 @@ fn bitcast_physical_pointer( } fn force_bitcast_ptr_to_bit( - _: ast::Type, - _: ast::Type, + _: &ast::Type, + _: &ast::Type, _: Option, ) -> Result, TranslateError> { Ok(Some(ConversionKind::PtrToBit)) } -fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { - if inst.width() != operand.width() { + if inst.size_of() != operand.size_of() { return false; } match inst.kind() { @@ -4431,22 +4690,22 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { } (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { - should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) } _ => false, } } fn should_bitcast_packed( - operand: ast::Type, - instr: ast::Type, + operand: &ast::Type, + instr: &ast::Type, ss: Option, ) -> Result, TranslateError> { if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = (operand, instr) { if scalar.kind() == ScalarKind::Bit - && scalar.width() == (vec_underlying_type.width() * vec_len) + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) { return Ok(Some(ConversionKind::Default)); } @@ -4455,8 +4714,8 @@ fn should_bitcast_packed( } fn should_bitcast_wrapper( - operand: ast::Type, - instr: ast::Type, + operand: &ast::Type, + instr: &ast::Type, _: Option, ) -> Result, TranslateError> { if instr == operand { @@ -4470,8 +4729,8 @@ fn should_bitcast_wrapper( } fn should_convert_relaxed_src_wrapper( - src_type: ast::Type, - instr_type: ast::Type, + src_type: &ast::Type, + instr_type: &ast::Type, _: Option, ) -> Result, TranslateError> { if src_type == instr_type { @@ -4485,8 +4744,8 @@ fn should_convert_relaxed_src_wrapper( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands fn should_convert_relaxed_src( - src_type: ast::Type, - instr_type: ast::Type, + src_type: &ast::Type, + instr_type: &ast::Type, ) -> Option { if src_type == instr_type { return None; @@ -4494,21 +4753,24 @@ fn should_convert_relaxed_src( match (src_type, instr_type) { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { ScalarKind::Bit => { - if instr_type.width() <= src_type.width() { + if instr_type.size_of() <= src_type.size_of() { Some(ConversionKind::Default) } else { None } } ScalarKind::Signed | ScalarKind::Unsigned => { - if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ScalarKind::Float + { Some(ConversionKind::Default) } else { None } } ScalarKind::Float => { - if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit { + if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit + { Some(ConversionKind::Default) } else { None @@ -4519,15 +4781,18 @@ fn should_convert_relaxed_src( }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { - should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) } _ => None, } } fn should_convert_relaxed_dst_wrapper( - dst_type: ast::Type, - instr_type: ast::Type, + dst_type: &ast::Type, + instr_type: &ast::Type, _: Option, ) -> Result, TranslateError> { if dst_type == instr_type { @@ -4541,8 +4806,8 @@ fn should_convert_relaxed_dst_wrapper( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands fn should_convert_relaxed_dst( - dst_type: ast::Type, - instr_type: ast::Type, + dst_type: &ast::Type, + instr_type: &ast::Type, ) -> Option { if dst_type == instr_type { return None; @@ -4550,7 +4815,7 @@ fn should_convert_relaxed_dst( match (dst_type, instr_type) { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { ScalarKind::Bit => { - if instr_type.width() <= dst_type.width() { + if instr_type.size_of() <= dst_type.size_of() { Some(ConversionKind::Default) } else { None @@ -4558,9 +4823,9 @@ fn should_convert_relaxed_dst( } ScalarKind::Signed => { if dst_type.kind() != ScalarKind::Float { - if instr_type.width() == dst_type.width() { + if instr_type.size_of() == dst_type.size_of() { Some(ConversionKind::Default) - } else if instr_type.width() < dst_type.width() { + } else if instr_type.size_of() < dst_type.size_of() { Some(ConversionKind::SignExtend) } else { None @@ -4570,14 +4835,17 @@ fn should_convert_relaxed_dst( } } ScalarKind::Unsigned => { - if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ScalarKind::Float + { Some(ConversionKind::Default) } else { None } } ScalarKind::Float => { - if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit { + if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit + { Some(ConversionKind::Default) } else { None @@ -4588,7 +4856,10 @@ fn should_convert_relaxed_dst( }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { - should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) } _ => None, } @@ -4611,7 +4882,8 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> { f(&ast::FnArgument { align: arg.align, name: arg.name, - v_type: ast::FnArgumentType::Param(arg.v_type), + v_type: ast::FnArgumentType::Param(arg.v_type.clone()), + array_init: arg.array_init.clone(), }) }), } @@ -4698,14 +4970,17 @@ mod tests { .collect::>() } - fn assert_conversion_table Option>( + fn assert_conversion_table Option>( table: &'static str, f: F, ) { let conv_table = parse_conversion_table(table); for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() { for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() { - let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type)); + let conversion = f( + &ast::Type::Scalar(*op_type), + &ast::Type::Scalar(*instr_type), + ); if instr_idx == op_idx { assert_eq!(conversion, None); } else {