diff --git a/Cargo.toml b/Cargo.toml index 42be95a..1666bee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,8 +5,8 @@ members = [ "level_zero", "spirv_tools-sys", "notcuda", - "notcuda_inject", - "notcuda_redirect", + #"notcuda_inject", + #"notcuda_redirect", "ptx", ] diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1e90eba..1cbe721 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -28,8 +28,11 @@ quick_error! { } } -macro_rules! sub_scalar_type { +macro_rules! sub_enum { ($name:ident { $($variant:ident),+ $(,)? }) => { + sub_enum!{ $name : ScalarType { $($variant),+ } } + }; + ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { #[derive(PartialEq, Eq, Clone, Copy)] pub enum $name { $( @@ -37,23 +40,23 @@ macro_rules! sub_scalar_type { )+ } - impl From<$name> for ScalarType { - fn from(t: $name) -> ScalarType { + impl From<$name> for $base_type { + fn from(t: $name) -> $base_type { match t { $( - $name::$variant => ScalarType::$variant, + $name::$variant => $base_type::$variant, )+ } } } - impl std::convert::TryFrom for $name { + impl std::convert::TryFrom<$base_type> for $name { type Error = (); - fn try_from(t: ScalarType) -> Result { + fn try_from(t: $base_type) -> Result { match t { $( - ScalarType::$variant => Ok($name::$variant), + $base_type::$variant => Ok($name::$variant), )+ _ => Err(()), } @@ -64,6 +67,13 @@ macro_rules! sub_scalar_type { macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { + sub_type! { $type_name : Type { + $( + $variant ($($field_type),+), + )+ + }} + }; + ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { #[derive(PartialEq, Eq, Clone)] pub enum $type_name { $( @@ -71,26 +81,26 @@ macro_rules! sub_type { )+ } - impl From<$type_name> for Type { + impl From<$type_name> for $base_type { #[allow(non_snake_case)] - fn from(t: $type_name) -> Type { + fn from(t: $type_name) -> $base_type { match t { $( - $type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+), + $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), )+ } } } - impl std::convert::TryFrom for $type_name { + impl std::convert::TryFrom<$base_type> for $type_name { type Error = (); #[allow(non_snake_case)] #[allow(unreachable_patterns)] - fn try_from(t: Type) -> Result { + fn try_from(t: $base_type) -> Result { match t { $( - Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), + $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), )+ _ => Err(()), } @@ -99,10 +109,12 @@ macro_rules! sub_type { }; } +// Pointer is used when doing SLM converison to SPIRV sub_type! { VariableRegType { Scalar(ScalarType), Vector(SizedScalarType, u8), + Pointer(SizedScalarType, PointerStateSpace) } } @@ -146,13 +158,13 @@ sub_type! { // .param .b32 foobar[] sub_type! { VariableParamType { - Scalar(ParamScalarType), + Scalar(LdStScalarType), Array(SizedScalarType, VecU32), Pointer(SizedScalarType, PointerStateSpace), } } -sub_scalar_type!(SizedScalarType { +sub_enum!(SizedScalarType { B8, B16, B32, @@ -171,7 +183,7 @@ sub_scalar_type!(SizedScalarType { F64, }); -sub_scalar_type!(ParamScalarType { +sub_enum!(LdStScalarType { B8, B16, B32, @@ -232,7 +244,11 @@ pub enum Directive<'a, P: ArgParams> { pub enum MethodDecl<'a, ID> { Func(Vec>, ID, Vec>), - Kernel(&'a str, Vec>), + Kernel { + name: &'a str, + in_args: Vec>, + uses_shared_mem: bool, + }, } pub type FnArgument = Variable; @@ -262,25 +278,52 @@ impl From for Type { match t { FnArgumentType::Reg(x) => x.into(), FnArgumentType::Param(x) => x.into(), - FnArgumentType::Shared => Type::Scalar(ScalarType::B64), + FnArgumentType::Shared => { + Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) + } } } } -#[derive(PartialEq, Eq, Clone, Copy)] -pub enum PointerStateSpace { - Global, - Const, - Shared, - Param, -} +sub_enum!( + PointerStateSpace : LdStateSpace { + Global, + Const, + Shared, + Param, + } +); #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec), - Pointer(ScalarType, PointerStateSpace), + Pointer(PointerType, LdStateSpace), +} + +sub_type! { + PointerType { + Scalar(ScalarType), + Vector(ScalarType, u8), + } +} + +impl From for PointerType { + fn from(t: SizedScalarType) -> Self { + PointerType::Scalar(t.into()) + } +} + +impl TryFrom for SizedScalarType { + type Error = (); + + fn try_from(value: PointerType) -> Result { + match value { + PointerType::Scalar(t) => Ok(t.try_into()?), + PointerType::Vector(_, _) => Err(()), + } + } } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -304,7 +347,7 @@ pub enum ScalarType { Pred, } -sub_scalar_type!(IntType { +sub_enum!(IntType { U8, U16, U32, @@ -315,9 +358,9 @@ sub_scalar_type!(IntType { S64 }); -sub_scalar_type!(UIntType { U8, U16, U32, U64 }); +sub_enum!(UIntType { U8, U16, U32, U64 }); -sub_scalar_type!(SIntType { S8, S16, S32, S64 }); +sub_enum!(SIntType { S8, S16, S32, S64 }); impl IntType { pub fn is_signed(self) -> bool { @@ -341,7 +384,7 @@ impl IntType { } } -sub_scalar_type!(FloatType { +sub_enum!(FloatType { F16, F16x2, F32, @@ -615,7 +658,23 @@ pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub typ: Type, + pub typ: LdStType, +} + +sub_type! { + LdStType { + Scalar(LdStScalarType), + Vector(LdStScalarType, u8), + } +} + +impl From for PointerType { + fn from(t: LdStType) -> Self { + match t { + LdStType::Scalar(t) => PointerType::Scalar(t.into()), + LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), + } + } } #[derive(Copy, Clone, PartialEq, Eq)] @@ -860,7 +919,7 @@ pub enum ShlType { B64, } -sub_scalar_type!(ShrType { +sub_enum!(ShrType { B16, B32, B64, @@ -876,7 +935,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub typ: Type, + pub typ: LdStType, } #[derive(PartialEq, Eq, Copy, Clone)] @@ -900,7 +959,7 @@ pub struct RetData { pub uniform: bool, } -sub_scalar_type!(OrType { +sub_enum!(OrType { Pred, B16, B32, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index c066ae4..c29d16b 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -237,7 +237,8 @@ LinkingDirectives: ast::LinkingDirective = { } MethodDecl: ast::MethodDecl<'input, &'input str> = { - ".entry" => ast::MethodDecl::Kernel(name, params), + ".entry" => + ast::MethodDecl::Kernel{ name, in_args, uses_shared_mem: false }, ".func" => { ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) } @@ -294,10 +295,6 @@ ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, ".pred" => ast::ScalarType::Pred, - LdStScalarType -}; - -LdStScalarType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, @@ -442,7 +439,7 @@ ModuleVariable: ast::Variable = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space ParamVariable: (Option, Vec, ast::VariableParamType, &'input str) = { - ".param" > => { + ".param" > => { let (align, t, name) = var; let v_type = ast::VariableParamType::Scalar(t); (align, Vec::new(), v_type, name) @@ -506,22 +503,22 @@ SizedScalarType: ast::SizedScalarType = { } #[inline] -ParamScalarType: ast::ParamScalarType = { - ".b8" => ast::ParamScalarType::B8, - ".b16" => ast::ParamScalarType::B16, - ".b32" => ast::ParamScalarType::B32, - ".b64" => ast::ParamScalarType::B64, - ".u8" => ast::ParamScalarType::U8, - ".u16" => ast::ParamScalarType::U16, - ".u32" => ast::ParamScalarType::U32, - ".u64" => ast::ParamScalarType::U64, - ".s8" => ast::ParamScalarType::S8, - ".s16" => ast::ParamScalarType::S16, - ".s32" => ast::ParamScalarType::S32, - ".s64" => ast::ParamScalarType::S64, - ".f16" => ast::ParamScalarType::F16, - ".f32" => ast::ParamScalarType::F32, - ".f64" => ast::ParamScalarType::F64, +LdStScalarType: ast::LdStScalarType = { + ".b8" => ast::LdStScalarType::B8, + ".b16" => ast::LdStScalarType::B16, + ".b32" => ast::LdStScalarType::B32, + ".b64" => ast::LdStScalarType::B64, + ".u8" => ast::LdStScalarType::U8, + ".u16" => ast::LdStScalarType::U16, + ".u32" => ast::LdStScalarType::U32, + ".u64" => ast::LdStScalarType::U64, + ".s8" => ast::LdStScalarType::S8, + ".s16" => ast::LdStScalarType::S16, + ".s32" => ast::LdStScalarType::S32, + ".s64" => ast::LdStScalarType::S64, + ".f16" => ast::LdStScalarType::F16, + ".f32" => ast::LdStScalarType::F32, + ".f64" => ast::LdStScalarType::F64, } Instruction: ast::Instruction> = { @@ -572,9 +569,9 @@ OperandOrVector: ast::OperandOrVector<&'input str> = { => ast::OperandOrVector::Vec(dst) } -LdStType: ast::Type = { - => ast::Type::Vector(t, v), - => ast::Type::Scalar(t), +LdStType: ast::LdStType = { + => ast::LdStType::Vector(t, v), + => ast::LdStType::Scalar(t), } LdStQualifier: ast::LdStQualifier = { diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index 84e7eac..b184980 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -2,52 +2,67 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 - %29 = OpExtInstImport "OpenCL.std" + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %32 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "cvta" + OpEntryPoint Kernel %2 "extern_shared" %1 %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint + %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar + %40 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar +%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %32 - %7 = OpFunctionParameter %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong +%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %2 = OpFunction %void None %40 %8 = OpFunctionParameter %ulong - %27 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function + %9 = OpFunctionParameter %ulong + %28 = OpFunctionParameter %_ptr_Workgroup_uchar + %41 = OpLabel + %29 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %29 %28 + OpBranch %26 + %26 = OpLabel OpStore %3 %8 - %10 = OpLoad %ulong %2 - %9 = OpCopyObject %ulong %10 OpStore %4 %9 - %12 = OpLoad %ulong %3 - %11 = OpCopyObject %ulong %12 - OpStore %5 %11 - %14 = OpLoad %ulong %4 - %22 = OpCopyObject %ulong %14 - %21 = OpCopyObject %ulong %22 - %13 = OpCopyObject %ulong %21 - OpStore %4 %13 - %16 = OpLoad %ulong %5 - %24 = OpCopyObject %ulong %16 - %23 = OpCopyObject %ulong %24 - %15 = OpCopyObject %ulong %23 - OpStore %5 %15 - %18 = OpLoad %ulong %4 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 - %17 = OpLoad %float %25 - OpStore %6 %17 - %19 = OpLoad %ulong %5 - %20 = OpLoad %float %6 - %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 - OpStore %26 %20 + %11 = OpLoad %ulong %3 + %10 = OpCopyObject %ulong %11 + OpStore %5 %10 + %13 = OpLoad %ulong %4 + %12 = OpCopyObject %ulong %13 + OpStore %6 %12 + %15 = OpLoad %ulong %5 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %15 + %14 = OpLoad %ulong %22 + OpStore %7 %14 + %30 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29 + %16 = OpLoad %_ptr_Workgroup_uint %30 + %17 = OpLoad %ulong %7 + %23 = OpBitcast %_ptr_Workgroup_ulong %16 + OpStore %23 %17 + %31 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29 + %19 = OpLoad %_ptr_Workgroup_uint %31 + %24 = OpBitcast %_ptr_Workgroup_ulong %19 + %18 = OpLoad %ulong %24 + OpStore %7 %18 + %20 = OpLoad %ulong %6 + %21 = OpLoad %ulong %7 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20 + OpStore %25 %21 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 14c3bc9..4c5f9b3 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -107,27 +107,33 @@ fn test_ptx_assert<'a, T: From + ze::SafeRepr + Debug + Copy + PartialEq>( let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; assert!(errors.len() == 0); - let (spirv, _) = translate::to_spirv(ast)?; + let notcuda_module = translate::to_spirv_module(ast)?; let name = CString::new(name)?; - let result = - run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?; + let result = run_spirv(name.as_c_str(), notcuda_module, input, output) + .map_err(|err| DisplayError { err })?; assert_eq!(output, result.as_slice()); Ok(()) } fn run_spirv + ze::SafeRepr + Copy + Debug>( name: &CStr, - spirv: &[u32], + module: translate::Module, input: &[T], output: &mut [T], ) -> ze::Result> { ze::init()?; + let spirv = module.spirv.assemble(); let byte_il = unsafe { slice::from_raw_parts::( spirv.as_ptr() as *const _, spirv.len() * mem::size_of::(), ) }; + let use_shared_mem = module + .kernel_info + .get(name.to_str().unwrap()) + .unwrap() + .uses_shared_mem; let mut result = vec![0u8.into(); output.len()]; { let mut drivers = ze::Driver::get()?; @@ -140,7 +146,7 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( let module = match module { Ok(m) => m, Err(err) => { - let raw_err_string = log.get_cstring()?; + let raw_err_string = log.get_cstring()?; let err_string = raw_err_string.to_string_lossy(); panic!("{:?}\n{}", err, err_string); } @@ -164,6 +170,9 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( kernel.set_group_size(1, 1, 1)?; kernel.set_arg_buffer(0, inp_b_ptr_mut)?; kernel.set_arg_buffer(1, out_b_ptr_mut)?; + if use_shared_mem { + unsafe { kernel.set_arg_raw(2, 128, ptr::null())? }; + } cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?; cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?; queue.execute(cmd_list)?; @@ -179,7 +188,7 @@ fn test_spvtxt_assert<'a>( let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; assert!(errors.len() == 0); - let (ptx_mod, _) = translate::to_spirv_module(ast)?; + let spirv_module = translate::to_spirv_module(ast)?; let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut()); @@ -211,9 +220,9 @@ fn test_spvtxt_assert<'a>( rspirv::binary::parse_words(&parsed_spirv, &mut loader)?; let spvtxt_mod = loader.module(); unsafe { spirv_tools::spvBinaryDestroy(spv_binary) }; - if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) { + if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) { // We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer - let spv_from_ptx_binary = ptx_mod.assemble(); + let spv_from_ptx_binary = spirv_module.spirv.assemble(); let mut spv_text: spirv_tools::spv_text = ptr::null_mut(); let result = unsafe { spirv_tools::spvBinaryToText( @@ -234,7 +243,7 @@ fn test_spvtxt_assert<'a>( // TODO: stop leaking kernel text Cow::Borrowed(spv_from_ptx_text) } else { - Cow::Owned(ptx_mod.disassemble()) + Cow::Owned(spirv_module.spirv.disassemble()) }; if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") { let mut path = PathBuf::from(dump_path); diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 09dd0bb..ab7187f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,8 +1,11 @@ use crate::ast; use half::f16; use rspirv::{binary::Disassemble, dr}; -use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, iter, mem}; +use std::{ + collections::{hash_map, HashMap, HashSet}, + convert::TryFrom, +}; use rspirv::binary::Assemble; @@ -12,7 +15,7 @@ quick_error! { UnknownSymbol {} UntypedSymbol {} MismatchedType {} - Spirv (err: rspirv::dr::Error) { + Spirv(err: rspirv::dr::Error) { from() display("{}", err) cause(err) @@ -45,8 +48,15 @@ impl From for SpirvType { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(typ, state_space) => { - SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into()) + ast::Type::Pointer(ast::PointerType::Scalar(typ), state_space) => SpirvType::Pointer( + Box::new(SpirvType::Base(typ.into())), + state_space.to_spirv(), + ), + ast::Type::Pointer(ast::PointerType::Vector(typ, len), state_space) => { + SpirvType::Pointer( + Box::new(SpirvType::Vector(typ.into(), len)), + state_space.to_spirv(), + ) } } } @@ -365,12 +375,16 @@ impl TypeWordMap { } }, ast::Type::Pointer(typ, state_space) => { - let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?; + let base_t = typ.clone().into(); + let base = self.get_or_add_constant(b, &base_t, &[])?; let result_type = self.get_or_add( b, - SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()), + SpirvType::Pointer( + Box::new(SpirvType::from(base_t)), + (*state_space).to_spirv(), + ), ); - b.variable(result_type, None, (*state_space).into(), Some(base)) + b.variable(result_type, None, (*state_space).to_spirv(), Some(base)) } }) } @@ -402,9 +416,17 @@ impl TypeWordMap { } } -pub fn to_spirv_module<'a>( - ast: ast::Module<'a>, -) -> Result<(dr::Module, HashMap>), TranslateError> { +pub struct Module { + pub spirv: dr::Module, + pub kernel_info: HashMap, +} + +pub struct KernelInfo { + pub arguments_sizes: Vec, + pub uses_shared_mem: bool, +} + +pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { let mut id_defs = GlobalStringIdResolver::new(1); let directives = ast .directives @@ -413,6 +435,9 @@ pub fn to_spirv_module<'a>( .collect::, _>>()?; let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); + let mut directives = + convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id()); + normalize_variable_decls(&mut directives); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 3); emit_capabilities(&mut builder); @@ -421,7 +446,7 @@ pub fn to_spirv_module<'a>( emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); emit_builtins(&mut builder, &mut map, &id_defs); - let mut args_len = HashMap::new(); + let mut kernel_info = HashMap::new(); for d in directives { match d { Directive::Variable(var) => { @@ -433,13 +458,20 @@ pub fn to_spirv_module<'a>( None => continue, }; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; - emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?; + emit_function_header( + &mut builder, + &mut map, + &id_defs, + f.func_decl, + &mut kernel_info, + )?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; builder.end_function()?; } } } - Ok((builder.module(), args_len)) + let spirv = builder.module(); + Ok(Module { spirv, kernel_info }) } type MultiHashMap = HashMap>; @@ -461,16 +493,18 @@ fn multi_hash_map_append(m: &mut MultiHashMap, // This pass looks for all uses of .extern .shared and converts them to // an additional method argument fn convert_dynamic_shared_memory_usage<'input>( - new_id: &mut impl FnMut() -> spirv::Word, id_defs: &mut GlobalStringIdResolver<'input>, module: Vec>, + new_id: &mut impl FnMut() -> spirv::Word, ) -> Vec> { - let mut extern_shared_decls = HashSet::new(); + let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { Directive::Variable(var) => { - if let ast::VariableType::Shared(_) = var.v_type { - extern_shared_decls.insert(var.name); + if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = + var.v_type + { + extern_shared_decls.insert(var.name, p_type); } } _ => {} @@ -490,7 +524,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), }) => { let call_key = match func_decl { - ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), + ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name), ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), }; let statements = statements @@ -501,7 +535,7 @@ fn convert_dynamic_shared_memory_usage<'input>( Statement::Call(call) } statement => statement.map_id(&mut |id| { - if extern_shared_decls.contains(&id) { + if extern_shared_decls.contains_key(&id) { methods_using_extern_shared.insert(call_key); } id @@ -530,7 +564,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements), }) => { let call_key = match func_decl { - ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), + ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name), ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), }; if !methods_using_extern_shared.contains(&call_key) { @@ -550,8 +584,13 @@ fn convert_dynamic_shared_memory_usage<'input>( name: shared_id_param, }); } - ast::MethodDecl::Kernel(_, input_args) => { - input_args.push(ast::Variable { + ast::MethodDecl::Kernel { + in_args, + uses_shared_mem, + .. + } => { + *uses_shared_mem = true; + in_args.push(ast::Variable { align: None, v_type: ast::KernelArgumentType::Shared, array_init: Vec::new(), @@ -559,33 +598,37 @@ fn convert_dynamic_shared_memory_usage<'input>( }); } } - let statements = statements - .into_iter() - .map(|statement| match statement { - Statement::Call(mut call) => { - // We can safely skip checking call arguments, - // because there's simply no way to pass shared ptr - // without converting it to .b64 first - if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) - { - call.param_list - .push((shared_id_param, ast::FnArgumentType::Shared)); - } - Statement::Call(call) - } - statement => statement.map_id(&mut |id| { - if extern_shared_decls.contains(&id) { - shared_id_param - } else { - id - } - }), - }) - .collect(); + let shared_var_id = new_id(); + let shared_var = ExpandedStatement::Variable(ast::Variable { + align: None, + name: shared_var_id, + array_init: Vec::new(), + v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( + ast::SizedScalarType::B8, + ast::PointerStateSpace::Shared, + )), + }); + let shared_var_st = ExpandedStatement::StoreVar( + ast::Arg2St { + src1: shared_var_id, + src2: shared_id_param, + }, + ast::Type::Scalar(ast::ScalarType::B8), + ); + let mut new_statements = vec![shared_var, shared_var_st]; + replace_uses_of_shared_memory( + &mut new_statements, + new_id, + &extern_shared_decls, + &mut methods_using_extern_shared, + shared_id_param, + shared_var_id, + statements, + ); Directive::Method(Function { func_decl, globals, - body: Some(statements), + body: Some(new_statements), }) } directive => directive, @@ -593,6 +636,57 @@ fn convert_dynamic_shared_memory_usage<'input>( .collect::>() } +fn replace_uses_of_shared_memory<'a>( + result: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + extern_shared_decls: &HashMap, + methods_using_extern_shared: &mut HashSet>, + shared_id_param: spirv::Word, + shared_var_id: spirv::Word, + statements: Vec, +) { + for statement in statements { + match statement { + Statement::Call(mut call) => { + // We can safely skip checking call arguments, + // because there's simply no way to pass shared ptr + // without converting it to .b64 first + if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) { + call.param_list + .push((shared_id_param, ast::FnArgumentType::Shared)); + } + result.push(Statement::Call(call)) + } + statement => { + let new_statement = statement.map_id(&mut |id| { + if let Some(typ) = extern_shared_decls.get(&id) { + let replacement_id = new_id(); + if *typ != ast::SizedScalarType::B8 { + result.push(Statement::Conversion(ImplicitConversion { + src: shared_var_id, + dst: replacement_id, + from: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::B8), + ast::LdStateSpace::Shared, + ), + to: ast::Type::Pointer( + ast::PointerType::Scalar((*typ).into()), + ast::LdStateSpace::Shared, + ), + kind: ConversionKind::PtrToPtr { spirv_ptr: true }, + })); + } + replacement_id + } else { + id + } + }); + result.push(new_statement); + } + } + } +} + fn get_callers_of_extern_shared<'a>( methods_using_extern_shared: &mut HashSet>, directly_called_by: &MultiHashMap>, @@ -670,15 +764,26 @@ fn emit_function_header<'a>( map: &mut TypeWordMap, global: &GlobalStringIdResolver<'a>, func_directive: ast::MethodDecl, - all_args_lens: &mut HashMap>, + kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { - if let ast::MethodDecl::Kernel(name, args) = &func_directive { - let args_lens = args.iter().map(|param| param.v_type.width()).collect(); - all_args_lens.insert(name.to_string(), args_lens); + if let ast::MethodDecl::Kernel { + name, + in_args, + uses_shared_mem, + } = &func_directive + { + let args_lens = in_args.iter().map(|param| param.v_type.width()).collect(); + kernel_info.insert( + name.to_string(), + KernelInfo { + arguments_sizes: args_lens, + uses_shared_mem: *uses_shared_mem, + }, + ); } let (ret_type, func_type) = get_function_type(builder, map, &func_directive); let fn_id = match func_directive { - ast::MethodDecl::Kernel(name, _) => { + ast::MethodDecl::Kernel { name, .. } => { let fn_id = global.get_id(name)?; let mut global_variables = global .variables_type_check @@ -718,8 +823,15 @@ fn emit_function_header<'a>( pub fn to_spirv<'a>( ast: ast::Module<'a>, ) -> Result<(Vec, HashMap>), TranslateError> { - let (module, all_args_lens) = to_spirv_module(ast)?; - Ok((module.assemble(), all_args_lens)) + let module = to_spirv_module(ast)?; + Ok(( + module.spirv.assemble(), + module + .kernel_info + .into_iter() + .map(|(k, v)| (k, v.arguments_sizes)) + .collect(), + )) } fn emit_capabilities(builder: &mut dr::Builder) { @@ -843,8 +955,7 @@ fn to_ssa<'input, 'b>( insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - let sorted_statements = normalize_variable_decls(labeled_statements); - let (f_body, globals) = extract_globals(sorted_statements); + let (f_body, globals) = extract_globals(labeled_statements); Ok(Function { func_decl: f_args, globals: globals, @@ -859,12 +970,20 @@ fn extract_globals( (sorted_statements, Vec::new()) } -fn normalize_variable_decls(mut func: Vec) -> Vec { - func[1..].sort_by_key(|s| match s { - Statement::Variable(_) => 0, - _ => 1, - }); - func +fn normalize_variable_decls(directives: &mut Vec) { + for directive in directives { + match directive { + Directive::Method(Function { + body: Some(func), .. + }) => { + func[1..].sort_by_key(|s| match s { + Statement::Variable(_) => 0, + _ => 1, + }); + } + _ => (), + } + } } fn convert_to_typed_statements( @@ -1138,8 +1257,8 @@ fn insert_mem_ssa_statements<'a, 'b>( ) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec), TranslateError> { let mut result = Vec::with_capacity(func.len()); let out_param = match &mut f_args { - ast::MethodDecl::Kernel(_, in_params) => { - for p in in_params.iter_mut() { + ast::MethodDecl::Kernel { in_args, .. } => { + for p in in_args.iter_mut() { let typ = ast::Type::from(p.v_type.clone()); let new_id = id_def.new_id(typ.clone()); result.push(Statement::Variable(ast::Variable { @@ -1736,7 +1855,7 @@ fn insert_implicit_conversions_impl( conversion_fn = bitcast_physical_pointer; } ArgumentSemantics::RegisterPointer => { - conversion_fn = force_bitcast; + conversion_fn = bitcast_logical_pointer; } ArgumentSemantics::Address => { conversion_fn = force_bitcast_ptr_to_bit; @@ -1790,10 +1909,10 @@ fn get_function_type( .iter() .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), ), - ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( + ast::MethodDecl::Kernel { in_args, .. } => map.get_or_add_fn( builder, iter::empty(), - params + in_args .iter() .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), ), @@ -1886,14 +2005,19 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak { todo!() } - let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone())); + let result_type = + map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone()))); match data.state_space { - ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { + ast::LdStateSpace::Generic + | ast::LdStateSpace::Global + | ast::LdStateSpace::Shared => { builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param | ast::LdStateSpace::Local => { - let result_type = - map.get_or_add(builder, SpirvType::from(data.typ.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::from(ast::Type::from(data.typ.clone())), + ); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } _ => todo!(), @@ -1906,11 +2030,14 @@ fn emit_function_body_ops( if data.state_space == ast::StStateSpace::Param || data.state_space == ast::StStateSpace::Local { - let result_type = - map.get_or_add(builder, SpirvType::from(data.typ.clone())); + let result_type = map.get_or_add( + builder, + SpirvType::from(ast::Type::from(data.typ.clone())), + ); builder.copy_object(result_type, Some(arg.src1), arg.src2)?; } else if data.state_space == ast::StStateSpace::Generic || data.state_space == ast::StStateSpace::Global + || data.state_space == ast::StStateSpace::Shared { builder.store(arg.src1, arg.src2, None, &[])?; } else { @@ -2642,10 +2769,7 @@ fn emit_implicit_conversion( builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::BitToPtr(space)) => { - let dst_type = map.get_or_add( - builder, - SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()), - ); + let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { @@ -2703,6 +2827,20 @@ fn emit_implicit_conversion( let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } + (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { + let result_type = if spirv_ptr { + map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::from(cv.to.clone())), + spirv::StorageClass::Function, + ), + ) + } else { + map.get_or_add(builder, SpirvType::from(cv.to.clone())) + }; + builder.bitcast(result_type, Some(cv.dst), cv.src)?; + } _ => unreachable!(), } Ok(()) @@ -2903,9 +3041,15 @@ impl<'a> GlobalStringIdResolver<'a> { type_check: HashMap::new(), }; let new_fn_decl = match header { - ast::MethodDecl::Kernel(name, params) => { - ast::MethodDecl::Kernel(name, expand_kernel_params(&mut fn_resolver, params.iter())) - } + ast::MethodDecl::Kernel { + name, + in_args, + uses_shared_mem, + } => ast::MethodDecl::Kernel { + name, + in_args: expand_kernel_params(&mut fn_resolver, in_args.iter()), + uses_shared_mem: *uses_shared_mem, + }, ast::MethodDecl::Func(ret_params, _, params) => { let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter()); let params_ids = expand_fn_params(&mut fn_resolver, params.iter()); @@ -3598,7 +3742,7 @@ impl ast::Instruction { ast::Instruction::Ld(d, a) => { let is_param = d.state_space == ast::LdStateSpace::Param || d.state_space == ast::LdStateSpace::Local; - let new_args = a.map(visitor, &d.typ, is_param)?; + let new_args = a.map(visitor, &d, is_param)?; ast::Instruction::Ld(d, new_args) } ast::Instruction::Mov(d, a) => { @@ -3655,7 +3799,7 @@ impl ast::Instruction { ast::Instruction::St(d, a) => { let is_param = d.state_space == ast::StStateSpace::Param || d.state_space == ast::StStateSpace::Local; - let new_args = a.map(visitor, &d.typ, is_param)?; + let new_args = a.map(visitor, &d, is_param)?; ast::Instruction::St(d, new_args) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), @@ -3826,29 +3970,36 @@ impl ast::Type { scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), - state_space: ast::PointerStateSpace::Global, + state_space: ast::LdStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], - state_space: ast::PointerStateSpace::Global, + state_space: ast::LdStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), - state_space: ast::PointerStateSpace::Global, + state_space: ast::LdStateSpace::Global, }, - ast::Type::Pointer(scalar, state_space) => TypeParts { - kind: TypeKind::Pointer, + ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts { + kind: TypeKind::PointerScalar, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), state_space: *state_space, }, + ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts { + kind: TypeKind::PointerVector, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: vec![*len as u32], + state_space: *state_space, + }, } } @@ -3865,8 +4016,15 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), - TypeKind::Pointer => ast::Type::Pointer( - ast::ScalarType::from_parts(t.width, t.scalar_kind), + TypeKind::PointerScalar => ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), + t.state_space, + ), + TypeKind::PointerVector => ast::Type::Pointer( + ast::PointerType::Vector( + ast::ScalarType::from_parts(t.width, t.scalar_kind), + t.components[0] as u8, + ), t.state_space, ), } @@ -3879,7 +4037,7 @@ struct TypeParts { scalar_kind: ScalarKind, width: u8, components: Vec, - state_space: ast::PointerStateSpace, + state_space: ast::LdStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -3887,7 +4045,8 @@ enum TypeKind { Scalar, Vector, Array, - Pointer, + PointerScalar, + PointerVector, } impl ast::Instruction { @@ -4007,6 +4166,7 @@ enum ConversionKind { SignExtend, BitToPtr(ast::LdStateSpace), PtrToBit, + PtrToPtr { spirv_ptr: bool }, } impl ast::PredAt { @@ -4058,7 +4218,7 @@ impl ast::VariableParamType { (ast::ScalarType::from(*t).size_of() as usize) * (len.iter().fold(1, |x, y| x * (*y)) as usize) } - ast::VariableParamType::Pointer(_, _) => mem::size_of::() + ast::VariableParamType::Pointer(_, _) => mem::size_of::(), } } } @@ -4076,7 +4236,10 @@ impl From for ast::Type { fn from(this: ast::KernelArgumentType) -> Self { match this { ast::KernelArgumentType::Normal(typ) => typ.into(), - ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64), + ast::KernelArgumentType::Shared => ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::B8), + ast::LdStateSpace::Shared, + ), } } } @@ -4085,9 +4248,10 @@ impl ast::KernelArgumentType { fn to_param(self) -> ast::VariableParamType { match self { ast::KernelArgumentType::Normal(p) => p, - ast::KernelArgumentType::Shared => { - ast::VariableParamType::Scalar(ast::ParamScalarType::B64) - } + ast::KernelArgumentType::Shared => ast::VariableParamType::Pointer( + ast::SizedScalarType::B8, + ast::PointerStateSpace::Shared, + ), } } } @@ -4193,7 +4357,7 @@ impl ast::Arg2Ld { fn map>( self, visitor: &mut V, - t: &ast::Type, + details: &ast::LdDetails, is_param: bool, ) -> Result, TranslateError> { let dst = visitor.id_or_vector( @@ -4202,7 +4366,7 @@ impl ast::Arg2Ld { is_dst: true, sema: ArgumentSemantics::DefaultRelaxed, }, - &ast::Type::from(t.clone()), + &ast::Type::from(details.typ.clone()), )?; let src = visitor.operand( ArgumentDescriptor { @@ -4214,7 +4378,14 @@ impl ast::Arg2Ld { ArgumentSemantics::PhysicalPointer }, }, - t, + &(if is_param { + ast::Type::from(details.typ.clone()) + } else { + ast::Type::Pointer( + ast::PointerType::from(details.typ.clone()), + details.state_space, + ) + }), )?; Ok(ast::Arg2Ld { dst, src }) } @@ -4233,7 +4404,7 @@ impl ast::Arg2St { fn map>( self, visitor: &mut V, - t: &ast::Type, + details: &ast::StData, is_param: bool, ) -> Result, TranslateError> { let src1 = visitor.operand( @@ -4246,7 +4417,14 @@ impl ast::Arg2St { ArgumentSemantics::PhysicalPointer }, }, - t, + &(if is_param { + details.typ.clone().into() + } else { + ast::Type::Pointer( + ast::PointerType::from(details.typ.clone()), + details.state_space.to_ld_ss(), + ) + }), )?; let src2 = visitor.operand_or_vector( ArgumentDescriptor { @@ -4254,7 +4432,7 @@ impl ast::Arg2St { is_dst: false, sema: ArgumentSemantics::DefaultRelaxed, }, - t, + &details.typ.clone().into(), )?; Ok(ast::Arg2St { src1, src2 }) } @@ -4957,7 +5135,7 @@ impl ast::MulDetails { } } -fn force_bitcast( +fn bitcast_logical_pointer( operand: &ast::Type, instr: &ast::Type, _: Option, @@ -4971,21 +5149,12 @@ fn force_bitcast( fn bitcast_physical_pointer( operand_type: &ast::Type, - _: &ast::Type, + instr_type: &ast::Type, ss: Option, ) -> Result, TranslateError> { match operand_type { // array decays to a pointer - ast::Type::Array(_, vec) => { - if vec.len() != 0 { - return Err(TranslateError::MismatchedType); - } - if let Some(space) = ss { - Ok(Some(ConversionKind::BitToPtr(space))) - } else { - Err(TranslateError::Unreachable) - } - } + ast::Type::Array(_, _) => todo!(), ast::Type::Scalar(ast::ScalarType::B64) | ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::S64) => { @@ -4995,6 +5164,27 @@ fn bitcast_physical_pointer( Err(TranslateError::Unreachable) } } + ast::Type::Pointer(op_scalar_t, op_space) => { + if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { + if op_space == instr_space { + if op_scalar_t == instr_scalar_t { + Ok(None) + } else { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } + } else { + if *op_space == ast::LdStateSpace::Generic + || *instr_space == ast::LdStateSpace::Generic + { + Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) + } else { + Err(TranslateError::MismatchedType) + } + } + } else { + Err(TranslateError::MismatchedType) + } + } _ => Err(TranslateError::MismatchedType), } } @@ -5206,7 +5396,7 @@ fn should_convert_relaxed_dst( impl<'a> ast::MethodDecl<'a, &'a str> { fn name(&self) -> &'a str { match self { - ast::MethodDecl::Kernel(name, _) => name, + ast::MethodDecl::Kernel { name, .. } => name, ast::MethodDecl::Func(_, name, _) => name, } } @@ -5216,7 +5406,7 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> { fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument)) { match self { ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f), - ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| { + ast::MethodDecl::Kernel { in_args, .. } => in_args.iter().for_each(|arg| { f(&ast::FnArgument { align: arg.align, name: arg.name,