More fixes

This commit is contained in:
Andrzej Janik 2024-09-03 02:19:27 +02:00
commit 8d15499acc
9 changed files with 487 additions and 104 deletions

View file

@ -4,7 +4,7 @@ use super::{
};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
use std::cmp::Ordering;
use std::{cmp::Ordering, num::NonZeroU8};
pub enum Statement<P: Operand> {
Label(P::Ident),
@ -760,19 +760,37 @@ pub enum Type {
// .param.b32 foo;
Scalar(ScalarType),
// .param.v2.b32 foo;
Vector(ScalarType, u8),
Vector(u8, ScalarType),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
Pointer(ScalarType, StateSpace),
}
impl Type {
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
match vector {
Some(prefix) => Type::Vector(scalar, prefix.len()),
Some(prefix) => Type::Vector(prefix.len().get(), scalar),
None => Type::Scalar(scalar),
}
}
pub(crate) fn maybe_vector_parsed(prefix: Option<NonZeroU8>, scalar: ScalarType) -> Self {
match prefix {
Some(prefix) => Type::Vector(prefix.get(), scalar),
None => Type::Scalar(scalar),
}
}
pub(crate) fn maybe_array(
prefix: Option<NonZeroU8>,
scalar: ScalarType,
array: Option<Vec<u32>>,
) -> Self {
match array {
Some(dimensions) => Type::Array(prefix, scalar, dimensions),
None => Self::maybe_vector_parsed(prefix, scalar),
}
}
}
impl ScalarType {
@ -1304,7 +1322,9 @@ impl<T: Operand> CallArgs<T> {
.input_arguments
.into_iter()
.zip(details.input_arguments.iter())
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), false, false))
.map(|(param, (type_, space))| {
visitor.visit(param, Some((type_, *space)), false, false)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(CallArgs {
return_arguments,

View file

@ -0,0 +1,69 @@
import os, sys, subprocess
SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
MULTIVAR = ["", "<1>" ]
VECTOR = ["", ".v2" ]
HEADER = """
.version 8.5
.target sm_90
.address_size 64
"""
def directive(space, variable, multivar, vector):
return """{3}
{0} {4} .b32 variable{2} {1};
""".format(space, variable, multivar, HEADER, vector)
def entry_arg(space, variable, multivar, vector):
return """{3}
.entry foobar ( {0} {4} .b32 variable{2} {1})
{{
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def fn_arg(space, variable, multivar, vector):
return """{3}
.func foobar ( {0} {4} .b32 variable{2} {1})
{{
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def fn_body(space, variable, multivar, vector):
return """{3}
.func foobar ()
{{
{0} {4} .b32 variable{2} {1};
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def generate(generator):
legal = []
for space in SPACE:
for init in TYPE_AND_INIT:
for multi in MULTIVAR:
for vector in VECTOR:
ptx = generator(space, init, multi, vector)
if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
legal.append((space, vector, init, multi))
print(generator.__name__)
print(legal)
def main():
generate(directive)
generate(entry_arg)
generate(fn_arg)
generate(fn_body)
if __name__ == "__main__":
main()

View file

@ -3,9 +3,10 @@ use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::num::{ParseFloatError, ParseIntError};
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use winnow::ascii::dec_uint;
use winnow::combinator::*;
use winnow::error::{ErrMode, ErrorKind};
use winnow::stream::Accumulate;
use winnow::token::any;
use winnow::{
@ -72,11 +73,13 @@ impl From<RawRoundingMode> for ast::RoundingMode {
}
impl VectorPrefix {
pub(crate) fn len(self) -> u8 {
match self {
VectorPrefix::V2 => 2,
VectorPrefix::V4 => 4,
VectorPrefix::V8 => 8,
pub(crate) fn len(self) -> NonZeroU8 {
unsafe {
match self {
VectorPrefix::V2 => NonZeroU8::new_unchecked(2),
VectorPrefix::V4 => NonZeroU8::new_unchecked(4),
VectorPrefix::V8 => NonZeroU8::new_unchecked(8),
}
}
}
}
@ -386,22 +389,14 @@ fn module_variable<'a, 'input>(
) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> {
(
linking_directives,
module_variable_state_space.flat_map(variable_scalar_or_vector),
global_space
.flat_map(multi_variable)
// TODO: support multi var in globals
.map(|multi_var| multi_var.var),
)
.parse_next(stream)
}
fn module_variable_state_space<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<StateSpace> {
alt((
Token::DotConst.value(StateSpace::Const),
Token::DotGlobal.value(StateSpace::Global),
Token::DotShared.value(StateSpace::Shared),
))
.parse_next(stream)
}
fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
(
Token::DotFile,
@ -547,17 +542,13 @@ fn kernel_arguments<'a, 'input>(
fn kernel_input<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::Variable<&'input str>> {
preceded(
Token::DotParam,
variable_scalar_or_vector(StateSpace::Param),
)
.parse_next(stream)
preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream)
}
fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
dispatch! { any;
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
Token::DotParam => method_parameter(StateSpace::Param),
Token::DotReg => method_parameter(StateSpace::Reg),
_ => fail
}
.parse_next(stream)
@ -596,7 +587,7 @@ fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32
}
}
separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..3, u32, Token::Comma)
separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma)
.map(|acc| acc.value)
.parse_next(stream)
}
@ -618,7 +609,12 @@ fn statement<'a, 'input>(
alt((
label.map(Some),
debug_directive.map(|_| None),
multi_variable.map(Some),
terminated(
method_space
.flat_map(multi_variable)
.map(|var| Some(Statement::Variable(var))),
Token::Semicolon,
),
predicated_instruction.map(Some),
pragma.map(|_| None),
block_statement.map(Some),
@ -632,59 +628,328 @@ fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
.parse_next(stream)
}
fn multi_variable<'a, 'input>(
fn method_parameter<'a, 'input: 'a>(
state_space: StateSpace,
) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
let (align, vector, type_, name) = variable_declaration.parse_next(stream)?;
let array_dimensions = if state_space != StateSpace::Reg {
opt(array_dimensions).parse_next(stream)?
} else {
None
};
// TODO: push this check into array_dimensions(...)
if let Some(ref dims) = array_dimensions {
if dims[0] == 0 {
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
}
}
Ok(Variable {
align,
v_type: Type::maybe_array(vector, type_, array_dimensions),
state_space,
name,
array_init: Vec::new(),
})
}
}
// TODO: split to a separate type
fn variable_declaration<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> {
(
variable,
opt(delimited(Token::Lt, u32, Token::Gt)),
Token::Semicolon,
opt(align.verify(|x| x.count_ones() == 1)),
vector_prefix,
scalar_type,
ident,
)
.map(|(var, count, _)| ast::Statement::Variable(ast::MultiVariable { var, count }))
.parse_next(stream)
}
fn variable<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
dispatch! {any;
Token::DotReg => variable_scalar_or_vector(StateSpace::Reg),
Token::DotLocal => variable_scalar_or_vector(StateSpace::Local),
Token::DotParam => variable_scalar_or_vector(StateSpace::Param),
Token::DotShared => variable_scalar_or_vector(StateSpace::Shared),
_ => fail
fn multi_variable<'a, 'input: 'a>(
state_space: StateSpace,
) -> impl Parser<PtxParser<'a, 'input>, MultiVariable<&'input str>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
let ((align, vector, type_, name), count) = (
variable_declaration,
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)),
)
.parse_next(stream)?;
if count.is_some() {
return Ok(MultiVariable {
var: Variable {
align,
v_type: Type::maybe_vector_parsed(vector, type_),
state_space,
name,
array_init: Vec::new(),
},
count,
});
}
let mut array_dimensions = if state_space != StateSpace::Reg {
opt(array_dimensions).parse_next(stream)?
} else {
None
};
let initializer = match state_space {
StateSpace::Global | StateSpace::Const => match array_dimensions {
Some(ref mut dimensions) => {
opt(array_initializer(vector, type_, dimensions)).parse_next(stream)?
}
None => opt(value_initializer(vector, type_)).parse_next(stream)?,
},
_ => None,
};
if let Some(ref dims) = array_dimensions {
if dims[0] == 0 {
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
}
}
Ok(MultiVariable {
var: Variable {
align,
v_type: Type::maybe_array(vector, type_, array_dimensions),
state_space,
name,
array_init: initializer.unwrap_or(Vec::new()),
},
count,
})
}
}
fn array_initializer<'a, 'input: 'a>(
vector: Option<NonZeroU8>,
type_: ScalarType,
array_dimensions: &mut Vec<u32>,
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + '_ {
move |stream: &mut PtxParser<'a, 'input>| {
Token::Eq.parse_next(stream)?;
let mut result = Vec::new();
// TODO: vector constants and multi dim arrays
if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 {
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
}
delimited(
Token::LBracket,
separated(
array_dimensions[0] as usize..=array_dimensions[0] as usize,
single_value_append(&mut result, type_),
Token::Comma,
),
Token::RBracket,
)
.parse_next(stream)?;
Ok(result)
}
}
fn value_initializer<'a, 'input: 'a>(
vector: Option<NonZeroU8>,
type_: ScalarType,
) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
Token::Eq.parse_next(stream)?;
let mut result = Vec::new();
// TODO: vector constants
if vector.is_some() {
return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
}
single_value_append(&mut result, type_).parse_next(stream)?;
Ok(result)
}
}
fn single_value_append<'a, 'input: 'a>(
accumulator: &mut Vec<u8>,
type_: ScalarType,
) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + '_ {
move |stream: &mut PtxParser<'a, 'input>| {
let value = immediate_value.parse_next(stream)?;
match (type_, value) {
(ScalarType::U8, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&u8::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U8, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&u8::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U16, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&u16::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U16, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&u16::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U32, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&u32::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U32, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&u32::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U64, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&u64::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::U64, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&u64::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S8, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&i8::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S8, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&i8::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S16, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&i16::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S16, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&i16::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S32, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&i32::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S32, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&i32::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S64, ImmediateValue::U64(x)) => accumulator.extend_from_slice(
&i64::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::S64, ImmediateValue::S64(x)) => accumulator.extend_from_slice(
&i64::try_from(x)
.map_err(|_| ErrMode::from_error_kind(stream, ErrorKind::Verify))?
.to_le_bytes(),
),
(ScalarType::F32, ImmediateValue::F32(x)) => {
accumulator.extend_from_slice(&x.to_le_bytes())
}
(ScalarType::F64, ImmediateValue::F64(x)) => {
accumulator.extend_from_slice(&x.to_le_bytes())
}
_ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)),
}
Ok(())
}
}
fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Vec<u32>> {
let dimension = delimited(
Token::LBracket,
opt(u32).verify(|dim| *dim != Some(0)),
Token::RBracket,
)
.parse_next(stream)?;
let result = vec![dimension.unwrap_or(0)];
repeat_fold_0_or_more(
delimited(
Token::LBracket,
u32.verify(|dim| *dim != 0),
Token::RBracket,
),
move || result,
|mut result: Vec<u32>, x| {
result.push(x);
result
},
stream,
)
}
// Copied and fixed from Winnow sources (fold_repeat0_)
// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator,
// this really should be FnOnce() -> Result
fn repeat_fold_0_or_more<I, O, E, F, G, H, R>(
mut f: F,
init: H,
mut g: G,
input: &mut I,
) -> PResult<R, E>
where
I: Stream,
F: Parser<I, O, E>,
G: FnMut(R, O) -> R,
H: FnOnce() -> R,
E: ParserError<I>,
{
use winnow::error::ErrMode;
let mut res = init();
loop {
let start = input.checkpoint();
match f.parse_next(input) {
Ok(o) => {
res = g(res, o);
}
Err(ErrMode::Backtrack(_)) => {
input.reset(&start);
return Ok(res);
}
Err(e) => {
return Err(e);
}
}
}
}
fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
alt((
Token::DotGlobal.value(StateSpace::Global),
Token::DotConst.value(StateSpace::Const),
Token::DotShared.value(StateSpace::Shared),
))
.parse_next(stream)
}
fn variable_scalar_or_vector<'a, 'input: 'a>(
state_space: StateSpace,
) -> impl Parser<PtxParser<'a, 'input>, ast::Variable<&'input str>, ContextError> {
move |stream: &mut PtxParser<'a, 'input>| {
(opt(align), scalar_vector_type, ident)
.map(|(align, v_type, name)| ast::Variable {
align,
v_type,
state_space,
name,
array_init: Vec::new(),
})
.parse_next(stream)
}
fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
alt((
Token::DotReg.value(StateSpace::Reg),
Token::DotLocal.value(StateSpace::Local),
Token::DotParam.value(StateSpace::Param),
global_space,
))
.parse_next(stream)
}
fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
preceded(Token::DotAlign, u32).parse_next(stream)
}
fn scalar_vector_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Type> {
(
opt(alt((
Token::DotV2.value(VectorPrefix::V2),
Token::DotV4.value(VectorPrefix::V4),
))),
scalar_type,
)
.map(|(prefix, scalar)| ast::Type::maybe_vector(prefix, scalar))
.parse_next(stream)
fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Option<NonZeroU8>> {
opt(alt((
Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }),
Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }),
Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }),
)))
.parse_next(stream)
}
fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
@ -1157,6 +1422,8 @@ derive_parser!(
Minus,
#[token("+")]
Plus,
#[token("=")]
Eq,
#[token(".version")]
DotVersion,
#[token(".loc")]
@ -2509,7 +2776,7 @@ derive_parser!(
scope: scope.unwrap_or(MemScope::Gpu),
space: global.unwrap_or(StateSpace::Generic),
op: ast::AtomicOp::new(float_op, f32.kind()),
type_: ast::Type::Vector(f32, vec_32_bit.len())
type_: ast::Type::Vector(vec_32_bit.len().get(), f32)
},
arguments: AtomArgs { dst: d, src1: a, src2: b }
}
@ -2840,7 +3107,7 @@ derive_parser!(
// .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 };
prmt.b32 d, a, b, c => {
match c {
ast::ParsedOperand::Imm(ImmediateValue::U64(control)) => ast::Instruction::Prmt {
ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt {
data: control as u16,
arguments: PrmtArgs {
dst: d, src1: a, src2: b