mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Parse call instruction
This commit is contained in:
parent
34b0a67f0a
commit
c21c55dfc2
3 changed files with 177 additions and 39 deletions
|
@ -5,7 +5,8 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
logos = "0.14"
|
||||
winnow = { version = "0.6.18", features = ["debug"] }
|
||||
winnow = { version = "0.6.18" }
|
||||
gen = { path = "../gen" }
|
||||
thiserror = "1.0"
|
||||
bitflags = "1.2"
|
||||
rustc-hash = "2.0.0"
|
||||
|
|
|
@ -557,7 +557,7 @@ pub struct SetpData {
|
|||
|
||||
impl SetpData {
|
||||
pub(crate) fn try_parse(
|
||||
errors: &mut PtxParserState,
|
||||
state: &mut PtxParserState,
|
||||
cmp_op: super::RawSetpCompareOp,
|
||||
ftz: bool,
|
||||
type_: ScalarType,
|
||||
|
@ -565,7 +565,7 @@ impl SetpData {
|
|||
let flush_to_zero = match (ftz, type_) {
|
||||
(_, ScalarType::F32) => Some(ftz),
|
||||
_ => {
|
||||
errors.push(PtxError::NonF32Ftz);
|
||||
state.errors.push(PtxError::NonF32Ftz);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
@ -576,7 +576,7 @@ impl SetpData {
|
|||
match SetpCompareInt::try_from(cmp_op) {
|
||||
Ok(op) => SetpCompareOp::Integer(op),
|
||||
Err(err) => {
|
||||
errors.push(err);
|
||||
state.errors.push(err);
|
||||
SetpCompareOp::Integer(SetpCompareInt::Eq)
|
||||
}
|
||||
}
|
||||
|
@ -682,36 +682,52 @@ impl From<RawSetpCompareOp> for SetpCompareFloat {
|
|||
}
|
||||
|
||||
pub struct CallDetails {
|
||||
uniform: bool,
|
||||
ret_params: Vec<(Type, StateSpace)>,
|
||||
param_list: Vec<(Type, StateSpace)>,
|
||||
pub uniform: bool,
|
||||
pub return_arguments: Vec<(Type, StateSpace)>,
|
||||
pub input_arguments: Vec<(Type, StateSpace)>,
|
||||
}
|
||||
|
||||
pub struct CallArgs<T: Operand> {
|
||||
pub ret_params: Vec<T::Ident>,
|
||||
pub return_arguments: Vec<T::Ident>,
|
||||
pub func: T::Ident,
|
||||
pub param_list: Vec<T>,
|
||||
pub input_arguments: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: Operand> CallArgs<T> {
|
||||
#[allow(dead_code)] // Used by generated code
|
||||
fn visit(&self, details: &CallDetails, visitor: &mut impl Visitor<T>) {
|
||||
for (param, (type_, space)) in self.ret_params.iter().zip(details.ret_params.iter()) {
|
||||
for (param, (type_, space)) in self
|
||||
.return_arguments
|
||||
.iter()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true);
|
||||
}
|
||||
visitor.visit_ident(&self.func, None, false);
|
||||
for (param, (type_, space)) in self.param_list.iter().zip(details.param_list.iter()) {
|
||||
for (param, (type_, space)) in self
|
||||
.input_arguments
|
||||
.iter()
|
||||
.zip(details.input_arguments.iter())
|
||||
{
|
||||
visitor.visit(param, Some((type_, *space)), true);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Used by generated code
|
||||
fn visit_mut(&mut self, details: &CallDetails, visitor: &mut impl VisitorMut<T>) {
|
||||
for (param, (type_, space)) in self.ret_params.iter_mut().zip(details.ret_params.iter()) {
|
||||
for (param, (type_, space)) in self
|
||||
.return_arguments
|
||||
.iter_mut()
|
||||
.zip(details.return_arguments.iter())
|
||||
{
|
||||
visitor.visit_ident(param, Some((type_, *space)), true);
|
||||
}
|
||||
visitor.visit_ident(&mut self.func, None, false);
|
||||
for (param, (type_, space)) in self.param_list.iter_mut().zip(details.param_list.iter()) {
|
||||
for (param, (type_, space)) in self
|
||||
.input_arguments
|
||||
.iter_mut()
|
||||
.zip(details.input_arguments.iter())
|
||||
{
|
||||
visitor.visit(param, Some((type_, *space)), true);
|
||||
}
|
||||
}
|
||||
|
@ -722,23 +738,23 @@ impl<T: Operand> CallArgs<T> {
|
|||
details: &CallDetails,
|
||||
visitor: &mut impl VisitorMap<T, U>,
|
||||
) -> CallArgs<U> {
|
||||
let ret_params = self
|
||||
.ret_params
|
||||
let return_arguments = self
|
||||
.return_arguments
|
||||
.into_iter()
|
||||
.zip(details.ret_params.iter())
|
||||
.zip(details.return_arguments.iter())
|
||||
.map(|(param, (type_, space))| visitor.visit_ident(param, Some((type_, *space)), true))
|
||||
.collect::<Vec<_>>();
|
||||
let func = visitor.visit_ident(self.func, None, false);
|
||||
let param_list = self
|
||||
.param_list
|
||||
let input_arguments = self
|
||||
.input_arguments
|
||||
.into_iter()
|
||||
.zip(details.param_list.iter())
|
||||
.zip(details.input_arguments.iter())
|
||||
.map(|(param, (type_, space))| visitor.visit(param, Some((type_, *space)), true))
|
||||
.collect::<Vec<_>>();
|
||||
CallArgs {
|
||||
ret_params,
|
||||
return_arguments,
|
||||
func,
|
||||
param_list,
|
||||
input_arguments,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use gen::derive_parser;
|
||||
use logos::Logos;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::mem;
|
||||
use std::num::{ParseFloatError, ParseIntError};
|
||||
use winnow::ascii::dec_uint;
|
||||
|
@ -69,8 +71,49 @@ impl From<RawFloatRounding> for ast::RoundingMode {
|
|||
}
|
||||
}
|
||||
|
||||
type PtxParserState = Vec<PtxError>;
|
||||
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState>;
|
||||
struct PtxParserState<'input> {
|
||||
errors: Vec<PtxError>,
|
||||
function_declarations:
|
||||
FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
|
||||
}
|
||||
|
||||
impl<'input> PtxParserState<'input> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
errors: Vec::new(),
|
||||
function_declarations: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) {
|
||||
let name = match function_decl.name {
|
||||
MethodName::Kernel(name) => name,
|
||||
MethodName::Func(name) => name,
|
||||
};
|
||||
let return_arguments = Self::get_type_space(&*function_decl.return_arguments);
|
||||
let input_arguments = Self::get_type_space(&*function_decl.input_arguments);
|
||||
// TODO: check if declarations match
|
||||
self.function_declarations
|
||||
.insert(name, (return_arguments, input_arguments));
|
||||
}
|
||||
|
||||
fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> {
|
||||
input_arguments
|
||||
.iter()
|
||||
.map(|var| (var.v_type.clone(), var.state_space))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'input> Debug for PtxParserState<'input> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PtxParserState")
|
||||
.field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'input>>;
|
||||
|
||||
fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
|
||||
any.verify_map(|t| {
|
||||
|
@ -127,7 +170,7 @@ fn take_error<'a, 'input: 'a, O, E>(
|
|||
Ok(match parser.parse_next(input)? {
|
||||
Ok(x) => x,
|
||||
Err((x, err)) => {
|
||||
input.state.push(err);
|
||||
input.state.errors.push(err);
|
||||
x
|
||||
}
|
||||
})
|
||||
|
@ -353,7 +396,7 @@ fn function<'a, 'input>(
|
|||
ast::LinkingDirective,
|
||||
ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>,
|
||||
)> {
|
||||
(
|
||||
let (linking, function) = (
|
||||
linking_directives,
|
||||
method_declaration,
|
||||
repeat(0.., tuning_directive),
|
||||
|
@ -369,7 +412,9 @@ fn function<'a, 'input>(
|
|||
},
|
||||
)
|
||||
})
|
||||
.parse_next(stream)
|
||||
.parse_next(stream)?;
|
||||
stream.state.record_function(&function.func_directive);
|
||||
Ok((linking, function))
|
||||
}
|
||||
|
||||
fn linking_directives<'a, 'input>(
|
||||
|
@ -771,6 +816,10 @@ pub enum PtxError {
|
|||
#[error("")]
|
||||
WrongType,
|
||||
#[error("")]
|
||||
UnknownFunction,
|
||||
#[error("")]
|
||||
MalformedCall,
|
||||
#[error("")]
|
||||
WrongArrayType,
|
||||
#[error("")]
|
||||
WrongVectorElement,
|
||||
|
@ -903,6 +952,74 @@ fn bra<'a, 'input>(
|
|||
.parse_next(stream)
|
||||
}
|
||||
|
||||
fn call<'a, 'input>(
|
||||
stream: &mut PtxParser<'a, 'input>,
|
||||
) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
|
||||
let (uni, return_arguments, name, input_arguments) = (
|
||||
opt(Token::DotUni),
|
||||
opt((
|
||||
Token::LParen,
|
||||
separated(1.., ident, Token::Comma).map(|x: Vec<_>| x),
|
||||
Token::RParen,
|
||||
Token::Comma,
|
||||
)
|
||||
.map(|(_, arguments, _, _)| arguments)),
|
||||
ident,
|
||||
opt((
|
||||
Token::Comma.void(),
|
||||
Token::LParen.void(),
|
||||
separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x),
|
||||
Token::RParen.void(),
|
||||
)
|
||||
.map(|(_, _, arguments, _)| arguments)),
|
||||
)
|
||||
.parse_next(stream)?;
|
||||
let uniform = uni.is_some();
|
||||
let recorded_fn = match stream.state.function_declarations.get(name) {
|
||||
Some(decl) => decl,
|
||||
None => {
|
||||
stream.state.errors.push(PtxError::UnknownFunction);
|
||||
return Ok(empty_call(uniform, name));
|
||||
}
|
||||
};
|
||||
let return_arguments = return_arguments.unwrap_or(Vec::new());
|
||||
let input_arguments = input_arguments.unwrap_or(Vec::new());
|
||||
if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len()
|
||||
{
|
||||
stream.state.errors.push(PtxError::MalformedCall);
|
||||
return Ok(empty_call(uniform, name));
|
||||
}
|
||||
let data = CallDetails {
|
||||
uniform,
|
||||
return_arguments: recorded_fn.0.clone(),
|
||||
input_arguments: recorded_fn.1.clone(),
|
||||
};
|
||||
let arguments = CallArgs {
|
||||
return_arguments,
|
||||
func: name,
|
||||
input_arguments,
|
||||
};
|
||||
Ok(ast::Instruction::Call { data, arguments })
|
||||
}
|
||||
|
||||
fn empty_call<'input>(
|
||||
uniform: bool,
|
||||
name: &'input str,
|
||||
) -> ast::Instruction<ParsedOperandStr<'input>> {
|
||||
ast::Instruction::Call {
|
||||
data: CallDetails {
|
||||
uniform,
|
||||
return_arguments: Vec::new(),
|
||||
input_arguments: Vec::new(),
|
||||
},
|
||||
arguments: CallArgs {
|
||||
return_arguments: Vec::new(),
|
||||
func: name,
|
||||
input_arguments: Vec::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Modifiers are turned into arguments to the blocks, with type:
|
||||
// * If it is an alternative:
|
||||
// * If it is mandatory then its type is Foo (as defined by the relevant rule)
|
||||
|
@ -1033,7 +1150,7 @@ derive_parser!(
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
|
||||
st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
|
||||
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::St {
|
||||
data: StData {
|
||||
|
@ -1058,7 +1175,7 @@ derive_parser!(
|
|||
}
|
||||
st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
|
||||
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::St {
|
||||
data: StData {
|
||||
|
@ -1072,7 +1189,7 @@ derive_parser!(
|
|||
}
|
||||
st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
|
||||
if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::St {
|
||||
data: StData {
|
||||
|
@ -1085,7 +1202,7 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
st.mmio.relaxed.sys{.global}.type [a], b => {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
Instruction::St {
|
||||
data: ast::StData {
|
||||
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
|
||||
|
@ -1114,7 +1231,7 @@ derive_parser!(
|
|||
ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => {
|
||||
let (a, unified) = a;
|
||||
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Ld {
|
||||
data: LdDetails {
|
||||
|
@ -1129,7 +1246,7 @@ derive_parser!(
|
|||
}
|
||||
ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => {
|
||||
if level_prefetch_size.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Ld {
|
||||
data: LdDetails {
|
||||
|
@ -1144,7 +1261,7 @@ derive_parser!(
|
|||
}
|
||||
ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
|
||||
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Ld {
|
||||
data: LdDetails {
|
||||
|
@ -1159,7 +1276,7 @@ derive_parser!(
|
|||
}
|
||||
ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
|
||||
if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
}
|
||||
Instruction::Ld {
|
||||
data: LdDetails {
|
||||
|
@ -1173,7 +1290,7 @@ derive_parser!(
|
|||
}
|
||||
}
|
||||
ld.mmio.relaxed.sys{.global}.type d, [a] => {
|
||||
state.push(PtxError::Todo);
|
||||
state.errors.push(PtxError::Todo);
|
||||
Instruction::Ld {
|
||||
data: LdDetails {
|
||||
qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
|
||||
|
@ -1506,6 +1623,9 @@ derive_parser!(
|
|||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
|
||||
bra <= { bra(stream) }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
|
||||
call <= { call(stream) }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||
ret{.uni} => {
|
||||
Instruction::Ret { data: RetData { uniform: uni } }
|
||||
|
@ -1558,7 +1678,7 @@ fn main() {
|
|||
println!("{:?}", &tokens);
|
||||
let stream = PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
state: PtxParserState::new(),
|
||||
};
|
||||
let _module = module.parse(stream).unwrap();
|
||||
println!("{}", mem::size_of::<Token>());
|
||||
|
@ -1567,6 +1687,7 @@ fn main() {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::target;
|
||||
use super::PtxParserState;
|
||||
use super::Token;
|
||||
use logos::Logos;
|
||||
use winnow::prelude::*;
|
||||
|
@ -1578,7 +1699,7 @@ mod tests {
|
|||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
state: PtxParserState::new(),
|
||||
};
|
||||
assert_eq!(target.parse(stream).unwrap(), (11, None));
|
||||
}
|
||||
|
@ -1590,7 +1711,7 @@ mod tests {
|
|||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
state: PtxParserState::new(),
|
||||
};
|
||||
assert_eq!(target.parse(stream).unwrap(), (90, Some('a')));
|
||||
}
|
||||
|
@ -1602,7 +1723,7 @@ mod tests {
|
|||
.unwrap();
|
||||
let stream = super::PtxParser {
|
||||
input: &tokens[..],
|
||||
state: Vec::new(),
|
||||
state: PtxParserState::new(),
|
||||
};
|
||||
assert!(target.parse(stream).is_err());
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue