mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-06 00:00:13 +00:00
Add dynamic shared mem support
This commit is contained in:
parent
28a0968294
commit
85ee8210df
6 changed files with 491 additions and 221 deletions
|
@ -5,8 +5,8 @@ members = [
|
||||||
"level_zero",
|
"level_zero",
|
||||||
"spirv_tools-sys",
|
"spirv_tools-sys",
|
||||||
"notcuda",
|
"notcuda",
|
||||||
"notcuda_inject",
|
#"notcuda_inject",
|
||||||
"notcuda_redirect",
|
#"notcuda_redirect",
|
||||||
"ptx",
|
"ptx",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
127
ptx/src/ast.rs
127
ptx/src/ast.rs
|
@ -28,8 +28,11 @@ quick_error! {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! sub_scalar_type {
|
macro_rules! sub_enum {
|
||||||
($name:ident { $($variant:ident),+ $(,)? }) => {
|
($name:ident { $($variant:ident),+ $(,)? }) => {
|
||||||
|
sub_enum!{ $name : ScalarType { $($variant),+ } }
|
||||||
|
};
|
||||||
|
($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
|
||||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||||
pub enum $name {
|
pub enum $name {
|
||||||
$(
|
$(
|
||||||
|
@ -37,23 +40,23 @@ macro_rules! sub_scalar_type {
|
||||||
)+
|
)+
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<$name> for ScalarType {
|
impl From<$name> for $base_type {
|
||||||
fn from(t: $name) -> ScalarType {
|
fn from(t: $name) -> $base_type {
|
||||||
match t {
|
match t {
|
||||||
$(
|
$(
|
||||||
$name::$variant => ScalarType::$variant,
|
$name::$variant => $base_type::$variant,
|
||||||
)+
|
)+
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::convert::TryFrom<ScalarType> for $name {
|
impl std::convert::TryFrom<$base_type> for $name {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn try_from(t: ScalarType) -> Result<Self, Self::Error> {
|
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
|
||||||
match t {
|
match t {
|
||||||
$(
|
$(
|
||||||
ScalarType::$variant => Ok($name::$variant),
|
$base_type::$variant => Ok($name::$variant),
|
||||||
)+
|
)+
|
||||||
_ => Err(()),
|
_ => Err(()),
|
||||||
}
|
}
|
||||||
|
@ -64,6 +67,13 @@ macro_rules! sub_scalar_type {
|
||||||
|
|
||||||
macro_rules! sub_type {
|
macro_rules! sub_type {
|
||||||
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||||
|
sub_type! { $type_name : Type {
|
||||||
|
$(
|
||||||
|
$variant ($($field_type),+),
|
||||||
|
)+
|
||||||
|
}}
|
||||||
|
};
|
||||||
|
($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||||
#[derive(PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub enum $type_name {
|
pub enum $type_name {
|
||||||
$(
|
$(
|
||||||
|
@ -71,26 +81,26 @@ macro_rules! sub_type {
|
||||||
)+
|
)+
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<$type_name> for Type {
|
impl From<$type_name> for $base_type {
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
fn from(t: $type_name) -> Type {
|
fn from(t: $type_name) -> $base_type {
|
||||||
match t {
|
match t {
|
||||||
$(
|
$(
|
||||||
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
|
$type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
|
||||||
)+
|
)+
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::convert::TryFrom<Type> for $type_name {
|
impl std::convert::TryFrom<$base_type> for $type_name {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
#[allow(unreachable_patterns)]
|
#[allow(unreachable_patterns)]
|
||||||
fn try_from(t: Type) -> Result<Self, Self::Error> {
|
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
|
||||||
match t {
|
match t {
|
||||||
$(
|
$(
|
||||||
Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
|
$base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
|
||||||
)+
|
)+
|
||||||
_ => Err(()),
|
_ => Err(()),
|
||||||
}
|
}
|
||||||
|
@ -99,10 +109,12 @@ macro_rules! sub_type {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pointer is used when doing SLM converison to SPIRV
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableRegType {
|
VariableRegType {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(SizedScalarType, u8),
|
Vector(SizedScalarType, u8),
|
||||||
|
Pointer(SizedScalarType, PointerStateSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,13 +158,13 @@ sub_type! {
|
||||||
// .param .b32 foobar[]
|
// .param .b32 foobar[]
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableParamType {
|
VariableParamType {
|
||||||
Scalar(ParamScalarType),
|
Scalar(LdStScalarType),
|
||||||
Array(SizedScalarType, VecU32),
|
Array(SizedScalarType, VecU32),
|
||||||
Pointer(SizedScalarType, PointerStateSpace),
|
Pointer(SizedScalarType, PointerStateSpace),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_scalar_type!(SizedScalarType {
|
sub_enum!(SizedScalarType {
|
||||||
B8,
|
B8,
|
||||||
B16,
|
B16,
|
||||||
B32,
|
B32,
|
||||||
|
@ -171,7 +183,7 @@ sub_scalar_type!(SizedScalarType {
|
||||||
F64,
|
F64,
|
||||||
});
|
});
|
||||||
|
|
||||||
sub_scalar_type!(ParamScalarType {
|
sub_enum!(LdStScalarType {
|
||||||
B8,
|
B8,
|
||||||
B16,
|
B16,
|
||||||
B32,
|
B32,
|
||||||
|
@ -232,7 +244,11 @@ pub enum Directive<'a, P: ArgParams> {
|
||||||
|
|
||||||
pub enum MethodDecl<'a, ID> {
|
pub enum MethodDecl<'a, ID> {
|
||||||
Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
|
Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
|
||||||
Kernel(&'a str, Vec<KernelArgument<ID>>),
|
Kernel {
|
||||||
|
name: &'a str,
|
||||||
|
in_args: Vec<KernelArgument<ID>>,
|
||||||
|
uses_shared_mem: bool,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
|
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
|
||||||
|
@ -262,25 +278,52 @@ impl From<FnArgumentType> for Type {
|
||||||
match t {
|
match t {
|
||||||
FnArgumentType::Reg(x) => x.into(),
|
FnArgumentType::Reg(x) => x.into(),
|
||||||
FnArgumentType::Param(x) => x.into(),
|
FnArgumentType::Param(x) => x.into(),
|
||||||
FnArgumentType::Shared => Type::Scalar(ScalarType::B64),
|
FnArgumentType::Shared => {
|
||||||
|
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
sub_enum!(
|
||||||
pub enum PointerStateSpace {
|
PointerStateSpace : LdStateSpace {
|
||||||
Global,
|
Global,
|
||||||
Const,
|
Const,
|
||||||
Shared,
|
Shared,
|
||||||
Param,
|
Param,
|
||||||
}
|
}
|
||||||
|
);
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(ScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
Array(ScalarType, Vec<u32>),
|
Array(ScalarType, Vec<u32>),
|
||||||
Pointer(ScalarType, PointerStateSpace),
|
Pointer(PointerType, LdStateSpace),
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_type! {
|
||||||
|
PointerType {
|
||||||
|
Scalar(ScalarType),
|
||||||
|
Vector(ScalarType, u8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SizedScalarType> for PointerType {
|
||||||
|
fn from(t: SizedScalarType) -> Self {
|
||||||
|
PointerType::Scalar(t.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<PointerType> for SizedScalarType {
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
fn try_from(value: PointerType) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
PointerType::Scalar(t) => Ok(t.try_into()?),
|
||||||
|
PointerType::Vector(_, _) => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||||
|
@ -304,7 +347,7 @@ pub enum ScalarType {
|
||||||
Pred,
|
Pred,
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_scalar_type!(IntType {
|
sub_enum!(IntType {
|
||||||
U8,
|
U8,
|
||||||
U16,
|
U16,
|
||||||
U32,
|
U32,
|
||||||
|
@ -315,9 +358,9 @@ sub_scalar_type!(IntType {
|
||||||
S64
|
S64
|
||||||
});
|
});
|
||||||
|
|
||||||
sub_scalar_type!(UIntType { U8, U16, U32, U64 });
|
sub_enum!(UIntType { U8, U16, U32, U64 });
|
||||||
|
|
||||||
sub_scalar_type!(SIntType { S8, S16, S32, S64 });
|
sub_enum!(SIntType { S8, S16, S32, S64 });
|
||||||
|
|
||||||
impl IntType {
|
impl IntType {
|
||||||
pub fn is_signed(self) -> bool {
|
pub fn is_signed(self) -> bool {
|
||||||
|
@ -341,7 +384,7 @@ impl IntType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_scalar_type!(FloatType {
|
sub_enum!(FloatType {
|
||||||
F16,
|
F16,
|
||||||
F16x2,
|
F16x2,
|
||||||
F32,
|
F32,
|
||||||
|
@ -615,7 +658,23 @@ pub struct LdDetails {
|
||||||
pub qualifier: LdStQualifier,
|
pub qualifier: LdStQualifier,
|
||||||
pub state_space: LdStateSpace,
|
pub state_space: LdStateSpace,
|
||||||
pub caching: LdCacheOperator,
|
pub caching: LdCacheOperator,
|
||||||
pub typ: Type,
|
pub typ: LdStType,
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_type! {
|
||||||
|
LdStType {
|
||||||
|
Scalar(LdStScalarType),
|
||||||
|
Vector(LdStScalarType, u8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<LdStType> for PointerType {
|
||||||
|
fn from(t: LdStType) -> Self {
|
||||||
|
match t {
|
||||||
|
LdStType::Scalar(t) => PointerType::Scalar(t.into()),
|
||||||
|
LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||||
|
@ -860,7 +919,7 @@ pub enum ShlType {
|
||||||
B64,
|
B64,
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_scalar_type!(ShrType {
|
sub_enum!(ShrType {
|
||||||
B16,
|
B16,
|
||||||
B32,
|
B32,
|
||||||
B64,
|
B64,
|
||||||
|
@ -876,7 +935,7 @@ pub struct StData {
|
||||||
pub qualifier: LdStQualifier,
|
pub qualifier: LdStQualifier,
|
||||||
pub state_space: StStateSpace,
|
pub state_space: StStateSpace,
|
||||||
pub caching: StCacheOperator,
|
pub caching: StCacheOperator,
|
||||||
pub typ: Type,
|
pub typ: LdStType,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||||
|
@ -900,7 +959,7 @@ pub struct RetData {
|
||||||
pub uniform: bool,
|
pub uniform: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_scalar_type!(OrType {
|
sub_enum!(OrType {
|
||||||
Pred,
|
Pred,
|
||||||
B16,
|
B16,
|
||||||
B32,
|
B32,
|
||||||
|
|
|
@ -237,7 +237,8 @@ LinkingDirectives: ast::LinkingDirective = {
|
||||||
}
|
}
|
||||||
|
|
||||||
MethodDecl: ast::MethodDecl<'input, &'input str> = {
|
MethodDecl: ast::MethodDecl<'input, &'input str> = {
|
||||||
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
".entry" <name:ExtendedID> <in_args:KernelArguments> =>
|
||||||
|
ast::MethodDecl::Kernel{ name, in_args, uses_shared_mem: false },
|
||||||
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
||||||
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
||||||
}
|
}
|
||||||
|
@ -294,10 +295,6 @@ ScalarType: ast::ScalarType = {
|
||||||
".f16" => ast::ScalarType::F16,
|
".f16" => ast::ScalarType::F16,
|
||||||
".f16x2" => ast::ScalarType::F16x2,
|
".f16x2" => ast::ScalarType::F16x2,
|
||||||
".pred" => ast::ScalarType::Pred,
|
".pred" => ast::ScalarType::Pred,
|
||||||
LdStScalarType
|
|
||||||
};
|
|
||||||
|
|
||||||
LdStScalarType: ast::ScalarType = {
|
|
||||||
".b8" => ast::ScalarType::B8,
|
".b8" => ast::ScalarType::B8,
|
||||||
".b16" => ast::ScalarType::B16,
|
".b16" => ast::ScalarType::B16,
|
||||||
".b32" => ast::ScalarType::B32,
|
".b32" => ast::ScalarType::B32,
|
||||||
|
@ -442,7 +439,7 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||||
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
||||||
".param" <var:VariableScalar<ParamScalarType>> => {
|
".param" <var:VariableScalar<LdStScalarType>> => {
|
||||||
let (align, t, name) = var;
|
let (align, t, name) = var;
|
||||||
let v_type = ast::VariableParamType::Scalar(t);
|
let v_type = ast::VariableParamType::Scalar(t);
|
||||||
(align, Vec::new(), v_type, name)
|
(align, Vec::new(), v_type, name)
|
||||||
|
@ -506,22 +503,22 @@ SizedScalarType: ast::SizedScalarType = {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
ParamScalarType: ast::ParamScalarType = {
|
LdStScalarType: ast::LdStScalarType = {
|
||||||
".b8" => ast::ParamScalarType::B8,
|
".b8" => ast::LdStScalarType::B8,
|
||||||
".b16" => ast::ParamScalarType::B16,
|
".b16" => ast::LdStScalarType::B16,
|
||||||
".b32" => ast::ParamScalarType::B32,
|
".b32" => ast::LdStScalarType::B32,
|
||||||
".b64" => ast::ParamScalarType::B64,
|
".b64" => ast::LdStScalarType::B64,
|
||||||
".u8" => ast::ParamScalarType::U8,
|
".u8" => ast::LdStScalarType::U8,
|
||||||
".u16" => ast::ParamScalarType::U16,
|
".u16" => ast::LdStScalarType::U16,
|
||||||
".u32" => ast::ParamScalarType::U32,
|
".u32" => ast::LdStScalarType::U32,
|
||||||
".u64" => ast::ParamScalarType::U64,
|
".u64" => ast::LdStScalarType::U64,
|
||||||
".s8" => ast::ParamScalarType::S8,
|
".s8" => ast::LdStScalarType::S8,
|
||||||
".s16" => ast::ParamScalarType::S16,
|
".s16" => ast::LdStScalarType::S16,
|
||||||
".s32" => ast::ParamScalarType::S32,
|
".s32" => ast::LdStScalarType::S32,
|
||||||
".s64" => ast::ParamScalarType::S64,
|
".s64" => ast::LdStScalarType::S64,
|
||||||
".f16" => ast::ParamScalarType::F16,
|
".f16" => ast::LdStScalarType::F16,
|
||||||
".f32" => ast::ParamScalarType::F32,
|
".f32" => ast::LdStScalarType::F32,
|
||||||
".f64" => ast::ParamScalarType::F64,
|
".f64" => ast::LdStScalarType::F64,
|
||||||
}
|
}
|
||||||
|
|
||||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
|
@ -572,9 +569,9 @@ OperandOrVector: ast::OperandOrVector<&'input str> = {
|
||||||
<dst:VectorExtract> => ast::OperandOrVector::Vec(dst)
|
<dst:VectorExtract> => ast::OperandOrVector::Vec(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
LdStType: ast::Type = {
|
LdStType: ast::LdStType = {
|
||||||
<v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v),
|
<v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v),
|
||||||
<t:LdStScalarType> => ast::Type::Scalar(t),
|
<t:LdStScalarType> => ast::LdStType::Scalar(t),
|
||||||
}
|
}
|
||||||
|
|
||||||
LdStQualifier: ast::LdStQualifier = {
|
LdStQualifier: ast::LdStQualifier = {
|
||||||
|
|
|
@ -2,52 +2,67 @@
|
||||||
OpCapability Linkage
|
OpCapability Linkage
|
||||||
OpCapability Addresses
|
OpCapability Addresses
|
||||||
OpCapability Kernel
|
OpCapability Kernel
|
||||||
OpCapability Int64
|
|
||||||
OpCapability Int8
|
OpCapability Int8
|
||||||
%29 = OpExtInstImport "OpenCL.std"
|
OpCapability Int16
|
||||||
|
OpCapability Int64
|
||||||
|
OpCapability Float16
|
||||||
|
OpCapability Float64
|
||||||
|
%32 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "cvta"
|
OpEntryPoint Kernel %2 "extern_shared" %1
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
|
%uint = OpTypeInt 32 0
|
||||||
|
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
|
||||||
|
%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint
|
||||||
|
%1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%32 = OpTypeFunction %void %ulong %ulong
|
%uchar = OpTypeInt 8 0
|
||||||
|
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
|
||||||
|
%40 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
|
||||||
|
%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
|
||||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
%float = OpTypeFloat 32
|
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
|
||||||
%_ptr_Function_float = OpTypePointer Function %float
|
%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint
|
||||||
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
|
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
|
||||||
%1 = OpFunction %void None %32
|
%2 = OpFunction %void None %40
|
||||||
%7 = OpFunctionParameter %ulong
|
|
||||||
%8 = OpFunctionParameter %ulong
|
%8 = OpFunctionParameter %ulong
|
||||||
%27 = OpLabel
|
%9 = OpFunctionParameter %ulong
|
||||||
%2 = OpVariable %_ptr_Function_ulong Function
|
%28 = OpFunctionParameter %_ptr_Workgroup_uchar
|
||||||
|
%41 = OpLabel
|
||||||
|
%29 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
|
||||||
%3 = OpVariable %_ptr_Function_ulong Function
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
%4 = OpVariable %_ptr_Function_ulong Function
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
%5 = OpVariable %_ptr_Function_ulong Function
|
%5 = OpVariable %_ptr_Function_ulong Function
|
||||||
%6 = OpVariable %_ptr_Function_float Function
|
%6 = OpVariable %_ptr_Function_ulong Function
|
||||||
OpStore %2 %7
|
%7 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %29 %28
|
||||||
|
OpBranch %26
|
||||||
|
%26 = OpLabel
|
||||||
OpStore %3 %8
|
OpStore %3 %8
|
||||||
%10 = OpLoad %ulong %2
|
|
||||||
%9 = OpCopyObject %ulong %10
|
|
||||||
OpStore %4 %9
|
OpStore %4 %9
|
||||||
%12 = OpLoad %ulong %3
|
%11 = OpLoad %ulong %3
|
||||||
%11 = OpCopyObject %ulong %12
|
%10 = OpCopyObject %ulong %11
|
||||||
OpStore %5 %11
|
OpStore %5 %10
|
||||||
%14 = OpLoad %ulong %4
|
%13 = OpLoad %ulong %4
|
||||||
%22 = OpCopyObject %ulong %14
|
%12 = OpCopyObject %ulong %13
|
||||||
%21 = OpCopyObject %ulong %22
|
OpStore %6 %12
|
||||||
%13 = OpCopyObject %ulong %21
|
%15 = OpLoad %ulong %5
|
||||||
OpStore %4 %13
|
%22 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %15
|
||||||
%16 = OpLoad %ulong %5
|
%14 = OpLoad %ulong %22
|
||||||
%24 = OpCopyObject %ulong %16
|
OpStore %7 %14
|
||||||
%23 = OpCopyObject %ulong %24
|
%30 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29
|
||||||
%15 = OpCopyObject %ulong %23
|
%16 = OpLoad %_ptr_Workgroup_uint %30
|
||||||
OpStore %5 %15
|
%17 = OpLoad %ulong %7
|
||||||
%18 = OpLoad %ulong %4
|
%23 = OpBitcast %_ptr_Workgroup_ulong %16
|
||||||
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
|
OpStore %23 %17
|
||||||
%17 = OpLoad %float %25
|
%31 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %29
|
||||||
OpStore %6 %17
|
%19 = OpLoad %_ptr_Workgroup_uint %31
|
||||||
%19 = OpLoad %ulong %5
|
%24 = OpBitcast %_ptr_Workgroup_ulong %19
|
||||||
%20 = OpLoad %float %6
|
%18 = OpLoad %ulong %24
|
||||||
%26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
|
OpStore %7 %18
|
||||||
OpStore %26 %20
|
%20 = OpLoad %ulong %6
|
||||||
|
%21 = OpLoad %ulong %7
|
||||||
|
%25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20
|
||||||
|
OpStore %25 %21
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
|
|
|
@ -107,27 +107,33 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
|
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
|
||||||
assert!(errors.len() == 0);
|
assert!(errors.len() == 0);
|
||||||
let (spirv, _) = translate::to_spirv(ast)?;
|
let notcuda_module = translate::to_spirv_module(ast)?;
|
||||||
let name = CString::new(name)?;
|
let name = CString::new(name)?;
|
||||||
let result =
|
let result = run_spirv(name.as_c_str(), notcuda_module, input, output)
|
||||||
run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?;
|
.map_err(|err| DisplayError { err })?;
|
||||||
assert_eq!(output, result.as_slice());
|
assert_eq!(output, result.as_slice());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
||||||
name: &CStr,
|
name: &CStr,
|
||||||
spirv: &[u32],
|
module: translate::Module,
|
||||||
input: &[T],
|
input: &[T],
|
||||||
output: &mut [T],
|
output: &mut [T],
|
||||||
) -> ze::Result<Vec<T>> {
|
) -> ze::Result<Vec<T>> {
|
||||||
ze::init()?;
|
ze::init()?;
|
||||||
|
let spirv = module.spirv.assemble();
|
||||||
let byte_il = unsafe {
|
let byte_il = unsafe {
|
||||||
slice::from_raw_parts::<u8>(
|
slice::from_raw_parts::<u8>(
|
||||||
spirv.as_ptr() as *const _,
|
spirv.as_ptr() as *const _,
|
||||||
spirv.len() * mem::size_of::<u32>(),
|
spirv.len() * mem::size_of::<u32>(),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
let use_shared_mem = module
|
||||||
|
.kernel_info
|
||||||
|
.get(name.to_str().unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.uses_shared_mem;
|
||||||
let mut result = vec![0u8.into(); output.len()];
|
let mut result = vec![0u8.into(); output.len()];
|
||||||
{
|
{
|
||||||
let mut drivers = ze::Driver::get()?;
|
let mut drivers = ze::Driver::get()?;
|
||||||
|
@ -140,7 +146,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
||||||
let module = match module {
|
let module = match module {
|
||||||
Ok(m) => m,
|
Ok(m) => m,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let raw_err_string = log.get_cstring()?;
|
let raw_err_string = log.get_cstring()?;
|
||||||
let err_string = raw_err_string.to_string_lossy();
|
let err_string = raw_err_string.to_string_lossy();
|
||||||
panic!("{:?}\n{}", err, err_string);
|
panic!("{:?}\n{}", err, err_string);
|
||||||
}
|
}
|
||||||
|
@ -164,6 +170,9 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
||||||
kernel.set_group_size(1, 1, 1)?;
|
kernel.set_group_size(1, 1, 1)?;
|
||||||
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
|
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
|
||||||
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
|
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
|
||||||
|
if use_shared_mem {
|
||||||
|
unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
|
||||||
|
}
|
||||||
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
|
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
|
||||||
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
|
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
|
||||||
queue.execute(cmd_list)?;
|
queue.execute(cmd_list)?;
|
||||||
|
@ -179,7 +188,7 @@ fn test_spvtxt_assert<'a>(
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
||||||
assert!(errors.len() == 0);
|
assert!(errors.len() == 0);
|
||||||
let (ptx_mod, _) = translate::to_spirv_module(ast)?;
|
let spirv_module = translate::to_spirv_module(ast)?;
|
||||||
let spv_context =
|
let spv_context =
|
||||||
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
||||||
assert!(spv_context != ptr::null_mut());
|
assert!(spv_context != ptr::null_mut());
|
||||||
|
@ -211,9 +220,9 @@ fn test_spvtxt_assert<'a>(
|
||||||
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
|
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
|
||||||
let spvtxt_mod = loader.module();
|
let spvtxt_mod = loader.module();
|
||||||
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
|
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
|
||||||
if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) {
|
if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) {
|
||||||
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
|
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
|
||||||
let spv_from_ptx_binary = ptx_mod.assemble();
|
let spv_from_ptx_binary = spirv_module.spirv.assemble();
|
||||||
let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
|
let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
|
||||||
let result = unsafe {
|
let result = unsafe {
|
||||||
spirv_tools::spvBinaryToText(
|
spirv_tools::spvBinaryToText(
|
||||||
|
@ -234,7 +243,7 @@ fn test_spvtxt_assert<'a>(
|
||||||
// TODO: stop leaking kernel text
|
// TODO: stop leaking kernel text
|
||||||
Cow::Borrowed(spv_from_ptx_text)
|
Cow::Borrowed(spv_from_ptx_text)
|
||||||
} else {
|
} else {
|
||||||
Cow::Owned(ptx_mod.disassemble())
|
Cow::Owned(spirv_module.spirv.disassemble())
|
||||||
};
|
};
|
||||||
if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") {
|
if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") {
|
||||||
let mut path = PathBuf::from(dump_path);
|
let mut path = PathBuf::from(dump_path);
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
use crate::ast;
|
use crate::ast;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use rspirv::{binary::Disassemble, dr};
|
use rspirv::{binary::Disassemble, dr};
|
||||||
use std::collections::{hash_map, HashMap, HashSet};
|
|
||||||
use std::{borrow::Cow, iter, mem};
|
use std::{borrow::Cow, iter, mem};
|
||||||
|
use std::{
|
||||||
|
collections::{hash_map, HashMap, HashSet},
|
||||||
|
convert::TryFrom,
|
||||||
|
};
|
||||||
|
|
||||||
use rspirv::binary::Assemble;
|
use rspirv::binary::Assemble;
|
||||||
|
|
||||||
|
@ -12,7 +15,7 @@ quick_error! {
|
||||||
UnknownSymbol {}
|
UnknownSymbol {}
|
||||||
UntypedSymbol {}
|
UntypedSymbol {}
|
||||||
MismatchedType {}
|
MismatchedType {}
|
||||||
Spirv (err: rspirv::dr::Error) {
|
Spirv(err: rspirv::dr::Error) {
|
||||||
from()
|
from()
|
||||||
display("{}", err)
|
display("{}", err)
|
||||||
cause(err)
|
cause(err)
|
||||||
|
@ -45,8 +48,15 @@ impl From<ast::Type> for SpirvType {
|
||||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||||
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
||||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
||||||
ast::Type::Pointer(typ, state_space) => {
|
ast::Type::Pointer(ast::PointerType::Scalar(typ), state_space) => SpirvType::Pointer(
|
||||||
SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into())
|
Box::new(SpirvType::Base(typ.into())),
|
||||||
|
state_space.to_spirv(),
|
||||||
|
),
|
||||||
|
ast::Type::Pointer(ast::PointerType::Vector(typ, len), state_space) => {
|
||||||
|
SpirvType::Pointer(
|
||||||
|
Box::new(SpirvType::Vector(typ.into(), len)),
|
||||||
|
state_space.to_spirv(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -365,12 +375,16 @@ impl TypeWordMap {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ast::Type::Pointer(typ, state_space) => {
|
ast::Type::Pointer(typ, state_space) => {
|
||||||
let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?;
|
let base_t = typ.clone().into();
|
||||||
|
let base = self.get_or_add_constant(b, &base_t, &[])?;
|
||||||
let result_type = self.get_or_add(
|
let result_type = self.get_or_add(
|
||||||
b,
|
b,
|
||||||
SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()),
|
SpirvType::Pointer(
|
||||||
|
Box::new(SpirvType::from(base_t)),
|
||||||
|
(*state_space).to_spirv(),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
b.variable(result_type, None, (*state_space).into(), Some(base))
|
b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -402,9 +416,17 @@ impl TypeWordMap {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_spirv_module<'a>(
|
pub struct Module {
|
||||||
ast: ast::Module<'a>,
|
pub spirv: dr::Module,
|
||||||
) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> {
|
pub kernel_info: HashMap<String, KernelInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KernelInfo {
|
||||||
|
pub arguments_sizes: Vec<usize>,
|
||||||
|
pub uses_shared_mem: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
|
||||||
let mut id_defs = GlobalStringIdResolver::new(1);
|
let mut id_defs = GlobalStringIdResolver::new(1);
|
||||||
let directives = ast
|
let directives = ast
|
||||||
.directives
|
.directives
|
||||||
|
@ -413,6 +435,9 @@ pub fn to_spirv_module<'a>(
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
let mut builder = dr::Builder::new();
|
let mut builder = dr::Builder::new();
|
||||||
builder.reserve_ids(id_defs.current_id());
|
builder.reserve_ids(id_defs.current_id());
|
||||||
|
let mut directives =
|
||||||
|
convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
|
||||||
|
normalize_variable_decls(&mut directives);
|
||||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||||
builder.set_version(1, 3);
|
builder.set_version(1, 3);
|
||||||
emit_capabilities(&mut builder);
|
emit_capabilities(&mut builder);
|
||||||
|
@ -421,7 +446,7 @@ pub fn to_spirv_module<'a>(
|
||||||
emit_memory_model(&mut builder);
|
emit_memory_model(&mut builder);
|
||||||
let mut map = TypeWordMap::new(&mut builder);
|
let mut map = TypeWordMap::new(&mut builder);
|
||||||
emit_builtins(&mut builder, &mut map, &id_defs);
|
emit_builtins(&mut builder, &mut map, &id_defs);
|
||||||
let mut args_len = HashMap::new();
|
let mut kernel_info = HashMap::new();
|
||||||
for d in directives {
|
for d in directives {
|
||||||
match d {
|
match d {
|
||||||
Directive::Variable(var) => {
|
Directive::Variable(var) => {
|
||||||
|
@ -433,13 +458,20 @@ pub fn to_spirv_module<'a>(
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
|
||||||
emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?;
|
emit_function_header(
|
||||||
|
&mut builder,
|
||||||
|
&mut map,
|
||||||
|
&id_defs,
|
||||||
|
f.func_decl,
|
||||||
|
&mut kernel_info,
|
||||||
|
)?;
|
||||||
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
|
||||||
builder.end_function()?;
|
builder.end_function()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok((builder.module(), args_len))
|
let spirv = builder.module();
|
||||||
|
Ok(Module { spirv, kernel_info })
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
|
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
|
||||||
|
@ -461,16 +493,18 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
|
||||||
// This pass looks for all uses of .extern .shared and converts them to
|
// This pass looks for all uses of .extern .shared and converts them to
|
||||||
// an additional method argument
|
// an additional method argument
|
||||||
fn convert_dynamic_shared_memory_usage<'input>(
|
fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
new_id: &mut impl FnMut() -> spirv::Word,
|
|
||||||
id_defs: &mut GlobalStringIdResolver<'input>,
|
id_defs: &mut GlobalStringIdResolver<'input>,
|
||||||
module: Vec<Directive<'input>>,
|
module: Vec<Directive<'input>>,
|
||||||
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
) -> Vec<Directive<'input>> {
|
) -> Vec<Directive<'input>> {
|
||||||
let mut extern_shared_decls = HashSet::new();
|
let mut extern_shared_decls = HashMap::new();
|
||||||
for dir in module.iter() {
|
for dir in module.iter() {
|
||||||
match dir {
|
match dir {
|
||||||
Directive::Variable(var) => {
|
Directive::Variable(var) => {
|
||||||
if let ast::VariableType::Shared(_) = var.v_type {
|
if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
|
||||||
extern_shared_decls.insert(var.name);
|
var.v_type
|
||||||
|
{
|
||||||
|
extern_shared_decls.insert(var.name, p_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
|
@ -490,7 +524,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = match func_decl {
|
let call_key = match func_decl {
|
||||||
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
|
ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
|
||||||
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
||||||
};
|
};
|
||||||
let statements = statements
|
let statements = statements
|
||||||
|
@ -501,7 +535,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
Statement::Call(call)
|
Statement::Call(call)
|
||||||
}
|
}
|
||||||
statement => statement.map_id(&mut |id| {
|
statement => statement.map_id(&mut |id| {
|
||||||
if extern_shared_decls.contains(&id) {
|
if extern_shared_decls.contains_key(&id) {
|
||||||
methods_using_extern_shared.insert(call_key);
|
methods_using_extern_shared.insert(call_key);
|
||||||
}
|
}
|
||||||
id
|
id
|
||||||
|
@ -530,7 +564,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
body: Some(statements),
|
body: Some(statements),
|
||||||
}) => {
|
}) => {
|
||||||
let call_key = match func_decl {
|
let call_key = match func_decl {
|
||||||
ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
|
ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
|
||||||
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
|
||||||
};
|
};
|
||||||
if !methods_using_extern_shared.contains(&call_key) {
|
if !methods_using_extern_shared.contains(&call_key) {
|
||||||
|
@ -550,8 +584,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
name: shared_id_param,
|
name: shared_id_param,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
ast::MethodDecl::Kernel(_, input_args) => {
|
ast::MethodDecl::Kernel {
|
||||||
input_args.push(ast::Variable {
|
in_args,
|
||||||
|
uses_shared_mem,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
*uses_shared_mem = true;
|
||||||
|
in_args.push(ast::Variable {
|
||||||
align: None,
|
align: None,
|
||||||
v_type: ast::KernelArgumentType::Shared,
|
v_type: ast::KernelArgumentType::Shared,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
|
@ -559,33 +598,37 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let statements = statements
|
let shared_var_id = new_id();
|
||||||
.into_iter()
|
let shared_var = ExpandedStatement::Variable(ast::Variable {
|
||||||
.map(|statement| match statement {
|
align: None,
|
||||||
Statement::Call(mut call) => {
|
name: shared_var_id,
|
||||||
// We can safely skip checking call arguments,
|
array_init: Vec::new(),
|
||||||
// because there's simply no way to pass shared ptr
|
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||||
// without converting it to .b64 first
|
ast::SizedScalarType::B8,
|
||||||
if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func))
|
ast::PointerStateSpace::Shared,
|
||||||
{
|
)),
|
||||||
call.param_list
|
});
|
||||||
.push((shared_id_param, ast::FnArgumentType::Shared));
|
let shared_var_st = ExpandedStatement::StoreVar(
|
||||||
}
|
ast::Arg2St {
|
||||||
Statement::Call(call)
|
src1: shared_var_id,
|
||||||
}
|
src2: shared_id_param,
|
||||||
statement => statement.map_id(&mut |id| {
|
},
|
||||||
if extern_shared_decls.contains(&id) {
|
ast::Type::Scalar(ast::ScalarType::B8),
|
||||||
shared_id_param
|
);
|
||||||
} else {
|
let mut new_statements = vec![shared_var, shared_var_st];
|
||||||
id
|
replace_uses_of_shared_memory(
|
||||||
}
|
&mut new_statements,
|
||||||
}),
|
new_id,
|
||||||
})
|
&extern_shared_decls,
|
||||||
.collect();
|
&mut methods_using_extern_shared,
|
||||||
|
shared_id_param,
|
||||||
|
shared_var_id,
|
||||||
|
statements,
|
||||||
|
);
|
||||||
Directive::Method(Function {
|
Directive::Method(Function {
|
||||||
func_decl,
|
func_decl,
|
||||||
globals,
|
globals,
|
||||||
body: Some(statements),
|
body: Some(new_statements),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
directive => directive,
|
directive => directive,
|
||||||
|
@ -593,6 +636,57 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn replace_uses_of_shared_memory<'a>(
|
||||||
|
result: &mut Vec<ExpandedStatement>,
|
||||||
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
|
extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
|
||||||
|
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
|
||||||
|
shared_id_param: spirv::Word,
|
||||||
|
shared_var_id: spirv::Word,
|
||||||
|
statements: Vec<ExpandedStatement>,
|
||||||
|
) {
|
||||||
|
for statement in statements {
|
||||||
|
match statement {
|
||||||
|
Statement::Call(mut call) => {
|
||||||
|
// We can safely skip checking call arguments,
|
||||||
|
// because there's simply no way to pass shared ptr
|
||||||
|
// without converting it to .b64 first
|
||||||
|
if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func)) {
|
||||||
|
call.param_list
|
||||||
|
.push((shared_id_param, ast::FnArgumentType::Shared));
|
||||||
|
}
|
||||||
|
result.push(Statement::Call(call))
|
||||||
|
}
|
||||||
|
statement => {
|
||||||
|
let new_statement = statement.map_id(&mut |id| {
|
||||||
|
if let Some(typ) = extern_shared_decls.get(&id) {
|
||||||
|
let replacement_id = new_id();
|
||||||
|
if *typ != ast::SizedScalarType::B8 {
|
||||||
|
result.push(Statement::Conversion(ImplicitConversion {
|
||||||
|
src: shared_var_id,
|
||||||
|
dst: replacement_id,
|
||||||
|
from: ast::Type::Pointer(
|
||||||
|
ast::PointerType::Scalar(ast::ScalarType::B8),
|
||||||
|
ast::LdStateSpace::Shared,
|
||||||
|
),
|
||||||
|
to: ast::Type::Pointer(
|
||||||
|
ast::PointerType::Scalar((*typ).into()),
|
||||||
|
ast::LdStateSpace::Shared,
|
||||||
|
),
|
||||||
|
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
replacement_id
|
||||||
|
} else {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
});
|
||||||
|
result.push(new_statement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_callers_of_extern_shared<'a>(
|
fn get_callers_of_extern_shared<'a>(
|
||||||
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
|
methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
|
||||||
directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
|
directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
|
||||||
|
@ -670,15 +764,26 @@ fn emit_function_header<'a>(
|
||||||
map: &mut TypeWordMap,
|
map: &mut TypeWordMap,
|
||||||
global: &GlobalStringIdResolver<'a>,
|
global: &GlobalStringIdResolver<'a>,
|
||||||
func_directive: ast::MethodDecl<spirv::Word>,
|
func_directive: ast::MethodDecl<spirv::Word>,
|
||||||
all_args_lens: &mut HashMap<String, Vec<usize>>,
|
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||||
) -> Result<(), TranslateError> {
|
) -> Result<(), TranslateError> {
|
||||||
if let ast::MethodDecl::Kernel(name, args) = &func_directive {
|
if let ast::MethodDecl::Kernel {
|
||||||
let args_lens = args.iter().map(|param| param.v_type.width()).collect();
|
name,
|
||||||
all_args_lens.insert(name.to_string(), args_lens);
|
in_args,
|
||||||
|
uses_shared_mem,
|
||||||
|
} = &func_directive
|
||||||
|
{
|
||||||
|
let args_lens = in_args.iter().map(|param| param.v_type.width()).collect();
|
||||||
|
kernel_info.insert(
|
||||||
|
name.to_string(),
|
||||||
|
KernelInfo {
|
||||||
|
arguments_sizes: args_lens,
|
||||||
|
uses_shared_mem: *uses_shared_mem,
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
|
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
|
||||||
let fn_id = match func_directive {
|
let fn_id = match func_directive {
|
||||||
ast::MethodDecl::Kernel(name, _) => {
|
ast::MethodDecl::Kernel { name, .. } => {
|
||||||
let fn_id = global.get_id(name)?;
|
let fn_id = global.get_id(name)?;
|
||||||
let mut global_variables = global
|
let mut global_variables = global
|
||||||
.variables_type_check
|
.variables_type_check
|
||||||
|
@ -718,8 +823,15 @@ fn emit_function_header<'a>(
|
||||||
pub fn to_spirv<'a>(
|
pub fn to_spirv<'a>(
|
||||||
ast: ast::Module<'a>,
|
ast: ast::Module<'a>,
|
||||||
) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
|
) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
|
||||||
let (module, all_args_lens) = to_spirv_module(ast)?;
|
let module = to_spirv_module(ast)?;
|
||||||
Ok((module.assemble(), all_args_lens))
|
Ok((
|
||||||
|
module.spirv.assemble(),
|
||||||
|
module
|
||||||
|
.kernel_info
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k, v.arguments_sizes))
|
||||||
|
.collect(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||||
|
@ -843,8 +955,7 @@ fn to_ssa<'input, 'b>(
|
||||||
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
|
||||||
let mut numeric_id_defs = numeric_id_defs.unmut();
|
let mut numeric_id_defs = numeric_id_defs.unmut();
|
||||||
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
|
||||||
let sorted_statements = normalize_variable_decls(labeled_statements);
|
let (f_body, globals) = extract_globals(labeled_statements);
|
||||||
let (f_body, globals) = extract_globals(sorted_statements);
|
|
||||||
Ok(Function {
|
Ok(Function {
|
||||||
func_decl: f_args,
|
func_decl: f_args,
|
||||||
globals: globals,
|
globals: globals,
|
||||||
|
@ -859,12 +970,20 @@ fn extract_globals(
|
||||||
(sorted_statements, Vec::new())
|
(sorted_statements, Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
|
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
|
||||||
func[1..].sort_by_key(|s| match s {
|
for directive in directives {
|
||||||
Statement::Variable(_) => 0,
|
match directive {
|
||||||
_ => 1,
|
Directive::Method(Function {
|
||||||
});
|
body: Some(func), ..
|
||||||
func
|
}) => {
|
||||||
|
func[1..].sort_by_key(|s| match s {
|
||||||
|
Statement::Variable(_) => 0,
|
||||||
|
_ => 1,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_to_typed_statements(
|
fn convert_to_typed_statements(
|
||||||
|
@ -1138,8 +1257,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
||||||
) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> {
|
) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> {
|
||||||
let mut result = Vec::with_capacity(func.len());
|
let mut result = Vec::with_capacity(func.len());
|
||||||
let out_param = match &mut f_args {
|
let out_param = match &mut f_args {
|
||||||
ast::MethodDecl::Kernel(_, in_params) => {
|
ast::MethodDecl::Kernel { in_args, .. } => {
|
||||||
for p in in_params.iter_mut() {
|
for p in in_args.iter_mut() {
|
||||||
let typ = ast::Type::from(p.v_type.clone());
|
let typ = ast::Type::from(p.v_type.clone());
|
||||||
let new_id = id_def.new_id(typ.clone());
|
let new_id = id_def.new_id(typ.clone());
|
||||||
result.push(Statement::Variable(ast::Variable {
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
@ -1736,7 +1855,7 @@ fn insert_implicit_conversions_impl(
|
||||||
conversion_fn = bitcast_physical_pointer;
|
conversion_fn = bitcast_physical_pointer;
|
||||||
}
|
}
|
||||||
ArgumentSemantics::RegisterPointer => {
|
ArgumentSemantics::RegisterPointer => {
|
||||||
conversion_fn = force_bitcast;
|
conversion_fn = bitcast_logical_pointer;
|
||||||
}
|
}
|
||||||
ArgumentSemantics::Address => {
|
ArgumentSemantics::Address => {
|
||||||
conversion_fn = force_bitcast_ptr_to_bit;
|
conversion_fn = force_bitcast_ptr_to_bit;
|
||||||
|
@ -1790,10 +1909,10 @@ fn get_function_type(
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
|
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
|
||||||
),
|
),
|
||||||
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
|
ast::MethodDecl::Kernel { in_args, .. } => map.get_or_add_fn(
|
||||||
builder,
|
builder,
|
||||||
iter::empty(),
|
iter::empty(),
|
||||||
params
|
in_args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
|
.map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
|
||||||
),
|
),
|
||||||
|
@ -1886,14 +2005,19 @@ fn emit_function_body_ops(
|
||||||
if data.qualifier != ast::LdStQualifier::Weak {
|
if data.qualifier != ast::LdStQualifier::Weak {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone()));
|
let result_type =
|
||||||
|
map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
|
||||||
match data.state_space {
|
match data.state_space {
|
||||||
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
|
ast::LdStateSpace::Generic
|
||||||
|
| ast::LdStateSpace::Global
|
||||||
|
| ast::LdStateSpace::Shared => {
|
||||||
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
|
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
|
||||||
}
|
}
|
||||||
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
|
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
|
||||||
let result_type =
|
let result_type = map.get_or_add(
|
||||||
map.get_or_add(builder, SpirvType::from(data.typ.clone()));
|
builder,
|
||||||
|
SpirvType::from(ast::Type::from(data.typ.clone())),
|
||||||
|
);
|
||||||
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
|
@ -1906,11 +2030,14 @@ fn emit_function_body_ops(
|
||||||
if data.state_space == ast::StStateSpace::Param
|
if data.state_space == ast::StStateSpace::Param
|
||||||
|| data.state_space == ast::StStateSpace::Local
|
|| data.state_space == ast::StStateSpace::Local
|
||||||
{
|
{
|
||||||
let result_type =
|
let result_type = map.get_or_add(
|
||||||
map.get_or_add(builder, SpirvType::from(data.typ.clone()));
|
builder,
|
||||||
|
SpirvType::from(ast::Type::from(data.typ.clone())),
|
||||||
|
);
|
||||||
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
|
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
|
||||||
} else if data.state_space == ast::StStateSpace::Generic
|
} else if data.state_space == ast::StStateSpace::Generic
|
||||||
|| data.state_space == ast::StStateSpace::Global
|
|| data.state_space == ast::StStateSpace::Global
|
||||||
|
|| data.state_space == ast::StStateSpace::Shared
|
||||||
{
|
{
|
||||||
builder.store(arg.src1, arg.src2, None, &[])?;
|
builder.store(arg.src1, arg.src2, None, &[])?;
|
||||||
} else {
|
} else {
|
||||||
|
@ -2642,10 +2769,7 @@ fn emit_implicit_conversion(
|
||||||
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
|
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
|
||||||
}
|
}
|
||||||
(_, _, ConversionKind::BitToPtr(space)) => {
|
(_, _, ConversionKind::BitToPtr(space)) => {
|
||||||
let dst_type = map.get_or_add(
|
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
|
||||||
builder,
|
|
||||||
SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()),
|
|
||||||
);
|
|
||||||
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
|
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
|
||||||
}
|
}
|
||||||
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
|
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
|
||||||
|
@ -2703,6 +2827,20 @@ fn emit_implicit_conversion(
|
||||||
let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
|
let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
|
||||||
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
|
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
|
||||||
}
|
}
|
||||||
|
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
|
||||||
|
let result_type = if spirv_ptr {
|
||||||
|
map.get_or_add(
|
||||||
|
builder,
|
||||||
|
SpirvType::Pointer(
|
||||||
|
Box::new(SpirvType::from(cv.to.clone())),
|
||||||
|
spirv::StorageClass::Function,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
map.get_or_add(builder, SpirvType::from(cv.to.clone()))
|
||||||
|
};
|
||||||
|
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -2903,9 +3041,15 @@ impl<'a> GlobalStringIdResolver<'a> {
|
||||||
type_check: HashMap::new(),
|
type_check: HashMap::new(),
|
||||||
};
|
};
|
||||||
let new_fn_decl = match header {
|
let new_fn_decl = match header {
|
||||||
ast::MethodDecl::Kernel(name, params) => {
|
ast::MethodDecl::Kernel {
|
||||||
ast::MethodDecl::Kernel(name, expand_kernel_params(&mut fn_resolver, params.iter()))
|
name,
|
||||||
}
|
in_args,
|
||||||
|
uses_shared_mem,
|
||||||
|
} => ast::MethodDecl::Kernel {
|
||||||
|
name,
|
||||||
|
in_args: expand_kernel_params(&mut fn_resolver, in_args.iter()),
|
||||||
|
uses_shared_mem: *uses_shared_mem,
|
||||||
|
},
|
||||||
ast::MethodDecl::Func(ret_params, _, params) => {
|
ast::MethodDecl::Func(ret_params, _, params) => {
|
||||||
let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter());
|
let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter());
|
||||||
let params_ids = expand_fn_params(&mut fn_resolver, params.iter());
|
let params_ids = expand_fn_params(&mut fn_resolver, params.iter());
|
||||||
|
@ -3598,7 +3742,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
||||||
ast::Instruction::Ld(d, a) => {
|
ast::Instruction::Ld(d, a) => {
|
||||||
let is_param = d.state_space == ast::LdStateSpace::Param
|
let is_param = d.state_space == ast::LdStateSpace::Param
|
||||||
|| d.state_space == ast::LdStateSpace::Local;
|
|| d.state_space == ast::LdStateSpace::Local;
|
||||||
let new_args = a.map(visitor, &d.typ, is_param)?;
|
let new_args = a.map(visitor, &d, is_param)?;
|
||||||
ast::Instruction::Ld(d, new_args)
|
ast::Instruction::Ld(d, new_args)
|
||||||
}
|
}
|
||||||
ast::Instruction::Mov(d, a) => {
|
ast::Instruction::Mov(d, a) => {
|
||||||
|
@ -3655,7 +3799,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
||||||
ast::Instruction::St(d, a) => {
|
ast::Instruction::St(d, a) => {
|
||||||
let is_param = d.state_space == ast::StStateSpace::Param
|
let is_param = d.state_space == ast::StStateSpace::Param
|
||||||
|| d.state_space == ast::StStateSpace::Local;
|
|| d.state_space == ast::StStateSpace::Local;
|
||||||
let new_args = a.map(visitor, &d.typ, is_param)?;
|
let new_args = a.map(visitor, &d, is_param)?;
|
||||||
ast::Instruction::St(d, new_args)
|
ast::Instruction::St(d, new_args)
|
||||||
}
|
}
|
||||||
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
|
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
|
||||||
|
@ -3826,29 +3970,36 @@ impl ast::Type {
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: Vec::new(),
|
components: Vec::new(),
|
||||||
state_space: ast::PointerStateSpace::Global,
|
state_space: ast::LdStateSpace::Global,
|
||||||
},
|
},
|
||||||
ast::Type::Vector(scalar, components) => TypeParts {
|
ast::Type::Vector(scalar, components) => TypeParts {
|
||||||
kind: TypeKind::Vector,
|
kind: TypeKind::Vector,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: vec![*components as u32],
|
components: vec![*components as u32],
|
||||||
state_space: ast::PointerStateSpace::Global,
|
state_space: ast::LdStateSpace::Global,
|
||||||
},
|
},
|
||||||
ast::Type::Array(scalar, components) => TypeParts {
|
ast::Type::Array(scalar, components) => TypeParts {
|
||||||
kind: TypeKind::Array,
|
kind: TypeKind::Array,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: components.clone(),
|
components: components.clone(),
|
||||||
state_space: ast::PointerStateSpace::Global,
|
state_space: ast::LdStateSpace::Global,
|
||||||
},
|
},
|
||||||
ast::Type::Pointer(scalar, state_space) => TypeParts {
|
ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
|
||||||
kind: TypeKind::Pointer,
|
kind: TypeKind::PointerScalar,
|
||||||
scalar_kind: scalar.kind(),
|
scalar_kind: scalar.kind(),
|
||||||
width: scalar.size_of(),
|
width: scalar.size_of(),
|
||||||
components: Vec::new(),
|
components: Vec::new(),
|
||||||
state_space: *state_space,
|
state_space: *state_space,
|
||||||
},
|
},
|
||||||
|
ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
|
||||||
|
kind: TypeKind::PointerVector,
|
||||||
|
scalar_kind: scalar.kind(),
|
||||||
|
width: scalar.size_of(),
|
||||||
|
components: vec![*len as u32],
|
||||||
|
state_space: *state_space,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3865,8 +4016,15 @@ impl ast::Type {
|
||||||
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
||||||
t.components,
|
t.components,
|
||||||
),
|
),
|
||||||
TypeKind::Pointer => ast::Type::Pointer(
|
TypeKind::PointerScalar => ast::Type::Pointer(
|
||||||
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
|
||||||
|
t.state_space,
|
||||||
|
),
|
||||||
|
TypeKind::PointerVector => ast::Type::Pointer(
|
||||||
|
ast::PointerType::Vector(
|
||||||
|
ast::ScalarType::from_parts(t.width, t.scalar_kind),
|
||||||
|
t.components[0] as u8,
|
||||||
|
),
|
||||||
t.state_space,
|
t.state_space,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
@ -3879,7 +4037,7 @@ struct TypeParts {
|
||||||
scalar_kind: ScalarKind,
|
scalar_kind: ScalarKind,
|
||||||
width: u8,
|
width: u8,
|
||||||
components: Vec<u32>,
|
components: Vec<u32>,
|
||||||
state_space: ast::PointerStateSpace,
|
state_space: ast::LdStateSpace,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Copy, Clone)]
|
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||||
|
@ -3887,7 +4045,8 @@ enum TypeKind {
|
||||||
Scalar,
|
Scalar,
|
||||||
Vector,
|
Vector,
|
||||||
Array,
|
Array,
|
||||||
Pointer,
|
PointerScalar,
|
||||||
|
PointerVector,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ast::Instruction<ExpandedArgParams> {
|
impl ast::Instruction<ExpandedArgParams> {
|
||||||
|
@ -4007,6 +4166,7 @@ enum ConversionKind {
|
||||||
SignExtend,
|
SignExtend,
|
||||||
BitToPtr(ast::LdStateSpace),
|
BitToPtr(ast::LdStateSpace),
|
||||||
PtrToBit,
|
PtrToBit,
|
||||||
|
PtrToPtr { spirv_ptr: bool },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> ast::PredAt<T> {
|
impl<T> ast::PredAt<T> {
|
||||||
|
@ -4058,7 +4218,7 @@ impl ast::VariableParamType {
|
||||||
(ast::ScalarType::from(*t).size_of() as usize)
|
(ast::ScalarType::from(*t).size_of() as usize)
|
||||||
* (len.iter().fold(1, |x, y| x * (*y)) as usize)
|
* (len.iter().fold(1, |x, y| x * (*y)) as usize)
|
||||||
}
|
}
|
||||||
ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>()
|
ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4076,7 +4236,10 @@ impl From<ast::KernelArgumentType> for ast::Type {
|
||||||
fn from(this: ast::KernelArgumentType) -> Self {
|
fn from(this: ast::KernelArgumentType) -> Self {
|
||||||
match this {
|
match this {
|
||||||
ast::KernelArgumentType::Normal(typ) => typ.into(),
|
ast::KernelArgumentType::Normal(typ) => typ.into(),
|
||||||
ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64),
|
ast::KernelArgumentType::Shared => ast::Type::Pointer(
|
||||||
|
ast::PointerType::Scalar(ast::ScalarType::B8),
|
||||||
|
ast::LdStateSpace::Shared,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4085,9 +4248,10 @@ impl ast::KernelArgumentType {
|
||||||
fn to_param(self) -> ast::VariableParamType {
|
fn to_param(self) -> ast::VariableParamType {
|
||||||
match self {
|
match self {
|
||||||
ast::KernelArgumentType::Normal(p) => p,
|
ast::KernelArgumentType::Normal(p) => p,
|
||||||
ast::KernelArgumentType::Shared => {
|
ast::KernelArgumentType::Shared => ast::VariableParamType::Pointer(
|
||||||
ast::VariableParamType::Scalar(ast::ParamScalarType::B64)
|
ast::SizedScalarType::B8,
|
||||||
}
|
ast::PointerStateSpace::Shared,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4193,7 +4357,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
|
||||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||||
self,
|
self,
|
||||||
visitor: &mut V,
|
visitor: &mut V,
|
||||||
t: &ast::Type,
|
details: &ast::LdDetails,
|
||||||
is_param: bool,
|
is_param: bool,
|
||||||
) -> Result<ast::Arg2Ld<U>, TranslateError> {
|
) -> Result<ast::Arg2Ld<U>, TranslateError> {
|
||||||
let dst = visitor.id_or_vector(
|
let dst = visitor.id_or_vector(
|
||||||
|
@ -4202,7 +4366,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
|
||||||
is_dst: true,
|
is_dst: true,
|
||||||
sema: ArgumentSemantics::DefaultRelaxed,
|
sema: ArgumentSemantics::DefaultRelaxed,
|
||||||
},
|
},
|
||||||
&ast::Type::from(t.clone()),
|
&ast::Type::from(details.typ.clone()),
|
||||||
)?;
|
)?;
|
||||||
let src = visitor.operand(
|
let src = visitor.operand(
|
||||||
ArgumentDescriptor {
|
ArgumentDescriptor {
|
||||||
|
@ -4214,7 +4378,14 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
|
||||||
ArgumentSemantics::PhysicalPointer
|
ArgumentSemantics::PhysicalPointer
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
t,
|
&(if is_param {
|
||||||
|
ast::Type::from(details.typ.clone())
|
||||||
|
} else {
|
||||||
|
ast::Type::Pointer(
|
||||||
|
ast::PointerType::from(details.typ.clone()),
|
||||||
|
details.state_space,
|
||||||
|
)
|
||||||
|
}),
|
||||||
)?;
|
)?;
|
||||||
Ok(ast::Arg2Ld { dst, src })
|
Ok(ast::Arg2Ld { dst, src })
|
||||||
}
|
}
|
||||||
|
@ -4233,7 +4404,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
||||||
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||||
self,
|
self,
|
||||||
visitor: &mut V,
|
visitor: &mut V,
|
||||||
t: &ast::Type,
|
details: &ast::StData,
|
||||||
is_param: bool,
|
is_param: bool,
|
||||||
) -> Result<ast::Arg2St<U>, TranslateError> {
|
) -> Result<ast::Arg2St<U>, TranslateError> {
|
||||||
let src1 = visitor.operand(
|
let src1 = visitor.operand(
|
||||||
|
@ -4246,7 +4417,14 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
||||||
ArgumentSemantics::PhysicalPointer
|
ArgumentSemantics::PhysicalPointer
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
t,
|
&(if is_param {
|
||||||
|
details.typ.clone().into()
|
||||||
|
} else {
|
||||||
|
ast::Type::Pointer(
|
||||||
|
ast::PointerType::from(details.typ.clone()),
|
||||||
|
details.state_space.to_ld_ss(),
|
||||||
|
)
|
||||||
|
}),
|
||||||
)?;
|
)?;
|
||||||
let src2 = visitor.operand_or_vector(
|
let src2 = visitor.operand_or_vector(
|
||||||
ArgumentDescriptor {
|
ArgumentDescriptor {
|
||||||
|
@ -4254,7 +4432,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
|
||||||
is_dst: false,
|
is_dst: false,
|
||||||
sema: ArgumentSemantics::DefaultRelaxed,
|
sema: ArgumentSemantics::DefaultRelaxed,
|
||||||
},
|
},
|
||||||
t,
|
&details.typ.clone().into(),
|
||||||
)?;
|
)?;
|
||||||
Ok(ast::Arg2St { src1, src2 })
|
Ok(ast::Arg2St { src1, src2 })
|
||||||
}
|
}
|
||||||
|
@ -4957,7 +5135,7 @@ impl ast::MulDetails {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn force_bitcast(
|
fn bitcast_logical_pointer(
|
||||||
operand: &ast::Type,
|
operand: &ast::Type,
|
||||||
instr: &ast::Type,
|
instr: &ast::Type,
|
||||||
_: Option<ast::LdStateSpace>,
|
_: Option<ast::LdStateSpace>,
|
||||||
|
@ -4971,21 +5149,12 @@ fn force_bitcast(
|
||||||
|
|
||||||
fn bitcast_physical_pointer(
|
fn bitcast_physical_pointer(
|
||||||
operand_type: &ast::Type,
|
operand_type: &ast::Type,
|
||||||
_: &ast::Type,
|
instr_type: &ast::Type,
|
||||||
ss: Option<ast::LdStateSpace>,
|
ss: Option<ast::LdStateSpace>,
|
||||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
match operand_type {
|
match operand_type {
|
||||||
// array decays to a pointer
|
// array decays to a pointer
|
||||||
ast::Type::Array(_, vec) => {
|
ast::Type::Array(_, _) => todo!(),
|
||||||
if vec.len() != 0 {
|
|
||||||
return Err(TranslateError::MismatchedType);
|
|
||||||
}
|
|
||||||
if let Some(space) = ss {
|
|
||||||
Ok(Some(ConversionKind::BitToPtr(space)))
|
|
||||||
} else {
|
|
||||||
Err(TranslateError::Unreachable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ast::Type::Scalar(ast::ScalarType::B64)
|
ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
| ast::Type::Scalar(ast::ScalarType::S64) => {
|
| ast::Type::Scalar(ast::ScalarType::S64) => {
|
||||||
|
@ -4995,6 +5164,27 @@ fn bitcast_physical_pointer(
|
||||||
Err(TranslateError::Unreachable)
|
Err(TranslateError::Unreachable)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ast::Type::Pointer(op_scalar_t, op_space) => {
|
||||||
|
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
|
||||||
|
if op_space == instr_space {
|
||||||
|
if op_scalar_t == instr_scalar_t {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if *op_space == ast::LdStateSpace::Generic
|
||||||
|
|| *instr_space == ast::LdStateSpace::Generic
|
||||||
|
{
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
|
||||||
|
} else {
|
||||||
|
Err(TranslateError::MismatchedType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(TranslateError::MismatchedType)
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => Err(TranslateError::MismatchedType),
|
_ => Err(TranslateError::MismatchedType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5206,7 +5396,7 @@ fn should_convert_relaxed_dst(
|
||||||
impl<'a> ast::MethodDecl<'a, &'a str> {
|
impl<'a> ast::MethodDecl<'a, &'a str> {
|
||||||
fn name(&self) -> &'a str {
|
fn name(&self) -> &'a str {
|
||||||
match self {
|
match self {
|
||||||
ast::MethodDecl::Kernel(name, _) => name,
|
ast::MethodDecl::Kernel { name, .. } => name,
|
||||||
ast::MethodDecl::Func(_, name, _) => name,
|
ast::MethodDecl::Func(_, name, _) => name,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5216,7 +5406,7 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> {
|
||||||
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) {
|
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) {
|
||||||
match self {
|
match self {
|
||||||
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
|
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
|
||||||
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
|
ast::MethodDecl::Kernel { in_args, .. } => in_args.iter().for_each(|arg| {
|
||||||
f(&ast::FnArgument {
|
f(&ast::FnArgument {
|
||||||
align: arg.align,
|
align: arg.align,
|
||||||
name: arg.name,
|
name: arg.name,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue