PTX parser rewrite (#267)
Some checks failed
Rust / Build and publish (Linux) (push) Has been cancelled
Rust / Build and publish (Windows) (push) Has been cancelled

Replaces traditional LALRPOP-based parser with winnow-based parser to handle out-of-order instruction modifer. Generate instruction type and instruction visitor from a macro instead of writing by hand. Add separate compilation path using the new parser that only works in tests for now
This commit is contained in:
Andrzej Janik 2024-09-04 15:47:42 +02:00 committed by GitHub
parent 872054ae40
commit 193eb29be8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 14776 additions and 55 deletions

View file

@ -1,5 +1,7 @@
[workspace]
resolver = "2"
members = [
"cuda_base",
"cuda_types",
@ -15,6 +17,9 @@ members = [
"zluda_redirect",
"zluda_ml",
"ptx",
"ptx_parser",
"ptx_parser_macros",
"ptx_parser_macros_impl",
]
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]

View file

@ -7,7 +7,7 @@ edition = "2018"
[lib]
[dependencies]
lalrpop-util = "0.19"
ptx_parser = { path = "../ptx_parser" }
regex = "1"
rspirv = "0.7"
spirv_headers = "1.5"
@ -17,8 +17,12 @@ bit-vec = "0.6"
half ="1.6"
bitflags = "1.2"
[dependencies.lalrpop-util]
version = "0.19.12"
features = ["lexer"]
[build-dependencies.lalrpop]
version = "0.19"
version = "0.19.12"
features = ["lexer"]
[dev-dependencies]

View file

@ -16,6 +16,8 @@ pub enum PtxError {
source: ParseFloatError,
},
#[error("")]
Unsupported32Bit,
#[error("")]
SyntaxError,
#[error("")]
NonF32Ftz,
@ -32,15 +34,9 @@ pub enum PtxError {
#[error("")]
NonExternPointer,
#[error("{start}:{end}")]
UnrecognizedStatement {
start: usize,
end: usize,
},
UnrecognizedStatement { start: usize, end: usize },
#[error("{start}:{end}")]
UnrecognizedDirective {
start: usize,
end: usize,
},
UnrecognizedDirective { start: usize, end: usize },
}
// For some weird reson this is illegal:
@ -576,11 +572,15 @@ impl CvtDetails {
if saturate {
if src.kind() == ScalarKind::Signed {
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
err.push(ParseError::User {
error: PtxError::SyntaxError,
});
}
} else {
if dst == src || dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
err.push(ParseError::User {
error: PtxError::SyntaxError,
});
}
}
}
@ -596,7 +596,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && dst != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
err.push(ParseError::from(lalrpop_util::ParseError::User {
error: PtxError::NonF32Ftz,
}));
}
CvtDetails::FloatFromInt(CvtDesc {
dst,
@ -616,7 +618,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && src != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
err.push(ParseError::from(lalrpop_util::ParseError::User {
error: PtxError::NonF32Ftz,
}));
}
CvtDetails::IntFromFloat(CvtDesc {
dst,

View file

@ -24,6 +24,7 @@ lalrpop_mod!(
);
pub mod ast;
pub(crate) mod pass;
#[cfg(test)]
mod test;
mod translate;

View file

@ -0,0 +1,299 @@
use std::collections::{BTreeMap, BTreeSet};
use super::*;
/*
PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[];
In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to
an additional method argument
The question is how this artificial argument should be expressed. There are
several options:
* Straight conversion:
.shared .b32 shared_mem[]
* Introduce .param_shared statespace:
.param_shared .b32 shared_mem
or
.param_shared .b32 shared_mem[]
* Introduce .shared_ptr <SCALAR> type:
.param .shared_ptr .b32 shared_mem
* Reuse .ptr hint:
.param .u64 .ptr shared_mem
This is the most tempting, but also the most nonsensical, .ptr is just a
hint, which has no semantical meaning (and the output of our
transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/
pub(super) fn run<'input>(
module: Vec<Directive<'input>>,
kernels_methods_call_map: &MethodsCallMap<'input>,
new_id: &mut impl FnMut() -> SpirvWord,
) -> Result<Vec<Directive<'input>>, TranslateError> {
let mut globals_shared = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(
_,
ast::Variable {
state_space: ast::StateSpace::Shared,
name,
v_type,
..
},
) => {
globals_shared.insert(*name, v_type.clone());
}
_ => {}
}
}
if globals_shared.len() == 0 {
return Ok(module);
}
let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
let module = module
.into_iter()
.map(|directive| match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}) => {
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| {
statement.visit_map(
&mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
if let Some(_) = globals_shared.get(&id) {
methods_to_directly_used_shared_globals
.entry(call_key)
.or_insert_with(HashSet::new)
.insert(id);
}
Ok::<_, TranslateError>(id)
},
)
})
.collect::<Result<Vec<_>, _>>()?;
Ok::<_, TranslateError>(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}))
}
directive => Ok(directive),
})
.collect::<Result<Vec<_>, _>>()?;
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
methods_to_directly_used_shared_globals,
kernels_methods_call_map,
);
// now visit every method declaration and inject those additional arguments
let mut directives = Vec::with_capacity(module.len());
for directive in module.into_iter() {
match directive {
Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}) => {
let statements = {
let func_decl_ref = &mut (*func_decl).borrow_mut();
let method_name = func_decl_ref.name;
insert_arguments_remap_statements(
new_id,
kernels_methods_call_map,
&globals_shared,
&methods_to_indirectly_used_shared_globals,
method_name,
&mut directives,
func_decl_ref,
statements,
)?
};
directives.push(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
}));
}
directive => directives.push(directive),
}
}
Ok(directives)
}
// We need to compute two kinds of information:
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
// * If it's a function -> does it use .shared global (directly or indirectly)
fn resolve_indirect_uses_of_globals_shared<'input>(
methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
kernels_methods_call_map: &MethodsCallMap<'input>,
) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
let mut result = HashMap::new();
for (method, callees) in kernels_methods_call_map.methods() {
let mut indirect_globals = methods_use_of_globals_shared
.get(&method)
.into_iter()
.flatten()
.copied()
.collect::<BTreeSet<_>>();
for &callee in callees {
indirect_globals.extend(
methods_use_of_globals_shared
.get(&ast::MethodName::Func(callee))
.into_iter()
.flatten()
.copied(),
);
}
result.insert(method, indirect_globals);
}
result
}
fn insert_arguments_remap_statements<'input>(
new_id: &mut impl FnMut() -> SpirvWord,
kernels_methods_call_map: &MethodsCallMap<'input>,
globals_shared: &HashMap<SpirvWord, ast::Type>,
methods_to_indirectly_used_shared_globals: &HashMap<
ast::MethodName<'input, SpirvWord>,
BTreeSet<SpirvWord>,
>,
method_name: ast::MethodName<SpirvWord>,
result: &mut Vec<Directive>,
func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let remapped_globals_in_method =
if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
match method_name {
ast::MethodName::Func(..) => {
let remapped_globals = method_globals
.iter()
.map(|global| {
(
*global,
(
new_id(),
globals_shared
.get(&global)
.unwrap_or_else(|| todo!())
.clone(),
),
)
})
.collect::<BTreeMap<_, _>>();
for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
func_decl_ref.input_arguments.push(ast::Variable {
align: None,
v_type: shared_global_type.clone(),
state_space: ast::StateSpace::Shared,
name: *new_shared_global_id,
array_init: Vec::new(),
});
}
remapped_globals
}
ast::MethodName::Kernel(..) => method_globals
.iter()
.map(|global| {
(
*global,
(
*global,
globals_shared
.get(&global)
.unwrap_or_else(|| todo!())
.clone(),
),
)
})
.collect::<BTreeMap<_, _>>(),
}
} else {
return Ok(statements);
};
replace_uses_of_shared_memory(
new_id,
methods_to_indirectly_used_shared_globals,
statements,
remapped_globals_in_method,
)
}
fn replace_uses_of_shared_memory<'input>(
new_id: &mut impl FnMut() -> SpirvWord,
methods_to_indirectly_used_shared_globals: &HashMap<
ast::MethodName<'input, SpirvWord>,
BTreeSet<SpirvWord>,
>,
statements: Vec<ExpandedStatement>,
remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Instruction(ast::Instruction::Call {
mut data,
mut arguments,
}) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if let Some(shared_globals_used_by_callee) =
methods_to_indirectly_used_shared_globals
.get(&ast::MethodName::Func(arguments.func))
{
for &shared_global_used_by_callee in shared_globals_used_by_callee {
let (remapped_shared_id, type_) = remapped_globals_in_method
.get(&shared_global_used_by_callee)
.unwrap_or_else(|| todo!());
data.input_arguments
.push((type_.clone(), ast::StateSpace::Shared));
arguments.input_arguments.push(*remapped_shared_id);
}
}
result.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}))
}
statement => {
let new_statement =
statement.visit_map(&mut |id,
_: Option<(&ast::Type, ast::StateSpace)>,
_,
_| {
Ok::<_, TranslateError>(
if let Some((remapped_shared_id, _)) =
remapped_globals_in_method.get(&id)
{
*remapped_shared_id
} else {
id
},
)
})?;
result.push(new_statement);
}
}
}
Ok(result)
}

View file

@ -0,0 +1,524 @@
use super::*;
use ptx_parser as ast;
use std::{
collections::{BTreeSet, HashSet},
iter,
rc::Rc,
};
/*
Our goal here is to transform
.visible .entry foobar(.param .u64 input) {
.reg .b64 in_addr;
.reg .b64 in_addr2;
ld.param.u64 in_addr, [input];
cvta.to.global.u64 in_addr2, in_addr;
}
into:
.visible .entry foobar(.param .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
ld.param.u8[] in_addr, [input];
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.reg .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
mov.u8[] in_addr, input;
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.param ptr<u8, global> input) {
.reg ptr<u8, global> in_addr;
.reg ptr<u8, global> in_addr2;
ld.param.ptr<u8, global> in_addr, [input];
mov.ptr<u8, global> in_addr2, in_addr;
}
*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate out of calls and into calls
pub(super) fn run<'a, 'input>(
func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<
(
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
Vec<TypedStatement>,
),
TranslateError,
> {
let mut method_decl = func_args.borrow_mut();
if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
drop(method_decl);
return Ok((func_args, func_body));
}
if Rc::strong_count(&func_args) != 1 {
return Err(error_unreachable());
}
let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
_ => None,
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta {
data:
ast::CvtaDetails {
state_space: ast::StateSpace::Global,
direction: ast::CvtaDirection::GenericToExplicit,
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
stateful_markers.push((dst, src));
}
}
}
Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if func_args_64bit.contains(&src) {
multi_hash_map_append(&mut stateful_init_reg, dst, src);
}
}
}
_ => {}
}
}
if stateful_markers.len() == 0 {
drop(method_decl);
return Ok((func_args, func_body));
}
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
if let Some(func_args) = stateful_init_reg.get(&src) {
for a in func_args {
func_args_ptr.insert(*a);
regs_ptr_current.insert(src);
regs_ptr_current.insert(dst);
}
}
}
// BTreeSet here to have a stable order of iteration,
// unfortunately our tests rely on it
let mut regs_ptr_seen = BTreeSet::new();
while regs_ptr_current.len() > 0 {
let mut regs_ptr_new = HashSet::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}
Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}
_ => {}
}
}
for id in regs_ptr_current {
regs_ptr_seen.insert(id);
}
regs_ptr_current = regs_ptr_new;
}
drop(regs_ptr_current);
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
for arg in (*method_decl).input_arguments.iter_mut() {
if !func_args_ptr.contains(&arg.name) {
continue;
}
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Param,
);
let old_name = arg.name;
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.name = new_id;
remapped_ids.insert(old_name, new_id);
}
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
c @ Statement::Conditional(_) => result.push(c),
c @ Statement::Constant(..) => result.push(c),
Statement::Variable(var) => {
if !remapped_ids.contains_key(&var.name) {
result.push(Statement::Variable(var));
}
}
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_add_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(src1) if remapped_ids.contains_key(&src1) => {
(remapped_ids.get(&src1).unwrap(), arguments.src2)
}
Some(src2) if remapped_ids.contains_key(&src2) => {
(remapped_ids.get(&src2).unwrap(), arguments.src1)
}
_ => return Err(error_unreachable()),
};
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
_ => return Err(error_unreachable()),
};
let offset_neg = id_defs.register_intermediate(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
result.push(Statement::Instruction(ast::Instruction::Neg {
data: ast::TypeFtz {
type_: ast::ScalarType::S64,
flush_to_zero: None,
},
arguments: ast::NegArgs {
src: offset,
dst: TypedOperand::Reg(offset_neg),
},
}));
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
}))
}
inst @ Statement::Instruction(_) => {
let mut post_statements = Vec::new();
let new_statement = inst.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
repack @ Statement::RepackVector(_) => {
let mut post_statements = Vec::new();
let new_statement = repack.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
drop(method_decl);
Ok((func_args, result))
}
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
match id_defs.get_typed(id) {
Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
fn is_add_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::AddArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
if let Some(ref src1_reg) = arg.src1.underlying_register() {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when adding two pointers
if let Some(ref src2_reg) = arg.src2.underlying_register() {
return !remapped_ids.contains_key(src2_reg);
}
}
}
if let Some(ref src2_reg) = arg.src2.underlying_register() {
remapped_ids.contains_key(src2_reg)
} else {
false
}
}
}
}
fn is_sub_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::SubArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
match arg.src1.underlying_register() {
Some(ref src1_reg) => {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when subtracting two pointers
arg.src2
.underlying_register()
.map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
} else {
false
}
}
None => false,
}
}
}
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
operand: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_conversion: bool,
) -> Result<TypedOperand, TranslateError> {
operand.map(|operand, _| {
Ok(match remapped_ids.get(&operand) {
Some(new_id) => {
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
// TODO: readd if required
if let Some((expected_type, expected_space)) = type_space {
let implicit_conversion = if relaxed_conversion {
if is_dst {
super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
} else {
super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
}
} else {
super::insert_implicit_conversions::default_implicit_conversion
};
if implicit_conversion(
(new_operand_space, &new_operand_type),
(expected_space, expected_type),
)
.is_ok()
{
return Ok(*new_id);
}
}
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
};
if is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_operand_type,
from_space: old_operand_space,
to_type: new_operand_type,
to_space: new_operand_space,
kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: new_operand_type,
from_space: new_operand_space,
to_type: old_operand_type,
to_space: old_operand_space,
kind,
}));
converting_id
}
}
None => operand,
})
})
}

View file

@ -0,0 +1,138 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run(
func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
for s in func {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
} if fn_defs.fns.contains_key(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(error_mismatched_type());
}
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
}));
}
ast::Instruction::Call { data, arguments } => {
let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call =
Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
visitor.func.push(instruction);
visitor.func.extend(visitor.post_stmts);
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
_ => return Err(error_unreachable()),
}
}
Ok(result)
}
struct VectorRepackVisitor<'a, 'b> {
func: &'b mut Vec<TypedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
post_stmts: Option<TypedStatement>,
}
impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
VectorRepackVisitor {
func,
id_def,
post_stmts: None,
}
}
fn convert_vector(
&mut self,
is_dst: bool,
relaxed_type_check: bool,
typ: &ast::Type,
state_space: ast::StateSpace,
idx: Vec<SpirvWord>,
) -> Result<SpirvWord, TranslateError> {
// mov.u32 foobar, {a,b};
let scalar_t = match typ {
ast::Type::Vector(_, scalar_t) => *scalar_t,
_ => return Err(error_mismatched_type()),
};
let temp_vec = self
.id_def
.register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
relaxed_type_check,
});
if is_dst {
self.post_stmts = Some(statement);
} else {
self.func.push(statement);
}
Ok(temp_vec)
}
}
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
for VectorRepackVisitor<'a, 'b>
{
fn visit_ident(
&mut self,
ident: SpirvWord,
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
_: bool,
_: bool,
) -> Result<SpirvWord, TranslateError> {
Ok(ident)
}
fn visit(
&mut self,
op: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
Ok(match op {
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
ast::ParsedOperand::VecPack(vec) => {
let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
TypedOperand::Reg(self.convert_vector(
is_dst,
relaxed_type_check,
type_,
space,
vec,
)?)
}
})
}
}

2763
ptx/src/pass/emit_spirv.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,181 @@
use super::*;
use ptx_parser as ast;
pub(super) fn run<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
Statement::Constant(c) => result.push(Statement::Constant(c)),
Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
s => {
let (new_statement, post_stmts) = {
let mut visitor = FlattenArguments::new(&mut result, id_def);
(s.visit_map(&mut visitor)?, visitor.post_stmts)
};
result.push(new_statement);
result.extend(post_stmts);
}
}
}
Ok(result)
}
struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'b> FlattenArguments<'a, 'b> {
fn new(
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Self {
FlattenArguments {
func,
id_def,
post_stmts: Vec::new(),
}
}
fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
Ok(name)
}
fn reg_offset(
&mut self,
reg: SpirvWord,
offset: i32,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
_is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
let (type_, state_space) = if let Some((type_, state_space)) = type_space {
(type_, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
return Err(error_mismatched_type());
}
let reg_scalar_type = match reg_type {
ast::Type::Scalar(underlying_type) => underlying_type,
_ => return Err(error_mismatched_type()),
};
let id_constant_stmt = self
.id_def
.register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: reg_scalar_type,
value: ast::ImmediateValue::S64(offset as i64),
}));
let arith_details = match reg_scalar_type.kind() {
ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
}),
ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
ast::ArithDetails::Integer(ast::ArithInteger {
type_: reg_scalar_type,
saturate: false,
})
}
_ => return Err(error_unreachable()),
};
let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
self.func
.push(Statement::Instruction(ast::Instruction::Add {
data: arith_details,
arguments: ast::AddArgs {
dst: id_add_result,
src1: reg,
src2: id_constant_stmt,
},
}));
Ok(id_add_result)
} else {
let id_constant_stmt = self.id_def.register_intermediate(
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self
.id_def
.register_intermediate(type_.clone(), state_space);
self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: type_.clone(),
state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
Ok(dst)
}
}
fn immediate(
&mut self,
value: ast::ImmediateValue,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
let (scalar_t, state_space) =
if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
(*scalar, state_space)
} else {
return Err(TranslateError::UntypedSymbol);
};
let id = self
.id_def
.register_intermediate(ast::Type::Scalar(scalar_t), state_space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
value,
}));
Ok(id)
}
}
impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
fn visit(
&mut self,
args: TypedOperand,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match args {
TypedOperand::Reg(r) => self.reg(r),
TypedOperand::Imm(x) => self.immediate(x, type_space),
TypedOperand::RegOffset(reg, offset) => {
self.reg_offset(reg, offset, type_space, is_dst)
}
TypedOperand::VecMember(..) => Err(error_unreachable()),
}
}
fn visit_ident(
&mut self,
name: <TypedOperand as ptx_parser::Operand>::Ident,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
self.reg(name)
}
}

View file

@ -0,0 +1,282 @@
use super::*;
pub(super) fn run<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
match statement {
Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfe { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Bfi { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
let fn_name: String =
[ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Brev { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Activemask { arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::IncrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_inc",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::DecrementWrap,
semantics,
scope,
space,
..
},
arguments,
}) => {
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_dec",
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
Statement::Instruction(ast::Instruction::Atom {
data:
data @ ast::AtomDetails {
op: ast::AtomicOp::FloatAdd,
semantics,
scope,
space,
..
},
arguments,
}) => {
let scalar_type = match data.type_ {
ptx_parser::Type::Scalar(scalar) => scalar,
_ => return Err(error_unreachable()),
};
let fn_name = [
ZLUDA_PTX_PREFIX,
"atom_",
semantics_to_ptx_name(semantics),
"_",
scope_to_ptx_name(scope),
"_",
space_to_ptx_name(space),
"_add_",
scalar_to_ptx_name(scalar_type),
]
.concat();
local.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
ast::Instruction::Atom { data, arguments },
fn_name,
)?);
}
s => local.push(s),
}
}
Ok((local, global))
}
fn instruction_to_fn_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
inst: ast::Instruction<SpirvWord>,
fn_name: String,
) -> Result<ExpandedStatement, TranslateError> {
let mut arguments = Vec::new();
ast::visit_map(inst, &mut |operand,
type_space: Option<(
&ast::Type,
ast::StateSpace,
)>,
is_dst,
_| {
let (typ, space) = match type_space {
Some((typ, space)) => (typ.clone(), space),
None => return Err(error_unreachable()),
};
arguments.push((operand, is_dst, typ, space));
Ok(SpirvWord(0))
})?;
let return_arguments_count = arguments
.iter()
.position(|(desc, is_dst, _, _)| !is_dst)
.unwrap_or(arguments.len());
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
let fn_id = register_external_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
return_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
input_arguments
.iter()
.map(|(_, _, typ, state)| (typ, *state)),
)?;
Ok(Statement::Instruction(ast::Instruction::Call {
data: ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
input_arguments: input_arguments
.iter()
.map(|(_, _, typ, state)| (typ.clone(), *state))
.collect::<Vec<_>>(),
},
arguments: ast::CallArgs {
return_arguments: return_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
func: fn_id,
input_arguments: input_arguments
.iter()
.map(|(name, _, _, _)| *name)
.collect::<Vec<_>>(),
},
}))
}
fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
match this {
ast::ScalarType::B8 => "b8",
ast::ScalarType::B16 => "b16",
ast::ScalarType::B32 => "b32",
ast::ScalarType::B64 => "b64",
ast::ScalarType::B128 => "b128",
ast::ScalarType::U8 => "u8",
ast::ScalarType::U16 => "u16",
ast::ScalarType::U16x2 => "u16x2",
ast::ScalarType::U32 => "u32",
ast::ScalarType::U64 => "u64",
ast::ScalarType::S8 => "s8",
ast::ScalarType::S16 => "s16",
ast::ScalarType::S16x2 => "s16x2",
ast::ScalarType::S32 => "s32",
ast::ScalarType::S64 => "s64",
ast::ScalarType::F16 => "f16",
ast::ScalarType::F16x2 => "f16x2",
ast::ScalarType::F32 => "f32",
ast::ScalarType::F64 => "f64",
ast::ScalarType::BF16 => "bf16",
ast::ScalarType::BF16x2 => "bf16x2",
ast::ScalarType::Pred => "pred",
}
}
fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
match this {
ast::AtomSemantics::Relaxed => "relaxed",
ast::AtomSemantics::Acquire => "acquire",
ast::AtomSemantics::Release => "release",
ast::AtomSemantics::AcqRel => "acq_rel",
}
}
fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
match this {
ast::MemScope::Cta => "cta",
ast::MemScope::Gpu => "gpu",
ast::MemScope::Sys => "sys",
ast::MemScope::Cluster => "cluster",
}
}
fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
match this {
ast::StateSpace::Generic => "generic",
ast::StateSpace::Global => "global",
ast::StateSpace::Shared => "shared",
ast::StateSpace::Reg => "reg",
ast::StateSpace::Const => "const",
ast::StateSpace::Local => "local",
ast::StateSpace::Param => "param",
ast::StateSpace::Sreg => "sreg",
ast::StateSpace::SharedCluster => "shared_cluster",
ast::StateSpace::ParamEntry => "param_entry",
ast::StateSpace::SharedCta => "shared_cta",
ast::StateSpace::ParamFunc => "param_func",
}
}

View file

@ -0,0 +1,130 @@
use super::*;
use std::collections::HashMap;
pub(super) fn run<'a, 'b, 'input>(
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let result = Vec::with_capacity(typed_statements.len());
let mut sreg_sresolver = SpecialRegisterResolver {
ptx_impl_imports,
numeric_id_defs,
result,
};
for statement in typed_statements {
let statement = statement.visit_map(&mut sreg_sresolver)?;
sreg_sresolver.result.push(statement);
}
Ok(sreg_sresolver.result)
}
struct SpecialRegisterResolver<'a, 'b, 'input> {
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,
result: Vec<TypedStatement>,
}
impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for SpecialRegisterResolver<'a, 'b, 'input>
{
fn visit(
&mut self,
operand: TypedOperand,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
}
fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.replace_sreg(args, is_dst, None)
}
}
impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
is_dst: bool,
vector_index: Option<u8>,
) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(
TypedOperand::Reg(constant),
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)]
}
(None, None) => Vec::new(),
_ => return Err(error_mismatched_type()),
};
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let return_type = sreg.get_function_return_type();
let fn_result = self.numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let fn_call = register_external_fn_call(
self.numeric_id_defs,
self.ptx_impl_imports,
ocl_fn_name.to_string(),
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
)?;
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: fn_call,
input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(fn_result)
} else {
Ok(name)
}
}
}

View file

@ -0,0 +1,432 @@
use std::mem;
use super::*;
use ptx_parser as ast;
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- st.param [x] y (used as function return arguments) same rule as above applies
- generic/global ld: for instruction `ld x, [y]`, y must be of type
b64/u64/s64, which is bitcast to a pointer, dereferenced and then
documented special ld/st/cvt conversion rules are applied to dst
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
pub(super) fn run(
func: Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
Statement::Instruction(inst) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::Instruction(inst),
)?;
}
Statement::PtrAccess(access) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::PtrAccess(access),
)?;
}
Statement::RepackVector(repack) => {
insert_implicit_conversions_impl(
&mut result,
id_def,
Statement::RepackVector(repack),
)?;
}
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(..)
| s @ Statement::StoreVar(..)
| s @ Statement::RetValue(..)
| s @ Statement::FunctionPointer(..) => result.push(s),
}
}
Ok(result)
}
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: ExpandedStatement,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
let statement = stmt.visit_map::<SpirvWord, TranslateError>(
&mut |operand,
type_state: Option<(&ast::Type, ast::StateSpace)>,
is_dst,
relaxed_type_check| {
let (instr_type, instruction_space) = match type_state {
None => return Ok(operand),
Some(t) => t,
};
let (operand_type, operand_space) = id_def.get_typed(operand)?;
let conversion_fn = if relaxed_type_check {
if is_dst {
should_convert_relaxed_dst_wrapper
} else {
should_convert_relaxed_src_wrapper
}
} else {
default_implicit_conversion
};
match conversion_fn(
(operand_space, &operand_type),
(instruction_space, instr_type),
)? {
Some(conv_kind) => {
let conv_output = if is_dst { &mut post_conv } else { &mut *func };
let mut from_type = instr_type.clone();
let mut from_space = instruction_space;
let mut to_type = operand_type;
let mut to_space = operand_space;
let mut src =
id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = operand;
let result = Ok::<_, TranslateError>(src);
if !is_dst {
mem::swap(&mut src, &mut dst);
mem::swap(&mut from_type, &mut to_type);
mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
from_space,
to_type,
to_space,
kind: conv_kind,
}));
result
}
None => Ok(operand),
}
},
)?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
pub(crate) fn default_implicit_conversion(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if instruction_space == ast::StateSpace::Reg {
if space_is_compatible(operand_space, ast::StateSpace::Reg) {
if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
(operand_type, instruction_type)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
}
} else if is_addressable(operand_space) {
return Ok(Some(ConversionKind::AddressOf));
}
}
if !space_is_compatible(instruction_space, operand_space) {
default_implicit_conversion_space(
(operand_space, operand_type),
(instruction_space, instruction_type),
)
} else if instruction_type != operand_type {
default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
} else {
Ok(None)
}
}
fn is_addressable(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Const
| ast::StateSpace::Generic
| ast::StateSpace::Global
| ast::StateSpace::Local
| ast::StateSpace::Shared => true,
ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
ast::StateSpace::SharedCluster
| ast::StateSpace::SharedCta
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc => todo!(),
}
}
// Space is different
fn default_implicit_conversion_space(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
|| (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
{
Ok(Some(ConversionKind::PtrToPtr))
} else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
match operand_type {
ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
if *operand_ptr_space == instruction_space =>
{
if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
// TODO: 32 bit
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
ast::StateSpace::Global
| ast::StateSpace::Generic
| ast::StateSpace::Const
| ast::StateSpace::Local
| ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
_ => Err(error_mismatched_type()),
},
ast::Type::Scalar(ast::ScalarType::B32)
| ast::Type::Scalar(ast::ScalarType::U32)
| ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
Ok(Some(ConversionKind::BitToPtr))
}
_ => Err(error_mismatched_type()),
},
_ => Err(error_mismatched_type()),
}
} else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
match instruction_type {
ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
if operand_space == *instruction_ptr_space =>
{
if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
Ok(Some(ConversionKind::PtrToPtr))
} else {
Ok(None)
}
}
_ => Err(error_mismatched_type()),
}
} else {
Err(error_mismatched_type())
}
}
// Space is same, but type is different
fn default_implicit_conversion_type(
space: ast::StateSpace,
operand_type: &ast::Type,
instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
if space_is_compatible(space, ast::StateSpace::Reg) {
if should_bitcast(instruction_type, operand_type) {
Ok(Some(ConversionKind::Default))
} else {
Err(TranslateError::MismatchedType)
}
} else {
Ok(Some(ConversionKind::PtrToPtr))
}
}
fn coerces_to_generic(this: ast::StateSpace) -> bool {
match this {
ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::Local
| ptx_parser::StateSpace::SharedCta
| ast::StateSpace::SharedCluster
| ast::StateSpace::Shared => true,
ast::StateSpace::Reg
| ast::StateSpace::Param
| ast::StateSpace::ParamEntry
| ast::StateSpace::ParamFunc
| ast::StateSpace::Generic
| ast::StateSpace::Sreg => false,
}
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
ast::ScalarKind::Signed => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Unsigned
}
ast::ScalarKind::Unsigned => {
operand.kind() == ast::ScalarKind::Bit
|| operand.kind() == ast::ScalarKind::Signed
}
ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
| (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
pub(crate) fn should_convert_relaxed_dst_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !space_is_compatible(operand_space, instruction_space) {
return Err(TranslateError::MismatchedType);
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed => {
if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
}
} else {
None
}
}
ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= dst_type.size_of()
&& dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_dst(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}
pub(crate) fn should_convert_relaxed_src_wrapper(
(operand_space, operand_type): (ast::StateSpace, &ast::Type),
(instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
if !space_is_compatible(operand_space, instruction_space) {
return Err(error_mismatched_type());
}
if operand_type == instruction_type {
return Ok(None);
}
match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(error_mismatched_type()),
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: &ast::Type,
instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Float => {
if instr_type.size_of() <= src_type.size_of()
&& src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
| (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
should_convert_relaxed_src(
&ast::Type::Scalar(*dst_type),
&ast::Type::Scalar(*instr_type),
)
}
_ => None,
}
}

View file

@ -0,0 +1,275 @@
use super::*;
use ptx_parser as ast;
/*
How do we handle arguments:
- input .params in kernels
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that
at SPIR-V level every .param is a pointer in Function storage class)
- input .params in functions
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %_ptr_Function_ulong
- input .regs
.reg .b64 in_arg
get turned into the same SPIR-V as kernel .params:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
- output .regs
.reg .b64 out_arg
get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function
- output .params don't exist, they have been moved to input positions
by an earlier pass
Distinguishing betweem kernel .params and function .params is not the
cleanest solution. Alternatively, we could "deparamize" all kernel .param
arguments by turning them into .reg arguments like this:
.param .b64 arg -> .reg ptr<.b64,.param> arg
This has the massive downside that this transformation would have to run
very early and would muddy up already difficult code. It's simpler to just
have an if here
*/
pub(super) fn run<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.input_arguments.iter_mut() {
insert_mem_ssa_argument(
id_def,
&mut result,
arg,
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
);
}
for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret { data } => {
// TODO: handle multiple output args
match &fn_decl.return_arguments[..] {
[return_reg] => {
let new_id = id_def.register_intermediate(Some((
return_reg.v_type.clone(),
ast::StateSpace::Reg,
)));
result.push(Statement::LoadVar(LoadVarDetails {
arg: ast::LdArgs {
dst: new_id,
src: return_reg.name,
},
typ: return_reg.v_type.clone(),
member_index: None,
}));
result.push(Statement::RetValue(data, new_id));
}
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
_ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::Instruction(inst),
)?,
},
Statement::Conditional(bra) => {
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
}
Statement::Conversion(conv) => {
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
}
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::PtrAccess(ptr_access),
)?,
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::RepackVector(repack),
)?,
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::FunctionPointer(func_ptr),
)?,
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
result.push(s)
}
_ => return Err(error_unreachable()),
}
}
Ok(result)
}
fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<SpirvWord>,
is_kernel: bool,
) {
if !is_kernel && arg.state_space == ast::StateSpace::Param {
return;
}
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: ast::StateSpace::Reg,
name: arg.name,
array_init: Vec::new(),
}));
func.push(Statement::StoreVar(StoreVarDetails {
arg: ast::StArgs {
src1: arg.name,
src2: new_id,
},
typ: arg.v_type.clone(),
member_index: None,
}));
arg.name = new_id;
}
fn insert_mem_ssa_argument_reg_return(
func: &mut Vec<TypedStatement>,
arg: &ast::Variable<SpirvWord>,
) {
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}
fn insert_mem_ssa_statement_default<'a, 'input>(
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
stmt: TypedStatement,
) -> Result<(), TranslateError> {
let mut visitor = InsertMemSSAVisitor {
id_def,
func,
post_statements: Vec::new(),
};
let new_stmt = stmt.visit_map(&mut visitor)?;
visitor.func.push(new_stmt);
visitor.func.extend(visitor.post_statements);
Ok(())
}
struct InsertMemSSAVisitor<'a, 'input> {
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
post_statements: Vec<TypedStatement>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
symbol: SpirvWord,
member_index: Option<u8>,
expected: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if expected.is_none() {
return Ok(symbol);
};
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match member_index {
Some(idx) => {
let vector_width = match var_type {
ast::Type::Vector(width, scalar_t) => {
var_type = ast::Type::Scalar(scalar_t);
width
}
_ => return Err(error_mismatched_type()),
};
Some((
idx,
if self.id_def.special_registers.get(symbol).is_some() {
Some(vector_width)
} else {
None
},
))
}
None => None,
};
let generated_id = self
.id_def
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: ast::LdArgs {
dst: generated_id,
src: symbol,
},
typ: var_type,
member_index,
}));
} else {
self.post_statements
.push(Statement::StoreVar(StoreVarDetails {
arg: ast::StArgs {
src1: symbol,
src2: generated_id,
},
typ: var_type,
member_index: member_index.map(|(idx, _)| idx),
}));
}
Ok(generated_id)
}
}
impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
operand: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
Ok(match operand {
TypedOperand::Reg(reg) => {
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
}
TypedOperand::RegOffset(reg, offset) => {
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
}
op @ TypedOperand::Imm(..) => op,
TypedOperand::VecMember(symbol, index) => {
TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
}
})
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.symbol(args, None, type_space, is_dst)
}
}

1677
ptx/src/pass/mod.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run<'input, 'b>(
id_defs: &mut FnStringIdResolver<'input, 'b>,
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
}
_ => (),
}
}
let mut result = Vec::new();
for s in func {
expand_map_variables(id_defs, fn_defs, &mut result, s)?;
}
Ok(result)
}
fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedOperand<&'a str>>,
) -> Result<(), TranslateError> {
match s {
ast::Statement::Block(block) => {
id_defs.start_block();
for s in block {
expand_map_variables(id_defs, fn_defs, result, s)?;
}
id_defs.end_block();
}
ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
.transpose()?,
ast::visit_map(i, &mut |id,
_: Option<(&ast::Type, ast::StateSpace)>,
_: bool,
_: bool| {
id_defs.get_id(id)
})?,
))),
ast::Statement::Variable(var) => {
let var_type = var.var.v_type.clone();
match var.count {
Some(count) => {
for new_id in
id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
{
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
let new_id =
id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
}
}
}
};
Ok(())
}

View file

@ -0,0 +1,48 @@
use std::{collections::HashSet, iter};
use super::*;
pub(super) fn run(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut labels_in_use = HashSet::new();
for s in func.iter() {
match s {
Statement::Instruction(i) => {
if let Some(target) = jump_target(i) {
labels_in_use.insert(target);
}
}
Statement::Conditional(cond) => {
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
Statement::Variable(..)
| Statement::LoadVar(..)
| Statement::StoreVar(..)
| Statement::RetValue(..)
| Statement::Conversion(..)
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}
}
iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
}))
.collect::<Vec<_>>()
}
fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
this: &ast::Instruction<T>,
) -> Option<SpirvWord> {
match this {
ast::Instruction::Bra { arguments } => Some(arguments.src),
_ => None,
}
}

View file

@ -0,0 +1,44 @@
use super::*;
use ptx_parser as ast;
pub(crate) fn run(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
let if_true = id_def.register_intermediate(None);
let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(inst));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(inst));
}
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
_ => return Err(error_unreachable()),
}
}
Ok(result)
}

View file

@ -7,20 +7,24 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
%21 = OpExtInstImport "OpenCL.std"
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "clz"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %24
%1 = OpFunction %void None %25
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%19 = OpLabel
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@ -37,11 +41,12 @@
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
%13 = OpExtInst %uint %21 clz %14
%18 = OpExtInst %uint %22 clz %14
%13 = OpCopyObject %uint %18
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
OpStore %18 %16 Aligned 4
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
OpStore %19 %16 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -7,6 +7,9 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s16_s8"
@ -45,9 +48,7 @@
%32 = OpBitcast %uint %15
%34 = OpUConvert %uchar %32
%20 = OpCopyObject %uchar %34
%35 = OpBitcast %uchar %20
%37 = OpSConvert %ushort %35
%19 = OpCopyObject %ushort %37
%19 = OpSConvert %ushort %20
%14 = OpSConvert %uint %19
OpStore %6 %14
%16 = OpLoad %ulong %5

View file

@ -7,9 +7,13 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s64_s32"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%27 = OpTypeFunction %void %ulong %ulong
@ -40,9 +44,7 @@
%12 = OpCopyObject %uint %18
OpStore %6 %12
%15 = OpLoad %uint %6
%32 = OpBitcast %uint %15
%33 = OpSConvert %ulong %32
%14 = OpCopyObject %ulong %33
%14 = OpSConvert %ulong %15
OpStore %7 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %ulong %7

View file

@ -7,9 +7,13 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_sat_s_u"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
@ -42,7 +46,7 @@
%15 = OpSatConvertSToU %uint %16
OpStore %7 %15
%18 = OpLoad %uint %7
%17 = OpBitcast %uint %18
%17 = OpCopyObject %uint %18
OpStore %8 %17
%19 = OpLoad %ulong %5
%20 = OpLoad %uint %8

View file

@ -1,3 +1,4 @@
use crate::pass;
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
spirv_txt: &'a [u8],
spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
let spirv_module = translate::to_spirv_module(ast)?;
let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap();
let spirv_module = pass::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());

View file

@ -7,20 +7,24 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
%21 = OpExtInstImport "OpenCL.std"
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "popc"
OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%1 = OpFunction %void None %24
%1 = OpFunction %void None %25
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
%19 = OpLabel
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@ -37,11 +41,12 @@
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
%13 = OpBitCount %uint %14
%18 = OpBitCount %uint %14
%13 = OpCopyObject %uint %18
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6
%18 = OpConvertUToPtr %_ptr_Generic_uint %15
OpStore %18 %16 Aligned 4
%19 = OpConvertUToPtr %_ptr_Generic_uint %15
OpStore %19 %16 Aligned 4
OpReturn
OpFunctionEnd

View file

@ -1,4 +1,4 @@
// Excersise as many features of vector types as possible
// Exercise as many features of vector types as possible
.version 6.5
.target sm_60

View file

@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
for statement in sorted_statements {
match statement {
Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,

17
ptx_parser/Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "ptx_parser"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[lib]
[dependencies]
logos = "0.14"
winnow = { version = "0.6.18" }
#winnow = { version = "0.6.18", features = ["debug"] }
ptx_parser_macros = { path = "../ptx_parser_macros" }
thiserror = "1.0"
bitflags = "1.2"
rustc-hash = "2.0.0"
derive_more = { version = "1", features = ["display"] }

1695
ptx_parser/src/ast.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,69 @@
import os, sys, subprocess
SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
MULTIVAR = ["", "<1>" ]
VECTOR = ["", ".v2" ]
HEADER = """
.version 8.5
.target sm_90
.address_size 64
"""
def directive(space, variable, multivar, vector):
return """{3}
{0} {4} .b32 variable{2} {1};
""".format(space, variable, multivar, HEADER, vector)
def entry_arg(space, variable, multivar, vector):
return """{3}
.entry foobar ( {0} {4} .b32 variable{2} {1})
{{
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def fn_arg(space, variable, multivar, vector):
return """{3}
.func foobar ( {0} {4} .b32 variable{2} {1})
{{
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def fn_body(space, variable, multivar, vector):
return """{3}
.func foobar ()
{{
{0} {4} .b32 variable{2} {1};
ret;
}}
""".format(space, variable, multivar, HEADER, vector)
def generate(generator):
legal = []
for space in SPACE:
for init in TYPE_AND_INIT:
for multi in MULTIVAR:
for vector in VECTOR:
ptx = generator(space, init, multi, vector)
if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
legal.append((space, vector, init, multi))
print(generator.__name__)
print(legal)
def main():
generate(directive)
generate(entry_arg)
generate(fn_arg)
generate(fn_body)
if __name__ == "__main__":
main()

3269
ptx_parser/src/lib.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,17 @@
[package]
name = "ptx_parser_macros"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[lib]
proc-macro = true
[dependencies]
ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" }
convert_case = "0.6.0"
rustc-hash = "2.0.0"
syn = "2.0.67"
quote = "1.0"
proc-macro2 = "1.0.86"
either = "1.13.0"

1023
ptx_parser_macros/src/lib.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,13 @@
[package]
name = "ptx_parser_macros_impl"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"
[lib]
[dependencies]
syn = { version = "2.0.67", features = ["extra-traits", "full"] }
quote = "1.0"
proc-macro2 = "1.0.86"
rustc-hash = "2.0.0"

View file

@ -0,0 +1,881 @@
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{
braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token,
Type, TypeParam, Visibility,
};
pub mod parser;
pub struct GenerateInstructionType {
pub visibility: Option<Visibility>,
pub name: Ident,
pub type_parameters: Punctuated<TypeParam, Token![,]>,
pub short_parameters: Punctuated<Ident, Token![,]>,
pub variants: Punctuated<InstructionVariant, Token![,]>,
}
impl GenerateInstructionType {
pub fn emit_arg_types(&self, tokens: &mut TokenStream) {
for v in self.variants.iter() {
v.emit_type(&self.visibility, tokens);
}
}
pub fn emit_instruction_type(&self, tokens: &mut TokenStream) {
let vis = &self.visibility;
let type_name = &self.name;
let type_parameters = &self.type_parameters;
let variants = self.variants.iter().map(|v| v.emit_variant());
quote! {
#vis enum #type_name<#type_parameters> {
#(#variants),*
}
}
.to_tokens(tokens);
}
pub fn emit_visit(&self, tokens: &mut TokenStream) {
self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
}
pub fn emit_visit_mut(&self, tokens: &mut TokenStream) {
self.emit_visit_impl(
VisitKind::RefMut,
tokens,
InstructionVariant::emit_visit_mut,
)
}
pub fn emit_visit_map(&self, tokens: &mut TokenStream) {
self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map)
}
fn emit_visit_impl(
&self,
kind: VisitKind,
tokens: &mut TokenStream,
mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream),
) {
let type_name = &self.name;
let type_parameters = &self.type_parameters;
let short_parameters = &self.short_parameters;
let mut inner_tokens = TokenStream::new();
for v in self.variants.iter() {
fn_(v, type_name, &mut inner_tokens);
}
let visit_ref = kind.reference();
let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
let visit_fn = format_ident!("visit{}", kind.fn_suffix());
let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
(
quote! { <#type_parameters, To: Operand, Err> },
quote! { <#short_parameters, To, Err> },
quote! { std::result::Result<#type_name<To>, Err> },
)
} else {
(
quote! { <#type_parameters, Err> },
quote! { <#short_parameters, Err> },
quote! { std::result::Result<(), Err> },
)
};
quote! {
pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
Ok(match i {
#inner_tokens
})
}
}.to_tokens(tokens);
if kind == VisitKind::Map {
return;
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum VisitKind {
Ref,
RefMut,
Map,
}
impl VisitKind {
fn fn_suffix(self) -> &'static str {
match self {
VisitKind::Ref => "",
VisitKind::RefMut => "_mut",
VisitKind::Map => "_map",
}
}
fn type_suffix(self) -> &'static str {
match self {
VisitKind::Ref => "",
VisitKind::RefMut => "Mut",
VisitKind::Map => "Map",
}
}
fn reference(self) -> Option<proc_macro2::TokenStream> {
match self {
VisitKind::Ref => Some(quote! { & }),
VisitKind::RefMut => Some(quote! { &mut }),
VisitKind::Map => None,
}
}
}
impl Parse for GenerateInstructionType {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let visibility = if !input.peek(Token![enum]) {
Some(input.parse::<Visibility>()?)
} else {
None
};
input.parse::<Token![enum]>()?;
let name = input.parse::<Ident>()?;
input.parse::<Token![<]>()?;
let type_parameters = Punctuated::parse_separated_nonempty(input)?;
let short_parameters = type_parameters
.iter()
.map(|p: &TypeParam| p.ident.clone())
.collect();
input.parse::<Token![>]>()?;
let variants_buffer;
braced!(variants_buffer in input);
let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?;
Ok(Self {
visibility,
name,
type_parameters,
short_parameters,
variants,
})
}
}
pub struct InstructionVariant {
pub name: Ident,
pub type_: Option<Option<Expr>>,
pub space: Option<Expr>,
pub data: Option<Type>,
pub arguments: Option<Arguments>,
pub visit: Option<Expr>,
pub visit_mut: Option<Expr>,
pub map: Option<Expr>,
}
impl InstructionVariant {
fn args_name(&self) -> Ident {
format_ident!("{}Args", self.name)
}
fn emit_variant(&self) -> TokenStream {
let name = &self.name;
let data = match &self.data {
None => {
quote! {}
}
Some(data_type) => {
quote! {
data: #data_type,
}
}
};
let arguments = match &self.arguments {
None => {
quote! {}
}
Some(args) => {
let args_name = self.args_name();
match &args {
Arguments::Def(InstructionArguments { generic: None, .. }) => {
quote! {
arguments: #args_name,
}
}
Arguments::Def(InstructionArguments {
generic: Some(generics),
..
}) => {
quote! {
arguments: #args_name <#generics>,
}
}
Arguments::Decl(type_) => quote! {
arguments: #type_,
},
}
}
};
quote! {
#name { #data #arguments }
}
}
fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
}
fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) {
self.emit_visit_impl(
&self.visit_mut,
enum_,
tokens,
InstructionArguments::emit_visit_mut,
)
}
fn emit_visit_impl(
&self,
visit_fn: &Option<Expr>,
enum_: &Ident,
tokens: &mut TokenStream,
mut fn_: impl FnMut(&InstructionArguments, &Option<Option<Expr>>, &Option<Expr>) -> TokenStream,
) {
let name = &self.name;
let arguments = match &self.arguments {
None => {
quote! {
#enum_ :: #name { .. } => { }
}
.to_tokens(tokens);
return;
}
Some(Arguments::Decl(_)) => {
quote! {
#enum_ :: #name { data, arguments } => { #visit_fn }
}
.to_tokens(tokens);
return;
}
Some(Arguments::Def(args)) => args,
};
let data = &self.data.as_ref().map(|_| quote! { data,});
let arg_calls = fn_(arguments, &self.type_, &self.space);
quote! {
#enum_ :: #name { #data arguments } => {
#arg_calls
}
}
.to_tokens(tokens);
}
fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) {
let name = &self.name;
let data = &self.data.as_ref().map(|_| quote! { data,});
let arguments = match self.arguments {
None => None,
Some(Arguments::Decl(_)) => {
let map = self.map.as_ref().unwrap();
quote! {
#enum_ :: #name { #data arguments } => {
#map
}
}
.to_tokens(tokens);
return;
}
Some(Arguments::Def(ref def)) => Some(def),
};
let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,});
let mut arg_calls = None;
let arguments_init = arguments.as_ref().map(|arguments| {
let arg_type = self.args_name();
arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space));
let arg_names = arguments.fields.iter().map(|arg| &arg.name);
quote! {
arguments: #arg_type { #(#arg_names),* }
}
});
quote! {
#enum_ :: #name { #data #arguments_ident } => {
#arg_calls
#enum_ :: #name { #data #arguments_init }
}
}
.to_tokens(tokens);
}
fn emit_type(&self, vis: &Option<Visibility>, tokens: &mut TokenStream) {
let arguments = match self.arguments {
Some(Arguments::Def(ref a)) => a,
Some(Arguments::Decl(_)) => return,
None => return,
};
let name = self.args_name();
let type_parameters = if arguments.generic.is_some() {
Some(quote! { <T> })
} else {
None
};
let fields = arguments.fields.iter().map(|f| f.emit_field(vis));
quote! {
#vis struct #name #type_parameters {
#(#fields),*
}
}
.to_tokens(tokens);
}
}
impl Parse for InstructionVariant {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = input.parse::<Ident>()?;
let properties_buffer;
braced!(properties_buffer in input);
let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?;
let mut type_ = None;
let mut space = None;
let mut data = None;
let mut arguments = None;
let mut visit = None;
let mut visit_mut = None;
let mut map = None;
for property in properties {
match property {
VariantProperty::Type(t) => type_ = Some(t),
VariantProperty::Space(s) => space = Some(s),
VariantProperty::Data(d) => data = Some(d),
VariantProperty::Arguments(a) => arguments = Some(a),
VariantProperty::Visit(e) => visit = Some(e),
VariantProperty::VisitMut(e) => visit_mut = Some(e),
VariantProperty::Map(e) => map = Some(e),
}
}
Ok(Self {
name,
type_,
space,
data,
arguments,
visit,
visit_mut,
map,
})
}
}
enum VariantProperty {
Type(Option<Expr>),
Space(Expr),
Data(Type),
Arguments(Arguments),
Visit(Expr),
VisitMut(Expr),
Map(Expr),
}
impl VariantProperty {
pub fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
Ok(if lookahead.peek(Token![type]) {
input.parse::<Token![type]>()?;
input.parse::<Token![:]>()?;
VariantProperty::Type(if input.peek(Token![!]) {
input.parse::<Token![!]>()?;
None
} else {
Some(input.parse::<Expr>()?)
})
} else if lookahead.peek(Ident) {
let key = input.parse::<Ident>()?;
match &*key.to_string() {
"data" => {
input.parse::<Token![:]>()?;
VariantProperty::Data(input.parse::<Type>()?)
}
"space" => {
input.parse::<Token![:]>()?;
VariantProperty::Space(input.parse::<Expr>()?)
}
"arguments" => {
let generics = if input.peek(Token![<]) {
input.parse::<Token![<]>()?;
let gen_params =
Punctuated::<PathSegment, syn::token::PathSep>::parse_separated_nonempty(input)?;
input.parse::<Token![>]>()?;
Some(gen_params)
} else {
None
};
input.parse::<Token![:]>()?;
if input.peek(token::Brace) {
let fields;
braced!(fields in input);
VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse(
generics, &fields,
)?))
} else {
VariantProperty::Arguments(Arguments::Decl(input.parse::<Type>()?))
}
}
"visit" => {
input.parse::<Token![:]>()?;
VariantProperty::Visit(input.parse::<Expr>()?)
}
"visit_mut" => {
input.parse::<Token![:]>()?;
VariantProperty::VisitMut(input.parse::<Expr>()?)
}
"map" => {
input.parse::<Token![:]>()?;
VariantProperty::Map(input.parse::<Expr>()?)
}
x => {
return Err(syn::Error::new(
key.span(),
format!(
"Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.",
x
),
))
}
}
} else {
return Err(lookahead.error());
})
}
}
pub enum Arguments {
Decl(Type),
Def(InstructionArguments),
}
pub struct InstructionArguments {
pub generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
pub fields: Punctuated<ArgumentField, Token![,]>,
}
impl InstructionArguments {
pub fn parse(
generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
input: syn::parse::ParseStream,
) -> syn::Result<Self> {
let fields = Punctuated::<ArgumentField, Token![,]>::parse_terminated_with(
input,
ArgumentField::parse,
)?;
Ok(Self { generic, fields })
}
fn emit_visit(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
) -> TokenStream {
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit)
}
fn emit_visit_mut(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
) -> TokenStream {
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut)
}
fn emit_visit_map(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
) -> TokenStream {
self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map)
}
fn emit_visit_impl(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
mut fn_: impl FnMut(&ArgumentField, &Option<Option<Expr>>, &Option<Expr>, bool) -> TokenStream,
) -> TokenStream {
let is_ident = if let Some(ref generic) = self.generic {
generic.len() > 1
} else {
false
};
let field_calls = self
.fields
.iter()
.map(|f| fn_(f, parent_type, parent_space, is_ident));
quote! {
#(#field_calls)*
}
}
}
pub struct ArgumentField {
pub name: Ident,
pub is_dst: bool,
pub repr: Type,
pub space: Option<Expr>,
pub type_: Option<Expr>,
pub relaxed_type_check: bool,
}
impl ArgumentField {
fn parse_block(
input: syn::parse::ParseStream,
) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>, bool)> {
let content;
braced!(content in input);
let all_fields =
Punctuated::<ExprOrPath, Token![,]>::parse_terminated_with(&content, |content| {
let lookahead = content.lookahead1();
Ok(if lookahead.peek(Token![type]) {
content.parse::<Token![type]>()?;
content.parse::<Token![:]>()?;
ExprOrPath::Type(content.parse::<Expr>()?)
} else if lookahead.peek(Ident) {
let name_ident = content.parse::<Ident>()?;
content.parse::<Token![:]>()?;
match &*name_ident.to_string() {
"relaxed_type_check" => {
ExprOrPath::RelaxedTypeCheck(content.parse::<LitBool>()?.value)
}
"repr" => ExprOrPath::Repr(content.parse::<Type>()?),
"space" => ExprOrPath::Space(content.parse::<Expr>()?),
"dst" => {
let ident = content.parse::<LitBool>()?;
ExprOrPath::Dst(ident.value)
}
name => {
return Err(syn::Error::new(
name_ident.span(),
format!("Unexpected key `{}`, expected `repr` or `space", name),
))
}
}
} else {
return Err(lookahead.error());
})
})?;
let mut repr = None;
let mut type_ = None;
let mut space = None;
let mut is_dst = None;
let mut relaxed_type_check = false;
for exp_or_path in all_fields {
match exp_or_path {
ExprOrPath::Repr(r) => repr = Some(r),
ExprOrPath::Type(t) => type_ = Some(t),
ExprOrPath::Space(s) => space = Some(s),
ExprOrPath::Dst(x) => is_dst = Some(x),
ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed,
}
}
Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check))
}
fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> {
input.parse::<Type>()
}
fn emit_visit(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
is_ident: bool,
) -> TokenStream {
self.emit_visit_impl(parent_type, parent_space, is_ident, false)
}
fn emit_visit_mut(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
is_ident: bool,
) -> TokenStream {
self.emit_visit_impl(parent_type, parent_space, is_ident, true)
}
fn emit_visit_impl(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
is_ident: bool,
is_mut: bool,
) -> TokenStream {
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
(Some(type_), _) => (false, Some(type_)),
(None, None) => panic!("No type set"),
(None, Some(None)) => (true, None),
(None, Some(Some(type_))) => (false, Some(type_)),
};
let space = self
.space
.as_ref()
.or(parent_space.as_ref())
.map(|space| quote! { #space })
.unwrap_or_else(|| quote! { StateSpace::Reg });
let is_dst = self.is_dst;
let relaxed_type_check = self.relaxed_type_check;
let name = &self.name;
let type_space = if is_typeless {
quote! {
let type_space = None;
}
} else {
quote! {
let type_ = #type_;
let space = #space;
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
}
};
if is_ident {
if is_mut {
quote! {
{
#type_space
visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
}
}
} else {
quote! {
{
#type_space
visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
}
}
}
} else {
let (operand_fn, arguments_name) = if is_mut {
(
quote! {
VisitOperand::visit_mut
},
quote! {
&mut arguments.#name
},
)
} else {
(
quote! {
VisitOperand::visit
},
quote! {
& arguments.#name
},
)
};
quote! {{
#type_space
#operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?;
}}
}
}
fn emit_visit_map(
&self,
parent_type: &Option<Option<Expr>>,
parent_space: &Option<Expr>,
is_ident: bool,
) -> TokenStream {
let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
(Some(type_), _) => (false, Some(type_)),
(None, None) => panic!("No type set"),
(None, Some(None)) => (true, None),
(None, Some(Some(type_))) => (false, Some(type_)),
};
let space = self
.space
.as_ref()
.or(parent_space.as_ref())
.map(|space| quote! { #space })
.unwrap_or_else(|| quote! { StateSpace::Reg });
let is_dst = self.is_dst;
let relaxed_type_check = self.relaxed_type_check;
let name = &self.name;
let type_space = if is_typeless {
quote! {
let type_space = None;
}
} else {
quote! {
let type_ = #type_;
let space = #space;
let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
}
};
let map_call = if is_ident {
quote! {
visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)?
}
} else {
quote! {
MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?
}
};
quote! {
let #name = {
#type_space
#map_call
};
}
}
fn is_dst(name: &Ident) -> syn::Result<bool> {
if name.to_string().starts_with("dst") {
Ok(true)
} else if name.to_string().starts_with("src") {
Ok(false)
} else {
return Err(syn::Error::new(
name.span(),
format!(
"Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`",
name
),
));
}
}
fn emit_field(&self, vis: &Option<Visibility>) -> TokenStream {
let name = &self.name;
let type_ = &self.repr;
quote! {
#vis #name: #type_
}
}
}
impl Parse for ArgumentField {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name = input.parse::<Ident>()?;
input.parse::<Token![:]>()?;
let lookahead = input.lookahead1();
let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) {
Self::parse_block(input)?
} else if lookahead.peek(syn::Ident) {
(Self::parse_basic(input)?, None, None, None, false)
} else {
return Err(lookahead.error());
};
let is_dst = match is_dst {
Some(x) => x,
None => Self::is_dst(&name)?,
};
Ok(Self {
name,
is_dst,
repr,
type_,
space,
relaxed_type_check
})
}
}
enum ExprOrPath {
Repr(Type),
Type(Expr),
Space(Expr),
Dst(bool),
RelaxedTypeCheck(bool),
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use quote::{quote, ToTokens};
fn to_string(x: impl ToTokens) -> String {
quote! { #x }.to_string()
}
#[test]
fn parse_argument_field_basic() {
let input = quote! {
dst: P::Operand
};
let arg = syn::parse2::<ArgumentField>(input).unwrap();
assert_eq!("dst", arg.name.to_string());
assert_eq!("P :: Operand", to_string(arg.repr));
assert!(matches!(arg.type_, None));
}
#[test]
fn parse_argument_field_block() {
let input = quote! {
dst: {
type: ScalarType::U32,
space: StateSpace::Global,
repr: P::Operand,
}
};
let arg = syn::parse2::<ArgumentField>(input).unwrap();
assert_eq!("dst", arg.name.to_string());
assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap()));
assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap()));
assert_eq!("P :: Operand", to_string(arg.repr));
}
#[test]
fn parse_argument_field_block_untyped() {
let input = quote! {
dst: {
repr: P::Operand,
}
};
let arg = syn::parse2::<ArgumentField>(input).unwrap();
assert_eq!("dst", arg.name.to_string());
assert_eq!("P :: Operand", to_string(arg.repr));
assert!(matches!(arg.type_, None));
}
#[test]
fn parse_variant_complex() {
let input = quote! {
Ld {
type: ScalarType::U32,
space: StateSpace::Global,
data: LdDetails,
arguments<P>: {
dst: {
repr: P::Operand,
type: ScalarType::U32,
space: StateSpace::Shared,
},
src: P::Operand,
},
}
};
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
assert_eq!("Ld", variant.name.to_string());
assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
assert_eq!("LdDetails", to_string(variant.data.unwrap()));
let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
a
} else {
panic!()
};
assert_eq!("P", to_string(arguments.generic));
let mut fields = arguments.fields.into_iter();
let dst = fields.next().unwrap();
assert_eq!("P :: Operand", to_string(dst.repr));
assert_eq!("ScalarType :: U32", to_string(dst.type_));
assert_eq!("StateSpace :: Shared", to_string(dst.space));
let src = fields.next().unwrap();
assert_eq!("P :: Operand", to_string(src.repr));
assert!(matches!(src.type_, None));
assert!(matches!(src.space, None));
}
#[test]
fn visit_variant_empty() {
let input = quote! {
Ret {
data: RetData
}
};
let variant = syn::parse2::<InstructionVariant>(input).unwrap();
let mut output = TokenStream::new();
variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output);
assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }");
}
}

View file

@ -0,0 +1,844 @@
use proc_macro2::Span;
use proc_macro2::TokenStream;
use quote::quote;
use quote::ToTokens;
use rustc_hash::FxHashMap;
use std::fmt::Write;
use syn::bracketed;
use syn::parse::Peek;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::LitInt;
use syn::Type;
use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token};
pub struct ParseDefinitions {
pub token_type: ItemEnum,
pub additional_enums: FxHashMap<Ident, ItemEnum>,
pub definitions: Vec<OpcodeDefinition>,
}
impl Parse for ParseDefinitions {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let token_type = input.parse::<ItemEnum>()?;
let mut additional_enums = FxHashMap::default();
while input.peek(Token![#]) {
let enum_ = input.parse::<ItemEnum>()?;
additional_enums.insert(enum_.ident.clone(), enum_);
}
let mut definitions = Vec::new();
while !input.is_empty() {
definitions.push(input.parse::<OpcodeDefinition>()?);
}
Ok(Self {
token_type,
additional_enums,
definitions,
})
}
}
pub struct OpcodeDefinition(pub Patterns, pub Vec<Rule>);
impl Parse for OpcodeDefinition {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let patterns = input.parse::<Patterns>()?;
let mut rules = Vec::new();
while Rule::peek(input) {
rules.push(input.parse::<Rule>()?);
input.parse::<Token![;]>()?;
}
Ok(Self(patterns, rules))
}
}
pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>);
impl Parse for Patterns {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut result = Vec::new();
loop {
if !OpcodeDecl::peek(input) {
break;
}
let decl = input.parse::<OpcodeDecl>()?;
let code_block = input.parse::<CodeBlock>()?;
result.push((decl, code_block))
}
Ok(Self(result))
}
}
pub struct OpcodeDecl(pub Instruction, pub Arguments);
impl OpcodeDecl {
fn peek(input: syn::parse::ParseStream) -> bool {
Instruction::peek(input) && !input.peek2(Token![=])
}
}
impl Parse for OpcodeDecl {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
Ok(Self(
input.parse::<Instruction>()?,
input.parse::<Arguments>()?,
))
}
}
pub struct CodeBlock {
pub special: bool,
pub code: proc_macro2::Group,
}
impl Parse for CodeBlock {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
let (special, code) = if lookahead.peek(Token![<]) {
input.parse::<Token![<]>()?;
input.parse::<Token![=]>()?;
//input.parse::<Token![>]>()?;
(true, input.parse::<proc_macro2::Group>()?)
} else if lookahead.peek(Token![=]) {
input.parse::<Token![=]>()?;
input.parse::<Token![>]>()?;
(false, input.parse::<proc_macro2::Group>()?)
} else {
return Err(lookahead.error());
};
Ok(Self { special, code })
}
}
pub struct Rule {
pub modifier: Option<DotModifier>,
pub type_: Option<Type>,
pub alternatives: Vec<DotModifier>,
}
impl Rule {
fn peek(input: syn::parse::ParseStream) -> bool {
DotModifier::peek(input)
|| (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>]))
}
fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result<Vec<DotModifier>> {
let mut result = Vec::new();
Self::parse_with_alternative(input, &mut result)?;
loop {
if !input.peek(Token![,]) {
break;
}
input.parse::<Token![,]>()?;
Self::parse_with_alternative(input, &mut result)?;
}
Ok(result)
}
fn parse_with_alternative(
input: &syn::parse::ParseBuffer,
result: &mut Vec<DotModifier>,
) -> Result<(), syn::Error> {
input.parse::<Token![.]>()?;
let part1 = input.parse::<IdentLike>()?;
if input.peek(token::Brace) {
result.push(DotModifier {
part1: part1.clone(),
part2: None,
});
let suffix_content;
braced!(suffix_content in input);
let suffixes = Punctuated::<IdentOrTypeSuffix, Token![,]>::parse_separated_nonempty(
&suffix_content,
)?;
for part2 in suffixes {
result.push(DotModifier {
part1: part1.clone(),
part2: Some(part2),
});
}
} else if IdentOrTypeSuffix::peek(input) {
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
result.push(DotModifier { part1, part2 });
} else {
result.push(DotModifier { part1, part2: None });
}
Ok(())
}
}
#[derive(PartialEq, Eq, Hash, Clone)]
struct IdentOrTypeSuffix(IdentLike);
impl IdentOrTypeSuffix {
fn span(&self) -> Span {
self.0.span()
}
fn peek(input: syn::parse::ParseStream) -> bool {
input.peek(Token![::])
}
}
impl ToTokens for IdentOrTypeSuffix {
fn to_tokens(&self, tokens: &mut TokenStream) {
let ident = &self.0;
quote! { :: #ident }.to_tokens(tokens)
}
}
impl Parse for IdentOrTypeSuffix {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
input.parse::<Token![::]>()?;
Ok(Self(input.parse::<IdentLike>()?))
}
}
impl Parse for Rule {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let (modifier, type_) = if DotModifier::peek(input) {
let modifier = Some(input.parse::<DotModifier>()?);
if input.peek(Token![:]) {
input.parse::<Token![:]>()?;
(modifier, Some(input.parse::<Type>()?))
} else {
(modifier, None)
}
} else {
(None, Some(input.parse::<Type>()?))
};
input.parse::<Token![=]>()?;
let content;
braced!(content in input);
let alternatives = Self::parse_alternatives(&content)?;
Ok(Self {
modifier,
type_,
alternatives,
})
}
}
pub struct Instruction {
pub name: Ident,
pub modifiers: Vec<MaybeDotModifier>,
}
impl Instruction {
fn peek(input: syn::parse::ParseStream) -> bool {
input.peek(Ident)
}
}
impl Parse for Instruction {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let instruction = input.parse::<Ident>()?;
let mut modifiers = Vec::new();
loop {
if !MaybeDotModifier::peek(input) {
break;
}
modifiers.push(MaybeDotModifier::parse(input)?);
}
Ok(Self {
name: instruction,
modifiers,
})
}
}
pub struct MaybeDotModifier {
pub optional: bool,
pub modifier: DotModifier,
}
impl MaybeDotModifier {
fn peek(input: syn::parse::ParseStream) -> bool {
input.peek(token::Brace) || DotModifier::peek(input)
}
}
impl Parse for MaybeDotModifier {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
Ok(if input.peek(token::Brace) {
let content;
braced!(content in input);
let modifier = DotModifier::parse(&content)?;
Self {
modifier,
optional: true,
}
} else {
let modifier = DotModifier::parse(input)?;
Self {
modifier,
optional: false,
}
})
}
}
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct DotModifier {
part1: IdentLike,
part2: Option<IdentOrTypeSuffix>,
}
impl std::fmt::Display for DotModifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, ".")?;
self.part1.fmt(f)?;
if let Some(ref part2) = self.part2 {
write!(f, "::")?;
part2.0.fmt(f)?;
}
Ok(())
}
}
impl std::fmt::Debug for DotModifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self, f)
}
}
impl DotModifier {
pub fn span(&self) -> Span {
let part1 = self.part1.span();
if let Some(ref part2) = self.part2 {
part1.join(part2.span()).unwrap_or(part1)
} else {
part1
}
}
pub fn ident(&self) -> Ident {
let mut result = String::new();
write!(&mut result, "{}", self.part1).unwrap();
if let Some(ref part2) = self.part2 {
write!(&mut result, "_{}", part2.0).unwrap();
} else {
match self.part1 {
IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'),
IdentLike::Ident(_) | IdentLike::Integer(_) => {}
}
}
Ident::new(&result.to_ascii_lowercase(), self.span())
}
pub fn variant_capitalized(&self) -> Ident {
self.capitalized_impl(String::new())
}
pub fn dot_capitalized(&self) -> Ident {
self.capitalized_impl("Dot".to_string())
}
fn capitalized_impl(&self, prefix: String) -> Ident {
let mut temp = String::new();
write!(&mut temp, "{}", &self.part1).unwrap();
if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 {
write!(&mut temp, "_{}", part2).unwrap();
}
let mut result = prefix;
let mut capitalize = true;
for c in temp.chars() {
if c == '_' {
capitalize = true;
continue;
}
// Special hack to emit `BF16`` instead of `Bf16``
let c = if capitalize || c == 'f' && result.ends_with('B') {
capitalize = false;
c.to_ascii_uppercase()
} else {
c
};
result.push(c);
}
Ident::new(&result, self.span())
}
pub fn tokens(&self) -> TokenStream {
let part1 = &self.part1;
let part2 = &self.part2;
match self.part2 {
None => quote! { . #part1 },
Some(_) => quote! { . #part1 #part2 },
}
}
fn peek(input: syn::parse::ParseStream) -> bool {
input.peek(Token![.])
}
}
impl Parse for DotModifier {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
input.parse::<Token![.]>()?;
let part1 = input.parse::<IdentLike>()?;
if IdentOrTypeSuffix::peek(input) {
let part2 = Some(IdentOrTypeSuffix::parse(input)?);
Ok(Self { part1, part2 })
} else {
Ok(Self { part1, part2: None })
}
}
}
#[derive(PartialEq, Eq, Hash, Clone)]
enum IdentLike {
Type(Token![type]),
Const(Token![const]),
Ident(Ident),
Integer(LitInt),
}
impl IdentLike {
fn span(&self) -> Span {
match self {
IdentLike::Type(c) => c.span(),
IdentLike::Const(t) => t.span(),
IdentLike::Ident(i) => i.span(),
IdentLike::Integer(l) => l.span(),
}
}
}
impl std::fmt::Display for IdentLike {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IdentLike::Type(_) => f.write_str("type"),
IdentLike::Const(_) => f.write_str("const"),
IdentLike::Ident(ident) => write!(f, "{}", ident),
IdentLike::Integer(integer) => write!(f, "{}", integer),
}
}
}
impl ToTokens for IdentLike {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
IdentLike::Type(_) => quote! { type }.to_tokens(tokens),
IdentLike::Const(_) => quote! { const }.to_tokens(tokens),
IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens),
IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens),
}
}
}
impl Parse for IdentLike {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
Ok(if lookahead.peek(Token![const]) {
IdentLike::Const(input.parse::<Token![const]>()?)
} else if lookahead.peek(Token![type]) {
IdentLike::Type(input.parse::<Token![type]>()?)
} else if lookahead.peek(Ident) {
IdentLike::Ident(input.parse::<Ident>()?)
} else if lookahead.peek(LitInt) {
IdentLike::Integer(input.parse::<LitInt>()?)
} else {
return Err(lookahead.error());
})
}
}
// Arguments decalaration can loook like this:
// a{, b}
// That's why we don't parse Arguments as Punctuated<Argument, Token![,]>
#[derive(PartialEq, Eq)]
pub struct Arguments(pub Vec<Argument>);
impl Parse for Arguments {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut result = Vec::new();
loop {
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
let mut optional = false;
let mut can_be_negated = false;
let mut pre_pipe = false;
let ident;
let lookahead = input.lookahead1();
if lookahead.peek(token::Brace) {
let content;
braced!(content in input);
let lookahead = content.lookahead1();
if lookahead.peek(Token![!]) {
content.parse::<Token![!]>()?;
can_be_negated = true;
ident = input.parse::<Ident>()?;
} else if lookahead.peek(Token![,]) {
optional = true;
content.parse::<Token![,]>()?;
ident = content.parse::<Ident>()?;
} else {
return Err(lookahead.error());
}
} else if lookahead.peek(token::Bracket) {
let bracketed;
bracketed!(bracketed in input);
if bracketed.peek(Token![|]) {
optional = true;
bracketed.parse::<Token![|]>()?;
pre_pipe = true;
ident = bracketed.parse::<Ident>()?;
} else {
let mut sub_args = Self::parse(&bracketed)?;
sub_args.0.first_mut().unwrap().pre_bracket = true;
sub_args.0.last_mut().unwrap().post_bracket = true;
if peek_brace_token(input, Token![.]) {
let optional_suffix;
braced!(optional_suffix in input);
optional_suffix.parse::<Token![.]>()?;
let unified_ident = optional_suffix.parse::<Ident>()?;
if unified_ident.to_string() != "unified" {
return Err(syn::Error::new(
unified_ident.span(),
format!("Exptected `unified`, got `{}`", unified_ident),
));
}
for a in sub_args.0.iter_mut() {
a.unified = true;
}
}
result.extend(sub_args.0);
continue;
}
} else if lookahead.peek(Ident) {
ident = input.parse::<Ident>()?;
} else if lookahead.peek(Token![|]) {
input.parse::<Token![|]>()?;
pre_pipe = true;
ident = input.parse::<Ident>()?;
} else {
break;
}
result.push(Argument {
optional,
pre_pipe,
can_be_negated,
pre_bracket: false,
ident,
post_bracket: false,
unified: false,
});
}
Ok(Self(result))
}
}
// This is effectively input.peek(token::Brace) && input.peek2(Token![.])
// input.peek2 is supposed to skip over next token, but it skips over whole
// braced token group. Not sure if it's a bug
fn peek_brace_token<T: Peek>(input: syn::parse::ParseStream, _t: T) -> bool {
use syn::token::Token;
let cursor = input.cursor();
cursor
.group(proc_macro2::Delimiter::Brace)
.map_or(false, |(content, ..)| T::Token::peek(content))
}
#[derive(PartialEq, Eq)]
pub struct Argument {
pub optional: bool,
pub pre_bracket: bool,
pub pre_pipe: bool,
pub can_be_negated: bool,
pub ident: Ident,
pub post_bracket: bool,
pub unified: bool,
}
#[cfg(test)]
mod tests {
use super::{Arguments, DotModifier, MaybeDotModifier};
use quote::{quote, ToTokens};
#[test]
fn parse_modifier_complex() {
let input = quote! {
.level::eviction_priority
};
let modifier = syn::parse2::<DotModifier>(input).unwrap();
assert_eq!(
". level :: eviction_priority",
modifier.tokens().to_string()
);
}
#[test]
fn parse_modifier_optional() {
let input = quote! {
{ .level::eviction_priority }
};
let maybe_modifider = syn::parse2::<MaybeDotModifier>(input).unwrap();
assert_eq!(
". level :: eviction_priority",
maybe_modifider.modifier.tokens().to_string()
);
assert!(maybe_modifider.optional);
}
#[test]
fn parse_type_token() {
let input = quote! {
. type
};
let maybe_modifier = syn::parse2::<MaybeDotModifier>(input).unwrap();
assert_eq!(". type", maybe_modifier.modifier.tokens().to_string());
assert!(!maybe_modifier.optional);
}
#[test]
fn arguments_memory() {
let input = quote! {
[a], b
};
let arguments = syn::parse2::<Arguments>(input).unwrap();
let a = &arguments.0[0];
assert!(!a.optional);
assert_eq!("a", a.ident.to_string());
assert!(a.pre_bracket);
assert!(!a.pre_pipe);
assert!(a.post_bracket);
assert!(!a.can_be_negated);
let b = &arguments.0[1];
assert!(!b.optional);
assert_eq!("b", b.ident.to_string());
assert!(!b.pre_bracket);
assert!(!b.pre_pipe);
assert!(!b.post_bracket);
assert!(!b.can_be_negated);
}
#[test]
fn arguments_optional() {
let input = quote! {
b{, cache_policy}
};
let arguments = syn::parse2::<Arguments>(input).unwrap();
let b = &arguments.0[0];
assert!(!b.optional);
assert_eq!("b", b.ident.to_string());
assert!(!b.pre_bracket);
assert!(!b.pre_pipe);
assert!(!b.post_bracket);
assert!(!b.can_be_negated);
let cache_policy = &arguments.0[1];
assert!(cache_policy.optional);
assert_eq!("cache_policy", cache_policy.ident.to_string());
assert!(!cache_policy.pre_bracket);
assert!(!cache_policy.pre_pipe);
assert!(!cache_policy.post_bracket);
assert!(!cache_policy.can_be_negated);
}
#[test]
fn arguments_optional_pred() {
let input = quote! {
p[|q], a
};
let arguments = syn::parse2::<Arguments>(input).unwrap();
assert_eq!(arguments.0.len(), 3);
let p = &arguments.0[0];
assert!(!p.optional);
assert_eq!("p", p.ident.to_string());
assert!(!p.pre_bracket);
assert!(!p.pre_pipe);
assert!(!p.post_bracket);
assert!(!p.can_be_negated);
let q = &arguments.0[1];
assert!(q.optional);
assert_eq!("q", q.ident.to_string());
assert!(!q.pre_bracket);
assert!(q.pre_pipe);
assert!(!q.post_bracket);
assert!(!q.can_be_negated);
let a = &arguments.0[2];
assert!(!a.optional);
assert_eq!("a", a.ident.to_string());
assert!(!a.pre_bracket);
assert!(!a.pre_pipe);
assert!(!a.post_bracket);
assert!(!a.can_be_negated);
}
#[test]
fn arguments_optional_with_negate() {
let input = quote! {
b, {!}c
};
let arguments = syn::parse2::<Arguments>(input).unwrap();
assert_eq!(arguments.0.len(), 2);
let b = &arguments.0[0];
assert!(!b.optional);
assert_eq!("b", b.ident.to_string());
assert!(!b.pre_bracket);
assert!(!b.pre_pipe);
assert!(!b.post_bracket);
assert!(!b.can_be_negated);
let c = &arguments.0[1];
assert!(!c.optional);
assert_eq!("c", c.ident.to_string());
assert!(!c.pre_bracket);
assert!(!c.pre_pipe);
assert!(!c.post_bracket);
assert!(c.can_be_negated);
}
#[test]
fn arguments_tex() {
let input = quote! {
d[|p], [a{, b}, c], dpdx, dpdy {, e}
};
let arguments = syn::parse2::<Arguments>(input).unwrap();
assert_eq!(arguments.0.len(), 8);
{
let d = &arguments.0[0];
assert!(!d.optional);
assert_eq!("d", d.ident.to_string());
assert!(!d.pre_bracket);
assert!(!d.pre_pipe);
assert!(!d.post_bracket);
assert!(!d.can_be_negated);
}
{
let p = &arguments.0[1];
assert!(p.optional);
assert_eq!("p", p.ident.to_string());
assert!(!p.pre_bracket);
assert!(p.pre_pipe);
assert!(!p.post_bracket);
assert!(!p.can_be_negated);
}
{
let a = &arguments.0[2];
assert!(!a.optional);
assert_eq!("a", a.ident.to_string());
assert!(a.pre_bracket);
assert!(!a.pre_pipe);
assert!(!a.post_bracket);
assert!(!a.can_be_negated);
}
{
let b = &arguments.0[3];
assert!(b.optional);
assert_eq!("b", b.ident.to_string());
assert!(!b.pre_bracket);
assert!(!b.pre_pipe);
assert!(!b.post_bracket);
assert!(!b.can_be_negated);
}
{
let c = &arguments.0[4];
assert!(!c.optional);
assert_eq!("c", c.ident.to_string());
assert!(!c.pre_bracket);
assert!(!c.pre_pipe);
assert!(c.post_bracket);
assert!(!c.can_be_negated);
}
{
let dpdx = &arguments.0[5];
assert!(!dpdx.optional);
assert_eq!("dpdx", dpdx.ident.to_string());
assert!(!dpdx.pre_bracket);
assert!(!dpdx.pre_pipe);
assert!(!dpdx.post_bracket);
assert!(!dpdx.can_be_negated);
}
{
let dpdy = &arguments.0[6];
assert!(!dpdy.optional);
assert_eq!("dpdy", dpdy.ident.to_string());
assert!(!dpdy.pre_bracket);
assert!(!dpdy.pre_pipe);
assert!(!dpdy.post_bracket);
assert!(!dpdy.can_be_negated);
}
{
let e = &arguments.0[7];
assert!(e.optional);
assert_eq!("e", e.ident.to_string());
assert!(!e.pre_bracket);
assert!(!e.pre_pipe);
assert!(!e.post_bracket);
assert!(!e.can_be_negated);
}
}
#[test]
fn rule_multi() {
let input = quote! {
.ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }
};
let rule = syn::parse2::<super::Rule>(input).unwrap();
assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string());
assert_eq!(
"StateSpace",
rule.type_.unwrap().to_token_stream().to_string()
);
let alts = rule
.alternatives
.iter()
.map(|m| m.tokens().to_string())
.collect::<Vec<_>>();
assert_eq!(
vec![
". global",
". local",
". param",
". param :: func",
". shared",
". shared :: cta",
". shared :: cluster"
],
alts
);
}
#[test]
fn rule_multi2() {
let input = quote! {
.cop: StCacheOperator = { .wb, .cg, .cs, .wt }
};
let rule = syn::parse2::<super::Rule>(input).unwrap();
assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string());
assert_eq!(
"StCacheOperator",
rule.type_.unwrap().to_token_stream().to_string()
);
let alts = rule
.alternatives
.iter()
.map(|m| m.tokens().to_string())
.collect::<Vec<_>>();
assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts);
}
#[test]
fn args_unified() {
let input = quote! {
d, [a]{.unified}{, cache_policy}
};
let args = syn::parse2::<super::Arguments>(input).unwrap();
let a = &args.0[1];
assert!(!a.optional);
assert_eq!("a", a.ident.to_string());
assert!(a.pre_bracket);
assert!(!a.pre_pipe);
assert!(a.post_bracket);
assert!(!a.can_be_negated);
assert!(a.unified);
}
#[test]
fn special_block() {
let input = quote! {
bra <= { bra(stream) }
};
syn::parse2::<super::OpcodeDefinition>(input).unwrap();
}
}