Be more precise about types admitted in register definitions and method arguments

This commit is contained in:
Andrzej Janik 2020-09-11 00:40:13 +02:00
parent 76afbeba63
commit 1238796dfd
7 changed files with 647 additions and 351 deletions

View file

@ -12,9 +12,117 @@ quick_error! {
SyntaxError {}
NonF32Ftz {}
WrongArrayType {}
WrongVectorElement {}
MultiArrayVariable {}
}
}
macro_rules! sub_scalar_type {
($name:ident { $($variant:ident),+ $(,)? }) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $name {
$(
$variant,
)+
}
impl From<$name> for ScalarType {
fn from(t: $name) -> ScalarType {
match t {
$(
$name::$variant => ScalarType::$variant,
)+
}
}
}
};
}
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $type_name {
$(
$variant ($($field_type),+),
)+
}
impl From<$type_name> for Type {
#[allow(non_snake_case)]
fn from(t: $type_name) -> Type {
match t {
$(
$type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
)+
}
}
}
};
}
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(SizedScalarType, u8),
}
}
sub_type! {
VariableLocalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
Array(SizedScalarType, u32),
}
}
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
Array(SizedScalarType, u32),
}
}
sub_scalar_type!(SizedScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F16x2,
F32,
F64,
});
sub_scalar_type!(ParamScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F32,
F64,
});
pub trait UnwrapWithVec<E, To> {
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
}
@ -56,6 +164,9 @@ pub enum MethodDecl<'a, P: ArgParams> {
Kernel(&'a str, Vec<KernelArgument<P>>),
}
pub type FnArgument<P: ArgParams> = Variable<FnArgumentType, P>;
pub type KernelArgument<P: ArgParams> = Variable<VariableParamType, P>;
pub struct Function<'a, P: ArgParams, S> {
pub func_directive: MethodDecl<'a, P>,
pub body: Option<Vec<S>>,
@ -63,43 +174,28 @@ pub struct Function<'a, P: ArgParams, S> {
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
pub struct FnArgument<P: ArgParams> {
pub base: KernelArgument<P>,
pub state_space: FnArgStateSpace,
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum FnArgStateSpace {
Reg,
Param,
}
#[derive(Default, Copy, Clone)]
pub struct KernelArgument<P: ArgParams> {
pub name: P::ID,
pub a_type: ScalarType,
// TODO: turn length into part of type definition
pub length: u32,
impl From<FnArgumentType> for Type {
fn from(t: FnArgumentType) -> Self {
match t {
FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(),
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum Type {
Scalar(ScalarType),
ExtendedScalar(ExtendedScalarType),
Vector(ScalarType, u8),
Array(ScalarType, u32),
}
impl From<FloatType> for Type {
fn from(t: FloatType) -> Self {
match t {
FloatType::F16 => Type::Scalar(ScalarType::F16),
FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2),
FloatType::F32 => Type::Scalar(ScalarType::F32),
FloatType::F64 => Type::Scalar(ScalarType::F64),
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ScalarType {
B8,
@ -117,25 +213,11 @@ pub enum ScalarType {
F16,
F32,
F64,
F16x2,
Pred,
}
impl From<IntType> for ScalarType {
fn from(t: IntType) -> Self {
match t {
IntType::S8 => ScalarType::S8,
IntType::S16 => ScalarType::S16,
IntType::S32 => ScalarType::S32,
IntType::S64 => ScalarType::S64,
IntType::U8 => ScalarType::U8,
IntType::U16 => ScalarType::U16,
IntType::U32 => ScalarType::U32,
IntType::U64 => ScalarType::U64,
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum IntType {
sub_scalar_type!(IntType {
U8,
U16,
U32,
@ -143,8 +225,8 @@ pub enum IntType {
S8,
S16,
S32,
S64,
}
S64
});
impl IntType {
pub fn is_signed(self) -> bool {
@ -168,19 +250,12 @@ impl IntType {
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum FloatType {
sub_scalar_type!(FloatType {
F16,
F16x2,
F32,
F64,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExtendedScalarType {
F16x2,
Pred,
}
F64
});
impl Default for ScalarType {
fn default() -> Self {
@ -190,19 +265,39 @@ impl Default for ScalarType {
pub enum Statement<P: ArgParams> {
Label(P::ID),
Variable(Variable<P>),
Variable(MultiVariable<P>),
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
Block(Vec<Statement<P>>),
}
pub struct Variable<P: ArgParams> {
pub space: StateSpace,
pub align: Option<u32>,
pub v_type: Type,
pub name: P::ID,
pub struct MultiVariable<P: ArgParams> {
pub var: Variable<VariableType, P>,
pub count: Option<u32>,
}
pub struct Variable<T, P: ArgParams> {
pub align: Option<u32>,
pub v_type: T,
pub name: P::ID,
}
#[derive(Eq, PartialEq, Copy, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
}
impl From<VariableType> for Type {
fn from(t: VariableType) -> Self {
match t {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
@ -322,7 +417,7 @@ pub enum CallOperand<ID> {
pub enum MovOperand<ID> {
Op(Operand<ID>),
Vec(String, String),
Vec(ID, u8),
}
pub enum VectorPrefix {
@ -334,7 +429,7 @@ pub struct LdData {
pub qualifier: LdStQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
pub vector: Option<VectorPrefix>,
pub vector: Option<u8>,
pub typ: ScalarType,
}
@ -376,6 +471,37 @@ pub struct MovData {
pub typ: Type,
}
sub_scalar_type!(MovScalarType {
B16,
B32,
B64,
U16,
U32,
U64,
S16,
S32,
S64,
F32,
F64,
Pred,
});
enum MovType {
Scalar(MovScalarType),
Vector(MovScalarType, u8),
Array(MovScalarType, u32),
}
impl From<MovType> for Type {
fn from(t: MovType) -> Self {
match t {
MovType::Scalar(t) => Type::Scalar(t.into()),
MovType::Vector(t, len) => Type::Vector(t.into(), len),
MovType::Array(t, len) => Type::Array(t.into(), len),
}
}
}
pub enum MulDetails {
Int(MulIntDesc),
Float(MulFloatDesc),
@ -587,7 +713,7 @@ pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
pub caching: StCacheOperator,
pub vector: Option<VectorPrefix>,
pub vector: Option<u8>,
pub typ: ScalarType,
}

View file

@ -21,6 +21,7 @@ match {
"@",
"[", "]",
"{", "}",
"<", ">",
"|",
".acquire",
".address_size",
@ -133,8 +134,6 @@ match {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
r"\.[a-zA-Z][a-zA-Z0-9_$]*" => DotID,
} else {
r"(?:[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+)<[0-9]+>" => ParametrizedID,
}
ExtendedID : &'input str = {
@ -214,7 +213,9 @@ LinkingDirective = {
MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
}
};
KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
@ -225,32 +226,25 @@ FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
"(" <args:Comma<FnInput>> ")" => args
};
FnInput: ast::FnArgument<ast::ParsedArgParams<'input>> = {
".reg" <_type:ScalarType> <name:ExtendedID> => {
ast::FnArgument {
base: ast::KernelArgument {a_type: _type, name: name, length: 1 },
state_space: ast::FnArgStateSpace::Reg,
}
},
<p:KernelInput> => {
ast::FnArgument {
base: p,
state_space: ast::FnArgStateSpace::Param,
}
KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = {
<v:ParamVariable> => {
let (align, v_type, name) = v;
ast::Variable{ align, v_type, name }
}
};
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
KernelInput: ast::KernelArgument<ast::ParsedArgParams<'input>> = {
".param" <_type:ScalarType> <name:ExtendedID> => {
ast::KernelArgument {a_type: _type, name: name, length: 1 }
FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type);
ast::Variable{ align, v_type, name }
},
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => {
let length = length.parse::<u32>();
let length = length.unwrap_with(errors);
ast::KernelArgument { a_type: a_type, name: name, length: length }
<v:ParamVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Param(v_type);
ast::Variable{ align, v_type, name }
}
};
}
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
"{" <s:Statement*> "}" => { Some(without_none(s)) },
@ -267,22 +261,13 @@ StateSpaceSpecifier: ast::StateSpace = {
".param" => ast::StateSpace::Param, // used to prepare function call
};
Type: ast::Type = {
<t:ScalarType> => ast::Type::Scalar(t),
<t:ExtendedScalarType> => ast::Type::ExtendedScalar(t),
};
ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".pred" => ast::ScalarType::Pred,
MemoryType
};
ExtendedScalarType: ast::ExtendedScalarType = {
".f16x2" => ast::ExtendedScalarType::F16x2,
".pred" => ast::ExtendedScalarType::Pred,
};
MemoryType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
@ -303,7 +288,7 @@ MemoryType: ast::ScalarType = {
Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
};
@ -328,21 +313,109 @@ Align: u32 = {
}
};
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
<s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => {
let (name, count) = v;
let t = match (t, arr) {
(ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
(t, Some(_)) => {
errors.push(ast::PtxError::WrongArrayType);
t
},
(t, None) => t,
};
ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = {
<var:Variable> <count:VariableParam?> => ast::MultiVariable{<>}
}
VariableParam: u32 = {
"<" <n:Num> ">" => {
let size = n.parse::<u32>();
size.unwrap_with(errors)
}
}
Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
ast::Variable {align, v_type, name}
},
LocalVariable,
<v:ParamVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type);
ast::Variable {align, v_type, name}
},
};
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
let v_type = ast::VariableRegType::Scalar(t);
(align, v_type, name)
},
".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableRegType::Vector(t, v_len);
(align, v_type, name)
}
}
LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
ast::Variable {align, v_type, name}
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
let v_type = ast::VariableParamType::Scalar(t);
(align, v_type, name)
},
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableParamType::Array(t, arr);
(align, v_type, name)
}
}
#[inline]
SizedScalarType: ast::SizedScalarType = {
".b8" => ast::SizedScalarType::B8,
".b16" => ast::SizedScalarType::B16,
".b32" => ast::SizedScalarType::B32,
".b64" => ast::SizedScalarType::B64,
".u8" => ast::SizedScalarType::U8,
".u16" => ast::SizedScalarType::U16,
".u32" => ast::SizedScalarType::U32,
".u64" => ast::SizedScalarType::U64,
".s8" => ast::SizedScalarType::S8,
".s16" => ast::SizedScalarType::S16,
".s32" => ast::SizedScalarType::S32,
".s64" => ast::SizedScalarType::S64,
".f16" => ast::SizedScalarType::F16,
".f16x2" => ast::SizedScalarType::F16x2,
".f32" => ast::SizedScalarType::F32,
".f64" => ast::SizedScalarType::F64,
}
#[inline]
ParamScalarType: ast::ParamScalarType = {
".b8" => ast::ParamScalarType::B8,
".b16" => ast::ParamScalarType::B16,
".b32" => ast::ParamScalarType::B32,
".b64" => ast::ParamScalarType::B64,
".u8" => ast::ParamScalarType::U8,
".u16" => ast::ParamScalarType::U16,
".u32" => ast::ParamScalarType::U32,
".u64" => ast::ParamScalarType::U64,
".s8" => ast::ParamScalarType::S8,
".s16" => ast::ParamScalarType::S16,
".s32" => ast::ParamScalarType::S32,
".s64" => ast::ParamScalarType::S64,
".f16" => ast::ParamScalarType::F16,
".f32" => ast::ParamScalarType::F32,
".f64" => ast::ParamScalarType::F64,
}
ArraySpecifier: u32 = {
"[" <n:Num> "]" => {
let size = n.parse::<u32>();
@ -350,20 +423,6 @@ ArraySpecifier: u32 = {
}
};
VariableName: (&'input str, Option<u32>) = {
<id:ExtendedID> => (id, None),
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
<id:ParametrizedID> => {
let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap();
let count = id[left_angle+1..id.len()-1].parse::<u32>();
let count = match count {
Ok(c) => Some(c),
Err(e) => { errors.push(e.into()); None },
};
(&id[0..left_angle], count)
}
};
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
InstMov,
@ -445,7 +504,7 @@ MovType: ast::Type = {
".s64" => ast::Type::Scalar(ast::ScalarType::S64),
".f32" => ast::Type::Scalar(ast::ScalarType::F32),
".f64" => ast::Type::Scalar(ast::ScalarType::F64),
".pred" => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)
".pred" => ast::Type::Scalar(ast::ScalarType::Pred)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@ -934,7 +993,17 @@ MovOperand: ast::MovOperand<&'input str> = {
<o:Operand> => ast::MovOperand::Op(o),
<o:VectorOperand> => {
let (pref, suf) = o;
ast::MovOperand::Vec(pref.to_string(), suf.to_string())
let suf_idx = match suf {
"x" | "r" => 0,
"y" | "g" => 1,
"z" | "b" => 2,
"w" | "a" => 3,
_ => {
errors.push(ast::PtxError::WrongVectorElement);
0
}
};
ast::MovOperand::Vec(pref, suf_idx)
}
};
@ -980,9 +1049,9 @@ OptionalDst: &'input str = {
"|" <dst2:ExtendedID> => dst2
}
VectorPrefix: ast::VectorPrefix = {
".v2" => ast::VectorPrefix::V2,
".v4" => ast::VectorPrefix::V4
VectorPrefix: u8 = {
".v2" => 2,
".v4" => 4
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#debugging-directives-file

View file

@ -8,16 +8,16 @@ fn parse_and_assert(s: &str) {
assert!(errors.len() == 0);
}
#[test]
fn empty() {
parse_and_assert(".version 6.5 .target sm_30, debug");
fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
crate::to_spirv(ast)?;
Ok(())
}
#[test]
#[allow(non_snake_case)]
fn vectorAdd_kernel64_ptx() {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add);
fn empty() {
parse_and_assert(".version 6.5 .target sm_30, debug");
}
#[test]
@ -28,8 +28,14 @@ fn operands_ptx() {
#[test]
#[allow(non_snake_case)]
fn _Z9vectorAddPKfS0_Pfi_ptx() {
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
parse_and_assert(vector_add);
fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
compile_and_assert(vector_add)
}
#[test]
#[allow(non_snake_case)]
fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> {
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
compile_and_assert(vector_add)
}

View file

@ -54,6 +54,7 @@ test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
struct DisplayError<T: Debug> {
err: T,

View file

@ -0,0 +1,44 @@
// Excersise as many features of vector types as possible
.version 6.5
.target sm_53
.address_size 64
.func (.reg .v2 .u32 output) impl(
.reg .v2 .u32 input
)
{
.reg .v2 .u32 temp_v;
.reg .u32 temp1;
.reg .u32 temp2;
mov.u32 temp1, input.x;
mov.u32 temp2, input.y;
add.u32 temp2, temp1, temp2;
mov.u32 temp_v.x, temp2;
mov.u32 temp_v.y, temp2;
mov.v2.u32 output, temp_v;
ret;
}
.visible .entry vector(
.param .u64 input_p,
.param .u64 output_p
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .v2 .u32 temp;
.reg .u32 temp1;
.reg .u32 temp2;
.reg .b64 packed;
ld.param.u64 in_addr, [input_p];
ld.param.u64 out_addr, [output_p];
ld.v2.u32 temp, [in_addr];
call (temp), impl, (temp);
mov.b64 packed, temp;
st.v2.u32 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,46 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%1 = OpFunction %void None %28
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%23 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %8
OpStore %3 %9
%11 = OpLoad %ulong %2
%10 = OpCopyObject %ulong %11
OpStore %4 %10
%13 = OpLoad %ulong %3
%12 = OpCopyObject %ulong %13
OpStore %5 %12
%15 = OpLoad %ulong %4
%21 = OpConvertUToPtr %_ptr_Generic_ulong %15
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
%16 = OpIAdd %ulong %17 %ulong_1
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
%22 = OpConvertUToPtr %_ptr_Generic_ulong %18
OpStore %22 %19
OpReturn
OpFunctionEnd

View file

@ -8,6 +8,7 @@ use rspirv::binary::Assemble;
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
Vector(SpirvScalarKey, u8),
Array(SpirvScalarKey, u32),
Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
@ -17,7 +18,7 @@ impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = match t {
ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
};
SpirvType::Pointer(Box::new(key), sc)
@ -28,7 +29,7 @@ impl From<ast::Type> for SpirvType {
fn from(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
}
}
@ -77,15 +78,8 @@ impl From<ast::ScalarType> for SpirvScalarKey {
ast::ScalarType::F16 => SpirvScalarKey::F16,
ast::ScalarType::F32 => SpirvScalarKey::F32,
ast::ScalarType::F64 => SpirvScalarKey::F64,
}
}
}
impl From<ast::ExtendedScalarType> for SpirvScalarKey {
fn from(t: ast::ExtendedScalarType) -> Self {
match t {
ast::ExtendedScalarType::Pred => SpirvScalarKey::Pred,
ast::ExtendedScalarType::F16x2 => SpirvScalarKey::F16x2,
ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
ast::ScalarType::Pred => SpirvScalarKey::Pred,
}
}
}
@ -135,6 +129,13 @@ impl TypeWordMap {
.entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base))
}
SpirvType::Vector(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
.complex
.entry(t)
.or_insert_with(|| b.type_vector(base, len as u32))
}
SpirvType::Array(typ, len) => {
let base = self.get_or_add_spirv_scalar(b, typ);
*self
@ -232,8 +233,8 @@ fn emit_function_header<'a>(
spirv::FunctionControl::NONE,
func_type,
)?;
func_directive.visit_args(|arg| {
let result_type = map.get_or_add_scalar(builder, arg.a_type);
func_directive.visit_args(&mut |arg| {
let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
@ -285,9 +286,9 @@ fn expand_kernel_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
args.map(|a| ast::KernelArgument {
name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
a_type: a.a_type,
length: a.length,
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
v_type: a.v_type,
align: a.align,
})
.collect()
}
@ -297,12 +298,9 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
args.map(|a| ast::FnArgument {
state_space: a.state_space,
base: ast::KernelArgument {
name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))),
a_type: a.base.a_type,
length: a.base.length,
},
name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
v_type: a.v_type,
align: a.align,
})
.collect()
}
@ -375,16 +373,12 @@ fn resolve_fn_calls(
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)],
) -> Vec<ArgCall<T>> {
params_decl: &[ast::FnArgumentType],
) -> Vec<(T, ast::FnArgumentType)> {
params
.into_iter()
.zip(params_decl.iter())
.map(|(id, &(space, typ))| ArgCall {
id,
typ: ast::Type::Scalar(typ),
space: space,
})
.map(|(id, typ)| (id, *typ))
.collect::<Vec<_>>()
}
@ -476,12 +470,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
let typ = ast::Type::Scalar(p.a_type);
let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ));
result.push(Statement::Variable(VariableDecl {
space: ast::StateSpace::Reg,
align: None,
v_type: typ,
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: ast::VariableType::Param(p.v_type),
name: p.name,
}));
result.push(Statement::StoreVar(
@ -497,32 +490,31 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
let typ = ast::Type::Scalar(p.base.a_type);
let typ = ast::Type::from(p.v_type);
let new_id = id_def.new_id(Some(typ));
result.push(Statement::Variable(VariableDecl {
space: ast::StateSpace::Reg,
align: None,
v_type: typ,
name: p.base.name,
let var_typ = ast::VariableType::from(p.v_type);
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: var_typ,
name: p.name,
}));
result.push(Statement::StoreVar(
ast::Arg2St {
src1: p.base.name,
src1: p.name,
src2: new_id,
},
typ,
));
p.base.name = new_id;
p.name = new_id;
}
match &mut **out_params {
[p] => {
result.push(Statement::Variable(VariableDecl {
space: ast::StateSpace::Reg,
align: None,
v_type: ast::Type::Scalar(p.base.a_type),
name: p.base.name,
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: ast::VariableType::from(p.v_type),
name: p.name,
}));
Some(p.base.name)
Some(p.name)
}
[] => None,
_ => todo!(),
@ -552,15 +544,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
},
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
ast::ExtendedScalarType::Pred,
)));
let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: bra.predicate,
},
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
ast::Type::Scalar(ast::ScalarType::Pred),
));
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
@ -642,7 +632,15 @@ fn expand_arguments<'a, 'b>(
let new_inst = inst.map(&mut visitor);
result.push(Statement::Instruction(new_inst));
}
Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)),
Statement::Variable(ast::Variable {
align,
v_type,
name,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
name,
})),
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@ -745,7 +743,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
) -> spirv::Word {
match &desc.op {
ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
ast::MovOperand::Vec(_, _) => todo!(),
ast::MovOperand::Vec(opr, _) => self.variable(desc.new_op(*opr)),
}
}
}
@ -835,13 +833,19 @@ fn get_function_type(
match method_decl {
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
builder,
out_params.iter().map(|p| SpirvType::from(p.base.a_type)),
in_params.iter().map(|p| SpirvType::from(p.base.a_type)),
out_params
.iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
in_params
.iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
),
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
builder,
iter::empty(),
params.iter().map(|p| SpirvType::from(p.a_type)),
params
.iter()
.map(|p| SpirvType::from(ast::Type::from(p.v_type))),
),
}
}
@ -870,31 +874,38 @@ fn emit_function_body_ops(
Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id),
[(id, typ)] => (
map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
*id,
),
_ => todo!(),
};
let arg_list = call.param_list.iter().map(|p| p.id).collect::<Vec<_>>();
let arg_list = call
.param_list
.iter()
.map(|(id, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
}
Statement::Variable(VariableDecl {
name: id,
v_type: typ,
space: ss,
Statement::Variable(ast::Variable {
align,
v_type,
name,
}) => {
let type_id = map.get_or_add(
builder,
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
SpirvType::new_pointer(ast::Type::from(*v_type), spirv::StorageClass::Function),
);
let st_class = match ss {
ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function,
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
_ => todo!(),
let st_class = match v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
spirv::StorageClass::Function
}
ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
};
builder.variable(type_id, Some(*id), st_class, None);
builder.variable(type_id, Some(*name), st_class, None);
if let Some(align) = align {
builder.decorate(
*id,
*name,
spirv::Decoration::Alignment,
&[dr::Operand::LiteralInt32(*align)],
);
@ -1051,7 +1062,7 @@ fn emit_cvt(
if desc.saturate || desc.flush_to_zero {
todo!()
}
let dest_t: ast::Type = desc.dst.into();
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
builder.f_convert(result_type, Some(arg.dst), arg.src)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
@ -1060,7 +1071,7 @@ fn emit_cvt(
if desc.saturate || desc.flush_to_zero {
todo!()
}
let dest_t: ast::Type = desc.dst.into();
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.src.is_signed() {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
@ -1367,7 +1378,7 @@ fn normalize_identifiers<'a, 'b>(
fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
@ -1386,21 +1397,19 @@ fn expand_map_variables<'a, 'b>(
))),
ast::Statement::Variable(var) => match var.count {
Some(count) => {
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
result.push(Statement::Variable(VariableDecl {
space: var.space,
align: var.align,
v_type: var.v_type,
for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
}))
}
}
None => {
let new_id = id_defs.add_def(var.name, Some(var.v_type));
result.push(Statement::Variable(VariableDecl {
space: var.space,
align: var.align,
v_type: var.v_type,
let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
}));
}
@ -1408,15 +1417,38 @@ fn expand_map_variables<'a, 'b>(
}
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)]
enum PtxSpecialRegister {
Tid,
Ntid,
Ctaid,
Nctaid,
Gridid,
}
impl PtxSpecialRegister {
fn try_parse(s: &str) -> Option<Self> {
match s {
"%tid" => Some(Self::Tid),
"%ntid" => Some(Self::Ntid),
"%ctaid" => Some(Self::Ctaid),
"%nctaid" => Some(Self::Nctaid),
"%gridid" => Some(Self::Gridid),
_ => None,
}
}
}
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
pub struct FnDecl {
ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
ret_vals: Vec<ast::FnArgumentType>,
params: Vec<ast::FnArgumentType>,
}
impl<'a> GlobalStringIdResolver<'a> {
@ -1424,6 +1456,7 @@ impl<'a> GlobalStringIdResolver<'a> {
Self {
current_id: start_id,
variables: HashMap::new(),
special_registers: HashMap::new(),
fns: HashMap::new(),
}
}
@ -1461,6 +1494,7 @@ impl<'a> GlobalStringIdResolver<'a> {
let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
global_variables: &self.variables,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
@ -1474,14 +1508,8 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert(
name_id,
FnDecl {
ret_vals: ret_params_ids
.iter()
.map(|p| (p.state_space, p.base.a_type))
.collect(),
params: params_ids
.iter()
.map(|p| (p.state_space, p.base.a_type))
.collect(),
ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
params: params_ids.iter().map(|p| p.v_type).collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@ -1516,7 +1544,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
//global: &'b mut GlobalStringIdResolver<'a>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, ast::Type>,
}
@ -1537,14 +1565,28 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
self.variables.pop();
}
fn get_id(&self, id: &str) -> spirv::Word {
fn get_id(&mut self, id: &str) -> spirv::Word {
for scope in self.variables.iter().rev() {
match scope.get(id) {
Some(id) => return *id,
None => continue,
}
}
self.global_variables[id]
match self.global_variables.get(id) {
Some(id) => *id,
None => {
let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!());
match self.special_registers.entry(sreg) {
hash_map::Entry::Occupied(e) => *e.get(),
hash_map::Entry::Vacant(e) => {
let numeric_id = *self.current_id;
*self.current_id += 1;
e.insert(numeric_id);
numeric_id
}
}
}
}
}
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
@ -1602,7 +1644,7 @@ impl<'b> NumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(VariableDecl),
Variable(ast::Variable<ast::VariableType, P>),
Instruction(I),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
@ -1614,18 +1656,11 @@ enum Statement<I, P: ast::ArgParams> {
RetValue(ast::RetData, spirv::Word),
}
struct VariableDecl {
pub space: ast::StateSpace,
pub align: Option<u32>,
pub v_type: ast::Type,
pub name: spirv::Word,
}
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
pub ret_params: Vec<ArgCall<spirv::Word>>,
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
pub func: spirv::Word,
pub param_list: Vec<ArgCall<P::CallOperand>>,
pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
}
impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
@ -1636,18 +1671,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
let ret_params = self
.ret_params
.into_iter()
.map(|p| {
.map(|(id, typ)| {
let new_id = visitor.variable(ArgumentDescriptor {
op: p.id,
typ: Some(p.typ),
op: id,
typ: Some(typ.into()),
is_dst: true,
is_pointer: false,
});
ArgCall {
id: new_id,
typ: p.typ,
space: p.space,
}
(new_id, typ)
})
.collect();
let func = visitor.variable(ArgumentDescriptor {
@ -1659,18 +1690,14 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
let param_list = self
.param_list
.into_iter()
.map(|p| {
.map(|(id, typ)| {
let new_id = visitor.src_call_operand(ArgumentDescriptor {
op: p.id,
typ: Some(p.typ),
op: id,
typ: Some(typ.into()),
is_dst: false,
is_pointer: false,
});
ArgCall {
id: new_id,
typ: p.typ,
space: p.space,
}
(new_id, typ)
})
.collect();
ResolvedCall {
@ -1700,12 +1727,6 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
}
}
struct ArgCall<ID> {
id: ID,
typ: ast::Type,
space: ast::FnArgStateSpace,
}
pub trait ArgParamsEx: ast::ArgParams {
fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
}
@ -1817,7 +1838,9 @@ where
) -> ast::MovOperand<spirv::Word> {
match desc.op {
ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
ast::MovOperand::Vec(reg, x2) => {
ast::MovOperand::Vec(self.variable(desc.new_op(reg)), x2)
}
}
}
}
@ -1881,13 +1904,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (desc.dst.into(), desc.src.into()),
ast::CvtDetails::FloatFromInt(desc) => {
(desc.dst.into(), ast::Type::Scalar(desc.src.into()))
}
ast::CvtDetails::IntFromFloat(desc) => {
(ast::Type::Scalar(desc.dst.into()), desc.src.into())
}
ast::CvtDetails::FloatFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::FloatFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromFloat(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
),
ast::CvtDetails::IntFromInt(desc) => (
ast::Type::Scalar(desc.dst.into()),
ast::Type::Scalar(desc.src.into()),
@ -2261,14 +2289,14 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ast::Arg4 {
dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
visitor.variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
@ -2298,14 +2326,14 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ast::Arg5 {
dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
visitor.variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
@ -2324,7 +2352,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}),
src3: visitor.operand(ArgumentDescriptor {
op: self.src3,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
is_dst: false,
is_pointer: false,
}),
@ -2332,65 +2360,6 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}
}
/*
impl<T: ArgParamsEx> ast::ArgCall<T> {
fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
fn_resolve: &GlobalFnDeclResolver<'a>,
) -> ast::ArgCall<U> {
// TODO: error out if lengths don't match
let fn_decl = T::get_fn_decl(&self.func, fn_resolve);
let ret_params = self
.ret_params
.into_iter()
.zip(fn_decl.ret_vals.iter().copied())
.map(|(a, (space, typ))| {
visitor.variable(ArgumentDescriptor {
op: a,
typ: Some(ast::Type::Scalar(typ)),
is_dst: true,
is_pointer: if space == ast::FnArgStateSpace::Reg {
false
} else {
true
},
})
})
.collect();
let func = visitor.variable(ArgumentDescriptor {
op: self.func,
typ: None,
is_dst: false,
is_pointer: false,
});
let param_list = self
.param_list
.into_iter()
.zip(fn_decl.params.iter().copied())
.map(|(a, (space, typ))| {
visitor.src_call_operand(ArgumentDescriptor {
op: a,
typ: Some(ast::Type::Scalar(typ)),
is_dst: false,
is_pointer: if space == ast::FnArgStateSpace::Reg {
false
} else {
true
},
})
})
.collect();
ast::ArgCall {
uniform: false,
ret_params,
func: func,
param_list: param_list,
}
}
}
*/
impl<T> ast::CallOperand<T> {
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
match self {
@ -2418,6 +2387,8 @@ enum ScalarKind {
Unsigned,
Signed,
Float,
Float2,
Pred,
}
impl ast::ScalarType {
@ -2438,6 +2409,8 @@ impl ast::ScalarType {
ast::ScalarType::S64 => 8,
ast::ScalarType::B64 => 8,
ast::ScalarType::F64 => 8,
ast::ScalarType::F16x2 => 4,
ast::ScalarType::Pred => 1,
}
}
@ -2458,6 +2431,8 @@ impl ast::ScalarType {
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}
@ -2490,6 +2465,11 @@ impl ast::ScalarType {
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
@ -2497,7 +2477,7 @@ impl ast::ScalarType {
impl ast::NotType {
fn to_type(self) -> ast::Type {
match self {
ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
@ -2519,7 +2499,9 @@ impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
}
}
}
@ -2528,7 +2510,9 @@ impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
ast::Type::Scalar((*typ).into())
}
}
}
}
@ -2560,6 +2544,15 @@ impl ast::LdStateSpace {
}
}
impl From<ast::FnArgumentType> for ast::VariableType {
fn from(t: ast::FnArgumentType) -> Self {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@ -2575,6 +2568,8 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
ScalarKind::Unsigned => {
operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => false,
}
}
_ => false,
@ -2758,6 +2753,8 @@ fn should_convert_relaxed_src(
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
_ => None,
}
@ -2807,6 +2804,8 @@ fn should_convert_relaxed_dst(
None
}
}
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
_ => None,
}
@ -2862,16 +2861,21 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
}
}
impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> {
fn visit_args(&self, f: impl FnMut(&ast::KernelArgument<P>)) {
impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
match self {
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f),
ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f),
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
f(&ast::FnArgument {
align: arg.align,
name: arg.name,
v_type: ast::FnArgumentType::Param(arg.v_type),
})
}),
}
}
}
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {
use super::*;