Add dynamic shared mem support

This commit is contained in:
Andrzej Janik 2020-10-25 00:24:40 +02:00
commit 85ee8210df
6 changed files with 491 additions and 221 deletions

View file

@ -5,8 +5,8 @@ members = [
"level_zero", "level_zero",
"spirv_tools-sys", "spirv_tools-sys",
"notcuda", "notcuda",
"notcuda_inject", #"notcuda_inject",
"notcuda_redirect", #"notcuda_redirect",
"ptx", "ptx",
] ]

View file

@ -28,8 +28,11 @@ quick_error! {
} }
} }
macro_rules! sub_scalar_type { macro_rules! sub_enum {
($name:ident { $($variant:ident),+ $(,)? }) => { ($name:ident { $($variant:ident),+ $(,)? }) => {
sub_enum!{ $name : ScalarType { $($variant),+ } }
};
($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
#[derive(PartialEq, Eq, Clone, Copy)] #[derive(PartialEq, Eq, Clone, Copy)]
pub enum $name { pub enum $name {
$( $(
@ -37,23 +40,23 @@ macro_rules! sub_scalar_type {
)+ )+
} }
impl From<$name> for ScalarType { impl From<$name> for $base_type {
fn from(t: $name) -> ScalarType { fn from(t: $name) -> $base_type {
match t { match t {
$( $(
$name::$variant => ScalarType::$variant, $name::$variant => $base_type::$variant,
)+ )+
} }
} }
} }
impl std::convert::TryFrom<ScalarType> for $name { impl std::convert::TryFrom<$base_type> for $name {
type Error = (); type Error = ();
fn try_from(t: ScalarType) -> Result<Self, Self::Error> { fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t { match t {
$( $(
ScalarType::$variant => Ok($name::$variant), $base_type::$variant => Ok($name::$variant),
)+ )+
_ => Err(()), _ => Err(()),
} }
@ -64,6 +67,13 @@ macro_rules! sub_scalar_type {
macro_rules! sub_type { macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
sub_type! { $type_name : Type {
$(
$variant ($($field_type),+),
)+
}}
};
($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub enum $type_name { pub enum $type_name {
$( $(
@ -71,26 +81,26 @@ macro_rules! sub_type {
)+ )+
} }
impl From<$type_name> for Type { impl From<$type_name> for $base_type {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn from(t: $type_name) -> Type { fn from(t: $type_name) -> $base_type {
match t { match t {
$( $(
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+), $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
)+ )+
} }
} }
} }
impl std::convert::TryFrom<Type> for $type_name { impl std::convert::TryFrom<$base_type> for $type_name {
type Error = (); type Error = ();
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(unreachable_patterns)] #[allow(unreachable_patterns)]
fn try_from(t: Type) -> Result<Self, Self::Error> { fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t { match t {
$( $(
Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
)+ )+
_ => Err(()), _ => Err(()),
} }
@ -99,10 +109,12 @@ macro_rules! sub_type {
}; };
} }
// Pointer is used when doing SLM converison to SPIRV
sub_type! { sub_type! {
VariableRegType { VariableRegType {
Scalar(ScalarType), Scalar(ScalarType),
Vector(SizedScalarType, u8), Vector(SizedScalarType, u8),
Pointer(SizedScalarType, PointerStateSpace)
} }
} }
@ -146,13 +158,13 @@ sub_type! {
// .param .b32 foobar[] // .param .b32 foobar[]
sub_type! { sub_type! {
VariableParamType { VariableParamType {
Scalar(ParamScalarType), Scalar(LdStScalarType),
Array(SizedScalarType, VecU32), Array(SizedScalarType, VecU32),
Pointer(SizedScalarType, PointerStateSpace), Pointer(SizedScalarType, PointerStateSpace),
} }
} }
sub_scalar_type!(SizedScalarType { sub_enum!(SizedScalarType {
B8, B8,
B16, B16,
B32, B32,
@ -171,7 +183,7 @@ sub_scalar_type!(SizedScalarType {
F64, F64,
}); });
sub_scalar_type!(ParamScalarType { sub_enum!(LdStScalarType {
B8, B8,
B16, B16,
B32, B32,
@ -232,7 +244,11 @@ pub enum Directive<'a, P: ArgParams> {
pub enum MethodDecl<'a, ID> { pub enum MethodDecl<'a, ID> {
Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
Kernel(&'a str, Vec<KernelArgument<ID>>), Kernel {
name: &'a str,
in_args: Vec<KernelArgument<ID>>,
uses_shared_mem: bool,
},
} }
pub type FnArgument<ID> = Variable<FnArgumentType, ID>; pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
@ -262,25 +278,52 @@ impl From<FnArgumentType> for Type {
match t { match t {
FnArgumentType::Reg(x) => x.into(), FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(), FnArgumentType::Param(x) => x.into(),
FnArgumentType::Shared => Type::Scalar(ScalarType::B64), FnArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
} }
} }
} }
#[derive(PartialEq, Eq, Clone, Copy)] sub_enum!(
pub enum PointerStateSpace { PointerStateSpace : LdStateSpace {
Global, Global,
Const, Const,
Shared, Shared,
Param, Param,
} }
);
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub enum Type { pub enum Type {
Scalar(ScalarType), Scalar(ScalarType),
Vector(ScalarType, u8), Vector(ScalarType, u8),
Array(ScalarType, Vec<u32>), Array(ScalarType, Vec<u32>),
Pointer(ScalarType, PointerStateSpace), Pointer(PointerType, LdStateSpace),
}
sub_type! {
PointerType {
Scalar(ScalarType),
Vector(ScalarType, u8),
}
}
impl From<SizedScalarType> for PointerType {
fn from(t: SizedScalarType) -> Self {
PointerType::Scalar(t.into())
}
}
impl TryFrom<PointerType> for SizedScalarType {
type Error = ();
fn try_from(value: PointerType) -> Result<Self, Self::Error> {
match value {
PointerType::Scalar(t) => Ok(t.try_into()?),
PointerType::Vector(_, _) => Err(()),
}
}
} }
#[derive(PartialEq, Eq, Hash, Clone, Copy)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
@ -304,7 +347,7 @@ pub enum ScalarType {
Pred, Pred,
} }
sub_scalar_type!(IntType { sub_enum!(IntType {
U8, U8,
U16, U16,
U32, U32,
@ -315,9 +358,9 @@ sub_scalar_type!(IntType {
S64 S64
}); });
sub_scalar_type!(UIntType { U8, U16, U32, U64 }); sub_enum!(UIntType { U8, U16, U32, U64 });
sub_scalar_type!(SIntType { S8, S16, S32, S64 }); sub_enum!(SIntType { S8, S16, S32, S64 });
impl IntType { impl IntType {
pub fn is_signed(self) -> bool { pub fn is_signed(self) -> bool {
@ -341,7 +384,7 @@ impl IntType {
} }
} }
sub_scalar_type!(FloatType { sub_enum!(FloatType {
F16, F16,
F16x2, F16x2,
F32, F32,
@ -615,7 +658,23 @@ pub struct LdDetails {
pub qualifier: LdStQualifier, pub qualifier: LdStQualifier,
pub state_space: LdStateSpace, pub state_space: LdStateSpace,
pub caching: LdCacheOperator, pub caching: LdCacheOperator,
pub typ: Type, pub typ: LdStType,
}
sub_type! {
LdStType {
Scalar(LdStScalarType),
Vector(LdStScalarType, u8),
}
}
impl From<LdStType> for PointerType {
fn from(t: LdStType) -> Self {
match t {
LdStType::Scalar(t) => PointerType::Scalar(t.into()),
LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
}
}
} }
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
@ -860,7 +919,7 @@ pub enum ShlType {
B64, B64,
} }
sub_scalar_type!(ShrType { sub_enum!(ShrType {
B16, B16,
B32, B32,
B64, B64,
@ -876,7 +935,7 @@ pub struct StData {
pub qualifier: LdStQualifier, pub qualifier: LdStQualifier,
pub state_space: StStateSpace, pub state_space: StStateSpace,
pub caching: StCacheOperator, pub caching: StCacheOperator,
pub typ: Type, pub typ: LdStType,
} }
#[derive(PartialEq, Eq, Copy, Clone)] #[derive(PartialEq, Eq, Copy, Clone)]
@ -900,7 +959,7 @@ pub struct RetData {
pub uniform: bool, pub uniform: bool,
} }
sub_scalar_type!(OrType { sub_enum!(OrType {
Pred, Pred,
B16, B16,
B32, B32,

View file

@ -237,7 +237,8 @@ LinkingDirectives: ast::LinkingDirective = {
} }
MethodDecl: ast::MethodDecl<'input, &'input str> = { MethodDecl: ast::MethodDecl<'input, &'input str> = {
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params), ".entry" <name:ExtendedID> <in_args:KernelArguments> =>
ast::MethodDecl::Kernel{ name, in_args, uses_shared_mem: false },
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => { ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
} }
@ -294,10 +295,6 @@ ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16, ".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2, ".f16x2" => ast::ScalarType::F16x2,
".pred" => ast::ScalarType::Pred, ".pred" => ast::ScalarType::Pred,
LdStScalarType
};
LdStScalarType: ast::ScalarType = {
".b8" => ast::ScalarType::B8, ".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16, ".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32, ".b32" => ast::ScalarType::B32,
@ -442,7 +439,7 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = { ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
".param" <var:VariableScalar<ParamScalarType>> => { ".param" <var:VariableScalar<LdStScalarType>> => {
let (align, t, name) = var; let (align, t, name) = var;
let v_type = ast::VariableParamType::Scalar(t); let v_type = ast::VariableParamType::Scalar(t);
(align, Vec::new(), v_type, name) (align, Vec::new(), v_type, name)
@ -506,22 +503,22 @@ SizedScalarType: ast::SizedScalarType = {
} }
#[inline] #[inline]
ParamScalarType: ast::ParamScalarType = { LdStScalarType: ast::LdStScalarType = {
".b8" => ast::ParamScalarType::B8, ".b8" => ast::LdStScalarType::B8,
".b16" => ast::ParamScalarType::B16, ".b16" => ast::LdStScalarType::B16,
".b32" => ast::ParamScalarType::B32, ".b32" => ast::LdStScalarType::B32,
".b64" => ast::ParamScalarType::B64, ".b64" => ast::LdStScalarType::B64,
".u8" => ast::ParamScalarType::U8, ".u8" => ast::LdStScalarType::U8,
".u16" => ast::ParamScalarType::U16, ".u16" => ast::LdStScalarType::U16,
".u32" => ast::ParamScalarType::U32, ".u32" => ast::LdStScalarType::U32,
".u64" => ast::ParamScalarType::U64, ".u64" => ast::LdStScalarType::U64,
".s8" => ast::ParamScalarType::S8, ".s8" => ast::LdStScalarType::S8,
".s16" => ast::ParamScalarType::S16, ".s16" => ast::LdStScalarType::S16,
".s32" => ast::ParamScalarType::S32, ".s32" => ast::LdStScalarType::S32,
".s64" => ast::ParamScalarType::S64, ".s64" => ast::LdStScalarType::S64,
".f16" => ast::ParamScalarType::F16, ".f16" => ast::LdStScalarType::F16,
".f32" => ast::ParamScalarType::F32, ".f32" => ast::LdStScalarType::F32,
".f64" => ast::ParamScalarType::F64, ".f64" => ast::LdStScalarType::F64,
} }
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
@ -572,9 +569,9 @@ OperandOrVector: ast::OperandOrVector<&'input str> = {
<dst:VectorExtract> => ast::OperandOrVector::Vec(dst) <dst:VectorExtract> => ast::OperandOrVector::Vec(dst)
} }
LdStType: ast::Type = { LdStType: ast::LdStType = {
<v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v), <v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v),
<t:LdStScalarType> => ast::Type::Scalar(t), <t:LdStScalarType> => ast::LdStType::Scalar(t),
} }
LdStQualifier: ast::LdStQualifier = { LdStQualifier: ast::LdStQualifier = {

View file

@ -2,52 +2,67 @@
OpCapability Linkage OpCapability Linkage
OpCapability Addresses OpCapability Addresses
OpCapability Kernel OpCapability Kernel
OpCapability Int64
OpCapability Int8 OpCapability Int8
%29 = OpExtInstImport "OpenCL.std" OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%32 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvta" OpEntryPoint Kernel %2 "extern_shared" %1
%void = OpTypeVoid %void = OpTypeVoid
%uint = OpTypeInt 32 0
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint
%1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup
%ulong = OpTypeInt 64 0 %ulong = OpTypeInt 64 0
%32 = OpTypeFunction %void %ulong %ulong %uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
%40 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong
%float = OpTypeFloat 32 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Function_float = OpTypePointer Function %float %_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%1 = OpFunction %void None %32 %2 = OpFunction %void None %40
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong
%27 = OpLabel %9 = OpFunctionParameter %ulong
%2 = OpVariable %_ptr_Function_ulong Function %28 = OpFunctionParameter %_ptr_Workgroup_uchar
%41 = OpLabel
%29 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
%3 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_float Function %6 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %7 %7 = OpVariable %_ptr_Function_ulong Function
OpStore %29 %28
OpBranch %26
%26 = OpLabel
OpStore %3 %8 OpStore %3 %8
%10 = OpLoad %ulong %2
%9 = OpCopyObject %ulong %10
OpStore %4 %9 OpStore %4 %9
%12 = OpLoad %ulong %3 %11 = OpLoad %ulong %3
%11 = OpCopyObject %ulong %12 %10 = OpCopyObject %ulong %11
OpStore %5 %11 OpStore %5 %10
%14 = OpLoad %ulong %4 %13 = OpLoad %ulong %4
%22 = OpCopyObject %ulong %14 %12 = OpCopyObject %ulong %13
%21 = OpCopyObject %ulong %22 OpStore %6 %12
%13 = OpCopyObject %ulong %21 %15 = OpLoad %ulong %5
OpStore %4 %13 %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %15
%16 = OpLoad %ulong %5 %14 = OpLoad %ulong %22
%24 = OpCopyObject %ulong %16 OpStore %7 %14
%23 = OpCopyObject %ulong %24 %30 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29
%15 = OpCopyObject %ulong %23 %16 = OpLoad %_ptr_Workgroup_uint %30
OpStore %5 %15 %17 = OpLoad %ulong %7
%18 = OpLoad %ulong %4 %23 = OpBitcast %_ptr_Workgroup_ulong %16
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 OpStore %23 %17
%17 = OpLoad %float %25 %31 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29
OpStore %6 %17 %19 = OpLoad %_ptr_Workgroup_uint %31
%19 = OpLoad %ulong %5 %24 = OpBitcast %_ptr_Workgroup_ulong %19
%20 = OpLoad %float %6 %18 = OpLoad %ulong %24
%26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 OpStore %7 %18
OpStore %26 %20 %20 = OpLoad %ulong %6
%21 = OpLoad %ulong %7
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20
OpStore %25 %21
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd

View file

@ -107,27 +107,33 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
let mut errors = Vec::new(); let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
assert!(errors.len() == 0); assert!(errors.len() == 0);
let (spirv, _) = translate::to_spirv(ast)?; let notcuda_module = translate::to_spirv_module(ast)?;
let name = CString::new(name)?; let name = CString::new(name)?;
let result = let result = run_spirv(name.as_c_str(), notcuda_module, input, output)
run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?; .map_err(|err| DisplayError { err })?;
assert_eq!(output, result.as_slice()); assert_eq!(output, result.as_slice());
Ok(()) Ok(())
} }
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
name: &CStr, name: &CStr,
spirv: &[u32], module: translate::Module,
input: &[T], input: &[T],
output: &mut [T], output: &mut [T],
) -> ze::Result<Vec<T>> { ) -> ze::Result<Vec<T>> {
ze::init()?; ze::init()?;
let spirv = module.spirv.assemble();
let byte_il = unsafe { let byte_il = unsafe {
slice::from_raw_parts::<u8>( slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _, spirv.as_ptr() as *const _,
spirv.len() * mem::size_of::<u32>(), spirv.len() * mem::size_of::<u32>(),
) )
}; };
let use_shared_mem = module
.kernel_info
.get(name.to_str().unwrap())
.unwrap()
.uses_shared_mem;
let mut result = vec![0u8.into(); output.len()]; let mut result = vec![0u8.into(); output.len()];
{ {
let mut drivers = ze::Driver::get()?; let mut drivers = ze::Driver::get()?;
@ -140,7 +146,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let module = match module { let module = match module {
Ok(m) => m, Ok(m) => m,
Err(err) => { Err(err) => {
let raw_err_string = log.get_cstring()?; let raw_err_string = log.get_cstring()?;
let err_string = raw_err_string.to_string_lossy(); let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string); panic!("{:?}\n{}", err, err_string);
} }
@ -164,6 +170,9 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
kernel.set_group_size(1, 1, 1)?; kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?; kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
kernel.set_arg_buffer(1, out_b_ptr_mut)?; kernel.set_arg_buffer(1, out_b_ptr_mut)?;
if use_shared_mem {
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
}
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?; cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?; cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
queue.execute(cmd_list)?; queue.execute(cmd_list)?;
@ -179,7 +188,7 @@ fn test_spvtxt_assert<'a>(
let mut errors = Vec::new(); let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0); assert!(errors.len() == 0);
let (ptx_mod, _) = translate::to_spirv_module(ast)?; let spirv_module = translate::to_spirv_module(ast)?;
let spv_context = let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut()); assert!(spv_context != ptr::null_mut());
@ -211,9 +220,9 @@ fn test_spvtxt_assert<'a>(
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?; rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
let spvtxt_mod = loader.module(); let spvtxt_mod = loader.module();
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) }; unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) { if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) {
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer // We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
let spv_from_ptx_binary = ptx_mod.assemble(); let spv_from_ptx_binary = spirv_module.spirv.assemble();
let mut spv_text: spirv_tools::spv_text = ptr::null_mut(); let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
let result = unsafe { let result = unsafe {
spirv_tools::spvBinaryToText( spirv_tools::spvBinaryToText(
@ -234,7 +243,7 @@ fn test_spvtxt_assert<'a>(
// TODO: stop leaking kernel text // TODO: stop leaking kernel text
Cow::Borrowed(spv_from_ptx_text) Cow::Borrowed(spv_from_ptx_text)
} else { } else {
Cow::Owned(ptx_mod.disassemble()) Cow::Owned(spirv_module.spirv.disassemble())
}; };
if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") { if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") {
let mut path = PathBuf::from(dump_path); let mut path = PathBuf::from(dump_path);

View file

@ -1,8 +1,11 @@
use crate::ast; use crate::ast;
use half::f16; use half::f16;
use rspirv::{binary::Disassemble, dr}; use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, iter, mem}; use std::{borrow::Cow, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryFrom,
};
use rspirv::binary::Assemble; use rspirv::binary::Assemble;
@ -12,7 +15,7 @@ quick_error! {
UnknownSymbol {} UnknownSymbol {}
UntypedSymbol {} UntypedSymbol {}
MismatchedType {} MismatchedType {}
Spirv (err: rspirv::dr::Error) { Spirv(err: rspirv::dr::Error) {
from() from()
display("{}", err) display("{}", err)
cause(err) cause(err)
@ -45,8 +48,15 @@ impl From<ast::Type> for SpirvType {
ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
ast::Type::Pointer(typ, state_space) => { ast::Type::Pointer(ast::PointerType::Scalar(typ), state_space) => SpirvType::Pointer(
SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into()) Box::new(SpirvType::Base(typ.into())),
state_space.to_spirv(),
),
ast::Type::Pointer(ast::PointerType::Vector(typ, len), state_space) => {
SpirvType::Pointer(
Box::new(SpirvType::Vector(typ.into(), len)),
state_space.to_spirv(),
)
} }
} }
} }
@ -365,12 +375,16 @@ impl TypeWordMap {
} }
}, },
ast::Type::Pointer(typ, state_space) => { ast::Type::Pointer(typ, state_space) => {
let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?; let base_t = typ.clone().into();
let base = self.get_or_add_constant(b, &base_t, &[])?;
let result_type = self.get_or_add( let result_type = self.get_or_add(
b, b,
SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()), SpirvType::Pointer(
Box::new(SpirvType::from(base_t)),
(*state_space).to_spirv(),
),
); );
b.variable(result_type, None, (*state_space).into(), Some(base)) b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
} }
}) })
} }
@ -402,9 +416,17 @@ impl TypeWordMap {
} }
} }
pub fn to_spirv_module<'a>( pub struct Module {
ast: ast::Module<'a>, pub spirv: dr::Module,
) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> { pub kernel_info: HashMap<String, KernelInfo>,
}
pub struct KernelInfo {
pub arguments_sizes: Vec<usize>,
pub uses_shared_mem: bool,
}
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1); let mut id_defs = GlobalStringIdResolver::new(1);
let directives = ast let directives = ast
.directives .directives
@ -413,6 +435,9 @@ pub fn to_spirv_module<'a>(
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let mut builder = dr::Builder::new(); let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id()); builder.reserve_ids(id_defs.current_id());
let mut directives =
convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3); builder.set_version(1, 3);
emit_capabilities(&mut builder); emit_capabilities(&mut builder);
@ -421,7 +446,7 @@ pub fn to_spirv_module<'a>(
emit_memory_model(&mut builder); emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder); let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs); emit_builtins(&mut builder, &mut map, &id_defs);
let mut args_len = HashMap::new(); let mut kernel_info = HashMap::new();
for d in directives { for d in directives {
match d { match d {
Directive::Variable(var) => { Directive::Variable(var) => {
@ -433,13 +458,20 @@ pub fn to_spirv_module<'a>(
None => continue, None => continue,
}; };
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?; emit_function_header(
&mut builder,
&mut map,
&id_defs,
f.func_decl,
&mut kernel_info,
)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?; builder.end_function()?;
} }
} }
} }
Ok((builder.module(), args_len)) let spirv = builder.module();
Ok(Module { spirv, kernel_info })
} }
type MultiHashMap<K, V> = HashMap<K, Vec<V>>; type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
@ -461,16 +493,18 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
// This pass looks for all uses of .extern .shared and converts them to // This pass looks for all uses of .extern .shared and converts them to
// an additional method argument // an additional method argument
fn convert_dynamic_shared_memory_usage<'input>( fn convert_dynamic_shared_memory_usage<'input>(
new_id: &mut impl FnMut() -> spirv::Word,
id_defs: &mut GlobalStringIdResolver<'input>, id_defs: &mut GlobalStringIdResolver<'input>,
module: Vec<Directive<'input>>, module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> { ) -> Vec<Directive<'input>> {
let mut extern_shared_decls = HashSet::new(); let mut extern_shared_decls = HashMap::new();
for dir in module.iter() { for dir in module.iter() {
match dir { match dir {
Directive::Variable(var) => { Directive::Variable(var) => {
if let ast::VariableType::Shared(_) = var.v_type { if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
extern_shared_decls.insert(var.name); var.v_type
{
extern_shared_decls.insert(var.name, p_type);
} }
} }
_ => {} _ => {}
@ -490,7 +524,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements), body: Some(statements),
}) => { }) => {
let call_key = match func_decl { let call_key = match func_decl {
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
}; };
let statements = statements let statements = statements
@ -501,7 +535,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
Statement::Call(call) Statement::Call(call)
} }
statement => statement.map_id(&mut |id| { statement => statement.map_id(&mut |id| {
if extern_shared_decls.contains(&id) { if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key); methods_using_extern_shared.insert(call_key);
} }
id id
@ -530,7 +564,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements), body: Some(statements),
}) => { }) => {
let call_key = match func_decl { let call_key = match func_decl {
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name), ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id), ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
}; };
if !methods_using_extern_shared.contains(&call_key) { if !methods_using_extern_shared.contains(&call_key) {
@ -550,8 +584,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
name: shared_id_param, name: shared_id_param,
}); });
} }
ast::MethodDecl::Kernel(_, input_args) => { ast::MethodDecl::Kernel {
input_args.push(ast::Variable { in_args,
uses_shared_mem,
..
} => {
*uses_shared_mem = true;
in_args.push(ast::Variable {
align: None, align: None,
v_type: ast::KernelArgumentType::Shared, v_type: ast::KernelArgumentType::Shared,
array_init: Vec::new(), array_init: Vec::new(),
@ -559,33 +598,37 @@ fn convert_dynamic_shared_memory_usage<'input>(
}); });
} }
} }
let statements = statements let shared_var_id = new_id();
.into_iter() let shared_var = ExpandedStatement::Variable(ast::Variable {
.map(|statement| match statement { align: None,
Statement::Call(mut call) => { name: shared_var_id,
// We can safely skip checking call arguments, array_init: Vec::new(),
// because there's simply no way to pass shared ptr v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
// without converting it to .b64 first ast::SizedScalarType::B8,
if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) ast::PointerStateSpace::Shared,
{ )),
call.param_list });
.push((shared_id_param, ast::FnArgumentType::Shared)); let shared_var_st = ExpandedStatement::StoreVar(
} ast::Arg2St {
Statement::Call(call) src1: shared_var_id,
} src2: shared_id_param,
statement => statement.map_id(&mut |id| { },
if extern_shared_decls.contains(&id) { ast::Type::Scalar(ast::ScalarType::B8),
shared_id_param );
} else { let mut new_statements = vec![shared_var, shared_var_st];
id replace_uses_of_shared_memory(
} &mut new_statements,
}), new_id,
}) &extern_shared_decls,
.collect(); &mut methods_using_extern_shared,
shared_id_param,
shared_var_id,
statements,
);
Directive::Method(Function { Directive::Method(Function {
func_decl, func_decl,
globals, globals,
body: Some(statements), body: Some(new_statements),
}) })
} }
directive => directive, directive => directive,
@ -593,6 +636,57 @@ fn convert_dynamic_shared_memory_usage<'input>(
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
) {
for statement in statements {
match statement {
Statement::Call(mut call) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) {
call.param_list
.push((shared_id_param, ast::FnArgumentType::Shared));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id| {
if let Some(typ) = extern_shared_decls.get(&id) {
let replacement_id = new_id();
if *typ != ast::SizedScalarType::B8 {
result.push(Statement::Conversion(ImplicitConversion {
src: shared_var_id,
dst: replacement_id,
from: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
to: ast::Type::Pointer(
ast::PointerType::Scalar((*typ).into()),
ast::LdStateSpace::Shared,
),
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
}));
}
replacement_id
} else {
id
}
});
result.push(new_statement);
}
}
}
}
fn get_callers_of_extern_shared<'a>( fn get_callers_of_extern_shared<'a>(
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>, methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>, directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
@ -670,15 +764,26 @@ fn emit_function_header<'a>(
map: &mut TypeWordMap, map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>, global: &GlobalStringIdResolver<'a>,
func_directive: ast::MethodDecl<spirv::Word>, func_directive: ast::MethodDecl<spirv::Word>,
all_args_lens: &mut HashMap<String, Vec<usize>>, kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel(name, args) = &func_directive { if let ast::MethodDecl::Kernel {
let args_lens = args.iter().map(|param| param.v_type.width()).collect(); name,
all_args_lens.insert(name.to_string(), args_lens); in_args,
uses_shared_mem,
} = &func_directive
{
let args_lens = in_args.iter().map(|param| param.v_type.width()).collect();
kernel_info.insert(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
uses_shared_mem: *uses_shared_mem,
},
);
} }
let (ret_type, func_type) = get_function_type(builder, map, &func_directive); let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
let fn_id = match func_directive { let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => { ast::MethodDecl::Kernel { name, .. } => {
let fn_id = global.get_id(name)?; let fn_id = global.get_id(name)?;
let mut global_variables = global let mut global_variables = global
.variables_type_check .variables_type_check
@ -718,8 +823,15 @@ fn emit_function_header<'a>(
pub fn to_spirv<'a>( pub fn to_spirv<'a>(
ast: ast::Module<'a>, ast: ast::Module<'a>,
) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> { ) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let (module, all_args_lens) = to_spirv_module(ast)?; let module = to_spirv_module(ast)?;
Ok((module.assemble(), all_args_lens)) Ok((
module.spirv.assemble(),
module
.kernel_info
.into_iter()
.map(|(k, v)| (k, v.arguments_sizes))
.collect(),
))
} }
fn emit_capabilities(builder: &mut dr::Builder) { fn emit_capabilities(builder: &mut dr::Builder) {
@ -843,8 +955,7 @@ fn to_ssa<'input, 'b>(
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut(); let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let sorted_statements = normalize_variable_decls(labeled_statements); let (f_body, globals) = extract_globals(labeled_statements);
let (f_body, globals) = extract_globals(sorted_statements);
Ok(Function { Ok(Function {
func_decl: f_args, func_decl: f_args,
globals: globals, globals: globals,
@ -859,12 +970,20 @@ fn extract_globals(
(sorted_statements, Vec::new()) (sorted_statements, Vec::new())
} }
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> { fn normalize_variable_decls(directives: &mut Vec<Directive>) {
func[1..].sort_by_key(|s| match s { for directive in directives {
Statement::Variable(_) => 0, match directive {
_ => 1, Directive::Method(Function {
}); body: Some(func), ..
func }) => {
func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0,
_ => 1,
});
}
_ => (),
}
}
} }
fn convert_to_typed_statements( fn convert_to_typed_statements(
@ -1138,8 +1257,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> { ) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> {
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
let out_param = match &mut f_args { let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => { ast::MethodDecl::Kernel { in_args, .. } => {
for p in in_params.iter_mut() { for p in in_args.iter_mut() {
let typ = ast::Type::from(p.v_type.clone()); let typ = ast::Type::from(p.v_type.clone());
let new_id = id_def.new_id(typ.clone()); let new_id = id_def.new_id(typ.clone());
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
@ -1736,7 +1855,7 @@ fn insert_implicit_conversions_impl(
conversion_fn = bitcast_physical_pointer; conversion_fn = bitcast_physical_pointer;
} }
ArgumentSemantics::RegisterPointer => { ArgumentSemantics::RegisterPointer => {
conversion_fn = force_bitcast; conversion_fn = bitcast_logical_pointer;
} }
ArgumentSemantics::Address => { ArgumentSemantics::Address => {
conversion_fn = force_bitcast_ptr_to_bit; conversion_fn = force_bitcast_ptr_to_bit;
@ -1790,10 +1909,10 @@ fn get_function_type(
.iter() .iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
), ),
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( ast::MethodDecl::Kernel { in_args, .. } => map.get_or_add_fn(
builder, builder,
iter::empty(), iter::empty(),
params in_args
.iter() .iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))), .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
), ),
@ -1886,14 +2005,19 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak { if data.qualifier != ast::LdStQualifier::Weak {
todo!() todo!()
} }
let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone())); let result_type =
map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
match data.state_space { match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { ast::LdStateSpace::Generic
| ast::LdStateSpace::Global
| ast::LdStateSpace::Shared => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?; builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
} }
ast::LdStateSpace::Param | ast::LdStateSpace::Local => { ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
let result_type = let result_type = map.get_or_add(
map.get_or_add(builder, SpirvType::from(data.typ.clone())); builder,
SpirvType::from(ast::Type::from(data.typ.clone())),
);
builder.copy_object(result_type, Some(arg.dst), arg.src)?; builder.copy_object(result_type, Some(arg.dst), arg.src)?;
} }
_ => todo!(), _ => todo!(),
@ -1906,11 +2030,14 @@ fn emit_function_body_ops(
if data.state_space == ast::StStateSpace::Param if data.state_space == ast::StStateSpace::Param
|| data.state_space == ast::StStateSpace::Local || data.state_space == ast::StStateSpace::Local
{ {
let result_type = let result_type = map.get_or_add(
map.get_or_add(builder, SpirvType::from(data.typ.clone())); builder,
SpirvType::from(ast::Type::from(data.typ.clone())),
);
builder.copy_object(result_type, Some(arg.src1), arg.src2)?; builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
} else if data.state_space == ast::StStateSpace::Generic } else if data.state_space == ast::StStateSpace::Generic
|| data.state_space == ast::StStateSpace::Global || data.state_space == ast::StStateSpace::Global
|| data.state_space == ast::StStateSpace::Shared
{ {
builder.store(arg.src1, arg.src2, None, &[])?; builder.store(arg.src1, arg.src2, None, &[])?;
} else { } else {
@ -2642,10 +2769,7 @@ fn emit_implicit_conversion(
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
} }
(_, _, ConversionKind::BitToPtr(space)) => { (_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add( let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
} }
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
@ -2703,6 +2827,20 @@ fn emit_implicit_conversion(
let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?; builder.bitcast(into_type, Some(cv.dst), cv.src)?;
} }
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
let result_type = if spirv_ptr {
map.get_or_add(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(cv.to.clone())),
spirv::StorageClass::Function,
),
)
} else {
map.get_or_add(builder, SpirvType::from(cv.to.clone()))
};
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(), _ => unreachable!(),
} }
Ok(()) Ok(())
@ -2903,9 +3041,15 @@ impl<'a> GlobalStringIdResolver<'a> {
type_check: HashMap::new(), type_check: HashMap::new(),
}; };
let new_fn_decl = match header { let new_fn_decl = match header {
ast::MethodDecl::Kernel(name, params) => { ast::MethodDecl::Kernel {
ast::MethodDecl::Kernel(name, expand_kernel_params(&mut fn_resolver, params.iter())) name,
} in_args,
uses_shared_mem,
} => ast::MethodDecl::Kernel {
name,
in_args: expand_kernel_params(&mut fn_resolver, in_args.iter()),
uses_shared_mem: *uses_shared_mem,
},
ast::MethodDecl::Func(ret_params, _, params) => { ast::MethodDecl::Func(ret_params, _, params) => {
let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter()); let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter());
let params_ids = expand_fn_params(&mut fn_resolver, params.iter()); let params_ids = expand_fn_params(&mut fn_resolver, params.iter());
@ -3598,7 +3742,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Ld(d, a) => { ast::Instruction::Ld(d, a) => {
let is_param = d.state_space == ast::LdStateSpace::Param let is_param = d.state_space == ast::LdStateSpace::Param
|| d.state_space == ast::LdStateSpace::Local; || d.state_space == ast::LdStateSpace::Local;
let new_args = a.map(visitor, &d.typ, is_param)?; let new_args = a.map(visitor, &d, is_param)?;
ast::Instruction::Ld(d, new_args) ast::Instruction::Ld(d, new_args)
} }
ast::Instruction::Mov(d, a) => { ast::Instruction::Mov(d, a) => {
@ -3655,7 +3799,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::St(d, a) => { ast::Instruction::St(d, a) => {
let is_param = d.state_space == ast::StStateSpace::Param let is_param = d.state_space == ast::StStateSpace::Param
|| d.state_space == ast::StStateSpace::Local; || d.state_space == ast::StStateSpace::Local;
let new_args = a.map(visitor, &d.typ, is_param)?; let new_args = a.map(visitor, &d, is_param)?;
ast::Instruction::St(d, new_args) ast::Instruction::St(d, new_args)
} }
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
@ -3826,29 +3970,36 @@ impl ast::Type {
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: Vec::new(), components: Vec::new(),
state_space: ast::PointerStateSpace::Global, state_space: ast::LdStateSpace::Global,
}, },
ast::Type::Vector(scalar, components) => TypeParts { ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector, kind: TypeKind::Vector,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: vec![*components as u32], components: vec![*components as u32],
state_space: ast::PointerStateSpace::Global, state_space: ast::LdStateSpace::Global,
}, },
ast::Type::Array(scalar, components) => TypeParts { ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array, kind: TypeKind::Array,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: components.clone(), components: components.clone(),
state_space: ast::PointerStateSpace::Global, state_space: ast::LdStateSpace::Global,
}, },
ast::Type::Pointer(scalar, state_space) => TypeParts { ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
kind: TypeKind::Pointer, kind: TypeKind::PointerScalar,
scalar_kind: scalar.kind(), scalar_kind: scalar.kind(),
width: scalar.size_of(), width: scalar.size_of(),
components: Vec::new(), components: Vec::new(),
state_space: *state_space, state_space: *state_space,
}, },
ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
kind: TypeKind::PointerVector,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*len as u32],
state_space: *state_space,
},
} }
} }
@ -3865,8 +4016,15 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind), ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components, t.components,
), ),
TypeKind::Pointer => ast::Type::Pointer( TypeKind::PointerScalar => ast::Type::Pointer(
ast::ScalarType::from_parts(t.width, t.scalar_kind), ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
t.state_space,
),
TypeKind::PointerVector => ast::Type::Pointer(
ast::PointerType::Vector(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components[0] as u8,
),
t.state_space, t.state_space,
), ),
} }
@ -3879,7 +4037,7 @@ struct TypeParts {
scalar_kind: ScalarKind, scalar_kind: ScalarKind,
width: u8, width: u8,
components: Vec<u32>, components: Vec<u32>,
state_space: ast::PointerStateSpace, state_space: ast::LdStateSpace,
} }
#[derive(Eq, PartialEq, Copy, Clone)] #[derive(Eq, PartialEq, Copy, Clone)]
@ -3887,7 +4045,8 @@ enum TypeKind {
Scalar, Scalar,
Vector, Vector,
Array, Array,
Pointer, PointerScalar,
PointerVector,
} }
impl ast::Instruction<ExpandedArgParams> { impl ast::Instruction<ExpandedArgParams> {
@ -4007,6 +4166,7 @@ enum ConversionKind {
SignExtend, SignExtend,
BitToPtr(ast::LdStateSpace), BitToPtr(ast::LdStateSpace),
PtrToBit, PtrToBit,
PtrToPtr { spirv_ptr: bool },
} }
impl<T> ast::PredAt<T> { impl<T> ast::PredAt<T> {
@ -4058,7 +4218,7 @@ impl ast::VariableParamType {
(ast::ScalarType::from(*t).size_of() as usize) (ast::ScalarType::from(*t).size_of() as usize)
* (len.iter().fold(1, |x, y| x * (*y)) as usize) * (len.iter().fold(1, |x, y| x * (*y)) as usize)
} }
ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>() ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>(),
} }
} }
} }
@ -4076,7 +4236,10 @@ impl From<ast::KernelArgumentType> for ast::Type {
fn from(this: ast::KernelArgumentType) -> Self { fn from(this: ast::KernelArgumentType) -> Self {
match this { match this {
ast::KernelArgumentType::Normal(typ) => typ.into(), ast::KernelArgumentType::Normal(typ) => typ.into(),
ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64), ast::KernelArgumentType::Shared => ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
} }
} }
} }
@ -4085,9 +4248,10 @@ impl ast::KernelArgumentType {
fn to_param(self) -> ast::VariableParamType { fn to_param(self) -> ast::VariableParamType {
match self { match self {
ast::KernelArgumentType::Normal(p) => p, ast::KernelArgumentType::Normal(p) => p,
ast::KernelArgumentType::Shared => { ast::KernelArgumentType::Shared => ast::VariableParamType::Pointer(
ast::VariableParamType::Scalar(ast::ParamScalarType::B64) ast::SizedScalarType::B8,
} ast::PointerStateSpace::Shared,
),
} }
} }
} }
@ -4193,7 +4357,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: &ast::Type, details: &ast::LdDetails,
is_param: bool, is_param: bool,
) -> Result<ast::Arg2Ld<U>, TranslateError> { ) -> Result<ast::Arg2Ld<U>, TranslateError> {
let dst = visitor.id_or_vector( let dst = visitor.id_or_vector(
@ -4202,7 +4366,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
is_dst: true, is_dst: true,
sema: ArgumentSemantics::DefaultRelaxed, sema: ArgumentSemantics::DefaultRelaxed,
}, },
&ast::Type::from(t.clone()), &ast::Type::from(details.typ.clone()),
)?; )?;
let src = visitor.operand( let src = visitor.operand(
ArgumentDescriptor { ArgumentDescriptor {
@ -4214,7 +4378,14 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentSemantics::PhysicalPointer ArgumentSemantics::PhysicalPointer
}, },
}, },
t, &(if is_param {
ast::Type::from(details.typ.clone())
} else {
ast::Type::Pointer(
ast::PointerType::from(details.typ.clone()),
details.state_space,
)
}),
)?; )?;
Ok(ast::Arg2Ld { dst, src }) Ok(ast::Arg2Ld { dst, src })
} }
@ -4233,7 +4404,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>( fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self, self,
visitor: &mut V, visitor: &mut V,
t: &ast::Type, details: &ast::StData,
is_param: bool, is_param: bool,
) -> Result<ast::Arg2St<U>, TranslateError> { ) -> Result<ast::Arg2St<U>, TranslateError> {
let src1 = visitor.operand( let src1 = visitor.operand(
@ -4246,7 +4417,14 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
ArgumentSemantics::PhysicalPointer ArgumentSemantics::PhysicalPointer
}, },
}, },
t, &(if is_param {
details.typ.clone().into()
} else {
ast::Type::Pointer(
ast::PointerType::from(details.typ.clone()),
details.state_space.to_ld_ss(),
)
}),
)?; )?;
let src2 = visitor.operand_or_vector( let src2 = visitor.operand_or_vector(
ArgumentDescriptor { ArgumentDescriptor {
@ -4254,7 +4432,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
is_dst: false, is_dst: false,
sema: ArgumentSemantics::DefaultRelaxed, sema: ArgumentSemantics::DefaultRelaxed,
}, },
t, &details.typ.clone().into(),
)?; )?;
Ok(ast::Arg2St { src1, src2 }) Ok(ast::Arg2St { src1, src2 })
} }
@ -4957,7 +5135,7 @@ impl ast::MulDetails {
} }
} }
fn force_bitcast( fn bitcast_logical_pointer(
operand: &ast::Type, operand: &ast::Type,
instr: &ast::Type, instr: &ast::Type,
_: Option<ast::LdStateSpace>, _: Option<ast::LdStateSpace>,
@ -4971,21 +5149,12 @@ fn force_bitcast(
fn bitcast_physical_pointer( fn bitcast_physical_pointer(
operand_type: &ast::Type, operand_type: &ast::Type,
_: &ast::Type, instr_type: &ast::Type,
ss: Option<ast::LdStateSpace>, ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> { ) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type { match operand_type {
// array decays to a pointer // array decays to a pointer
ast::Type::Array(_, vec) => { ast::Type::Array(_, _) => todo!(),
if vec.len() != 0 {
return Err(TranslateError::MismatchedType);
}
if let Some(space) = ss {
Ok(Some(ConversionKind::BitToPtr(space)))
} else {
Err(TranslateError::Unreachable)
}
}
ast::Type::Scalar(ast::ScalarType::B64) ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => { | ast::Type::Scalar(ast::ScalarType::S64) => {
@ -4995,6 +5164,27 @@ fn bitcast_physical_pointer(
Err(TranslateError::Unreachable) Err(TranslateError::Unreachable)
} }
} }
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
if op_scalar_t == instr_scalar_t {
Ok(None)
} else {
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
}
} else {
if *op_space == ast::LdStateSpace::Generic
|| *instr_space == ast::LdStateSpace::Generic
{
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
} else {
Err(TranslateError::MismatchedType)
}
}
} else {
Err(TranslateError::MismatchedType)
}
}
_ => Err(TranslateError::MismatchedType), _ => Err(TranslateError::MismatchedType),
} }
} }
@ -5206,7 +5396,7 @@ fn should_convert_relaxed_dst(
impl<'a> ast::MethodDecl<'a, &'a str> { impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str { fn name(&self) -> &'a str {
match self { match self {
ast::MethodDecl::Kernel(name, _) => name, ast::MethodDecl::Kernel { name, .. } => name,
ast::MethodDecl::Func(_, name, _) => name, ast::MethodDecl::Func(_, name, _) => name,
} }
} }
@ -5216,7 +5406,7 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> {
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) { fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) {
match self { match self {
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f), ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| { ast::MethodDecl::Kernel { in_args, .. } => in_args.iter().for_each(|arg| {
f(&ast::FnArgument { f(&ast::FnArgument {
align: arg.align, align: arg.align,
name: arg.name, name: arg.name,