mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Throw away special variable types
This commit is contained in:
parent
a55c851eaa
commit
d51aaaf552
3 changed files with 256 additions and 490 deletions
215
ptx/src/ast.rs
215
ptx/src/ast.rs
|
@ -1,6 +1,5 @@
|
|||
use half::f16;
|
||||
use lalrpop_util::{lexer::Token, ParseError};
|
||||
use std::convert::TryInto;
|
||||
use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
|
||||
use std::{marker::PhantomData, num::ParseIntError};
|
||||
|
||||
|
@ -34,107 +33,12 @@ pub enum PtxError {
|
|||
NonExternPointer,
|
||||
}
|
||||
|
||||
macro_rules! sub_type {
|
||||
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||
sub_type! { $type_name : Type {
|
||||
$(
|
||||
$variant ($($field_type),+),
|
||||
)+
|
||||
}}
|
||||
};
|
||||
($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum $type_name {
|
||||
$(
|
||||
$variant ($($field_type),+),
|
||||
)+
|
||||
}
|
||||
|
||||
impl From<$type_name> for $base_type {
|
||||
#[allow(non_snake_case)]
|
||||
fn from(t: $type_name) -> $base_type {
|
||||
match t {
|
||||
$(
|
||||
$type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<$base_type> for $type_name {
|
||||
type Error = ();
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
#[allow(unreachable_patterns)]
|
||||
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
|
||||
match t {
|
||||
$(
|
||||
$base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
|
||||
)+
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
sub_type! {
|
||||
VariableRegType {
|
||||
Scalar(ScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
// Array type is used when emiting SSA statements at the start of a method
|
||||
Array(ScalarType, VecU32),
|
||||
// Pointer variant is used when passing around SLM pointer between
|
||||
// function calls for dynamic SLM
|
||||
Pointer(ScalarType, LdStateSpace)
|
||||
}
|
||||
}
|
||||
|
||||
type VecU32 = Vec<u32>;
|
||||
|
||||
sub_type! {
|
||||
VariableLocalType {
|
||||
Scalar(ScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
Array(ScalarType, VecU32),
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<VariableGlobalType> for VariableLocalType {
|
||||
type Error = PtxError;
|
||||
|
||||
fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
|
||||
VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
|
||||
VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
|
||||
VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sub_type! {
|
||||
VariableGlobalType {
|
||||
Scalar(ScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
Array(ScalarType, VecU32),
|
||||
Pointer(ScalarType, LdStateSpace),
|
||||
}
|
||||
}
|
||||
|
||||
// For some weird reson this is illegal:
|
||||
// .param .f16x2 foobar;
|
||||
// but this is legal:
|
||||
// .param .f16x2 foobar[1];
|
||||
// even more interestingly this is legal, but only in .func (not in .entry):
|
||||
// .param .b32 foobar[]
|
||||
sub_type! {
|
||||
VariableParamType {
|
||||
Scalar(ScalarType),
|
||||
Array(ScalarType, VecU32),
|
||||
Pointer(ScalarType, LdStateSpace),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum BarDetails {
|
||||
|
@ -178,7 +82,7 @@ pub struct Module<'a> {
|
|||
}
|
||||
|
||||
pub enum Directive<'a, P: ArgParams> {
|
||||
Variable(Variable<VariableType, P::Id>),
|
||||
Variable(Variable<P::Id>),
|
||||
Method(Function<'a, &'a str, Statement<P>>),
|
||||
}
|
||||
|
||||
|
@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> {
|
|||
},
|
||||
}
|
||||
|
||||
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
|
||||
pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
|
||||
pub type FnArgument<ID> = Variable<ID>;
|
||||
pub type KernelArgument<ID> = Variable<ID>;
|
||||
|
||||
pub struct Function<'a, ID, S> {
|
||||
pub func_directive: MethodDecl<'a, ID>,
|
||||
|
@ -201,76 +105,6 @@ pub struct Function<'a, ID, S> {
|
|||
|
||||
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum FnArgumentType {
|
||||
Reg(VariableRegType),
|
||||
Param(VariableParamType),
|
||||
Shared,
|
||||
}
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum KernelArgumentType {
|
||||
Normal(VariableParamType),
|
||||
Shared,
|
||||
}
|
||||
|
||||
impl From<KernelArgumentType> for Type {
|
||||
fn from(this: KernelArgumentType) -> Self {
|
||||
match this {
|
||||
KernelArgumentType::Normal(typ) => typ.into(),
|
||||
KernelArgumentType::Shared => {
|
||||
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FnArgumentType {
|
||||
pub fn to_type(&self, is_kernel: bool) -> Type {
|
||||
if is_kernel {
|
||||
self.to_kernel_type()
|
||||
} else {
|
||||
self.to_func_type()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_kernel_type(&self) -> Type {
|
||||
match self {
|
||||
FnArgumentType::Reg(x) => x.clone().into(),
|
||||
FnArgumentType::Param(x) => x.clone().into(),
|
||||
FnArgumentType::Shared => {
|
||||
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_func_type(&self) -> Type {
|
||||
match self {
|
||||
FnArgumentType::Reg(x) => x.clone().into(),
|
||||
FnArgumentType::Param(VariableParamType::Scalar(t)) => {
|
||||
Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param)
|
||||
}
|
||||
FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer(
|
||||
PointerType::Array((*t).into(), dims.clone()),
|
||||
LdStateSpace::Param,
|
||||
),
|
||||
FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer(
|
||||
PointerType::Pointer((*t).into(), (*space).into()),
|
||||
LdStateSpace::Param,
|
||||
),
|
||||
FnArgumentType::Shared => {
|
||||
Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_param(&self) -> bool {
|
||||
match self {
|
||||
FnArgumentType::Param(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum Type {
|
||||
Scalar(ScalarType),
|
||||
|
@ -283,7 +117,7 @@ pub enum Type {
|
|||
pub enum PointerType {
|
||||
Scalar(ScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
Array(ScalarType, VecU32),
|
||||
Array(ScalarType, Vec<u32>),
|
||||
// Instances of this variant are generated during stateful conversion
|
||||
Pointer(ScalarType, LdStateSpace),
|
||||
}
|
||||
|
@ -366,51 +200,19 @@ pub enum Statement<P: ArgParams> {
|
|||
}
|
||||
|
||||
pub struct MultiVariable<ID> {
|
||||
pub var: Variable<VariableType, ID>,
|
||||
pub var: Variable<ID>,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Variable<T, ID> {
|
||||
pub struct Variable<ID> {
|
||||
pub align: Option<u32>,
|
||||
pub v_type: T,
|
||||
pub v_type: Type,
|
||||
pub state_space: StateSpace,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Clone)]
|
||||
pub enum VariableType {
|
||||
Reg(VariableRegType),
|
||||
Local(VariableLocalType),
|
||||
Param(VariableParamType),
|
||||
Global(VariableGlobalType),
|
||||
Shared(VariableGlobalType),
|
||||
}
|
||||
|
||||
impl VariableType {
|
||||
pub fn to_type(&self) -> (StateSpace, Type) {
|
||||
match self {
|
||||
VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
|
||||
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
|
||||
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
|
||||
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
|
||||
VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VariableType> for Type {
|
||||
fn from(t: VariableType) -> Self {
|
||||
match t {
|
||||
VariableType::Reg(t) => t.into(),
|
||||
VariableType::Local(t) => t.into(),
|
||||
VariableType::Param(t) => t.into(),
|
||||
VariableType::Global(t) => t.into(),
|
||||
VariableType::Shared(t) => t.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
|
@ -419,6 +221,7 @@ pub enum StateSpace {
|
|||
Local,
|
||||
Shared,
|
||||
Param,
|
||||
Generic,
|
||||
}
|
||||
|
||||
pub struct PredAt<ID> {
|
||||
|
|
|
@ -404,28 +404,29 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
|
|||
"(" <args:Comma<FnInput>> ")" => args
|
||||
};
|
||||
|
||||
KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
|
||||
KernelInput: ast::Variable<&'input str> = {
|
||||
<v:ParamDeclaration> => {
|
||||
let (align, v_type, name) = v;
|
||||
ast::Variable {
|
||||
align,
|
||||
v_type: ast::KernelArgumentType::Normal(v_type),
|
||||
v_type,
|
||||
state_space: ast::StateSpace::Param,
|
||||
name,
|
||||
array_init: Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
|
||||
FnInput: ast::Variable<&'input str> = {
|
||||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Reg(v_type);
|
||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
||||
let state_space = ast::StateSpace::Reg;
|
||||
ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
<v:ParamDeclaration> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Param(v_type);
|
||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
||||
let state_space = ast::StateSpace::Param;
|
||||
ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -508,102 +509,109 @@ VariableParam: u32 = {
|
|||
"<" <n:U32Num> ">" => n
|
||||
}
|
||||
|
||||
Variable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
Variable: ast::Variable<&'input str> = {
|
||||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Reg(v_type);
|
||||
ast::Variable {align, v_type, name, array_init: Vec::new()}
|
||||
let state_space = ast::StateSpace::Reg;
|
||||
ast::Variable {align, v_type, state_space, name, array_init: Vec::new()}
|
||||
},
|
||||
LocalVariable,
|
||||
<v:ParamVariable> => {
|
||||
let (align, array_init, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Param(v_type);
|
||||
ast::Variable {align, v_type, name, array_init}
|
||||
let state_space = ast::StateSpace::Param;
|
||||
ast::Variable {align, v_type, state_space, name, array_init}
|
||||
},
|
||||
SharedVariable,
|
||||
};
|
||||
|
||||
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
||||
RegVariable: (Option<u32>, ast::Type, &'input str) = {
|
||||
".reg" <var:VariableScalar<ScalarType>> => {
|
||||
let (align, t, name) = var;
|
||||
let v_type = ast::VariableRegType::Scalar(t);
|
||||
let v_type = ast::Type::Scalar(t);
|
||||
(align, v_type, name)
|
||||
},
|
||||
".reg" <var:VariableVector<SizedScalarType>> => {
|
||||
let (align, v_len, t, name) = var;
|
||||
let v_type = ast::VariableRegType::Vector(t, v_len);
|
||||
let v_type = ast::Type::Vector(t, v_len);
|
||||
(align, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
LocalVariable: ast::Variable<&'input str> = {
|
||||
".local" <var:VariableScalar<SizedScalarType>> => {
|
||||
let (align, t, name) = var;
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
|
||||
ast::Variable { align, v_type, name, array_init: Vec::new() }
|
||||
let v_type = ast::Type::Scalar(t);
|
||||
let state_space = ast::StateSpace::Local;
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
".local" <var:VariableVector<SizedScalarType>> => {
|
||||
let (align, v_len, t, name) = var;
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
|
||||
ast::Variable { align, v_type, name, array_init: Vec::new() }
|
||||
let v_type = ast::Type::Vector(t, v_len);
|
||||
let state_space = ast::StateSpace::Local;
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||
let (align, t, name, arr_or_ptr) = var;
|
||||
let state_space = ast::StateSpace::Local;
|
||||
let (v_type, array_init) = match arr_or_ptr {
|
||||
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||
(ast::VariableLocalType::Array(t, dimensions), init)
|
||||
(ast::Type::Array(t, dimensions), init)
|
||||
}
|
||||
ast::ArrayOrPointer::Pointer => {
|
||||
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
|
||||
}
|
||||
};
|
||||
Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
|
||||
Ok(ast::Variable { align, v_type, state_space, name, array_init })
|
||||
}
|
||||
}
|
||||
|
||||
SharedVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
SharedVariable: ast::Variable<&'input str> = {
|
||||
".shared" <var:VariableScalar<SizedScalarType>> => {
|
||||
let (align, t, name) = var;
|
||||
let v_type = ast::VariableGlobalType::Scalar(t);
|
||||
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
|
||||
let state_space = ast::StateSpace::Shared;
|
||||
let v_type = ast::Type::Scalar(t);
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
".shared" <var:VariableVector<SizedScalarType>> => {
|
||||
let (align, v_len, t, name) = var;
|
||||
let v_type = ast::VariableGlobalType::Vector(t, v_len);
|
||||
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
|
||||
let state_space = ast::StateSpace::Shared;
|
||||
let v_type = ast::Type::Vector(t, v_len);
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||
let (align, t, name, arr_or_ptr) = var;
|
||||
let state_space = ast::StateSpace::Shared;
|
||||
let (v_type, array_init) = match arr_or_ptr {
|
||||
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||
(ast::VariableGlobalType::Array(t, dimensions), init)
|
||||
(ast::Type::Array(t, dimensions), init)
|
||||
}
|
||||
ast::ArrayOrPointer::Pointer => {
|
||||
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
|
||||
}
|
||||
};
|
||||
Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init })
|
||||
Ok(ast::Variable { align, v_type, state_space, name, array_init })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
ModuleVariable: ast::Variable<&'input str> = {
|
||||
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
|
||||
let (align, v_type, name, array_init) = def;
|
||||
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
|
||||
let state_space = ast::StateSpace::Global;
|
||||
ast::Variable { align, v_type, state_space, name, array_init }
|
||||
},
|
||||
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
|
||||
let (align, v_type, name, array_init) = def;
|
||||
ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
|
||||
let state_space = ast::StateSpace::Shared;
|
||||
ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
|
||||
},
|
||||
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
|
||||
let (align, t, name, arr_or_ptr) = var;
|
||||
let (v_type, array_init) = match arr_or_ptr {
|
||||
let (v_type, state_space, array_init) = match arr_or_ptr {
|
||||
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||
if space == ".global" {
|
||||
(ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
|
||||
(ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
|
||||
} else {
|
||||
(ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
|
||||
(ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
|
||||
}
|
||||
}
|
||||
ast::ArrayOrPointer::Pointer => {
|
||||
|
@ -611,38 +619,38 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
|
|||
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
|
||||
}
|
||||
if space == ".global" {
|
||||
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new())
|
||||
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new())
|
||||
} else {
|
||||
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new())
|
||||
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new())
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(ast::Variable{ align, array_init, v_type, name })
|
||||
Ok(ast::Variable{ align, v_type, state_space, name, array_init })
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
||||
ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
|
||||
".param" <var:VariableScalar<LdStScalarType>> => {
|
||||
let (align, t, name) = var;
|
||||
let v_type = ast::VariableParamType::Scalar(t);
|
||||
let v_type = ast::Type::Scalar(t);
|
||||
(align, Vec::new(), v_type, name)
|
||||
},
|
||||
".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
|
||||
let (align, t, name, arr_or_ptr) = var;
|
||||
let (v_type, array_init) = match arr_or_ptr {
|
||||
ast::ArrayOrPointer::Array { dimensions, init } => {
|
||||
(ast::VariableParamType::Array(t, dimensions), init)
|
||||
(ast::Type::Array(t, dimensions), init)
|
||||
}
|
||||
ast::ArrayOrPointer::Pointer => {
|
||||
(ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new())
|
||||
(ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new())
|
||||
}
|
||||
};
|
||||
(align, array_init, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||
ParamDeclaration: (Option<u32>, ast::Type, &'input str) = {
|
||||
<var:ParamVariable> =>? {
|
||||
let (align, array_init, v_type, name) = var;
|
||||
if array_init.len() > 0 {
|
||||
|
@ -653,15 +661,15 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
|
|||
}
|
||||
}
|
||||
|
||||
GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
|
||||
GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = {
|
||||
<scalar:VariableScalar<SizedScalarType>> => {
|
||||
let (align, t, name) = scalar;
|
||||
let v_type = ast::VariableGlobalType::Scalar(t);
|
||||
let v_type = ast::Type::Scalar(t);
|
||||
(align, v_type, name, Vec::new())
|
||||
},
|
||||
<var:VariableVector<SizedScalarType>> => {
|
||||
let (align, v_len, t, name) = var;
|
||||
let v_type = ast::VariableGlobalType::Vector(t, v_len);
|
||||
let v_type = ast::Type::Vector(t, v_len);
|
||||
(align, v_type, name, Vec::new())
|
||||
},
|
||||
}
|
||||
|
|
|
@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
let mut extern_shared_decls = HashMap::new();
|
||||
for dir in module.iter() {
|
||||
match dir {
|
||||
Directive::Variable(var) => {
|
||||
if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
|
||||
var.v_type
|
||||
{
|
||||
extern_shared_decls.insert(var.name, p_type);
|
||||
}
|
||||
Directive::Variable(ast::Variable {
|
||||
v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name,
|
||||
..
|
||||
}) => {
|
||||
extern_shared_decls.insert(*name, p_type.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
let shared_id_param = new_id();
|
||||
spirv_decl.input.push({
|
||||
ast::Variable {
|
||||
name: shared_id_param,
|
||||
align: None,
|
||||
v_type: ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
||||
ast::PointerType::Scalar(ast::ScalarType::B8),
|
||||
ast::LdStateSpace::Shared,
|
||||
),
|
||||
state_space: ast::StateSpace::Param,
|
||||
array_init: Vec::new(),
|
||||
name: shared_id_param,
|
||||
}
|
||||
});
|
||||
spirv_decl.uses_shared_mem = true;
|
||||
let shared_var_id = new_id();
|
||||
let shared_var = ExpandedStatement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: shared_var_id,
|
||||
array_init: Vec::new(),
|
||||
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||
ast::ScalarType::B8,
|
||||
align: None,
|
||||
v_type: ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::B8),
|
||||
ast::LdStateSpace::Shared,
|
||||
)),
|
||||
),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
array_init: Vec::new(),
|
||||
});
|
||||
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
|
||||
arg: ast::Arg2St {
|
||||
|
@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
|||
fn replace_uses_of_shared_memory<'a>(
|
||||
result: &mut Vec<ExpandedStatement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
|
||||
extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
|
||||
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
|
||||
shared_id_param: spirv::Word,
|
||||
shared_var_id: spirv::Word,
|
||||
|
@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>(
|
|||
// because there's simply no way to pass shared ptr
|
||||
// without converting it to .b64 first
|
||||
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
|
||||
call.param_list
|
||||
.push((shared_id_param, ast::FnArgumentType::Shared));
|
||||
call.param_list.push((
|
||||
shared_id_param,
|
||||
ast::Type::Scalar(ast::ScalarType::B8),
|
||||
ast::StateSpace::Shared,
|
||||
));
|
||||
}
|
||||
result.push(Statement::Call(call))
|
||||
}
|
||||
statement => {
|
||||
let new_statement = statement.map_id(&mut |id, _| {
|
||||
if let Some(typ) = extern_shared_decls.get(&id) {
|
||||
if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
|
||||
if *typ == ast::ScalarType::B8 {
|
||||
return shared_var_id;
|
||||
}
|
||||
|
@ -1067,7 +1073,7 @@ fn emit_function_header<'a>(
|
|||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
defined_globals: &GlobalStringIdResolver<'a>,
|
||||
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
|
||||
synthetic_globals: &[ast::Variable<spirv::Word>],
|
||||
func_decl: &SpirvMethodDecl<'a>,
|
||||
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
|
||||
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
|
||||
|
@ -1204,9 +1210,9 @@ fn translate_directive<'input>(
|
|||
|
||||
fn translate_variable<'a>(
|
||||
id_defs: &mut GlobalStringIdResolver<'a>,
|
||||
var: ast::Variable<ast::VariableType, &'a str>,
|
||||
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
|
||||
let (space, var_type) = var.v_type.to_type();
|
||||
var: ast::Variable<&'a str>,
|
||||
) -> Result<ast::Variable<spirv::Word>, TranslateError> {
|
||||
let (space, var_type) = (var.state_space, var.v_type.clone());
|
||||
let mut is_variable = false;
|
||||
let var_type = match space {
|
||||
ast::StateSpace::Reg => {
|
||||
|
@ -1226,10 +1232,12 @@ fn translate_variable<'a>(
|
|||
}
|
||||
}
|
||||
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
|
||||
ast::StateSpace::Generic => todo!(),
|
||||
};
|
||||
Ok(ast::Variable {
|
||||
align: var.align,
|
||||
v_type: var.v_type,
|
||||
state_space: var.state_space,
|
||||
name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
|
||||
array_init: var.array_init,
|
||||
})
|
||||
|
@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>(
|
|||
false,
|
||||
),
|
||||
v_type: a.v_type.clone(),
|
||||
state_space: a.state_space,
|
||||
align: a.align,
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
|
@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>(
|
|||
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
|
||||
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
|
||||
args.map(|a| {
|
||||
let is_variable = match a.v_type {
|
||||
ast::FnArgumentType::Reg(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
let var_type = a.v_type.to_func_type();
|
||||
let is_variable = a.state_space == ast::StateSpace::Reg;
|
||||
Ok(ast::FnArgument {
|
||||
name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
|
||||
name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
|
||||
v_type: a.v_type.clone(),
|
||||
state_space: a.state_space,
|
||||
align: a.align,
|
||||
array_init: Vec::new(),
|
||||
})
|
||||
|
@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>(
|
|||
sorted_statements: Vec<ExpandedStatement>,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> (
|
||||
Vec<ExpandedStatement>,
|
||||
Vec<ast::Variable<ast::VariableType, spirv::Word>>,
|
||||
) {
|
||||
) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
|
||||
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||
let mut global = Vec::new();
|
||||
for statement in sorted_statements {
|
||||
|
@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>(
|
|||
var
|
||||
@
|
||||
ast::Variable {
|
||||
v_type: ast::VariableType::Shared(_),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
..
|
||||
},
|
||||
)
|
||||
|
@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>(
|
|||
var
|
||||
@
|
||||
ast::Variable {
|
||||
v_type: ast::VariableType::Global(_),
|
||||
state_space: ast::StateSpace::Global,
|
||||
..
|
||||
},
|
||||
) => global.push(var),
|
||||
|
@ -1592,10 +1595,10 @@ fn convert_to_typed_statements(
|
|||
let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
|
||||
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
|
||||
.into_iter()
|
||||
.partition(|(_, arg_type)| arg_type.is_param());
|
||||
.partition(|(_, _, space)| *space == ast::StateSpace::Param);
|
||||
let normalized_input_args = out_params
|
||||
.into_iter()
|
||||
.map(|(id, typ)| (ast::Operand::Reg(id), typ))
|
||||
.map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space))
|
||||
.chain(in_args.into_iter())
|
||||
.collect();
|
||||
let resolved_call = ResolvedCall {
|
||||
|
@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call(
|
|||
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
|
||||
vec![ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
|
||||
v_type: ast::Type::Scalar(scalar_typ),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
|
@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call(
|
|||
vec![
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
|
||||
typ, ptr_space,
|
||||
)),
|
||||
v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
|
||||
v_type: ast::Type::Scalar(scalar_typ),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
|
@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call(
|
|||
Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
func: fn_id,
|
||||
ret_params: vec![(
|
||||
arg.dst,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
|
||||
)],
|
||||
ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
|
||||
param_list: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
|
||||
ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
|
||||
ast::Type::Scalar(scalar_typ),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
|
@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call(
|
|||
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
|
||||
vec![ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
|
@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call(
|
|||
vec![
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
|
||||
ast::ScalarType::U32,
|
||||
)),
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
|
||||
ast::ScalarType::U32,
|
||||
)),
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
|
@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call(
|
|||
Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
func: fn_id,
|
||||
ret_params: vec![(
|
||||
arg.dst,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
)],
|
||||
ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
||||
param_list: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src3,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
|
@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call(
|
|||
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
|
||||
vec![ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
}],
|
||||
|
@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call(
|
|||
vec![
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
v_type: ast::Type::Scalar(typ.into()),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
|
||||
ast::ScalarType::U32,
|
||||
)),
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
ast::FnArgument {
|
||||
align: None,
|
||||
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
|
||||
ast::ScalarType::U32,
|
||||
)),
|
||||
v_type: ast::Type::Scalar(ast::ScalarType::U32),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: id_defs.new_non_variable(None),
|
||||
array_init: Vec::new(),
|
||||
},
|
||||
|
@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call(
|
|||
Statement::Call(ResolvedCall {
|
||||
uniform: false,
|
||||
func: fn_id,
|
||||
ret_params: vec![(
|
||||
arg.dst,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
)],
|
||||
ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
|
||||
param_list: vec![
|
||||
(
|
||||
arg.src1,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src2,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
|
||||
ast::Type::Scalar(typ.into()),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src3,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
(
|
||||
arg.src4,
|
||||
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
|
||||
ast::Type::Scalar(ast::ScalarType::U32),
|
||||
ast::StateSpace::Reg,
|
||||
),
|
||||
],
|
||||
})
|
||||
|
@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call(
|
|||
|
||||
fn to_resolved_fn_args<T>(
|
||||
params: Vec<T>,
|
||||
params_decl: &[ast::FnArgumentType],
|
||||
) -> Vec<(T, ast::FnArgumentType)> {
|
||||
params_decl: &[(ast::Type, ast::StateSpace)],
|
||||
) -> Vec<(T, ast::Type, ast::StateSpace)> {
|
||||
params
|
||||
.into_iter()
|
||||
.zip(params_decl.iter())
|
||||
.map(|(id, typ)| (id, typ.clone()))
|
||||
.map(|(id, (typ, space))| (id, typ.clone(), *space))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
|
@ -2096,50 +2101,38 @@ fn normalize_predicates(
|
|||
fn insert_mem_ssa_statements<'a, 'b>(
|
||||
func: Vec<TypedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
|
||||
_: &'a ast::MethodDecl<'b, spirv::Word>,
|
||||
fn_decl: &mut SpirvMethodDecl,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let is_func = match ast_fn_decl {
|
||||
ast::MethodDecl::Func(..) => true,
|
||||
ast::MethodDecl::Kernel { .. } => false,
|
||||
};
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for arg in fn_decl.output.iter() {
|
||||
match type_to_variable_type(&arg.v_type, is_func)? {
|
||||
Some(var_type) => {
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: var_type,
|
||||
name: arg.name,
|
||||
array_init: arg.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
None => return Err(error_unreachable()),
|
||||
}
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: arg.state_space,
|
||||
name: arg.name,
|
||||
array_init: arg.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
for spirv_arg in fn_decl.input.iter_mut() {
|
||||
match type_to_variable_type(&spirv_arg.v_type, is_func)? {
|
||||
Some(var_type) => {
|
||||
let typ = spirv_arg.v_type.clone();
|
||||
let new_id = id_def.new_non_variable(Some(typ.clone()));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: spirv_arg.align,
|
||||
v_type: var_type,
|
||||
name: spirv_arg.name,
|
||||
array_init: spirv_arg.array_init.clone(),
|
||||
}));
|
||||
result.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::Arg2St {
|
||||
src1: spirv_arg.name,
|
||||
src2: new_id,
|
||||
},
|
||||
typ,
|
||||
member_index: None,
|
||||
}));
|
||||
spirv_arg.name = new_id;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
let typ = spirv_arg.v_type.clone();
|
||||
let new_id = id_def.new_non_variable(Some(typ.clone()));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: spirv_arg.align,
|
||||
v_type: spirv_arg.v_type.clone(),
|
||||
state_space: spirv_arg.state_space,
|
||||
name: spirv_arg.name,
|
||||
array_init: spirv_arg.array_init.clone(),
|
||||
}));
|
||||
result.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::Arg2St {
|
||||
src1: spirv_arg.name,
|
||||
src2: new_id,
|
||||
},
|
||||
typ,
|
||||
member_index: None,
|
||||
}));
|
||||
spirv_arg.name = new_id;
|
||||
}
|
||||
for s in func {
|
||||
match s {
|
||||
|
@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
fn type_to_variable_type(
|
||||
t: &ast::Type,
|
||||
is_func: bool,
|
||||
) -> Result<Option<ast::VariableType>, TranslateError> {
|
||||
Ok(match t {
|
||||
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
|
||||
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
|
||||
(*typ)
|
||||
.try_into()
|
||||
.map_err(|_| TranslateError::MismatchedType)?,
|
||||
*len,
|
||||
))),
|
||||
ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
|
||||
(*typ)
|
||||
.try_into()
|
||||
.map_err(|_| TranslateError::MismatchedType)?,
|
||||
len.clone(),
|
||||
))),
|
||||
ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
|
||||
if is_func {
|
||||
return Ok(None);
|
||||
}
|
||||
Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||
scalar_type
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|_| error_unreachable())?,
|
||||
(*space).try_into().map_err(|_| error_unreachable())?,
|
||||
)))
|
||||
}
|
||||
ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
|
||||
_ => return Err(error_unreachable()),
|
||||
})
|
||||
}
|
||||
|
||||
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
|
||||
fn visit(
|
||||
self,
|
||||
|
@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>(
|
|||
Statement::Variable(ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
state_space,
|
||||
name,
|
||||
array_init,
|
||||
}) => result.push(Statement::Variable(ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
state_space,
|
||||
name,
|
||||
array_init,
|
||||
})),
|
||||
|
@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl(
|
|||
fn get_function_type(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
|
||||
spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
|
||||
spirv_input: &[ast::Variable<spirv::Word>],
|
||||
spirv_output: &[ast::Variable<spirv::Word>],
|
||||
) -> (spirv::Word, spirv::Word) {
|
||||
map.get_or_add_fn(
|
||||
builder,
|
||||
|
@ -2822,8 +2782,8 @@ fn emit_function_body_ops(
|
|||
Statement::Label(_) => (),
|
||||
Statement::Call(call) => {
|
||||
let (result_type, result_id) = match &*call.ret_params {
|
||||
[(id, typ)] => (
|
||||
map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
|
||||
[(id, typ, _)] => (
|
||||
map.get_or_add(builder, SpirvType::from(typ.clone())),
|
||||
Some(*id),
|
||||
),
|
||||
[] => (map.void(), None),
|
||||
|
@ -2832,7 +2792,7 @@ fn emit_function_body_ops(
|
|||
let arg_list = call
|
||||
.param_list
|
||||
.iter()
|
||||
.map(|(id, _)| *id)
|
||||
.map(|(id, _, _)| *id)
|
||||
.collect::<Vec<_>>();
|
||||
builder.function_call(result_type, result_id, call.func, arg_list)?;
|
||||
}
|
||||
|
@ -3602,14 +3562,16 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
|
|||
fn emit_variable(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
var: &ast::Variable<ast::VariableType, spirv::Word>,
|
||||
var: &ast::Variable<spirv::Word>,
|
||||
) -> Result<(), TranslateError> {
|
||||
let (must_init, st_class) = match var.v_type {
|
||||
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
|
||||
let (must_init, st_class) = match var.state_space {
|
||||
ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
|
||||
(false, spirv::StorageClass::Function)
|
||||
}
|
||||
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
|
||||
ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
|
||||
ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
|
||||
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
|
||||
ast::StateSpace::Const => todo!(),
|
||||
ast::StateSpace::Generic => todo!(),
|
||||
};
|
||||
let initalizer = if var.array_init.len() > 0 {
|
||||
Some(map.get_or_add_constant(
|
||||
|
@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>(
|
|||
ast::Statement::Variable(var) => {
|
||||
let mut var_type = ast::Type::from(var.var.v_type.clone());
|
||||
let mut is_variable = false;
|
||||
var_type = match var.var.v_type {
|
||||
ast::VariableType::Reg(_) => {
|
||||
var_type = match var.var.state_space {
|
||||
ast::StateSpace::Reg => {
|
||||
is_variable = true;
|
||||
var_type
|
||||
}
|
||||
ast::VariableType::Shared(_) => {
|
||||
ast::StateSpace::Shared => {
|
||||
// If it's a pointer it will be translated to a method parameter later
|
||||
if let ast::Type::Pointer(..) = var_type {
|
||||
is_variable = true;
|
||||
|
@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>(
|
|||
var_type.param_pointer_to(ast::LdStateSpace::Shared)?
|
||||
}
|
||||
}
|
||||
ast::VariableType::Global(_) => {
|
||||
var_type.param_pointer_to(ast::LdStateSpace::Global)?
|
||||
}
|
||||
ast::VariableType::Param(_) => {
|
||||
var_type.param_pointer_to(ast::LdStateSpace::Param)?
|
||||
}
|
||||
ast::VariableType::Local(_) => {
|
||||
var_type.param_pointer_to(ast::LdStateSpace::Local)?
|
||||
}
|
||||
ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
|
||||
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
|
||||
ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
|
||||
ast::StateSpace::Const => todo!(),
|
||||
ast::StateSpace::Generic => todo!(),
|
||||
};
|
||||
match var.count {
|
||||
Some(count) => {
|
||||
|
@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>(
|
|||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init.clone(),
|
||||
}))
|
||||
|
@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>(
|
|||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init,
|
||||
}));
|
||||
|
@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>(
|
|||
align: None,
|
||||
name: new_id,
|
||||
array_init: Vec::new(),
|
||||
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||
ast::ScalarType::U8,
|
||||
v_type: ast::Type::Pointer(
|
||||
ast::PointerType::Scalar(ast::ScalarType::U8),
|
||||
ast::LdStateSpace::Global,
|
||||
)),
|
||||
),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
}));
|
||||
remapped_ids.insert(reg, new_id);
|
||||
}
|
||||
|
@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> {
|
|||
}
|
||||
|
||||
pub struct FnDecl {
|
||||
ret_vals: Vec<ast::FnArgumentType>,
|
||||
params: Vec<ast::FnArgumentType>,
|
||||
ret_vals: Vec<(ast::Type, ast::StateSpace)>,
|
||||
params: Vec<(ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
||||
impl<'a> GlobalStringIdResolver<'a> {
|
||||
|
@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
self.fns.insert(
|
||||
name_id,
|
||||
FnDecl {
|
||||
ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
|
||||
params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
|
||||
ret_vals: ret_params_ids
|
||||
.iter()
|
||||
.map(|p| (p.v_type.clone(), p.state_space))
|
||||
.collect(),
|
||||
params: params_ids
|
||||
.iter()
|
||||
.map(|p| (p.v_type.clone(), p.state_space))
|
||||
.collect(),
|
||||
},
|
||||
);
|
||||
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
|
||||
|
@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> {
|
|||
|
||||
enum Statement<I, P: ast::ArgParams> {
|
||||
Label(u32),
|
||||
Variable(ast::Variable<ast::VariableType, P::Id>),
|
||||
Variable(ast::Variable<P::Id>),
|
||||
Instruction(I),
|
||||
// SPIR-V compatible replacement for PTX predicates
|
||||
Conditional(BrachCondition),
|
||||
|
@ -5352,16 +5319,17 @@ impl ExpandedStatement {
|
|||
Statement::StoreVar(details)
|
||||
}
|
||||
Statement::Call(mut call) => {
|
||||
for (id, typ) in call.ret_params.iter_mut() {
|
||||
let is_dst = match typ {
|
||||
ast::FnArgumentType::Reg(_) => true,
|
||||
ast::FnArgumentType::Param(_) => false,
|
||||
ast::FnArgumentType::Shared => false,
|
||||
for (id, _, space) in call.ret_params.iter_mut() {
|
||||
let is_dst = match space {
|
||||
ast::StateSpace::Reg => true,
|
||||
ast::StateSpace::Param => false,
|
||||
ast::StateSpace::Shared => false,
|
||||
_ => todo!(),
|
||||
};
|
||||
*id = f(*id, is_dst);
|
||||
}
|
||||
call.func = f(call.func, false);
|
||||
for (id, _) in call.param_list.iter_mut() {
|
||||
for (id, _, _) in call.param_list.iter_mut() {
|
||||
*id = f(*id, false);
|
||||
}
|
||||
Statement::Call(call)
|
||||
|
@ -5502,9 +5470,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
|
|||
|
||||
struct ResolvedCall<P: ast::ArgParams> {
|
||||
pub uniform: bool,
|
||||
pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
|
||||
pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>,
|
||||
pub func: P::Id,
|
||||
pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
|
||||
pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
|
||||
}
|
||||
|
||||
impl<T: ast::ArgParams> ResolvedCall<T> {
|
||||
|
@ -5526,16 +5494,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
|
|||
let ret_params = self
|
||||
.ret_params
|
||||
.into_iter()
|
||||
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
|
||||
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
||||
let new_id = visitor.id(
|
||||
ArgumentDescriptor {
|
||||
op: id,
|
||||
is_dst: !typ.is_param(),
|
||||
sema: typ.semantics(),
|
||||
is_dst: space != ast::StateSpace::Param,
|
||||
sema: space.semantics(),
|
||||
},
|
||||
Some(&typ.to_func_type()),
|
||||
Some(&typ),
|
||||
)?;
|
||||
Ok((new_id, typ))
|
||||
Ok((new_id, typ, space))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let func = visitor.id(
|
||||
|
@ -5549,16 +5517,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
|
|||
let param_list = self
|
||||
.param_list
|
||||
.into_iter()
|
||||
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
|
||||
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
|
||||
let new_id = visitor.operand(
|
||||
ArgumentDescriptor {
|
||||
op: id,
|
||||
is_dst: false,
|
||||
sema: typ.semantics(),
|
||||
sema: space.semantics(),
|
||||
},
|
||||
&typ.to_func_type(),
|
||||
&typ,
|
||||
)?;
|
||||
Ok((new_id, typ))
|
||||
Ok((new_id, typ, space))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(ResolvedCall {
|
||||
|
@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams {
|
|||
}
|
||||
|
||||
enum Directive<'input> {
|
||||
Variable(ast::Variable<ast::VariableType, spirv::Word>),
|
||||
Variable(ast::Variable<spirv::Word>),
|
||||
Method(Function<'input>),
|
||||
}
|
||||
|
||||
struct Function<'input> {
|
||||
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
|
||||
pub spirv_decl: SpirvMethodDecl<'input>,
|
||||
pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
|
||||
pub globals: Vec<ast::Variable<spirv::Word>>,
|
||||
pub body: Option<Vec<ExpandedStatement>>,
|
||||
import_as: Option<String>,
|
||||
tuning: Vec<ast::TuningDirective>,
|
||||
|
@ -7300,16 +7268,6 @@ impl ast::LdStateSpace {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ast::FnArgumentType> for ast::VariableType {
|
||||
fn from(t: ast::FnArgumentType) -> Self {
|
||||
match t {
|
||||
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
|
||||
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
|
||||
ast::FnArgumentType::Shared => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::Operand<T> {
|
||||
fn underlying(&self) -> Option<&T> {
|
||||
match self {
|
||||
|
@ -7362,12 +7320,13 @@ impl ast::AtomSemantics {
|
|||
}
|
||||
}
|
||||
|
||||
impl ast::FnArgumentType {
|
||||
fn semantics(&self) -> ArgumentSemantics {
|
||||
impl ast::StateSpace {
|
||||
fn semantics(self) -> ArgumentSemantics {
|
||||
match self {
|
||||
ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
|
||||
ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
|
||||
ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
|
||||
ast::StateSpace::Reg => ArgumentSemantics::Default,
|
||||
ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
|
||||
ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer,
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> {
|
|||
}
|
||||
|
||||
struct SpirvMethodDecl<'input> {
|
||||
input: Vec<ast::Variable<ast::Type, spirv::Word>>,
|
||||
output: Vec<ast::Variable<ast::Type, spirv::Word>>,
|
||||
input: Vec<ast::Variable<spirv::Word>>,
|
||||
output: Vec<ast::Variable<spirv::Word>>,
|
||||
name: MethodName<'input>,
|
||||
uses_shared_mem: bool,
|
||||
}
|
||||
|
@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> {
|
|||
ast::MethodDecl::Kernel { in_args, .. } => {
|
||||
let spirv_input = in_args
|
||||
.iter()
|
||||
.map(|var| {
|
||||
let v_type = match &var.v_type {
|
||||
ast::KernelArgumentType::Normal(t) => {
|
||||
ast::FnArgumentType::Param(t.clone())
|
||||
}
|
||||
ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
|
||||
};
|
||||
ast::Variable {
|
||||
name: var.name,
|
||||
align: var.align,
|
||||
v_type: v_type.to_kernel_type(),
|
||||
array_init: var.array_init.clone(),
|
||||
}
|
||||
.map(|var| ast::Variable {
|
||||
name: var.name,
|
||||
align: var.align,
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
array_init: var.array_init.clone(),
|
||||
})
|
||||
.collect();
|
||||
(spirv_input, Vec::new())
|
||||
}
|
||||
ast::MethodDecl::Func(out_args, _, in_args) => {
|
||||
let (param_output, non_param_output): (Vec<_>, Vec<_>) =
|
||||
out_args.iter().partition(|var| var.v_type.is_param());
|
||||
let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
|
||||
.iter()
|
||||
.partition(|var| var.state_space == ast::StateSpace::Param);
|
||||
let spirv_output = non_param_output
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.map(|var| ast::Variable {
|
||||
name: var.name,
|
||||
align: var.align,
|
||||
v_type: var.v_type.to_func_type(),
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
array_init: var.array_init.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> {
|
|||
.map(|var| ast::Variable {
|
||||
name: var.name,
|
||||
align: var.align,
|
||||
v_type: var.v_type.to_func_type(),
|
||||
v_type: var.v_type.clone(),
|
||||
state_space: var.state_space,
|
||||
array_init: var.array_init.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
|
Loading…
Add table
Reference in a new issue