Throw away special variable types

This commit is contained in:
Andrzej Janik 2021-04-17 14:01:50 +02:00
commit d51aaaf552
3 changed files with 256 additions and 490 deletions

View file

@ -1,6 +1,5 @@
use half::f16; use half::f16;
use lalrpop_util::{lexer::Token, ParseError}; use lalrpop_util::{lexer::Token, ParseError};
use std::convert::TryInto;
use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError}; use std::{marker::PhantomData, num::ParseIntError};
@ -34,107 +33,12 @@ pub enum PtxError {
NonExternPointer, NonExternPointer,
} }
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
sub_type! { $type_name : Type {
$(
$variant ($($field_type),+),
)+
}}
};
($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone)]
pub enum $type_name {
$(
$variant ($($field_type),+),
)+
}
impl From<$type_name> for $base_type {
#[allow(non_snake_case)]
fn from(t: $type_name) -> $base_type {
match t {
$(
$type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
)+
}
}
}
impl std::convert::TryFrom<$base_type> for $type_name {
type Error = ();
#[allow(non_snake_case)]
#[allow(unreachable_patterns)]
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t {
$(
$base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
)+
_ => Err(()),
}
}
}
};
}
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(ScalarType, u8),
// Array type is used when emiting SSA statements at the start of a method
Array(ScalarType, VecU32),
// Pointer variant is used when passing around SLM pointer between
// function calls for dynamic SLM
Pointer(ScalarType, LdStateSpace)
}
}
type VecU32 = Vec<u32>;
sub_type! {
VariableLocalType {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
}
}
impl TryFrom<VariableGlobalType> for VariableLocalType {
type Error = PtxError;
fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
match value {
VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
}
}
}
sub_type! {
VariableGlobalType {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
// For some weird reson this is illegal: // For some weird reson this is illegal:
// .param .f16x2 foobar; // .param .f16x2 foobar;
// but this is legal: // but this is legal:
// .param .f16x2 foobar[1]; // .param .f16x2 foobar[1];
// even more interestingly this is legal, but only in .func (not in .entry): // even more interestingly this is legal, but only in .func (not in .entry):
// .param .b32 foobar[] // .param .b32 foobar[]
sub_type! {
VariableParamType {
Scalar(ScalarType),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails { pub enum BarDetails {
@ -178,7 +82,7 @@ pub struct Module<'a> {
} }
pub enum Directive<'a, P: ArgParams> { pub enum Directive<'a, P: ArgParams> {
Variable(Variable<VariableType, P::Id>), Variable(Variable<P::Id>),
Method(Function<'a, &'a str, Statement<P>>), Method(Function<'a, &'a str, Statement<P>>),
} }
@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> {
}, },
} }
pub type FnArgument<ID> = Variable<FnArgumentType, ID>; pub type FnArgument<ID> = Variable<ID>;
pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; pub type KernelArgument<ID> = Variable<ID>;
pub struct Function<'a, ID, S> { pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>, pub func_directive: MethodDecl<'a, ID>,
@ -201,76 +105,6 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone)]
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
Shared,
}
#[derive(PartialEq, Eq, Clone)]
pub enum KernelArgumentType {
Normal(VariableParamType),
Shared,
}
impl From<KernelArgumentType> for Type {
fn from(this: KernelArgumentType) -> Self {
match this {
KernelArgumentType::Normal(typ) => typ.into(),
KernelArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
}
impl FnArgumentType {
pub fn to_type(&self, is_kernel: bool) -> Type {
if is_kernel {
self.to_kernel_type()
} else {
self.to_func_type()
}
}
pub fn to_kernel_type(&self) -> Type {
match self {
FnArgumentType::Reg(x) => x.clone().into(),
FnArgumentType::Param(x) => x.clone().into(),
FnArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
pub fn to_func_type(&self) -> Type {
match self {
FnArgumentType::Reg(x) => x.clone().into(),
FnArgumentType::Param(VariableParamType::Scalar(t)) => {
Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param)
}
FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer(
PointerType::Array((*t).into(), dims.clone()),
LdStateSpace::Param,
),
FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer(
PointerType::Pointer((*t).into(), (*space).into()),
LdStateSpace::Param,
),
FnArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
pub fn is_param(&self) -> bool {
match self {
FnArgumentType::Param(_) => true,
_ => false,
}
}
}
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub enum Type { pub enum Type {
Scalar(ScalarType), Scalar(ScalarType),
@ -283,7 +117,7 @@ pub enum Type {
pub enum PointerType { pub enum PointerType {
Scalar(ScalarType), Scalar(ScalarType),
Vector(ScalarType, u8), Vector(ScalarType, u8),
Array(ScalarType, VecU32), Array(ScalarType, Vec<u32>),
// Instances of this variant are generated during stateful conversion // Instances of this variant are generated during stateful conversion
Pointer(ScalarType, LdStateSpace), Pointer(ScalarType, LdStateSpace),
} }
@ -366,51 +200,19 @@ pub enum Statement<P: ArgParams> {
} }
pub struct MultiVariable<ID> { pub struct MultiVariable<ID> {
pub var: Variable<VariableType, ID>, pub var: Variable<ID>,
pub count: Option<u32>, pub count: Option<u32>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Variable<T, ID> { pub struct Variable<ID> {
pub align: Option<u32>, pub align: Option<u32>,
pub v_type: T, pub v_type: Type,
pub state_space: StateSpace,
pub name: ID, pub name: ID,
pub array_init: Vec<u8>, pub array_init: Vec<u8>,
} }
#[derive(Eq, PartialEq, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
Global(VariableGlobalType),
Shared(VariableGlobalType),
}
impl VariableType {
pub fn to_type(&self) -> (StateSpace, Type) {
match self {
VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
}
}
}
impl From<VariableType> for Type {
fn from(t: VariableType) -> Self {
match t {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
VariableType::Global(t) => t.into(),
VariableType::Shared(t) => t.into(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)] #[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace { pub enum StateSpace {
Reg, Reg,
@ -419,6 +221,7 @@ pub enum StateSpace {
Local, Local,
Shared, Shared,
Param, Param,
Generic,
} }
pub struct PredAt<ID> { pub struct PredAt<ID> {

View file

@ -404,28 +404,29 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args "(" <args:Comma<FnInput>> ")" => args
}; };
KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = { KernelInput: ast::Variable<&'input str> = {
<v:ParamDeclaration> => { <v:ParamDeclaration> => {
let (align, v_type, name) = v; let (align, v_type, name) = v;
ast::Variable { ast::Variable {
align, align,
v_type: ast::KernelArgumentType::Normal(v_type), v_type,
state_space: ast::StateSpace::Param,
name, name,
array_init: Vec::new() array_init: Vec::new()
} }
} }
} }
FnInput: ast::Variable<ast::FnArgumentType, &'input str> = { FnInput: ast::Variable<&'input str> = {
<v:RegVariable> => { <v:RegVariable> => {
let (align, v_type, name) = v; let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type); let state_space = ast::StateSpace::Reg;
ast::Variable{ align, v_type, name, array_init: Vec::new() } ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
}, },
<v:ParamDeclaration> => { <v:ParamDeclaration> => {
let (align, v_type, name) = v; let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Param(v_type); let state_space = ast::StateSpace::Param;
ast::Variable{ align, v_type, name, array_init: Vec::new() } ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
} }
} }
@ -508,102 +509,109 @@ VariableParam: u32 = {
"<" <n:U32Num> ">" => n "<" <n:U32Num> ">" => n
} }
Variable: ast::Variable<ast::VariableType, &'input str> = { Variable: ast::Variable<&'input str> = {
<v:RegVariable> => { <v:RegVariable> => {
let (align, v_type, name) = v; let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type); let state_space = ast::StateSpace::Reg;
ast::Variable {align, v_type, name, array_init: Vec::new()} ast::Variable {align, v_type, state_space, name, array_init: Vec::new()}
}, },
LocalVariable, LocalVariable,
<v:ParamVariable> => { <v:ParamVariable> => {
let (align, array_init, v_type, name) = v; let (align, array_init, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type); let state_space = ast::StateSpace::Param;
ast::Variable {align, v_type, name, array_init} ast::Variable {align, v_type, state_space, name, array_init}
}, },
SharedVariable, SharedVariable,
}; };
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { RegVariable: (Option<u32>, ast::Type, &'input str) = {
".reg" <var:VariableScalar<ScalarType>> => { ".reg" <var:VariableScalar<ScalarType>> => {
let (align, t, name) = var; let (align, t, name) = var;
let v_type = ast::VariableRegType::Scalar(t); let v_type = ast::Type::Scalar(t);
(align, v_type, name) (align, v_type, name)
}, },
".reg" <var:VariableVector<SizedScalarType>> => { ".reg" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var; let (align, v_len, t, name) = var;
let v_type = ast::VariableRegType::Vector(t, v_len); let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name) (align, v_type, name)
} }
} }
LocalVariable: ast::Variable<ast::VariableType, &'input str> = { LocalVariable: ast::Variable<&'input str> = {
".local" <var:VariableScalar<SizedScalarType>> => { ".local" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var; let (align, t, name) = var;
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); let v_type = ast::Type::Scalar(t);
ast::Variable { align, v_type, name, array_init: Vec::new() } let state_space = ast::StateSpace::Local;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
}, },
".local" <var:VariableVector<SizedScalarType>> => { ".local" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var; let (align, v_len, t, name) = var;
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); let v_type = ast::Type::Vector(t, v_len);
ast::Variable { align, v_type, name, array_init: Vec::new() } let state_space = ast::StateSpace::Local;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
}, },
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? { ".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var; let (align, t, name, arr_or_ptr) = var;
let state_space = ast::StateSpace::Local;
let (v_type, array_init) = match arr_or_ptr { let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => { ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableLocalType::Array(t, dimensions), init) (ast::Type::Array(t, dimensions), init)
} }
ast::ArrayOrPointer::Pointer => { ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
} }
}; };
Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) Ok(ast::Variable { align, v_type, state_space, name, array_init })
} }
} }
SharedVariable: ast::Variable<ast::VariableType, &'input str> = { SharedVariable: ast::Variable<&'input str> = {
".shared" <var:VariableScalar<SizedScalarType>> => { ".shared" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var; let (align, t, name) = var;
let v_type = ast::VariableGlobalType::Scalar(t); let state_space = ast::StateSpace::Shared;
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } let v_type = ast::Type::Scalar(t);
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
}, },
".shared" <var:VariableVector<SizedScalarType>> => { ".shared" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var; let (align, v_len, t, name) = var;
let v_type = ast::VariableGlobalType::Vector(t, v_len); let state_space = ast::StateSpace::Shared;
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } let v_type = ast::Type::Vector(t, v_len);
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
}, },
".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? { ".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var; let (align, t, name, arr_or_ptr) = var;
let state_space = ast::StateSpace::Shared;
let (v_type, array_init) = match arr_or_ptr { let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => { ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableGlobalType::Array(t, dimensions), init) (ast::Type::Array(t, dimensions), init)
} }
ast::ArrayOrPointer::Pointer => { ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
} }
}; };
Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) Ok(ast::Variable { align, v_type, state_space, name, array_init })
} }
} }
ModuleVariable: ast::Variable<&'input str> = {
ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => { LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def; let (align, v_type, name, array_init) = def;
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } let state_space = ast::StateSpace::Global;
ast::Variable { align, v_type, state_space, name, array_init }
}, },
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => { LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def; let (align, v_type, name, array_init) = def;
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } let state_space = ast::StateSpace::Shared;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
}, },
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var; let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr { let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => { ast::ArrayOrPointer::Array { dimensions, init } => {
if space == ".global" { if space == ".global" {
(ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
} else { } else {
(ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
} }
} }
ast::ArrayOrPointer::Pointer => { ast::ArrayOrPointer::Pointer => {
@ -611,38 +619,38 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
} }
if space == ".global" { if space == ".global" {
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new())
} else { } else {
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new())
} }
} }
}; };
Ok(ast::Variable{ align, array_init, v_type, name }) Ok(ast::Variable{ align, v_type, state_space, name, array_init })
} }
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = { ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
".param" <var:VariableScalar<LdStScalarType>> => { ".param" <var:VariableScalar<LdStScalarType>> => {
let (align, t, name) = var; let (align, t, name) = var;
let v_type = ast::VariableParamType::Scalar(t); let v_type = ast::Type::Scalar(t);
(align, Vec::new(), v_type, name) (align, Vec::new(), v_type, name)
}, },
".param" <var:VariableArrayOrPointer<SizedScalarType>> => { ".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
let (align, t, name, arr_or_ptr) = var; let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr { let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => { ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableParamType::Array(t, dimensions), init) (ast::Type::Array(t, dimensions), init)
} }
ast::ArrayOrPointer::Pointer => { ast::ArrayOrPointer::Pointer => {
(ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new())
} }
}; };
(align, array_init, v_type, name) (align, array_init, v_type, name)
} }
} }
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { ParamDeclaration: (Option<u32>, ast::Type, &'input str) = {
<var:ParamVariable> =>? { <var:ParamVariable> =>? {
let (align, array_init, v_type, name) = var; let (align, array_init, v_type, name) = var;
if array_init.len() > 0 { if array_init.len() > 0 {
@ -653,15 +661,15 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
} }
} }
GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = { GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = {
<scalar:VariableScalar<SizedScalarType>> => { <scalar:VariableScalar<SizedScalarType>> => {
let (align, t, name) = scalar; let (align, t, name) = scalar;
let v_type = ast::VariableGlobalType::Scalar(t); let v_type = ast::Type::Scalar(t);
(align, v_type, name, Vec::new()) (align, v_type, name, Vec::new())
}, },
<var:VariableVector<SizedScalarType>> => { <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var; let (align, v_len, t, name) = var;
let v_type = ast::VariableGlobalType::Vector(t, v_len); let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name, Vec::new()) (align, v_type, name, Vec::new())
}, },
} }

View file

@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
let mut extern_shared_decls = HashMap::new(); let mut extern_shared_decls = HashMap::new();
for dir in module.iter() { for dir in module.iter() {
match dir { match dir {
Directive::Variable(var) => { Directive::Variable(ast::Variable {
if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
var.v_type state_space: ast::StateSpace::Shared,
{ name,
extern_shared_decls.insert(var.name, p_type); ..
} }) => {
extern_shared_decls.insert(*name, p_type.clone());
} }
_ => {} _ => {}
} }
@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>(
let shared_id_param = new_id(); let shared_id_param = new_id();
spirv_decl.input.push({ spirv_decl.input.push({
ast::Variable { ast::Variable {
name: shared_id_param,
align: None, align: None,
v_type: ast::Type::Pointer( v_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8), ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared, ast::LdStateSpace::Shared,
), ),
state_space: ast::StateSpace::Param,
array_init: Vec::new(), array_init: Vec::new(),
name: shared_id_param,
} }
}); });
spirv_decl.uses_shared_mem = true; spirv_decl.uses_shared_mem = true;
let shared_var_id = new_id(); let shared_var_id = new_id();
let shared_var = ExpandedStatement::Variable(ast::Variable { let shared_var = ExpandedStatement::Variable(ast::Variable {
align: None,
name: shared_var_id, name: shared_var_id,
array_init: Vec::new(), align: None,
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( v_type: ast::Type::Pointer(
ast::ScalarType::B8, ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared, ast::LdStateSpace::Shared,
)), ),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
}); });
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
arg: ast::Arg2St { arg: ast::Arg2St {
@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>( fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>, result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word, new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>, extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>, methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word, shared_id_param: spirv::Word,
shared_var_id: spirv::Word, shared_var_id: spirv::Word,
@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>(
// because there's simply no way to pass shared ptr // because there's simply no way to pass shared ptr
// without converting it to .b64 first // without converting it to .b64 first
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
call.param_list call.param_list.push((
.push((shared_id_param, ast::FnArgumentType::Shared)); shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
ast::StateSpace::Shared,
));
} }
result.push(Statement::Call(call)) result.push(Statement::Call(call))
} }
statement => { statement => {
let new_statement = statement.map_id(&mut |id, _| { let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) { if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
if *typ == ast::ScalarType::B8 { if *typ == ast::ScalarType::B8 {
return shared_var_id; return shared_var_id;
} }
@ -1067,7 +1073,7 @@ fn emit_function_header<'a>(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>, defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>], synthetic_globals: &[ast::Variable<spirv::Word>],
func_decl: &SpirvMethodDecl<'a>, func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>, _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>, call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
@ -1204,9 +1210,9 @@ fn translate_directive<'input>(
fn translate_variable<'a>( fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>, id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>, var: ast::Variable<&'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> { ) -> Result<ast::Variable<spirv::Word>, TranslateError> {
let (space, var_type) = var.v_type.to_type(); let (space, var_type) = (var.state_space, var.v_type.clone());
let mut is_variable = false; let mut is_variable = false;
let var_type = match space { let var_type = match space {
ast::StateSpace::Reg => { ast::StateSpace::Reg => {
@ -1226,10 +1232,12 @@ fn translate_variable<'a>(
} }
} }
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
ast::StateSpace::Generic => todo!(),
}; };
Ok(ast::Variable { Ok(ast::Variable {
align: var.align, align: var.align,
v_type: var.v_type, v_type: var.v_type,
state_space: var.state_space,
name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init, array_init: var.array_init,
}) })
@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>(
false, false,
), ),
v_type: a.v_type.clone(), v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align, align: a.align,
array_init: Vec::new(), array_init: Vec::new(),
}) })
@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>, args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> { ) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| { args.map(|a| {
let is_variable = match a.v_type { let is_variable = a.state_space == ast::StateSpace::Reg;
ast::FnArgumentType::Reg(_) => true,
_ => false,
};
let var_type = a.v_type.to_func_type();
Ok(ast::FnArgument { Ok(ast::FnArgument {
name: fn_resolver.add_def(a.name, Some(var_type), is_variable), name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
v_type: a.v_type.clone(), v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align, align: a.align,
array_init: Vec::new(), array_init: Vec::new(),
}) })
@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>, sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>, ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
) -> ( ) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
Vec<ExpandedStatement>,
Vec<ast::Variable<ast::VariableType, spirv::Word>>,
) {
let mut local = Vec::with_capacity(sorted_statements.len()); let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new(); let mut global = Vec::new();
for statement in sorted_statements { for statement in sorted_statements {
@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>(
var var
@ @
ast::Variable { ast::Variable {
v_type: ast::VariableType::Shared(_), state_space: ast::StateSpace::Shared,
.. ..
}, },
) )
@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>(
var var
@ @
ast::Variable { ast::Variable {
v_type: ast::VariableType::Global(_), state_space: ast::StateSpace::Global,
.. ..
}, },
) => global.push(var), ) => global.push(var),
@ -1592,10 +1595,10 @@ fn convert_to_typed_statements(
let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter() .into_iter()
.partition(|(_, arg_type)| arg_type.is_param()); .partition(|(_, _, space)| *space == ast::StateSpace::Param);
let normalized_input_args = out_params let normalized_input_args = out_params
.into_iter() .into_iter()
.map(|(id, typ)| (ast::Operand::Reg(id), typ)) .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space))
.chain(in_args.into_iter()) .chain(in_args.into_iter())
.collect(); .collect();
let resolved_call = ResolvedCall { let resolved_call = ResolvedCall {
@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument { vec![ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}], }],
@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call(
vec![ vec![
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
typ, ptr_space, state_space: ast::StateSpace::Reg,
)),
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call(
Statement::Call(ResolvedCall { Statement::Call(ResolvedCall {
uniform: false, uniform: false,
func: fn_id, func: fn_id,
ret_params: vec![( ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
)],
param_list: vec![ param_list: vec![
( (
arg.src1, arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)), ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
ast::StateSpace::Reg,
), ),
( (
arg.src2, arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)), ast::Type::Scalar(scalar_typ),
ast::StateSpace::Reg,
), ),
], ],
}) })
@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument { vec![ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}], }],
@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call(
vec![ vec![
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( v_type: ast::Type::Scalar(ast::ScalarType::U32),
ast::ScalarType::U32, state_space: ast::StateSpace::Reg,
)),
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( v_type: ast::Type::Scalar(ast::ScalarType::U32),
ast::ScalarType::U32, state_space: ast::StateSpace::Reg,
)),
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call(
Statement::Call(ResolvedCall { Statement::Call(ResolvedCall {
uniform: false, uniform: false,
func: fn_id, func: fn_id,
ret_params: vec![( ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
param_list: vec![ param_list: vec![
( (
arg.src1, arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
), ),
( (
arg.src2, arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
), ),
( (
arg.src3, arg.src3,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
), ),
], ],
}) })
@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument { vec![ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}], }],
@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call(
vec![ vec![
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( v_type: ast::Type::Scalar(ast::ScalarType::U32),
ast::ScalarType::U32, state_space: ast::StateSpace::Reg,
)),
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
ast::FnArgument { ast::FnArgument {
align: None, align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( v_type: ast::Type::Scalar(ast::ScalarType::U32),
ast::ScalarType::U32, state_space: ast::StateSpace::Reg,
)),
name: id_defs.new_non_variable(None), name: id_defs.new_non_variable(None),
array_init: Vec::new(), array_init: Vec::new(),
}, },
@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call(
Statement::Call(ResolvedCall { Statement::Call(ResolvedCall {
uniform: false, uniform: false,
func: fn_id, func: fn_id,
ret_params: vec![( ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
param_list: vec![ param_list: vec![
( (
arg.src1, arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
), ),
( (
arg.src2, arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
), ),
( (
arg.src3, arg.src3,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
), ),
( (
arg.src4, arg.src4,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
), ),
], ],
}) })
@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call(
fn to_resolved_fn_args<T>( fn to_resolved_fn_args<T>(
params: Vec<T>, params: Vec<T>,
params_decl: &[ast::FnArgumentType], params_decl: &[(ast::Type, ast::StateSpace)],
) -> Vec<(T, ast::FnArgumentType)> { ) -> Vec<(T, ast::Type, ast::StateSpace)> {
params params
.into_iter() .into_iter()
.zip(params_decl.iter()) .zip(params_decl.iter())
.map(|(id, typ)| (id, typ.clone())) .map(|(id, (typ, space))| (id, typ.clone(), *space))
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
@ -2096,35 +2101,26 @@ fn normalize_predicates(
fn insert_mem_ssa_statements<'a, 'b>( fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>, func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver, id_def: &mut NumericIdResolver,
ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, _: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl, fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> { ) -> Result<Vec<TypedStatement>, TranslateError> {
let is_func = match ast_fn_decl {
ast::MethodDecl::Func(..) => true,
ast::MethodDecl::Kernel { .. } => false,
};
let mut result = Vec::with_capacity(func.len()); let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() { for arg in fn_decl.output.iter() {
match type_to_variable_type(&arg.v_type, is_func)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: arg.align, align: arg.align,
v_type: var_type, v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name, name: arg.name,
array_init: arg.array_init.clone(), array_init: arg.array_init.clone(),
})); }));
} }
None => return Err(error_unreachable()),
}
}
for spirv_arg in fn_decl.input.iter_mut() { for spirv_arg in fn_decl.input.iter_mut() {
match type_to_variable_type(&spirv_arg.v_type, is_func)? {
Some(var_type) => {
let typ = spirv_arg.v_type.clone(); let typ = spirv_arg.v_type.clone();
let new_id = id_def.new_non_variable(Some(typ.clone())); let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: spirv_arg.align, align: spirv_arg.align,
v_type: var_type, v_type: spirv_arg.v_type.clone(),
state_space: spirv_arg.state_space,
name: spirv_arg.name, name: spirv_arg.name,
array_init: spirv_arg.array_init.clone(), array_init: spirv_arg.array_init.clone(),
})); }));
@ -2138,9 +2134,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
})); }));
spirv_arg.name = new_id; spirv_arg.name = new_id;
} }
None => {}
}
}
for s in func { for s in func {
match s { match s {
Statement::Call(call) => { Statement::Call(call) => {
@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result) Ok(result)
} }
fn type_to_variable_type(
t: &ast::Type,
is_func: bool,
) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
*len,
))),
ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
if is_func {
return Ok(None);
}
Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
scalar_type
.clone()
.try_into()
.map_err(|_| error_unreachable())?,
(*space).try_into().map_err(|_| error_unreachable())?,
)))
}
ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
_ => return Err(error_unreachable()),
})
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized { trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit( fn visit(
self, self,
@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>(
Statement::Variable(ast::Variable { Statement::Variable(ast::Variable {
align, align,
v_type, v_type,
state_space,
name, name,
array_init, array_init,
}) => result.push(Statement::Variable(ast::Variable { }) => result.push(Statement::Variable(ast::Variable {
align, align,
v_type, v_type,
state_space,
name, name,
array_init, array_init,
})), })),
@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl(
fn get_function_type( fn get_function_type(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
spirv_input: &[ast::Variable<ast::Type, spirv::Word>], spirv_input: &[ast::Variable<spirv::Word>],
spirv_output: &[ast::Variable<ast::Type, spirv::Word>], spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) { ) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn( map.get_or_add_fn(
builder, builder,
@ -2822,8 +2782,8 @@ fn emit_function_body_ops(
Statement::Label(_) => (), Statement::Label(_) => (),
Statement::Call(call) => { Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params { let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => ( [(id, typ, _)] => (
map.get_or_add(builder, SpirvType::from(typ.to_func_type())), map.get_or_add(builder, SpirvType::from(typ.clone())),
Some(*id), Some(*id),
), ),
[] => (map.void(), None), [] => (map.void(), None),
@ -2832,7 +2792,7 @@ fn emit_function_body_ops(
let arg_list = call let arg_list = call
.param_list .param_list
.iter() .iter()
.map(|(id, _)| *id) .map(|(id, _, _)| *id)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?; builder.function_call(result_type, result_id, call.func, arg_list)?;
} }
@ -3602,14 +3562,16 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
fn emit_variable( fn emit_variable(
builder: &mut dr::Builder, builder: &mut dr::Builder,
map: &mut TypeWordMap, map: &mut TypeWordMap,
var: &ast::Variable<ast::VariableType, spirv::Word>, var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let (must_init, st_class) = match var.v_type { let (must_init, st_class) = match var.state_space {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function) (false, spirv::StorageClass::Function)
} }
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => todo!(),
ast::StateSpace::Generic => todo!(),
}; };
let initalizer = if var.array_init.len() > 0 { let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant( Some(map.get_or_add_constant(
@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>(
ast::Statement::Variable(var) => { ast::Statement::Variable(var) => {
let mut var_type = ast::Type::from(var.var.v_type.clone()); let mut var_type = ast::Type::from(var.var.v_type.clone());
let mut is_variable = false; let mut is_variable = false;
var_type = match var.var.v_type { var_type = match var.var.state_space {
ast::VariableType::Reg(_) => { ast::StateSpace::Reg => {
is_variable = true; is_variable = true;
var_type var_type
} }
ast::VariableType::Shared(_) => { ast::StateSpace::Shared => {
// If it's a pointer it will be translated to a method parameter later // If it's a pointer it will be translated to a method parameter later
if let ast::Type::Pointer(..) = var_type { if let ast::Type::Pointer(..) = var_type {
is_variable = true; is_variable = true;
@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>(
var_type.param_pointer_to(ast::LdStateSpace::Shared)? var_type.param_pointer_to(ast::LdStateSpace::Shared)?
} }
} }
ast::VariableType::Global(_) => { ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
var_type.param_pointer_to(ast::LdStateSpace::Global)? ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
} ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
ast::VariableType::Param(_) => { ast::StateSpace::Const => todo!(),
var_type.param_pointer_to(ast::LdStateSpace::Param)? ast::StateSpace::Generic => todo!(),
}
ast::VariableType::Local(_) => {
var_type.param_pointer_to(ast::LdStateSpace::Local)?
}
}; };
match var.count { match var.count {
Some(count) => { Some(count) => {
@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: var.var.align, align: var.var.align,
v_type: var.var.v_type.clone(), v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id, name: new_id,
array_init: var.var.array_init.clone(), array_init: var.var.array_init.clone(),
})) }))
@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable { result.push(Statement::Variable(ast::Variable {
align: var.var.align, align: var.var.align,
v_type: var.var.v_type.clone(), v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id, name: new_id,
array_init: var.var.array_init, array_init: var.var.array_init,
})); }));
@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>(
align: None, align: None,
name: new_id, name: new_id,
array_init: Vec::new(), array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( v_type: ast::Type::Pointer(
ast::ScalarType::U8, ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global, ast::LdStateSpace::Global,
)), ),
state_space: ast::StateSpace::Reg,
})); }));
remapped_ids.insert(reg, new_id); remapped_ids.insert(reg, new_id);
} }
@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> {
} }
pub struct FnDecl { pub struct FnDecl {
ret_vals: Vec<ast::FnArgumentType>, ret_vals: Vec<(ast::Type, ast::StateSpace)>,
params: Vec<ast::FnArgumentType>, params: Vec<(ast::Type, ast::StateSpace)>,
} }
impl<'a> GlobalStringIdResolver<'a> { impl<'a> GlobalStringIdResolver<'a> {
@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert( self.fns.insert(
name_id, name_id,
FnDecl { FnDecl {
ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), ret_vals: ret_params_ids
params: params_ids.iter().map(|p| p.v_type.clone()).collect(), .iter()
.map(|p| (p.v_type.clone(), p.state_space))
.collect(),
params: params_ids
.iter()
.map(|p| (p.v_type.clone(), p.state_space))
.collect(),
}, },
); );
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> { enum Statement<I, P: ast::ArgParams> {
Label(u32), Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>), Variable(ast::Variable<P::Id>),
Instruction(I), Instruction(I),
// SPIR-V compatible replacement for PTX predicates // SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition), Conditional(BrachCondition),
@ -5352,16 +5319,17 @@ impl ExpandedStatement {
Statement::StoreVar(details) Statement::StoreVar(details)
} }
Statement::Call(mut call) => { Statement::Call(mut call) => {
for (id, typ) in call.ret_params.iter_mut() { for (id, _, space) in call.ret_params.iter_mut() {
let is_dst = match typ { let is_dst = match space {
ast::FnArgumentType::Reg(_) => true, ast::StateSpace::Reg => true,
ast::FnArgumentType::Param(_) => false, ast::StateSpace::Param => false,
ast::FnArgumentType::Shared => false, ast::StateSpace::Shared => false,
_ => todo!(),
}; };
*id = f(*id, is_dst); *id = f(*id, is_dst);
} }
call.func = f(call.func, false); call.func = f(call.func, false);
for (id, _) in call.param_list.iter_mut() { for (id, _, _) in call.param_list.iter_mut() {
*id = f(*id, false); *id = f(*id, false);
} }
Statement::Call(call) Statement::Call(call)
@ -5502,9 +5470,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
struct ResolvedCall<P: ast::ArgParams> { struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool, pub uniform: bool,
pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>,
pub func: P::Id, pub func: P::Id,
pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
} }
impl<T: ast::ArgParams> ResolvedCall<T> { impl<T: ast::ArgParams> ResolvedCall<T> {
@ -5526,16 +5494,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let ret_params = self let ret_params = self
.ret_params .ret_params
.into_iter() .into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| { .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id( let new_id = visitor.id(
ArgumentDescriptor { ArgumentDescriptor {
op: id, op: id,
is_dst: !typ.is_param(), is_dst: space != ast::StateSpace::Param,
sema: typ.semantics(), sema: space.semantics(),
}, },
Some(&typ.to_func_type()), Some(&typ),
)?; )?;
Ok((new_id, typ)) Ok((new_id, typ, space))
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let func = visitor.id( let func = visitor.id(
@ -5549,16 +5517,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let param_list = self let param_list = self
.param_list .param_list
.into_iter() .into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| { .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand( let new_id = visitor.operand(
ArgumentDescriptor { ArgumentDescriptor {
op: id, op: id,
is_dst: false, is_dst: false,
sema: typ.semantics(), sema: space.semantics(),
}, },
&typ.to_func_type(), &typ,
)?; )?;
Ok((new_id, typ)) Ok((new_id, typ, space))
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall { Ok(ResolvedCall {
@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams {
} }
enum Directive<'input> { enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>), Variable(ast::Variable<spirv::Word>),
Method(Function<'input>), Method(Function<'input>),
} }
struct Function<'input> { struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>, pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub spirv_decl: SpirvMethodDecl<'input>, pub spirv_decl: SpirvMethodDecl<'input>,
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>, pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>, pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>, import_as: Option<String>,
tuning: Vec<ast::TuningDirective>, tuning: Vec<ast::TuningDirective>,
@ -7300,16 +7268,6 @@ impl ast::LdStateSpace {
} }
} }
impl From<ast::FnArgumentType> for ast::VariableType {
fn from(t: ast::FnArgumentType) -> Self {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
ast::FnArgumentType::Shared => todo!(),
}
}
}
impl<T> ast::Operand<T> { impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> { fn underlying(&self) -> Option<&T> {
match self { match self {
@ -7362,12 +7320,13 @@ impl ast::AtomSemantics {
} }
} }
impl ast::FnArgumentType { impl ast::StateSpace {
fn semantics(&self) -> ArgumentSemantics { fn semantics(self) -> ArgumentSemantics {
match self { match self {
ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default, ast::StateSpace::Reg => ArgumentSemantics::Default,
ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer, ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer, ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer,
_ => todo!(),
} }
} }
} }
@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> {
} }
struct SpirvMethodDecl<'input> { struct SpirvMethodDecl<'input> {
input: Vec<ast::Variable<ast::Type, spirv::Word>>, input: Vec<ast::Variable<spirv::Word>>,
output: Vec<ast::Variable<ast::Type, spirv::Word>>, output: Vec<ast::Variable<spirv::Word>>,
name: MethodName<'input>, name: MethodName<'input>,
uses_shared_mem: bool, uses_shared_mem: bool,
} }
@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> {
ast::MethodDecl::Kernel { in_args, .. } => { ast::MethodDecl::Kernel { in_args, .. } => {
let spirv_input = in_args let spirv_input = in_args
.iter() .iter()
.map(|var| { .map(|var| ast::Variable {
let v_type = match &var.v_type {
ast::KernelArgumentType::Normal(t) => {
ast::FnArgumentType::Param(t.clone())
}
ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
};
ast::Variable {
name: var.name, name: var.name,
align: var.align, align: var.align,
v_type: v_type.to_kernel_type(), v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(), array_init: var.array_init.clone(),
}
}) })
.collect(); .collect();
(spirv_input, Vec::new()) (spirv_input, Vec::new())
} }
ast::MethodDecl::Func(out_args, _, in_args) => { ast::MethodDecl::Func(out_args, _, in_args) => {
let (param_output, non_param_output): (Vec<_>, Vec<_>) = let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
out_args.iter().partition(|var| var.v_type.is_param()); .iter()
.partition(|var| var.state_space == ast::StateSpace::Param);
let spirv_output = non_param_output let spirv_output = non_param_output
.into_iter() .into_iter()
.cloned() .cloned()
.map(|var| ast::Variable { .map(|var| ast::Variable {
name: var.name, name: var.name,
align: var.align, align: var.align,
v_type: var.v_type.to_func_type(), v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(), array_init: var.array_init.clone(),
}) })
.collect(); .collect();
@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> {
.map(|var| ast::Variable { .map(|var| ast::Variable {
name: var.name, name: var.name,
align: var.align, align: var.align,
v_type: var.v_type.to_func_type(), v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(), array_init: var.array_init.clone(),
}) })
.collect(); .collect();