Parse call instruction

This commit is contained in:
Andrzej Janik 2024-08-20 03:53:18 +02:00
parent 34b0a67f0a
commit c21c55dfc2
3 changed files with 177 additions and 39 deletions

View file

@ -5,7 +5,8 @@ edition = "2021"
[dependencies]
logos = "0.14"
winnow = { version = "0.6.18", features = ["debug"] }
winnow = { version = "0.6.18" }
gen = { path = "../gen" }
thiserror = "1.0"
bitflags = "1.2"
rustc-hash = "2.0.0"

View file

@ -557,7 +557,7 @@ pub struct SetpData {
impl SetpData {
pub(crate) fn try_parse(
errors: &mut PtxParserState,
state: &mut PtxParserState,
cmp_op: super::RawSetpCompareOp,
ftz: bool,
type_: ScalarType,
@ -565,7 +565,7 @@ impl SetpData {
let flush_to_zero = match (ftz, type_) {
(_, ScalarType::F32) => Some(ftz),
_ => {
errors.push(PtxError::NonF32Ftz);
state.errors.push(PtxError::NonF32Ftz);
None
}
};
@ -576,7 +576,7 @@ impl SetpData {
match SetpCompareInt::try_from(cmp_op) {
Ok(op) => SetpCompareOp::Integer(op),
Err(err) => {
errors.push(err);
state.errors.push(err);
SetpCompareOp::Integer(SetpCompareInt::Eq)
}
}
@ -682,36 +682,52 @@ impl From<RawSetpCompareOp> for SetpCompareFloat {
}
pub struct CallDetails {
uniform: bool,
ret_params: Vec<(Type, StateSpace)>,
param_list: Vec<(Type, StateSpace)>,
pub uniform: bool,
pub return_arguments: Vec<(Type, StateSpace)>,
pub input_arguments: Vec<(Type, StateSpace)>,
}
pub struct CallArgs<T: Operand> {
pub ret_params: Vec<T::Ident>,
pub return_arguments: Vec<T::Ident>,
pub func: T::Ident,
pub param_list: Vec<T>,
pub input_arguments: Vec<T>,
}
impl<T: Operand> CallArgs<T> {
#[allow(dead_code)] // Used by generated code
fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor<T>) {
for (param, (type_, space)) in self.ret_params.iter().zip(details.ret_params.iter()) {
for (param, (type_, space)) in self
.return_arguments
.iter()
.zip(details.return_arguments.iter())
{
visitor.visit_ident(param, Some((type_, *space)), true);
}
visitor.visit_ident(&self.func, None, false);
for (param, (type_, space)) in self.param_list.iter().zip(details.param_list.iter()) {
for (param, (type_, space)) in self
.input_arguments
.iter()
.zip(details.input_arguments.iter())
{
visitor.visit(param, Some((type_, *space)), true);
}
}
#[allow(dead_code)] // Used by generated code
fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut<T>) {
for (param, (type_, space)) in self.ret_params.iter_mut().zip(details.ret_params.iter()) {
for (param, (type_, space)) in self
.return_arguments
.iter_mut()
.zip(details.return_arguments.iter())
{
visitor.visit_ident(param, Some((type_, *space)), true);
}
visitor.visit_ident(&mut self.func, None, false);
for (param, (type_, space)) in self.param_list.iter_mut().zip(details.param_list.iter()) {
for (param, (type_, space)) in self
.input_arguments
.iter_mut()
.zip(details.input_arguments.iter())
{
visitor.visit(param, Some((type_, *space)), true);
}
}
@ -722,23 +738,23 @@ impl<T: Operand> CallArgs<T> {
details: &CallDetails,
visitor: &mut impl VisitorMap<T, U>,
) -> CallArgs<U> {
let ret_params = self
.ret_params
let return_arguments = self
.return_arguments
.into_iter()
.zip(details.ret_params.iter())
.zip(details.return_arguments.iter())
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true))
.collect::<Vec<_>>();
let func = visitor.visit_ident(self.func, None, false);
let param_list = self
.param_list
let input_arguments = self
.input_arguments
.into_iter()
.zip(details.param_list.iter())
.zip(details.input_arguments.iter())
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true))
.collect::<Vec<_>>();
CallArgs {
ret_params,
return_arguments,
func,
param_list,
input_arguments,
}
}
}

View file

@ -1,5 +1,7 @@
use gen::derive_parser;
use logos::Logos;
use rustc_hash::FxHashMap;
use std::fmt::Debug;
use std::mem;
use std::num::{ParseFloatError, ParseIntError};
use winnow::ascii::dec_uint;
@ -69,8 +71,49 @@ impl From<RawFloatRounding> for ast::RoundingMode {
}
}
type PtxParserState = Vec<PtxError>;
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>;
struct PtxParserState<'input> {
errors: Vec<PtxError>,
function_declarations:
FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
}
impl<'input> PtxParserState<'input> {
fn new() -> Self {
Self {
errors: Vec::new(),
function_declarations: FxHashMap::default(),
}
}
fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) {
let name = match function_decl.name {
MethodName::Kernel(name) => name,
MethodName::Func(name) => name,
};
let return_arguments = Self::get_type_space(&*function_decl.return_arguments);
let input_arguments = Self::get_type_space(&*function_decl.input_arguments);
// TODO: check if declarations match
self.function_declarations
.insert(name, (return_arguments, input_arguments));
}
fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> {
input_arguments
.iter()
.map(|var| (var.v_type.clone(), var.state_space))
.collect::<Vec<_>>()
}
}
impl<'input> Debug for PtxParserState<'input> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PtxParserState")
.field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
.finish()
}
}
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>;
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
any.verify_map(|t| {
@ -127,7 +170,7 @@ fn take_error<'a, 'input: 'a, O, E>(
Ok(match parser.parse_next(input)? {
Ok(x) => x,
Err((x, err)) => {
input.state.push(err);
input.state.errors.push(err);
x
}
})
@ -353,7 +396,7 @@ fn function<'a, 'input>(
ast::LinkingDirective,
ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>,
)> {
(
let (linking, function) = (
linking_directives,
method_declaration,
repeat(0.., tuning_directive),
@ -369,7 +412,9 @@ fn function<'a, 'input>(
},
)
})
.parse_next(stream)
.parse_next(stream)?;
stream.state.record_function(&function.func_directive);
Ok((linking, function))
}
fn linking_directives<'a, 'input>(
@ -771,6 +816,10 @@ pub enum PtxError {
#[error("")]
WrongType,
#[error("")]
UnknownFunction,
#[error("")]
MalformedCall,
#[error("")]
WrongArrayType,
#[error("")]
WrongVectorElement,
@ -903,6 +952,74 @@ fn bra<'a, 'input>(
.parse_next(stream)
}
fn call<'a, 'input>(
stream: &mut PtxParser<'a, 'input>,
) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
let (uni, return_arguments, name, input_arguments) = (
opt(Token::DotUni),
opt((
Token::LParen,
separated(1.., ident, Token::Comma).map(|x: Vec<_>| x),
Token::RParen,
Token::Comma,
)
.map(|(_, arguments, _, _)| arguments)),
ident,
opt((
Token::Comma.void(),
Token::LParen.void(),
separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x),
Token::RParen.void(),
)
.map(|(_, _, arguments, _)| arguments)),
)
.parse_next(stream)?;
let uniform = uni.is_some();
let recorded_fn = match stream.state.function_declarations.get(name) {
Some(decl) => decl,
None => {
stream.state.errors.push(PtxError::UnknownFunction);
return Ok(empty_call(uniform, name));
}
};
let return_arguments = return_arguments.unwrap_or(Vec::new());
let input_arguments = input_arguments.unwrap_or(Vec::new());
if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len()
{
stream.state.errors.push(PtxError::MalformedCall);
return Ok(empty_call(uniform, name));
}
let data = CallDetails {
uniform,
return_arguments: recorded_fn.0.clone(),
input_arguments: recorded_fn.1.clone(),
};
let arguments = CallArgs {
return_arguments,
func: name,
input_arguments,
};
Ok(ast::Instruction::Call { data, arguments })
}
fn empty_call<'input>(
uniform: bool,
name: &'input str,
) -> ast::Instruction<ParsedOperandStr<'input>> {
ast::Instruction::Call {
data: CallDetails {
uniform,
return_arguments: Vec::new(),
input_arguments: Vec::new(),
},
arguments: CallArgs {
return_arguments: Vec::new(),
func: name,
input_arguments: Vec::new(),
},
}
}
// Modifiers are turned into arguments to the blocks, with type:
// * If it is an alternative:
// * If it is mandatory then its type is Foo (as defined by the relevant rule)
@ -1033,7 +1150,7 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::St {
data: StData {
@ -1058,7 +1175,7 @@ derive_parser!(
}
st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::St {
data: StData {
@ -1072,7 +1189,7 @@ derive_parser!(
}
st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::St {
data: StData {
@ -1085,7 +1202,7 @@ derive_parser!(
}
}
st.mmio.relaxed.sys{.global}.type [a], b => {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
Instruction::St {
data: ast::StData {
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
@ -1114,7 +1231,7 @@ derive_parser!(
ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => {
let (a, unified) = a;
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
@ -1129,7 +1246,7 @@ derive_parser!(
}
ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => {
if level_prefetch_size.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
@ -1144,7 +1261,7 @@ derive_parser!(
}
ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
@ -1159,7 +1276,7 @@ derive_parser!(
}
ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
}
Instruction::Ld {
data: LdDetails {
@ -1173,7 +1290,7 @@ derive_parser!(
}
}
ld.mmio.relaxed.sys{.global}.type d, [a] => {
state.push(PtxError::Todo);
state.errors.push(PtxError::Todo);
Instruction::Ld {
data: LdDetails {
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
@ -1506,6 +1623,9 @@ derive_parser!(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
bra <= { bra(stream) }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
call <= { call(stream) }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
@ -1558,7 +1678,7 @@ fn main() {
println!("{:?}", &tokens);
let stream = PtxParser {
input: &tokens[..],
state: Vec::new(),
state: PtxParserState::new(),
};
let _module = module.parse(stream).unwrap();
println!("{}", mem::size_of::<Token>());
@ -1567,6 +1687,7 @@ fn main() {
#[cfg(test)]
mod tests {
use super::target;
use super::PtxParserState;
use super::Token;
use logos::Logos;
use winnow::prelude::*;
@ -1578,7 +1699,7 @@ mod tests {
.unwrap();
let stream = super::PtxParser {
input: &tokens[..],
state: Vec::new(),
state: PtxParserState::new(),
};
assert_eq!(target.parse(stream).unwrap(), (11, None));
}
@ -1590,7 +1711,7 @@ mod tests {
.unwrap();
let stream = super::PtxParser {
input: &tokens[..],
state: Vec::new(),
state: PtxParserState::new(),
};
assert_eq!(target.parse(stream).unwrap(), (90, Some('a')));
}
@ -1602,7 +1723,7 @@ mod tests {
.unwrap();
let stream = super::PtxParser {
input: &tokens[..],
state: Vec::new(),
state: PtxParserState::new(),
};
assert!(target.parse(stream).is_err());
}