mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-02 22:30:41 +00:00
Remove all remaining subenums
This commit is contained in:
parent
a0baad9456
commit
4d04fe251d
3 changed files with 92 additions and 187 deletions
161
ptx/src/ast.rs
161
ptx/src/ast.rs
|
@ -34,43 +34,6 @@ pub enum PtxError {
|
||||||
NonExternPointer,
|
NonExternPointer,
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! sub_enum {
|
|
||||||
($name:ident { $($variant:ident),+ $(,)? }) => {
|
|
||||||
sub_enum!{ $name : ScalarType { $($variant),+ } }
|
|
||||||
};
|
|
||||||
($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
|
||||||
pub enum $name {
|
|
||||||
$(
|
|
||||||
$variant,
|
|
||||||
)+
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<$name> for $base_type {
|
|
||||||
fn from(t: $name) -> $base_type {
|
|
||||||
match t {
|
|
||||||
$(
|
|
||||||
$name::$variant => $base_type::$variant,
|
|
||||||
)+
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::TryFrom<$base_type> for $name {
|
|
||||||
type Error = ();
|
|
||||||
|
|
||||||
fn try_from(t: $base_type) -> Result<Self, Self::Error> {
|
|
||||||
match t {
|
|
||||||
$(
|
|
||||||
$base_type::$variant => Ok($name::$variant),
|
|
||||||
)+
|
|
||||||
_ => Err(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! sub_type {
|
macro_rules! sub_type {
|
||||||
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||||
sub_type! { $type_name : Type {
|
sub_type! { $type_name : Type {
|
||||||
|
@ -118,12 +81,12 @@ macro_rules! sub_type {
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableRegType {
|
VariableRegType {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(SizedScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
// Array type is used when emiting SSA statements at the start of a method
|
// Array type is used when emiting SSA statements at the start of a method
|
||||||
Array(ScalarType, VecU32),
|
Array(ScalarType, VecU32),
|
||||||
// Pointer variant is used when passing around SLM pointer between
|
// Pointer variant is used when passing around SLM pointer between
|
||||||
// function calls for dynamic SLM
|
// function calls for dynamic SLM
|
||||||
Pointer(SizedScalarType, PointerStateSpace)
|
Pointer(ScalarType, LdStateSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,9 +94,9 @@ type VecU32 = Vec<u32>;
|
||||||
|
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableLocalType {
|
VariableLocalType {
|
||||||
Scalar(SizedScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(SizedScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
Array(SizedScalarType, VecU32),
|
Array(ScalarType, VecU32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,10 +115,10 @@ impl TryFrom<VariableGlobalType> for VariableLocalType {
|
||||||
|
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableGlobalType {
|
VariableGlobalType {
|
||||||
Scalar(SizedScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(SizedScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
Array(SizedScalarType, VecU32),
|
Array(ScalarType, VecU32),
|
||||||
Pointer(SizedScalarType, PointerStateSpace),
|
Pointer(ScalarType, LdStateSpace),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,49 +130,12 @@ sub_type! {
|
||||||
// .param .b32 foobar[]
|
// .param .b32 foobar[]
|
||||||
sub_type! {
|
sub_type! {
|
||||||
VariableParamType {
|
VariableParamType {
|
||||||
Scalar(LdStScalarType),
|
Scalar(ScalarType),
|
||||||
Array(SizedScalarType, VecU32),
|
Array(ScalarType, VecU32),
|
||||||
Pointer(SizedScalarType, PointerStateSpace),
|
Pointer(ScalarType, LdStateSpace),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_enum!(SizedScalarType {
|
|
||||||
B8,
|
|
||||||
B16,
|
|
||||||
B32,
|
|
||||||
B64,
|
|
||||||
U8,
|
|
||||||
U16,
|
|
||||||
U32,
|
|
||||||
U64,
|
|
||||||
S8,
|
|
||||||
S16,
|
|
||||||
S32,
|
|
||||||
S64,
|
|
||||||
F16,
|
|
||||||
F16x2,
|
|
||||||
F32,
|
|
||||||
F64,
|
|
||||||
});
|
|
||||||
|
|
||||||
sub_enum!(LdStScalarType {
|
|
||||||
B8,
|
|
||||||
B16,
|
|
||||||
B32,
|
|
||||||
B64,
|
|
||||||
U8,
|
|
||||||
U16,
|
|
||||||
U32,
|
|
||||||
U64,
|
|
||||||
S8,
|
|
||||||
S16,
|
|
||||||
S32,
|
|
||||||
S64,
|
|
||||||
F16,
|
|
||||||
F32,
|
|
||||||
F64,
|
|
||||||
});
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||||
pub enum BarDetails {
|
pub enum BarDetails {
|
||||||
SyncAligned,
|
SyncAligned,
|
||||||
|
@ -345,16 +271,6 @@ impl FnArgumentType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sub_enum!(
|
|
||||||
PointerStateSpace : LdStateSpace {
|
|
||||||
Generic,
|
|
||||||
Global,
|
|
||||||
Const,
|
|
||||||
Shared,
|
|
||||||
Param,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
|
@ -371,18 +287,18 @@ pub enum PointerType {
|
||||||
Pointer(ScalarType, LdStateSpace),
|
Pointer(ScalarType, LdStateSpace),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<SizedScalarType> for PointerType {
|
impl From<ScalarType> for PointerType {
|
||||||
fn from(t: SizedScalarType) -> Self {
|
fn from(t: ScalarType) -> Self {
|
||||||
PointerType::Scalar(t.into())
|
PointerType::Scalar(t.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<PointerType> for SizedScalarType {
|
impl TryFrom<PointerType> for ScalarType {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn try_from(value: PointerType) -> Result<Self, Self::Error> {
|
fn try_from(value: PointerType) -> Result<Self, Self::Error> {
|
||||||
match value {
|
match value {
|
||||||
PointerType::Scalar(t) => Ok(t.try_into()?),
|
PointerType::Scalar(t) => Ok(t),
|
||||||
PointerType::Vector(_, _) => Err(()),
|
PointerType::Vector(_, _) => Err(()),
|
||||||
PointerType::Array(_, _) => Err(()),
|
PointerType::Array(_, _) => Err(()),
|
||||||
PointerType::Pointer(_, _) => Err(()),
|
PointerType::Pointer(_, _) => Err(()),
|
||||||
|
@ -685,8 +601,8 @@ pub struct LdDetails {
|
||||||
|
|
||||||
sub_type! {
|
sub_type! {
|
||||||
LdStType {
|
LdStType {
|
||||||
Scalar(LdStScalarType),
|
Scalar(ScalarType),
|
||||||
Vector(LdStScalarType, u8),
|
Vector(ScalarType, u8),
|
||||||
// Used in generated code
|
// Used in generated code
|
||||||
Pointer(PointerType, LdStateSpace),
|
Pointer(PointerType, LdStateSpace),
|
||||||
}
|
}
|
||||||
|
@ -1135,7 +1051,7 @@ pub struct NegDetails {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> NumsOrArrays<'a> {
|
impl<'a> NumsOrArrays<'a> {
|
||||||
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
|
pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
|
||||||
self.normalize_dimensions(dimensions)?;
|
self.normalize_dimensions(dimensions)?;
|
||||||
let sizeof_t = ScalarType::from(typ).size_of() as usize;
|
let sizeof_t = ScalarType::from(typ).size_of() as usize;
|
||||||
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
|
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
|
||||||
|
@ -1166,7 +1082,7 @@ impl<'a> NumsOrArrays<'a> {
|
||||||
|
|
||||||
fn parse_and_copy(
|
fn parse_and_copy(
|
||||||
&self,
|
&self,
|
||||||
t: SizedScalarType,
|
t: ScalarType,
|
||||||
size_of_t: usize,
|
size_of_t: usize,
|
||||||
dimensions: &[u32],
|
dimensions: &[u32],
|
||||||
result: &mut [u8],
|
result: &mut [u8],
|
||||||
|
@ -1206,47 +1122,48 @@ impl<'a> NumsOrArrays<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_and_copy_single(
|
fn parse_and_copy_single(
|
||||||
t: SizedScalarType,
|
t: ScalarType,
|
||||||
idx: usize,
|
idx: usize,
|
||||||
str_val: &str,
|
str_val: &str,
|
||||||
radix: u32,
|
radix: u32,
|
||||||
output: &mut [u8],
|
output: &mut [u8],
|
||||||
) -> Result<(), PtxError> {
|
) -> Result<(), PtxError> {
|
||||||
match t {
|
match t {
|
||||||
SizedScalarType::B8 | SizedScalarType::U8 => {
|
ScalarType::B8 | ScalarType::U8 => {
|
||||||
Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::B16 | SizedScalarType::U16 => {
|
ScalarType::B16 | ScalarType::U16 => {
|
||||||
Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::B32 | SizedScalarType::U32 => {
|
ScalarType::B32 | ScalarType::U32 => {
|
||||||
Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::B64 | SizedScalarType::U64 => {
|
ScalarType::B64 | ScalarType::U64 => {
|
||||||
Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::S8 => {
|
ScalarType::S8 => {
|
||||||
Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::S16 => {
|
ScalarType::S16 => {
|
||||||
Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::S32 => {
|
ScalarType::S32 => {
|
||||||
Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::S64 => {
|
ScalarType::S64 => {
|
||||||
Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::F16 => {
|
ScalarType::F16 => {
|
||||||
Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::F16x2 => todo!(),
|
ScalarType::F16x2 => todo!(),
|
||||||
SizedScalarType::F32 => {
|
ScalarType::F32 => {
|
||||||
Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
SizedScalarType::F64 => {
|
ScalarType::F64 => {
|
||||||
Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
|
Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
|
||||||
}
|
}
|
||||||
|
ScalarType::Pred => todo!()
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1334,13 +1251,13 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn array_fails_multiple_0_dmiensions() {
|
fn array_fails_multiple_0_dmiensions() {
|
||||||
let inp = NumsOrArrays::Nums(Vec::new());
|
let inp = NumsOrArrays::Nums(Vec::new());
|
||||||
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
|
assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn array_fails_on_empty() {
|
fn array_fails_on_empty() {
|
||||||
let inp = NumsOrArrays::Nums(Vec::new());
|
let inp = NumsOrArrays::Nums(Vec::new());
|
||||||
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
|
assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1352,7 +1269,7 @@ mod tests {
|
||||||
let mut dimensions = vec![0u32, 2];
|
let mut dimensions = vec![0u32, 2];
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
vec![1u8, 2, 3, 4],
|
vec![1u8, 2, 3, 4],
|
||||||
inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
|
inp.to_vec(ScalarType::B8, &mut dimensions).unwrap()
|
||||||
);
|
);
|
||||||
assert_eq!(dimensions, vec![2u32, 2]);
|
assert_eq!(dimensions, vec![2u32, 2]);
|
||||||
}
|
}
|
||||||
|
@ -1364,7 +1281,7 @@ mod tests {
|
||||||
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
|
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
|
||||||
]);
|
]);
|
||||||
let mut dimensions = vec![0u32, 2];
|
let mut dimensions = vec![0u32, 2];
|
||||||
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
|
assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1374,6 +1291,6 @@ mod tests {
|
||||||
NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
|
NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
|
||||||
]);
|
]);
|
||||||
let mut dimensions = vec![0u32, 2];
|
let mut dimensions = vec![0u32, 2];
|
||||||
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
|
assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -611,9 +611,9 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||||
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
|
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
|
||||||
}
|
}
|
||||||
if space == ".global" {
|
if space == ".global" {
|
||||||
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new())
|
(ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new())
|
||||||
} else {
|
} else {
|
||||||
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new())
|
(ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -635,7 +635,7 @@ ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
||||||
(ast::VariableParamType::Array(t, dimensions), init)
|
(ast::VariableParamType::Array(t, dimensions), init)
|
||||||
}
|
}
|
||||||
ast::ArrayOrPointer::Pointer => {
|
ast::ArrayOrPointer::Pointer => {
|
||||||
(ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new())
|
(ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
(align, array_init, v_type, name)
|
(align, array_init, v_type, name)
|
||||||
|
@ -667,42 +667,42 @@ GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
SizedScalarType: ast::SizedScalarType = {
|
SizedScalarType: ast::ScalarType = {
|
||||||
".b8" => ast::SizedScalarType::B8,
|
".b8" => ast::ScalarType::B8,
|
||||||
".b16" => ast::SizedScalarType::B16,
|
".b16" => ast::ScalarType::B16,
|
||||||
".b32" => ast::SizedScalarType::B32,
|
".b32" => ast::ScalarType::B32,
|
||||||
".b64" => ast::SizedScalarType::B64,
|
".b64" => ast::ScalarType::B64,
|
||||||
".u8" => ast::SizedScalarType::U8,
|
".u8" => ast::ScalarType::U8,
|
||||||
".u16" => ast::SizedScalarType::U16,
|
".u16" => ast::ScalarType::U16,
|
||||||
".u32" => ast::SizedScalarType::U32,
|
".u32" => ast::ScalarType::U32,
|
||||||
".u64" => ast::SizedScalarType::U64,
|
".u64" => ast::ScalarType::U64,
|
||||||
".s8" => ast::SizedScalarType::S8,
|
".s8" => ast::ScalarType::S8,
|
||||||
".s16" => ast::SizedScalarType::S16,
|
".s16" => ast::ScalarType::S16,
|
||||||
".s32" => ast::SizedScalarType::S32,
|
".s32" => ast::ScalarType::S32,
|
||||||
".s64" => ast::SizedScalarType::S64,
|
".s64" => ast::ScalarType::S64,
|
||||||
".f16" => ast::SizedScalarType::F16,
|
".f16" => ast::ScalarType::F16,
|
||||||
".f16x2" => ast::SizedScalarType::F16x2,
|
".f16x2" => ast::ScalarType::F16x2,
|
||||||
".f32" => ast::SizedScalarType::F32,
|
".f32" => ast::ScalarType::F32,
|
||||||
".f64" => ast::SizedScalarType::F64,
|
".f64" => ast::ScalarType::F64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
LdStScalarType: ast::LdStScalarType = {
|
LdStScalarType: ast::ScalarType = {
|
||||||
".b8" => ast::LdStScalarType::B8,
|
".b8" => ast::ScalarType::B8,
|
||||||
".b16" => ast::LdStScalarType::B16,
|
".b16" => ast::ScalarType::B16,
|
||||||
".b32" => ast::LdStScalarType::B32,
|
".b32" => ast::ScalarType::B32,
|
||||||
".b64" => ast::LdStScalarType::B64,
|
".b64" => ast::ScalarType::B64,
|
||||||
".u8" => ast::LdStScalarType::U8,
|
".u8" => ast::ScalarType::U8,
|
||||||
".u16" => ast::LdStScalarType::U16,
|
".u16" => ast::ScalarType::U16,
|
||||||
".u32" => ast::LdStScalarType::U32,
|
".u32" => ast::ScalarType::U32,
|
||||||
".u64" => ast::LdStScalarType::U64,
|
".u64" => ast::ScalarType::U64,
|
||||||
".s8" => ast::LdStScalarType::S8,
|
".s8" => ast::ScalarType::S8,
|
||||||
".s16" => ast::LdStScalarType::S16,
|
".s16" => ast::ScalarType::S16,
|
||||||
".s32" => ast::LdStScalarType::S32,
|
".s32" => ast::ScalarType::S32,
|
||||||
".s64" => ast::LdStScalarType::S64,
|
".s64" => ast::ScalarType::S64,
|
||||||
".f16" => ast::LdStScalarType::F16,
|
".f16" => ast::ScalarType::F16,
|
||||||
".f32" => ast::LdStScalarType::F32,
|
".f32" => ast::ScalarType::F32,
|
||||||
".f64" => ast::LdStScalarType::F64,
|
".f64" => ast::ScalarType::F64,
|
||||||
}
|
}
|
||||||
|
|
||||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
|
|
|
@ -97,18 +97,6 @@ impl ast::Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Into<spirv::StorageClass> for ast::PointerStateSpace {
|
|
||||||
fn into(self) -> spirv::StorageClass {
|
|
||||||
match self {
|
|
||||||
ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
|
|
||||||
ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
|
|
||||||
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
|
|
||||||
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
|
|
||||||
ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ast::ScalarType> for SpirvType {
|
impl From<ast::ScalarType> for SpirvType {
|
||||||
fn from(t: ast::ScalarType) -> Self {
|
fn from(t: ast::ScalarType) -> Self {
|
||||||
SpirvType::Base(t.into())
|
SpirvType::Base(t.into())
|
||||||
|
@ -824,8 +812,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
name: shared_var_id,
|
name: shared_var_id,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||||
ast::SizedScalarType::B8,
|
ast::ScalarType::B8,
|
||||||
ast::PointerStateSpace::Shared,
|
ast::LdStateSpace::Shared,
|
||||||
)),
|
)),
|
||||||
});
|
});
|
||||||
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
|
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
|
||||||
|
@ -863,7 +851,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
|
||||||
fn replace_uses_of_shared_memory<'a>(
|
fn replace_uses_of_shared_memory<'a>(
|
||||||
result: &mut Vec<ExpandedStatement>,
|
result: &mut Vec<ExpandedStatement>,
|
||||||
new_id: &mut impl FnMut() -> spirv::Word,
|
new_id: &mut impl FnMut() -> spirv::Word,
|
||||||
extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
|
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
|
||||||
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
|
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
|
||||||
shared_id_param: spirv::Word,
|
shared_id_param: spirv::Word,
|
||||||
shared_var_id: spirv::Word,
|
shared_var_id: spirv::Word,
|
||||||
|
@ -884,7 +872,7 @@ fn replace_uses_of_shared_memory<'a>(
|
||||||
statement => {
|
statement => {
|
||||||
let new_statement = statement.map_id(&mut |id, _| {
|
let new_statement = statement.map_id(&mut |id, _| {
|
||||||
if let Some(typ) = extern_shared_decls.get(&id) {
|
if let Some(typ) = extern_shared_decls.get(&id) {
|
||||||
if *typ == ast::SizedScalarType::B8 {
|
if *typ == ast::ScalarType::B8 {
|
||||||
return shared_var_id;
|
return shared_var_id;
|
||||||
}
|
}
|
||||||
let replacement_id = new_id();
|
let replacement_id = new_id();
|
||||||
|
@ -1505,7 +1493,7 @@ fn extract_globals<'input, 'b>(
|
||||||
d,
|
d,
|
||||||
a,
|
a,
|
||||||
"inc",
|
"inc",
|
||||||
ast::SizedScalarType::U32,
|
ast::ScalarType::U32,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
|
@ -1527,7 +1515,7 @@ fn extract_globals<'input, 'b>(
|
||||||
d,
|
d,
|
||||||
a,
|
a,
|
||||||
"dec",
|
"dec",
|
||||||
ast::SizedScalarType::U32,
|
ast::ScalarType::U32,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
|
@ -1553,8 +1541,8 @@ fn extract_globals<'input, 'b>(
|
||||||
space,
|
space,
|
||||||
};
|
};
|
||||||
let (op, typ) = match typ {
|
let (op, typ) = match typ {
|
||||||
ast::ScalarType::F32 => ("add_f32", ast::SizedScalarType::F32),
|
ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
|
||||||
ast::ScalarType::F64 => ("add_f64", ast::SizedScalarType::F64),
|
ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
local.push(to_ptx_impl_atomic_call(
|
local.push(to_ptx_impl_atomic_call(
|
||||||
|
@ -1734,7 +1722,7 @@ fn to_ptx_impl_atomic_call(
|
||||||
details: ast::AtomDetails,
|
details: ast::AtomDetails,
|
||||||
arg: ast::Arg3<ExpandedArgParams>,
|
arg: ast::Arg3<ExpandedArgParams>,
|
||||||
op: &'static str,
|
op: &'static str,
|
||||||
typ: ast::SizedScalarType,
|
typ: ast::ScalarType,
|
||||||
) -> ExpandedStatement {
|
) -> ExpandedStatement {
|
||||||
let semantics = ptx_semantics_name(details.semantics);
|
let semantics = ptx_semantics_name(details.semantics);
|
||||||
let scope = ptx_scope_name(details.scope);
|
let scope = ptx_scope_name(details.scope);
|
||||||
|
@ -1745,9 +1733,9 @@ fn to_ptx_impl_atomic_call(
|
||||||
);
|
);
|
||||||
// TODO: extract to a function
|
// TODO: extract to a function
|
||||||
let ptr_space = match details.space {
|
let ptr_space = match details.space {
|
||||||
ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
|
ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
|
||||||
ast::AtomSpace::Global => ast::PointerStateSpace::Global,
|
ast::AtomSpace::Global => ast::LdStateSpace::Global,
|
||||||
ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
|
ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
|
||||||
};
|
};
|
||||||
let scalar_typ = ast::ScalarType::from(typ);
|
let scalar_typ = ast::ScalarType::from(typ);
|
||||||
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
let fn_id = match ptx_impl_imports.entry(fn_name) {
|
||||||
|
@ -4565,7 +4553,7 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
Statement::Instruction(ast::Instruction::Ld(
|
Statement::Instruction(ast::Instruction::Ld(
|
||||||
ast::LdDetails {
|
ast::LdDetails {
|
||||||
state_space: ast::LdStateSpace::Param,
|
state_space: ast::LdStateSpace::Param,
|
||||||
typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
|
typ: ast::LdStType::Scalar(ast::ScalarType::U64),
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
arg,
|
arg,
|
||||||
|
@ -4573,7 +4561,7 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
| Statement::Instruction(ast::Instruction::Ld(
|
| Statement::Instruction(ast::Instruction::Ld(
|
||||||
ast::LdDetails {
|
ast::LdDetails {
|
||||||
state_space: ast::LdStateSpace::Param,
|
state_space: ast::LdStateSpace::Param,
|
||||||
typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
|
typ: ast::LdStType::Scalar(ast::ScalarType::S64),
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
arg,
|
arg,
|
||||||
|
@ -4581,7 +4569,7 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
| Statement::Instruction(ast::Instruction::Ld(
|
| Statement::Instruction(ast::Instruction::Ld(
|
||||||
ast::LdDetails {
|
ast::LdDetails {
|
||||||
state_space: ast::LdStateSpace::Param,
|
state_space: ast::LdStateSpace::Param,
|
||||||
typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
|
typ: ast::LdStType::Scalar(ast::ScalarType::B64),
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
arg,
|
arg,
|
||||||
|
@ -4672,8 +4660,8 @@ fn convert_to_stateful_memory_access<'a>(
|
||||||
name: new_id,
|
name: new_id,
|
||||||
array_init: Vec::new(),
|
array_init: Vec::new(),
|
||||||
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
|
||||||
ast::SizedScalarType::U8,
|
ast::ScalarType::U8,
|
||||||
ast::PointerStateSpace::Global,
|
ast::LdStateSpace::Global,
|
||||||
)),
|
)),
|
||||||
}));
|
}));
|
||||||
remapped_ids.insert(reg, new_id);
|
remapped_ids.insert(reg, new_id);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue