Remove all remaining subenums

This commit is contained in:
Andrzej Janik 2021-04-15 19:21:52 +02:00
parent a0baad9456
commit 4d04fe251d
3 changed files with 92 additions and 187 deletions

View file

@ -34,43 +34,6 @@ pub enum PtxError {
NonExternPointer,
}
macro_rules! sub_enum {
($name:ident { $($variant:ident),+ $(,)? }) => {
sub_enum!{ $name : ScalarType { $($variant),+ } }
};
($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $name {
$(
$variant,
)+
}
impl From<$name> for $base_type {
fn from(t: $name) -> $base_type {
match t {
$(
$name::$variant => $base_type::$variant,
)+
}
}
}
impl std::convert::TryFrom<$base_type> for $name {
type Error = ();
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t {
$(
$base_type::$variant => Ok($name::$variant),
)+
_ => Err(()),
}
}
}
};
}
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
sub_type! { $type_name : Type {
@ -118,12 +81,12 @@ macro_rules! sub_type {
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(SizedScalarType, u8),
Vector(ScalarType, u8),
// Array type is used when emiting SSA statements at the start of a method
Array(ScalarType, VecU32),
// Pointer variant is used when passing around SLM pointer between
// function calls for dynamic SLM
Pointer(SizedScalarType, PointerStateSpace)
Pointer(ScalarType, LdStateSpace)
}
}
@ -131,9 +94,9 @@ type VecU32 = Vec<u32>;
sub_type! {
VariableLocalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
Array(SizedScalarType, VecU32),
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
}
}
@ -152,10 +115,10 @@ impl TryFrom<VariableGlobalType> for VariableLocalType {
sub_type! {
VariableGlobalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
Array(SizedScalarType, VecU32),
Pointer(SizedScalarType, PointerStateSpace),
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
@ -167,49 +130,12 @@ sub_type! {
// .param .b32 foobar[]
sub_type! {
VariableParamType {
Scalar(LdStScalarType),
Array(SizedScalarType, VecU32),
Pointer(SizedScalarType, PointerStateSpace),
Scalar(ScalarType),
Array(ScalarType, VecU32),
Pointer(ScalarType, LdStateSpace),
}
}
sub_enum!(SizedScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F16x2,
F32,
F64,
});
sub_enum!(LdStScalarType {
B8,
B16,
B32,
B64,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F32,
F64,
});
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails {
SyncAligned,
@ -345,16 +271,6 @@ impl FnArgumentType {
}
}
sub_enum!(
PointerStateSpace : LdStateSpace {
Generic,
Global,
Const,
Shared,
Param,
}
);
#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
@ -371,18 +287,18 @@ pub enum PointerType {
Pointer(ScalarType, LdStateSpace),
}
impl From<SizedScalarType> for PointerType {
fn from(t: SizedScalarType) -> Self {
impl From<ScalarType> for PointerType {
fn from(t: ScalarType) -> Self {
PointerType::Scalar(t.into())
}
}
impl TryFrom<PointerType> for SizedScalarType {
impl TryFrom<PointerType> for ScalarType {
type Error = ();
fn try_from(value: PointerType) -> Result<Self, Self::Error> {
match value {
PointerType::Scalar(t) => Ok(t.try_into()?),
PointerType::Scalar(t) => Ok(t),
PointerType::Vector(_, _) => Err(()),
PointerType::Array(_, _) => Err(()),
PointerType::Pointer(_, _) => Err(()),
@ -685,8 +601,8 @@ pub struct LdDetails {
sub_type! {
LdStType {
Scalar(LdStScalarType),
Vector(LdStScalarType, u8),
Scalar(ScalarType),
Vector(ScalarType, u8),
// Used in generated code
Pointer(PointerType, LdStateSpace),
}
@ -1135,7 +1051,7 @@ pub struct NegDetails {
}
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;
let sizeof_t = ScalarType::from(typ).size_of() as usize;
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
@ -1166,7 +1082,7 @@ impl<'a> NumsOrArrays<'a> {
fn parse_and_copy(
&self,
t: SizedScalarType,
t: ScalarType,
size_of_t: usize,
dimensions: &[u32],
result: &mut [u8],
@ -1206,47 +1122,48 @@ impl<'a> NumsOrArrays<'a> {
}
fn parse_and_copy_single(
t: SizedScalarType,
t: ScalarType,
idx: usize,
str_val: &str,
radix: u32,
output: &mut [u8],
) -> Result<(), PtxError> {
match t {
SizedScalarType::B8 | SizedScalarType::U8 => {
ScalarType::B8 | ScalarType::U8 => {
Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
}
SizedScalarType::B16 | SizedScalarType::U16 => {
ScalarType::B16 | ScalarType::U16 => {
Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
}
SizedScalarType::B32 | SizedScalarType::U32 => {
ScalarType::B32 | ScalarType::U32 => {
Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
}
SizedScalarType::B64 | SizedScalarType::U64 => {
ScalarType::B64 | ScalarType::U64 => {
Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
}
SizedScalarType::S8 => {
ScalarType::S8 => {
Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
}
SizedScalarType::S16 => {
ScalarType::S16 => {
Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
}
SizedScalarType::S32 => {
ScalarType::S32 => {
Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
}
SizedScalarType::S64 => {
ScalarType::S64 => {
Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
}
SizedScalarType::F16 => {
ScalarType::F16 => {
Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
}
SizedScalarType::F16x2 => todo!(),
SizedScalarType::F32 => {
ScalarType::F16x2 => todo!(),
ScalarType::F32 => {
Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
}
SizedScalarType::F64 => {
ScalarType::F64 => {
Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
}
ScalarType::Pred => todo!()
}
Ok(())
}
@ -1334,13 +1251,13 @@ mod tests {
#[test]
fn array_fails_multiple_0_dmiensions() {
let inp = NumsOrArrays::Nums(Vec::new());
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err());
}
#[test]
fn array_fails_on_empty() {
let inp = NumsOrArrays::Nums(Vec::new());
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err());
}
#[test]
@ -1352,7 +1269,7 @@ mod tests {
let mut dimensions = vec![0u32, 2];
assert_eq!(
vec![1u8, 2, 3, 4],
inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
inp.to_vec(ScalarType::B8, &mut dimensions).unwrap()
);
assert_eq!(dimensions, vec![2u32, 2]);
}
@ -1364,7 +1281,7 @@ mod tests {
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
}
#[test]
@ -1374,6 +1291,6 @@ mod tests {
NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
}
}

View file

@ -611,9 +611,9 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new())
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new())
} else {
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new())
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new())
}
}
};
@ -635,7 +635,7 @@ ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
(ast::VariableParamType::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
(ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new())
(ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new())
}
};
(align, array_init, v_type, name)
@ -667,42 +667,42 @@ GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input
}
#[inline]
SizedScalarType: ast::SizedScalarType = {
".b8" => ast::SizedScalarType::B8,
".b16" => ast::SizedScalarType::B16,
".b32" => ast::SizedScalarType::B32,
".b64" => ast::SizedScalarType::B64,
".u8" => ast::SizedScalarType::U8,
".u16" => ast::SizedScalarType::U16,
".u32" => ast::SizedScalarType::U32,
".u64" => ast::SizedScalarType::U64,
".s8" => ast::SizedScalarType::S8,
".s16" => ast::SizedScalarType::S16,
".s32" => ast::SizedScalarType::S32,
".s64" => ast::SizedScalarType::S64,
".f16" => ast::SizedScalarType::F16,
".f16x2" => ast::SizedScalarType::F16x2,
".f32" => ast::SizedScalarType::F32,
".f64" => ast::SizedScalarType::F64,
SizedScalarType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
".u8" => ast::ScalarType::U8,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s8" => ast::ScalarType::S8,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
}
#[inline]
LdStScalarType: ast::LdStScalarType = {
".b8" => ast::LdStScalarType::B8,
".b16" => ast::LdStScalarType::B16,
".b32" => ast::LdStScalarType::B32,
".b64" => ast::LdStScalarType::B64,
".u8" => ast::LdStScalarType::U8,
".u16" => ast::LdStScalarType::U16,
".u32" => ast::LdStScalarType::U32,
".u64" => ast::LdStScalarType::U64,
".s8" => ast::LdStScalarType::S8,
".s16" => ast::LdStScalarType::S16,
".s32" => ast::LdStScalarType::S32,
".s64" => ast::LdStScalarType::S64,
".f16" => ast::LdStScalarType::F16,
".f32" => ast::LdStScalarType::F32,
".f64" => ast::LdStScalarType::F64,
LdStScalarType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
".u8" => ast::ScalarType::U8,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s8" => ast::ScalarType::S8,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f16" => ast::ScalarType::F16,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
}
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {

View file

@ -97,18 +97,6 @@ impl ast::Type {
}
}
impl Into<spirv::StorageClass> for ast::PointerStateSpace {
fn into(self) -> spirv::StorageClass {
match self {
ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
}
}
}
impl From<ast::ScalarType> for SpirvType {
fn from(t: ast::ScalarType) -> Self {
SpirvType::Base(t.into())
@ -824,8 +812,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
name: shared_var_id,
array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
ast::SizedScalarType::B8,
ast::PointerStateSpace::Shared,
ast::ScalarType::B8,
ast::LdStateSpace::Shared,
)),
});
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
@ -863,7 +851,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
@ -884,7 +872,7 @@ fn replace_uses_of_shared_memory<'a>(
statement => {
let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
if *typ == ast::SizedScalarType::B8 {
if *typ == ast::ScalarType::B8 {
return shared_var_id;
}
let replacement_id = new_id();
@ -1505,7 +1493,7 @@ fn extract_globals<'input, 'b>(
d,
a,
"inc",
ast::SizedScalarType::U32,
ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@ -1527,7 +1515,7 @@ fn extract_globals<'input, 'b>(
d,
a,
"dec",
ast::SizedScalarType::U32,
ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@ -1553,8 +1541,8 @@ fn extract_globals<'input, 'b>(
space,
};
let (op, typ) = match typ {
ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32),
ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64),
ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
_ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
@ -1734,7 +1722,7 @@ fn to_ptx_impl_atomic_call(
details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
typ: ast::SizedScalarType,
typ: ast::ScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
@ -1745,9 +1733,9 @@ fn to_ptx_impl_atomic_call(
);
// TODO: extract to a function
let ptr_space = match details.space {
ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
ast::AtomSpace::Global => ast::PointerStateSpace::Global,
ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
ast::AtomSpace::Global => ast::LdStateSpace::Global,
ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
};
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
@ -4565,7 +4553,7 @@ fn convert_to_stateful_memory_access<'a>(
Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
state_space: ast::LdStateSpace::Param,
typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
typ: ast::LdStType::Scalar(ast::ScalarType::U64),
..
},
arg,
@ -4573,7 +4561,7 @@ fn convert_to_stateful_memory_access<'a>(
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
state_space: ast::LdStateSpace::Param,
typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
typ: ast::LdStType::Scalar(ast::ScalarType::S64),
..
},
arg,
@ -4581,7 +4569,7 @@ fn convert_to_stateful_memory_access<'a>(
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
state_space: ast::LdStateSpace::Param,
typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
typ: ast::LdStType::Scalar(ast::ScalarType::B64),
..
},
arg,
@ -4672,8 +4660,8 @@ fn convert_to_stateful_memory_access<'a>(
name: new_id,
array_init: Vec::new(),
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
ast::SizedScalarType::U8,
ast::PointerStateSpace::Global,
ast::ScalarType::U8,
ast::LdStateSpace::Global,
)),
}));
remapped_ids.insert(reg, new_id);