diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index ad4b473..455a8c2 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -467,8 +467,22 @@ fn convert_to_stateful_memory_access_postprocess( Some(new_id) => { let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; // TODO: readd if required - if let Some(..) = type_space { - if relaxed_conversion { + if let Some((expected_type, expected_space)) = type_space { + let implicit_conversion = if relaxed_conversion { + if is_dst { + super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper + } else { + super::insert_implicit_conversions::should_convert_relaxed_src_wrapper + } + } else { + super::insert_implicit_conversions::default_implicit_conversion + }; + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { return Ok(*new_id); } } diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 2342ad5..c2af204 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -67,7 +67,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ) -> Result { // mov.u32 foobar, {a,b}; let scalar_t = match typ { - ast::Type::Vector(scalar_t, _) => *scalar_t, + ast::Type::Vector(_, scalar_t) => *scalar_t, _ => return Err(error_mismatched_type()), }; let temp_vec = self diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index e2e6a3b..8aa4576 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -291,7 +291,7 @@ impl TypeWordMap { | ast::ScalarType::BF16x2 | ast::ScalarType::B128 => todo!(), }, - ast::Type::Vector(typ, len) => { + ast::Type::Vector(len, typ) => { let result_type = self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); let size_of_t = typ.size_of(); @@ -309,7 +309,7 @@ impl TypeWordMap { .collect::, _>>()?; SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) } - ast::Type::Array(typ, dims) => match dims.as_slice() { + ast::Type::Array(_, typ, dims) => match dims.as_slice() { [] => return Err(error_unreachable()), [dim] => { let result_type = self @@ -342,7 +342,7 @@ impl TypeWordMap { Ok::<_, TranslateError>( self.get_or_add_constant( b, - &ast::Type::Array(*typ, rest.to_vec()), + &ast::Type::Array(None, *typ, rest.to_vec()), &init[((size_of_t as usize) * (x as usize))..], )? .0, @@ -397,8 +397,8 @@ impl SpirvType { fn new(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), - ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), + ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), + ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( Box::new(SpirvType::Base(pointer_t.into())), space_to_spirv(space), @@ -809,8 +809,8 @@ fn emit_function_header<'input>( pub fn type_size_of(this: &ast::Type) -> usize { match this { ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(typ, len) => len + ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(_, typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), ast::Type::Pointer(..) => mem::size_of::(), @@ -1853,11 +1853,16 @@ fn emit_mul_int( builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; } ast::MulIntControl::High => { + let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_mul_hi + } else { + spirv::CLOp::u_mul_hi + }; builder.ext_inst( inst_type.0, Some(arg.dst.0), opencl, - spirv::CLOp::s_mul_hi as spirv::Word, + opencl_inst as spirv::Word, [ dr::Operand::IdRef(arg.src1.0), dr::Operand::IdRef(arg.src2.0), @@ -2646,7 +2651,7 @@ fn emit_load_var( match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { - ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), _ => return Err(error_mismatched_type()), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); @@ -2710,14 +2715,14 @@ fn to_parts(this: &ast::Type) -> TypeParts { width: scalar.size_of(), components: Vec::new(), }, - ast::Type::Vector(scalar, components) => TypeParts { + ast::Type::Vector(components, scalar) => TypeParts { kind: TypeKind::Vector, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], }, - ast::Type::Array(scalar, components) => TypeParts { + ast::Type::Array(_, scalar, components) => TypeParts { kind: TypeKind::Array, state_space: ast::StateSpace::Reg, scalar_kind: scalar.kind(), @@ -2738,12 +2743,14 @@ fn type_from_parts(t: TypeParts) -> ast::Type { match t.kind { TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), TypeKind::Vector => ast::Type::Vector( - scalar_from_parts(t.width, t.scalar_kind), t.components[0] as u8, + scalar_from_parts(t.width, t.scalar_kind), + ), + TypeKind::Array => ast::Type::Array( + None, + scalar_from_parts(t.width, t.scalar_kind), + t.components, ), - TypeKind::Array => { - ast::Type::Array(scalar_from_parts(t.width, t.scalar_kind), t.components) - } TypeKind::Pointer => { ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) } diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 0dce598..2857551 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -123,13 +123,13 @@ fn insert_implicit_conversions_impl( Ok(()) } -fn default_implicit_conversion( +pub(crate) fn default_implicit_conversion( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if instruction_space == ast::StateSpace::Reg { if space_is_compatible(operand_space, ast::StateSpace::Reg) { - if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = (operand_type, instruction_type) { if scalar.kind() == ast::ScalarKind::Bit @@ -282,15 +282,15 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { ast::ScalarKind::Pred => false, } } - (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) - | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) } _ => false, } } -fn should_convert_relaxed_dst_wrapper( +pub(crate) fn should_convert_relaxed_dst_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { @@ -356,8 +356,8 @@ fn should_convert_relaxed_dst( } ast::ScalarKind::Pred => None, }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { should_convert_relaxed_dst( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), @@ -367,7 +367,7 @@ fn should_convert_relaxed_dst( } } -fn should_convert_relaxed_src_wrapper( +pub(crate) fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { @@ -420,8 +420,8 @@ fn should_convert_relaxed_src( } ast::ScalarKind::Pred => None, }, - (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) - | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { should_convert_relaxed_src( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index c1e30b0..e314b05 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -195,7 +195,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { let member_index = match member_index { Some(idx) => { let vector_width = match var_type { - ast::Type::Vector(scalar_t, width) => { + ast::Type::Vector(width, scalar_t) => { var_type = ast::Type::Scalar(scalar_t); width } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 4ca2f02..92d1bf4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,6 +1,7 @@ use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; use std::hash::Hash; +use std::num::NonZeroU8; use std::{ borrow::Cow, cell::RefCell, @@ -360,7 +361,7 @@ impl PtxSpecialRegister { PtxSpecialRegister::Tid | PtxSpecialRegister::Ntid | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), + | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()), _ => ast::Type::Scalar(self.get_function_return_type()), } } @@ -764,7 +765,12 @@ impl> Statement, T> { }) } Statement::Conditional(conditional) => { - let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?; + let predicate = visitor.visit_ident( + conditional.predicate, + Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)), + false, + false, + )?; let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; Statement::Conditional(BrachCondition { @@ -919,7 +925,7 @@ impl> Statement, T> { let packed = visitor.visit_ident( packed, Some(( - &ast::Type::Vector(typ, unpacked.len() as u8), + &ast::Type::Vector(unpacked.len() as u8, typ), ast::StateSpace::Reg, )), false, @@ -930,7 +936,7 @@ impl> Statement, T> { let packed = visitor.visit_ident( packed, Some(( - &ast::Type::Vector(typ, unpacked.len() as u8), + &ast::Type::Vector(unpacked.len() as u8, typ), ast::StateSpace::Reg, )), true, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 2a6bb53..c266947 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; -use std::cmp::Ordering; +use std::{cmp::Ordering, num::NonZeroU8}; pub enum Statement { Label(P::Ident), @@ -760,19 +760,37 @@ pub enum Type { // .param.b32 foo; Scalar(ScalarType), // .param.v2.b32 foo; - Vector(ScalarType, u8), + Vector(u8, ScalarType), // .param.b32 foo[4]; - Array(ScalarType, Vec), + Array(Option, ScalarType, Vec), Pointer(ScalarType, StateSpace), } impl Type { pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { match vector { - Some(prefix) => Type::Vector(scalar, prefix.len()), + Some(prefix) => Type::Vector(prefix.len().get(), scalar), None => Type::Scalar(scalar), } } + + pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { + match prefix { + Some(prefix) => Type::Vector(prefix.get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_array( + prefix: Option, + scalar: ScalarType, + array: Option>, + ) -> Self { + match array { + Some(dimensions) => Type::Array(prefix, scalar, dimensions), + None => Self::maybe_vector_parsed(prefix, scalar), + } + } } impl ScalarType { @@ -1304,7 +1322,9 @@ impl CallArgs { .input_arguments .into_iter() .zip(details.input_arguments.iter()) - .map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), false, false)) + .map(|(param, (type_, space))| { + visitor.visit(param, Some((type_, *space)), false, false) + }) .collect::, _>>()?; Ok(CallArgs { return_arguments, diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py new file mode 100644 index 0000000..04ffdb9 --- /dev/null +++ b/ptx_parser/src/check_args.py @@ -0,0 +1,69 @@ +import os, sys, subprocess + + +SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"] +TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"] +MULTIVAR = ["", "<1>" ] +VECTOR = ["", ".v2" ] + +HEADER = """ + .version 8.5 + .target sm_90 + .address_size 64 +""" + + +def directive(space, variable, multivar, vector): + return """{3} + {0} {4} .b32 variable{2} {1}; + """.format(space, variable, multivar, HEADER, vector) + +def entry_arg(space, variable, multivar, vector): + return """{3} + .entry foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_arg(space, variable, multivar, vector): + return """{3} + .func foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_body(space, variable, multivar, vector): + return """{3} + .func foobar () + {{ + {0} {4} .b32 variable{2} {1}; + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def generate(generator): + legal = [] + for space in SPACE: + for init in TYPE_AND_INIT: + for multi in MULTIVAR: + for vector in VECTOR: + ptx = generator(space, init, multi, vector) + if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): # + legal.append((space, vector, init, multi)) + print(generator.__name__) + print(legal) + + +def main(): + generate(directive) + generate(entry_arg) + generate(fn_arg) + generate(fn_body) + +if __name__ == "__main__": + main() diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 3a9ece5..dfe78ee 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -3,9 +3,10 @@ use logos::Logos; use ptx_parser_macros::derive_parser; use rustc_hash::FxHashMap; use std::fmt::Debug; -use std::num::{ParseFloatError, ParseIntError}; +use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; use winnow::ascii::dec_uint; use winnow::combinator::*; +use winnow::error::{ErrMode, ErrorKind}; use winnow::stream::Accumulate; use winnow::token::any; use winnow::{ @@ -72,11 +73,13 @@ impl From for ast::RoundingMode { } impl VectorPrefix { - pub(crate) fn len(self) -> u8 { - match self { - VectorPrefix::V2 => 2, - VectorPrefix::V4 => 4, - VectorPrefix::V8 => 8, + pub(crate) fn len(self) -> NonZeroU8 { + unsafe { + match self { + VectorPrefix::V2 => NonZeroU8::new_unchecked(2), + VectorPrefix::V4 => NonZeroU8::new_unchecked(4), + VectorPrefix::V8 => NonZeroU8::new_unchecked(8), + } } } } @@ -386,22 +389,14 @@ fn module_variable<'a, 'input>( ) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { ( linking_directives, - module_variable_state_space.flat_map(variable_scalar_or_vector), + global_space + .flat_map(multi_variable) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var), ) .parse_next(stream) } -fn module_variable_state_space<'a, 'input>( - stream: &mut PtxParser<'a, 'input>, -) -> PResult { - alt(( - Token::DotConst.value(StateSpace::Const), - Token::DotGlobal.value(StateSpace::Global), - Token::DotShared.value(StateSpace::Shared), - )) - .parse_next(stream) -} - fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { ( Token::DotFile, @@ -547,17 +542,13 @@ fn kernel_arguments<'a, 'input>( fn kernel_input<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult> { - preceded( - Token::DotParam, - variable_scalar_or_vector(StateSpace::Param), - ) - .parse_next(stream) + preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) } fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { dispatch! { any; - Token::DotParam => variable_scalar_or_vector(StateSpace::Param), - Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), + Token::DotParam => method_parameter(StateSpace::Param), + Token::DotReg => method_parameter(StateSpace::Reg), _ => fail } .parse_next(stream) @@ -596,7 +587,7 @@ fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32 } } - separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma) + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma) .map(|acc| acc.value) .parse_next(stream) } @@ -618,7 +609,12 @@ fn statement<'a, 'input>( alt(( label.map(Some), debug_directive.map(|_| None), - multi_variable.map(Some), + terminated( + method_space + .flat_map(multi_variable) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), predicated_instruction.map(Some), pragma.map(|_| None), block_statement.map(Some), @@ -632,59 +628,328 @@ fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { .parse_next(stream) } -fn multi_variable<'a, 'input>( +fn method_parameter<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + let array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + // TODO: push this check into array_dimensions(...) + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: Vec::new(), + }) + } +} + +// TODO: split to a separate type +fn variable_declaration<'a, 'input>( stream: &mut PtxParser<'a, 'input>, -) -> PResult>> { +) -> PResult<(Option, Option, ScalarType, &'input str)> { ( - variable, - opt(delimited(Token::Lt, u32, Token::Gt)), - Token::Semicolon, + opt(align.verify(|x| x.count_ones() == 1)), + vector_prefix, + scalar_type, + ident, ) - .map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count })) .parse_next(stream) } -fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { - dispatch! {any; - Token::DotReg => variable_scalar_or_vector(StateSpace::Reg), - Token::DotLocal => variable_scalar_or_vector(StateSpace::Local), - Token::DotParam => variable_scalar_or_vector(StateSpace::Param), - Token::DotShared => variable_scalar_or_vector(StateSpace::Shared), - _ => fail +fn multi_variable<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, MultiVariable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let ((align, vector, type_, name), count) = ( + variable_declaration, + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names + opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), + ) + .parse_next(stream)?; + if count.is_some() { + return Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_vector_parsed(vector, type_), + state_space, + name, + array_init: Vec::new(), + }, + count, + }); + } + let mut array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + let initializer = match state_space { + StateSpace::Global | StateSpace::Const => match array_dimensions { + Some(ref mut dimensions) => { + opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + } + None => opt(value_initializer(vector, type_)).parse_next(stream)?, + }, + _ => None, + }; + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: initializer.unwrap_or(Vec::new()), + }, + count, + }) } +} + +fn array_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, + array_dimensions: &mut Vec, +) -> impl Parser, Vec, ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants and multi dim arrays + if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + delimited( + Token::LBracket, + separated( + array_dimensions[0] as usize..=array_dimensions[0] as usize, + single_value_append(&mut result, type_), + Token::Comma, + ), + Token::RBracket, + ) + .parse_next(stream)?; + Ok(result) + } +} + +fn value_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, +) -> impl Parser, Vec, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants + if vector.is_some() { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + single_value_append(&mut result, type_).parse_next(stream)?; + Ok(result) + } +} + +fn single_value_append<'a, 'input: 'a>( + accumulator: &mut Vec, + type_: ScalarType, +) -> impl Parser, (), ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + let value = immediate_value.parse_next(stream)?; + match (type_, value) { + (ScalarType::U8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &u64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::U64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &u64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S8, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S8, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i8::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S16, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S16, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i16::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S32, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S32, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i32::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S64, ImmediateValue::U64(x)) => accumulator.extend_from_slice( + &i64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::S64, ImmediateValue::S64(x)) => accumulator.extend_from_slice( + &i64::try_from(x) + .map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))? + .to_le_bytes(), + ), + (ScalarType::F32, ImmediateValue::F32(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + (ScalarType::F64, ImmediateValue::F64(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), + } + Ok(()) + } +} + +fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + let dimension = delimited( + Token::LBracket, + opt(u32).verify(|dim| *dim != Some(0)), + Token::RBracket, + ) + .parse_next(stream)?; + let result = vec![dimension.unwrap_or(0)]; + repeat_fold_0_or_more( + delimited( + Token::LBracket, + u32.verify(|dim| *dim != 0), + Token::RBracket, + ), + move || result, + |mut result: Vec, x| { + result.push(x); + result + }, + stream, + ) +} + +// Copied and fixed from Winnow sources (fold_repeat0_) +// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator, +// this really should be FnOnce() -> Result +fn repeat_fold_0_or_more( + mut f: F, + init: H, + mut g: G, + input: &mut I, +) -> PResult +where + I: Stream, + F: Parser, + G: FnMut(R, O) -> R, + H: FnOnce() -> R, + E: ParserError, +{ + use winnow::error::ErrMode; + let mut res = init(); + loop { + let start = input.checkpoint(); + match f.parse_next(input) { + Ok(o) => { + res = g(res, o); + } + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + return Ok(res); + } + Err(e) => { + return Err(e); + } + } + } +} + +fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotGlobal.value(StateSpace::Global), + Token::DotConst.value(StateSpace::Const), + Token::DotShared.value(StateSpace::Shared), + )) .parse_next(stream) } -fn variable_scalar_or_vector<'a, 'input: 'a>( - state_space: StateSpace, -) -> impl Parser, ast::Variable<&'input str>, ContextError> { - move |stream: &mut PtxParser<'a, 'input>| { - (opt(align), scalar_vector_type, ident) - .map(|(align, v_type, name)| ast::Variable { - align, - v_type, - state_space, - name, - array_init: Vec::new(), - }) - .parse_next(stream) - } +fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotReg.value(StateSpace::Reg), + Token::DotLocal.value(StateSpace::Local), + Token::DotParam.value(StateSpace::Param), + global_space, + )) + .parse_next(stream) } fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { preceded(Token::DotAlign, u32).parse_next(stream) } -fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { - ( - opt(alt(( - Token::DotV2.value(VectorPrefix::V2), - Token::DotV4.value(VectorPrefix::V4), - ))), - scalar_type, - ) - .map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar)) - .parse_next(stream) +fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + opt(alt(( + Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }), + Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }), + Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }), + ))) + .parse_next(stream) } fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { @@ -1157,6 +1422,8 @@ derive_parser!( Minus, #[token("+")] Plus, + #[token("=")] + Eq, #[token(".version")] DotVersion, #[token(".loc")] @@ -2509,7 +2776,7 @@ derive_parser!( scope: scope.unwrap_or(MemScope::Gpu), space: global.unwrap_or(StateSpace::Generic), op: ast::AtomicOp::new(float_op, f32.kind()), - type_: ast::Type::Vector(f32, vec_32_bit.len()) + type_: ast::Type::Vector(vec_32_bit.len().get(), f32) }, arguments: AtomArgs { dst: d, src1: a, src2: b } } @@ -2840,7 +3107,7 @@ derive_parser!( // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; prmt.b32 d, a, b, c => { match c { - ast::ParsedOperand::Imm(ImmediateValue::U64(control)) => ast::Instruction::Prmt { + ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt { data: control as u16, arguments: PrmtArgs { dst: d, src1: a, src2: b