mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Be more precise about types admitted in register definitions and method arguments
This commit is contained in:
parent
76afbeba63
commit
1238796dfd
7 changed files with 647 additions and 351 deletions
256
ptx/src/ast.rs
256
ptx/src/ast.rs
|
@ -12,9 +12,117 @@ quick_error! {
|
|||
SyntaxError {}
|
||||
NonF32Ftz {}
|
||||
WrongArrayType {}
|
||||
WrongVectorElement {}
|
||||
MultiArrayVariable {}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! sub_scalar_type {
|
||||
($name:ident { $($variant:ident),+ $(,)? }) => {
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub enum $name {
|
||||
$(
|
||||
$variant,
|
||||
)+
|
||||
}
|
||||
|
||||
impl From<$name> for ScalarType {
|
||||
fn from(t: $name) -> ScalarType {
|
||||
match t {
|
||||
$(
|
||||
$name::$variant => ScalarType::$variant,
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! sub_type {
|
||||
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub enum $type_name {
|
||||
$(
|
||||
$variant ($($field_type),+),
|
||||
)+
|
||||
}
|
||||
|
||||
impl From<$type_name> for Type {
|
||||
#[allow(non_snake_case)]
|
||||
fn from(t: $type_name) -> Type {
|
||||
match t {
|
||||
$(
|
||||
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
sub_type! {
|
||||
VariableRegType {
|
||||
Scalar(ScalarType),
|
||||
Vector(SizedScalarType, u8),
|
||||
}
|
||||
}
|
||||
|
||||
sub_type! {
|
||||
VariableLocalType {
|
||||
Scalar(SizedScalarType),
|
||||
Vector(SizedScalarType, u8),
|
||||
Array(SizedScalarType, u32),
|
||||
}
|
||||
}
|
||||
|
||||
// For some weird reson this is illegal:
|
||||
// .param .f16x2 foobar;
|
||||
// but this is legal:
|
||||
// .param .f16x2 foobar[1];
|
||||
sub_type! {
|
||||
VariableParamType {
|
||||
Scalar(ParamScalarType),
|
||||
Array(SizedScalarType, u32),
|
||||
}
|
||||
}
|
||||
|
||||
sub_scalar_type!(SizedScalarType {
|
||||
B8,
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S8,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
F16,
|
||||
F16x2,
|
||||
F32,
|
||||
F64,
|
||||
});
|
||||
|
||||
sub_scalar_type!(ParamScalarType {
|
||||
B8,
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S8,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
});
|
||||
|
||||
pub trait UnwrapWithVec<E, To> {
|
||||
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
|
||||
}
|
||||
|
@ -56,6 +164,9 @@ pub enum MethodDecl<'a, P: ArgParams> {
|
|||
Kernel(&'a str, Vec<KernelArgument<P>>),
|
||||
}
|
||||
|
||||
pub type FnArgument<P: ArgParams> = Variable<FnArgumentType, P>;
|
||||
pub type KernelArgument<P: ArgParams> = Variable<VariableParamType, P>;
|
||||
|
||||
pub struct Function<'a, P: ArgParams, S> {
|
||||
pub func_directive: MethodDecl<'a, P>,
|
||||
pub body: Option<Vec<S>>,
|
||||
|
@ -63,43 +174,28 @@ pub struct Function<'a, P: ArgParams, S> {
|
|||
|
||||
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
|
||||
|
||||
pub struct FnArgument<P: ArgParams> {
|
||||
pub base: KernelArgument<P>,
|
||||
pub state_space: FnArgStateSpace,
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub enum FnArgumentType {
|
||||
Reg(VariableRegType),
|
||||
Param(VariableParamType),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum FnArgStateSpace {
|
||||
Reg,
|
||||
Param,
|
||||
}
|
||||
|
||||
#[derive(Default, Copy, Clone)]
|
||||
pub struct KernelArgument<P: ArgParams> {
|
||||
pub name: P::ID,
|
||||
pub a_type: ScalarType,
|
||||
// TODO: turn length into part of type definition
|
||||
pub length: u32,
|
||||
impl From<FnArgumentType> for Type {
|
||||
fn from(t: FnArgumentType) -> Self {
|
||||
match t {
|
||||
FnArgumentType::Reg(x) => x.into(),
|
||||
FnArgumentType::Param(x) => x.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum Type {
|
||||
Scalar(ScalarType),
|
||||
ExtendedScalar(ExtendedScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
Array(ScalarType, u32),
|
||||
}
|
||||
|
||||
impl From<FloatType> for Type {
|
||||
fn from(t: FloatType) -> Self {
|
||||
match t {
|
||||
FloatType::F16 => Type::Scalar(ScalarType::F16),
|
||||
FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2),
|
||||
FloatType::F32 => Type::Scalar(ScalarType::F32),
|
||||
FloatType::F64 => Type::Scalar(ScalarType::F64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum ScalarType {
|
||||
B8,
|
||||
|
@ -117,25 +213,11 @@ pub enum ScalarType {
|
|||
F16,
|
||||
F32,
|
||||
F64,
|
||||
F16x2,
|
||||
Pred,
|
||||
}
|
||||
|
||||
impl From<IntType> for ScalarType {
|
||||
fn from(t: IntType) -> Self {
|
||||
match t {
|
||||
IntType::S8 => ScalarType::S8,
|
||||
IntType::S16 => ScalarType::S16,
|
||||
IntType::S32 => ScalarType::S32,
|
||||
IntType::S64 => ScalarType::S64,
|
||||
IntType::U8 => ScalarType::U8,
|
||||
IntType::U16 => ScalarType::U16,
|
||||
IntType::U32 => ScalarType::U32,
|
||||
IntType::U64 => ScalarType::U64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum IntType {
|
||||
sub_scalar_type!(IntType {
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
|
@ -143,8 +225,8 @@ pub enum IntType {
|
|||
S8,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
}
|
||||
S64
|
||||
});
|
||||
|
||||
impl IntType {
|
||||
pub fn is_signed(self) -> bool {
|
||||
|
@ -168,19 +250,12 @@ impl IntType {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum FloatType {
|
||||
sub_scalar_type!(FloatType {
|
||||
F16,
|
||||
F16x2,
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub enum ExtendedScalarType {
|
||||
F16x2,
|
||||
Pred,
|
||||
}
|
||||
F64
|
||||
});
|
||||
|
||||
impl Default for ScalarType {
|
||||
fn default() -> Self {
|
||||
|
@ -190,19 +265,39 @@ impl Default for ScalarType {
|
|||
|
||||
pub enum Statement<P: ArgParams> {
|
||||
Label(P::ID),
|
||||
Variable(Variable<P>),
|
||||
Variable(MultiVariable<P>),
|
||||
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
|
||||
Block(Vec<Statement<P>>),
|
||||
}
|
||||
|
||||
pub struct Variable<P: ArgParams> {
|
||||
pub space: StateSpace,
|
||||
pub align: Option<u32>,
|
||||
pub v_type: Type,
|
||||
pub name: P::ID,
|
||||
pub struct MultiVariable<P: ArgParams> {
|
||||
pub var: Variable<VariableType, P>,
|
||||
pub count: Option<u32>,
|
||||
}
|
||||
|
||||
pub struct Variable<T, P: ArgParams> {
|
||||
pub align: Option<u32>,
|
||||
pub v_type: T,
|
||||
pub name: P::ID,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||
pub enum VariableType {
|
||||
Reg(VariableRegType),
|
||||
Local(VariableLocalType),
|
||||
Param(VariableParamType),
|
||||
}
|
||||
|
||||
impl From<VariableType> for Type {
|
||||
fn from(t: VariableType) -> Self {
|
||||
match t {
|
||||
VariableType::Reg(t) => t.into(),
|
||||
VariableType::Local(t) => t.into(),
|
||||
VariableType::Param(t) => t.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
|
@ -322,7 +417,7 @@ pub enum CallOperand<ID> {
|
|||
|
||||
pub enum MovOperand<ID> {
|
||||
Op(Operand<ID>),
|
||||
Vec(String, String),
|
||||
Vec(ID, u8),
|
||||
}
|
||||
|
||||
pub enum VectorPrefix {
|
||||
|
@ -334,7 +429,7 @@ pub struct LdData {
|
|||
pub qualifier: LdStQualifier,
|
||||
pub state_space: LdStateSpace,
|
||||
pub caching: LdCacheOperator,
|
||||
pub vector: Option<VectorPrefix>,
|
||||
pub vector: Option<u8>,
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
|
@ -376,6 +471,37 @@ pub struct MovData {
|
|||
pub typ: Type,
|
||||
}
|
||||
|
||||
sub_scalar_type!(MovScalarType {
|
||||
B16,
|
||||
B32,
|
||||
B64,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
S16,
|
||||
S32,
|
||||
S64,
|
||||
F32,
|
||||
F64,
|
||||
Pred,
|
||||
});
|
||||
|
||||
enum MovType {
|
||||
Scalar(MovScalarType),
|
||||
Vector(MovScalarType, u8),
|
||||
Array(MovScalarType, u32),
|
||||
}
|
||||
|
||||
impl From<MovType> for Type {
|
||||
fn from(t: MovType) -> Self {
|
||||
match t {
|
||||
MovType::Scalar(t) => Type::Scalar(t.into()),
|
||||
MovType::Vector(t, len) => Type::Vector(t.into(), len),
|
||||
MovType::Array(t, len) => Type::Array(t.into(), len),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum MulDetails {
|
||||
Int(MulIntDesc),
|
||||
Float(MulFloatDesc),
|
||||
|
@ -587,7 +713,7 @@ pub struct StData {
|
|||
pub qualifier: LdStQualifier,
|
||||
pub state_space: StStateSpace,
|
||||
pub caching: StCacheOperator,
|
||||
pub vector: Option<VectorPrefix>,
|
||||
pub vector: Option<u8>,
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ match {
|
|||
"@",
|
||||
"[", "]",
|
||||
"{", "}",
|
||||
"<", ">",
|
||||
"|",
|
||||
".acquire",
|
||||
".address_size",
|
||||
|
@ -133,8 +134,6 @@ match {
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
|
||||
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
|
||||
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
|
||||
} else {
|
||||
r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID,
|
||||
}
|
||||
|
||||
ExtendedID : &'input str = {
|
||||
|
@ -214,7 +213,9 @@ LinkingDirective = {
|
|||
|
||||
MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
|
||||
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
|
||||
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
||||
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
|
||||
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
|
||||
}
|
||||
};
|
||||
|
||||
KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
|
||||
|
@ -225,32 +226,25 @@ FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
|
|||
"(" <args:Comma<FnInput>> ")" => args
|
||||
};
|
||||
|
||||
FnInput: ast::FnArgument<ast::ParsedArgParams<'input>> = {
|
||||
".reg" <_type:ScalarType> <name:ExtendedID> => {
|
||||
ast::FnArgument {
|
||||
base: ast::KernelArgument {a_type: _type, name: name, length: 1 },
|
||||
state_space: ast::FnArgStateSpace::Reg,
|
||||
}
|
||||
},
|
||||
<p:KernelInput> => {
|
||||
ast::FnArgument {
|
||||
base: p,
|
||||
state_space: ast::FnArgStateSpace::Param,
|
||||
}
|
||||
KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = {
|
||||
<v:ParamVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
ast::Variable{ align, v_type, name }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
KernelInput: ast::KernelArgument<ast::ParsedArgParams<'input>> = {
|
||||
".param" <_type:ScalarType> <name:ExtendedID> => {
|
||||
ast::KernelArgument {a_type: _type, name: name, length: 1 }
|
||||
FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = {
|
||||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Reg(v_type);
|
||||
ast::Variable{ align, v_type, name }
|
||||
},
|
||||
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => {
|
||||
let length = length.parse::<u32>();
|
||||
let length = length.unwrap_with(errors);
|
||||
ast::KernelArgument { a_type: a_type, name: name, length: length }
|
||||
<v:ParamVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Param(v_type);
|
||||
ast::Variable{ align, v_type, name }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
|
||||
"{" <s:Statement*> "}" => { Some(without_none(s)) },
|
||||
|
@ -267,22 +261,13 @@ StateSpaceSpecifier: ast::StateSpace = {
|
|||
".param" => ast::StateSpace::Param, // used to prepare function call
|
||||
};
|
||||
|
||||
|
||||
Type: ast::Type = {
|
||||
<t:ScalarType> => ast::Type::Scalar(t),
|
||||
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
|
||||
};
|
||||
|
||||
ScalarType: ast::ScalarType = {
|
||||
".f16" => ast::ScalarType::F16,
|
||||
".f16x2" => ast::ScalarType::F16x2,
|
||||
".pred" => ast::ScalarType::Pred,
|
||||
MemoryType
|
||||
};
|
||||
|
||||
ExtendedScalarType: ast::ExtendedScalarType = {
|
||||
".f16x2" => ast::ExtendedScalarType::F16x2,
|
||||
".pred" => ast::ExtendedScalarType::Pred,
|
||||
};
|
||||
|
||||
MemoryType: ast::ScalarType = {
|
||||
".b8" => ast::ScalarType::B8,
|
||||
".b16" => ast::ScalarType::B16,
|
||||
|
@ -303,7 +288,7 @@ MemoryType: ast::ScalarType = {
|
|||
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||
<l:Label> => Some(ast::Statement::Label(l)),
|
||||
DebugDirective => None,
|
||||
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
|
||||
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
|
||||
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
|
||||
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
|
||||
};
|
||||
|
@ -328,21 +313,109 @@ Align: u32 = {
|
|||
}
|
||||
};
|
||||
|
||||
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
|
||||
<s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => {
|
||||
let (name, count) = v;
|
||||
let t = match (t, arr) {
|
||||
(ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
|
||||
(t, Some(_)) => {
|
||||
errors.push(ast::PtxError::WrongArrayType);
|
||||
t
|
||||
},
|
||||
(t, None) => t,
|
||||
};
|
||||
ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||
MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = {
|
||||
<var:Variable> <count:VariableParam?> => ast::MultiVariable{<>}
|
||||
}
|
||||
|
||||
VariableParam: u32 = {
|
||||
"<" <n:Num> ">" => {
|
||||
let size = n.parse::<u32>();
|
||||
size.unwrap_with(errors)
|
||||
}
|
||||
}
|
||||
|
||||
Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
|
||||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Reg(v_type);
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
LocalVariable,
|
||||
<v:ParamVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Param(v_type);
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
};
|
||||
|
||||
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
||||
".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableRegType::Scalar(t);
|
||||
(align, v_type, name)
|
||||
},
|
||||
".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableRegType::Vector(t, v_len);
|
||||
(align, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
|
||||
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
|
||||
ast::Variable {align, v_type, name}
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableParamType::Scalar(t);
|
||||
(align, v_type, name)
|
||||
},
|
||||
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||
let v_type = ast::VariableParamType::Array(t, arr);
|
||||
(align, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
SizedScalarType: ast::SizedScalarType = {
|
||||
".b8" => ast::SizedScalarType::B8,
|
||||
".b16" => ast::SizedScalarType::B16,
|
||||
".b32" => ast::SizedScalarType::B32,
|
||||
".b64" => ast::SizedScalarType::B64,
|
||||
".u8" => ast::SizedScalarType::U8,
|
||||
".u16" => ast::SizedScalarType::U16,
|
||||
".u32" => ast::SizedScalarType::U32,
|
||||
".u64" => ast::SizedScalarType::U64,
|
||||
".s8" => ast::SizedScalarType::S8,
|
||||
".s16" => ast::SizedScalarType::S16,
|
||||
".s32" => ast::SizedScalarType::S32,
|
||||
".s64" => ast::SizedScalarType::S64,
|
||||
".f16" => ast::SizedScalarType::F16,
|
||||
".f16x2" => ast::SizedScalarType::F16x2,
|
||||
".f32" => ast::SizedScalarType::F32,
|
||||
".f64" => ast::SizedScalarType::F64,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
ParamScalarType: ast::ParamScalarType = {
|
||||
".b8" => ast::ParamScalarType::B8,
|
||||
".b16" => ast::ParamScalarType::B16,
|
||||
".b32" => ast::ParamScalarType::B32,
|
||||
".b64" => ast::ParamScalarType::B64,
|
||||
".u8" => ast::ParamScalarType::U8,
|
||||
".u16" => ast::ParamScalarType::U16,
|
||||
".u32" => ast::ParamScalarType::U32,
|
||||
".u64" => ast::ParamScalarType::U64,
|
||||
".s8" => ast::ParamScalarType::S8,
|
||||
".s16" => ast::ParamScalarType::S16,
|
||||
".s32" => ast::ParamScalarType::S32,
|
||||
".s64" => ast::ParamScalarType::S64,
|
||||
".f16" => ast::ParamScalarType::F16,
|
||||
".f32" => ast::ParamScalarType::F32,
|
||||
".f64" => ast::ParamScalarType::F64,
|
||||
}
|
||||
|
||||
ArraySpecifier: u32 = {
|
||||
"[" <n:Num> "]" => {
|
||||
let size = n.parse::<u32>();
|
||||
|
@ -350,20 +423,6 @@ ArraySpecifier: u32 = {
|
|||
}
|
||||
};
|
||||
|
||||
VariableName: (&'input str, Option<u32>) = {
|
||||
<id:ExtendedID> => (id, None),
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
|
||||
<id:ParametrizedID> => {
|
||||
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
|
||||
let count = id[left_angle+1..id.len()-1].parse::<u32>();
|
||||
let count = match count {
|
||||
Ok(c) => Some(c),
|
||||
Err(e) => { errors.push(e.into()); None },
|
||||
};
|
||||
(&id[0..left_angle], count)
|
||||
}
|
||||
};
|
||||
|
||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
InstLd,
|
||||
InstMov,
|
||||
|
@ -445,7 +504,7 @@ MovType: ast::Type = {
|
|||
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
|
||||
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
|
||||
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
|
||||
".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)
|
||||
".pred" => ast::Type::Scalar(ast::ScalarType::Pred)
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
|
||||
|
@ -934,7 +993,17 @@ MovOperand: ast::MovOperand<&'input str> = {
|
|||
<o:Operand> => ast::MovOperand::Op(o),
|
||||
<o:VectorOperand> => {
|
||||
let (pref, suf) = o;
|
||||
ast::MovOperand::Vec(pref.to_string(), suf.to_string())
|
||||
let suf_idx = match suf {
|
||||
"x" | "r" => 0,
|
||||
"y" | "g" => 1,
|
||||
"z" | "b" => 2,
|
||||
"w" | "a" => 3,
|
||||
_ => {
|
||||
errors.push(ast::PtxError::WrongVectorElement);
|
||||
0
|
||||
}
|
||||
};
|
||||
ast::MovOperand::Vec(pref, suf_idx)
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -980,9 +1049,9 @@ OptionalDst: &'input str = {
|
|||
"|" <dst2:ExtendedID> => dst2
|
||||
}
|
||||
|
||||
VectorPrefix: ast::VectorPrefix = {
|
||||
".v2" => ast::VectorPrefix::V2,
|
||||
".v4" => ast::VectorPrefix::V4
|
||||
VectorPrefix: u8 = {
|
||||
".v2" => 2,
|
||||
".v4" => 4
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file
|
||||
|
|
|
@ -8,16 +8,16 @@ fn parse_and_assert(s: &str) {
|
|||
assert!(errors.len() == 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty() {
|
||||
parse_and_assert(".version 6.5 .target sm_30, debug");
|
||||
fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
|
||||
crate::to_spirv(ast)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn vectorAdd_kernel64_ptx() {
|
||||
let vector_add = include_str!("vectorAdd_kernel64.ptx");
|
||||
parse_and_assert(vector_add);
|
||||
fn empty() {
|
||||
parse_and_assert(".version 6.5 .target sm_30, debug");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -28,8 +28,14 @@ fn operands_ptx() {
|
|||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn _Z9vectorAddPKfS0_Pfi_ptx() {
|
||||
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
||||
parse_and_assert(vector_add);
|
||||
fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> {
|
||||
let vector_add = include_str!("vectorAdd_kernel64.ptx");
|
||||
compile_and_assert(vector_add)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(non_snake_case)]
|
||||
fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> {
|
||||
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
||||
compile_and_assert(vector_add)
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
|
|||
test_ptx!(block, [1u64], [2u64]);
|
||||
test_ptx!(local_align, [1u64], [1u64]);
|
||||
test_ptx!(call, [1u64], [2u64]);
|
||||
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
|
44
ptx/src/test/spirv_run/vector.ptx
Normal file
44
ptx/src/test/spirv_run/vector.ptx
Normal file
|
@ -0,0 +1,44 @@
|
|||
// Excersise as many features of vector types as possible
|
||||
|
||||
.version 6.5
|
||||
.target sm_53
|
||||
.address_size 64
|
||||
|
||||
.func (.reg .v2 .u32 output) impl(
|
||||
.reg .v2 .u32 input
|
||||
)
|
||||
{
|
||||
.reg .v2 .u32 temp_v;
|
||||
.reg .u32 temp1;
|
||||
.reg .u32 temp2;
|
||||
|
||||
mov.u32 temp1, input.x;
|
||||
mov.u32 temp2, input.y;
|
||||
add.u32 temp2, temp1, temp2;
|
||||
mov.u32 temp_v.x, temp2;
|
||||
mov.u32 temp_v.y, temp2;
|
||||
mov.v2.u32 output, temp_v;
|
||||
ret;
|
||||
}
|
||||
|
||||
.visible .entry vector(
|
||||
.param .u64 input_p,
|
||||
.param .u64 output_p
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .v2 .u32 temp;
|
||||
.reg .u32 temp1;
|
||||
.reg .u32 temp2;
|
||||
.reg .b64 packed;
|
||||
|
||||
ld.param.u64 in_addr, [input_p];
|
||||
ld.param.u64 out_addr, [output_p];
|
||||
|
||||
ld.v2.u32 temp, [in_addr];
|
||||
call (temp), impl, (temp);
|
||||
mov.b64 packed, temp;
|
||||
st.v2.u32 [out_addr], temp;
|
||||
ret;
|
||||
}
|
46
ptx/src/test/spirv_run/vector.spvtxt
Normal file
46
ptx/src/test/spirv_run/vector.spvtxt
Normal file
|
@ -0,0 +1,46 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%25 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "add"
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%28 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
|
||||
%ulong_1 = OpConstant %ulong 1
|
||||
%1 = OpFunction %void None %28
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%23 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %8
|
||||
OpStore %3 %9
|
||||
%11 = OpLoad %ulong %2
|
||||
%10 = OpCopyObject %ulong %11
|
||||
OpStore %4 %10
|
||||
%13 = OpLoad %ulong %3
|
||||
%12 = OpCopyObject %ulong %13
|
||||
OpStore %5 %12
|
||||
%15 = OpLoad %ulong %4
|
||||
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
|
||||
%14 = OpLoad %ulong %21
|
||||
OpStore %6 %14
|
||||
%17 = OpLoad %ulong %6
|
||||
%16 = OpIAdd %ulong %17 %ulong_1
|
||||
OpStore %7 %16
|
||||
%18 = OpLoad %ulong %5
|
||||
%19 = OpLoad %ulong %7
|
||||
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
|
||||
OpStore %22 %19
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -8,6 +8,7 @@ use rspirv::binary::Assemble;
|
|||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
enum SpirvType {
|
||||
Base(SpirvScalarKey),
|
||||
Vector(SpirvScalarKey, u8),
|
||||
Array(SpirvScalarKey, u32),
|
||||
Pointer(Box<SpirvType>, spirv::StorageClass),
|
||||
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
|
||||
|
@ -17,7 +18,7 @@ impl SpirvType {
|
|||
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
|
||||
let key = match t {
|
||||
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
||||
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
|
||||
ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
|
||||
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
|
||||
};
|
||||
SpirvType::Pointer(Box::new(key), sc)
|
||||
|
@ -28,7 +29,7 @@ impl From<ast::Type> for SpirvType {
|
|||
fn from(t: ast::Type) -> Self {
|
||||
match t {
|
||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
|
||||
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
|
||||
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
|
||||
}
|
||||
}
|
||||
|
@ -77,15 +78,8 @@ impl From<ast::ScalarType> for SpirvScalarKey {
|
|||
ast::ScalarType::F16 => SpirvScalarKey::F16,
|
||||
ast::ScalarType::F32 => SpirvScalarKey::F32,
|
||||
ast::ScalarType::F64 => SpirvScalarKey::F64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ast::ExtendedScalarType> for SpirvScalarKey {
|
||||
fn from(t: ast::ExtendedScalarType) -> Self {
|
||||
match t {
|
||||
ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred,
|
||||
ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2,
|
||||
ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
|
||||
ast::ScalarType::Pred => SpirvScalarKey::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,6 +129,13 @@ impl TypeWordMap {
|
|||
.entry(t)
|
||||
.or_insert_with(|| b.type_pointer(None, storage, base))
|
||||
}
|
||||
SpirvType::Vector(typ, len) => {
|
||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
.or_insert_with(|| b.type_vector(base, len as u32))
|
||||
}
|
||||
SpirvType::Array(typ, len) => {
|
||||
let base = self.get_or_add_spirv_scalar(b, typ);
|
||||
*self
|
||||
|
@ -232,8 +233,8 @@ fn emit_function_header<'a>(
|
|||
spirv::FunctionControl::NONE,
|
||||
func_type,
|
||||
)?;
|
||||
func_directive.visit_args(|arg| {
|
||||
let result_type = map.get_or_add_scalar(builder, arg.a_type);
|
||||
func_directive.visit_args(&mut |arg| {
|
||||
let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
|
||||
let inst = dr::Instruction::new(
|
||||
spirv::Op::FunctionParameter,
|
||||
Some(result_type),
|
||||
|
@ -285,9 +286,9 @@ fn expand_kernel_params<'a, 'b>(
|
|||
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
|
||||
args.map(|a| ast::KernelArgument {
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
|
||||
a_type: a.a_type,
|
||||
length: a.length,
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||
v_type: a.v_type,
|
||||
align: a.align,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -297,12 +298,9 @@ fn expand_fn_params<'a, 'b>(
|
|||
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
|
||||
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
|
||||
args.map(|a| ast::FnArgument {
|
||||
state_space: a.state_space,
|
||||
base: ast::KernelArgument {
|
||||
name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))),
|
||||
a_type: a.base.a_type,
|
||||
length: a.base.length,
|
||||
},
|
||||
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
|
||||
v_type: a.v_type,
|
||||
align: a.align,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -375,16 +373,12 @@ fn resolve_fn_calls(
|
|||
|
||||
fn to_resolved_fn_args<T>(
|
||||
params: Vec<T>,
|
||||
params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)],
|
||||
) -> Vec<ArgCall<T>> {
|
||||
params_decl: &[ast::FnArgumentType],
|
||||
) -> Vec<(T, ast::FnArgumentType)> {
|
||||
params
|
||||
.into_iter()
|
||||
.zip(params_decl.iter())
|
||||
.map(|(id, &(space, typ))| ArgCall {
|
||||
id,
|
||||
typ: ast::Type::Scalar(typ),
|
||||
space: space,
|
||||
})
|
||||
.map(|(id, typ)| (id, *typ))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
|
@ -476,12 +470,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
let out_param = match &mut f_args {
|
||||
ast::MethodDecl::Kernel(_, in_params) => {
|
||||
for p in in_params.iter_mut() {
|
||||
let typ = ast::Type::Scalar(p.a_type);
|
||||
let typ = ast::Type::from(p.v_type);
|
||||
let new_id = id_def.new_id(Some(typ));
|
||||
result.push(Statement::Variable(VariableDecl {
|
||||
space: ast::StateSpace::Reg,
|
||||
align: None,
|
||||
v_type: typ,
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: p.align,
|
||||
v_type: ast::VariableType::Param(p.v_type),
|
||||
name: p.name,
|
||||
}));
|
||||
result.push(Statement::StoreVar(
|
||||
|
@ -497,32 +490,31 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
}
|
||||
ast::MethodDecl::Func(out_params, _, in_params) => {
|
||||
for p in in_params.iter_mut() {
|
||||
let typ = ast::Type::Scalar(p.base.a_type);
|
||||
let typ = ast::Type::from(p.v_type);
|
||||
let new_id = id_def.new_id(Some(typ));
|
||||
result.push(Statement::Variable(VariableDecl {
|
||||
space: ast::StateSpace::Reg,
|
||||
align: None,
|
||||
v_type: typ,
|
||||
name: p.base.name,
|
||||
let var_typ = ast::VariableType::from(p.v_type);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: p.align,
|
||||
v_type: var_typ,
|
||||
name: p.name,
|
||||
}));
|
||||
result.push(Statement::StoreVar(
|
||||
ast::Arg2St {
|
||||
src1: p.base.name,
|
||||
src1: p.name,
|
||||
src2: new_id,
|
||||
},
|
||||
typ,
|
||||
));
|
||||
p.base.name = new_id;
|
||||
p.name = new_id;
|
||||
}
|
||||
match &mut **out_params {
|
||||
[p] => {
|
||||
result.push(Statement::Variable(VariableDecl {
|
||||
space: ast::StateSpace::Reg,
|
||||
align: None,
|
||||
v_type: ast::Type::Scalar(p.base.a_type),
|
||||
name: p.base.name,
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: p.align,
|
||||
v_type: ast::VariableType::from(p.v_type),
|
||||
name: p.name,
|
||||
}));
|
||||
Some(p.base.name)
|
||||
Some(p.name)
|
||||
}
|
||||
[] => None,
|
||||
_ => todo!(),
|
||||
|
@ -552,15 +544,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
|
|||
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
|
||||
},
|
||||
Statement::Conditional(mut bra) => {
|
||||
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
|
||||
ast::ExtendedScalarType::Pred,
|
||||
)));
|
||||
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
dst: generated_id,
|
||||
src: bra.predicate,
|
||||
},
|
||||
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
||||
ast::Type::Scalar(ast::ScalarType::Pred),
|
||||
));
|
||||
bra.predicate = generated_id;
|
||||
result.push(Statement::Conditional(bra));
|
||||
|
@ -642,7 +632,15 @@ fn expand_arguments<'a, 'b>(
|
|||
let new_inst = inst.map(&mut visitor);
|
||||
result.push(Statement::Instruction(new_inst));
|
||||
}
|
||||
Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)),
|
||||
Statement::Variable(ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
name,
|
||||
}) => result.push(Statement::Variable(ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
name,
|
||||
})),
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
|
||||
|
@ -745,7 +743,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
|
|||
) -> spirv::Word {
|
||||
match &desc.op {
|
||||
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
|
||||
ast::MovOperand::Vec(_, _) => todo!(),
|
||||
ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -835,13 +833,19 @@ fn get_function_type(
|
|||
match method_decl {
|
||||
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
|
||||
builder,
|
||||
out_params.iter().map(|p| SpirvType::from(p.base.a_type)),
|
||||
in_params.iter().map(|p| SpirvType::from(p.base.a_type)),
|
||||
out_params
|
||||
.iter()
|
||||
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||
in_params
|
||||
.iter()
|
||||
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||
),
|
||||
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
|
||||
builder,
|
||||
iter::empty(),
|
||||
params.iter().map(|p| SpirvType::from(p.a_type)),
|
||||
params
|
||||
.iter()
|
||||
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -870,31 +874,38 @@ fn emit_function_body_ops(
|
|||
Statement::Label(_) => (),
|
||||
Statement::Call(call) => {
|
||||
let (result_type, result_id) = match &*call.ret_params {
|
||||
[p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id),
|
||||
[(id, typ)] => (
|
||||
map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
|
||||
*id,
|
||||
),
|
||||
_ => todo!(),
|
||||
};
|
||||
let arg_list = call.param_list.iter().map(|p| p.id).collect::<Vec<_>>();
|
||||
let arg_list = call
|
||||
.param_list
|
||||
.iter()
|
||||
.map(|(id, _)| *id)
|
||||
.collect::<Vec<_>>();
|
||||
builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
|
||||
}
|
||||
Statement::Variable(VariableDecl {
|
||||
name: id,
|
||||
v_type: typ,
|
||||
space: ss,
|
||||
Statement::Variable(ast::Variable {
|
||||
align,
|
||||
v_type,
|
||||
name,
|
||||
}) => {
|
||||
let type_id = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
|
||||
SpirvType::new_pointer(ast::Type::from(*v_type), spirv::StorageClass::Function),
|
||||
);
|
||||
let st_class = match ss {
|
||||
ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function,
|
||||
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
|
||||
_ => todo!(),
|
||||
let st_class = match v_type {
|
||||
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
|
||||
spirv::StorageClass::Function
|
||||
}
|
||||
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
|
||||
};
|
||||
builder.variable(type_id, Some(*id), st_class, None);
|
||||
builder.variable(type_id, Some(*name), st_class, None);
|
||||
if let Some(align) = align {
|
||||
builder.decorate(
|
||||
*id,
|
||||
*name,
|
||||
spirv::Decoration::Alignment,
|
||||
&[dr::Operand::LiteralInt32(*align)],
|
||||
);
|
||||
|
@ -1051,7 +1062,7 @@ fn emit_cvt(
|
|||
if desc.saturate || desc.flush_to_zero {
|
||||
todo!()
|
||||
}
|
||||
let dest_t: ast::Type = desc.dst.into();
|
||||
let dest_t: ast::ScalarType = desc.dst.into();
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
|
||||
emit_rounding_decoration(builder, arg.dst, desc.rounding);
|
||||
|
@ -1060,7 +1071,7 @@ fn emit_cvt(
|
|||
if desc.saturate || desc.flush_to_zero {
|
||||
todo!()
|
||||
}
|
||||
let dest_t: ast::Type = desc.dst.into();
|
||||
let dest_t: ast::ScalarType = desc.dst.into();
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
|
||||
if desc.src.is_signed() {
|
||||
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
|
||||
|
@ -1367,7 +1378,7 @@ fn normalize_identifiers<'a, 'b>(
|
|||
|
||||
fn expand_map_variables<'a, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver,
|
||||
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
s: ast::Statement<ast::ParsedArgParams<'a>>,
|
||||
) {
|
||||
|
@ -1386,21 +1397,19 @@ fn expand_map_variables<'a, 'b>(
|
|||
))),
|
||||
ast::Statement::Variable(var) => match var.count {
|
||||
Some(count) => {
|
||||
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
|
||||
result.push(Statement::Variable(VariableDecl {
|
||||
space: var.space,
|
||||
align: var.align,
|
||||
v_type: var.v_type,
|
||||
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type,
|
||||
name: new_id,
|
||||
}))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let new_id = id_defs.add_def(var.name, Some(var.v_type));
|
||||
result.push(Statement::Variable(VariableDecl {
|
||||
space: var.space,
|
||||
align: var.align,
|
||||
v_type: var.v_type,
|
||||
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type,
|
||||
name: new_id,
|
||||
}));
|
||||
}
|
||||
|
@ -1408,15 +1417,38 @@ fn expand_map_variables<'a, 'b>(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
enum PtxSpecialRegister {
|
||||
Tid,
|
||||
Ntid,
|
||||
Ctaid,
|
||||
Nctaid,
|
||||
Gridid,
|
||||
}
|
||||
|
||||
impl PtxSpecialRegister {
|
||||
fn try_parse(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"%tid" => Some(Self::Tid),
|
||||
"%ntid" => Some(Self::Ntid),
|
||||
"%ctaid" => Some(Self::Ctaid),
|
||||
"%nctaid" => Some(Self::Nctaid),
|
||||
"%gridid" => Some(Self::Gridid),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalStringIdResolver<'input> {
|
||||
current_id: spirv::Word,
|
||||
variables: HashMap<Cow<'input, str>, spirv::Word>,
|
||||
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
fns: HashMap<spirv::Word, FnDecl>,
|
||||
}
|
||||
|
||||
pub struct FnDecl {
|
||||
ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
|
||||
params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
|
||||
ret_vals: Vec<ast::FnArgumentType>,
|
||||
params: Vec<ast::FnArgumentType>,
|
||||
}
|
||||
|
||||
impl<'a> GlobalStringIdResolver<'a> {
|
||||
|
@ -1424,6 +1456,7 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
Self {
|
||||
current_id: start_id,
|
||||
variables: HashMap::new(),
|
||||
special_registers: HashMap::new(),
|
||||
fns: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
@ -1461,6 +1494,7 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
let mut fn_resolver = FnStringIdResolver {
|
||||
current_id: &mut self.current_id,
|
||||
global_variables: &self.variables,
|
||||
special_registers: &mut self.special_registers,
|
||||
variables: vec![HashMap::new(); 1],
|
||||
type_check: HashMap::new(),
|
||||
};
|
||||
|
@ -1474,14 +1508,8 @@ impl<'a> GlobalStringIdResolver<'a> {
|
|||
self.fns.insert(
|
||||
name_id,
|
||||
FnDecl {
|
||||
ret_vals: ret_params_ids
|
||||
.iter()
|
||||
.map(|p| (p.state_space, p.base.a_type))
|
||||
.collect(),
|
||||
params: params_ids
|
||||
.iter()
|
||||
.map(|p| (p.state_space, p.base.a_type))
|
||||
.collect(),
|
||||
ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
|
||||
params: params_ids.iter().map(|p| p.v_type).collect(),
|
||||
},
|
||||
);
|
||||
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
|
||||
|
@ -1516,7 +1544,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
|
|||
struct FnStringIdResolver<'input, 'b> {
|
||||
current_id: &'b mut spirv::Word,
|
||||
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
|
||||
//global: &'b mut GlobalStringIdResolver<'a>,
|
||||
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
|
||||
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
|
||||
type_check: HashMap<u32, ast::Type>,
|
||||
}
|
||||
|
@ -1537,14 +1565,28 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
|
|||
self.variables.pop();
|
||||
}
|
||||
|
||||
fn get_id(&self, id: &str) -> spirv::Word {
|
||||
fn get_id(&mut self, id: &str) -> spirv::Word {
|
||||
for scope in self.variables.iter().rev() {
|
||||
match scope.get(id) {
|
||||
Some(id) => return *id,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
self.global_variables[id]
|
||||
match self.global_variables.get(id) {
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!());
|
||||
match self.special_registers.entry(sreg) {
|
||||
hash_map::Entry::Occupied(e) => *e.get(),
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
let numeric_id = *self.current_id;
|
||||
*self.current_id += 1;
|
||||
e.insert(numeric_id);
|
||||
numeric_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
|
||||
|
@ -1602,7 +1644,7 @@ impl<'b> NumericIdResolver<'b> {
|
|||
|
||||
enum Statement<I, P: ast::ArgParams> {
|
||||
Label(u32),
|
||||
Variable(VariableDecl),
|
||||
Variable(ast::Variable<ast::VariableType, P>),
|
||||
Instruction(I),
|
||||
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
|
||||
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
|
||||
|
@ -1614,18 +1656,11 @@ enum Statement<I, P: ast::ArgParams> {
|
|||
RetValue(ast::RetData, spirv::Word),
|
||||
}
|
||||
|
||||
struct VariableDecl {
|
||||
pub space: ast::StateSpace,
|
||||
pub align: Option<u32>,
|
||||
pub v_type: ast::Type,
|
||||
pub name: spirv::Word,
|
||||
}
|
||||
|
||||
struct ResolvedCall<P: ast::ArgParams> {
|
||||
pub uniform: bool,
|
||||
pub ret_params: Vec<ArgCall<spirv::Word>>,
|
||||
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
|
||||
pub func: spirv::Word,
|
||||
pub param_list: Vec<ArgCall<P::CallOperand>>,
|
||||
pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
|
||||
}
|
||||
|
||||
impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
||||
|
@ -1636,18 +1671,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
|||
let ret_params = self
|
||||
.ret_params
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
.map(|(id, typ)| {
|
||||
let new_id = visitor.variable(ArgumentDescriptor {
|
||||
op: p.id,
|
||||
typ: Some(p.typ),
|
||||
op: id,
|
||||
typ: Some(typ.into()),
|
||||
is_dst: true,
|
||||
is_pointer: false,
|
||||
});
|
||||
ArgCall {
|
||||
id: new_id,
|
||||
typ: p.typ,
|
||||
space: p.space,
|
||||
}
|
||||
(new_id, typ)
|
||||
})
|
||||
.collect();
|
||||
let func = visitor.variable(ArgumentDescriptor {
|
||||
|
@ -1659,18 +1690,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
|
|||
let param_list = self
|
||||
.param_list
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
.map(|(id, typ)| {
|
||||
let new_id = visitor.src_call_operand(ArgumentDescriptor {
|
||||
op: p.id,
|
||||
typ: Some(p.typ),
|
||||
op: id,
|
||||
typ: Some(typ.into()),
|
||||
is_dst: false,
|
||||
is_pointer: false,
|
||||
});
|
||||
ArgCall {
|
||||
id: new_id,
|
||||
typ: p.typ,
|
||||
space: p.space,
|
||||
}
|
||||
(new_id, typ)
|
||||
})
|
||||
.collect();
|
||||
ResolvedCall {
|
||||
|
@ -1700,12 +1727,6 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
|
|||
}
|
||||
}
|
||||
|
||||
struct ArgCall<ID> {
|
||||
id: ID,
|
||||
typ: ast::Type,
|
||||
space: ast::FnArgStateSpace,
|
||||
}
|
||||
|
||||
pub trait ArgParamsEx: ast::ArgParams {
|
||||
fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
|
||||
}
|
||||
|
@ -1817,7 +1838,9 @@ where
|
|||
) -> ast::MovOperand<spirv::Word> {
|
||||
match desc.op {
|
||||
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
|
||||
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
|
||||
ast::MovOperand::Vec(reg, x2) => {
|
||||
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1881,13 +1904,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
|
|||
}
|
||||
ast::Instruction::Cvt(d, a) => {
|
||||
let (dst_t, src_t) = match &d {
|
||||
ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()),
|
||||
ast::CvtDetails::FloatFromInt(desc) => {
|
||||
(desc.dst.into(), ast::Type::Scalar(desc.src.into()))
|
||||
}
|
||||
ast::CvtDetails::IntFromFloat(desc) => {
|
||||
(ast::Type::Scalar(desc.dst.into()), desc.src.into())
|
||||
}
|
||||
ast::CvtDetails::FloatFromFloat(desc) => (
|
||||
ast::Type::Scalar(desc.dst.into()),
|
||||
ast::Type::Scalar(desc.src.into()),
|
||||
),
|
||||
ast::CvtDetails::FloatFromInt(desc) => (
|
||||
ast::Type::Scalar(desc.dst.into()),
|
||||
ast::Type::Scalar(desc.src.into()),
|
||||
),
|
||||
ast::CvtDetails::IntFromFloat(desc) => (
|
||||
ast::Type::Scalar(desc.dst.into()),
|
||||
ast::Type::Scalar(desc.src.into()),
|
||||
),
|
||||
ast::CvtDetails::IntFromInt(desc) => (
|
||||
ast::Type::Scalar(desc.dst.into()),
|
||||
ast::Type::Scalar(desc.src.into()),
|
||||
|
@ -2261,14 +2289,14 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
|
|||
ast::Arg4 {
|
||||
dst1: visitor.variable(ArgumentDescriptor {
|
||||
op: self.dst1,
|
||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
||||
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||
is_dst: true,
|
||||
is_pointer: false,
|
||||
}),
|
||||
dst2: self.dst2.map(|dst2| {
|
||||
visitor.variable(ArgumentDescriptor {
|
||||
op: dst2,
|
||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
||||
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||
is_dst: true,
|
||||
is_pointer: false,
|
||||
})
|
||||
|
@ -2298,14 +2326,14 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||
ast::Arg5 {
|
||||
dst1: visitor.variable(ArgumentDescriptor {
|
||||
op: self.dst1,
|
||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
||||
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||
is_dst: true,
|
||||
is_pointer: false,
|
||||
}),
|
||||
dst2: self.dst2.map(|dst2| {
|
||||
visitor.variable(ArgumentDescriptor {
|
||||
op: dst2,
|
||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
||||
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||
is_dst: true,
|
||||
is_pointer: false,
|
||||
})
|
||||
|
@ -2324,7 +2352,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||
}),
|
||||
src3: visitor.operand(ArgumentDescriptor {
|
||||
op: self.src3,
|
||||
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
|
||||
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
|
||||
is_dst: false,
|
||||
is_pointer: false,
|
||||
}),
|
||||
|
@ -2332,65 +2360,6 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
impl<T: ArgParamsEx> ast::ArgCall<T> {
|
||||
fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
|
||||
self,
|
||||
visitor: &mut V,
|
||||
fn_resolve: &GlobalFnDeclResolver<'a>,
|
||||
) -> ast::ArgCall<U> {
|
||||
// TODO: error out if lengths don't match
|
||||
let fn_decl = T::get_fn_decl(&self.func, fn_resolve);
|
||||
let ret_params = self
|
||||
.ret_params
|
||||
.into_iter()
|
||||
.zip(fn_decl.ret_vals.iter().copied())
|
||||
.map(|(a, (space, typ))| {
|
||||
visitor.variable(ArgumentDescriptor {
|
||||
op: a,
|
||||
typ: Some(ast::Type::Scalar(typ)),
|
||||
is_dst: true,
|
||||
is_pointer: if space == ast::FnArgStateSpace::Reg {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let func = visitor.variable(ArgumentDescriptor {
|
||||
op: self.func,
|
||||
typ: None,
|
||||
is_dst: false,
|
||||
is_pointer: false,
|
||||
});
|
||||
let param_list = self
|
||||
.param_list
|
||||
.into_iter()
|
||||
.zip(fn_decl.params.iter().copied())
|
||||
.map(|(a, (space, typ))| {
|
||||
visitor.src_call_operand(ArgumentDescriptor {
|
||||
op: a,
|
||||
typ: Some(ast::Type::Scalar(typ)),
|
||||
is_dst: false,
|
||||
is_pointer: if space == ast::FnArgStateSpace::Reg {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
ast::ArgCall {
|
||||
uniform: false,
|
||||
ret_params,
|
||||
func: func,
|
||||
param_list: param_list,
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
impl<T> ast::CallOperand<T> {
|
||||
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
|
||||
match self {
|
||||
|
@ -2418,6 +2387,8 @@ enum ScalarKind {
|
|||
Unsigned,
|
||||
Signed,
|
||||
Float,
|
||||
Float2,
|
||||
Pred,
|
||||
}
|
||||
|
||||
impl ast::ScalarType {
|
||||
|
@ -2438,6 +2409,8 @@ impl ast::ScalarType {
|
|||
ast::ScalarType::S64 => 8,
|
||||
ast::ScalarType::B64 => 8,
|
||||
ast::ScalarType::F64 => 8,
|
||||
ast::ScalarType::F16x2 => 4,
|
||||
ast::ScalarType::Pred => 1,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2458,6 +2431,8 @@ impl ast::ScalarType {
|
|||
ast::ScalarType::F16 => ScalarKind::Float,
|
||||
ast::ScalarType::F32 => ScalarKind::Float,
|
||||
ast::ScalarType::F64 => ScalarKind::Float,
|
||||
ast::ScalarType::F16x2 => ScalarKind::Float,
|
||||
ast::ScalarType::Pred => ScalarKind::Pred,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2490,6 +2465,11 @@ impl ast::ScalarType {
|
|||
8 => ast::ScalarType::U64,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Float2 => match width {
|
||||
4 => ast::ScalarType::F16x2,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
ScalarKind::Pred => ast::ScalarType::Pred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2497,7 +2477,7 @@ impl ast::ScalarType {
|
|||
impl ast::NotType {
|
||||
fn to_type(self) -> ast::Type {
|
||||
match self {
|
||||
ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
||||
ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
|
||||
ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
|
||||
ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
|
||||
ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
|
||||
|
@ -2519,7 +2499,9 @@ impl ast::AddDetails {
|
|||
fn get_type(&self) -> ast::Type {
|
||||
match self {
|
||||
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(),
|
||||
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
|
||||
ast::Type::Scalar((*typ).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2528,7 +2510,9 @@ impl ast::MulDetails {
|
|||
fn get_type(&self) -> ast::Type {
|
||||
match self {
|
||||
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
|
||||
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(),
|
||||
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
|
||||
ast::Type::Scalar((*typ).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2560,6 +2544,15 @@ impl ast::LdStateSpace {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ast::FnArgumentType> for ast::VariableType {
|
||||
fn from(t: ast::FnArgumentType) -> Self {
|
||||
match t {
|
||||
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
|
||||
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
|
@ -2575,6 +2568,8 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
|||
ScalarKind::Unsigned => {
|
||||
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
|
||||
}
|
||||
ScalarKind::Float2 => todo!(),
|
||||
ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
_ => false,
|
||||
|
@ -2758,6 +2753,8 @@ fn should_convert_relaxed_src(
|
|||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float2 => todo!(),
|
||||
ScalarKind::Pred => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
|
@ -2807,6 +2804,8 @@ fn should_convert_relaxed_dst(
|
|||
None
|
||||
}
|
||||
}
|
||||
ScalarKind::Float2 => todo!(),
|
||||
ScalarKind::Pred => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
|
@ -2862,16 +2861,21 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> {
|
||||
fn visit_args(&self, f: impl FnMut(&ast::KernelArgument<P>)) {
|
||||
impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
|
||||
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
|
||||
match self {
|
||||
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f),
|
||||
ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f),
|
||||
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
|
||||
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
|
||||
f(&ast::FnArgument {
|
||||
align: arg.align,
|
||||
name: arg.name,
|
||||
v_type: ast::FnArgumentType::Param(arg.v_type),
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CFGs below taken from "Modern Compiler Implementation in Java"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
Loading…
Add table
Reference in a new issue