mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
PTX parser rewrite (#267)
Replaces traditional LALRPOP-based parser with winnow-based parser to handle out-of-order instruction modifer. Generate instruction type and instruction visitor from a macro instead of writing by hand. Add separate compilation path using the new parser that only works in tests for now
This commit is contained in:
parent
872054ae40
commit
193eb29be8
34 changed files with 14776 additions and 55 deletions
|
@ -1,5 +1,7 @@
|
|||
[workspace]
|
||||
|
||||
resolver = "2"
|
||||
|
||||
members = [
|
||||
"cuda_base",
|
||||
"cuda_types",
|
||||
|
@ -15,6 +17,9 @@ members = [
|
|||
"zluda_redirect",
|
||||
"zluda_ml",
|
||||
"ptx",
|
||||
"ptx_parser",
|
||||
"ptx_parser_macros",
|
||||
"ptx_parser_macros_impl",
|
||||
]
|
||||
|
||||
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]
|
||||
|
|
|
@ -7,7 +7,7 @@ edition = "2018"
|
|||
[lib]
|
||||
|
||||
[dependencies]
|
||||
lalrpop-util = "0.19"
|
||||
ptx_parser = { path = "../ptx_parser" }
|
||||
regex = "1"
|
||||
rspirv = "0.7"
|
||||
spirv_headers = "1.5"
|
||||
|
@ -17,8 +17,12 @@ bit-vec = "0.6"
|
|||
half ="1.6"
|
||||
bitflags = "1.2"
|
||||
|
||||
[dependencies.lalrpop-util]
|
||||
version = "0.19.12"
|
||||
features = ["lexer"]
|
||||
|
||||
[build-dependencies.lalrpop]
|
||||
version = "0.19"
|
||||
version = "0.19.12"
|
||||
features = ["lexer"]
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -16,6 +16,8 @@ pub enum PtxError {
|
|||
source: ParseFloatError,
|
||||
},
|
||||
#[error("")]
|
||||
Unsupported32Bit,
|
||||
#[error("")]
|
||||
SyntaxError,
|
||||
#[error("")]
|
||||
NonF32Ftz,
|
||||
|
@ -32,15 +34,9 @@ pub enum PtxError {
|
|||
#[error("")]
|
||||
NonExternPointer,
|
||||
#[error("{start}:{end}")]
|
||||
UnrecognizedStatement {
|
||||
start: usize,
|
||||
end: usize,
|
||||
},
|
||||
UnrecognizedStatement { start: usize, end: usize },
|
||||
#[error("{start}:{end}")]
|
||||
UnrecognizedDirective {
|
||||
start: usize,
|
||||
end: usize,
|
||||
},
|
||||
UnrecognizedDirective { start: usize, end: usize },
|
||||
}
|
||||
|
||||
// For some weird reson this is illegal:
|
||||
|
@ -576,11 +572,15 @@ impl CvtDetails {
|
|||
if saturate {
|
||||
if src.kind() == ScalarKind::Signed {
|
||||
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
err.push(ParseError::User {
|
||||
error: PtxError::SyntaxError,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if dst == src || dst.size_of() >= src.size_of() {
|
||||
err.push(ParseError::from(PtxError::SyntaxError));
|
||||
err.push(ParseError::User {
|
||||
error: PtxError::SyntaxError,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -596,7 +596,9 @@ impl CvtDetails {
|
|||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && dst != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||
error: PtxError::NonF32Ftz,
|
||||
}));
|
||||
}
|
||||
CvtDetails::FloatFromInt(CvtDesc {
|
||||
dst,
|
||||
|
@ -616,7 +618,9 @@ impl CvtDetails {
|
|||
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
|
||||
) -> Self {
|
||||
if flush_to_zero && src != ScalarType::F32 {
|
||||
err.push(ParseError::from(PtxError::NonF32Ftz));
|
||||
err.push(ParseError::from(lalrpop_util::ParseError::User {
|
||||
error: PtxError::NonF32Ftz,
|
||||
}));
|
||||
}
|
||||
CvtDetails::IntFromFloat(CvtDesc {
|
||||
dst,
|
||||
|
|
|
@ -24,6 +24,7 @@ lalrpop_mod!(
|
|||
);
|
||||
|
||||
pub mod ast;
|
||||
pub(crate) mod pass;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
mod translate;
|
||||
|
|
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
299
ptx/src/pass/convert_dynamic_shared_memory_usage.rs
Normal file
|
@ -0,0 +1,299 @@
|
|||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use super::*;
|
||||
|
||||
/*
|
||||
PTX represents dynamically allocated shared local memory as
|
||||
.extern .shared .b32 shared_mem[];
|
||||
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
|
||||
And in AMD compilation
|
||||
This pass looks for all uses of .extern .shared and converts them to
|
||||
an additional method argument
|
||||
The question is how this artificial argument should be expressed. There are
|
||||
several options:
|
||||
* Straight conversion:
|
||||
.shared .b32 shared_mem[]
|
||||
* Introduce .param_shared statespace:
|
||||
.param_shared .b32 shared_mem
|
||||
or
|
||||
.param_shared .b32 shared_mem[]
|
||||
* Introduce .shared_ptr <SCALAR> type:
|
||||
.param .shared_ptr .b32 shared_mem
|
||||
* Reuse .ptr hint:
|
||||
.param .u64 .ptr shared_mem
|
||||
This is the most tempting, but also the most nonsensical, .ptr is just a
|
||||
hint, which has no semantical meaning (and the output of our
|
||||
transformation has a semantical meaning - we emit additional
|
||||
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
|
||||
*/
|
||||
pub(super) fn run<'input>(
|
||||
module: Vec<Directive<'input>>,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
) -> Result<Vec<Directive<'input>>, TranslateError> {
|
||||
let mut globals_shared = HashMap::new();
|
||||
for dir in module.iter() {
|
||||
match dir {
|
||||
Directive::Variable(
|
||||
_,
|
||||
ast::Variable {
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name,
|
||||
v_type,
|
||||
..
|
||||
},
|
||||
) => {
|
||||
globals_shared.insert(*name, v_type.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if globals_shared.len() == 0 {
|
||||
return Ok(module);
|
||||
}
|
||||
let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
|
||||
let module = module
|
||||
.into_iter()
|
||||
.map(|directive| match directive {
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}) => {
|
||||
let call_key = (*func_decl).borrow().name;
|
||||
let statements = statements
|
||||
.into_iter()
|
||||
.map(|statement| {
|
||||
statement.visit_map(
|
||||
&mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
|
||||
if let Some(_) = globals_shared.get(&id) {
|
||||
methods_to_directly_used_shared_globals
|
||||
.entry(call_key)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(id);
|
||||
}
|
||||
Ok::<_, TranslateError>(id)
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok::<_, TranslateError>(Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}))
|
||||
}
|
||||
directive => Ok(directive),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
|
||||
// make sure it gets propagated to `fn1` and `kernel`
|
||||
let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
|
||||
methods_to_directly_used_shared_globals,
|
||||
kernels_methods_call_map,
|
||||
);
|
||||
// now visit every method declaration and inject those additional arguments
|
||||
let mut directives = Vec::with_capacity(module.len());
|
||||
for directive in module.into_iter() {
|
||||
match directive {
|
||||
Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}) => {
|
||||
let statements = {
|
||||
let func_decl_ref = &mut (*func_decl).borrow_mut();
|
||||
let method_name = func_decl_ref.name;
|
||||
insert_arguments_remap_statements(
|
||||
new_id,
|
||||
kernels_methods_call_map,
|
||||
&globals_shared,
|
||||
&methods_to_indirectly_used_shared_globals,
|
||||
method_name,
|
||||
&mut directives,
|
||||
func_decl_ref,
|
||||
statements,
|
||||
)?
|
||||
};
|
||||
directives.push(Directive::Method(Function {
|
||||
func_decl,
|
||||
globals,
|
||||
body: Some(statements),
|
||||
import_as,
|
||||
tuning,
|
||||
linkage,
|
||||
}));
|
||||
}
|
||||
directive => directives.push(directive),
|
||||
}
|
||||
}
|
||||
Ok(directives)
|
||||
}
|
||||
|
||||
// We need to compute two kinds of information:
|
||||
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
|
||||
// * If it's a function -> does it use .shared global (directly or indirectly)
|
||||
fn resolve_indirect_uses_of_globals_shared<'input>(
|
||||
methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
|
||||
let mut result = HashMap::new();
|
||||
for (method, callees) in kernels_methods_call_map.methods() {
|
||||
let mut indirect_globals = methods_use_of_globals_shared
|
||||
.get(&method)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.copied()
|
||||
.collect::<BTreeSet<_>>();
|
||||
for &callee in callees {
|
||||
indirect_globals.extend(
|
||||
methods_use_of_globals_shared
|
||||
.get(&ast::MethodName::Func(callee))
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.copied(),
|
||||
);
|
||||
}
|
||||
result.insert(method, indirect_globals);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn insert_arguments_remap_statements<'input>(
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
kernels_methods_call_map: &MethodsCallMap<'input>,
|
||||
globals_shared: &HashMap<SpirvWord, ast::Type>,
|
||||
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||
ast::MethodName<'input, SpirvWord>,
|
||||
BTreeSet<SpirvWord>,
|
||||
>,
|
||||
method_name: ast::MethodName<SpirvWord>,
|
||||
result: &mut Vec<Directive>,
|
||||
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
|
||||
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
|
||||
let remapped_globals_in_method =
|
||||
if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
|
||||
match method_name {
|
||||
ast::MethodName::Func(..) => {
|
||||
let remapped_globals = method_globals
|
||||
.iter()
|
||||
.map(|global| {
|
||||
(
|
||||
*global,
|
||||
(
|
||||
new_id(),
|
||||
globals_shared
|
||||
.get(&global)
|
||||
.unwrap_or_else(|| todo!())
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
|
||||
func_decl_ref.input_arguments.push(ast::Variable {
|
||||
align: None,
|
||||
v_type: shared_global_type.clone(),
|
||||
state_space: ast::StateSpace::Shared,
|
||||
name: *new_shared_global_id,
|
||||
array_init: Vec::new(),
|
||||
});
|
||||
}
|
||||
remapped_globals
|
||||
}
|
||||
ast::MethodName::Kernel(..) => method_globals
|
||||
.iter()
|
||||
.map(|global| {
|
||||
(
|
||||
*global,
|
||||
(
|
||||
*global,
|
||||
globals_shared
|
||||
.get(&global)
|
||||
.unwrap_or_else(|| todo!())
|
||||
.clone(),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect::<BTreeMap<_, _>>(),
|
||||
}
|
||||
} else {
|
||||
return Ok(statements);
|
||||
};
|
||||
replace_uses_of_shared_memory(
|
||||
new_id,
|
||||
methods_to_indirectly_used_shared_globals,
|
||||
statements,
|
||||
remapped_globals_in_method,
|
||||
)
|
||||
}
|
||||
|
||||
fn replace_uses_of_shared_memory<'input>(
|
||||
new_id: &mut impl FnMut() -> SpirvWord,
|
||||
methods_to_indirectly_used_shared_globals: &HashMap<
|
||||
ast::MethodName<'input, SpirvWord>,
|
||||
BTreeSet<SpirvWord>,
|
||||
>,
|
||||
statements: Vec<ExpandedStatement>,
|
||||
remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(statements.len());
|
||||
for statement in statements {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Call {
|
||||
mut data,
|
||||
mut arguments,
|
||||
}) => {
|
||||
// We can safely skip checking call arguments,
|
||||
// because there's simply no way to pass shared ptr
|
||||
// without converting it to .b64 first
|
||||
if let Some(shared_globals_used_by_callee) =
|
||||
methods_to_indirectly_used_shared_globals
|
||||
.get(&ast::MethodName::Func(arguments.func))
|
||||
{
|
||||
for &shared_global_used_by_callee in shared_globals_used_by_callee {
|
||||
let (remapped_shared_id, type_) = remapped_globals_in_method
|
||||
.get(&shared_global_used_by_callee)
|
||||
.unwrap_or_else(|| todo!());
|
||||
data.input_arguments
|
||||
.push((type_.clone(), ast::StateSpace::Shared));
|
||||
arguments.input_arguments.push(*remapped_shared_id);
|
||||
}
|
||||
}
|
||||
result.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}))
|
||||
}
|
||||
statement => {
|
||||
let new_statement =
|
||||
statement.visit_map(&mut |id,
|
||||
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_,
|
||||
_| {
|
||||
Ok::<_, TranslateError>(
|
||||
if let Some((remapped_shared_id, _)) =
|
||||
remapped_globals_in_method.get(&id)
|
||||
{
|
||||
*remapped_shared_id
|
||||
} else {
|
||||
id
|
||||
},
|
||||
)
|
||||
})?;
|
||||
result.push(new_statement);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
524
ptx/src/pass/convert_to_stateful_memory_access.rs
Normal file
524
ptx/src/pass/convert_to_stateful_memory_access.rs
Normal file
|
@ -0,0 +1,524 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
use std::{
|
||||
collections::{BTreeSet, HashSet},
|
||||
iter,
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
/*
|
||||
Our goal here is to transform
|
||||
.visible .entry foobar(.param .u64 input) {
|
||||
.reg .b64 in_addr;
|
||||
.reg .b64 in_addr2;
|
||||
ld.param.u64 in_addr, [input];
|
||||
cvta.to.global.u64 in_addr2, in_addr;
|
||||
}
|
||||
into:
|
||||
.visible .entry foobar(.param .u8 input[]) {
|
||||
.reg .u8 in_addr[];
|
||||
.reg .u8 in_addr2[];
|
||||
ld.param.u8[] in_addr, [input];
|
||||
mov.u8[] in_addr2, in_addr;
|
||||
}
|
||||
or:
|
||||
.visible .entry foobar(.reg .u8 input[]) {
|
||||
.reg .u8 in_addr[];
|
||||
.reg .u8 in_addr2[];
|
||||
mov.u8[] in_addr, input;
|
||||
mov.u8[] in_addr2, in_addr;
|
||||
}
|
||||
or:
|
||||
.visible .entry foobar(.param ptr<u8, global> input) {
|
||||
.reg ptr<u8, global> in_addr;
|
||||
.reg ptr<u8, global> in_addr2;
|
||||
ld.param.ptr<u8, global> in_addr, [input];
|
||||
mov.ptr<u8, global> in_addr2, in_addr;
|
||||
}
|
||||
*/
|
||||
// TODO: detect more patterns (mov, call via reg, call via param)
|
||||
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
|
||||
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
|
||||
// argument expansion
|
||||
// TODO: propagate out of calls and into calls
|
||||
pub(super) fn run<'a, 'input>(
|
||||
func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
func_body: Vec<TypedStatement>,
|
||||
id_defs: &mut NumericIdResolver<'a>,
|
||||
) -> Result<
|
||||
(
|
||||
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
|
||||
Vec<TypedStatement>,
|
||||
),
|
||||
TranslateError,
|
||||
> {
|
||||
let mut method_decl = func_args.borrow_mut();
|
||||
if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
|
||||
drop(method_decl);
|
||||
return Ok((func_args, func_body));
|
||||
}
|
||||
if Rc::strong_count(&func_args) != 1 {
|
||||
return Err(error_unreachable());
|
||||
}
|
||||
let func_args_64bit = (*method_decl)
|
||||
.input_arguments
|
||||
.iter()
|
||||
.filter_map(|arg| match arg.v_type {
|
||||
ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<HashSet<_>>();
|
||||
let mut stateful_markers = Vec::new();
|
||||
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
|
||||
for statement in func_body.iter() {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Cvta {
|
||||
data:
|
||||
ast::CvtaDetails {
|
||||
state_space: ast::StateSpace::Global,
|
||||
direction: ast::CvtaDirection::GenericToExplicit,
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
if let (TypedOperand::Reg(dst), Some(src)) =
|
||||
(arguments.dst, arguments.src.underlying_register())
|
||||
{
|
||||
if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
|
||||
stateful_markers.push((dst, src));
|
||||
}
|
||||
}
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Ld {
|
||||
data:
|
||||
ast::LdDetails {
|
||||
state_space: ast::StateSpace::Param,
|
||||
typ: ast::Type::Scalar(ast::ScalarType::U64),
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Ld {
|
||||
data:
|
||||
ast::LdDetails {
|
||||
state_space: ast::StateSpace::Param,
|
||||
typ: ast::Type::Scalar(ast::ScalarType::S64),
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Ld {
|
||||
data:
|
||||
ast::LdDetails {
|
||||
state_space: ast::StateSpace::Param,
|
||||
typ: ast::Type::Scalar(ast::ScalarType::B64),
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
if let (TypedOperand::Reg(dst), Some(src)) =
|
||||
(arguments.dst, arguments.src.underlying_register())
|
||||
{
|
||||
if func_args_64bit.contains(&src) {
|
||||
multi_hash_map_append(&mut stateful_init_reg, dst, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if stateful_markers.len() == 0 {
|
||||
drop(method_decl);
|
||||
return Ok((func_args, func_body));
|
||||
}
|
||||
let mut func_args_ptr = HashSet::new();
|
||||
let mut regs_ptr_current = HashSet::new();
|
||||
for (dst, src) in stateful_markers {
|
||||
if let Some(func_args) = stateful_init_reg.get(&src) {
|
||||
for a in func_args {
|
||||
func_args_ptr.insert(*a);
|
||||
regs_ptr_current.insert(src);
|
||||
regs_ptr_current.insert(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
// BTreeSet here to have a stable order of iteration,
|
||||
// unfortunately our tests rely on it
|
||||
let mut regs_ptr_seen = BTreeSet::new();
|
||||
while regs_ptr_current.len() > 0 {
|
||||
let mut regs_ptr_new = HashSet::new();
|
||||
for statement in func_body.iter() {
|
||||
match statement {
|
||||
Statement::Instruction(ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::U64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
}) => {
|
||||
// TODO: don't mark result of double pointer sub or double
|
||||
// pointer add as ptr result
|
||||
if let (TypedOperand::Reg(dst), Some(src1)) =
|
||||
(arguments.dst, arguments.src1.underlying_register())
|
||||
{
|
||||
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
|
||||
regs_ptr_new.insert(dst);
|
||||
}
|
||||
} else if let (TypedOperand::Reg(dst), Some(src2)) =
|
||||
(arguments.dst, arguments.src2.underlying_register())
|
||||
{
|
||||
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
|
||||
regs_ptr_new.insert(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Statement::Instruction(ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::U64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
}) => {
|
||||
// TODO: don't mark result of double pointer sub or double
|
||||
// pointer add as ptr result
|
||||
if let (TypedOperand::Reg(dst), Some(src1)) =
|
||||
(arguments.dst, arguments.src1.underlying_register())
|
||||
{
|
||||
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
|
||||
regs_ptr_new.insert(dst);
|
||||
}
|
||||
} else if let (TypedOperand::Reg(dst), Some(src2)) =
|
||||
(arguments.dst, arguments.src2.underlying_register())
|
||||
{
|
||||
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
|
||||
regs_ptr_new.insert(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
for id in regs_ptr_current {
|
||||
regs_ptr_seen.insert(id);
|
||||
}
|
||||
regs_ptr_current = regs_ptr_new;
|
||||
}
|
||||
drop(regs_ptr_current);
|
||||
let mut remapped_ids = HashMap::new();
|
||||
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
|
||||
for reg in regs_ptr_seen {
|
||||
let new_id = id_defs.register_variable(
|
||||
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
ast::StateSpace::Reg,
|
||||
);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: None,
|
||||
name: new_id,
|
||||
array_init: Vec::new(),
|
||||
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
}));
|
||||
remapped_ids.insert(reg, new_id);
|
||||
}
|
||||
for arg in (*method_decl).input_arguments.iter_mut() {
|
||||
if !func_args_ptr.contains(&arg.name) {
|
||||
continue;
|
||||
}
|
||||
let new_id = id_defs.register_variable(
|
||||
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
|
||||
ast::StateSpace::Param,
|
||||
);
|
||||
let old_name = arg.name;
|
||||
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
|
||||
arg.name = new_id;
|
||||
remapped_ids.insert(old_name, new_id);
|
||||
}
|
||||
for statement in func_body {
|
||||
match statement {
|
||||
l @ Statement::Label(_) => result.push(l),
|
||||
c @ Statement::Conditional(_) => result.push(c),
|
||||
c @ Statement::Constant(..) => result.push(c),
|
||||
Statement::Variable(var) => {
|
||||
if !remapped_ids.contains_key(&var.name) {
|
||||
result.push(Statement::Variable(var));
|
||||
}
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::U64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Add {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
}) if is_add_ptr_direct(&remapped_ids, &arguments) => {
|
||||
let (ptr, offset) = match arguments.src1.underlying_register() {
|
||||
Some(src1) if remapped_ids.contains_key(&src1) => {
|
||||
(remapped_ids.get(&src1).unwrap(), arguments.src2)
|
||||
}
|
||||
Some(src2) if remapped_ids.contains_key(&src2) => {
|
||||
(remapped_ids.get(&src2).unwrap(), arguments.src1)
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let dst = arguments.dst.unwrap_reg()?;
|
||||
result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Global,
|
||||
dst: *remapped_ids.get(&dst).unwrap(),
|
||||
ptr_src: *ptr,
|
||||
offset_src: offset,
|
||||
}))
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::U64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
})
|
||||
| Statement::Instruction(ast::Instruction::Sub {
|
||||
data:
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: ast::ScalarType::S64,
|
||||
saturate: false,
|
||||
}),
|
||||
arguments,
|
||||
}) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
|
||||
let (ptr, offset) = match arguments.src1.underlying_register() {
|
||||
Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let offset_neg = id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
result.push(Statement::Instruction(ast::Instruction::Neg {
|
||||
data: ast::TypeFtz {
|
||||
type_: ast::ScalarType::S64,
|
||||
flush_to_zero: None,
|
||||
},
|
||||
arguments: ast::NegArgs {
|
||||
src: offset,
|
||||
dst: TypedOperand::Reg(offset_neg),
|
||||
},
|
||||
}));
|
||||
let dst = arguments.dst.unwrap_reg()?;
|
||||
result.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
|
||||
state_space: ast::StateSpace::Global,
|
||||
dst: *remapped_ids.get(&dst).unwrap(),
|
||||
ptr_src: *ptr,
|
||||
offset_src: TypedOperand::Reg(offset_neg),
|
||||
}))
|
||||
}
|
||||
inst @ Statement::Instruction(_) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = inst.visit_map(&mut FnVisitor::new(
|
||||
|operand, type_space, is_dst, relaxed_conversion| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
operand,
|
||||
type_space,
|
||||
is_dst,
|
||||
relaxed_conversion,
|
||||
)
|
||||
},
|
||||
))?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
repack @ Statement::RepackVector(_) => {
|
||||
let mut post_statements = Vec::new();
|
||||
let new_statement = repack.visit_map(&mut FnVisitor::new(
|
||||
|operand, type_space, is_dst, relaxed_conversion| {
|
||||
convert_to_stateful_memory_access_postprocess(
|
||||
id_defs,
|
||||
&remapped_ids,
|
||||
&mut result,
|
||||
&mut post_statements,
|
||||
operand,
|
||||
type_space,
|
||||
is_dst,
|
||||
relaxed_conversion,
|
||||
)
|
||||
},
|
||||
))?;
|
||||
result.push(new_statement);
|
||||
result.extend(post_statements);
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
drop(method_decl);
|
||||
Ok((func_args, result))
|
||||
}
|
||||
|
||||
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
|
||||
match id_defs.get_typed(id) {
|
||||
Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
|
||||
| Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
|
||||
| Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_add_ptr_direct(
|
||||
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||
arg: &ast::AddArgs<TypedOperand>,
|
||||
) -> bool {
|
||||
match arg.dst {
|
||||
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
|
||||
return false
|
||||
}
|
||||
TypedOperand::Reg(dst) => {
|
||||
if !remapped_ids.contains_key(&dst) {
|
||||
return false;
|
||||
}
|
||||
if let Some(ref src1_reg) = arg.src1.underlying_register() {
|
||||
if remapped_ids.contains_key(src1_reg) {
|
||||
// don't trigger optimization when adding two pointers
|
||||
if let Some(ref src2_reg) = arg.src2.underlying_register() {
|
||||
return !remapped_ids.contains_key(src2_reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(ref src2_reg) = arg.src2.underlying_register() {
|
||||
remapped_ids.contains_key(src2_reg)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_sub_ptr_direct(
|
||||
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||
arg: &ast::SubArgs<TypedOperand>,
|
||||
) -> bool {
|
||||
match arg.dst {
|
||||
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
|
||||
return false
|
||||
}
|
||||
TypedOperand::Reg(dst) => {
|
||||
if !remapped_ids.contains_key(&dst) {
|
||||
return false;
|
||||
}
|
||||
match arg.src1.underlying_register() {
|
||||
Some(ref src1_reg) => {
|
||||
if remapped_ids.contains_key(src1_reg) {
|
||||
// don't trigger optimization when subtracting two pointers
|
||||
arg.src2
|
||||
.underlying_register()
|
||||
.map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_to_stateful_memory_access_postprocess(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
|
||||
result: &mut Vec<TypedStatement>,
|
||||
post_statements: &mut Vec<TypedStatement>,
|
||||
operand: TypedOperand,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_conversion: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
operand.map(|operand, _| {
|
||||
Ok(match remapped_ids.get(&operand) {
|
||||
Some(new_id) => {
|
||||
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
|
||||
// TODO: readd if required
|
||||
if let Some((expected_type, expected_space)) = type_space {
|
||||
let implicit_conversion = if relaxed_conversion {
|
||||
if is_dst {
|
||||
super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
super::insert_implicit_conversions::default_implicit_conversion
|
||||
};
|
||||
if implicit_conversion(
|
||||
(new_operand_space, &new_operand_type),
|
||||
(expected_space, expected_type),
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
return Ok(*new_id);
|
||||
}
|
||||
}
|
||||
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
|
||||
let converting_id = id_defs
|
||||
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
|
||||
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
|
||||
ConversionKind::Default
|
||||
} else {
|
||||
ConversionKind::PtrToPtr
|
||||
};
|
||||
if is_dst {
|
||||
post_statements.push(Statement::Conversion(ImplicitConversion {
|
||||
src: converting_id,
|
||||
dst: *new_id,
|
||||
from_type: old_operand_type,
|
||||
from_space: old_operand_space,
|
||||
to_type: new_operand_type,
|
||||
to_space: new_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
} else {
|
||||
result.push(Statement::Conversion(ImplicitConversion {
|
||||
src: *new_id,
|
||||
dst: converting_id,
|
||||
from_type: new_operand_type,
|
||||
from_space: new_operand_space,
|
||||
to_type: old_operand_type,
|
||||
to_space: old_operand_space,
|
||||
kind,
|
||||
}));
|
||||
converting_id
|
||||
}
|
||||
}
|
||||
None => operand,
|
||||
})
|
||||
})
|
||||
}
|
138
ptx/src/pass/convert_to_typed.rs
Normal file
138
ptx/src/pass/convert_to_typed.rs
Normal file
|
@ -0,0 +1,138 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run(
|
||||
func: Vec<UnconditionalStatement>,
|
||||
fn_defs: &GlobalFnDeclResolver,
|
||||
id_defs: &mut NumericIdResolver,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Mov {
|
||||
data,
|
||||
arguments:
|
||||
ast::MovArgs {
|
||||
dst: ast::ParsedOperand::Reg(dst_reg),
|
||||
src: ast::ParsedOperand::Reg(src_reg),
|
||||
},
|
||||
} if fn_defs.fns.contains_key(&src_reg) => {
|
||||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
|
||||
dst: dst_reg,
|
||||
src: src_reg,
|
||||
}));
|
||||
}
|
||||
ast::Instruction::Call { data, arguments } => {
|
||||
let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
|
||||
let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let reresolved_call =
|
||||
Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
|
||||
visitor.func.push(reresolved_call);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
inst => {
|
||||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
|
||||
let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
|
||||
visitor.func.push(instruction);
|
||||
visitor.func.extend(visitor.post_stmts);
|
||||
}
|
||||
},
|
||||
Statement::Label(i) => result.push(Statement::Label(i)),
|
||||
Statement::Variable(v) => result.push(Statement::Variable(v)),
|
||||
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct VectorRepackVisitor<'a, 'b> {
|
||||
func: &'b mut Vec<TypedStatement>,
|
||||
id_def: &'b mut NumericIdResolver<'a>,
|
||||
post_stmts: Option<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
|
||||
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
|
||||
VectorRepackVisitor {
|
||||
func,
|
||||
id_def,
|
||||
post_stmts: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_vector(
|
||||
&mut self,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
typ: &ast::Type,
|
||||
state_space: ast::StateSpace,
|
||||
idx: Vec<SpirvWord>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
// mov.u32 foobar, {a,b};
|
||||
let scalar_t = match typ {
|
||||
ast::Type::Vector(_, scalar_t) => *scalar_t,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let temp_vec = self
|
||||
.id_def
|
||||
.register_intermediate(Some((typ.clone(), state_space)));
|
||||
let statement = Statement::RepackVector(RepackVectorDetails {
|
||||
is_extract: is_dst,
|
||||
typ: scalar_t,
|
||||
packed: temp_vec,
|
||||
unpacked: idx,
|
||||
relaxed_type_check,
|
||||
});
|
||||
if is_dst {
|
||||
self.post_stmts = Some(statement);
|
||||
} else {
|
||||
self.func.push(statement);
|
||||
}
|
||||
Ok(temp_vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
|
||||
for VectorRepackVisitor<'a, 'b>
|
||||
{
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
ident: SpirvWord,
|
||||
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
_: bool,
|
||||
_: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(ident)
|
||||
}
|
||||
|
||||
fn visit(
|
||||
&mut self,
|
||||
op: ast::ParsedOperand<SpirvWord>,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
Ok(match op {
|
||||
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
|
||||
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
|
||||
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
|
||||
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
|
||||
ast::ParsedOperand::VecPack(vec) => {
|
||||
let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
|
||||
TypedOperand::Reg(self.convert_vector(
|
||||
is_dst,
|
||||
relaxed_type_check,
|
||||
type_,
|
||||
space,
|
||||
vec,
|
||||
)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
2763
ptx/src/pass/emit_spirv.rs
Normal file
2763
ptx/src/pass/emit_spirv.rs
Normal file
File diff suppressed because it is too large
Load diff
181
ptx/src/pass/expand_arguments.rs
Normal file
181
ptx/src/pass/expand_arguments.rs
Normal file
|
@ -0,0 +1,181 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(super) fn run<'a, 'b>(
|
||||
func: Vec<TypedStatement>,
|
||||
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
|
||||
Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
|
||||
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
|
||||
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
|
||||
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
|
||||
Statement::Constant(c) => result.push(Statement::Constant(c)),
|
||||
Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
|
||||
s => {
|
||||
let (new_statement, post_stmts) = {
|
||||
let mut visitor = FlattenArguments::new(&mut result, id_def);
|
||||
(s.visit_map(&mut visitor)?, visitor.post_stmts)
|
||||
};
|
||||
result.push(new_statement);
|
||||
result.extend(post_stmts);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct FlattenArguments<'a, 'b> {
|
||||
func: &'b mut Vec<ExpandedStatement>,
|
||||
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||
post_stmts: Vec<ExpandedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FlattenArguments<'a, 'b> {
|
||||
fn new(
|
||||
func: &'b mut Vec<ExpandedStatement>,
|
||||
id_def: &'b mut MutableNumericIdResolver<'a>,
|
||||
) -> Self {
|
||||
FlattenArguments {
|
||||
func,
|
||||
id_def,
|
||||
post_stmts: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
fn reg_offset(
|
||||
&mut self,
|
||||
reg: SpirvWord,
|
||||
offset: i32,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
|
||||
(type_, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
|
||||
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
|
||||
if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let reg_scalar_type = match reg_type {
|
||||
ast::Type::Scalar(underlying_type) => underlying_type,
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let id_constant_stmt = self
|
||||
.id_def
|
||||
.register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: reg_scalar_type,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let arith_details = match reg_scalar_type.kind() {
|
||||
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
|
||||
ast::ArithDetails::Integer(ast::ArithInteger {
|
||||
type_: reg_scalar_type,
|
||||
saturate: false,
|
||||
})
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
|
||||
self.func
|
||||
.push(Statement::Instruction(ast::Instruction::Add {
|
||||
data: arith_details,
|
||||
arguments: ast::AddArgs {
|
||||
dst: id_add_result,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
}));
|
||||
Ok(id_add_result)
|
||||
} else {
|
||||
let id_constant_stmt = self.id_def.register_intermediate(
|
||||
ast::Type::Scalar(ast::ScalarType::S64),
|
||||
ast::StateSpace::Reg,
|
||||
);
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: ast::ScalarType::S64,
|
||||
value: ast::ImmediateValue::S64(offset as i64),
|
||||
}));
|
||||
let dst = self
|
||||
.id_def
|
||||
.register_intermediate(type_.clone(), state_space);
|
||||
self.func.push(Statement::PtrAccess(PtrAccess {
|
||||
underlying_type: type_.clone(),
|
||||
state_space: state_space,
|
||||
dst,
|
||||
ptr_src: reg,
|
||||
offset_src: id_constant_stmt,
|
||||
}));
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
fn immediate(
|
||||
&mut self,
|
||||
value: ast::ImmediateValue,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
let (scalar_t, state_space) =
|
||||
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
|
||||
(*scalar, state_space)
|
||||
} else {
|
||||
return Err(TranslateError::UntypedSymbol);
|
||||
};
|
||||
let id = self
|
||||
.id_def
|
||||
.register_intermediate(ast::Type::Scalar(scalar_t), state_space);
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: scalar_t,
|
||||
value,
|
||||
}));
|
||||
Ok(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
|
||||
fn visit(
|
||||
&mut self,
|
||||
args: TypedOperand,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
match args {
|
||||
TypedOperand::Reg(r) => self.reg(r),
|
||||
TypedOperand::Imm(x) => self.immediate(x, type_space),
|
||||
TypedOperand::RegOffset(reg, offset) => {
|
||||
self.reg_offset(reg, offset, type_space, is_dst)
|
||||
}
|
||||
TypedOperand::VecMember(..) => Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
name: <TypedOperand as ptx_parser::Operand>::Ident,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
_is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
|
||||
self.reg(name)
|
||||
}
|
||||
}
|
282
ptx/src/pass/extract_globals.rs
Normal file
282
ptx/src/pass/extract_globals.rs
Normal file
|
@ -0,0 +1,282 @@
|
|||
use super::*;
|
||||
|
||||
pub(super) fn run<'input, 'b>(
|
||||
sorted_statements: Vec<ExpandedStatement>,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
|
||||
let mut local = Vec::with_capacity(sorted_statements.len());
|
||||
let mut global = Vec::new();
|
||||
for statement in sorted_statements {
|
||||
match statement {
|
||||
Statement::Variable(
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Shared,
|
||||
..
|
||||
},
|
||||
)
|
||||
| Statement::Variable(
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Global,
|
||||
..
|
||||
},
|
||||
) => global.push(var),
|
||||
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfe { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Bfi { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
|
||||
let fn_name: String =
|
||||
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Brev { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
|
||||
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Activemask { arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::IncrementWrap,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_inc",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::DecrementWrap,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_dec",
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom {
|
||||
data:
|
||||
data @ ast::AtomDetails {
|
||||
op: ast::AtomicOp::FloatAdd,
|
||||
semantics,
|
||||
scope,
|
||||
space,
|
||||
..
|
||||
},
|
||||
arguments,
|
||||
}) => {
|
||||
let scalar_type = match data.type_ {
|
||||
ptx_parser::Type::Scalar(scalar) => scalar,
|
||||
_ => return Err(error_unreachable()),
|
||||
};
|
||||
let fn_name = [
|
||||
ZLUDA_PTX_PREFIX,
|
||||
"atom_",
|
||||
semantics_to_ptx_name(semantics),
|
||||
"_",
|
||||
scope_to_ptx_name(scope),
|
||||
"_",
|
||||
space_to_ptx_name(space),
|
||||
"_add_",
|
||||
scalar_to_ptx_name(scalar_type),
|
||||
]
|
||||
.concat();
|
||||
local.push(instruction_to_fn_call(
|
||||
id_def,
|
||||
ptx_impl_imports,
|
||||
ast::Instruction::Atom { data, arguments },
|
||||
fn_name,
|
||||
)?);
|
||||
}
|
||||
s => local.push(s),
|
||||
}
|
||||
}
|
||||
Ok((local, global))
|
||||
}
|
||||
|
||||
fn instruction_to_fn_call(
|
||||
id_defs: &mut NumericIdResolver,
|
||||
ptx_impl_imports: &mut HashMap<String, Directive>,
|
||||
inst: ast::Instruction<SpirvWord>,
|
||||
fn_name: String,
|
||||
) -> Result<ExpandedStatement, TranslateError> {
|
||||
let mut arguments = Vec::new();
|
||||
ast::visit_map(inst, &mut |operand,
|
||||
type_space: Option<(
|
||||
&ast::Type,
|
||||
ast::StateSpace,
|
||||
)>,
|
||||
is_dst,
|
||||
_| {
|
||||
let (typ, space) = match type_space {
|
||||
Some((typ, space)) => (typ.clone(), space),
|
||||
None => return Err(error_unreachable()),
|
||||
};
|
||||
arguments.push((operand, is_dst, typ, space));
|
||||
Ok(SpirvWord(0))
|
||||
})?;
|
||||
let return_arguments_count = arguments
|
||||
.iter()
|
||||
.position(|(desc, is_dst, _, _)| !is_dst)
|
||||
.unwrap_or(arguments.len());
|
||||
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
|
||||
let fn_id = register_external_fn_call(
|
||||
id_defs,
|
||||
ptx_impl_imports,
|
||||
fn_name,
|
||||
return_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ, *state)),
|
||||
input_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ, *state)),
|
||||
)?;
|
||||
Ok(Statement::Instruction(ast::Instruction::Call {
|
||||
data: ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||
.collect::<Vec<_>>(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, _, typ, state)| (typ.clone(), *state))
|
||||
.collect::<Vec<_>>(),
|
||||
},
|
||||
arguments: ast::CallArgs {
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _, _)| *name)
|
||||
.collect::<Vec<_>>(),
|
||||
func: fn_id,
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(name, _, _, _)| *name)
|
||||
.collect::<Vec<_>>(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
|
||||
match this {
|
||||
ast::ScalarType::B8 => "b8",
|
||||
ast::ScalarType::B16 => "b16",
|
||||
ast::ScalarType::B32 => "b32",
|
||||
ast::ScalarType::B64 => "b64",
|
||||
ast::ScalarType::B128 => "b128",
|
||||
ast::ScalarType::U8 => "u8",
|
||||
ast::ScalarType::U16 => "u16",
|
||||
ast::ScalarType::U16x2 => "u16x2",
|
||||
ast::ScalarType::U32 => "u32",
|
||||
ast::ScalarType::U64 => "u64",
|
||||
ast::ScalarType::S8 => "s8",
|
||||
ast::ScalarType::S16 => "s16",
|
||||
ast::ScalarType::S16x2 => "s16x2",
|
||||
ast::ScalarType::S32 => "s32",
|
||||
ast::ScalarType::S64 => "s64",
|
||||
ast::ScalarType::F16 => "f16",
|
||||
ast::ScalarType::F16x2 => "f16x2",
|
||||
ast::ScalarType::F32 => "f32",
|
||||
ast::ScalarType::F64 => "f64",
|
||||
ast::ScalarType::BF16 => "bf16",
|
||||
ast::ScalarType::BF16x2 => "bf16x2",
|
||||
ast::ScalarType::Pred => "pred",
|
||||
}
|
||||
}
|
||||
|
||||
fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
|
||||
match this {
|
||||
ast::AtomSemantics::Relaxed => "relaxed",
|
||||
ast::AtomSemantics::Acquire => "acquire",
|
||||
ast::AtomSemantics::Release => "release",
|
||||
ast::AtomSemantics::AcqRel => "acq_rel",
|
||||
}
|
||||
}
|
||||
|
||||
fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
|
||||
match this {
|
||||
ast::MemScope::Cta => "cta",
|
||||
ast::MemScope::Gpu => "gpu",
|
||||
ast::MemScope::Sys => "sys",
|
||||
ast::MemScope::Cluster => "cluster",
|
||||
}
|
||||
}
|
||||
|
||||
fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
|
||||
match this {
|
||||
ast::StateSpace::Generic => "generic",
|
||||
ast::StateSpace::Global => "global",
|
||||
ast::StateSpace::Shared => "shared",
|
||||
ast::StateSpace::Reg => "reg",
|
||||
ast::StateSpace::Const => "const",
|
||||
ast::StateSpace::Local => "local",
|
||||
ast::StateSpace::Param => "param",
|
||||
ast::StateSpace::Sreg => "sreg",
|
||||
ast::StateSpace::SharedCluster => "shared_cluster",
|
||||
ast::StateSpace::ParamEntry => "param_entry",
|
||||
ast::StateSpace::SharedCta => "shared_cta",
|
||||
ast::StateSpace::ParamFunc => "param_func",
|
||||
}
|
||||
}
|
130
ptx/src/pass/fix_special_registers.rs
Normal file
130
ptx/src/pass/fix_special_registers.rs
Normal file
|
@ -0,0 +1,130 @@
|
|||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(super) fn run<'a, 'b, 'input>(
|
||||
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||
typed_statements: Vec<TypedStatement>,
|
||||
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let result = Vec::with_capacity(typed_statements.len());
|
||||
let mut sreg_sresolver = SpecialRegisterResolver {
|
||||
ptx_impl_imports,
|
||||
numeric_id_defs,
|
||||
result,
|
||||
};
|
||||
for statement in typed_statements {
|
||||
let statement = statement.visit_map(&mut sreg_sresolver)?;
|
||||
sreg_sresolver.result.push(statement);
|
||||
}
|
||||
Ok(sreg_sresolver.result)
|
||||
}
|
||||
|
||||
struct SpecialRegisterResolver<'a, 'b, 'input> {
|
||||
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
|
||||
numeric_id_defs: &'a mut NumericIdResolver<'b>,
|
||||
result: Vec<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||
for SpecialRegisterResolver<'a, 'b, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: TypedOperand,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.replace_sreg(args, is_dst, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
|
||||
fn replace_sreg(
|
||||
&mut self,
|
||||
name: SpirvWord,
|
||||
is_dst: bool,
|
||||
vector_index: Option<u8>,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
|
||||
if is_dst {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
|
||||
(Some(idx), Some(inp_type)) => {
|
||||
if inp_type != ast::ScalarType::U8 {
|
||||
return Err(TranslateError::Unreachable);
|
||||
}
|
||||
let constant = self.numeric_id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
self.result.push(Statement::Constant(ConstantDefinition {
|
||||
dst: constant,
|
||||
typ: inp_type,
|
||||
value: ast::ImmediateValue::U64(idx as u64),
|
||||
}));
|
||||
vec![(
|
||||
TypedOperand::Reg(constant),
|
||||
ast::Type::Scalar(inp_type),
|
||||
ast::StateSpace::Reg,
|
||||
)]
|
||||
}
|
||||
(None, None) => Vec::new(),
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
|
||||
let return_type = sreg.get_function_return_type();
|
||||
let fn_result = self.numeric_id_defs.register_intermediate(Some((
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
let return_arguments = vec![(
|
||||
fn_result,
|
||||
ast::Type::Scalar(return_type),
|
||||
ast::StateSpace::Reg,
|
||||
)];
|
||||
let fn_call = register_external_fn_call(
|
||||
self.numeric_id_defs,
|
||||
self.ptx_impl_imports,
|
||||
ocl_fn_name.to_string(),
|
||||
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
|
||||
)?;
|
||||
let data = ast::CallDetails {
|
||||
uniform: false,
|
||||
return_arguments: return_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
input_arguments: input_arguments
|
||||
.iter()
|
||||
.map(|(_, typ, space)| (typ.clone(), *space))
|
||||
.collect(),
|
||||
};
|
||||
let arguments = ast::CallArgs {
|
||||
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
func: fn_call,
|
||||
input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
|
||||
};
|
||||
self.result
|
||||
.push(Statement::Instruction(ast::Instruction::Call {
|
||||
data,
|
||||
arguments,
|
||||
}));
|
||||
Ok(fn_result)
|
||||
} else {
|
||||
Ok(name)
|
||||
}
|
||||
}
|
||||
}
|
432
ptx/src/pass/insert_implicit_conversions.rs
Normal file
432
ptx/src/pass/insert_implicit_conversions.rs
Normal file
|
@ -0,0 +1,432 @@
|
|||
use std::mem;
|
||||
|
||||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
There are several kinds of implicit conversions in PTX:
|
||||
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
|
||||
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
|
||||
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
|
||||
semantics are to first zext/chop/bitcast `y` as needed and then do
|
||||
documented special ld/st/cvt conversion rules for destination operands
|
||||
- st.param [x] y (used as function return arguments) same rule as above applies
|
||||
- generic/global ld: for instruction `ld x, [y]`, y must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
|
||||
documented special ld/st/cvt conversion rules are applied to dst
|
||||
- generic/global st: for instruction `st [x], y`, x must be of type
|
||||
b64/u64/s64, which is bitcast to a pointer
|
||||
*/
|
||||
pub(super) fn run(
|
||||
func: Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
) -> Result<Vec<ExpandedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func.into_iter() {
|
||||
match s {
|
||||
Statement::Instruction(inst) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::Instruction(inst),
|
||||
)?;
|
||||
}
|
||||
Statement::PtrAccess(access) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::PtrAccess(access),
|
||||
)?;
|
||||
}
|
||||
Statement::RepackVector(repack) => {
|
||||
insert_implicit_conversions_impl(
|
||||
&mut result,
|
||||
id_def,
|
||||
Statement::RepackVector(repack),
|
||||
)?;
|
||||
}
|
||||
s @ Statement::Conditional(_)
|
||||
| s @ Statement::Conversion(_)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Constant(_)
|
||||
| s @ Statement::Variable(_)
|
||||
| s @ Statement::LoadVar(..)
|
||||
| s @ Statement::StoreVar(..)
|
||||
| s @ Statement::RetValue(..)
|
||||
| s @ Statement::FunctionPointer(..) => result.push(s),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_implicit_conversions_impl(
|
||||
func: &mut Vec<ExpandedStatement>,
|
||||
id_def: &mut MutableNumericIdResolver,
|
||||
stmt: ExpandedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut post_conv = Vec::new();
|
||||
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
|
||||
&mut |operand,
|
||||
type_state: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst,
|
||||
relaxed_type_check| {
|
||||
let (instr_type, instruction_space) = match type_state {
|
||||
None => return Ok(operand),
|
||||
Some(t) => t,
|
||||
};
|
||||
let (operand_type, operand_space) = id_def.get_typed(operand)?;
|
||||
let conversion_fn = if relaxed_type_check {
|
||||
if is_dst {
|
||||
should_convert_relaxed_dst_wrapper
|
||||
} else {
|
||||
should_convert_relaxed_src_wrapper
|
||||
}
|
||||
} else {
|
||||
default_implicit_conversion
|
||||
};
|
||||
match conversion_fn(
|
||||
(operand_space, &operand_type),
|
||||
(instruction_space, instr_type),
|
||||
)? {
|
||||
Some(conv_kind) => {
|
||||
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
|
||||
let mut from_type = instr_type.clone();
|
||||
let mut from_space = instruction_space;
|
||||
let mut to_type = operand_type;
|
||||
let mut to_space = operand_space;
|
||||
let mut src =
|
||||
id_def.register_intermediate(instr_type.clone(), instruction_space);
|
||||
let mut dst = operand;
|
||||
let result = Ok::<_, TranslateError>(src);
|
||||
if !is_dst {
|
||||
mem::swap(&mut src, &mut dst);
|
||||
mem::swap(&mut from_type, &mut to_type);
|
||||
mem::swap(&mut from_space, &mut to_space);
|
||||
}
|
||||
conv_output.push(Statement::Conversion(ImplicitConversion {
|
||||
src,
|
||||
dst,
|
||||
from_type,
|
||||
from_space,
|
||||
to_type,
|
||||
to_space,
|
||||
kind: conv_kind,
|
||||
}));
|
||||
result
|
||||
}
|
||||
None => Ok(operand),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
func.push(statement);
|
||||
func.append(&mut post_conv);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn default_implicit_conversion(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if instruction_space == ast::StateSpace::Reg {
|
||||
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
|
||||
(operand_type, instruction_type)
|
||||
{
|
||||
if scalar.kind() == ast::ScalarKind::Bit
|
||||
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
|
||||
{
|
||||
return Ok(Some(ConversionKind::Default));
|
||||
}
|
||||
}
|
||||
} else if is_addressable(operand_space) {
|
||||
return Ok(Some(ConversionKind::AddressOf));
|
||||
}
|
||||
}
|
||||
if !space_is_compatible(instruction_space, operand_space) {
|
||||
default_implicit_conversion_space(
|
||||
(operand_space, operand_type),
|
||||
(instruction_space, instruction_type),
|
||||
)
|
||||
} else if instruction_type != operand_type {
|
||||
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_addressable(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Const
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Global
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
|
||||
ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::SharedCta
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Space is different
|
||||
fn default_implicit_conversion_space(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|
||||
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
|
||||
{
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
|
||||
match operand_type {
|
||||
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
|
||||
if *operand_ptr_space == instruction_space =>
|
||||
{
|
||||
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
// TODO: 32 bit
|
||||
ast::Type::Scalar(ast::ScalarType::B64)
|
||||
| ast::Type::Scalar(ast::ScalarType::U64)
|
||||
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
ast::Type::Scalar(ast::ScalarType::B32)
|
||||
| ast::Type::Scalar(ast::ScalarType::U32)
|
||||
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
|
||||
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
|
||||
Ok(Some(ConversionKind::BitToPtr))
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
},
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
|
||||
match instruction_type {
|
||||
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
|
||||
if operand_space == *instruction_ptr_space =>
|
||||
{
|
||||
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
_ => Err(error_mismatched_type()),
|
||||
}
|
||||
} else {
|
||||
Err(error_mismatched_type())
|
||||
}
|
||||
}
|
||||
|
||||
// Space is same, but type is different
|
||||
fn default_implicit_conversion_type(
|
||||
space: ast::StateSpace,
|
||||
operand_type: &ast::Type,
|
||||
instruction_type: &ast::Type,
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if space_is_compatible(space, ast::StateSpace::Reg) {
|
||||
if should_bitcast(instruction_type, operand_type) {
|
||||
Ok(Some(ConversionKind::Default))
|
||||
} else {
|
||||
Err(TranslateError::MismatchedType)
|
||||
}
|
||||
} else {
|
||||
Ok(Some(ConversionKind::PtrToPtr))
|
||||
}
|
||||
}
|
||||
|
||||
fn coerces_to_generic(this: ast::StateSpace) -> bool {
|
||||
match this {
|
||||
ast::StateSpace::Global
|
||||
| ast::StateSpace::Const
|
||||
| ast::StateSpace::Local
|
||||
| ptx_parser::StateSpace::SharedCta
|
||||
| ast::StateSpace::SharedCluster
|
||||
| ast::StateSpace::Shared => true,
|
||||
ast::StateSpace::Reg
|
||||
| ast::StateSpace::Param
|
||||
| ast::StateSpace::ParamEntry
|
||||
| ast::StateSpace::ParamFunc
|
||||
| ast::StateSpace::Generic
|
||||
| ast::StateSpace::Sreg => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
|
||||
match (instr, operand) {
|
||||
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
|
||||
if inst.size_of() != operand.size_of() {
|
||||
return false;
|
||||
}
|
||||
match inst.kind() {
|
||||
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
|
||||
ast::ScalarKind::Signed => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Unsigned
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
operand.kind() == ast::ScalarKind::Bit
|
||||
|| operand.kind() == ast::ScalarKind::Signed
|
||||
}
|
||||
ast::ScalarKind::Pred => false,
|
||||
}
|
||||
}
|
||||
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
|
||||
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
|
||||
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_dst_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !space_is_compatible(operand_space, instruction_space) {
|
||||
return Err(TranslateError::MismatchedType);
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_dst(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(TranslateError::MismatchedType),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
|
||||
fn should_convert_relaxed_dst(
|
||||
dst_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if dst_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (dst_type, instr_type) {
|
||||
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed => {
|
||||
if dst_type.kind() != ast::ScalarKind::Float {
|
||||
if instr_type.size_of() == dst_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else if instr_type.size_of() < dst_type.size_of() {
|
||||
Some(ConversionKind::SignExtend)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= dst_type.size_of()
|
||||
&& dst_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_dst(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn should_convert_relaxed_src_wrapper(
|
||||
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
|
||||
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
|
||||
) -> Result<Option<ConversionKind>, TranslateError> {
|
||||
if !space_is_compatible(operand_space, instruction_space) {
|
||||
return Err(error_mismatched_type());
|
||||
}
|
||||
if operand_type == instruction_type {
|
||||
return Ok(None);
|
||||
}
|
||||
match should_convert_relaxed_src(operand_type, instruction_type) {
|
||||
conv @ Some(_) => Ok(conv),
|
||||
None => Err(error_mismatched_type()),
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
|
||||
fn should_convert_relaxed_src(
|
||||
src_type: &ast::Type,
|
||||
instr_type: &ast::Type,
|
||||
) -> Option<ConversionKind> {
|
||||
if src_type == instr_type {
|
||||
return None;
|
||||
}
|
||||
match (src_type, instr_type) {
|
||||
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
|
||||
ast::ScalarKind::Bit => {
|
||||
if instr_type.size_of() <= src_type.size_of() {
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() != ast::ScalarKind::Float
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Float => {
|
||||
if instr_type.size_of() <= src_type.size_of()
|
||||
&& src_type.kind() == ast::ScalarKind::Bit
|
||||
{
|
||||
Some(ConversionKind::Default)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::ScalarKind::Pred => None,
|
||||
},
|
||||
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
|
||||
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
|
||||
should_convert_relaxed_src(
|
||||
&ast::Type::Scalar(*dst_type),
|
||||
&ast::Type::Scalar(*instr_type),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
275
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
275
ptx/src/pass/insert_mem_ssa_statements.rs
Normal file
|
@ -0,0 +1,275 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
/*
|
||||
How do we handle arguments:
|
||||
- input .params in kernels
|
||||
.param .b64 in_arg
|
||||
get turned into this SPIR-V:
|
||||
%1 = OpFunctionParameter %ulong
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %1
|
||||
We do this for two reasons. One, common treatment for argument-declared
|
||||
.param variables and .param variables inside function (we assume that
|
||||
at SPIR-V level every .param is a pointer in Function storage class)
|
||||
- input .params in functions
|
||||
.param .b64 in_arg
|
||||
get turned into this SPIR-V:
|
||||
%1 = OpFunctionParameter %_ptr_Function_ulong
|
||||
- input .regs
|
||||
.reg .b64 in_arg
|
||||
get turned into the same SPIR-V as kernel .params:
|
||||
%1 = OpFunctionParameter %ulong
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
OpStore %2 %1
|
||||
- output .regs
|
||||
.reg .b64 out_arg
|
||||
get just a variable declaration:
|
||||
%2 = OpVariable %%_ptr_Function_ulong Function
|
||||
- output .params don't exist, they have been moved to input positions
|
||||
by an earlier pass
|
||||
Distinguishing betweem kernel .params and function .params is not the
|
||||
cleanest solution. Alternatively, we could "deparamize" all kernel .param
|
||||
arguments by turning them into .reg arguments like this:
|
||||
.param .b64 arg -> .reg ptr<.b64,.param> arg
|
||||
This has the massive downside that this transformation would have to run
|
||||
very early and would muddy up already difficult code. It's simpler to just
|
||||
have an if here
|
||||
*/
|
||||
pub(super) fn run<'a, 'b>(
|
||||
func: Vec<TypedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
|
||||
) -> Result<Vec<TypedStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for arg in fn_decl.input_arguments.iter_mut() {
|
||||
insert_mem_ssa_argument(
|
||||
id_def,
|
||||
&mut result,
|
||||
arg,
|
||||
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
|
||||
);
|
||||
}
|
||||
for arg in fn_decl.return_arguments.iter() {
|
||||
insert_mem_ssa_argument_reg_return(&mut result, arg);
|
||||
}
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Ret { data } => {
|
||||
// TODO: handle multiple output args
|
||||
match &fn_decl.return_arguments[..] {
|
||||
[return_reg] => {
|
||||
let new_id = id_def.register_intermediate(Some((
|
||||
return_reg.v_type.clone(),
|
||||
ast::StateSpace::Reg,
|
||||
)));
|
||||
result.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: ast::LdArgs {
|
||||
dst: new_id,
|
||||
src: return_reg.name,
|
||||
},
|
||||
typ: return_reg.v_type.clone(),
|
||||
member_index: None,
|
||||
}));
|
||||
result.push(Statement::RetValue(data, new_id));
|
||||
}
|
||||
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
inst => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::Instruction(inst),
|
||||
)?,
|
||||
},
|
||||
Statement::Conditional(bra) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
|
||||
}
|
||||
Statement::Conversion(conv) => {
|
||||
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
|
||||
}
|
||||
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::PtrAccess(ptr_access),
|
||||
)?,
|
||||
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::RepackVector(repack),
|
||||
)?,
|
||||
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
|
||||
id_def,
|
||||
&mut result,
|
||||
Statement::FunctionPointer(func_ptr),
|
||||
)?,
|
||||
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
|
||||
result.push(s)
|
||||
}
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_argument(
|
||||
id_def: &mut NumericIdResolver,
|
||||
func: &mut Vec<TypedStatement>,
|
||||
arg: &mut ast::Variable<SpirvWord>,
|
||||
is_kernel: bool,
|
||||
) {
|
||||
if !is_kernel && arg.state_space == ast::StateSpace::Param {
|
||||
return;
|
||||
}
|
||||
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
|
||||
func.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: ast::StateSpace::Reg,
|
||||
name: arg.name,
|
||||
array_init: Vec::new(),
|
||||
}));
|
||||
func.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::StArgs {
|
||||
src1: arg.name,
|
||||
src2: new_id,
|
||||
},
|
||||
typ: arg.v_type.clone(),
|
||||
member_index: None,
|
||||
}));
|
||||
arg.name = new_id;
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_argument_reg_return(
|
||||
func: &mut Vec<TypedStatement>,
|
||||
arg: &ast::Variable<SpirvWord>,
|
||||
) {
|
||||
func.push(Statement::Variable(ast::Variable {
|
||||
align: arg.align,
|
||||
v_type: arg.v_type.clone(),
|
||||
state_space: arg.state_space,
|
||||
name: arg.name,
|
||||
array_init: arg.array_init.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
fn insert_mem_ssa_statement_default<'a, 'input>(
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
stmt: TypedStatement,
|
||||
) -> Result<(), TranslateError> {
|
||||
let mut visitor = InsertMemSSAVisitor {
|
||||
id_def,
|
||||
func,
|
||||
post_statements: Vec::new(),
|
||||
};
|
||||
let new_stmt = stmt.visit_map(&mut visitor)?;
|
||||
visitor.func.push(new_stmt);
|
||||
visitor.func.extend(visitor.post_statements);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InsertMemSSAVisitor<'a, 'input> {
|
||||
id_def: &'a mut NumericIdResolver<'input>,
|
||||
func: &'a mut Vec<TypedStatement>,
|
||||
post_statements: Vec<TypedStatement>,
|
||||
}
|
||||
|
||||
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
|
||||
fn symbol(
|
||||
&mut self,
|
||||
symbol: SpirvWord,
|
||||
member_index: Option<u8>,
|
||||
expected: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
if expected.is_none() {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
|
||||
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
|
||||
return Ok(symbol);
|
||||
};
|
||||
let member_index = match member_index {
|
||||
Some(idx) => {
|
||||
let vector_width = match var_type {
|
||||
ast::Type::Vector(width, scalar_t) => {
|
||||
var_type = ast::Type::Scalar(scalar_t);
|
||||
width
|
||||
}
|
||||
_ => return Err(error_mismatched_type()),
|
||||
};
|
||||
Some((
|
||||
idx,
|
||||
if self.id_def.special_registers.get(symbol).is_some() {
|
||||
Some(vector_width)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let generated_id = self
|
||||
.id_def
|
||||
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
|
||||
if !is_dst {
|
||||
self.func.push(Statement::LoadVar(LoadVarDetails {
|
||||
arg: ast::LdArgs {
|
||||
dst: generated_id,
|
||||
src: symbol,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index,
|
||||
}));
|
||||
} else {
|
||||
self.post_statements
|
||||
.push(Statement::StoreVar(StoreVarDetails {
|
||||
arg: ast::StArgs {
|
||||
src1: symbol,
|
||||
src2: generated_id,
|
||||
},
|
||||
typ: var_type,
|
||||
member_index: member_index.map(|(idx, _)| idx),
|
||||
}));
|
||||
}
|
||||
Ok(generated_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
|
||||
for InsertMemSSAVisitor<'a, 'input>
|
||||
{
|
||||
fn visit(
|
||||
&mut self,
|
||||
operand: TypedOperand,
|
||||
type_space: Option<(&ast::Type, ast::StateSpace)>,
|
||||
is_dst: bool,
|
||||
_relaxed_type_check: bool,
|
||||
) -> Result<TypedOperand, TranslateError> {
|
||||
Ok(match operand {
|
||||
TypedOperand::Reg(reg) => {
|
||||
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
|
||||
}
|
||||
TypedOperand::RegOffset(reg, offset) => {
|
||||
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
|
||||
}
|
||||
op @ TypedOperand::Imm(..) => op,
|
||||
TypedOperand::VecMember(symbol, index) => {
|
||||
TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
args: SpirvWord,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, TranslateError> {
|
||||
self.symbol(args, None, type_space, is_dst)
|
||||
}
|
||||
}
|
1677
ptx/src/pass/mod.rs
Normal file
1677
ptx/src/pass/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
80
ptx/src/pass/normalize_identifiers.rs
Normal file
80
ptx/src/pass/normalize_identifiers.rs
Normal file
|
@ -0,0 +1,80 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run<'input, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'input, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
|
||||
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
|
||||
) -> Result<Vec<NormalizedStatement>, TranslateError> {
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
ast::Statement::Label(id) => {
|
||||
id_defs.add_def(*id, None, false);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
let mut result = Vec::new();
|
||||
for s in func {
|
||||
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn expand_map_variables<'a, 'b>(
|
||||
id_defs: &mut FnStringIdResolver<'a, 'b>,
|
||||
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
|
||||
result: &mut Vec<NormalizedStatement>,
|
||||
s: ast::Statement<ast::ParsedOperand<&'a str>>,
|
||||
) -> Result<(), TranslateError> {
|
||||
match s {
|
||||
ast::Statement::Block(block) => {
|
||||
id_defs.start_block();
|
||||
for s in block {
|
||||
expand_map_variables(id_defs, fn_defs, result, s)?;
|
||||
}
|
||||
id_defs.end_block();
|
||||
}
|
||||
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
|
||||
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
|
||||
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
|
||||
.transpose()?,
|
||||
ast::visit_map(i, &mut |id,
|
||||
_: Option<(&ast::Type, ast::StateSpace)>,
|
||||
_: bool,
|
||||
_: bool| {
|
||||
id_defs.get_id(id)
|
||||
})?,
|
||||
))),
|
||||
ast::Statement::Variable(var) => {
|
||||
let var_type = var.var.v_type.clone();
|
||||
match var.count {
|
||||
Some(count) => {
|
||||
for new_id in
|
||||
id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
|
||||
{
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let new_id =
|
||||
id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
|
||||
result.push(Statement::Variable(ast::Variable {
|
||||
align: var.var.align,
|
||||
v_type: var.var.v_type.clone(),
|
||||
state_space: var.var.state_space,
|
||||
name: new_id,
|
||||
array_init: var.var.array_init,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
48
ptx/src/pass/normalize_labels.rs
Normal file
48
ptx/src/pass/normalize_labels.rs
Normal file
|
@ -0,0 +1,48 @@
|
|||
use std::{collections::HashSet, iter};
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(super) fn run(
|
||||
func: Vec<ExpandedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Vec<ExpandedStatement> {
|
||||
let mut labels_in_use = HashSet::new();
|
||||
for s in func.iter() {
|
||||
match s {
|
||||
Statement::Instruction(i) => {
|
||||
if let Some(target) = jump_target(i) {
|
||||
labels_in_use.insert(target);
|
||||
}
|
||||
}
|
||||
Statement::Conditional(cond) => {
|
||||
labels_in_use.insert(cond.if_true);
|
||||
labels_in_use.insert(cond.if_false);
|
||||
}
|
||||
Statement::Variable(..)
|
||||
| Statement::LoadVar(..)
|
||||
| Statement::StoreVar(..)
|
||||
| Statement::RetValue(..)
|
||||
| Statement::Conversion(..)
|
||||
| Statement::Constant(..)
|
||||
| Statement::Label(..)
|
||||
| Statement::PtrAccess { .. }
|
||||
| Statement::RepackVector(..)
|
||||
| Statement::FunctionPointer(..) => {}
|
||||
}
|
||||
}
|
||||
iter::once(Statement::Label(id_def.register_intermediate(None)))
|
||||
.chain(func.into_iter().filter(|s| match s {
|
||||
Statement::Label(i) => labels_in_use.contains(i),
|
||||
_ => true,
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
|
||||
this: &ast::Instruction<T>,
|
||||
) -> Option<SpirvWord> {
|
||||
match this {
|
||||
ast::Instruction::Bra { arguments } => Some(arguments.src),
|
||||
_ => None,
|
||||
}
|
||||
}
|
44
ptx/src/pass/normalize_predicates.rs
Normal file
44
ptx/src/pass/normalize_predicates.rs
Normal file
|
@ -0,0 +1,44 @@
|
|||
use super::*;
|
||||
use ptx_parser as ast;
|
||||
|
||||
pub(crate) fn run(
|
||||
func: Vec<NormalizedStatement>,
|
||||
id_def: &mut NumericIdResolver,
|
||||
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
for s in func {
|
||||
match s {
|
||||
Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
Statement::Instruction((pred, inst)) => {
|
||||
if let Some(pred) = pred {
|
||||
let if_true = id_def.register_intermediate(None);
|
||||
let if_false = id_def.register_intermediate(None);
|
||||
let folded_bra = match &inst {
|
||||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
|
||||
_ => None,
|
||||
};
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
result.push(Statement::Instruction(inst));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
result.push(Statement::Instruction(inst));
|
||||
}
|
||||
}
|
||||
Statement::Variable(var) => result.push(Statement::Variable(var)),
|
||||
// Blocks are flattened when resolving ids
|
||||
_ => return Err(error_unreachable()),
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
|
@ -7,20 +7,24 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%21 = OpExtInstImport "OpenCL.std"
|
||||
OpCapability DenormFlushToZero
|
||||
OpExtension "SPV_KHR_float_controls"
|
||||
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||
%22 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "clz"
|
||||
OpExecutionMode %1 ContractionOff
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%24 = OpTypeFunction %void %ulong %ulong
|
||||
%25 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%1 = OpFunction %void None %24
|
||||
%1 = OpFunction %void None %25
|
||||
%7 = OpFunctionParameter %ulong
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%19 = OpLabel
|
||||
%20 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
|
@ -37,11 +41,12 @@
|
|||
%11 = OpLoad %uint %17 Aligned 4
|
||||
OpStore %6 %11
|
||||
%14 = OpLoad %uint %6
|
||||
%13 = OpExtInst %uint %21 clz %14
|
||||
%18 = OpExtInst %uint %22 clz %14
|
||||
%13 = OpCopyObject %uint %18
|
||||
OpStore %6 %13
|
||||
%15 = OpLoad %ulong %5
|
||||
%16 = OpLoad %uint %6
|
||||
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
OpStore %18 %16 Aligned 4
|
||||
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
OpStore %19 %16 Aligned 4
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
OpCapability DenormFlushToZero
|
||||
OpExtension "SPV_KHR_float_controls"
|
||||
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||
%24 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "cvt_s16_s8"
|
||||
|
@ -45,9 +48,7 @@
|
|||
%32 = OpBitcast %uint %15
|
||||
%34 = OpUConvert %uchar %32
|
||||
%20 = OpCopyObject %uchar %34
|
||||
%35 = OpBitcast %uchar %20
|
||||
%37 = OpSConvert %ushort %35
|
||||
%19 = OpCopyObject %ushort %37
|
||||
%19 = OpSConvert %ushort %20
|
||||
%14 = OpSConvert %uint %19
|
||||
OpStore %6 %14
|
||||
%16 = OpLoad %ulong %5
|
||||
|
|
|
@ -7,9 +7,13 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
OpCapability DenormFlushToZero
|
||||
OpExtension "SPV_KHR_float_controls"
|
||||
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||
%24 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "cvt_s64_s32"
|
||||
OpExecutionMode %1 ContractionOff
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%27 = OpTypeFunction %void %ulong %ulong
|
||||
|
@ -40,9 +44,7 @@
|
|||
%12 = OpCopyObject %uint %18
|
||||
OpStore %6 %12
|
||||
%15 = OpLoad %uint %6
|
||||
%32 = OpBitcast %uint %15
|
||||
%33 = OpSConvert %ulong %32
|
||||
%14 = OpCopyObject %ulong %33
|
||||
%14 = OpSConvert %ulong %15
|
||||
OpStore %7 %14
|
||||
%16 = OpLoad %ulong %5
|
||||
%17 = OpLoad %ulong %7
|
||||
|
|
|
@ -7,9 +7,13 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
OpCapability DenormFlushToZero
|
||||
OpExtension "SPV_KHR_float_controls"
|
||||
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||
%25 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "cvt_sat_s_u"
|
||||
OpExecutionMode %1 ContractionOff
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%28 = OpTypeFunction %void %ulong %ulong
|
||||
|
@ -42,7 +46,7 @@
|
|||
%15 = OpSatConvertSToU %uint %16
|
||||
OpStore %7 %15
|
||||
%18 = OpLoad %uint %7
|
||||
%17 = OpBitcast %uint %18
|
||||
%17 = OpCopyObject %uint %18
|
||||
OpStore %8 %17
|
||||
%19 = OpLoad %ulong %5
|
||||
%20 = OpLoad %uint %8
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::pass;
|
||||
use crate::ptx;
|
||||
use crate::translate;
|
||||
use hip_runtime_sys::hipError_t;
|
||||
|
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
|
|||
spirv_txt: &'a [u8],
|
||||
spirv_file_name: &'a str,
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
|
||||
assert!(errors.len() == 0);
|
||||
let spirv_module = translate::to_spirv_module(ast)?;
|
||||
let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap();
|
||||
let spirv_module = pass::to_spirv_module(ast)?;
|
||||
let spv_context =
|
||||
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
|
||||
assert!(spv_context != ptr::null_mut());
|
||||
|
|
|
@ -7,20 +7,24 @@
|
|||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%21 = OpExtInstImport "OpenCL.std"
|
||||
OpCapability DenormFlushToZero
|
||||
OpExtension "SPV_KHR_float_controls"
|
||||
OpExtension "SPV_KHR_no_integer_wrap_decoration"
|
||||
%22 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %1 "popc"
|
||||
OpExecutionMode %1 ContractionOff
|
||||
%void = OpTypeVoid
|
||||
%ulong = OpTypeInt 64 0
|
||||
%24 = OpTypeFunction %void %ulong %ulong
|
||||
%25 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Generic_uint = OpTypePointer Generic %uint
|
||||
%1 = OpFunction %void None %24
|
||||
%1 = OpFunction %void None %25
|
||||
%7 = OpFunctionParameter %ulong
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%19 = OpLabel
|
||||
%20 = OpLabel
|
||||
%2 = OpVariable %_ptr_Function_ulong Function
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
|
@ -37,11 +41,12 @@
|
|||
%11 = OpLoad %uint %17 Aligned 4
|
||||
OpStore %6 %11
|
||||
%14 = OpLoad %uint %6
|
||||
%13 = OpBitCount %uint %14
|
||||
%18 = OpBitCount %uint %14
|
||||
%13 = OpCopyObject %uint %18
|
||||
OpStore %6 %13
|
||||
%15 = OpLoad %ulong %5
|
||||
%16 = OpLoad %uint %6
|
||||
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
OpStore %18 %16 Aligned 4
|
||||
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
|
||||
OpStore %19 %16 Aligned 4
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Excersise as many features of vector types as possible
|
||||
// Exercise as many features of vector types as possible
|
||||
|
||||
.version 6.5
|
||||
.target sm_60
|
||||
|
|
|
@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
|
|||
for statement in sorted_statements {
|
||||
match statement {
|
||||
Statement::Variable(
|
||||
var
|
||||
@
|
||||
ast::Variable {
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Shared,
|
||||
..
|
||||
},
|
||||
)
|
||||
| Statement::Variable(
|
||||
var
|
||||
@
|
||||
ast::Variable {
|
||||
var @ ast::Variable {
|
||||
state_space: ast::StateSpace::Global,
|
||||
..
|
||||
},
|
||||
|
@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
|
|||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
details @ ast::AtomDetails {
|
||||
inner:
|
||||
ast::AtomInnerDetails::Unsigned {
|
||||
op: ast::AtomUIntOp::Inc,
|
||||
|
@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
|
|||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
details @ ast::AtomDetails {
|
||||
inner:
|
||||
ast::AtomInnerDetails::Unsigned {
|
||||
op: ast::AtomUIntOp::Dec,
|
||||
|
@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
|
|||
)?);
|
||||
}
|
||||
Statement::Instruction(ast::Instruction::Atom(
|
||||
details
|
||||
@
|
||||
ast::AtomDetails {
|
||||
details @ ast::AtomDetails {
|
||||
inner:
|
||||
ast::AtomInnerDetails::Float {
|
||||
op: ast::AtomFloatOp::Add,
|
||||
|
|
17
ptx_parser/Cargo.toml
Normal file
17
ptx_parser/Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "ptx_parser"
|
||||
version = "0.0.0"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
|
||||
[dependencies]
|
||||
logos = "0.14"
|
||||
winnow = { version = "0.6.18" }
|
||||
#winnow = { version = "0.6.18", features = ["debug"] }
|
||||
ptx_parser_macros = { path = "../ptx_parser_macros" }
|
||||
thiserror = "1.0"
|
||||
bitflags = "1.2"
|
||||
rustc-hash = "2.0.0"
|
||||
derive_more = { version = "1", features = ["display"] }
|
1695
ptx_parser/src/ast.rs
Normal file
1695
ptx_parser/src/ast.rs
Normal file
File diff suppressed because it is too large
Load diff
69
ptx_parser/src/check_args.py
Normal file
69
ptx_parser/src/check_args.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import os, sys, subprocess
|
||||
|
||||
|
||||
SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
|
||||
TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
|
||||
MULTIVAR = ["", "<1>" ]
|
||||
VECTOR = ["", ".v2" ]
|
||||
|
||||
HEADER = """
|
||||
.version 8.5
|
||||
.target sm_90
|
||||
.address_size 64
|
||||
"""
|
||||
|
||||
|
||||
def directive(space, variable, multivar, vector):
|
||||
return """{3}
|
||||
{0} {4} .b32 variable{2} {1};
|
||||
""".format(space, variable, multivar, HEADER, vector)
|
||||
|
||||
def entry_arg(space, variable, multivar, vector):
|
||||
return """{3}
|
||||
.entry foobar ( {0} {4} .b32 variable{2} {1})
|
||||
{{
|
||||
ret;
|
||||
}}
|
||||
""".format(space, variable, multivar, HEADER, vector)
|
||||
|
||||
|
||||
def fn_arg(space, variable, multivar, vector):
|
||||
return """{3}
|
||||
.func foobar ( {0} {4} .b32 variable{2} {1})
|
||||
{{
|
||||
ret;
|
||||
}}
|
||||
""".format(space, variable, multivar, HEADER, vector)
|
||||
|
||||
|
||||
def fn_body(space, variable, multivar, vector):
|
||||
return """{3}
|
||||
.func foobar ()
|
||||
{{
|
||||
{0} {4} .b32 variable{2} {1};
|
||||
ret;
|
||||
}}
|
||||
""".format(space, variable, multivar, HEADER, vector)
|
||||
|
||||
|
||||
def generate(generator):
|
||||
legal = []
|
||||
for space in SPACE:
|
||||
for init in TYPE_AND_INIT:
|
||||
for multi in MULTIVAR:
|
||||
for vector in VECTOR:
|
||||
ptx = generator(space, init, multi, vector)
|
||||
if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
|
||||
legal.append((space, vector, init, multi))
|
||||
print(generator.__name__)
|
||||
print(legal)
|
||||
|
||||
|
||||
def main():
|
||||
generate(directive)
|
||||
generate(entry_arg)
|
||||
generate(fn_arg)
|
||||
generate(fn_body)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3269
ptx_parser/src/lib.rs
Normal file
3269
ptx_parser/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
17
ptx_parser_macros/Cargo.toml
Normal file
17
ptx_parser_macros/Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "ptx_parser_macros"
|
||||
version = "0.0.0"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" }
|
||||
convert_case = "0.6.0"
|
||||
rustc-hash = "2.0.0"
|
||||
syn = "2.0.67"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0.86"
|
||||
either = "1.13.0"
|
1023
ptx_parser_macros/src/lib.rs
Normal file
1023
ptx_parser_macros/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
13
ptx_parser_macros_impl/Cargo.toml
Normal file
13
ptx_parser_macros_impl/Cargo.toml
Normal file
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "ptx_parser_macros_impl"
|
||||
version = "0.0.0"
|
||||
authors = ["Andrzej Janik <vosen@vosen.pl>"]
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
|
||||
[dependencies]
|
||||
syn = { version = "2.0.67", features = ["extra-traits", "full"] }
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0.86"
|
||||
rustc-hash = "2.0.0"
|
881
ptx_parser_macros_impl/src/lib.rs
Normal file
881
ptx_parser_macros_impl/src/lib.rs
Normal file
|
@ -0,0 +1,881 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use syn::{
|
||||
braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token,
|
||||
Type, TypeParam, Visibility,
|
||||
};
|
||||
|
||||
pub mod parser;
|
||||
|
||||
pub struct GenerateInstructionType {
|
||||
pub visibility: Option<Visibility>,
|
||||
pub name: Ident,
|
||||
pub type_parameters: Punctuated<TypeParam, Token![,]>,
|
||||
pub short_parameters: Punctuated<Ident, Token![,]>,
|
||||
pub variants: Punctuated<InstructionVariant, Token![,]>,
|
||||
}
|
||||
|
||||
impl GenerateInstructionType {
|
||||
pub fn emit_arg_types(&self, tokens: &mut TokenStream) {
|
||||
for v in self.variants.iter() {
|
||||
v.emit_type(&self.visibility, tokens);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn emit_instruction_type(&self, tokens: &mut TokenStream) {
|
||||
let vis = &self.visibility;
|
||||
let type_name = &self.name;
|
||||
let type_parameters = &self.type_parameters;
|
||||
let variants = self.variants.iter().map(|v| v.emit_variant());
|
||||
quote! {
|
||||
#vis enum #type_name<#type_parameters> {
|
||||
#(#variants),*
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
|
||||
pub fn emit_visit(&self, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
|
||||
}
|
||||
|
||||
pub fn emit_visit_mut(&self, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(
|
||||
VisitKind::RefMut,
|
||||
tokens,
|
||||
InstructionVariant::emit_visit_mut,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn emit_visit_map(&self, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map)
|
||||
}
|
||||
|
||||
fn emit_visit_impl(
|
||||
&self,
|
||||
kind: VisitKind,
|
||||
tokens: &mut TokenStream,
|
||||
mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream),
|
||||
) {
|
||||
let type_name = &self.name;
|
||||
let type_parameters = &self.type_parameters;
|
||||
let short_parameters = &self.short_parameters;
|
||||
let mut inner_tokens = TokenStream::new();
|
||||
for v in self.variants.iter() {
|
||||
fn_(v, type_name, &mut inner_tokens);
|
||||
}
|
||||
let visit_ref = kind.reference();
|
||||
let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
|
||||
let visit_fn = format_ident!("visit{}", kind.fn_suffix());
|
||||
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
|
||||
(
|
||||
quote! { <#type_parameters, To: Operand, Err> },
|
||||
quote! { <#short_parameters, To, Err> },
|
||||
quote! { std::result::Result<#type_name<To>, Err> },
|
||||
)
|
||||
} else {
|
||||
(
|
||||
quote! { <#type_parameters, Err> },
|
||||
quote! { <#short_parameters, Err> },
|
||||
quote! { std::result::Result<(), Err> },
|
||||
)
|
||||
};
|
||||
quote! {
|
||||
pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
|
||||
Ok(match i {
|
||||
#inner_tokens
|
||||
})
|
||||
}
|
||||
}.to_tokens(tokens);
|
||||
if kind == VisitKind::Map {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum VisitKind {
|
||||
Ref,
|
||||
RefMut,
|
||||
Map,
|
||||
}
|
||||
|
||||
impl VisitKind {
|
||||
fn fn_suffix(self) -> &'static str {
|
||||
match self {
|
||||
VisitKind::Ref => "",
|
||||
VisitKind::RefMut => "_mut",
|
||||
VisitKind::Map => "_map",
|
||||
}
|
||||
}
|
||||
|
||||
fn type_suffix(self) -> &'static str {
|
||||
match self {
|
||||
VisitKind::Ref => "",
|
||||
VisitKind::RefMut => "Mut",
|
||||
VisitKind::Map => "Map",
|
||||
}
|
||||
}
|
||||
|
||||
fn reference(self) -> Option<proc_macro2::TokenStream> {
|
||||
match self {
|
||||
VisitKind::Ref => Some(quote! { & }),
|
||||
VisitKind::RefMut => Some(quote! { &mut }),
|
||||
VisitKind::Map => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for GenerateInstructionType {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let visibility = if !input.peek(Token![enum]) {
|
||||
Some(input.parse::<Visibility>()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
input.parse::<Token![enum]>()?;
|
||||
let name = input.parse::<Ident>()?;
|
||||
input.parse::<Token![<]>()?;
|
||||
let type_parameters = Punctuated::parse_separated_nonempty(input)?;
|
||||
let short_parameters = type_parameters
|
||||
.iter()
|
||||
.map(|p: &TypeParam| p.ident.clone())
|
||||
.collect();
|
||||
input.parse::<Token![>]>()?;
|
||||
let variants_buffer;
|
||||
braced!(variants_buffer in input);
|
||||
let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?;
|
||||
Ok(Self {
|
||||
visibility,
|
||||
name,
|
||||
type_parameters,
|
||||
short_parameters,
|
||||
variants,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InstructionVariant {
|
||||
pub name: Ident,
|
||||
pub type_: Option<Option<Expr>>,
|
||||
pub space: Option<Expr>,
|
||||
pub data: Option<Type>,
|
||||
pub arguments: Option<Arguments>,
|
||||
pub visit: Option<Expr>,
|
||||
pub visit_mut: Option<Expr>,
|
||||
pub map: Option<Expr>,
|
||||
}
|
||||
|
||||
impl InstructionVariant {
|
||||
fn args_name(&self) -> Ident {
|
||||
format_ident!("{}Args", self.name)
|
||||
}
|
||||
|
||||
fn emit_variant(&self) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let data = match &self.data {
|
||||
None => {
|
||||
quote! {}
|
||||
}
|
||||
Some(data_type) => {
|
||||
quote! {
|
||||
data: #data_type,
|
||||
}
|
||||
}
|
||||
};
|
||||
let arguments = match &self.arguments {
|
||||
None => {
|
||||
quote! {}
|
||||
}
|
||||
Some(args) => {
|
||||
let args_name = self.args_name();
|
||||
match &args {
|
||||
Arguments::Def(InstructionArguments { generic: None, .. }) => {
|
||||
quote! {
|
||||
arguments: #args_name,
|
||||
}
|
||||
}
|
||||
Arguments::Def(InstructionArguments {
|
||||
generic: Some(generics),
|
||||
..
|
||||
}) => {
|
||||
quote! {
|
||||
arguments: #args_name <#generics>,
|
||||
}
|
||||
}
|
||||
Arguments::Decl(type_) => quote! {
|
||||
arguments: #type_,
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
quote! {
|
||||
#name { #data #arguments }
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
|
||||
}
|
||||
|
||||
fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(
|
||||
&self.visit_mut,
|
||||
enum_,
|
||||
tokens,
|
||||
InstructionArguments::emit_visit_mut,
|
||||
)
|
||||
}
|
||||
|
||||
fn emit_visit_impl(
|
||||
&self,
|
||||
visit_fn: &Option<Expr>,
|
||||
enum_: &Ident,
|
||||
tokens: &mut TokenStream,
|
||||
mut fn_: impl FnMut(&InstructionArguments, &Option<Option<Expr>>, &Option<Expr>) -> TokenStream,
|
||||
) {
|
||||
let name = &self.name;
|
||||
let arguments = match &self.arguments {
|
||||
None => {
|
||||
quote! {
|
||||
#enum_ :: #name { .. } => { }
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
return;
|
||||
}
|
||||
Some(Arguments::Decl(_)) => {
|
||||
quote! {
|
||||
#enum_ :: #name { data, arguments } => { #visit_fn }
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
return;
|
||||
}
|
||||
Some(Arguments::Def(args)) => args,
|
||||
};
|
||||
let data = &self.data.as_ref().map(|_| quote! { data,});
|
||||
let arg_calls = fn_(arguments, &self.type_, &self.space);
|
||||
quote! {
|
||||
#enum_ :: #name { #data arguments } => {
|
||||
#arg_calls
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
|
||||
fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||
let name = &self.name;
|
||||
let data = &self.data.as_ref().map(|_| quote! { data,});
|
||||
let arguments = match self.arguments {
|
||||
None => None,
|
||||
Some(Arguments::Decl(_)) => {
|
||||
let map = self.map.as_ref().unwrap();
|
||||
quote! {
|
||||
#enum_ :: #name { #data arguments } => {
|
||||
#map
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
return;
|
||||
}
|
||||
Some(Arguments::Def(ref def)) => Some(def),
|
||||
};
|
||||
let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,});
|
||||
let mut arg_calls = None;
|
||||
let arguments_init = arguments.as_ref().map(|arguments| {
|
||||
let arg_type = self.args_name();
|
||||
arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space));
|
||||
let arg_names = arguments.fields.iter().map(|arg| &arg.name);
|
||||
quote! {
|
||||
arguments: #arg_type { #(#arg_names),* }
|
||||
}
|
||||
});
|
||||
quote! {
|
||||
#enum_ :: #name { #data #arguments_ident } => {
|
||||
#arg_calls
|
||||
#enum_ :: #name { #data #arguments_init }
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
|
||||
fn emit_type(&self, vis: &Option<Visibility>, tokens: &mut TokenStream) {
|
||||
let arguments = match self.arguments {
|
||||
Some(Arguments::Def(ref a)) => a,
|
||||
Some(Arguments::Decl(_)) => return,
|
||||
None => return,
|
||||
};
|
||||
let name = self.args_name();
|
||||
let type_parameters = if arguments.generic.is_some() {
|
||||
Some(quote! { <T> })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let fields = arguments.fields.iter().map(|f| f.emit_field(vis));
|
||||
quote! {
|
||||
#vis struct #name #type_parameters {
|
||||
#(#fields),*
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for InstructionVariant {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let name = input.parse::<Ident>()?;
|
||||
let properties_buffer;
|
||||
braced!(properties_buffer in input);
|
||||
let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?;
|
||||
let mut type_ = None;
|
||||
let mut space = None;
|
||||
let mut data = None;
|
||||
let mut arguments = None;
|
||||
let mut visit = None;
|
||||
let mut visit_mut = None;
|
||||
let mut map = None;
|
||||
for property in properties {
|
||||
match property {
|
||||
VariantProperty::Type(t) => type_ = Some(t),
|
||||
VariantProperty::Space(s) => space = Some(s),
|
||||
VariantProperty::Data(d) => data = Some(d),
|
||||
VariantProperty::Arguments(a) => arguments = Some(a),
|
||||
VariantProperty::Visit(e) => visit = Some(e),
|
||||
VariantProperty::VisitMut(e) => visit_mut = Some(e),
|
||||
VariantProperty::Map(e) => map = Some(e),
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
name,
|
||||
type_,
|
||||
space,
|
||||
data,
|
||||
arguments,
|
||||
visit,
|
||||
visit_mut,
|
||||
map,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum VariantProperty {
|
||||
Type(Option<Expr>),
|
||||
Space(Expr),
|
||||
Data(Type),
|
||||
Arguments(Arguments),
|
||||
Visit(Expr),
|
||||
VisitMut(Expr),
|
||||
Map(Expr),
|
||||
}
|
||||
|
||||
impl VariantProperty {
|
||||
pub fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let lookahead = input.lookahead1();
|
||||
Ok(if lookahead.peek(Token![type]) {
|
||||
input.parse::<Token![type]>()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Type(if input.peek(Token![!]) {
|
||||
input.parse::<Token![!]>()?;
|
||||
None
|
||||
} else {
|
||||
Some(input.parse::<Expr>()?)
|
||||
})
|
||||
} else if lookahead.peek(Ident) {
|
||||
let key = input.parse::<Ident>()?;
|
||||
match &*key.to_string() {
|
||||
"data" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Data(input.parse::<Type>()?)
|
||||
}
|
||||
"space" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Space(input.parse::<Expr>()?)
|
||||
}
|
||||
"arguments" => {
|
||||
let generics = if input.peek(Token![<]) {
|
||||
input.parse::<Token![<]>()?;
|
||||
let gen_params =
|
||||
Punctuated::<PathSegment, syn::token::PathSep>::parse_separated_nonempty(input)?;
|
||||
input.parse::<Token![>]>()?;
|
||||
Some(gen_params)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
input.parse::<Token![:]>()?;
|
||||
if input.peek(token::Brace) {
|
||||
let fields;
|
||||
braced!(fields in input);
|
||||
VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse(
|
||||
generics, &fields,
|
||||
)?))
|
||||
} else {
|
||||
VariantProperty::Arguments(Arguments::Decl(input.parse::<Type>()?))
|
||||
}
|
||||
}
|
||||
"visit" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Visit(input.parse::<Expr>()?)
|
||||
}
|
||||
"visit_mut" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::VisitMut(input.parse::<Expr>()?)
|
||||
}
|
||||
"map" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Map(input.parse::<Expr>()?)
|
||||
}
|
||||
x => {
|
||||
return Err(syn::Error::new(
|
||||
key.span(),
|
||||
format!(
|
||||
"Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.",
|
||||
x
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Arguments {
|
||||
Decl(Type),
|
||||
Def(InstructionArguments),
|
||||
}
|
||||
|
||||
pub struct InstructionArguments {
|
||||
pub generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
|
||||
pub fields: Punctuated<ArgumentField, Token![,]>,
|
||||
}
|
||||
|
||||
impl InstructionArguments {
|
||||
pub fn parse(
|
||||
generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
|
||||
input: syn::parse::ParseStream,
|
||||
) -> syn::Result<Self> {
|
||||
let fields = Punctuated::<ArgumentField, Token![,]>::parse_terminated_with(
|
||||
input,
|
||||
ArgumentField::parse,
|
||||
)?;
|
||||
Ok(Self { generic, fields })
|
||||
}
|
||||
|
||||
fn emit_visit(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
) -> TokenStream {
|
||||
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit)
|
||||
}
|
||||
|
||||
fn emit_visit_mut(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
) -> TokenStream {
|
||||
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut)
|
||||
}
|
||||
|
||||
fn emit_visit_map(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
) -> TokenStream {
|
||||
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map)
|
||||
}
|
||||
|
||||
fn emit_visit_impl(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
mut fn_: impl FnMut(&ArgumentField, &Option<Option<Expr>>, &Option<Expr>, bool) -> TokenStream,
|
||||
) -> TokenStream {
|
||||
let is_ident = if let Some(ref generic) = self.generic {
|
||||
generic.len() > 1
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let field_calls = self
|
||||
.fields
|
||||
.iter()
|
||||
.map(|f| fn_(f, parent_type, parent_space, is_ident));
|
||||
quote! {
|
||||
#(#field_calls)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ArgumentField {
|
||||
pub name: Ident,
|
||||
pub is_dst: bool,
|
||||
pub repr: Type,
|
||||
pub space: Option<Expr>,
|
||||
pub type_: Option<Expr>,
|
||||
pub relaxed_type_check: bool,
|
||||
}
|
||||
|
||||
impl ArgumentField {
|
||||
fn parse_block(
|
||||
input: syn::parse::ParseStream,
|
||||
) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>, bool)> {
|
||||
let content;
|
||||
braced!(content in input);
|
||||
let all_fields =
|
||||
Punctuated::<ExprOrPath, Token![,]>::parse_terminated_with(&content, |content| {
|
||||
let lookahead = content.lookahead1();
|
||||
Ok(if lookahead.peek(Token![type]) {
|
||||
content.parse::<Token![type]>()?;
|
||||
content.parse::<Token![:]>()?;
|
||||
ExprOrPath::Type(content.parse::<Expr>()?)
|
||||
} else if lookahead.peek(Ident) {
|
||||
let name_ident = content.parse::<Ident>()?;
|
||||
content.parse::<Token![:]>()?;
|
||||
match &*name_ident.to_string() {
|
||||
"relaxed_type_check" => {
|
||||
ExprOrPath::RelaxedTypeCheck(content.parse::<LitBool>()?.value)
|
||||
}
|
||||
"repr" => ExprOrPath::Repr(content.parse::<Type>()?),
|
||||
"space" => ExprOrPath::Space(content.parse::<Expr>()?),
|
||||
"dst" => {
|
||||
let ident = content.parse::<LitBool>()?;
|
||||
ExprOrPath::Dst(ident.value)
|
||||
}
|
||||
name => {
|
||||
return Err(syn::Error::new(
|
||||
name_ident.span(),
|
||||
format!("Unexpected key `{}`, expected `repr` or `space", name),
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
})
|
||||
})?;
|
||||
let mut repr = None;
|
||||
let mut type_ = None;
|
||||
let mut space = None;
|
||||
let mut is_dst = None;
|
||||
let mut relaxed_type_check = false;
|
||||
for exp_or_path in all_fields {
|
||||
match exp_or_path {
|
||||
ExprOrPath::Repr(r) => repr = Some(r),
|
||||
ExprOrPath::Type(t) => type_ = Some(t),
|
||||
ExprOrPath::Space(s) => space = Some(s),
|
||||
ExprOrPath::Dst(x) => is_dst = Some(x),
|
||||
ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed,
|
||||
}
|
||||
}
|
||||
Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check))
|
||||
}
|
||||
|
||||
fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> {
|
||||
input.parse::<Type>()
|
||||
}
|
||||
|
||||
fn emit_visit(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
is_ident: bool,
|
||||
) -> TokenStream {
|
||||
self.emit_visit_impl(parent_type, parent_space, is_ident, false)
|
||||
}
|
||||
|
||||
fn emit_visit_mut(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
is_ident: bool,
|
||||
) -> TokenStream {
|
||||
self.emit_visit_impl(parent_type, parent_space, is_ident, true)
|
||||
}
|
||||
|
||||
fn emit_visit_impl(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
is_ident: bool,
|
||||
is_mut: bool,
|
||||
) -> TokenStream {
|
||||
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
|
||||
(Some(type_), _) => (false, Some(type_)),
|
||||
(None, None) => panic!("No type set"),
|
||||
(None, Some(None)) => (true, None),
|
||||
(None, Some(Some(type_))) => (false, Some(type_)),
|
||||
};
|
||||
let space = self
|
||||
.space
|
||||
.as_ref()
|
||||
.or(parent_space.as_ref())
|
||||
.map(|space| quote! { #space })
|
||||
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
||||
let is_dst = self.is_dst;
|
||||
let relaxed_type_check = self.relaxed_type_check;
|
||||
let name = &self.name;
|
||||
let type_space = if is_typeless {
|
||||
quote! {
|
||||
let type_space = None;
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let type_ = #type_;
|
||||
let space = #space;
|
||||
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
|
||||
}
|
||||
};
|
||||
if is_ident {
|
||||
if is_mut {
|
||||
quote! {
|
||||
{
|
||||
#type_space
|
||||
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
{
|
||||
#type_space
|
||||
visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (operand_fn, arguments_name) = if is_mut {
|
||||
(
|
||||
quote! {
|
||||
VisitOperand::visit_mut
|
||||
},
|
||||
quote! {
|
||||
&mut arguments.#name
|
||||
},
|
||||
)
|
||||
} else {
|
||||
(
|
||||
quote! {
|
||||
VisitOperand::visit
|
||||
},
|
||||
quote! {
|
||||
& arguments.#name
|
||||
},
|
||||
)
|
||||
};
|
||||
quote! {{
|
||||
#type_space
|
||||
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?;
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_visit_map(
|
||||
&self,
|
||||
parent_type: &Option<Option<Expr>>,
|
||||
parent_space: &Option<Expr>,
|
||||
is_ident: bool,
|
||||
) -> TokenStream {
|
||||
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
|
||||
(Some(type_), _) => (false, Some(type_)),
|
||||
(None, None) => panic!("No type set"),
|
||||
(None, Some(None)) => (true, None),
|
||||
(None, Some(Some(type_))) => (false, Some(type_)),
|
||||
};
|
||||
let space = self
|
||||
.space
|
||||
.as_ref()
|
||||
.or(parent_space.as_ref())
|
||||
.map(|space| quote! { #space })
|
||||
.unwrap_or_else(|| quote! { StateSpace::Reg });
|
||||
let is_dst = self.is_dst;
|
||||
let relaxed_type_check = self.relaxed_type_check;
|
||||
let name = &self.name;
|
||||
let type_space = if is_typeless {
|
||||
quote! {
|
||||
let type_space = None;
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let type_ = #type_;
|
||||
let space = #space;
|
||||
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
|
||||
}
|
||||
};
|
||||
let map_call = if is_ident {
|
||||
quote! {
|
||||
visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)?
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?
|
||||
}
|
||||
};
|
||||
quote! {
|
||||
let #name = {
|
||||
#type_space
|
||||
#map_call
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn is_dst(name: &Ident) -> syn::Result<bool> {
|
||||
if name.to_string().starts_with("dst") {
|
||||
Ok(true)
|
||||
} else if name.to_string().starts_with("src") {
|
||||
Ok(false)
|
||||
} else {
|
||||
return Err(syn::Error::new(
|
||||
name.span(),
|
||||
format!(
|
||||
"Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`",
|
||||
name
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_field(&self, vis: &Option<Visibility>) -> TokenStream {
|
||||
let name = &self.name;
|
||||
let type_ = &self.repr;
|
||||
quote! {
|
||||
#vis #name: #type_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for ArgumentField {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let name = input.parse::<Ident>()?;
|
||||
|
||||
input.parse::<Token![:]>()?;
|
||||
let lookahead = input.lookahead1();
|
||||
let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) {
|
||||
Self::parse_block(input)?
|
||||
} else if lookahead.peek(syn::Ident) {
|
||||
(Self::parse_basic(input)?, None, None, None, false)
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
};
|
||||
let is_dst = match is_dst {
|
||||
Some(x) => x,
|
||||
None => Self::is_dst(&name)?,
|
||||
};
|
||||
Ok(Self {
|
||||
name,
|
||||
is_dst,
|
||||
repr,
|
||||
type_,
|
||||
space,
|
||||
relaxed_type_check
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum ExprOrPath {
|
||||
Repr(Type),
|
||||
Type(Expr),
|
||||
Space(Expr),
|
||||
Dst(bool),
|
||||
RelaxedTypeCheck(bool),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proc_macro2::Span;
|
||||
use quote::{quote, ToTokens};
|
||||
|
||||
fn to_string(x: impl ToTokens) -> String {
|
||||
quote! { #x }.to_string()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_argument_field_basic() {
|
||||
let input = quote! {
|
||||
dst: P::Operand
|
||||
};
|
||||
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||
assert_eq!("dst", arg.name.to_string());
|
||||
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||
assert!(matches!(arg.type_, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_argument_field_block() {
|
||||
let input = quote! {
|
||||
dst: {
|
||||
type: ScalarType::U32,
|
||||
space: StateSpace::Global,
|
||||
repr: P::Operand,
|
||||
}
|
||||
};
|
||||
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||
assert_eq!("dst", arg.name.to_string());
|
||||
assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap()));
|
||||
assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap()));
|
||||
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_argument_field_block_untyped() {
|
||||
let input = quote! {
|
||||
dst: {
|
||||
repr: P::Operand,
|
||||
}
|
||||
};
|
||||
let arg = syn::parse2::<ArgumentField>(input).unwrap();
|
||||
assert_eq!("dst", arg.name.to_string());
|
||||
assert_eq!("P :: Operand", to_string(arg.repr));
|
||||
assert!(matches!(arg.type_, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_variant_complex() {
|
||||
let input = quote! {
|
||||
Ld {
|
||||
type: ScalarType::U32,
|
||||
space: StateSpace::Global,
|
||||
data: LdDetails,
|
||||
arguments<P>: {
|
||||
dst: {
|
||||
repr: P::Operand,
|
||||
type: ScalarType::U32,
|
||||
space: StateSpace::Shared,
|
||||
},
|
||||
src: P::Operand,
|
||||
},
|
||||
}
|
||||
};
|
||||
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
|
||||
assert_eq!("Ld", variant.name.to_string());
|
||||
assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
|
||||
assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
|
||||
assert_eq!("LdDetails", to_string(variant.data.unwrap()));
|
||||
let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
|
||||
a
|
||||
} else {
|
||||
panic!()
|
||||
};
|
||||
assert_eq!("P", to_string(arguments.generic));
|
||||
let mut fields = arguments.fields.into_iter();
|
||||
let dst = fields.next().unwrap();
|
||||
assert_eq!("P :: Operand", to_string(dst.repr));
|
||||
assert_eq!("ScalarType :: U32", to_string(dst.type_));
|
||||
assert_eq!("StateSpace :: Shared", to_string(dst.space));
|
||||
let src = fields.next().unwrap();
|
||||
assert_eq!("P :: Operand", to_string(src.repr));
|
||||
assert!(matches!(src.type_, None));
|
||||
assert!(matches!(src.space, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn visit_variant_empty() {
|
||||
let input = quote! {
|
||||
Ret {
|
||||
data: RetData
|
||||
}
|
||||
};
|
||||
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
|
||||
let mut output = TokenStream::new();
|
||||
variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output);
|
||||
assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }");
|
||||
}
|
||||
}
|
844
ptx_parser_macros_impl/src/parser.rs
Normal file
844
ptx_parser_macros_impl/src/parser.rs
Normal file
|
@ -0,0 +1,844 @@
|
|||
use proc_macro2::Span;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use quote::ToTokens;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::fmt::Write;
|
||||
use syn::bracketed;
|
||||
use syn::parse::Peek;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::LitInt;
|
||||
use syn::Type;
|
||||
use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token};
|
||||
|
||||
pub struct ParseDefinitions {
|
||||
pub token_type: ItemEnum,
|
||||
pub additional_enums: FxHashMap<Ident, ItemEnum>,
|
||||
pub definitions: Vec<OpcodeDefinition>,
|
||||
}
|
||||
|
||||
impl Parse for ParseDefinitions {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let token_type = input.parse::<ItemEnum>()?;
|
||||
let mut additional_enums = FxHashMap::default();
|
||||
while input.peek(Token![#]) {
|
||||
let enum_ = input.parse::<ItemEnum>()?;
|
||||
additional_enums.insert(enum_.ident.clone(), enum_);
|
||||
}
|
||||
let mut definitions = Vec::new();
|
||||
while !input.is_empty() {
|
||||
definitions.push(input.parse::<OpcodeDefinition>()?);
|
||||
}
|
||||
Ok(Self {
|
||||
token_type,
|
||||
additional_enums,
|
||||
definitions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpcodeDefinition(pub Patterns, pub Vec<Rule>);
|
||||
|
||||
impl Parse for OpcodeDefinition {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let patterns = input.parse::<Patterns>()?;
|
||||
let mut rules = Vec::new();
|
||||
while Rule::peek(input) {
|
||||
rules.push(input.parse::<Rule>()?);
|
||||
input.parse::<Token![;]>()?;
|
||||
}
|
||||
Ok(Self(patterns, rules))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>);
|
||||
|
||||
impl Parse for Patterns {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let mut result = Vec::new();
|
||||
loop {
|
||||
if !OpcodeDecl::peek(input) {
|
||||
break;
|
||||
}
|
||||
let decl = input.parse::<OpcodeDecl>()?;
|
||||
let code_block = input.parse::<CodeBlock>()?;
|
||||
result.push((decl, code_block))
|
||||
}
|
||||
Ok(Self(result))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpcodeDecl(pub Instruction, pub Arguments);
|
||||
|
||||
impl OpcodeDecl {
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
Instruction::peek(input) && !input.peek2(Token![=])
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for OpcodeDecl {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
Ok(Self(
|
||||
input.parse::<Instruction>()?,
|
||||
input.parse::<Arguments>()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodeBlock {
|
||||
pub special: bool,
|
||||
pub code: proc_macro2::Group,
|
||||
}
|
||||
|
||||
impl Parse for CodeBlock {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let lookahead = input.lookahead1();
|
||||
let (special, code) = if lookahead.peek(Token![<]) {
|
||||
input.parse::<Token![<]>()?;
|
||||
input.parse::<Token![=]>()?;
|
||||
//input.parse::<Token![>]>()?;
|
||||
(true, input.parse::<proc_macro2::Group>()?)
|
||||
} else if lookahead.peek(Token![=]) {
|
||||
input.parse::<Token![=]>()?;
|
||||
input.parse::<Token![>]>()?;
|
||||
(false, input.parse::<proc_macro2::Group>()?)
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
};
|
||||
Ok(Self { special, code })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Rule {
|
||||
pub modifier: Option<DotModifier>,
|
||||
pub type_: Option<Type>,
|
||||
pub alternatives: Vec<DotModifier>,
|
||||
}
|
||||
|
||||
impl Rule {
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
DotModifier::peek(input)
|
||||
|| (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>]))
|
||||
}
|
||||
|
||||
fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result<Vec<DotModifier>> {
|
||||
let mut result = Vec::new();
|
||||
Self::parse_with_alternative(input, &mut result)?;
|
||||
loop {
|
||||
if !input.peek(Token![,]) {
|
||||
break;
|
||||
}
|
||||
input.parse::<Token![,]>()?;
|
||||
Self::parse_with_alternative(input, &mut result)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn parse_with_alternative(
|
||||
input: &syn::parse::ParseBuffer,
|
||||
result: &mut Vec<DotModifier>,
|
||||
) -> Result<(), syn::Error> {
|
||||
input.parse::<Token![.]>()?;
|
||||
let part1 = input.parse::<IdentLike>()?;
|
||||
if input.peek(token::Brace) {
|
||||
result.push(DotModifier {
|
||||
part1: part1.clone(),
|
||||
part2: None,
|
||||
});
|
||||
let suffix_content;
|
||||
braced!(suffix_content in input);
|
||||
let suffixes = Punctuated::<IdentOrTypeSuffix, Token![,]>::parse_separated_nonempty(
|
||||
&suffix_content,
|
||||
)?;
|
||||
for part2 in suffixes {
|
||||
result.push(DotModifier {
|
||||
part1: part1.clone(),
|
||||
part2: Some(part2),
|
||||
});
|
||||
}
|
||||
} else if IdentOrTypeSuffix::peek(input) {
|
||||
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
|
||||
result.push(DotModifier { part1, part2 });
|
||||
} else {
|
||||
result.push(DotModifier { part1, part2: None });
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
struct IdentOrTypeSuffix(IdentLike);
|
||||
|
||||
impl IdentOrTypeSuffix {
|
||||
fn span(&self) -> Span {
|
||||
self.0.span()
|
||||
}
|
||||
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
input.peek(Token![::])
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for IdentOrTypeSuffix {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let ident = &self.0;
|
||||
quote! { :: #ident }.to_tokens(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for IdentOrTypeSuffix {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
input.parse::<Token![::]>()?;
|
||||
Ok(Self(input.parse::<IdentLike>()?))
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for Rule {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let (modifier, type_) = if DotModifier::peek(input) {
|
||||
let modifier = Some(input.parse::<DotModifier>()?);
|
||||
if input.peek(Token![:]) {
|
||||
input.parse::<Token![:]>()?;
|
||||
(modifier, Some(input.parse::<Type>()?))
|
||||
} else {
|
||||
(modifier, None)
|
||||
}
|
||||
} else {
|
||||
(None, Some(input.parse::<Type>()?))
|
||||
};
|
||||
input.parse::<Token![=]>()?;
|
||||
let content;
|
||||
braced!(content in input);
|
||||
let alternatives = Self::parse_alternatives(&content)?;
|
||||
Ok(Self {
|
||||
modifier,
|
||||
type_,
|
||||
alternatives,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Instruction {
|
||||
pub name: Ident,
|
||||
pub modifiers: Vec<MaybeDotModifier>,
|
||||
}
|
||||
impl Instruction {
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
input.peek(Ident)
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for Instruction {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let instruction = input.parse::<Ident>()?;
|
||||
let mut modifiers = Vec::new();
|
||||
loop {
|
||||
if !MaybeDotModifier::peek(input) {
|
||||
break;
|
||||
}
|
||||
modifiers.push(MaybeDotModifier::parse(input)?);
|
||||
}
|
||||
Ok(Self {
|
||||
name: instruction,
|
||||
modifiers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MaybeDotModifier {
|
||||
pub optional: bool,
|
||||
pub modifier: DotModifier,
|
||||
}
|
||||
|
||||
impl MaybeDotModifier {
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
input.peek(token::Brace) || DotModifier::peek(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for MaybeDotModifier {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
Ok(if input.peek(token::Brace) {
|
||||
let content;
|
||||
braced!(content in input);
|
||||
let modifier = DotModifier::parse(&content)?;
|
||||
Self {
|
||||
modifier,
|
||||
optional: true,
|
||||
}
|
||||
} else {
|
||||
let modifier = DotModifier::parse(input)?;
|
||||
Self {
|
||||
modifier,
|
||||
optional: false,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
pub struct DotModifier {
|
||||
part1: IdentLike,
|
||||
part2: Option<IdentOrTypeSuffix>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DotModifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, ".")?;
|
||||
self.part1.fmt(f)?;
|
||||
if let Some(ref part2) = self.part2 {
|
||||
write!(f, "::")?;
|
||||
part2.0.fmt(f)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for DotModifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl DotModifier {
|
||||
pub fn span(&self) -> Span {
|
||||
let part1 = self.part1.span();
|
||||
if let Some(ref part2) = self.part2 {
|
||||
part1.join(part2.span()).unwrap_or(part1)
|
||||
} else {
|
||||
part1
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident(&self) -> Ident {
|
||||
let mut result = String::new();
|
||||
write!(&mut result, "{}", self.part1).unwrap();
|
||||
if let Some(ref part2) = self.part2 {
|
||||
write!(&mut result, "_{}", part2.0).unwrap();
|
||||
} else {
|
||||
match self.part1 {
|
||||
IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'),
|
||||
IdentLike::Ident(_) | IdentLike::Integer(_) => {}
|
||||
}
|
||||
}
|
||||
Ident::new(&result.to_ascii_lowercase(), self.span())
|
||||
}
|
||||
|
||||
pub fn variant_capitalized(&self) -> Ident {
|
||||
self.capitalized_impl(String::new())
|
||||
}
|
||||
|
||||
pub fn dot_capitalized(&self) -> Ident {
|
||||
self.capitalized_impl("Dot".to_string())
|
||||
}
|
||||
|
||||
fn capitalized_impl(&self, prefix: String) -> Ident {
|
||||
let mut temp = String::new();
|
||||
write!(&mut temp, "{}", &self.part1).unwrap();
|
||||
if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 {
|
||||
write!(&mut temp, "_{}", part2).unwrap();
|
||||
}
|
||||
let mut result = prefix;
|
||||
let mut capitalize = true;
|
||||
for c in temp.chars() {
|
||||
if c == '_' {
|
||||
capitalize = true;
|
||||
continue;
|
||||
}
|
||||
// Special hack to emit `BF16`` instead of `Bf16``
|
||||
let c = if capitalize || c == 'f' && result.ends_with('B') {
|
||||
capitalize = false;
|
||||
c.to_ascii_uppercase()
|
||||
} else {
|
||||
c
|
||||
};
|
||||
result.push(c);
|
||||
}
|
||||
Ident::new(&result, self.span())
|
||||
}
|
||||
|
||||
pub fn tokens(&self) -> TokenStream {
|
||||
let part1 = &self.part1;
|
||||
let part2 = &self.part2;
|
||||
match self.part2 {
|
||||
None => quote! { . #part1 },
|
||||
Some(_) => quote! { . #part1 #part2 },
|
||||
}
|
||||
}
|
||||
|
||||
fn peek(input: syn::parse::ParseStream) -> bool {
|
||||
input.peek(Token![.])
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for DotModifier {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
input.parse::<Token![.]>()?;
|
||||
let part1 = input.parse::<IdentLike>()?;
|
||||
if IdentOrTypeSuffix::peek(input) {
|
||||
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
|
||||
Ok(Self { part1, part2 })
|
||||
} else {
|
||||
Ok(Self { part1, part2: None })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
enum IdentLike {
|
||||
Type(Token![type]),
|
||||
Const(Token![const]),
|
||||
Ident(Ident),
|
||||
Integer(LitInt),
|
||||
}
|
||||
|
||||
impl IdentLike {
|
||||
fn span(&self) -> Span {
|
||||
match self {
|
||||
IdentLike::Type(c) => c.span(),
|
||||
IdentLike::Const(t) => t.span(),
|
||||
IdentLike::Ident(i) => i.span(),
|
||||
IdentLike::Integer(l) => l.span(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for IdentLike {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
IdentLike::Type(_) => f.write_str("type"),
|
||||
IdentLike::Const(_) => f.write_str("const"),
|
||||
IdentLike::Ident(ident) => write!(f, "{}", ident),
|
||||
IdentLike::Integer(integer) => write!(f, "{}", integer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for IdentLike {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
match self {
|
||||
IdentLike::Type(_) => quote! { type }.to_tokens(tokens),
|
||||
IdentLike::Const(_) => quote! { const }.to_tokens(tokens),
|
||||
IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens),
|
||||
IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse for IdentLike {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let lookahead = input.lookahead1();
|
||||
Ok(if lookahead.peek(Token![const]) {
|
||||
IdentLike::Const(input.parse::<Token![const]>()?)
|
||||
} else if lookahead.peek(Token![type]) {
|
||||
IdentLike::Type(input.parse::<Token![type]>()?)
|
||||
} else if lookahead.peek(Ident) {
|
||||
IdentLike::Ident(input.parse::<Ident>()?)
|
||||
} else if lookahead.peek(LitInt) {
|
||||
IdentLike::Integer(input.parse::<LitInt>()?)
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Arguments decalaration can loook like this:
|
||||
// a{, b}
|
||||
// That's why we don't parse Arguments as Punctuated<Argument, Token![,]>
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct Arguments(pub Vec<Argument>);
|
||||
|
||||
impl Parse for Arguments {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let mut result = Vec::new();
|
||||
loop {
|
||||
if input.peek(Token![,]) {
|
||||
input.parse::<Token![,]>()?;
|
||||
}
|
||||
let mut optional = false;
|
||||
let mut can_be_negated = false;
|
||||
let mut pre_pipe = false;
|
||||
let ident;
|
||||
let lookahead = input.lookahead1();
|
||||
if lookahead.peek(token::Brace) {
|
||||
let content;
|
||||
braced!(content in input);
|
||||
let lookahead = content.lookahead1();
|
||||
if lookahead.peek(Token![!]) {
|
||||
content.parse::<Token![!]>()?;
|
||||
can_be_negated = true;
|
||||
ident = input.parse::<Ident>()?;
|
||||
} else if lookahead.peek(Token![,]) {
|
||||
optional = true;
|
||||
content.parse::<Token![,]>()?;
|
||||
ident = content.parse::<Ident>()?;
|
||||
} else {
|
||||
return Err(lookahead.error());
|
||||
}
|
||||
} else if lookahead.peek(token::Bracket) {
|
||||
let bracketed;
|
||||
bracketed!(bracketed in input);
|
||||
if bracketed.peek(Token![|]) {
|
||||
optional = true;
|
||||
bracketed.parse::<Token![|]>()?;
|
||||
pre_pipe = true;
|
||||
ident = bracketed.parse::<Ident>()?;
|
||||
} else {
|
||||
let mut sub_args = Self::parse(&bracketed)?;
|
||||
sub_args.0.first_mut().unwrap().pre_bracket = true;
|
||||
sub_args.0.last_mut().unwrap().post_bracket = true;
|
||||
if peek_brace_token(input, Token![.]) {
|
||||
let optional_suffix;
|
||||
braced!(optional_suffix in input);
|
||||
optional_suffix.parse::<Token![.]>()?;
|
||||
let unified_ident = optional_suffix.parse::<Ident>()?;
|
||||
if unified_ident.to_string() != "unified" {
|
||||
return Err(syn::Error::new(
|
||||
unified_ident.span(),
|
||||
format!("Exptected `unified`, got `{}`", unified_ident),
|
||||
));
|
||||
}
|
||||
for a in sub_args.0.iter_mut() {
|
||||
a.unified = true;
|
||||
}
|
||||
}
|
||||
result.extend(sub_args.0);
|
||||
continue;
|
||||
}
|
||||
} else if lookahead.peek(Ident) {
|
||||
ident = input.parse::<Ident>()?;
|
||||
} else if lookahead.peek(Token![|]) {
|
||||
input.parse::<Token![|]>()?;
|
||||
pre_pipe = true;
|
||||
ident = input.parse::<Ident>()?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
result.push(Argument {
|
||||
optional,
|
||||
pre_pipe,
|
||||
can_be_negated,
|
||||
pre_bracket: false,
|
||||
ident,
|
||||
post_bracket: false,
|
||||
unified: false,
|
||||
});
|
||||
}
|
||||
Ok(Self(result))
|
||||
}
|
||||
}
|
||||
|
||||
// This is effectively input.peek(token::Brace) && input.peek2(Token![.])
|
||||
// input.peek2 is supposed to skip over next token, but it skips over whole
|
||||
// braced token group. Not sure if it's a bug
|
||||
fn peek_brace_token<T: Peek>(input: syn::parse::ParseStream, _t: T) -> bool {
|
||||
use syn::token::Token;
|
||||
let cursor = input.cursor();
|
||||
cursor
|
||||
.group(proc_macro2::Delimiter::Brace)
|
||||
.map_or(false, |(content, ..)| T::Token::peek(content))
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct Argument {
|
||||
pub optional: bool,
|
||||
pub pre_bracket: bool,
|
||||
pub pre_pipe: bool,
|
||||
pub can_be_negated: bool,
|
||||
pub ident: Ident,
|
||||
pub post_bracket: bool,
|
||||
pub unified: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Arguments, DotModifier, MaybeDotModifier};
|
||||
use quote::{quote, ToTokens};
|
||||
|
||||
#[test]
|
||||
fn parse_modifier_complex() {
|
||||
let input = quote! {
|
||||
.level::eviction_priority
|
||||
};
|
||||
let modifier = syn::parse2::<DotModifier>(input).unwrap();
|
||||
assert_eq!(
|
||||
". level :: eviction_priority",
|
||||
modifier.tokens().to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_modifier_optional() {
|
||||
let input = quote! {
|
||||
{ .level::eviction_priority }
|
||||
};
|
||||
let maybe_modifider = syn::parse2::<MaybeDotModifier>(input).unwrap();
|
||||
assert_eq!(
|
||||
". level :: eviction_priority",
|
||||
maybe_modifider.modifier.tokens().to_string()
|
||||
);
|
||||
assert!(maybe_modifider.optional);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_type_token() {
|
||||
let input = quote! {
|
||||
. type
|
||||
};
|
||||
let maybe_modifier = syn::parse2::<MaybeDotModifier>(input).unwrap();
|
||||
assert_eq!(". type", maybe_modifier.modifier.tokens().to_string());
|
||||
assert!(!maybe_modifier.optional);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arguments_memory() {
|
||||
let input = quote! {
|
||||
[a], b
|
||||
};
|
||||
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||
let a = &arguments.0[0];
|
||||
assert!(!a.optional);
|
||||
assert_eq!("a", a.ident.to_string());
|
||||
assert!(a.pre_bracket);
|
||||
assert!(!a.pre_pipe);
|
||||
assert!(a.post_bracket);
|
||||
assert!(!a.can_be_negated);
|
||||
let b = &arguments.0[1];
|
||||
assert!(!b.optional);
|
||||
assert_eq!("b", b.ident.to_string());
|
||||
assert!(!b.pre_bracket);
|
||||
assert!(!b.pre_pipe);
|
||||
assert!(!b.post_bracket);
|
||||
assert!(!b.can_be_negated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arguments_optional() {
|
||||
let input = quote! {
|
||||
b{, cache_policy}
|
||||
};
|
||||
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||
let b = &arguments.0[0];
|
||||
assert!(!b.optional);
|
||||
assert_eq!("b", b.ident.to_string());
|
||||
assert!(!b.pre_bracket);
|
||||
assert!(!b.pre_pipe);
|
||||
assert!(!b.post_bracket);
|
||||
assert!(!b.can_be_negated);
|
||||
let cache_policy = &arguments.0[1];
|
||||
assert!(cache_policy.optional);
|
||||
assert_eq!("cache_policy", cache_policy.ident.to_string());
|
||||
assert!(!cache_policy.pre_bracket);
|
||||
assert!(!cache_policy.pre_pipe);
|
||||
assert!(!cache_policy.post_bracket);
|
||||
assert!(!cache_policy.can_be_negated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arguments_optional_pred() {
|
||||
let input = quote! {
|
||||
p[|q], a
|
||||
};
|
||||
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||
assert_eq!(arguments.0.len(), 3);
|
||||
let p = &arguments.0[0];
|
||||
assert!(!p.optional);
|
||||
assert_eq!("p", p.ident.to_string());
|
||||
assert!(!p.pre_bracket);
|
||||
assert!(!p.pre_pipe);
|
||||
assert!(!p.post_bracket);
|
||||
assert!(!p.can_be_negated);
|
||||
let q = &arguments.0[1];
|
||||
assert!(q.optional);
|
||||
assert_eq!("q", q.ident.to_string());
|
||||
assert!(!q.pre_bracket);
|
||||
assert!(q.pre_pipe);
|
||||
assert!(!q.post_bracket);
|
||||
assert!(!q.can_be_negated);
|
||||
let a = &arguments.0[2];
|
||||
assert!(!a.optional);
|
||||
assert_eq!("a", a.ident.to_string());
|
||||
assert!(!a.pre_bracket);
|
||||
assert!(!a.pre_pipe);
|
||||
assert!(!a.post_bracket);
|
||||
assert!(!a.can_be_negated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arguments_optional_with_negate() {
|
||||
let input = quote! {
|
||||
b, {!}c
|
||||
};
|
||||
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||
assert_eq!(arguments.0.len(), 2);
|
||||
let b = &arguments.0[0];
|
||||
assert!(!b.optional);
|
||||
assert_eq!("b", b.ident.to_string());
|
||||
assert!(!b.pre_bracket);
|
||||
assert!(!b.pre_pipe);
|
||||
assert!(!b.post_bracket);
|
||||
assert!(!b.can_be_negated);
|
||||
let c = &arguments.0[1];
|
||||
assert!(!c.optional);
|
||||
assert_eq!("c", c.ident.to_string());
|
||||
assert!(!c.pre_bracket);
|
||||
assert!(!c.pre_pipe);
|
||||
assert!(!c.post_bracket);
|
||||
assert!(c.can_be_negated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arguments_tex() {
|
||||
let input = quote! {
|
||||
d[|p], [a{, b}, c], dpdx, dpdy {, e}
|
||||
};
|
||||
let arguments = syn::parse2::<Arguments>(input).unwrap();
|
||||
assert_eq!(arguments.0.len(), 8);
|
||||
{
|
||||
let d = &arguments.0[0];
|
||||
assert!(!d.optional);
|
||||
assert_eq!("d", d.ident.to_string());
|
||||
assert!(!d.pre_bracket);
|
||||
assert!(!d.pre_pipe);
|
||||
assert!(!d.post_bracket);
|
||||
assert!(!d.can_be_negated);
|
||||
}
|
||||
{
|
||||
let p = &arguments.0[1];
|
||||
assert!(p.optional);
|
||||
assert_eq!("p", p.ident.to_string());
|
||||
assert!(!p.pre_bracket);
|
||||
assert!(p.pre_pipe);
|
||||
assert!(!p.post_bracket);
|
||||
assert!(!p.can_be_negated);
|
||||
}
|
||||
{
|
||||
let a = &arguments.0[2];
|
||||
assert!(!a.optional);
|
||||
assert_eq!("a", a.ident.to_string());
|
||||
assert!(a.pre_bracket);
|
||||
assert!(!a.pre_pipe);
|
||||
assert!(!a.post_bracket);
|
||||
assert!(!a.can_be_negated);
|
||||
}
|
||||
{
|
||||
let b = &arguments.0[3];
|
||||
assert!(b.optional);
|
||||
assert_eq!("b", b.ident.to_string());
|
||||
assert!(!b.pre_bracket);
|
||||
assert!(!b.pre_pipe);
|
||||
assert!(!b.post_bracket);
|
||||
assert!(!b.can_be_negated);
|
||||
}
|
||||
{
|
||||
let c = &arguments.0[4];
|
||||
assert!(!c.optional);
|
||||
assert_eq!("c", c.ident.to_string());
|
||||
assert!(!c.pre_bracket);
|
||||
assert!(!c.pre_pipe);
|
||||
assert!(c.post_bracket);
|
||||
assert!(!c.can_be_negated);
|
||||
}
|
||||
{
|
||||
let dpdx = &arguments.0[5];
|
||||
assert!(!dpdx.optional);
|
||||
assert_eq!("dpdx", dpdx.ident.to_string());
|
||||
assert!(!dpdx.pre_bracket);
|
||||
assert!(!dpdx.pre_pipe);
|
||||
assert!(!dpdx.post_bracket);
|
||||
assert!(!dpdx.can_be_negated);
|
||||
}
|
||||
{
|
||||
let dpdy = &arguments.0[6];
|
||||
assert!(!dpdy.optional);
|
||||
assert_eq!("dpdy", dpdy.ident.to_string());
|
||||
assert!(!dpdy.pre_bracket);
|
||||
assert!(!dpdy.pre_pipe);
|
||||
assert!(!dpdy.post_bracket);
|
||||
assert!(!dpdy.can_be_negated);
|
||||
}
|
||||
{
|
||||
let e = &arguments.0[7];
|
||||
assert!(e.optional);
|
||||
assert_eq!("e", e.ident.to_string());
|
||||
assert!(!e.pre_bracket);
|
||||
assert!(!e.pre_pipe);
|
||||
assert!(!e.post_bracket);
|
||||
assert!(!e.can_be_negated);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rule_multi() {
|
||||
let input = quote! {
|
||||
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }
|
||||
};
|
||||
let rule = syn::parse2::<super::Rule>(input).unwrap();
|
||||
assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string());
|
||||
assert_eq!(
|
||||
"StateSpace",
|
||||
rule.type_.unwrap().to_token_stream().to_string()
|
||||
);
|
||||
let alts = rule
|
||||
.alternatives
|
||||
.iter()
|
||||
.map(|m| m.tokens().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
vec![
|
||||
". global",
|
||||
". local",
|
||||
". param",
|
||||
". param :: func",
|
||||
". shared",
|
||||
". shared :: cta",
|
||||
". shared :: cluster"
|
||||
],
|
||||
alts
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rule_multi2() {
|
||||
let input = quote! {
|
||||
.cop: StCacheOperator = { .wb, .cg, .cs, .wt }
|
||||
};
|
||||
let rule = syn::parse2::<super::Rule>(input).unwrap();
|
||||
assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string());
|
||||
assert_eq!(
|
||||
"StCacheOperator",
|
||||
rule.type_.unwrap().to_token_stream().to_string()
|
||||
);
|
||||
let alts = rule
|
||||
.alternatives
|
||||
.iter()
|
||||
.map(|m| m.tokens().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn args_unified() {
|
||||
let input = quote! {
|
||||
d, [a]{.unified}{, cache_policy}
|
||||
};
|
||||
let args = syn::parse2::<super::Arguments>(input).unwrap();
|
||||
let a = &args.0[1];
|
||||
assert!(!a.optional);
|
||||
assert_eq!("a", a.ident.to_string());
|
||||
assert!(a.pre_bracket);
|
||||
assert!(!a.pre_pipe);
|
||||
assert!(a.post_bracket);
|
||||
assert!(!a.can_be_negated);
|
||||
assert!(a.unified);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn special_block() {
|
||||
let input = quote! {
|
||||
bra <= { bra(stream) }
|
||||
};
|
||||
syn::parse2::<super::OpcodeDefinition>(input).unwrap();
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue