mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
PTX parser rewrite (#267)
Replaces traditional LALRPOP-based parser with winnow-based parser to handle out-of-order instruction modifer. Generate instruction type and instruction visitor from a macro instead of writing by hand. Add separate compilation path using the new parser that only works in tests for now
This commit is contained in:
parent
872054ae40
commit
193eb29be8
34 changed files with 14776 additions and 55 deletions
|
@ -1,5 +1,7 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
|
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
members = [
|
members = [
|
||||||
"cuda_base",
|
"cuda_base",
|
||||||
"cuda_types",
|
"cuda_types",
|
||||||
|
@ -15,6 +17,9 @@ members = [
|
||||||
"zluda_redirect",
|
"zluda_redirect",
|
||||||
"zluda_ml",
|
"zluda_ml",
|
||||||
"ptx",
|
"ptx",
|
||||||
|
"ptx_parser",
|
||||||
|
"ptx_parser_macros",
|
||||||
|
"ptx_parser_macros_impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]
|
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]
|
||||||
|
|
|
@ -7,7 +7,7 @@ edition = "2018"
|
||||||
[lib]
|
[lib]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
lalrpop-util = "0.19"
|
ptx_parser = { path = "../ptx_parser" }
|
||||||
regex = "1"
|
regex = "1"
|
||||||
rspirv = "0.7"
|
rspirv = "0.7"
|
||||||
spirv_headers = "1.5"
|
spirv_headers = "1.5"
|
||||||
|
@ -17,8 +17,12 @@ bit-vec = "0.6"
|
||||||
half ="1.6"
|
half ="1.6"
|
||||||
bitflags = "1.2"
|
bitflags = "1.2"
|
||||||
|
|
||||||
|
[dependencies.lalrpop-util]
|
||||||
|
version = "0.19.12"
|
||||||
|
features = ["lexer"]
|
||||||
|
|
||||||
[build-dependencies.lalrpop]
|
[build-dependencies.lalrpop]
|
||||||
version = "0.19"
|
version = "0.19.12"
|
||||||
features = ["lexer"]
|
features = ["lexer"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -16,6 +16,8 @@ pub enum PtxError {
|
||||||
source: ParseFloatError,
|
source: ParseFloatError,
|
||||||
},
|
},
|
||||||
#[error("")]
|
#[error("")]
|
||||||
|
Unsupported32Bit,
|
||||||
|
#[error("")]
|
||||||
SyntaxError,
|
SyntaxError,
|
||||||
#[error("")]
|
#[error("")]
|
||||||
NonF32Ftz,
|
NonF32Ftz,
|
||||||
|
@ -32,15 +34,9 @@ pub enum PtxError {
|
||||||
#[error("")]
|
#[error("")]
|
||||||
NonExternPointer,
|
NonExternPointer,
|
||||||
#[error("{start}:{end}")]
|
#[error("{start}:{end}")]
|
||||||
UnrecognizedStatement {
|
UnrecognizedStatement { start: usize, end: usize },
|
||||||
start: usize,
|
|
||||||
end: usize,
|
|
||||||
},
|
|
||||||
#[error("{start}:{end}")]
|
#[error("{start}:{end}")]
|
||||||
UnrecognizedDirective {
|
UnrecognizedDirective { start: usize, end: usize },
|
||||||
start: usize,
|
|
||||||
end: usize,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For some weird reson this is illegal:
|
// For some weird reson this is illegal:
|
||||||
|
@ -576,11 +572,15 @@ impl CvtDetails {
|
||||||
if saturate {
|
if saturate {
|
||||||
if src.kind() == ScalarKind::Signed {
|
if src.kind() == ScalarKind::Signed {
|
||||||
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
|
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
|
||||||
err.push(ParseError::from(PtxError::SyntaxError));
|
err.push(ParseError::User {
|
||||||
|
error: PtxError::SyntaxError,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if dst == src || dst.size_of() >= src.size_of() {
|
if dst == src || dst.size_of() >= src.size_of() {
|
||||||
err.push(ParseError::from(PtxError::SyntaxError));
|
err.push(ParseError::User {
|
||||||
|
error: PtxError::SyntaxError,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -596,7 +596,9 @@ impl CvtDetails {
|
||||||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if flush_to_zero && dst != ScalarType::F32 {
|
if flush_to_zero && dst != ScalarType::F32 {
|
||||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||||
|
error: PtxError::NonF32Ftz,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
CvtDetails::FloatFromInt(CvtDesc {
|
CvtDetails::FloatFromInt(CvtDesc {
|
||||||
dst,
|
dst,
|
||||||
|
@ -616,7 +618,9 @@ impl CvtDetails {
|
||||||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
if flush_to_zero && src != ScalarType::F32 {
|
if flush_to_zero && src != ScalarType::F32 {
|
||||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||||
|
error: PtxError::NonF32Ftz,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
CvtDetails::IntFromFloat(CvtDesc {
|
CvtDetails::IntFromFloat(CvtDesc {
|
||||||
dst,
|
dst,
|
||||||
|
|
|
@ -24,6 +24,7 @@ lalrpop_mod!(
|
||||||
);
|
);
|
||||||
|
|
||||||
pub mod ast;
|
pub mod ast;
|
||||||
|
pub(crate) mod pass;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
mod translate;
|
mod translate;
|
||||||
|
|
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
|
@ -0,0 +1,299 @@
|
||||||
|
use std::collections::{BTreeMap, BTreeSet};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
/*
|
||||||
|
PTX represents dynamically allocated shared local memory as
|
||||||
|
.extern .shared .b32 shared_mem[];
|
||||||
|
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
|
||||||
|
And in AMD compilation
|
||||||
|
This pass looks for all uses of .extern .shared and converts them to
|
||||||
|
an additional method argument
|
||||||
|
The question is how this artificial argument should be expressed. There are
|
||||||
|
several options:
|
||||||
|
* Straight conversion:
|
||||||
|
.shared .b32 shared_mem[]
|
||||||
|
* Introduce .param_shared statespace:
|
||||||
|
.param_shared .b32 shared_mem
|
||||||
|
or
|
||||||
|
.param_shared .b32 shared_mem[]
|
||||||
|
* Introduce .shared_ptr <SCALAR> type:
|
||||||
|
.param .shared_ptr .b32 shared_mem
|
||||||
|
* Reuse .ptr hint:
|
||||||
|
.param .u64 .ptr shared_mem
|
||||||
|
This is the most tempting, but also the most nonsensical, .ptr is just a
|
||||||
|
hint, which has no semantical meaning (and the output of our
|
||||||
|
transformation has a semantical meaning - we emit additional
|
||||||
|
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
|
||||||
|
*/
|
||||||
|
pub(super) fn run<'input>(
|
||||||
|
module: Vec<Directive<'input>>,
|
||||||
|
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||||
|
new_id: &mut impl FnMut() -> SpirvWord,
|
||||||
|
) -> Result<Vec<Directive<'input>>, TranslateError> {
|
||||||
|
let mut globals_shared = HashMap::new();
|
||||||
|
for dir in module.iter() {
|
||||||
|
match dir {
|
||||||
|
Directive::Variable(
|
||||||
|
_,
|
||||||
|
ast::Variable {
|
||||||
|
state_space: ast::StateSpace::Shared,
|
||||||
|
name,
|
||||||
|
v_type,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
) => {
|
||||||
|
globals_shared.insert(*name, v_type.clone());
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if globals_shared.len() == 0 {
|
||||||
|
return Ok(module);
|
||||||
|
}
|
||||||
|
let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
|
||||||
|
let module = module
|
||||||
|
.into_iter()
|
||||||
|
.map(|directive| match directive {
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
import_as,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
}) => {
|
||||||
|
let call_key = (*func_decl).borrow().name;
|
||||||
|
let statements = statements
|
||||||
|
.into_iter()
|
||||||
|
.map(|statement| {
|
||||||
|
statement.visit_map(
|
||||||
|
&mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
|
||||||
|
if let Some(_) = globals_shared.get(&id) {
|
||||||
|
methods_to_directly_used_shared_globals
|
||||||
|
.entry(call_key)
|
||||||
|
.or_insert_with(HashSet::new)
|
||||||
|
.insert(id);
|
||||||
|
}
|
||||||
|
Ok::<_, TranslateError>(id)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
Ok::<_, TranslateError>(Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
import_as,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
directive => Ok(directive),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
||||||
|
// make sure it gets propagated to `fn1` and `kernel`
|
||||||
|
let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
|
||||||
|
methods_to_directly_used_shared_globals,
|
||||||
|
kernels_methods_call_map,
|
||||||
|
);
|
||||||
|
// now visit every method declaration and inject those additional arguments
|
||||||
|
let mut directives = Vec::with_capacity(module.len());
|
||||||
|
for directive in module.into_iter() {
|
||||||
|
match directive {
|
||||||
|
Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
import_as,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
}) => {
|
||||||
|
let statements = {
|
||||||
|
let func_decl_ref = &mut (*func_decl).borrow_mut();
|
||||||
|
let method_name = func_decl_ref.name;
|
||||||
|
insert_arguments_remap_statements(
|
||||||
|
new_id,
|
||||||
|
kernels_methods_call_map,
|
||||||
|
&globals_shared,
|
||||||
|
&methods_to_indirectly_used_shared_globals,
|
||||||
|
method_name,
|
||||||
|
&mut directives,
|
||||||
|
func_decl_ref,
|
||||||
|
statements,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
directives.push(Directive::Method(Function {
|
||||||
|
func_decl,
|
||||||
|
globals,
|
||||||
|
body: Some(statements),
|
||||||
|
import_as,
|
||||||
|
tuning,
|
||||||
|
linkage,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
directive => directives.push(directive),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(directives)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to compute two kinds of information:
|
||||||
|
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
|
||||||
|
// * If it's a function -> does it use .shared global (directly or indirectly)
|
||||||
|
fn resolve_indirect_uses_of_globals_shared<'input>(
|
||||||
|
methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
|
||||||
|
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||||
|
) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
|
||||||
|
let mut result = HashMap::new();
|
||||||
|
for (method, callees) in kernels_methods_call_map.methods() {
|
||||||
|
let mut indirect_globals = methods_use_of_globals_shared
|
||||||
|
.get(&method)
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.copied()
|
||||||
|
.collect::<BTreeSet<_>>();
|
||||||
|
for &callee in callees {
|
||||||
|
indirect_globals.extend(
|
||||||
|
methods_use_of_globals_shared
|
||||||
|
.get(&ast::MethodName::Func(callee))
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.copied(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
result.insert(method, indirect_globals);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_arguments_remap_statements<'input>(
|
||||||
|
new_id: &mut impl FnMut() -> SpirvWord,
|
||||||
|
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||||
|
globals_shared: &HashMap<SpirvWord, ast::Type>,
|
||||||
|
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||||
|
ast::MethodName<'input, SpirvWord>,
|
||||||
|
BTreeSet<SpirvWord>,
|
||||||
|
>,
|
||||||
|
method_name: ast::MethodName<SpirvWord>,
|
||||||
|
result: &mut Vec<Directive>,
|
||||||
|
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
|
||||||
|
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||||
|
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||||
|
let remapped_globals_in_method =
|
||||||
|
if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
|
||||||
|
match method_name {
|
||||||
|
ast::MethodName::Func(..) => {
|
||||||
|
let remapped_globals = method_globals
|
||||||
|
.iter()
|
||||||
|
.map(|global| {
|
||||||
|
(
|
||||||
|
*global,
|
||||||
|
(
|
||||||
|
new_id(),
|
||||||
|
globals_shared
|
||||||
|
.get(&global)
|
||||||
|
.unwrap_or_else(|| todo!())
|
||||||
|
.clone(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<BTreeMap<_, _>>();
|
||||||
|
for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
|
||||||
|
func_decl_ref.input_arguments.push(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
v_type: shared_global_type.clone(),
|
||||||
|
state_space: ast::StateSpace::Shared,
|
||||||
|
name: *new_shared_global_id,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
remapped_globals
|
||||||
|
}
|
||||||
|
ast::MethodName::Kernel(..) => method_globals
|
||||||
|
.iter()
|
||||||
|
.map(|global| {
|
||||||
|
(
|
||||||
|
*global,
|
||||||
|
(
|
||||||
|
*global,
|
||||||
|
globals_shared
|
||||||
|
.get(&global)
|
||||||
|
.unwrap_or_else(|| todo!())
|
||||||
|
.clone(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<BTreeMap<_, _>>(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Ok(statements);
|
||||||
|
};
|
||||||
|
replace_uses_of_shared_memory(
|
||||||
|
new_id,
|
||||||
|
methods_to_indirectly_used_shared_globals,
|
||||||
|
statements,
|
||||||
|
remapped_globals_in_method,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replace_uses_of_shared_memory<'input>(
|
||||||
|
new_id: &mut impl FnMut() -> SpirvWord,
|
||||||
|
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||||
|
ast::MethodName<'input, SpirvWord>,
|
||||||
|
BTreeSet<SpirvWord>,
|
||||||
|
>,
|
||||||
|
statements: Vec<ExpandedStatement>,
|
||||||
|
remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
|
||||||
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(statements.len());
|
||||||
|
for statement in statements {
|
||||||
|
match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Call {
|
||||||
|
mut data,
|
||||||
|
mut arguments,
|
||||||
|
}) => {
|
||||||
|
// We can safely skip checking call arguments,
|
||||||
|
// because there's simply no way to pass shared ptr
|
||||||
|
// without converting it to .b64 first
|
||||||
|
if let Some(shared_globals_used_by_callee) =
|
||||||
|
methods_to_indirectly_used_shared_globals
|
||||||
|
.get(&ast::MethodName::Func(arguments.func))
|
||||||
|
{
|
||||||
|
for &shared_global_used_by_callee in shared_globals_used_by_callee {
|
||||||
|
let (remapped_shared_id, type_) = remapped_globals_in_method
|
||||||
|
.get(&shared_global_used_by_callee)
|
||||||
|
.unwrap_or_else(|| todo!());
|
||||||
|
data.input_arguments
|
||||||
|
.push((type_.clone(), ast::StateSpace::Shared));
|
||||||
|
arguments.input_arguments.push(*remapped_shared_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data,
|
||||||
|
arguments,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
statement => {
|
||||||
|
let new_statement =
|
||||||
|
statement.visit_map(&mut |id,
|
||||||
|
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
_,
|
||||||
|
_| {
|
||||||
|
Ok::<_, TranslateError>(
|
||||||
|
if let Some((remapped_shared_id, _)) =
|
||||||
|
remapped_globals_in_method.get(&id)
|
||||||
|
{
|
||||||
|
*remapped_shared_id
|
||||||
|
} else {
|
||||||
|
id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
result.push(new_statement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
524
ptx/src/pass/convert_to_stateful_memory_access.rs
Normal file
524
ptx/src/pass/convert_to_stateful_memory_access.rs
Normal file
|
@ -0,0 +1,524 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
use std::{
|
||||||
|
collections::{BTreeSet, HashSet},
|
||||||
|
iter,
|
||||||
|
rc::Rc,
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Our goal here is to transform
|
||||||
|
.visible .entry foobar(.param .u64 input) {
|
||||||
|
.reg .b64 in_addr;
|
||||||
|
.reg .b64 in_addr2;
|
||||||
|
ld.param.u64 in_addr, [input];
|
||||||
|
cvta.to.global.u64 in_addr2, in_addr;
|
||||||
|
}
|
||||||
|
into:
|
||||||
|
.visible .entry foobar(.param .u8 input[]) {
|
||||||
|
.reg .u8 in_addr[];
|
||||||
|
.reg .u8 in_addr2[];
|
||||||
|
ld.param.u8[] in_addr, [input];
|
||||||
|
mov.u8[] in_addr2, in_addr;
|
||||||
|
}
|
||||||
|
or:
|
||||||
|
.visible .entry foobar(.reg .u8 input[]) {
|
||||||
|
.reg .u8 in_addr[];
|
||||||
|
.reg .u8 in_addr2[];
|
||||||
|
mov.u8[] in_addr, input;
|
||||||
|
mov.u8[] in_addr2, in_addr;
|
||||||
|
}
|
||||||
|
or:
|
||||||
|
.visible .entry foobar(.param ptr<u8, global> input) {
|
||||||
|
.reg ptr<u8, global> in_addr;
|
||||||
|
.reg ptr<u8, global> in_addr2;
|
||||||
|
ld.param.ptr<u8, global> in_addr, [input];
|
||||||
|
mov.ptr<u8, global> in_addr2, in_addr;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// TODO: detect more patterns (mov, call via reg, call via param)
|
||||||
|
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
|
||||||
|
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||||
|
// argument expansion
|
||||||
|
// TODO: propagate out of calls and into calls
|
||||||
|
pub(super) fn run<'a, 'input>(
|
||||||
|
func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||||
|
func_body: Vec<TypedStatement>,
|
||||||
|
id_defs: &mut NumericIdResolver<'a>,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||||
|
Vec<TypedStatement>,
|
||||||
|
),
|
||||||
|
TranslateError,
|
||||||
|
> {
|
||||||
|
let mut method_decl = func_args.borrow_mut();
|
||||||
|
if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
|
||||||
|
drop(method_decl);
|
||||||
|
return Ok((func_args, func_body));
|
||||||
|
}
|
||||||
|
if Rc::strong_count(&func_args) != 1 {
|
||||||
|
return Err(error_unreachable());
|
||||||
|
}
|
||||||
|
let func_args_64bit = (*method_decl)
|
||||||
|
.input_arguments
|
||||||
|
.iter()
|
||||||
|
.filter_map(|arg| match arg.v_type {
|
||||||
|
ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
let mut stateful_markers = Vec::new();
|
||||||
|
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
|
||||||
|
for statement in func_body.iter() {
|
||||||
|
match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Cvta {
|
||||||
|
data:
|
||||||
|
ast::CvtaDetails {
|
||||||
|
state_space: ast::StateSpace::Global,
|
||||||
|
direction: ast::CvtaDirection::GenericToExplicit,
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
if let (TypedOperand::Reg(dst), Some(src)) =
|
||||||
|
(arguments.dst, arguments.src.underlying_register())
|
||||||
|
{
|
||||||
|
if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
|
||||||
|
stateful_markers.push((dst, src));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Ld {
|
||||||
|
data:
|
||||||
|
ast::LdDetails {
|
||||||
|
state_space: ast::StateSpace::Param,
|
||||||
|
typ: ast::Type::Scalar(ast::ScalarType::U64),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Ld {
|
||||||
|
data:
|
||||||
|
ast::LdDetails {
|
||||||
|
state_space: ast::StateSpace::Param,
|
||||||
|
typ: ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Ld {
|
||||||
|
data:
|
||||||
|
ast::LdDetails {
|
||||||
|
state_space: ast::StateSpace::Param,
|
||||||
|
typ: ast::Type::Scalar(ast::ScalarType::B64),
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
if let (TypedOperand::Reg(dst), Some(src)) =
|
||||||
|
(arguments.dst, arguments.src.underlying_register())
|
||||||
|
{
|
||||||
|
if func_args_64bit.contains(&src) {
|
||||||
|
multi_hash_map_append(&mut stateful_init_reg, dst, src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if stateful_markers.len() == 0 {
|
||||||
|
drop(method_decl);
|
||||||
|
return Ok((func_args, func_body));
|
||||||
|
}
|
||||||
|
let mut func_args_ptr = HashSet::new();
|
||||||
|
let mut regs_ptr_current = HashSet::new();
|
||||||
|
for (dst, src) in stateful_markers {
|
||||||
|
if let Some(func_args) = stateful_init_reg.get(&src) {
|
||||||
|
for a in func_args {
|
||||||
|
func_args_ptr.insert(*a);
|
||||||
|
regs_ptr_current.insert(src);
|
||||||
|
regs_ptr_current.insert(dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// BTreeSet here to have a stable order of iteration,
|
||||||
|
// unfortunately our tests rely on it
|
||||||
|
let mut regs_ptr_seen = BTreeSet::new();
|
||||||
|
while regs_ptr_current.len() > 0 {
|
||||||
|
let mut regs_ptr_new = HashSet::new();
|
||||||
|
for statement in func_body.iter() {
|
||||||
|
match statement {
|
||||||
|
Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::U64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::S64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
// TODO: don't mark result of double pointer sub or double
|
||||||
|
// pointer add as ptr result
|
||||||
|
if let (TypedOperand::Reg(dst), Some(src1)) =
|
||||||
|
(arguments.dst, arguments.src1.underlying_register())
|
||||||
|
{
|
||||||
|
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
|
||||||
|
regs_ptr_new.insert(dst);
|
||||||
|
}
|
||||||
|
} else if let (TypedOperand::Reg(dst), Some(src2)) =
|
||||||
|
(arguments.dst, arguments.src2.underlying_register())
|
||||||
|
{
|
||||||
|
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
|
||||||
|
regs_ptr_new.insert(dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Statement::Instruction(ast::Instruction::Sub {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::U64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Sub {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::S64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
// TODO: don't mark result of double pointer sub or double
|
||||||
|
// pointer add as ptr result
|
||||||
|
if let (TypedOperand::Reg(dst), Some(src1)) =
|
||||||
|
(arguments.dst, arguments.src1.underlying_register())
|
||||||
|
{
|
||||||
|
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
|
||||||
|
regs_ptr_new.insert(dst);
|
||||||
|
}
|
||||||
|
} else if let (TypedOperand::Reg(dst), Some(src2)) =
|
||||||
|
(arguments.dst, arguments.src2.underlying_register())
|
||||||
|
{
|
||||||
|
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
|
||||||
|
regs_ptr_new.insert(dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id in regs_ptr_current {
|
||||||
|
regs_ptr_seen.insert(id);
|
||||||
|
}
|
||||||
|
regs_ptr_current = regs_ptr_new;
|
||||||
|
}
|
||||||
|
drop(regs_ptr_current);
|
||||||
|
let mut remapped_ids = HashMap::new();
|
||||||
|
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
|
||||||
|
for reg in regs_ptr_seen {
|
||||||
|
let new_id = id_defs.register_variable(
|
||||||
|
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
);
|
||||||
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
align: None,
|
||||||
|
name: new_id,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
}));
|
||||||
|
remapped_ids.insert(reg, new_id);
|
||||||
|
}
|
||||||
|
for arg in (*method_decl).input_arguments.iter_mut() {
|
||||||
|
if !func_args_ptr.contains(&arg.name) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let new_id = id_defs.register_variable(
|
||||||
|
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||||
|
ast::StateSpace::Param,
|
||||||
|
);
|
||||||
|
let old_name = arg.name;
|
||||||
|
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||||
|
arg.name = new_id;
|
||||||
|
remapped_ids.insert(old_name, new_id);
|
||||||
|
}
|
||||||
|
for statement in func_body {
|
||||||
|
match statement {
|
||||||
|
l @ Statement::Label(_) => result.push(l),
|
||||||
|
c @ Statement::Conditional(_) => result.push(c),
|
||||||
|
c @ Statement::Constant(..) => result.push(c),
|
||||||
|
Statement::Variable(var) => {
|
||||||
|
if !remapped_ids.contains_key(&var.name) {
|
||||||
|
result.push(Statement::Variable(var));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::U64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::S64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
}) if is_add_ptr_direct(&remapped_ids, &arguments) => {
|
||||||
|
let (ptr, offset) = match arguments.src1.underlying_register() {
|
||||||
|
Some(src1) if remapped_ids.contains_key(&src1) => {
|
||||||
|
(remapped_ids.get(&src1).unwrap(), arguments.src2)
|
||||||
|
}
|
||||||
|
Some(src2) if remapped_ids.contains_key(&src2) => {
|
||||||
|
(remapped_ids.get(&src2).unwrap(), arguments.src1)
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let dst = arguments.dst.unwrap_reg()?;
|
||||||
|
result.push(Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||||
|
state_space: ast::StateSpace::Global,
|
||||||
|
dst: *remapped_ids.get(&dst).unwrap(),
|
||||||
|
ptr_src: *ptr,
|
||||||
|
offset_src: offset,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Sub {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::U64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
| Statement::Instruction(ast::Instruction::Sub {
|
||||||
|
data:
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: ast::ScalarType::S64,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
arguments,
|
||||||
|
}) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
|
||||||
|
let (ptr, offset) = match arguments.src1.underlying_register() {
|
||||||
|
Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let offset_neg = id_defs.register_intermediate(Some((
|
||||||
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
result.push(Statement::Instruction(ast::Instruction::Neg {
|
||||||
|
data: ast::TypeFtz {
|
||||||
|
type_: ast::ScalarType::S64,
|
||||||
|
flush_to_zero: None,
|
||||||
|
},
|
||||||
|
arguments: ast::NegArgs {
|
||||||
|
src: offset,
|
||||||
|
dst: TypedOperand::Reg(offset_neg),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
let dst = arguments.dst.unwrap_reg()?;
|
||||||
|
result.push(Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||||
|
state_space: ast::StateSpace::Global,
|
||||||
|
dst: *remapped_ids.get(&dst).unwrap(),
|
||||||
|
ptr_src: *ptr,
|
||||||
|
offset_src: TypedOperand::Reg(offset_neg),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
inst @ Statement::Instruction(_) => {
|
||||||
|
let mut post_statements = Vec::new();
|
||||||
|
let new_statement = inst.visit_map(&mut FnVisitor::new(
|
||||||
|
|operand, type_space, is_dst, relaxed_conversion| {
|
||||||
|
convert_to_stateful_memory_access_postprocess(
|
||||||
|
id_defs,
|
||||||
|
&remapped_ids,
|
||||||
|
&mut result,
|
||||||
|
&mut post_statements,
|
||||||
|
operand,
|
||||||
|
type_space,
|
||||||
|
is_dst,
|
||||||
|
relaxed_conversion,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
))?;
|
||||||
|
result.push(new_statement);
|
||||||
|
result.extend(post_statements);
|
||||||
|
}
|
||||||
|
repack @ Statement::RepackVector(_) => {
|
||||||
|
let mut post_statements = Vec::new();
|
||||||
|
let new_statement = repack.visit_map(&mut FnVisitor::new(
|
||||||
|
|operand, type_space, is_dst, relaxed_conversion| {
|
||||||
|
convert_to_stateful_memory_access_postprocess(
|
||||||
|
id_defs,
|
||||||
|
&remapped_ids,
|
||||||
|
&mut result,
|
||||||
|
&mut post_statements,
|
||||||
|
operand,
|
||||||
|
type_space,
|
||||||
|
is_dst,
|
||||||
|
relaxed_conversion,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
))?;
|
||||||
|
result.push(new_statement);
|
||||||
|
result.extend(post_statements);
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(method_decl);
|
||||||
|
Ok((func_args, result))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
|
||||||
|
match id_defs.get_typed(id) {
|
||||||
|
Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
|
||||||
|
| Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
|
||||||
|
| Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_add_ptr_direct(
|
||||||
|
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||||
|
arg: &ast::AddArgs<TypedOperand>,
|
||||||
|
) -> bool {
|
||||||
|
match arg.dst {
|
||||||
|
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
TypedOperand::Reg(dst) => {
|
||||||
|
if !remapped_ids.contains_key(&dst) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if let Some(ref src1_reg) = arg.src1.underlying_register() {
|
||||||
|
if remapped_ids.contains_key(src1_reg) {
|
||||||
|
// don't trigger optimization when adding two pointers
|
||||||
|
if let Some(ref src2_reg) = arg.src2.underlying_register() {
|
||||||
|
return !remapped_ids.contains_key(src2_reg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(ref src2_reg) = arg.src2.underlying_register() {
|
||||||
|
remapped_ids.contains_key(src2_reg)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_sub_ptr_direct(
|
||||||
|
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||||
|
arg: &ast::SubArgs<TypedOperand>,
|
||||||
|
) -> bool {
|
||||||
|
match arg.dst {
|
||||||
|
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
TypedOperand::Reg(dst) => {
|
||||||
|
if !remapped_ids.contains_key(&dst) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match arg.src1.underlying_register() {
|
||||||
|
Some(ref src1_reg) => {
|
||||||
|
if remapped_ids.contains_key(src1_reg) {
|
||||||
|
// don't trigger optimization when subtracting two pointers
|
||||||
|
arg.src2
|
||||||
|
.underlying_register()
|
||||||
|
.map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_to_stateful_memory_access_postprocess(
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||||
|
result: &mut Vec<TypedStatement>,
|
||||||
|
post_statements: &mut Vec<TypedStatement>,
|
||||||
|
operand: TypedOperand,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_conversion: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
operand.map(|operand, _| {
|
||||||
|
Ok(match remapped_ids.get(&operand) {
|
||||||
|
Some(new_id) => {
|
||||||
|
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
|
||||||
|
// TODO: readd if required
|
||||||
|
if let Some((expected_type, expected_space)) = type_space {
|
||||||
|
let implicit_conversion = if relaxed_conversion {
|
||||||
|
if is_dst {
|
||||||
|
super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
|
||||||
|
} else {
|
||||||
|
super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
super::insert_implicit_conversions::default_implicit_conversion
|
||||||
|
};
|
||||||
|
if implicit_conversion(
|
||||||
|
(new_operand_space, &new_operand_type),
|
||||||
|
(expected_space, expected_type),
|
||||||
|
)
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
|
return Ok(*new_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
|
||||||
|
let converting_id = id_defs
|
||||||
|
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||||
|
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
|
||||||
|
ConversionKind::Default
|
||||||
|
} else {
|
||||||
|
ConversionKind::PtrToPtr
|
||||||
|
};
|
||||||
|
if is_dst {
|
||||||
|
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||||
|
src: converting_id,
|
||||||
|
dst: *new_id,
|
||||||
|
from_type: old_operand_type,
|
||||||
|
from_space: old_operand_space,
|
||||||
|
to_type: new_operand_type,
|
||||||
|
to_space: new_operand_space,
|
||||||
|
kind,
|
||||||
|
}));
|
||||||
|
converting_id
|
||||||
|
} else {
|
||||||
|
result.push(Statement::Conversion(ImplicitConversion {
|
||||||
|
src: *new_id,
|
||||||
|
dst: converting_id,
|
||||||
|
from_type: new_operand_type,
|
||||||
|
from_space: new_operand_space,
|
||||||
|
to_type: old_operand_type,
|
||||||
|
to_space: old_operand_space,
|
||||||
|
kind,
|
||||||
|
}));
|
||||||
|
converting_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => operand,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
138
ptx/src/pass/convert_to_typed.rs
Normal file
138
ptx/src/pass/convert_to_typed.rs
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run(
|
||||||
|
func: Vec<UnconditionalStatement>,
|
||||||
|
fn_defs: &GlobalFnDeclResolver,
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Mov {
|
||||||
|
data,
|
||||||
|
arguments:
|
||||||
|
ast::MovArgs {
|
||||||
|
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||||
|
src: ast::ParsedOperand::Reg(src_reg),
|
||||||
|
},
|
||||||
|
} if fn_defs.fns.contains_key(&src_reg) => {
|
||||||
|
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
||||||
|
dst: dst_reg,
|
||||||
|
src: src_reg,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
ast::Instruction::Call { data, arguments } => {
|
||||||
|
let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
|
||||||
|
let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
|
||||||
|
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||||
|
let reresolved_call =
|
||||||
|
Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
|
||||||
|
visitor.func.push(reresolved_call);
|
||||||
|
visitor.func.extend(visitor.post_stmts);
|
||||||
|
}
|
||||||
|
inst => {
|
||||||
|
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||||
|
let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
|
||||||
|
visitor.func.push(instruction);
|
||||||
|
visitor.func.extend(visitor.post_stmts);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||||
|
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||||
|
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct VectorRepackVisitor<'a, 'b> {
|
||||||
|
func: &'b mut Vec<TypedStatement>,
|
||||||
|
id_def: &'b mut NumericIdResolver<'a>,
|
||||||
|
post_stmts: Option<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||||
|
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
|
||||||
|
VectorRepackVisitor {
|
||||||
|
func,
|
||||||
|
id_def,
|
||||||
|
post_stmts: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_vector(
|
||||||
|
&mut self,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
typ: &ast::Type,
|
||||||
|
state_space: ast::StateSpace,
|
||||||
|
idx: Vec<SpirvWord>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
// mov.u32 foobar, {a,b};
|
||||||
|
let scalar_t = match typ {
|
||||||
|
ast::Type::Vector(_, scalar_t) => *scalar_t,
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let temp_vec = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(Some((typ.clone(), state_space)));
|
||||||
|
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||||
|
is_extract: is_dst,
|
||||||
|
typ: scalar_t,
|
||||||
|
packed: temp_vec,
|
||||||
|
unpacked: idx,
|
||||||
|
relaxed_type_check,
|
||||||
|
});
|
||||||
|
if is_dst {
|
||||||
|
self.post_stmts = Some(statement);
|
||||||
|
} else {
|
||||||
|
self.func.push(statement);
|
||||||
|
}
|
||||||
|
Ok(temp_vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
|
||||||
|
for VectorRepackVisitor<'a, 'b>
|
||||||
|
{
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
ident: SpirvWord,
|
||||||
|
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_: bool,
|
||||||
|
_: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
Ok(ident)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
op: ast::ParsedOperand<SpirvWord>,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
Ok(match op {
|
||||||
|
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
|
||||||
|
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
|
||||||
|
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
||||||
|
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
||||||
|
ast::ParsedOperand::VecPack(vec) => {
|
||||||
|
let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
|
||||||
|
TypedOperand::Reg(self.convert_vector(
|
||||||
|
is_dst,
|
||||||
|
relaxed_type_check,
|
||||||
|
type_,
|
||||||
|
space,
|
||||||
|
vec,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
2763
ptx/src/pass/emit_spirv.rs
Normal file
2763
ptx/src/pass/emit_spirv.rs
Normal file
File diff suppressed because it is too large
Load diff
181
ptx/src/pass/expand_arguments.rs
Normal file
181
ptx/src/pass/expand_arguments.rs
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'b>(
|
||||||
|
func: Vec<TypedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
|
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||||
|
Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
|
||||||
|
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||||
|
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||||
|
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||||
|
Statement::Constant(c) => result.push(Statement::Constant(c)),
|
||||||
|
Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
|
||||||
|
s => {
|
||||||
|
let (new_statement, post_stmts) = {
|
||||||
|
let mut visitor = FlattenArguments::new(&mut result, id_def);
|
||||||
|
(s.visit_map(&mut visitor)?, visitor.post_stmts)
|
||||||
|
};
|
||||||
|
result.push(new_statement);
|
||||||
|
result.extend(post_stmts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FlattenArguments<'a, 'b> {
|
||||||
|
func: &'b mut Vec<ExpandedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
post_stmts: Vec<ExpandedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> FlattenArguments<'a, 'b> {
|
||||||
|
fn new(
|
||||||
|
func: &'b mut Vec<ExpandedStatement>,
|
||||||
|
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||||
|
) -> Self {
|
||||||
|
FlattenArguments {
|
||||||
|
func,
|
||||||
|
id_def,
|
||||||
|
post_stmts: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reg_offset(
|
||||||
|
&mut self,
|
||||||
|
reg: SpirvWord,
|
||||||
|
offset: i32,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||||
|
(type_, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
|
||||||
|
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||||
|
if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
let reg_scalar_type = match reg_type {
|
||||||
|
ast::Type::Scalar(underlying_type) => underlying_type,
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let id_constant_stmt = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: reg_scalar_type,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let arith_details = match reg_scalar_type.kind() {
|
||||||
|
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
}),
|
||||||
|
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||||
|
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||||
|
type_: reg_scalar_type,
|
||||||
|
saturate: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
|
||||||
|
self.func
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Add {
|
||||||
|
data: arith_details,
|
||||||
|
arguments: ast::AddArgs {
|
||||||
|
dst: id_add_result,
|
||||||
|
src1: reg,
|
||||||
|
src2: id_constant_stmt,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
Ok(id_add_result)
|
||||||
|
} else {
|
||||||
|
let id_constant_stmt = self.id_def.register_intermediate(
|
||||||
|
ast::Type::Scalar(ast::ScalarType::S64),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id_constant_stmt,
|
||||||
|
typ: ast::ScalarType::S64,
|
||||||
|
value: ast::ImmediateValue::S64(offset as i64),
|
||||||
|
}));
|
||||||
|
let dst = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(type_.clone(), state_space);
|
||||||
|
self.func.push(Statement::PtrAccess(PtrAccess {
|
||||||
|
underlying_type: type_.clone(),
|
||||||
|
state_space: state_space,
|
||||||
|
dst,
|
||||||
|
ptr_src: reg,
|
||||||
|
offset_src: id_constant_stmt,
|
||||||
|
}));
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn immediate(
|
||||||
|
&mut self,
|
||||||
|
value: ast::ImmediateValue,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
let (scalar_t, state_space) =
|
||||||
|
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||||
|
(*scalar, state_space)
|
||||||
|
} else {
|
||||||
|
return Err(TranslateError::UntypedSymbol);
|
||||||
|
};
|
||||||
|
let id = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(ast::Type::Scalar(scalar_t), state_space);
|
||||||
|
self.func.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: id,
|
||||||
|
typ: scalar_t,
|
||||||
|
value,
|
||||||
|
}));
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
args: TypedOperand,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
match args {
|
||||||
|
TypedOperand::Reg(r) => self.reg(r),
|
||||||
|
TypedOperand::Imm(x) => self.immediate(x, type_space),
|
||||||
|
TypedOperand::RegOffset(reg, offset) => {
|
||||||
|
self.reg_offset(reg, offset, type_space, is_dst)
|
||||||
|
}
|
||||||
|
TypedOperand::VecMember(..) => Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
name: <TypedOperand as ptx_parser::Operand>::Ident,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
_is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
|
||||||
|
self.reg(name)
|
||||||
|
}
|
||||||
|
}
|
282
ptx/src/pass/extract_globals.rs
Normal file
282
ptx/src/pass/extract_globals.rs
Normal file
|
@ -0,0 +1,282 @@
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run<'input, 'b>(
|
||||||
|
sorted_statements: Vec<ExpandedStatement>,
|
||||||
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||||
|
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||||
|
let mut global = Vec::new();
|
||||||
|
for statement in sorted_statements {
|
||||||
|
match statement {
|
||||||
|
Statement::Variable(
|
||||||
|
var @ ast::Variable {
|
||||||
|
state_space: ast::StateSpace::Shared,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
)
|
||||||
|
| Statement::Variable(
|
||||||
|
var @ ast::Variable {
|
||||||
|
state_space: ast::StateSpace::Global,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
) => global.push(var),
|
||||||
|
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
|
||||||
|
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Bfe { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
|
||||||
|
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Bfi { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
|
||||||
|
let fn_name: String =
|
||||||
|
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Brev { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
|
||||||
|
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Activemask { arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Atom {
|
||||||
|
data:
|
||||||
|
data @ ast::AtomDetails {
|
||||||
|
op: ast::AtomicOp::IncrementWrap,
|
||||||
|
semantics,
|
||||||
|
scope,
|
||||||
|
space,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
let fn_name = [
|
||||||
|
ZLUDA_PTX_PREFIX,
|
||||||
|
"atom_",
|
||||||
|
semantics_to_ptx_name(semantics),
|
||||||
|
"_",
|
||||||
|
scope_to_ptx_name(scope),
|
||||||
|
"_",
|
||||||
|
space_to_ptx_name(space),
|
||||||
|
"_inc",
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Atom { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Atom {
|
||||||
|
data:
|
||||||
|
data @ ast::AtomDetails {
|
||||||
|
op: ast::AtomicOp::DecrementWrap,
|
||||||
|
semantics,
|
||||||
|
scope,
|
||||||
|
space,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
let fn_name = [
|
||||||
|
ZLUDA_PTX_PREFIX,
|
||||||
|
"atom_",
|
||||||
|
semantics_to_ptx_name(semantics),
|
||||||
|
"_",
|
||||||
|
scope_to_ptx_name(scope),
|
||||||
|
"_",
|
||||||
|
space_to_ptx_name(space),
|
||||||
|
"_dec",
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Atom { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
Statement::Instruction(ast::Instruction::Atom {
|
||||||
|
data:
|
||||||
|
data @ ast::AtomDetails {
|
||||||
|
op: ast::AtomicOp::FloatAdd,
|
||||||
|
semantics,
|
||||||
|
scope,
|
||||||
|
space,
|
||||||
|
..
|
||||||
|
},
|
||||||
|
arguments,
|
||||||
|
}) => {
|
||||||
|
let scalar_type = match data.type_ {
|
||||||
|
ptx_parser::Type::Scalar(scalar) => scalar,
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
let fn_name = [
|
||||||
|
ZLUDA_PTX_PREFIX,
|
||||||
|
"atom_",
|
||||||
|
semantics_to_ptx_name(semantics),
|
||||||
|
"_",
|
||||||
|
scope_to_ptx_name(scope),
|
||||||
|
"_",
|
||||||
|
space_to_ptx_name(space),
|
||||||
|
"_add_",
|
||||||
|
scalar_to_ptx_name(scalar_type),
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
local.push(instruction_to_fn_call(
|
||||||
|
id_def,
|
||||||
|
ptx_impl_imports,
|
||||||
|
ast::Instruction::Atom { data, arguments },
|
||||||
|
fn_name,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
s => local.push(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((local, global))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn instruction_to_fn_call(
|
||||||
|
id_defs: &mut NumericIdResolver,
|
||||||
|
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||||
|
inst: ast::Instruction<SpirvWord>,
|
||||||
|
fn_name: String,
|
||||||
|
) -> Result<ExpandedStatement, TranslateError> {
|
||||||
|
let mut arguments = Vec::new();
|
||||||
|
ast::visit_map(inst, &mut |operand,
|
||||||
|
type_space: Option<(
|
||||||
|
&ast::Type,
|
||||||
|
ast::StateSpace,
|
||||||
|
)>,
|
||||||
|
is_dst,
|
||||||
|
_| {
|
||||||
|
let (typ, space) = match type_space {
|
||||||
|
Some((typ, space)) => (typ.clone(), space),
|
||||||
|
None => return Err(error_unreachable()),
|
||||||
|
};
|
||||||
|
arguments.push((operand, is_dst, typ, space));
|
||||||
|
Ok(SpirvWord(0))
|
||||||
|
})?;
|
||||||
|
let return_arguments_count = arguments
|
||||||
|
.iter()
|
||||||
|
.position(|(desc, is_dst, _, _)| !is_dst)
|
||||||
|
.unwrap_or(arguments.len());
|
||||||
|
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
|
||||||
|
let fn_id = register_external_fn_call(
|
||||||
|
id_defs,
|
||||||
|
ptx_impl_imports,
|
||||||
|
fn_name,
|
||||||
|
return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, _, typ, state)| (typ, *state)),
|
||||||
|
input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, _, typ, state)| (typ, *state)),
|
||||||
|
)?;
|
||||||
|
Ok(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data: ast::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
},
|
||||||
|
arguments: ast::CallArgs {
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(name, _, _, _)| *name)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
func: fn_id,
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(name, _, _, _)| *name)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
||||||
|
match this {
|
||||||
|
ast::ScalarType::B8 => "b8",
|
||||||
|
ast::ScalarType::B16 => "b16",
|
||||||
|
ast::ScalarType::B32 => "b32",
|
||||||
|
ast::ScalarType::B64 => "b64",
|
||||||
|
ast::ScalarType::B128 => "b128",
|
||||||
|
ast::ScalarType::U8 => "u8",
|
||||||
|
ast::ScalarType::U16 => "u16",
|
||||||
|
ast::ScalarType::U16x2 => "u16x2",
|
||||||
|
ast::ScalarType::U32 => "u32",
|
||||||
|
ast::ScalarType::U64 => "u64",
|
||||||
|
ast::ScalarType::S8 => "s8",
|
||||||
|
ast::ScalarType::S16 => "s16",
|
||||||
|
ast::ScalarType::S16x2 => "s16x2",
|
||||||
|
ast::ScalarType::S32 => "s32",
|
||||||
|
ast::ScalarType::S64 => "s64",
|
||||||
|
ast::ScalarType::F16 => "f16",
|
||||||
|
ast::ScalarType::F16x2 => "f16x2",
|
||||||
|
ast::ScalarType::F32 => "f32",
|
||||||
|
ast::ScalarType::F64 => "f64",
|
||||||
|
ast::ScalarType::BF16 => "bf16",
|
||||||
|
ast::ScalarType::BF16x2 => "bf16x2",
|
||||||
|
ast::ScalarType::Pred => "pred",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
|
||||||
|
match this {
|
||||||
|
ast::AtomSemantics::Relaxed => "relaxed",
|
||||||
|
ast::AtomSemantics::Acquire => "acquire",
|
||||||
|
ast::AtomSemantics::Release => "release",
|
||||||
|
ast::AtomSemantics::AcqRel => "acq_rel",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
|
||||||
|
match this {
|
||||||
|
ast::MemScope::Cta => "cta",
|
||||||
|
ast::MemScope::Gpu => "gpu",
|
||||||
|
ast::MemScope::Sys => "sys",
|
||||||
|
ast::MemScope::Cluster => "cluster",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
|
||||||
|
match this {
|
||||||
|
ast::StateSpace::Generic => "generic",
|
||||||
|
ast::StateSpace::Global => "global",
|
||||||
|
ast::StateSpace::Shared => "shared",
|
||||||
|
ast::StateSpace::Reg => "reg",
|
||||||
|
ast::StateSpace::Const => "const",
|
||||||
|
ast::StateSpace::Local => "local",
|
||||||
|
ast::StateSpace::Param => "param",
|
||||||
|
ast::StateSpace::Sreg => "sreg",
|
||||||
|
ast::StateSpace::SharedCluster => "shared_cluster",
|
||||||
|
ast::StateSpace::ParamEntry => "param_entry",
|
||||||
|
ast::StateSpace::SharedCta => "shared_cta",
|
||||||
|
ast::StateSpace::ParamFunc => "param_func",
|
||||||
|
}
|
||||||
|
}
|
130
ptx/src/pass/fix_special_registers.rs
Normal file
130
ptx/src/pass/fix_special_registers.rs
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
use super::*;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub(super) fn run<'a, 'b, 'input>(
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
typed_statements: Vec<TypedStatement>,
|
||||||
|
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let result = Vec::with_capacity(typed_statements.len());
|
||||||
|
let mut sreg_sresolver = SpecialRegisterResolver {
|
||||||
|
ptx_impl_imports,
|
||||||
|
numeric_id_defs,
|
||||||
|
result,
|
||||||
|
};
|
||||||
|
for statement in typed_statements {
|
||||||
|
let statement = statement.visit_map(&mut sreg_sresolver)?;
|
||||||
|
sreg_sresolver.result.push(statement);
|
||||||
|
}
|
||||||
|
Ok(sreg_sresolver.result)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
|
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||||
|
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||||
|
result: Vec<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||||
|
for SpecialRegisterResolver<'a, 'b, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
operand: TypedOperand,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.replace_sreg(args, is_dst, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
||||||
|
fn replace_sreg(
|
||||||
|
&mut self,
|
||||||
|
name: SpirvWord,
|
||||||
|
is_dst: bool,
|
||||||
|
vector_index: Option<u8>,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
|
||||||
|
if is_dst {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||||
|
(Some(idx), Some(inp_type)) => {
|
||||||
|
if inp_type != ast::ScalarType::U8 {
|
||||||
|
return Err(TranslateError::Unreachable);
|
||||||
|
}
|
||||||
|
let constant = self.numeric_id_defs.register_intermediate(Some((
|
||||||
|
ast::Type::Scalar(inp_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
self.result.push(Statement::Constant(ConstantDefinition {
|
||||||
|
dst: constant,
|
||||||
|
typ: inp_type,
|
||||||
|
value: ast::ImmediateValue::U64(idx as u64),
|
||||||
|
}));
|
||||||
|
vec![(
|
||||||
|
TypedOperand::Reg(constant),
|
||||||
|
ast::Type::Scalar(inp_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
(None, None) => Vec::new(),
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||||
|
let return_type = sreg.get_function_return_type();
|
||||||
|
let fn_result = self.numeric_id_defs.register_intermediate(Some((
|
||||||
|
ast::Type::Scalar(return_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
let return_arguments = vec![(
|
||||||
|
fn_result,
|
||||||
|
ast::Type::Scalar(return_type),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)];
|
||||||
|
let fn_call = register_external_fn_call(
|
||||||
|
self.numeric_id_defs,
|
||||||
|
self.ptx_impl_imports,
|
||||||
|
ocl_fn_name.to_string(),
|
||||||
|
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
|
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||||
|
)?;
|
||||||
|
let data = ast::CallDetails {
|
||||||
|
uniform: false,
|
||||||
|
return_arguments: return_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
input_arguments: input_arguments
|
||||||
|
.iter()
|
||||||
|
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||||
|
.collect(),
|
||||||
|
};
|
||||||
|
let arguments = ast::CallArgs {
|
||||||
|
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
|
func: fn_call,
|
||||||
|
input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||||
|
};
|
||||||
|
self.result
|
||||||
|
.push(Statement::Instruction(ast::Instruction::Call {
|
||||||
|
data,
|
||||||
|
arguments,
|
||||||
|
}));
|
||||||
|
Ok(fn_result)
|
||||||
|
} else {
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
432
ptx/src/pass/insert_implicit_conversions.rs
Normal file
432
ptx/src/pass/insert_implicit_conversions.rs
Normal file
|
@ -0,0 +1,432 @@
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
/*
|
||||||
|
There are several kinds of implicit conversions in PTX:
|
||||||
|
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||||
|
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||||
|
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||||
|
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||||
|
documented special ld/st/cvt conversion rules for destination operands
|
||||||
|
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||||
|
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||||
|
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||||
|
documented special ld/st/cvt conversion rules are applied to dst
|
||||||
|
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||||
|
b64/u64/s64, which is bitcast to a pointer
|
||||||
|
*/
|
||||||
|
pub(super) fn run(
|
||||||
|
func: Vec<ExpandedStatement>,
|
||||||
|
id_def: &mut MutableNumericIdResolver,
|
||||||
|
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func.into_iter() {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => {
|
||||||
|
insert_implicit_conversions_impl(
|
||||||
|
&mut result,
|
||||||
|
id_def,
|
||||||
|
Statement::Instruction(inst),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Statement::PtrAccess(access) => {
|
||||||
|
insert_implicit_conversions_impl(
|
||||||
|
&mut result,
|
||||||
|
id_def,
|
||||||
|
Statement::PtrAccess(access),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Statement::RepackVector(repack) => {
|
||||||
|
insert_implicit_conversions_impl(
|
||||||
|
&mut result,
|
||||||
|
id_def,
|
||||||
|
Statement::RepackVector(repack),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
s @ Statement::Conditional(_)
|
||||||
|
| s @ Statement::Conversion(_)
|
||||||
|
| s @ Statement::Label(_)
|
||||||
|
| s @ Statement::Constant(_)
|
||||||
|
| s @ Statement::Variable(_)
|
||||||
|
| s @ Statement::LoadVar(..)
|
||||||
|
| s @ Statement::StoreVar(..)
|
||||||
|
| s @ Statement::RetValue(..)
|
||||||
|
| s @ Statement::FunctionPointer(..) => result.push(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_implicit_conversions_impl(
|
||||||
|
func: &mut Vec<ExpandedStatement>,
|
||||||
|
id_def: &mut MutableNumericIdResolver,
|
||||||
|
stmt: ExpandedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let mut post_conv = Vec::new();
|
||||||
|
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||||
|
&mut |operand,
|
||||||
|
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst,
|
||||||
|
relaxed_type_check| {
|
||||||
|
let (instr_type, instruction_space) = match type_state {
|
||||||
|
None => return Ok(operand),
|
||||||
|
Some(t) => t,
|
||||||
|
};
|
||||||
|
let (operand_type, operand_space) = id_def.get_typed(operand)?;
|
||||||
|
let conversion_fn = if relaxed_type_check {
|
||||||
|
if is_dst {
|
||||||
|
should_convert_relaxed_dst_wrapper
|
||||||
|
} else {
|
||||||
|
should_convert_relaxed_src_wrapper
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
default_implicit_conversion
|
||||||
|
};
|
||||||
|
match conversion_fn(
|
||||||
|
(operand_space, &operand_type),
|
||||||
|
(instruction_space, instr_type),
|
||||||
|
)? {
|
||||||
|
Some(conv_kind) => {
|
||||||
|
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||||
|
let mut from_type = instr_type.clone();
|
||||||
|
let mut from_space = instruction_space;
|
||||||
|
let mut to_type = operand_type;
|
||||||
|
let mut to_space = operand_space;
|
||||||
|
let mut src =
|
||||||
|
id_def.register_intermediate(instr_type.clone(), instruction_space);
|
||||||
|
let mut dst = operand;
|
||||||
|
let result = Ok::<_, TranslateError>(src);
|
||||||
|
if !is_dst {
|
||||||
|
mem::swap(&mut src, &mut dst);
|
||||||
|
mem::swap(&mut from_type, &mut to_type);
|
||||||
|
mem::swap(&mut from_space, &mut to_space);
|
||||||
|
}
|
||||||
|
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
from_type,
|
||||||
|
from_space,
|
||||||
|
to_type,
|
||||||
|
to_space,
|
||||||
|
kind: conv_kind,
|
||||||
|
}));
|
||||||
|
result
|
||||||
|
}
|
||||||
|
None => Ok(operand),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
func.push(statement);
|
||||||
|
func.append(&mut post_conv);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn default_implicit_conversion(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if instruction_space == ast::StateSpace::Reg {
|
||||||
|
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||||
|
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||||
|
(operand_type, instruction_type)
|
||||||
|
{
|
||||||
|
if scalar.kind() == ast::ScalarKind::Bit
|
||||||
|
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||||
|
{
|
||||||
|
return Ok(Some(ConversionKind::Default));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if is_addressable(operand_space) {
|
||||||
|
return Ok(Some(ConversionKind::AddressOf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !space_is_compatible(instruction_space, operand_space) {
|
||||||
|
default_implicit_conversion_space(
|
||||||
|
(operand_space, operand_type),
|
||||||
|
(instruction_space, instruction_type),
|
||||||
|
)
|
||||||
|
} else if instruction_type != operand_type {
|
||||||
|
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||||
|
match this {
|
||||||
|
ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ast::StateSpace::Shared => true,
|
||||||
|
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
|
||||||
|
ast::StateSpace::SharedCluster
|
||||||
|
| ast::StateSpace::SharedCta
|
||||||
|
| ast::StateSpace::ParamEntry
|
||||||
|
| ast::StateSpace::ParamFunc => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Space is different
|
||||||
|
fn default_implicit_conversion_space(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||||
|
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||||
|
{
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||||
|
match operand_type {
|
||||||
|
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||||
|
if *operand_ptr_space == instruction_space =>
|
||||||
|
{
|
||||||
|
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: 32 bit
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||||
|
ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
},
|
||||||
|
ast::Type::Scalar(ast::ScalarType::B32)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||||
|
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||||
|
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||||
|
Ok(Some(ConversionKind::BitToPtr))
|
||||||
|
}
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
},
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
||||||
|
match instruction_type {
|
||||||
|
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||||
|
if operand_space == *instruction_ptr_space =>
|
||||||
|
{
|
||||||
|
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(error_mismatched_type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Space is same, but type is different
|
||||||
|
fn default_implicit_conversion_type(
|
||||||
|
space: ast::StateSpace,
|
||||||
|
operand_type: &ast::Type,
|
||||||
|
instruction_type: &ast::Type,
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if space_is_compatible(space, ast::StateSpace::Reg) {
|
||||||
|
if should_bitcast(instruction_type, operand_type) {
|
||||||
|
Ok(Some(ConversionKind::Default))
|
||||||
|
} else {
|
||||||
|
Err(TranslateError::MismatchedType)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(Some(ConversionKind::PtrToPtr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||||
|
match this {
|
||||||
|
ast::StateSpace::Global
|
||||||
|
| ast::StateSpace::Const
|
||||||
|
| ast::StateSpace::Local
|
||||||
|
| ptx_parser::StateSpace::SharedCta
|
||||||
|
| ast::StateSpace::SharedCluster
|
||||||
|
| ast::StateSpace::Shared => true,
|
||||||
|
ast::StateSpace::Reg
|
||||||
|
| ast::StateSpace::Param
|
||||||
|
| ast::StateSpace::ParamEntry
|
||||||
|
| ast::StateSpace::ParamFunc
|
||||||
|
| ast::StateSpace::Generic
|
||||||
|
| ast::StateSpace::Sreg => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||||
|
match (instr, operand) {
|
||||||
|
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||||
|
if inst.size_of() != operand.size_of() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match inst.kind() {
|
||||||
|
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||||
|
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||||
|
ast::ScalarKind::Signed => {
|
||||||
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Unsigned => {
|
||||||
|
operand.kind() == ast::ScalarKind::Bit
|
||||||
|
|| operand.kind() == ast::ScalarKind::Signed
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||||
|
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||||
|
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if !space_is_compatible(operand_space, instruction_space) {
|
||||||
|
return Err(TranslateError::MismatchedType);
|
||||||
|
}
|
||||||
|
if operand_type == instruction_type {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||||
|
conv @ Some(_) => Ok(conv),
|
||||||
|
None => Err(TranslateError::MismatchedType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||||
|
fn should_convert_relaxed_dst(
|
||||||
|
dst_type: &ast::Type,
|
||||||
|
instr_type: &ast::Type,
|
||||||
|
) -> Option<ConversionKind> {
|
||||||
|
if dst_type == instr_type {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
match (dst_type, instr_type) {
|
||||||
|
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
|
ast::ScalarKind::Bit => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Signed => {
|
||||||
|
if dst_type.kind() != ast::ScalarKind::Float {
|
||||||
|
if instr_type.size_of() == dst_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else if instr_type.size_of() < dst_type.size_of() {
|
||||||
|
Some(ConversionKind::SignExtend)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Unsigned => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
|
&& dst_type.kind() != ast::ScalarKind::Float
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Float => {
|
||||||
|
if instr_type.size_of() <= dst_type.size_of()
|
||||||
|
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => None,
|
||||||
|
},
|
||||||
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
|
should_convert_relaxed_dst(
|
||||||
|
&ast::Type::Scalar(*dst_type),
|
||||||
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||||
|
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||||
|
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||||
|
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||||
|
if !space_is_compatible(operand_space, instruction_space) {
|
||||||
|
return Err(error_mismatched_type());
|
||||||
|
}
|
||||||
|
if operand_type == instruction_type {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||||
|
conv @ Some(_) => Ok(conv),
|
||||||
|
None => Err(error_mismatched_type()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||||
|
fn should_convert_relaxed_src(
|
||||||
|
src_type: &ast::Type,
|
||||||
|
instr_type: &ast::Type,
|
||||||
|
) -> Option<ConversionKind> {
|
||||||
|
if src_type == instr_type {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
match (src_type, instr_type) {
|
||||||
|
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||||
|
ast::ScalarKind::Bit => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of() {
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
|
&& src_type.kind() != ast::ScalarKind::Float
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Float => {
|
||||||
|
if instr_type.size_of() <= src_type.size_of()
|
||||||
|
&& src_type.kind() == ast::ScalarKind::Bit
|
||||||
|
{
|
||||||
|
Some(ConversionKind::Default)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::ScalarKind::Pred => None,
|
||||||
|
},
|
||||||
|
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||||
|
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||||
|
should_convert_relaxed_src(
|
||||||
|
&ast::Type::Scalar(*dst_type),
|
||||||
|
&ast::Type::Scalar(*instr_type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
275
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
275
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
|
@ -0,0 +1,275 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
/*
|
||||||
|
How do we handle arguments:
|
||||||
|
- input .params in kernels
|
||||||
|
.param .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
We do this for two reasons. One, common treatment for argument-declared
|
||||||
|
.param variables and .param variables inside function (we assume that
|
||||||
|
at SPIR-V level every .param is a pointer in Function storage class)
|
||||||
|
- input .params in functions
|
||||||
|
.param .b64 in_arg
|
||||||
|
get turned into this SPIR-V:
|
||||||
|
%1 = OpFunctionParameter %_ptr_Function_ulong
|
||||||
|
- input .regs
|
||||||
|
.reg .b64 in_arg
|
||||||
|
get turned into the same SPIR-V as kernel .params:
|
||||||
|
%1 = OpFunctionParameter %ulong
|
||||||
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
OpStore %2 %1
|
||||||
|
- output .regs
|
||||||
|
.reg .b64 out_arg
|
||||||
|
get just a variable declaration:
|
||||||
|
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||||
|
- output .params don't exist, they have been moved to input positions
|
||||||
|
by an earlier pass
|
||||||
|
Distinguishing betweem kernel .params and function .params is not the
|
||||||
|
cleanest solution. Alternatively, we could "deparamize" all kernel .param
|
||||||
|
arguments by turning them into .reg arguments like this:
|
||||||
|
.param .b64 arg -> .reg ptr<.b64,.param> arg
|
||||||
|
This has the massive downside that this transformation would have to run
|
||||||
|
very early and would muddy up already difficult code. It's simpler to just
|
||||||
|
have an if here
|
||||||
|
*/
|
||||||
|
pub(super) fn run<'a, 'b>(
|
||||||
|
func: Vec<TypedStatement>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
|
||||||
|
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for arg in fn_decl.input_arguments.iter_mut() {
|
||||||
|
insert_mem_ssa_argument(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
arg,
|
||||||
|
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for arg in fn_decl.return_arguments.iter() {
|
||||||
|
insert_mem_ssa_argument_reg_return(&mut result, arg);
|
||||||
|
}
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Ret { data } => {
|
||||||
|
// TODO: handle multiple output args
|
||||||
|
match &fn_decl.return_arguments[..] {
|
||||||
|
[return_reg] => {
|
||||||
|
let new_id = id_def.register_intermediate(Some((
|
||||||
|
return_reg.v_type.clone(),
|
||||||
|
ast::StateSpace::Reg,
|
||||||
|
)));
|
||||||
|
result.push(Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg: ast::LdArgs {
|
||||||
|
dst: new_id,
|
||||||
|
src: return_reg.name,
|
||||||
|
},
|
||||||
|
typ: return_reg.v_type.clone(),
|
||||||
|
member_index: None,
|
||||||
|
}));
|
||||||
|
result.push(Statement::RetValue(data, new_id));
|
||||||
|
}
|
||||||
|
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inst => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::Instruction(inst),
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
Statement::Conditional(bra) => {
|
||||||
|
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
|
||||||
|
}
|
||||||
|
Statement::Conversion(conv) => {
|
||||||
|
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
|
||||||
|
}
|
||||||
|
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::PtrAccess(ptr_access),
|
||||||
|
)?,
|
||||||
|
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::RepackVector(repack),
|
||||||
|
)?,
|
||||||
|
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
|
||||||
|
id_def,
|
||||||
|
&mut result,
|
||||||
|
Statement::FunctionPointer(func_ptr),
|
||||||
|
)?,
|
||||||
|
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
|
||||||
|
result.push(s)
|
||||||
|
}
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument(
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &mut ast::Variable<SpirvWord>,
|
||||||
|
is_kernel: bool,
|
||||||
|
) {
|
||||||
|
if !is_kernel && arg.state_space == ast::StateSpace::Param {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: ast::StateSpace::Reg,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: Vec::new(),
|
||||||
|
}));
|
||||||
|
func.push(Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::StArgs {
|
||||||
|
src1: arg.name,
|
||||||
|
src2: new_id,
|
||||||
|
},
|
||||||
|
typ: arg.v_type.clone(),
|
||||||
|
member_index: None,
|
||||||
|
}));
|
||||||
|
arg.name = new_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_argument_reg_return(
|
||||||
|
func: &mut Vec<TypedStatement>,
|
||||||
|
arg: &ast::Variable<SpirvWord>,
|
||||||
|
) {
|
||||||
|
func.push(Statement::Variable(ast::Variable {
|
||||||
|
align: arg.align,
|
||||||
|
v_type: arg.v_type.clone(),
|
||||||
|
state_space: arg.state_space,
|
||||||
|
name: arg.name,
|
||||||
|
array_init: arg.array_init.clone(),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_mem_ssa_statement_default<'a, 'input>(
|
||||||
|
id_def: &'a mut NumericIdResolver<'input>,
|
||||||
|
func: &'a mut Vec<TypedStatement>,
|
||||||
|
stmt: TypedStatement,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
let mut visitor = InsertMemSSAVisitor {
|
||||||
|
id_def,
|
||||||
|
func,
|
||||||
|
post_statements: Vec::new(),
|
||||||
|
};
|
||||||
|
let new_stmt = stmt.visit_map(&mut visitor)?;
|
||||||
|
visitor.func.push(new_stmt);
|
||||||
|
visitor.func.extend(visitor.post_statements);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
id_def: &'a mut NumericIdResolver<'input>,
|
||||||
|
func: &'a mut Vec<TypedStatement>,
|
||||||
|
post_statements: Vec<TypedStatement>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||||
|
fn symbol(
|
||||||
|
&mut self,
|
||||||
|
symbol: SpirvWord,
|
||||||
|
member_index: Option<u8>,
|
||||||
|
expected: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
if expected.is_none() {
|
||||||
|
return Ok(symbol);
|
||||||
|
};
|
||||||
|
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||||
|
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
||||||
|
return Ok(symbol);
|
||||||
|
};
|
||||||
|
let member_index = match member_index {
|
||||||
|
Some(idx) => {
|
||||||
|
let vector_width = match var_type {
|
||||||
|
ast::Type::Vector(width, scalar_t) => {
|
||||||
|
var_type = ast::Type::Scalar(scalar_t);
|
||||||
|
width
|
||||||
|
}
|
||||||
|
_ => return Err(error_mismatched_type()),
|
||||||
|
};
|
||||||
|
Some((
|
||||||
|
idx,
|
||||||
|
if self.id_def.special_registers.get(symbol).is_some() {
|
||||||
|
Some(vector_width)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let generated_id = self
|
||||||
|
.id_def
|
||||||
|
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
|
||||||
|
if !is_dst {
|
||||||
|
self.func.push(Statement::LoadVar(LoadVarDetails {
|
||||||
|
arg: ast::LdArgs {
|
||||||
|
dst: generated_id,
|
||||||
|
src: symbol,
|
||||||
|
},
|
||||||
|
typ: var_type,
|
||||||
|
member_index,
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
self.post_statements
|
||||||
|
.push(Statement::StoreVar(StoreVarDetails {
|
||||||
|
arg: ast::StArgs {
|
||||||
|
src1: symbol,
|
||||||
|
src2: generated_id,
|
||||||
|
},
|
||||||
|
typ: var_type,
|
||||||
|
member_index: member_index.map(|(idx, _)| idx),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Ok(generated_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||||
|
for InsertMemSSAVisitor<'a, 'input>
|
||||||
|
{
|
||||||
|
fn visit(
|
||||||
|
&mut self,
|
||||||
|
operand: TypedOperand,
|
||||||
|
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
_relaxed_type_check: bool,
|
||||||
|
) -> Result<TypedOperand, TranslateError> {
|
||||||
|
Ok(match operand {
|
||||||
|
TypedOperand::Reg(reg) => {
|
||||||
|
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
|
||||||
|
}
|
||||||
|
TypedOperand::RegOffset(reg, offset) => {
|
||||||
|
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
|
||||||
|
}
|
||||||
|
op @ TypedOperand::Imm(..) => op,
|
||||||
|
TypedOperand::VecMember(symbol, index) => {
|
||||||
|
TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ident(
|
||||||
|
&mut self,
|
||||||
|
args: SpirvWord,
|
||||||
|
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||||
|
is_dst: bool,
|
||||||
|
relaxed_type_check: bool,
|
||||||
|
) -> Result<SpirvWord, TranslateError> {
|
||||||
|
self.symbol(args, None, type_space, is_dst)
|
||||||
|
}
|
||||||
|
}
|
1677
ptx/src/pass/mod.rs
Normal file
1677
ptx/src/pass/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
80
ptx/src/pass/normalize_identifiers.rs
Normal file
80
ptx/src/pass/normalize_identifiers.rs
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run<'input, 'b>(
|
||||||
|
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||||
|
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||||
|
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||||
|
) -> Result<Vec<NormalizedStatement>, TranslateError> {
|
||||||
|
for s in func.iter() {
|
||||||
|
match s {
|
||||||
|
ast::Statement::Label(id) => {
|
||||||
|
id_defs.add_def(*id, None, false);
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for s in func {
|
||||||
|
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expand_map_variables<'a, 'b>(
|
||||||
|
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||||
|
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
||||||
|
result: &mut Vec<NormalizedStatement>,
|
||||||
|
s: ast::Statement<ast::ParsedOperand<&'a str>>,
|
||||||
|
) -> Result<(), TranslateError> {
|
||||||
|
match s {
|
||||||
|
ast::Statement::Block(block) => {
|
||||||
|
id_defs.start_block();
|
||||||
|
for s in block {
|
||||||
|
expand_map_variables(id_defs, fn_defs, result, s)?;
|
||||||
|
}
|
||||||
|
id_defs.end_block();
|
||||||
|
}
|
||||||
|
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
|
||||||
|
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
||||||
|
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
||||||
|
.transpose()?,
|
||||||
|
ast::visit_map(i, &mut |id,
|
||||||
|
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||||
|
_: bool,
|
||||||
|
_: bool| {
|
||||||
|
id_defs.get_id(id)
|
||||||
|
})?,
|
||||||
|
))),
|
||||||
|
ast::Statement::Variable(var) => {
|
||||||
|
let var_type = var.var.v_type.clone();
|
||||||
|
match var.count {
|
||||||
|
Some(count) => {
|
||||||
|
for new_id in
|
||||||
|
id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
|
||||||
|
{
|
||||||
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
align: var.var.align,
|
||||||
|
v_type: var.var.v_type.clone(),
|
||||||
|
state_space: var.var.state_space,
|
||||||
|
name: new_id,
|
||||||
|
array_init: var.var.array_init.clone(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let new_id =
|
||||||
|
id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
|
||||||
|
result.push(Statement::Variable(ast::Variable {
|
||||||
|
align: var.var.align,
|
||||||
|
v_type: var.var.v_type.clone(),
|
||||||
|
state_space: var.var.state_space,
|
||||||
|
name: new_id,
|
||||||
|
array_init: var.var.array_init,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
48
ptx/src/pass/normalize_labels.rs
Normal file
48
ptx/src/pass/normalize_labels.rs
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
use std::{collections::HashSet, iter};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub(super) fn run(
|
||||||
|
func: Vec<ExpandedStatement>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
) -> Vec<ExpandedStatement> {
|
||||||
|
let mut labels_in_use = HashSet::new();
|
||||||
|
for s in func.iter() {
|
||||||
|
match s {
|
||||||
|
Statement::Instruction(i) => {
|
||||||
|
if let Some(target) = jump_target(i) {
|
||||||
|
labels_in_use.insert(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Conditional(cond) => {
|
||||||
|
labels_in_use.insert(cond.if_true);
|
||||||
|
labels_in_use.insert(cond.if_false);
|
||||||
|
}
|
||||||
|
Statement::Variable(..)
|
||||||
|
| Statement::LoadVar(..)
|
||||||
|
| Statement::StoreVar(..)
|
||||||
|
| Statement::RetValue(..)
|
||||||
|
| Statement::Conversion(..)
|
||||||
|
| Statement::Constant(..)
|
||||||
|
| Statement::Label(..)
|
||||||
|
| Statement::PtrAccess { .. }
|
||||||
|
| Statement::RepackVector(..)
|
||||||
|
| Statement::FunctionPointer(..) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iter::once(Statement::Label(id_def.register_intermediate(None)))
|
||||||
|
.chain(func.into_iter().filter(|s| match s {
|
||||||
|
Statement::Label(i) => labels_in_use.contains(i),
|
||||||
|
_ => true,
|
||||||
|
}))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
|
||||||
|
this: &ast::Instruction<T>,
|
||||||
|
) -> Option<SpirvWord> {
|
||||||
|
match this {
|
||||||
|
ast::Instruction::Bra { arguments } => Some(arguments.src),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
44
ptx/src/pass/normalize_predicates.rs
Normal file
44
ptx/src/pass/normalize_predicates.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
use super::*;
|
||||||
|
use ptx_parser as ast;
|
||||||
|
|
||||||
|
pub(crate) fn run(
|
||||||
|
func: Vec<NormalizedStatement>,
|
||||||
|
id_def: &mut NumericIdResolver,
|
||||||
|
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
|
||||||
|
let mut result = Vec::with_capacity(func.len());
|
||||||
|
for s in func {
|
||||||
|
match s {
|
||||||
|
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
|
Statement::Instruction((pred, inst)) => {
|
||||||
|
if let Some(pred) = pred {
|
||||||
|
let if_true = id_def.register_intermediate(None);
|
||||||
|
let if_false = id_def.register_intermediate(None);
|
||||||
|
let folded_bra = match &inst {
|
||||||
|
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
let mut branch = BrachCondition {
|
||||||
|
predicate: pred.label,
|
||||||
|
if_true: folded_bra.unwrap_or(if_true),
|
||||||
|
if_false,
|
||||||
|
};
|
||||||
|
if pred.not {
|
||||||
|
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||||
|
}
|
||||||
|
result.push(Statement::Conditional(branch));
|
||||||
|
if folded_bra.is_none() {
|
||||||
|
result.push(Statement::Label(if_true));
|
||||||
|
result.push(Statement::Instruction(inst));
|
||||||
|
}
|
||||||
|
result.push(Statement::Label(if_false));
|
||||||
|
} else {
|
||||||
|
result.push(Statement::Instruction(inst));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||||
|
// Blocks are flattened when resolving ids
|
||||||
|
_ => return Err(error_unreachable()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
|
@ -7,20 +7,24 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
%21 = OpExtInstImport "OpenCL.std"
|
OpCapability DenormFlushToZero
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||||
|
%22 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "clz"
|
OpEntryPoint Kernel %1 "clz"
|
||||||
|
OpExecutionMode %1 ContractionOff
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%24 = OpTypeFunction %void %ulong %ulong
|
%25 = OpTypeFunction %void %ulong %ulong
|
||||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
%uint = OpTypeInt 32 0
|
%uint = OpTypeInt 32 0
|
||||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||||
%1 = OpFunction %void None %24
|
%1 = OpFunction %void None %25
|
||||||
%7 = OpFunctionParameter %ulong
|
%7 = OpFunctionParameter %ulong
|
||||||
%8 = OpFunctionParameter %ulong
|
%8 = OpFunctionParameter %ulong
|
||||||
%19 = OpLabel
|
%20 = OpLabel
|
||||||
%2 = OpVariable %_ptr_Function_ulong Function
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
%3 = OpVariable %_ptr_Function_ulong Function
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
%4 = OpVariable %_ptr_Function_ulong Function
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
@ -37,11 +41,12 @@
|
||||||
%11 = OpLoad %uint %17 Aligned 4
|
%11 = OpLoad %uint %17 Aligned 4
|
||||||
OpStore %6 %11
|
OpStore %6 %11
|
||||||
%14 = OpLoad %uint %6
|
%14 = OpLoad %uint %6
|
||||||
%13 = OpExtInst %uint %21 clz %14
|
%18 = OpExtInst %uint %22 clz %14
|
||||||
|
%13 = OpCopyObject %uint %18
|
||||||
OpStore %6 %13
|
OpStore %6 %13
|
||||||
%15 = OpLoad %ulong %5
|
%15 = OpLoad %ulong %5
|
||||||
%16 = OpLoad %uint %6
|
%16 = OpLoad %uint %6
|
||||||
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
|
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||||
OpStore %18 %16 Aligned 4
|
OpStore %19 %16 Aligned 4
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
|
|
|
@ -7,6 +7,9 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
|
OpCapability DenormFlushToZero
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||||
%24 = OpExtInstImport "OpenCL.std"
|
%24 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "cvt_s16_s8"
|
OpEntryPoint Kernel %1 "cvt_s16_s8"
|
||||||
|
@ -45,9 +48,7 @@
|
||||||
%32 = OpBitcast %uint %15
|
%32 = OpBitcast %uint %15
|
||||||
%34 = OpUConvert %uchar %32
|
%34 = OpUConvert %uchar %32
|
||||||
%20 = OpCopyObject %uchar %34
|
%20 = OpCopyObject %uchar %34
|
||||||
%35 = OpBitcast %uchar %20
|
%19 = OpSConvert %ushort %20
|
||||||
%37 = OpSConvert %ushort %35
|
|
||||||
%19 = OpCopyObject %ushort %37
|
|
||||||
%14 = OpSConvert %uint %19
|
%14 = OpSConvert %uint %19
|
||||||
OpStore %6 %14
|
OpStore %6 %14
|
||||||
%16 = OpLoad %ulong %5
|
%16 = OpLoad %ulong %5
|
||||||
|
|
|
@ -7,9 +7,13 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
|
OpCapability DenormFlushToZero
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||||
%24 = OpExtInstImport "OpenCL.std"
|
%24 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "cvt_s64_s32"
|
OpEntryPoint Kernel %1 "cvt_s64_s32"
|
||||||
|
OpExecutionMode %1 ContractionOff
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%27 = OpTypeFunction %void %ulong %ulong
|
%27 = OpTypeFunction %void %ulong %ulong
|
||||||
|
@ -40,9 +44,7 @@
|
||||||
%12 = OpCopyObject %uint %18
|
%12 = OpCopyObject %uint %18
|
||||||
OpStore %6 %12
|
OpStore %6 %12
|
||||||
%15 = OpLoad %uint %6
|
%15 = OpLoad %uint %6
|
||||||
%32 = OpBitcast %uint %15
|
%14 = OpSConvert %ulong %15
|
||||||
%33 = OpSConvert %ulong %32
|
|
||||||
%14 = OpCopyObject %ulong %33
|
|
||||||
OpStore %7 %14
|
OpStore %7 %14
|
||||||
%16 = OpLoad %ulong %5
|
%16 = OpLoad %ulong %5
|
||||||
%17 = OpLoad %ulong %7
|
%17 = OpLoad %ulong %7
|
||||||
|
|
|
@ -7,9 +7,13 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
|
OpCapability DenormFlushToZero
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||||
%25 = OpExtInstImport "OpenCL.std"
|
%25 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "cvt_sat_s_u"
|
OpEntryPoint Kernel %1 "cvt_sat_s_u"
|
||||||
|
OpExecutionMode %1 ContractionOff
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%28 = OpTypeFunction %void %ulong %ulong
|
%28 = OpTypeFunction %void %ulong %ulong
|
||||||
|
@ -42,7 +46,7 @@
|
||||||
%15 = OpSatConvertSToU %uint %16
|
%15 = OpSatConvertSToU %uint %16
|
||||||
OpStore %7 %15
|
OpStore %7 %15
|
||||||
%18 = OpLoad %uint %7
|
%18 = OpLoad %uint %7
|
||||||
%17 = OpBitcast %uint %18
|
%17 = OpCopyObject %uint %18
|
||||||
OpStore %8 %17
|
OpStore %8 %17
|
||||||
%19 = OpLoad %ulong %5
|
%19 = OpLoad %ulong %5
|
||||||
%20 = OpLoad %uint %8
|
%20 = OpLoad %uint %8
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::pass;
|
||||||
use crate::ptx;
|
use crate::ptx;
|
||||||
use crate::translate;
|
use crate::translate;
|
||||||
use hip_runtime_sys::hipError_t;
|
use hip_runtime_sys::hipError_t;
|
||||||
|
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
|
||||||
spirv_txt: &'a [u8],
|
spirv_txt: &'a [u8],
|
||||||
spirv_file_name: &'a str,
|
spirv_file_name: &'a str,
|
||||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||||
let mut errors = Vec::new();
|
let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap();
|
||||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
let spirv_module = pass::to_spirv_module(ast)?;
|
||||||
assert!(errors.len() == 0);
|
|
||||||
let spirv_module = translate::to_spirv_module(ast)?;
|
|
||||||
let spv_context =
|
let spv_context =
|
||||||
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
||||||
assert!(spv_context != ptr::null_mut());
|
assert!(spv_context != ptr::null_mut());
|
||||||
|
|
|
@ -7,20 +7,24 @@
|
||||||
OpCapability Int64
|
OpCapability Int64
|
||||||
OpCapability Float16
|
OpCapability Float16
|
||||||
OpCapability Float64
|
OpCapability Float64
|
||||||
%21 = OpExtInstImport "OpenCL.std"
|
OpCapability DenormFlushToZero
|
||||||
|
OpExtension "SPV_KHR_float_controls"
|
||||||
|
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||||
|
%22 = OpExtInstImport "OpenCL.std"
|
||||||
OpMemoryModel Physical64 OpenCL
|
OpMemoryModel Physical64 OpenCL
|
||||||
OpEntryPoint Kernel %1 "popc"
|
OpEntryPoint Kernel %1 "popc"
|
||||||
|
OpExecutionMode %1 ContractionOff
|
||||||
%void = OpTypeVoid
|
%void = OpTypeVoid
|
||||||
%ulong = OpTypeInt 64 0
|
%ulong = OpTypeInt 64 0
|
||||||
%24 = OpTypeFunction %void %ulong %ulong
|
%25 = OpTypeFunction %void %ulong %ulong
|
||||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||||
%uint = OpTypeInt 32 0
|
%uint = OpTypeInt 32 0
|
||||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||||
%1 = OpFunction %void None %24
|
%1 = OpFunction %void None %25
|
||||||
%7 = OpFunctionParameter %ulong
|
%7 = OpFunctionParameter %ulong
|
||||||
%8 = OpFunctionParameter %ulong
|
%8 = OpFunctionParameter %ulong
|
||||||
%19 = OpLabel
|
%20 = OpLabel
|
||||||
%2 = OpVariable %_ptr_Function_ulong Function
|
%2 = OpVariable %_ptr_Function_ulong Function
|
||||||
%3 = OpVariable %_ptr_Function_ulong Function
|
%3 = OpVariable %_ptr_Function_ulong Function
|
||||||
%4 = OpVariable %_ptr_Function_ulong Function
|
%4 = OpVariable %_ptr_Function_ulong Function
|
||||||
|
@ -37,11 +41,12 @@
|
||||||
%11 = OpLoad %uint %17 Aligned 4
|
%11 = OpLoad %uint %17 Aligned 4
|
||||||
OpStore %6 %11
|
OpStore %6 %11
|
||||||
%14 = OpLoad %uint %6
|
%14 = OpLoad %uint %6
|
||||||
%13 = OpBitCount %uint %14
|
%18 = OpBitCount %uint %14
|
||||||
|
%13 = OpCopyObject %uint %18
|
||||||
OpStore %6 %13
|
OpStore %6 %13
|
||||||
%15 = OpLoad %ulong %5
|
%15 = OpLoad %ulong %5
|
||||||
%16 = OpLoad %uint %6
|
%16 = OpLoad %uint %6
|
||||||
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
|
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||||
OpStore %18 %16 Aligned 4
|
OpStore %19 %16 Aligned 4
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Excersise as many features of vector types as possible
|
// Exercise as many features of vector types as possible
|
||||||
|
|
||||||
.version 6.5
|
.version 6.5
|
||||||
.target sm_60
|
.target sm_60
|
||||||
|
|
|
@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
|
||||||
for statement in sorted_statements {
|
for statement in sorted_statements {
|
||||||
match statement {
|
match statement {
|
||||||
Statement::Variable(
|
Statement::Variable(
|
||||||
var
|
var @ ast::Variable {
|
||||||
@
|
|
||||||
ast::Variable {
|
|
||||||
state_space: ast::StateSpace::Shared,
|
state_space: ast::StateSpace::Shared,
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
| Statement::Variable(
|
| Statement::Variable(
|
||||||
var
|
var @ ast::Variable {
|
||||||
@
|
|
||||||
ast::Variable {
|
|
||||||
state_space: ast::StateSpace::Global,
|
state_space: ast::StateSpace::Global,
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
|
@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
details
|
details @ ast::AtomDetails {
|
||||||
@
|
|
||||||
ast::AtomDetails {
|
|
||||||
inner:
|
inner:
|
||||||
ast::AtomInnerDetails::Unsigned {
|
ast::AtomInnerDetails::Unsigned {
|
||||||
op: ast::AtomUIntOp::Inc,
|
op: ast::AtomUIntOp::Inc,
|
||||||
|
@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
details
|
details @ ast::AtomDetails {
|
||||||
@
|
|
||||||
ast::AtomDetails {
|
|
||||||
inner:
|
inner:
|
||||||
ast::AtomInnerDetails::Unsigned {
|
ast::AtomInnerDetails::Unsigned {
|
||||||
op: ast::AtomUIntOp::Dec,
|
op: ast::AtomUIntOp::Dec,
|
||||||
|
@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
Statement::Instruction(ast::Instruction::Atom(
|
Statement::Instruction(ast::Instruction::Atom(
|
||||||
details
|
details @ ast::AtomDetails {
|
||||||
@
|
|
||||||
ast::AtomDetails {
|
|
||||||
inner:
|
inner:
|
||||||
ast::AtomInnerDetails::Float {
|
ast::AtomInnerDetails::Float {
|
||||||
op: ast::AtomFloatOp::Add,
|
op: ast::AtomFloatOp::Add,
|
||||||
|
|
17
ptx_parser/Cargo.toml
Normal file
17
ptx_parser/Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
name = "ptx_parser"
|
||||||
|
version = "0.0.0"
|
||||||
|
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
logos = "0.14"
|
||||||
|
winnow = { version = "0.6.18" }
|
||||||
|
#winnow = { version = "0.6.18", features = ["debug"] }
|
||||||
|
ptx_parser_macros = { path = "../ptx_parser_macros" }
|
||||||
|
thiserror = "1.0"
|
||||||
|
bitflags = "1.2"
|
||||||
|
rustc-hash = "2.0.0"
|
||||||
|
derive_more = { version = "1", features = ["display"] }
|
1695
ptx_parser/src/ast.rs
Normal file
1695
ptx_parser/src/ast.rs
Normal file
File diff suppressed because it is too large
Load diff
69
ptx_parser/src/check_args.py
Normal file
69
ptx_parser/src/check_args.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import os, sys, subprocess
|
||||||
|
|
||||||
|
|
||||||
|
SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
|
||||||
|
TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
|
||||||
|
MULTIVAR = ["", "<1>" ]
|
||||||
|
VECTOR = ["", ".v2" ]
|
||||||
|
|
||||||
|
HEADER = """
|
||||||
|
.version 8.5
|
||||||
|
.target sm_90
|
||||||
|
.address_size 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def directive(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
{0} {4} .b32 variable{2} {1};
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
def entry_arg(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.entry foobar ( {0} {4} .b32 variable{2} {1})
|
||||||
|
{{
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def fn_arg(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.func foobar ( {0} {4} .b32 variable{2} {1})
|
||||||
|
{{
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def fn_body(space, variable, multivar, vector):
|
||||||
|
return """{3}
|
||||||
|
.func foobar ()
|
||||||
|
{{
|
||||||
|
{0} {4} .b32 variable{2} {1};
|
||||||
|
ret;
|
||||||
|
}}
|
||||||
|
""".format(space, variable, multivar, HEADER, vector)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(generator):
|
||||||
|
legal = []
|
||||||
|
for space in SPACE:
|
||||||
|
for init in TYPE_AND_INIT:
|
||||||
|
for multi in MULTIVAR:
|
||||||
|
for vector in VECTOR:
|
||||||
|
ptx = generator(space, init, multi, vector)
|
||||||
|
if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
|
||||||
|
legal.append((space, vector, init, multi))
|
||||||
|
print(generator.__name__)
|
||||||
|
print(legal)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generate(directive)
|
||||||
|
generate(entry_arg)
|
||||||
|
generate(fn_arg)
|
||||||
|
generate(fn_body)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
3269
ptx_parser/src/lib.rs
Normal file
3269
ptx_parser/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
17
ptx_parser_macros/Cargo.toml
Normal file
17
ptx_parser_macros/Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
name = "ptx_parser_macros"
|
||||||
|
version = "0.0.0"
|
||||||
|
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
proc-macro = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" }
|
||||||
|
convert_case = "0.6.0"
|
||||||
|
rustc-hash = "2.0.0"
|
||||||
|
syn = "2.0.67"
|
||||||
|
quote = "1.0"
|
||||||
|
proc-macro2 = "1.0.86"
|
||||||
|
either = "1.13.0"
|
1023
ptx_parser_macros/src/lib.rs
Normal file
1023
ptx_parser_macros/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
13
ptx_parser_macros_impl/Cargo.toml
Normal file
13
ptx_parser_macros_impl/Cargo.toml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
[package]
|
||||||
|
name = "ptx_parser_macros_impl"
|
||||||
|
version = "0.0.0"
|
||||||
|
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
syn = { version = "2.0.67", features = ["extra-traits", "full"] }
|
||||||
|
quote = "1.0"
|
||||||
|
proc-macro2 = "1.0.86"
|
||||||
|
rustc-hash = "2.0.0"
|
881
ptx_parser_macros_impl/src/lib.rs
Normal file
881
ptx_parser_macros_impl/src/lib.rs
Normal file
|
@ -0,0 +1,881 @@
|
||||||
|
use proc_macro2::TokenStream;
|
||||||
|
use quote::{format_ident, quote, ToTokens};
|
||||||
|
use syn::{
|
||||||
|
braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token,
|
||||||
|
Type, TypeParam, Visibility,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod parser;
|
||||||
|
|
||||||
|
pub struct GenerateInstructionType {
|
||||||
|
pub visibility: Option<Visibility>,
|
||||||
|
pub name: Ident,
|
||||||
|
pub type_parameters: Punctuated<TypeParam, Token![,]>,
|
||||||
|
pub short_parameters: Punctuated<Ident, Token![,]>,
|
||||||
|
pub variants: Punctuated<InstructionVariant, Token![,]>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateInstructionType {
|
||||||
|
pub fn emit_arg_types(&self, tokens: &mut TokenStream) {
|
||||||
|
for v in self.variants.iter() {
|
||||||
|
v.emit_type(&self.visibility, tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_instruction_type(&self, tokens: &mut TokenStream) {
|
||||||
|
let vis = &self.visibility;
|
||||||
|
let type_name = &self.name;
|
||||||
|
let type_parameters = &self.type_parameters;
|
||||||
|
let variants = self.variants.iter().map(|v| v.emit_variant());
|
||||||
|
quote! {
|
||||||
|
#vis enum #type_name<#type_parameters> {
|
||||||
|
#(#variants),*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_visit(&self, tokens: &mut TokenStream) {
|
||||||
|
self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_visit_mut(&self, tokens: &mut TokenStream) {
|
||||||
|
self.emit_visit_impl(
|
||||||
|
VisitKind::RefMut,
|
||||||
|
tokens,
|
||||||
|
InstructionVariant::emit_visit_mut,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn emit_visit_map(&self, tokens: &mut TokenStream) {
|
||||||
|
self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_impl(
|
||||||
|
&self,
|
||||||
|
kind: VisitKind,
|
||||||
|
tokens: &mut TokenStream,
|
||||||
|
mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream),
|
||||||
|
) {
|
||||||
|
let type_name = &self.name;
|
||||||
|
let type_parameters = &self.type_parameters;
|
||||||
|
let short_parameters = &self.short_parameters;
|
||||||
|
let mut inner_tokens = TokenStream::new();
|
||||||
|
for v in self.variants.iter() {
|
||||||
|
fn_(v, type_name, &mut inner_tokens);
|
||||||
|
}
|
||||||
|
let visit_ref = kind.reference();
|
||||||
|
let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
|
||||||
|
let visit_fn = format_ident!("visit{}", kind.fn_suffix());
|
||||||
|
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
|
||||||
|
(
|
||||||
|
quote! { <#type_parameters, To: Operand, Err> },
|
||||||
|
quote! { <#short_parameters, To, Err> },
|
||||||
|
quote! { std::result::Result<#type_name<To>, Err> },
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
quote! { <#type_parameters, Err> },
|
||||||
|
quote! { <#short_parameters, Err> },
|
||||||
|
quote! { std::result::Result<(), Err> },
|
||||||
|
)
|
||||||
|
};
|
||||||
|
quote! {
|
||||||
|
pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
|
||||||
|
Ok(match i {
|
||||||
|
#inner_tokens
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}.to_tokens(tokens);
|
||||||
|
if kind == VisitKind::Map {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum VisitKind {
|
||||||
|
Ref,
|
||||||
|
RefMut,
|
||||||
|
Map,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisitKind {
|
||||||
|
fn fn_suffix(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
VisitKind::Ref => "",
|
||||||
|
VisitKind::RefMut => "_mut",
|
||||||
|
VisitKind::Map => "_map",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn type_suffix(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
VisitKind::Ref => "",
|
||||||
|
VisitKind::RefMut => "Mut",
|
||||||
|
VisitKind::Map => "Map",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reference(self) -> Option<proc_macro2::TokenStream> {
|
||||||
|
match self {
|
||||||
|
VisitKind::Ref => Some(quote! { & }),
|
||||||
|
VisitKind::RefMut => Some(quote! { &mut }),
|
||||||
|
VisitKind::Map => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for GenerateInstructionType {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let visibility = if !input.peek(Token![enum]) {
|
||||||
|
Some(input.parse::<Visibility>()?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
input.parse::<Token![enum]>()?;
|
||||||
|
let name = input.parse::<Ident>()?;
|
||||||
|
input.parse::<Token![<]>()?;
|
||||||
|
let type_parameters = Punctuated::parse_separated_nonempty(input)?;
|
||||||
|
let short_parameters = type_parameters
|
||||||
|
.iter()
|
||||||
|
.map(|p: &TypeParam| p.ident.clone())
|
||||||
|
.collect();
|
||||||
|
input.parse::<Token![>]>()?;
|
||||||
|
let variants_buffer;
|
||||||
|
braced!(variants_buffer in input);
|
||||||
|
let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?;
|
||||||
|
Ok(Self {
|
||||||
|
visibility,
|
||||||
|
name,
|
||||||
|
type_parameters,
|
||||||
|
short_parameters,
|
||||||
|
variants,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct InstructionVariant {
|
||||||
|
pub name: Ident,
|
||||||
|
pub type_: Option<Option<Expr>>,
|
||||||
|
pub space: Option<Expr>,
|
||||||
|
pub data: Option<Type>,
|
||||||
|
pub arguments: Option<Arguments>,
|
||||||
|
pub visit: Option<Expr>,
|
||||||
|
pub visit_mut: Option<Expr>,
|
||||||
|
pub map: Option<Expr>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InstructionVariant {
|
||||||
|
fn args_name(&self) -> Ident {
|
||||||
|
format_ident!("{}Args", self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_variant(&self) -> TokenStream {
|
||||||
|
let name = &self.name;
|
||||||
|
let data = match &self.data {
|
||||||
|
None => {
|
||||||
|
quote! {}
|
||||||
|
}
|
||||||
|
Some(data_type) => {
|
||||||
|
quote! {
|
||||||
|
data: #data_type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let arguments = match &self.arguments {
|
||||||
|
None => {
|
||||||
|
quote! {}
|
||||||
|
}
|
||||||
|
Some(args) => {
|
||||||
|
let args_name = self.args_name();
|
||||||
|
match &args {
|
||||||
|
Arguments::Def(InstructionArguments { generic: None, .. }) => {
|
||||||
|
quote! {
|
||||||
|
arguments: #args_name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arguments::Def(InstructionArguments {
|
||||||
|
generic: Some(generics),
|
||||||
|
..
|
||||||
|
}) => {
|
||||||
|
quote! {
|
||||||
|
arguments: #args_name <#generics>,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arguments::Decl(type_) => quote! {
|
||||||
|
arguments: #type_,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
quote! {
|
||||||
|
#name { #data #arguments }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||||
|
self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||||
|
self.emit_visit_impl(
|
||||||
|
&self.visit_mut,
|
||||||
|
enum_,
|
||||||
|
tokens,
|
||||||
|
InstructionArguments::emit_visit_mut,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_impl(
|
||||||
|
&self,
|
||||||
|
visit_fn: &Option<Expr>,
|
||||||
|
enum_: &Ident,
|
||||||
|
tokens: &mut TokenStream,
|
||||||
|
mut fn_: impl FnMut(&InstructionArguments, &Option<Option<Expr>>, &Option<Expr>) -> TokenStream,
|
||||||
|
) {
|
||||||
|
let name = &self.name;
|
||||||
|
let arguments = match &self.arguments {
|
||||||
|
None => {
|
||||||
|
quote! {
|
||||||
|
#enum_ :: #name { .. } => { }
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Some(Arguments::Decl(_)) => {
|
||||||
|
quote! {
|
||||||
|
#enum_ :: #name { data, arguments } => { #visit_fn }
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Some(Arguments::Def(args)) => args,
|
||||||
|
};
|
||||||
|
let data = &self.data.as_ref().map(|_| quote! { data,});
|
||||||
|
let arg_calls = fn_(arguments, &self.type_, &self.space);
|
||||||
|
quote! {
|
||||||
|
#enum_ :: #name { #data arguments } => {
|
||||||
|
#arg_calls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||||
|
let name = &self.name;
|
||||||
|
let data = &self.data.as_ref().map(|_| quote! { data,});
|
||||||
|
let arguments = match self.arguments {
|
||||||
|
None => None,
|
||||||
|
Some(Arguments::Decl(_)) => {
|
||||||
|
let map = self.map.as_ref().unwrap();
|
||||||
|
quote! {
|
||||||
|
#enum_ :: #name { #data arguments } => {
|
||||||
|
#map
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Some(Arguments::Def(ref def)) => Some(def),
|
||||||
|
};
|
||||||
|
let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,});
|
||||||
|
let mut arg_calls = None;
|
||||||
|
let arguments_init = arguments.as_ref().map(|arguments| {
|
||||||
|
let arg_type = self.args_name();
|
||||||
|
arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space));
|
||||||
|
let arg_names = arguments.fields.iter().map(|arg| &arg.name);
|
||||||
|
quote! {
|
||||||
|
arguments: #arg_type { #(#arg_names),* }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
quote! {
|
||||||
|
#enum_ :: #name { #data #arguments_ident } => {
|
||||||
|
#arg_calls
|
||||||
|
#enum_ :: #name { #data #arguments_init }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_type(&self, vis: &Option<Visibility>, tokens: &mut TokenStream) {
|
||||||
|
let arguments = match self.arguments {
|
||||||
|
Some(Arguments::Def(ref a)) => a,
|
||||||
|
Some(Arguments::Decl(_)) => return,
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
let name = self.args_name();
|
||||||
|
let type_parameters = if arguments.generic.is_some() {
|
||||||
|
Some(quote! { <T> })
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let fields = arguments.fields.iter().map(|f| f.emit_field(vis));
|
||||||
|
quote! {
|
||||||
|
#vis struct #name #type_parameters {
|
||||||
|
#(#fields),*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.to_tokens(tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for InstructionVariant {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let name = input.parse::<Ident>()?;
|
||||||
|
let properties_buffer;
|
||||||
|
braced!(properties_buffer in input);
|
||||||
|
let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?;
|
||||||
|
let mut type_ = None;
|
||||||
|
let mut space = None;
|
||||||
|
let mut data = None;
|
||||||
|
let mut arguments = None;
|
||||||
|
let mut visit = None;
|
||||||
|
let mut visit_mut = None;
|
||||||
|
let mut map = None;
|
||||||
|
for property in properties {
|
||||||
|
match property {
|
||||||
|
VariantProperty::Type(t) => type_ = Some(t),
|
||||||
|
VariantProperty::Space(s) => space = Some(s),
|
||||||
|
VariantProperty::Data(d) => data = Some(d),
|
||||||
|
VariantProperty::Arguments(a) => arguments = Some(a),
|
||||||
|
VariantProperty::Visit(e) => visit = Some(e),
|
||||||
|
VariantProperty::VisitMut(e) => visit_mut = Some(e),
|
||||||
|
VariantProperty::Map(e) => map = Some(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
type_,
|
||||||
|
space,
|
||||||
|
data,
|
||||||
|
arguments,
|
||||||
|
visit,
|
||||||
|
visit_mut,
|
||||||
|
map,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum VariantProperty {
|
||||||
|
Type(Option<Expr>),
|
||||||
|
Space(Expr),
|
||||||
|
Data(Type),
|
||||||
|
Arguments(Arguments),
|
||||||
|
Visit(Expr),
|
||||||
|
VisitMut(Expr),
|
||||||
|
Map(Expr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VariantProperty {
|
||||||
|
pub fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
Ok(if lookahead.peek(Token![type]) {
|
||||||
|
input.parse::<Token![type]>()?;
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::Type(if input.peek(Token![!]) {
|
||||||
|
input.parse::<Token![!]>()?;
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(input.parse::<Expr>()?)
|
||||||
|
})
|
||||||
|
} else if lookahead.peek(Ident) {
|
||||||
|
let key = input.parse::<Ident>()?;
|
||||||
|
match &*key.to_string() {
|
||||||
|
"data" => {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::Data(input.parse::<Type>()?)
|
||||||
|
}
|
||||||
|
"space" => {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::Space(input.parse::<Expr>()?)
|
||||||
|
}
|
||||||
|
"arguments" => {
|
||||||
|
let generics = if input.peek(Token![<]) {
|
||||||
|
input.parse::<Token![<]>()?;
|
||||||
|
let gen_params =
|
||||||
|
Punctuated::<PathSegment, syn::token::PathSep>::parse_separated_nonempty(input)?;
|
||||||
|
input.parse::<Token![>]>()?;
|
||||||
|
Some(gen_params)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
if input.peek(token::Brace) {
|
||||||
|
let fields;
|
||||||
|
braced!(fields in input);
|
||||||
|
VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse(
|
||||||
|
generics, &fields,
|
||||||
|
)?))
|
||||||
|
} else {
|
||||||
|
VariantProperty::Arguments(Arguments::Decl(input.parse::<Type>()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"visit" => {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::Visit(input.parse::<Expr>()?)
|
||||||
|
}
|
||||||
|
"visit_mut" => {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::VisitMut(input.parse::<Expr>()?)
|
||||||
|
}
|
||||||
|
"map" => {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
VariantProperty::Map(input.parse::<Expr>()?)
|
||||||
|
}
|
||||||
|
x => {
|
||||||
|
return Err(syn::Error::new(
|
||||||
|
key.span(),
|
||||||
|
format!(
|
||||||
|
"Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.",
|
||||||
|
x
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Arguments {
|
||||||
|
Decl(Type),
|
||||||
|
Def(InstructionArguments),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct InstructionArguments {
|
||||||
|
pub generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
|
||||||
|
pub fields: Punctuated<ArgumentField, Token![,]>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InstructionArguments {
|
||||||
|
pub fn parse(
|
||||||
|
generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
|
||||||
|
input: syn::parse::ParseStream,
|
||||||
|
) -> syn::Result<Self> {
|
||||||
|
let fields = Punctuated::<ArgumentField, Token![,]>::parse_terminated_with(
|
||||||
|
input,
|
||||||
|
ArgumentField::parse,
|
||||||
|
)?;
|
||||||
|
Ok(Self { generic, fields })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_mut(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_map(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_impl(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
mut fn_: impl FnMut(&ArgumentField, &Option<Option<Expr>>, &Option<Expr>, bool) -> TokenStream,
|
||||||
|
) -> TokenStream {
|
||||||
|
let is_ident = if let Some(ref generic) = self.generic {
|
||||||
|
generic.len() > 1
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let field_calls = self
|
||||||
|
.fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| fn_(f, parent_type, parent_space, is_ident));
|
||||||
|
quote! {
|
||||||
|
#(#field_calls)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ArgumentField {
|
||||||
|
pub name: Ident,
|
||||||
|
pub is_dst: bool,
|
||||||
|
pub repr: Type,
|
||||||
|
pub space: Option<Expr>,
|
||||||
|
pub type_: Option<Expr>,
|
||||||
|
pub relaxed_type_check: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ArgumentField {
|
||||||
|
fn parse_block(
|
||||||
|
input: syn::parse::ParseStream,
|
||||||
|
) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>, bool)> {
|
||||||
|
let content;
|
||||||
|
braced!(content in input);
|
||||||
|
let all_fields =
|
||||||
|
Punctuated::<ExprOrPath, Token![,]>::parse_terminated_with(&content, |content| {
|
||||||
|
let lookahead = content.lookahead1();
|
||||||
|
Ok(if lookahead.peek(Token![type]) {
|
||||||
|
content.parse::<Token![type]>()?;
|
||||||
|
content.parse::<Token![:]>()?;
|
||||||
|
ExprOrPath::Type(content.parse::<Expr>()?)
|
||||||
|
} else if lookahead.peek(Ident) {
|
||||||
|
let name_ident = content.parse::<Ident>()?;
|
||||||
|
content.parse::<Token![:]>()?;
|
||||||
|
match &*name_ident.to_string() {
|
||||||
|
"relaxed_type_check" => {
|
||||||
|
ExprOrPath::RelaxedTypeCheck(content.parse::<LitBool>()?.value)
|
||||||
|
}
|
||||||
|
"repr" => ExprOrPath::Repr(content.parse::<Type>()?),
|
||||||
|
"space" => ExprOrPath::Space(content.parse::<Expr>()?),
|
||||||
|
"dst" => {
|
||||||
|
let ident = content.parse::<LitBool>()?;
|
||||||
|
ExprOrPath::Dst(ident.value)
|
||||||
|
}
|
||||||
|
name => {
|
||||||
|
return Err(syn::Error::new(
|
||||||
|
name_ident.span(),
|
||||||
|
format!("Unexpected key `{}`, expected `repr` or `space", name),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
let mut repr = None;
|
||||||
|
let mut type_ = None;
|
||||||
|
let mut space = None;
|
||||||
|
let mut is_dst = None;
|
||||||
|
let mut relaxed_type_check = false;
|
||||||
|
for exp_or_path in all_fields {
|
||||||
|
match exp_or_path {
|
||||||
|
ExprOrPath::Repr(r) => repr = Some(r),
|
||||||
|
ExprOrPath::Type(t) => type_ = Some(t),
|
||||||
|
ExprOrPath::Space(s) => space = Some(s),
|
||||||
|
ExprOrPath::Dst(x) => is_dst = Some(x),
|
||||||
|
ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> {
|
||||||
|
input.parse::<Type>()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
is_ident: bool,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.emit_visit_impl(parent_type, parent_space, is_ident, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_mut(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
is_ident: bool,
|
||||||
|
) -> TokenStream {
|
||||||
|
self.emit_visit_impl(parent_type, parent_space, is_ident, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_impl(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
is_ident: bool,
|
||||||
|
is_mut: bool,
|
||||||
|
) -> TokenStream {
|
||||||
|
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
|
||||||
|
(Some(type_), _) => (false, Some(type_)),
|
||||||
|
(None, None) => panic!("No type set"),
|
||||||
|
(None, Some(None)) => (true, None),
|
||||||
|
(None, Some(Some(type_))) => (false, Some(type_)),
|
||||||
|
};
|
||||||
|
let space = self
|
||||||
|
.space
|
||||||
|
.as_ref()
|
||||||
|
.or(parent_space.as_ref())
|
||||||
|
.map(|space| quote! { #space })
|
||||||
|
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
||||||
|
let is_dst = self.is_dst;
|
||||||
|
let relaxed_type_check = self.relaxed_type_check;
|
||||||
|
let name = &self.name;
|
||||||
|
let type_space = if is_typeless {
|
||||||
|
quote! {
|
||||||
|
let type_space = None;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
let type_ = #type_;
|
||||||
|
let space = #space;
|
||||||
|
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if is_ident {
|
||||||
|
if is_mut {
|
||||||
|
quote! {
|
||||||
|
{
|
||||||
|
#type_space
|
||||||
|
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
{
|
||||||
|
#type_space
|
||||||
|
visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let (operand_fn, arguments_name) = if is_mut {
|
||||||
|
(
|
||||||
|
quote! {
|
||||||
|
VisitOperand::visit_mut
|
||||||
|
},
|
||||||
|
quote! {
|
||||||
|
&mut arguments.#name
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
quote! {
|
||||||
|
VisitOperand::visit
|
||||||
|
},
|
||||||
|
quote! {
|
||||||
|
& arguments.#name
|
||||||
|
},
|
||||||
|
)
|
||||||
|
};
|
||||||
|
quote! {{
|
||||||
|
#type_space
|
||||||
|
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?;
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_visit_map(
|
||||||
|
&self,
|
||||||
|
parent_type: &Option<Option<Expr>>,
|
||||||
|
parent_space: &Option<Expr>,
|
||||||
|
is_ident: bool,
|
||||||
|
) -> TokenStream {
|
||||||
|
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
|
||||||
|
(Some(type_), _) => (false, Some(type_)),
|
||||||
|
(None, None) => panic!("No type set"),
|
||||||
|
(None, Some(None)) => (true, None),
|
||||||
|
(None, Some(Some(type_))) => (false, Some(type_)),
|
||||||
|
};
|
||||||
|
let space = self
|
||||||
|
.space
|
||||||
|
.as_ref()
|
||||||
|
.or(parent_space.as_ref())
|
||||||
|
.map(|space| quote! { #space })
|
||||||
|
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
||||||
|
let is_dst = self.is_dst;
|
||||||
|
let relaxed_type_check = self.relaxed_type_check;
|
||||||
|
let name = &self.name;
|
||||||
|
let type_space = if is_typeless {
|
||||||
|
quote! {
|
||||||
|
let type_space = None;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
let type_ = #type_;
|
||||||
|
let space = #space;
|
||||||
|
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let map_call = if is_ident {
|
||||||
|
quote! {
|
||||||
|
visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
quote! {
|
||||||
|
let #name = {
|
||||||
|
#type_space
|
||||||
|
#map_call
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_dst(name: &Ident) -> syn::Result<bool> {
|
||||||
|
if name.to_string().starts_with("dst") {
|
||||||
|
Ok(true)
|
||||||
|
} else if name.to_string().starts_with("src") {
|
||||||
|
Ok(false)
|
||||||
|
} else {
|
||||||
|
return Err(syn::Error::new(
|
||||||
|
name.span(),
|
||||||
|
format!(
|
||||||
|
"Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`",
|
||||||
|
name
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn emit_field(&self, vis: &Option<Visibility>) -> TokenStream {
|
||||||
|
let name = &self.name;
|
||||||
|
let type_ = &self.repr;
|
||||||
|
quote! {
|
||||||
|
#vis #name: #type_
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for ArgumentField {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let name = input.parse::<Ident>()?;
|
||||||
|
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) {
|
||||||
|
Self::parse_block(input)?
|
||||||
|
} else if lookahead.peek(syn::Ident) {
|
||||||
|
(Self::parse_basic(input)?, None, None, None, false)
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
};
|
||||||
|
let is_dst = match is_dst {
|
||||||
|
Some(x) => x,
|
||||||
|
None => Self::is_dst(&name)?,
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
is_dst,
|
||||||
|
repr,
|
||||||
|
type_,
|
||||||
|
space,
|
||||||
|
relaxed_type_check
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ExprOrPath {
|
||||||
|
Repr(Type),
|
||||||
|
Type(Expr),
|
||||||
|
Space(Expr),
|
||||||
|
Dst(bool),
|
||||||
|
RelaxedTypeCheck(bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use proc_macro2::Span;
|
||||||
|
use quote::{quote, ToTokens};
|
||||||
|
|
||||||
|
fn to_string(x: impl ToTokens) -> String {
|
||||||
|
quote! { #x }.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_argument_field_basic() {
|
||||||
|
let input = quote! {
|
||||||
|
dst: P::Operand
|
||||||
|
};
|
||||||
|
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||||
|
assert_eq!("dst", arg.name.to_string());
|
||||||
|
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||||
|
assert!(matches!(arg.type_, None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_argument_field_block() {
|
||||||
|
let input = quote! {
|
||||||
|
dst: {
|
||||||
|
type: ScalarType::U32,
|
||||||
|
space: StateSpace::Global,
|
||||||
|
repr: P::Operand,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||||
|
assert_eq!("dst", arg.name.to_string());
|
||||||
|
assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap()));
|
||||||
|
assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap()));
|
||||||
|
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_argument_field_block_untyped() {
|
||||||
|
let input = quote! {
|
||||||
|
dst: {
|
||||||
|
repr: P::Operand,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||||
|
assert_eq!("dst", arg.name.to_string());
|
||||||
|
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||||
|
assert!(matches!(arg.type_, None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_variant_complex() {
|
||||||
|
let input = quote! {
|
||||||
|
Ld {
|
||||||
|
type: ScalarType::U32,
|
||||||
|
space: StateSpace::Global,
|
||||||
|
data: LdDetails,
|
||||||
|
arguments<P>: {
|
||||||
|
dst: {
|
||||||
|
repr: P::Operand,
|
||||||
|
type: ScalarType::U32,
|
||||||
|
space: StateSpace::Shared,
|
||||||
|
},
|
||||||
|
src: P::Operand,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
|
||||||
|
assert_eq!("Ld", variant.name.to_string());
|
||||||
|
assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
|
||||||
|
assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
|
||||||
|
assert_eq!("LdDetails", to_string(variant.data.unwrap()));
|
||||||
|
let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
|
||||||
|
a
|
||||||
|
} else {
|
||||||
|
panic!()
|
||||||
|
};
|
||||||
|
assert_eq!("P", to_string(arguments.generic));
|
||||||
|
let mut fields = arguments.fields.into_iter();
|
||||||
|
let dst = fields.next().unwrap();
|
||||||
|
assert_eq!("P :: Operand", to_string(dst.repr));
|
||||||
|
assert_eq!("ScalarType :: U32", to_string(dst.type_));
|
||||||
|
assert_eq!("StateSpace :: Shared", to_string(dst.space));
|
||||||
|
let src = fields.next().unwrap();
|
||||||
|
assert_eq!("P :: Operand", to_string(src.repr));
|
||||||
|
assert!(matches!(src.type_, None));
|
||||||
|
assert!(matches!(src.space, None));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn visit_variant_empty() {
|
||||||
|
let input = quote! {
|
||||||
|
Ret {
|
||||||
|
data: RetData
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
|
||||||
|
let mut output = TokenStream::new();
|
||||||
|
variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output);
|
||||||
|
assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }");
|
||||||
|
}
|
||||||
|
}
|
844
ptx_parser_macros_impl/src/parser.rs
Normal file
844
ptx_parser_macros_impl/src/parser.rs
Normal file
|
@ -0,0 +1,844 @@
|
||||||
|
use proc_macro2::Span;
|
||||||
|
use proc_macro2::TokenStream;
|
||||||
|
use quote::quote;
|
||||||
|
use quote::ToTokens;
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::fmt::Write;
|
||||||
|
use syn::bracketed;
|
||||||
|
use syn::parse::Peek;
|
||||||
|
use syn::punctuated::Punctuated;
|
||||||
|
use syn::spanned::Spanned;
|
||||||
|
use syn::LitInt;
|
||||||
|
use syn::Type;
|
||||||
|
use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token};
|
||||||
|
|
||||||
|
pub struct ParseDefinitions {
|
||||||
|
pub token_type: ItemEnum,
|
||||||
|
pub additional_enums: FxHashMap<Ident, ItemEnum>,
|
||||||
|
pub definitions: Vec<OpcodeDefinition>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for ParseDefinitions {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let token_type = input.parse::<ItemEnum>()?;
|
||||||
|
let mut additional_enums = FxHashMap::default();
|
||||||
|
while input.peek(Token![#]) {
|
||||||
|
let enum_ = input.parse::<ItemEnum>()?;
|
||||||
|
additional_enums.insert(enum_.ident.clone(), enum_);
|
||||||
|
}
|
||||||
|
let mut definitions = Vec::new();
|
||||||
|
while !input.is_empty() {
|
||||||
|
definitions.push(input.parse::<OpcodeDefinition>()?);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
token_type,
|
||||||
|
additional_enums,
|
||||||
|
definitions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpcodeDefinition(pub Patterns, pub Vec<Rule>);
|
||||||
|
|
||||||
|
impl Parse for OpcodeDefinition {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let patterns = input.parse::<Patterns>()?;
|
||||||
|
let mut rules = Vec::new();
|
||||||
|
while Rule::peek(input) {
|
||||||
|
rules.push(input.parse::<Rule>()?);
|
||||||
|
input.parse::<Token![;]>()?;
|
||||||
|
}
|
||||||
|
Ok(Self(patterns, rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>);
|
||||||
|
|
||||||
|
impl Parse for Patterns {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
loop {
|
||||||
|
if !OpcodeDecl::peek(input) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let decl = input.parse::<OpcodeDecl>()?;
|
||||||
|
let code_block = input.parse::<CodeBlock>()?;
|
||||||
|
result.push((decl, code_block))
|
||||||
|
}
|
||||||
|
Ok(Self(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OpcodeDecl(pub Instruction, pub Arguments);
|
||||||
|
|
||||||
|
impl OpcodeDecl {
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
Instruction::peek(input) && !input.peek2(Token![=])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for OpcodeDecl {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
Ok(Self(
|
||||||
|
input.parse::<Instruction>()?,
|
||||||
|
input.parse::<Arguments>()?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CodeBlock {
|
||||||
|
pub special: bool,
|
||||||
|
pub code: proc_macro2::Group,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for CodeBlock {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
let (special, code) = if lookahead.peek(Token![<]) {
|
||||||
|
input.parse::<Token![<]>()?;
|
||||||
|
input.parse::<Token![=]>()?;
|
||||||
|
//input.parse::<Token![>]>()?;
|
||||||
|
(true, input.parse::<proc_macro2::Group>()?)
|
||||||
|
} else if lookahead.peek(Token![=]) {
|
||||||
|
input.parse::<Token![=]>()?;
|
||||||
|
input.parse::<Token![>]>()?;
|
||||||
|
(false, input.parse::<proc_macro2::Group>()?)
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
};
|
||||||
|
Ok(Self { special, code })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Rule {
|
||||||
|
pub modifier: Option<DotModifier>,
|
||||||
|
pub type_: Option<Type>,
|
||||||
|
pub alternatives: Vec<DotModifier>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Rule {
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
DotModifier::peek(input)
|
||||||
|
|| (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>]))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result<Vec<DotModifier>> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
Self::parse_with_alternative(input, &mut result)?;
|
||||||
|
loop {
|
||||||
|
if !input.peek(Token![,]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
input.parse::<Token![,]>()?;
|
||||||
|
Self::parse_with_alternative(input, &mut result)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_with_alternative(
|
||||||
|
input: &syn::parse::ParseBuffer,
|
||||||
|
result: &mut Vec<DotModifier>,
|
||||||
|
) -> Result<(), syn::Error> {
|
||||||
|
input.parse::<Token![.]>()?;
|
||||||
|
let part1 = input.parse::<IdentLike>()?;
|
||||||
|
if input.peek(token::Brace) {
|
||||||
|
result.push(DotModifier {
|
||||||
|
part1: part1.clone(),
|
||||||
|
part2: None,
|
||||||
|
});
|
||||||
|
let suffix_content;
|
||||||
|
braced!(suffix_content in input);
|
||||||
|
let suffixes = Punctuated::<IdentOrTypeSuffix, Token![,]>::parse_separated_nonempty(
|
||||||
|
&suffix_content,
|
||||||
|
)?;
|
||||||
|
for part2 in suffixes {
|
||||||
|
result.push(DotModifier {
|
||||||
|
part1: part1.clone(),
|
||||||
|
part2: Some(part2),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if IdentOrTypeSuffix::peek(input) {
|
||||||
|
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
|
||||||
|
result.push(DotModifier { part1, part2 });
|
||||||
|
} else {
|
||||||
|
result.push(DotModifier { part1, part2: None });
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||||
|
struct IdentOrTypeSuffix(IdentLike);
|
||||||
|
|
||||||
|
impl IdentOrTypeSuffix {
|
||||||
|
fn span(&self) -> Span {
|
||||||
|
self.0.span()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
input.peek(Token![::])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToTokens for IdentOrTypeSuffix {
|
||||||
|
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||||
|
let ident = &self.0;
|
||||||
|
quote! { :: #ident }.to_tokens(tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for IdentOrTypeSuffix {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
input.parse::<Token![::]>()?;
|
||||||
|
Ok(Self(input.parse::<IdentLike>()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for Rule {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let (modifier, type_) = if DotModifier::peek(input) {
|
||||||
|
let modifier = Some(input.parse::<DotModifier>()?);
|
||||||
|
if input.peek(Token![:]) {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
(modifier, Some(input.parse::<Type>()?))
|
||||||
|
} else {
|
||||||
|
(modifier, None)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(None, Some(input.parse::<Type>()?))
|
||||||
|
};
|
||||||
|
input.parse::<Token![=]>()?;
|
||||||
|
let content;
|
||||||
|
braced!(content in input);
|
||||||
|
let alternatives = Self::parse_alternatives(&content)?;
|
||||||
|
Ok(Self {
|
||||||
|
modifier,
|
||||||
|
type_,
|
||||||
|
alternatives,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Instruction {
|
||||||
|
pub name: Ident,
|
||||||
|
pub modifiers: Vec<MaybeDotModifier>,
|
||||||
|
}
|
||||||
|
impl Instruction {
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
input.peek(Ident)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for Instruction {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let instruction = input.parse::<Ident>()?;
|
||||||
|
let mut modifiers = Vec::new();
|
||||||
|
loop {
|
||||||
|
if !MaybeDotModifier::peek(input) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
modifiers.push(MaybeDotModifier::parse(input)?);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
name: instruction,
|
||||||
|
modifiers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MaybeDotModifier {
|
||||||
|
pub optional: bool,
|
||||||
|
pub modifier: DotModifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeDotModifier {
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
input.peek(token::Brace) || DotModifier::peek(input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for MaybeDotModifier {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
Ok(if input.peek(token::Brace) {
|
||||||
|
let content;
|
||||||
|
braced!(content in input);
|
||||||
|
let modifier = DotModifier::parse(&content)?;
|
||||||
|
Self {
|
||||||
|
modifier,
|
||||||
|
optional: true,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let modifier = DotModifier::parse(input)?;
|
||||||
|
Self {
|
||||||
|
modifier,
|
||||||
|
optional: false,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||||
|
pub struct DotModifier {
|
||||||
|
part1: IdentLike,
|
||||||
|
part2: Option<IdentOrTypeSuffix>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for DotModifier {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, ".")?;
|
||||||
|
self.part1.fmt(f)?;
|
||||||
|
if let Some(ref part2) = self.part2 {
|
||||||
|
write!(f, "::")?;
|
||||||
|
part2.0.fmt(f)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for DotModifier {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
std::fmt::Display::fmt(&self, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DotModifier {
|
||||||
|
pub fn span(&self) -> Span {
|
||||||
|
let part1 = self.part1.span();
|
||||||
|
if let Some(ref part2) = self.part2 {
|
||||||
|
part1.join(part2.span()).unwrap_or(part1)
|
||||||
|
} else {
|
||||||
|
part1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ident(&self) -> Ident {
|
||||||
|
let mut result = String::new();
|
||||||
|
write!(&mut result, "{}", self.part1).unwrap();
|
||||||
|
if let Some(ref part2) = self.part2 {
|
||||||
|
write!(&mut result, "_{}", part2.0).unwrap();
|
||||||
|
} else {
|
||||||
|
match self.part1 {
|
||||||
|
IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'),
|
||||||
|
IdentLike::Ident(_) | IdentLike::Integer(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ident::new(&result.to_ascii_lowercase(), self.span())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn variant_capitalized(&self) -> Ident {
|
||||||
|
self.capitalized_impl(String::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dot_capitalized(&self) -> Ident {
|
||||||
|
self.capitalized_impl("Dot".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn capitalized_impl(&self, prefix: String) -> Ident {
|
||||||
|
let mut temp = String::new();
|
||||||
|
write!(&mut temp, "{}", &self.part1).unwrap();
|
||||||
|
if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 {
|
||||||
|
write!(&mut temp, "_{}", part2).unwrap();
|
||||||
|
}
|
||||||
|
let mut result = prefix;
|
||||||
|
let mut capitalize = true;
|
||||||
|
for c in temp.chars() {
|
||||||
|
if c == '_' {
|
||||||
|
capitalize = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Special hack to emit `BF16`` instead of `Bf16``
|
||||||
|
let c = if capitalize || c == 'f' && result.ends_with('B') {
|
||||||
|
capitalize = false;
|
||||||
|
c.to_ascii_uppercase()
|
||||||
|
} else {
|
||||||
|
c
|
||||||
|
};
|
||||||
|
result.push(c);
|
||||||
|
}
|
||||||
|
Ident::new(&result, self.span())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokens(&self) -> TokenStream {
|
||||||
|
let part1 = &self.part1;
|
||||||
|
let part2 = &self.part2;
|
||||||
|
match self.part2 {
|
||||||
|
None => quote! { . #part1 },
|
||||||
|
Some(_) => quote! { . #part1 #part2 },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||||
|
input.peek(Token![.])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for DotModifier {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
input.parse::<Token![.]>()?;
|
||||||
|
let part1 = input.parse::<IdentLike>()?;
|
||||||
|
if IdentOrTypeSuffix::peek(input) {
|
||||||
|
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
|
||||||
|
Ok(Self { part1, part2 })
|
||||||
|
} else {
|
||||||
|
Ok(Self { part1, part2: None })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||||
|
enum IdentLike {
|
||||||
|
Type(Token![type]),
|
||||||
|
Const(Token![const]),
|
||||||
|
Ident(Ident),
|
||||||
|
Integer(LitInt),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IdentLike {
|
||||||
|
fn span(&self) -> Span {
|
||||||
|
match self {
|
||||||
|
IdentLike::Type(c) => c.span(),
|
||||||
|
IdentLike::Const(t) => t.span(),
|
||||||
|
IdentLike::Ident(i) => i.span(),
|
||||||
|
IdentLike::Integer(l) => l.span(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for IdentLike {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
IdentLike::Type(_) => f.write_str("type"),
|
||||||
|
IdentLike::Const(_) => f.write_str("const"),
|
||||||
|
IdentLike::Ident(ident) => write!(f, "{}", ident),
|
||||||
|
IdentLike::Integer(integer) => write!(f, "{}", integer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToTokens for IdentLike {
|
||||||
|
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||||
|
match self {
|
||||||
|
IdentLike::Type(_) => quote! { type }.to_tokens(tokens),
|
||||||
|
IdentLike::Const(_) => quote! { const }.to_tokens(tokens),
|
||||||
|
IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens),
|
||||||
|
IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for IdentLike {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
Ok(if lookahead.peek(Token![const]) {
|
||||||
|
IdentLike::Const(input.parse::<Token![const]>()?)
|
||||||
|
} else if lookahead.peek(Token![type]) {
|
||||||
|
IdentLike::Type(input.parse::<Token![type]>()?)
|
||||||
|
} else if lookahead.peek(Ident) {
|
||||||
|
IdentLike::Ident(input.parse::<Ident>()?)
|
||||||
|
} else if lookahead.peek(LitInt) {
|
||||||
|
IdentLike::Integer(input.parse::<LitInt>()?)
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Arguments decalaration can loook like this:
|
||||||
|
// a{, b}
|
||||||
|
// That's why we don't parse Arguments as Punctuated<Argument, Token![,]>
|
||||||
|
#[derive(PartialEq, Eq)]
|
||||||
|
pub struct Arguments(pub Vec<Argument>);
|
||||||
|
|
||||||
|
impl Parse for Arguments {
|
||||||
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
loop {
|
||||||
|
if input.peek(Token![,]) {
|
||||||
|
input.parse::<Token![,]>()?;
|
||||||
|
}
|
||||||
|
let mut optional = false;
|
||||||
|
let mut can_be_negated = false;
|
||||||
|
let mut pre_pipe = false;
|
||||||
|
let ident;
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
if lookahead.peek(token::Brace) {
|
||||||
|
let content;
|
||||||
|
braced!(content in input);
|
||||||
|
let lookahead = content.lookahead1();
|
||||||
|
if lookahead.peek(Token![!]) {
|
||||||
|
content.parse::<Token![!]>()?;
|
||||||
|
can_be_negated = true;
|
||||||
|
ident = input.parse::<Ident>()?;
|
||||||
|
} else if lookahead.peek(Token![,]) {
|
||||||
|
optional = true;
|
||||||
|
content.parse::<Token![,]>()?;
|
||||||
|
ident = content.parse::<Ident>()?;
|
||||||
|
} else {
|
||||||
|
return Err(lookahead.error());
|
||||||
|
}
|
||||||
|
} else if lookahead.peek(token::Bracket) {
|
||||||
|
let bracketed;
|
||||||
|
bracketed!(bracketed in input);
|
||||||
|
if bracketed.peek(Token![|]) {
|
||||||
|
optional = true;
|
||||||
|
bracketed.parse::<Token![|]>()?;
|
||||||
|
pre_pipe = true;
|
||||||
|
ident = bracketed.parse::<Ident>()?;
|
||||||
|
} else {
|
||||||
|
let mut sub_args = Self::parse(&bracketed)?;
|
||||||
|
sub_args.0.first_mut().unwrap().pre_bracket = true;
|
||||||
|
sub_args.0.last_mut().unwrap().post_bracket = true;
|
||||||
|
if peek_brace_token(input, Token![.]) {
|
||||||
|
let optional_suffix;
|
||||||
|
braced!(optional_suffix in input);
|
||||||
|
optional_suffix.parse::<Token![.]>()?;
|
||||||
|
let unified_ident = optional_suffix.parse::<Ident>()?;
|
||||||
|
if unified_ident.to_string() != "unified" {
|
||||||
|
return Err(syn::Error::new(
|
||||||
|
unified_ident.span(),
|
||||||
|
format!("Exptected `unified`, got `{}`", unified_ident),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
for a in sub_args.0.iter_mut() {
|
||||||
|
a.unified = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.extend(sub_args.0);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else if lookahead.peek(Ident) {
|
||||||
|
ident = input.parse::<Ident>()?;
|
||||||
|
} else if lookahead.peek(Token![|]) {
|
||||||
|
input.parse::<Token![|]>()?;
|
||||||
|
pre_pipe = true;
|
||||||
|
ident = input.parse::<Ident>()?;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result.push(Argument {
|
||||||
|
optional,
|
||||||
|
pre_pipe,
|
||||||
|
can_be_negated,
|
||||||
|
pre_bracket: false,
|
||||||
|
ident,
|
||||||
|
post_bracket: false,
|
||||||
|
unified: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(Self(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is effectively input.peek(token::Brace) && input.peek2(Token![.])
|
||||||
|
// input.peek2 is supposed to skip over next token, but it skips over whole
|
||||||
|
// braced token group. Not sure if it's a bug
|
||||||
|
fn peek_brace_token<T: Peek>(input: syn::parse::ParseStream, _t: T) -> bool {
|
||||||
|
use syn::token::Token;
|
||||||
|
let cursor = input.cursor();
|
||||||
|
cursor
|
||||||
|
.group(proc_macro2::Delimiter::Brace)
|
||||||
|
.map_or(false, |(content, ..)| T::Token::peek(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq)]
|
||||||
|
pub struct Argument {
|
||||||
|
pub optional: bool,
|
||||||
|
pub pre_bracket: bool,
|
||||||
|
pub pre_pipe: bool,
|
||||||
|
pub can_be_negated: bool,
|
||||||
|
pub ident: Ident,
|
||||||
|
pub post_bracket: bool,
|
||||||
|
pub unified: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{Arguments, DotModifier, MaybeDotModifier};
|
||||||
|
use quote::{quote, ToTokens};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_modifier_complex() {
|
||||||
|
let input = quote! {
|
||||||
|
.level::eviction_priority
|
||||||
|
};
|
||||||
|
let modifier = syn::parse2::<DotModifier>(input).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
". level :: eviction_priority",
|
||||||
|
modifier.tokens().to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_modifier_optional() {
|
||||||
|
let input = quote! {
|
||||||
|
{ .level::eviction_priority }
|
||||||
|
};
|
||||||
|
let maybe_modifider = syn::parse2::<MaybeDotModifier>(input).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
". level :: eviction_priority",
|
||||||
|
maybe_modifider.modifier.tokens().to_string()
|
||||||
|
);
|
||||||
|
assert!(maybe_modifider.optional);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_type_token() {
|
||||||
|
let input = quote! {
|
||||||
|
. type
|
||||||
|
};
|
||||||
|
let maybe_modifier = syn::parse2::<MaybeDotModifier>(input).unwrap();
|
||||||
|
assert_eq!(". type", maybe_modifier.modifier.tokens().to_string());
|
||||||
|
assert!(!maybe_modifier.optional);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arguments_memory() {
|
||||||
|
let input = quote! {
|
||||||
|
[a], b
|
||||||
|
};
|
||||||
|
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||||
|
let a = &arguments.0[0];
|
||||||
|
assert!(!a.optional);
|
||||||
|
assert_eq!("a", a.ident.to_string());
|
||||||
|
assert!(a.pre_bracket);
|
||||||
|
assert!(!a.pre_pipe);
|
||||||
|
assert!(a.post_bracket);
|
||||||
|
assert!(!a.can_be_negated);
|
||||||
|
let b = &arguments.0[1];
|
||||||
|
assert!(!b.optional);
|
||||||
|
assert_eq!("b", b.ident.to_string());
|
||||||
|
assert!(!b.pre_bracket);
|
||||||
|
assert!(!b.pre_pipe);
|
||||||
|
assert!(!b.post_bracket);
|
||||||
|
assert!(!b.can_be_negated);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arguments_optional() {
|
||||||
|
let input = quote! {
|
||||||
|
b{, cache_policy}
|
||||||
|
};
|
||||||
|
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||||
|
let b = &arguments.0[0];
|
||||||
|
assert!(!b.optional);
|
||||||
|
assert_eq!("b", b.ident.to_string());
|
||||||
|
assert!(!b.pre_bracket);
|
||||||
|
assert!(!b.pre_pipe);
|
||||||
|
assert!(!b.post_bracket);
|
||||||
|
assert!(!b.can_be_negated);
|
||||||
|
let cache_policy = &arguments.0[1];
|
||||||
|
assert!(cache_policy.optional);
|
||||||
|
assert_eq!("cache_policy", cache_policy.ident.to_string());
|
||||||
|
assert!(!cache_policy.pre_bracket);
|
||||||
|
assert!(!cache_policy.pre_pipe);
|
||||||
|
assert!(!cache_policy.post_bracket);
|
||||||
|
assert!(!cache_policy.can_be_negated);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arguments_optional_pred() {
|
||||||
|
let input = quote! {
|
||||||
|
p[|q], a
|
||||||
|
};
|
||||||
|
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||||
|
assert_eq!(arguments.0.len(), 3);
|
||||||
|
let p = &arguments.0[0];
|
||||||
|
assert!(!p.optional);
|
||||||
|
assert_eq!("p", p.ident.to_string());
|
||||||
|
assert!(!p.pre_bracket);
|
||||||
|
assert!(!p.pre_pipe);
|
||||||
|
assert!(!p.post_bracket);
|
||||||
|
assert!(!p.can_be_negated);
|
||||||
|
let q = &arguments.0[1];
|
||||||
|
assert!(q.optional);
|
||||||
|
assert_eq!("q", q.ident.to_string());
|
||||||
|
assert!(!q.pre_bracket);
|
||||||
|
assert!(q.pre_pipe);
|
||||||
|
assert!(!q.post_bracket);
|
||||||
|
assert!(!q.can_be_negated);
|
||||||
|
let a = &arguments.0[2];
|
||||||
|
assert!(!a.optional);
|
||||||
|
assert_eq!("a", a.ident.to_string());
|
||||||
|
assert!(!a.pre_bracket);
|
||||||
|
assert!(!a.pre_pipe);
|
||||||
|
assert!(!a.post_bracket);
|
||||||
|
assert!(!a.can_be_negated);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arguments_optional_with_negate() {
|
||||||
|
let input = quote! {
|
||||||
|
b, {!}c
|
||||||
|
};
|
||||||
|
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||||
|
assert_eq!(arguments.0.len(), 2);
|
||||||
|
let b = &arguments.0[0];
|
||||||
|
assert!(!b.optional);
|
||||||
|
assert_eq!("b", b.ident.to_string());
|
||||||
|
assert!(!b.pre_bracket);
|
||||||
|
assert!(!b.pre_pipe);
|
||||||
|
assert!(!b.post_bracket);
|
||||||
|
assert!(!b.can_be_negated);
|
||||||
|
let c = &arguments.0[1];
|
||||||
|
assert!(!c.optional);
|
||||||
|
assert_eq!("c", c.ident.to_string());
|
||||||
|
assert!(!c.pre_bracket);
|
||||||
|
assert!(!c.pre_pipe);
|
||||||
|
assert!(!c.post_bracket);
|
||||||
|
assert!(c.can_be_negated);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn arguments_tex() {
|
||||||
|
let input = quote! {
|
||||||
|
d[|p], [a{, b}, c], dpdx, dpdy {, e}
|
||||||
|
};
|
||||||
|
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||||
|
assert_eq!(arguments.0.len(), 8);
|
||||||
|
{
|
||||||
|
let d = &arguments.0[0];
|
||||||
|
assert!(!d.optional);
|
||||||
|
assert_eq!("d", d.ident.to_string());
|
||||||
|
assert!(!d.pre_bracket);
|
||||||
|
assert!(!d.pre_pipe);
|
||||||
|
assert!(!d.post_bracket);
|
||||||
|
assert!(!d.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let p = &arguments.0[1];
|
||||||
|
assert!(p.optional);
|
||||||
|
assert_eq!("p", p.ident.to_string());
|
||||||
|
assert!(!p.pre_bracket);
|
||||||
|
assert!(p.pre_pipe);
|
||||||
|
assert!(!p.post_bracket);
|
||||||
|
assert!(!p.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let a = &arguments.0[2];
|
||||||
|
assert!(!a.optional);
|
||||||
|
assert_eq!("a", a.ident.to_string());
|
||||||
|
assert!(a.pre_bracket);
|
||||||
|
assert!(!a.pre_pipe);
|
||||||
|
assert!(!a.post_bracket);
|
||||||
|
assert!(!a.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let b = &arguments.0[3];
|
||||||
|
assert!(b.optional);
|
||||||
|
assert_eq!("b", b.ident.to_string());
|
||||||
|
assert!(!b.pre_bracket);
|
||||||
|
assert!(!b.pre_pipe);
|
||||||
|
assert!(!b.post_bracket);
|
||||||
|
assert!(!b.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let c = &arguments.0[4];
|
||||||
|
assert!(!c.optional);
|
||||||
|
assert_eq!("c", c.ident.to_string());
|
||||||
|
assert!(!c.pre_bracket);
|
||||||
|
assert!(!c.pre_pipe);
|
||||||
|
assert!(c.post_bracket);
|
||||||
|
assert!(!c.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let dpdx = &arguments.0[5];
|
||||||
|
assert!(!dpdx.optional);
|
||||||
|
assert_eq!("dpdx", dpdx.ident.to_string());
|
||||||
|
assert!(!dpdx.pre_bracket);
|
||||||
|
assert!(!dpdx.pre_pipe);
|
||||||
|
assert!(!dpdx.post_bracket);
|
||||||
|
assert!(!dpdx.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let dpdy = &arguments.0[6];
|
||||||
|
assert!(!dpdy.optional);
|
||||||
|
assert_eq!("dpdy", dpdy.ident.to_string());
|
||||||
|
assert!(!dpdy.pre_bracket);
|
||||||
|
assert!(!dpdy.pre_pipe);
|
||||||
|
assert!(!dpdy.post_bracket);
|
||||||
|
assert!(!dpdy.can_be_negated);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let e = &arguments.0[7];
|
||||||
|
assert!(e.optional);
|
||||||
|
assert_eq!("e", e.ident.to_string());
|
||||||
|
assert!(!e.pre_bracket);
|
||||||
|
assert!(!e.pre_pipe);
|
||||||
|
assert!(!e.post_bracket);
|
||||||
|
assert!(!e.can_be_negated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rule_multi() {
|
||||||
|
let input = quote! {
|
||||||
|
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }
|
||||||
|
};
|
||||||
|
let rule = syn::parse2::<super::Rule>(input).unwrap();
|
||||||
|
assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string());
|
||||||
|
assert_eq!(
|
||||||
|
"StateSpace",
|
||||||
|
rule.type_.unwrap().to_token_stream().to_string()
|
||||||
|
);
|
||||||
|
let alts = rule
|
||||||
|
.alternatives
|
||||||
|
.iter()
|
||||||
|
.map(|m| m.tokens().to_string())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(
|
||||||
|
vec![
|
||||||
|
". global",
|
||||||
|
". local",
|
||||||
|
". param",
|
||||||
|
". param :: func",
|
||||||
|
". shared",
|
||||||
|
". shared :: cta",
|
||||||
|
". shared :: cluster"
|
||||||
|
],
|
||||||
|
alts
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rule_multi2() {
|
||||||
|
let input = quote! {
|
||||||
|
.cop: StCacheOperator = { .wb, .cg, .cs, .wt }
|
||||||
|
};
|
||||||
|
let rule = syn::parse2::<super::Rule>(input).unwrap();
|
||||||
|
assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string());
|
||||||
|
assert_eq!(
|
||||||
|
"StCacheOperator",
|
||||||
|
rule.type_.unwrap().to_token_stream().to_string()
|
||||||
|
);
|
||||||
|
let alts = rule
|
||||||
|
.alternatives
|
||||||
|
.iter()
|
||||||
|
.map(|m| m.tokens().to_string())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn args_unified() {
|
||||||
|
let input = quote! {
|
||||||
|
d, [a]{.unified}{, cache_policy}
|
||||||
|
};
|
||||||
|
let args = syn::parse2::<super::Arguments>(input).unwrap();
|
||||||
|
let a = &args.0[1];
|
||||||
|
assert!(!a.optional);
|
||||||
|
assert_eq!("a", a.ident.to_string());
|
||||||
|
assert!(a.pre_bracket);
|
||||||
|
assert!(!a.pre_pipe);
|
||||||
|
assert!(a.post_bracket);
|
||||||
|
assert!(!a.can_be_negated);
|
||||||
|
assert!(a.unified);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn special_block() {
|
||||||
|
let input = quote! {
|
||||||
|
bra <= { bra(stream) }
|
||||||
|
};
|
||||||
|
syn::parse2::<super::OpcodeDefinition>(input).unwrap();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue