Throw away special variable types

This commit is contained in:
Andrzej Janik 2021-04-17 14:01:50 +02:00
parent a55c851eaa
commit d51aaaf552
3 changed files with 256 additions and 490 deletions

View file

@ -1,6 +1,5 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
use std::convert::TryInto;
use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
@ -34,107 +33,12 @@ pub enum PtxError {
NonExternPointer,
}
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
sub_type! { $type_name : Type {
$(
$variant ($($field_type),+),
)+
}}
};
($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone)]
pub enum $type_name {
$(
$variant ($($field_type),+),
)+
}
impl From<$type_name> for $base_type {
#[allow(non_snake_case)]
fn from(t: $type_name) -> $base_type {
match t {
$(
$type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
)+
}
}
}
impl std::convert::TryFrom<$base_type> for $type_name {
type Error = ();
#[allow(non_snake_case)]
#[allow(unreachable_patterns)]
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t {
$(
$base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
)+
_ => Err(()),
}
}
}
};
}
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(ScalarType, u8),
// Array type is used when emiting SSA statements at the start of a method
Array(ScalarType, VecU32),
// Pointer variant is used when passing around SLM pointer between
// function calls for dynamic SLM
Pointer(ScalarType, LdStateSpace)
}
}
type VecU32 = Vec<u32>;
sub_type! {
VariableLocalType {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
}
}
impl TryFrom<VariableGlobalType> for VariableLocalType {
type Error = PtxError;
fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
match value {
VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
}
}
}
sub_type! {
VariableGlobalType {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
// even more interestingly this is legal, but only in .func (not in .entry):
// .param .b32 foobar[]
sub_type! {
VariableParamType {
Scalar(ScalarType),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails {
@ -178,7 +82,7 @@ pub struct Module<'a> {
}
pub enum Directive<'a, P: ArgParams> {
Variable(Variable<VariableType, P::Id>),
Variable(Variable<P::Id>),
Method(Function<'a, &'a str, Statement<P>>),
}
@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> {
},
}
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
pub type FnArgument<ID> = Variable<ID>;
pub type KernelArgument<ID> = Variable<ID>;
pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>,
@ -201,76 +105,6 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone)]
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
Shared,
}
#[derive(PartialEq, Eq, Clone)]
pub enum KernelArgumentType {
Normal(VariableParamType),
Shared,
}
impl From<KernelArgumentType> for Type {
fn from(this: KernelArgumentType) -> Self {
match this {
KernelArgumentType::Normal(typ) => typ.into(),
KernelArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
}
impl FnArgumentType {
pub fn to_type(&self, is_kernel: bool) -> Type {
if is_kernel {
self.to_kernel_type()
} else {
self.to_func_type()
}
}
pub fn to_kernel_type(&self) -> Type {
match self {
FnArgumentType::Reg(x) => x.clone().into(),
FnArgumentType::Param(x) => x.clone().into(),
FnArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
pub fn to_func_type(&self) -> Type {
match self {
FnArgumentType::Reg(x) => x.clone().into(),
FnArgumentType::Param(VariableParamType::Scalar(t)) => {
Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param)
}
FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer(
PointerType::Array((*t).into(), dims.clone()),
LdStateSpace::Param,
),
FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer(
PointerType::Pointer((*t).into(), (*space).into()),
LdStateSpace::Param,
),
FnArgumentType::Shared => {
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
}
}
}
pub fn is_param(&self) -> bool {
match self {
FnArgumentType::Param(_) => true,
_ => false,
}
}
}
#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
@ -283,7 +117,7 @@ pub enum Type {
pub enum PointerType {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
Array(ScalarType, Vec<u32>),
// Instances of this variant are generated during stateful conversion
Pointer(ScalarType, LdStateSpace),
}
@ -366,51 +200,19 @@ pub enum Statement<P: ArgParams> {
}
pub struct MultiVariable<ID> {
pub var: Variable<VariableType, ID>,
pub var: Variable<ID>,
pub count: Option<u32>,
}
#[derive(Clone)]
pub struct Variable<T, ID> {
pub struct Variable<ID> {
pub align: Option<u32>,
pub v_type: T,
pub v_type: Type,
pub state_space: StateSpace,
pub name: ID,
pub array_init: Vec<u8>,
}
#[derive(Eq, PartialEq, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
Global(VariableGlobalType),
Shared(VariableGlobalType),
}
impl VariableType {
pub fn to_type(&self) -> (StateSpace, Type) {
match self {
VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
}
}
}
impl From<VariableType> for Type {
fn from(t: VariableType) -> Self {
match t {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
VariableType::Global(t) => t.into(),
VariableType::Shared(t) => t.into(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
@ -419,6 +221,7 @@ pub enum StateSpace {
Local,
Shared,
Param,
Generic,
}
pub struct PredAt<ID> {

View file

@ -404,28 +404,29 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
KernelInput: ast::Variable<&'input str> = {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
ast::Variable {
align,
v_type: ast::KernelArgumentType::Normal(v_type),
v_type,
state_space: ast::StateSpace::Param,
name,
array_init: Vec::new()
}
}
}
FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
FnInput: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type);
ast::Variable{ align, v_type, name, array_init: Vec::new() }
let state_space = ast::StateSpace::Reg;
ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
},
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Param(v_type);
ast::Variable{ align, v_type, name, array_init: Vec::new() }
let state_space = ast::StateSpace::Param;
ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
}
}
@ -508,102 +509,109 @@ VariableParam: u32 = {
"<" <n:U32Num> ">" => n
}
Variable: ast::Variable<ast::VariableType, &'input str> = {
Variable: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
ast::Variable {align, v_type, name, array_init: Vec::new()}
let state_space = ast::StateSpace::Reg;
ast::Variable {align, v_type, state_space, name, array_init: Vec::new()}
},
LocalVariable,
<v:ParamVariable> => {
let (align, array_init, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type);
ast::Variable {align, v_type, name, array_init}
let state_space = ast::StateSpace::Param;
ast::Variable {align, v_type, state_space, name, array_init}
},
SharedVariable,
};
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
RegVariable: (Option<u32>, ast::Type, &'input str) = {
".reg" <var:VariableScalar<ScalarType>> => {
let (align, t, name) = var;
let v_type = ast::VariableRegType::Scalar(t);
let v_type = ast::Type::Scalar(t);
(align, v_type, name)
},
".reg" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
let v_type = ast::VariableRegType::Vector(t, v_len);
let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name)
}
}
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
LocalVariable: ast::Variable<&'input str> = {
".local" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
ast::Variable { align, v_type, name, array_init: Vec::new() }
let v_type = ast::Type::Scalar(t);
let state_space = ast::StateSpace::Local;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
ast::Variable { align, v_type, name, array_init: Vec::new() }
let v_type = ast::Type::Vector(t, v_len);
let state_space = ast::StateSpace::Local;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
let state_space = ast::StateSpace::Local;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableLocalType::Array(t, dimensions), init)
(ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
SharedVariable: ast::Variable<ast::VariableType, &'input str> = {
SharedVariable: ast::Variable<&'input str> = {
".shared" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
let v_type = ast::VariableGlobalType::Scalar(t);
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
let state_space = ast::StateSpace::Shared;
let v_type = ast::Type::Scalar(t);
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
let v_type = ast::VariableGlobalType::Vector(t, v_len);
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
let state_space = ast::StateSpace::Shared;
let v_type = ast::Type::Vector(t, v_len);
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
let state_space = ast::StateSpace::Shared;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableGlobalType::Array(t, dimensions), init)
(ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init })
Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
ModuleVariable: ast::Variable<&'input str> = {
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
let state_space = ast::StateSpace::Global;
ast::Variable { align, v_type, state_space, name, array_init }
},
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
let state_space = ast::StateSpace::Shared;
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr {
let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
if space == ".global" {
(ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
(ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
} else {
(ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
(ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
}
}
ast::ArrayOrPointer::Pointer => {
@ -611,38 +619,38 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new())
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new())
} else {
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new())
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new())
}
}
};
Ok(ast::Variable{ align, array_init, v_type, name })
Ok(ast::Variable{ align, v_type, state_space, name, array_init })
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
".param" <var:VariableScalar<LdStScalarType>> => {
let (align, t, name) = var;
let v_type = ast::VariableParamType::Scalar(t);
let v_type = ast::Type::Scalar(t);
(align, Vec::new(), v_type, name)
},
".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
(ast::VariableParamType::Array(t, dimensions), init)
(ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
(ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new())
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new())
}
};
(align, array_init, v_type, name)
}
}
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
ParamDeclaration: (Option<u32>, ast::Type, &'input str) = {
<var:ParamVariable> =>? {
let (align, array_init, v_type, name) = var;
if array_init.len() > 0 {
@ -653,15 +661,15 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
}
}
GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = {
<scalar:VariableScalar<SizedScalarType>> => {
let (align, t, name) = scalar;
let v_type = ast::VariableGlobalType::Scalar(t);
let v_type = ast::Type::Scalar(t);
(align, v_type, name, Vec::new())
},
<var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
let v_type = ast::VariableGlobalType::Vector(t, v_len);
let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name, Vec::new())
},
}

View file

@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(var) => {
if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
var.v_type
{
extern_shared_decls.insert(var.name, p_type);
}
Directive::Variable(ast::Variable {
v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
state_space: ast::StateSpace::Shared,
name,
..
}) => {
extern_shared_decls.insert(*name, p_type.clone());
}
_ => {}
}
@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>(
let shared_id_param = new_id();
spirv_decl.input.push({
ast::Variable {
name: shared_id_param,
align: None,
v_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
state_space: ast::StateSpace::Param,
array_init: Vec::new(),
name: shared_id_param,
}
});
spirv_decl.uses_shared_mem = true;
let shared_var_id = new_id();
let shared_var = ExpandedStatement::Variable(ast::Variable {
align: None,
name: shared_var_id,
array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
ast::ScalarType::B8,
align: None,
v_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
)),
),
state_space: ast::StateSpace::Reg,
array_init: Vec::new(),
});
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>(
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
call.param_list
.push((shared_id_param, ast::FnArgumentType::Shared));
call.param_list.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
ast::StateSpace::Shared,
));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
if *typ == ast::ScalarType::B8 {
return shared_var_id;
}
@ -1067,7 +1073,7 @@ fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
synthetic_globals: &[ast::Variable<spirv::Word>],
func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
@ -1204,9 +1210,9 @@ fn translate_directive<'input>(
fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
let (space, var_type) = var.v_type.to_type();
var: ast::Variable<&'a str>,
) -> Result<ast::Variable<spirv::Word>, TranslateError> {
let (space, var_type) = (var.state_space, var.v_type.clone());
let mut is_variable = false;
let var_type = match space {
ast::StateSpace::Reg => {
@ -1226,10 +1232,12 @@ fn translate_variable<'a>(
}
}
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
ast::StateSpace::Generic => todo!(),
};
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
state_space: var.state_space,
name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>(
false,
),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
let is_variable = match a.v_type {
ast::FnArgumentType::Reg(_) => true,
_ => false,
};
let var_type = a.v_type.to_func_type();
let is_variable = a.state_space == ast::StateSpace::Reg;
Ok(ast::FnArgument {
name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> (
Vec<ExpandedStatement>,
Vec<ast::Variable<ast::VariableType, spirv::Word>>,
) {
) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
v_type: ast::VariableType::Shared(_),
state_space: ast::StateSpace::Shared,
..
},
)
@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
v_type: ast::VariableType::Global(_),
state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@ -1592,10 +1595,10 @@ fn convert_to_typed_statements(
let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter()
.partition(|(_, arg_type)| arg_type.is_param());
.partition(|(_, _, space)| *space == ast::StateSpace::Param);
let normalized_input_args = out_params
.into_iter()
.map(|(id, typ)| (ast::Operand::Reg(id), typ))
.map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call(
vec![
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
typ, ptr_space,
)),
v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
ret_params: vec![(
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
)],
ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
ast::Type::Scalar(scalar_typ),
ast::StateSpace::Reg,
),
],
})
@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call(
vec![
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
ret_params: vec![(
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
(
arg.src3,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
],
})
@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call(
vec![
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
ret_params: vec![(
arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
ast::Type::Scalar(typ.into()),
ast::StateSpace::Reg,
),
(
arg.src3,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
(
arg.src4,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
ast::Type::Scalar(ast::ScalarType::U32),
ast::StateSpace::Reg,
),
],
})
@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call(
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
) -> Vec<(T, ast::FnArgumentType)> {
params_decl: &[(ast::Type, ast::StateSpace)],
) -> Vec<(T, ast::Type, ast::StateSpace)> {
params
.into_iter()
.zip(params_decl.iter())
.map(|(id, typ)| (id, typ.clone()))
.map(|(id, (typ, space))| (id, typ.clone(), *space))
.collect::<Vec<_>>()
}
@ -2096,50 +2101,38 @@ fn normalize_predicates(
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
_: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
let is_func = match ast_fn_decl {
ast::MethodDecl::Func(..) => true,
ast::MethodDecl::Kernel { .. } => false,
};
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
match type_to_variable_type(&arg.v_type, is_func)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: var_type,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
None => return Err(error_unreachable()),
}
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
for spirv_arg in fn_decl.input.iter_mut() {
match type_to_variable_type(&spirv_arg.v_type, is_func)? {
Some(var_type) => {
let typ = spirv_arg.v_type.clone();
let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable {
align: spirv_arg.align,
v_type: var_type,
name: spirv_arg.name,
array_init: spirv_arg.array_init.clone(),
}));
result.push(Statement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
src1: spirv_arg.name,
src2: new_id,
},
typ,
member_index: None,
}));
spirv_arg.name = new_id;
}
None => {}
}
let typ = spirv_arg.v_type.clone();
let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable {
align: spirv_arg.align,
v_type: spirv_arg.v_type.clone(),
state_space: spirv_arg.state_space,
name: spirv_arg.name,
array_init: spirv_arg.array_init.clone(),
}));
result.push(Statement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
src1: spirv_arg.name,
src2: new_id,
},
typ,
member_index: None,
}));
spirv_arg.name = new_id;
}
for s in func {
match s {
@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result)
}
fn type_to_variable_type(
t: &ast::Type,
is_func: bool,
) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
*len,
))),
ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
(*typ)
.try_into()
.map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
if is_func {
return Ok(None);
}
Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
scalar_type
.clone()
.try_into()
.map_err(|_| error_unreachable())?,
(*space).try_into().map_err(|_| error_unreachable())?,
)))
}
ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
_ => return Err(error_unreachable()),
})
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>(
Statement::Variable(ast::Variable {
align,
v_type,
state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
state_space,
name,
array_init,
})),
@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
spirv_input: &[ast::Variable<spirv::Word>],
spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
@ -2822,8 +2782,8 @@ fn emit_function_body_ops(
Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => (
map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
[(id, typ, _)] => (
map.get_or_add(builder, SpirvType::from(typ.clone())),
Some(*id),
),
[] => (map.void(), None),
@ -2832,7 +2792,7 @@ fn emit_function_body_ops(
let arg_list = call
.param_list
.iter()
.map(|(id, _)| *id)
.map(|(id, _, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
@ -3602,14 +3562,16 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
var: &ast::Variable<ast::VariableType, spirv::Word>,
var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
let (must_init, st_class) = match var.v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
let (must_init, st_class) = match var.state_space {
ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => todo!(),
ast::StateSpace::Generic => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>(
ast::Statement::Variable(var) => {
let mut var_type = ast::Type::from(var.var.v_type.clone());
let mut is_variable = false;
var_type = match var.var.v_type {
ast::VariableType::Reg(_) => {
var_type = match var.var.state_space {
ast::StateSpace::Reg => {
is_variable = true;
var_type
}
ast::VariableType::Shared(_) => {
ast::StateSpace::Shared => {
// If it's a pointer it will be translated to a method parameter later
if let ast::Type::Pointer(..) = var_type {
is_variable = true;
@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>(
var_type.param_pointer_to(ast::LdStateSpace::Shared)?
}
}
ast::VariableType::Global(_) => {
var_type.param_pointer_to(ast::LdStateSpace::Global)?
}
ast::VariableType::Param(_) => {
var_type.param_pointer_to(ast::LdStateSpace::Param)?
}
ast::VariableType::Local(_) => {
var_type.param_pointer_to(ast::LdStateSpace::Local)?
}
ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
ast::StateSpace::Const => todo!(),
ast::StateSpace::Generic => todo!(),
};
match var.count {
Some(count) => {
@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>(
align: None,
name: new_id,
array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
ast::ScalarType::U8,
v_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global,
)),
),
state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> {
}
pub struct FnDecl {
ret_vals: Vec<ast::FnArgumentType>,
params: Vec<ast::FnArgumentType>,
ret_vals: Vec<(ast::Type, ast::StateSpace)>,
params: Vec<(ast::Type, ast::StateSpace)>,
}
impl<'a> GlobalStringIdResolver<'a> {
@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert(
name_id,
FnDecl {
ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
ret_vals: ret_params_ids
.iter()
.map(|p| (p.v_type.clone(), p.state_space))
.collect(),
params: params_ids
.iter()
.map(|p| (p.v_type.clone(), p.state_space))
.collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>),
Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@ -5352,16 +5319,17 @@ impl ExpandedStatement {
Statement::StoreVar(details)
}
Statement::Call(mut call) => {
for (id, typ) in call.ret_params.iter_mut() {
let is_dst = match typ {
ast::FnArgumentType::Reg(_) => true,
ast::FnArgumentType::Param(_) => false,
ast::FnArgumentType::Shared => false,
for (id, _, space) in call.ret_params.iter_mut() {
let is_dst = match space {
ast::StateSpace::Reg => true,
ast::StateSpace::Param => false,
ast::StateSpace::Shared => false,
_ => todo!(),
};
*id = f(*id, is_dst);
}
call.func = f(call.func, false);
for (id, _) in call.param_list.iter_mut() {
for (id, _, _) in call.param_list.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@ -5502,9 +5470,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>,
pub func: P::Id,
pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
@ -5526,16 +5494,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let ret_params = self
.ret_params
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
is_dst: !typ.is_param(),
sema: typ.semantics(),
is_dst: space != ast::StateSpace::Param,
sema: space.semantics(),
},
Some(&typ.to_func_type()),
Some(&typ),
)?;
Ok((new_id, typ))
Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
@ -5549,16 +5517,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let param_list = self
.param_list
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
sema: typ.semantics(),
sema: space.semantics(),
},
&typ.to_func_type(),
&typ,
)?;
Ok((new_id, typ))
Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams {
}
enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>),
Variable(ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub spirv_decl: SpirvMethodDecl<'input>,
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@ -7300,16 +7268,6 @@ impl ast::LdStateSpace {
}
}
impl From<ast::FnArgumentType> for ast::VariableType {
fn from(t: ast::FnArgumentType) -> Self {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
ast::FnArgumentType::Shared => todo!(),
}
}
}
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
@ -7362,12 +7320,13 @@ impl ast::AtomSemantics {
}
}
impl ast::FnArgumentType {
fn semantics(&self) -> ArgumentSemantics {
impl ast::StateSpace {
fn semantics(self) -> ArgumentSemantics {
match self {
ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
ast::StateSpace::Reg => ArgumentSemantics::Default,
ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer,
_ => todo!(),
}
}
}
@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> {
}
struct SpirvMethodDecl<'input> {
input: Vec<ast::Variable<ast::Type, spirv::Word>>,
output: Vec<ast::Variable<ast::Type, spirv::Word>>,
input: Vec<ast::Variable<spirv::Word>>,
output: Vec<ast::Variable<spirv::Word>>,
name: MethodName<'input>,
uses_shared_mem: bool,
}
@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> {
ast::MethodDecl::Kernel { in_args, .. } => {
let spirv_input = in_args
.iter()
.map(|var| {
let v_type = match &var.v_type {
ast::KernelArgumentType::Normal(t) => {
ast::FnArgumentType::Param(t.clone())
}
ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
};
ast::Variable {
name: var.name,
align: var.align,
v_type: v_type.to_kernel_type(),
array_init: var.array_init.clone(),
}
.map(|var| ast::Variable {
name: var.name,
align: var.align,
v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();
(spirv_input, Vec::new())
}
ast::MethodDecl::Func(out_args, _, in_args) => {
let (param_output, non_param_output): (Vec<_>, Vec<_>) =
out_args.iter().partition(|var| var.v_type.is_param());
let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
.iter()
.partition(|var| var.state_space == ast::StateSpace::Param);
let spirv_output = non_param_output
.into_iter()
.cloned()
.map(|var| ast::Variable {
name: var.name,
align: var.align,
v_type: var.v_type.to_func_type(),
v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();
@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> {
.map(|var| ast::Variable {
name: var.name,
align: var.align,
v_type: var.v_type.to_func_type(),
v_type: var.v_type.clone(),
state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();