mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-03 16:17:11 +00:00
Add pass test mechanism for insert_implicit_conversions (#477)
This commit is contained in:
parent
e805cb72a5
commit
00eb553454
9 changed files with 461 additions and 26 deletions
|
@ -29,6 +29,9 @@ mod replace_instructions_with_functions_fp_required;
|
|||
mod replace_known_functions;
|
||||
mod resolve_function_pointers;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
|
||||
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
|
||||
|
||||
|
@ -577,7 +580,18 @@ struct ImplicitConversion {
|
|||
kind: ConversionKind,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone)]
|
||||
impl std::fmt::Display for ImplicitConversion {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"zluda.convert_implicit{}{}{}{}{}",
|
||||
self.kind, self.to_space, self.to_type, self.from_space, self.from_type
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone, strum_macros::Display)]
|
||||
#[strum(serialize_all = "snake_case", prefix = ".")]
|
||||
enum ConversionKind {
|
||||
Default,
|
||||
// zero-extend/chop/bitcast depending on types
|
||||
|
@ -617,6 +631,12 @@ struct FunctionPointerDetails {
|
|||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
|
||||
pub struct SpirvWord(u32);
|
||||
|
||||
impl std::fmt::Display for SpirvWord {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "%{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for SpirvWord {
|
||||
fn from(value: u32) -> Self {
|
||||
Self(value)
|
||||
|
|
22
ptx/src/pass/test/insert_implicit_conversions/default.ptx
Normal file
22
ptx/src/pass/test/insert_implicit_conversions/default.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.func (.reg .b32 output) default (
|
||||
.reg .u32 input
|
||||
)
|
||||
{
|
||||
mov.b32 output, input;
|
||||
ret;
|
||||
}
|
||||
|
||||
// %%% output %%%
|
||||
|
||||
.func (.reg .b32 %2) %1 (
|
||||
.reg .u32 %3
|
||||
)
|
||||
{
|
||||
.b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.u32 %3;
|
||||
mov.b32 %2, %4;
|
||||
ret;
|
||||
}
|
22
ptx/src/pass/test/insert_implicit_conversions/mod.rs
Normal file
22
ptx/src/pass/test/insert_implicit_conversions/mod.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
use crate::pass::{test::directive2_vec_to_string, *};
|
||||
|
||||
use super::test_pass;
|
||||
|
||||
macro_rules! test_insert_implicit_conversions {
|
||||
($test_name:ident) => {
|
||||
test_pass!(run_insert_implicit_conversions, $test_name);
|
||||
};
|
||||
}
|
||||
|
||||
fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
|
||||
// We run the minimal number of passes required to produce the input expected by insert_implicit_conversions
|
||||
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
|
||||
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
|
||||
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap();
|
||||
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
|
||||
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
|
||||
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap();
|
||||
directive2_vec_to_string(&flat_resolver, directives)
|
||||
}
|
||||
|
||||
test_insert_implicit_conversions!(default);
|
247
ptx/src/pass/test/mod.rs
Normal file
247
ptx/src/pass/test/mod.rs
Normal file
|
@ -0,0 +1,247 @@
|
|||
use ptx_parser as ast;
|
||||
use std::{
|
||||
env, error,
|
||||
fs::{self, File},
|
||||
io::Write,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
mod insert_implicit_conversions;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_pass {
|
||||
($pass:expr, $test_name:ident) => {
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [<$test_name>]() -> Result<(), Box<dyn std::error::Error>> {
|
||||
use crate::test::read_test_file;
|
||||
let ptx = read_test_file!(concat!(stringify!($test_name), ".ptx"));
|
||||
let mut parts = ptx.split("// %%% output %%%");
|
||||
let ptx_in = parts.next().unwrap_or("").trim();
|
||||
let ptx_out = parts.next().unwrap_or("").trim();
|
||||
assert!(parts.next().is_none());
|
||||
crate::pass::test::test_pass_assert(stringify!($test_name), $pass, ptx_in, ptx_out)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
pub(crate) use test_pass;
|
||||
|
||||
use crate::pass::IdentEntry;
|
||||
|
||||
use super::{Directive2, Function2, GlobalStringIdentResolver2, SpirvWord, Statement};
|
||||
|
||||
fn directive2_vec_to_string(
|
||||
resolver: &GlobalStringIdentResolver2,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
) -> String {
|
||||
directives
|
||||
.into_iter()
|
||||
.map(|d| directive_to_string(resolver, d) + "\n")
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
}
|
||||
|
||||
fn directive_to_string(
|
||||
resolver: &GlobalStringIdentResolver2,
|
||||
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> String {
|
||||
match directive {
|
||||
Directive2::Variable(linking_directive, variable) => {
|
||||
let ld_string = if !linking_directive.is_empty() {
|
||||
format!("{} ", linking_directive)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
format!("{}{};", ld_string, variable)
|
||||
}
|
||||
Directive2::Method(function) => function_to_string(resolver, function),
|
||||
}
|
||||
}
|
||||
|
||||
fn function_to_string(
|
||||
resolver: &GlobalStringIdentResolver2,
|
||||
function: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> String {
|
||||
if function.import_as.is_some()
|
||||
|| function.tuning.len() > 0
|
||||
|| function.flush_to_zero_f32
|
||||
|| function.flush_to_zero_f16f64
|
||||
|| function.rounding_mode_f32 != ast::RoundingMode::NearestEven
|
||||
|| function.rounding_mode_f16f64 != ast::RoundingMode::NearestEven
|
||||
{
|
||||
todo!("Figure out some way of representing these in text");
|
||||
}
|
||||
|
||||
let linkage = if !function.linkage.is_empty() {
|
||||
format!("{} ", function.linkage)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let entry = if !function.is_kernel {
|
||||
format!(".func ")
|
||||
} else {
|
||||
format!(".entry ")
|
||||
};
|
||||
|
||||
let return_arguments = if function.return_arguments.len() > 0 {
|
||||
let args = function
|
||||
.return_arguments
|
||||
.iter()
|
||||
.map(|arg| format!("{}", arg))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
format!("({}) ", args)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let input_arguments = function
|
||||
.input_arguments
|
||||
.iter()
|
||||
.map(|arg| format!("\n {}", arg))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
let body = if let Some(stmts) = function.body {
|
||||
let stmt_strings = stmts
|
||||
.into_iter()
|
||||
.map(|stmt| format!(" {}\n", statement_to_string(resolver, stmt)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
format!("\n{{\n{}}}", stmt_strings)
|
||||
} else {
|
||||
format!(";")
|
||||
};
|
||||
|
||||
format!(
|
||||
"{}{}{}{} ({}\n){}",
|
||||
linkage, entry, return_arguments, function.name, input_arguments, body
|
||||
)
|
||||
}
|
||||
|
||||
struct StatementFormatter<'a> {
|
||||
resolver: &'a GlobalStringIdentResolver2<'a>,
|
||||
dst_strings: Vec<String>,
|
||||
other_args: Vec<SpirvWord>,
|
||||
}
|
||||
|
||||
impl<'a> StatementFormatter<'a> {
|
||||
fn new(resolver: &'a GlobalStringIdentResolver2<'a>) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
dst_strings: Vec::new(),
|
||||
other_args: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn format(&self, op: &str) -> String {
|
||||
let assign_temps = if self.dst_strings.len() > 0 {
|
||||
let temps = self.dst_strings.join(", ");
|
||||
format!("{} = ", temps)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let args = self
|
||||
.other_args
|
||||
.iter()
|
||||
.map(|arg| format!(" {}", arg))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
format!("{}{}{};", assign_temps, op, args)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ast::VisitorMap<SpirvWord, SpirvWord, ()> for StatementFormatter<'a> {
|
||||
fn visit(
|
||||
&mut self,
|
||||
arg: SpirvWord,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<SpirvWord, ()> {
|
||||
if is_dst {
|
||||
if let Some(IdentEntry { name: None, .. }) = self.resolver.ident_map.get(&arg) {
|
||||
let type_string = if let Some((type_, state_space)) = type_space {
|
||||
format!("{}{} ", type_, state_space)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
self.dst_strings.push(format!("{}{}", type_string, arg));
|
||||
|
||||
return Ok(arg);
|
||||
}
|
||||
}
|
||||
|
||||
self.other_args.push(arg);
|
||||
|
||||
Ok(arg)
|
||||
}
|
||||
|
||||
fn visit_ident(
|
||||
&mut self,
|
||||
arg: <SpirvWord as ptx_parser::Operand>::Ident,
|
||||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
|
||||
is_dst: bool,
|
||||
relaxed_type_check: bool,
|
||||
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, ()> {
|
||||
self.visit(arg, type_space, is_dst, relaxed_type_check)
|
||||
}
|
||||
}
|
||||
|
||||
fn statement_to_string(
|
||||
resolver: &GlobalStringIdentResolver2,
|
||||
stmt: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
|
||||
) -> String {
|
||||
let op = match &stmt {
|
||||
Statement::Variable(var) => format!("{}", var),
|
||||
Statement::Instruction(instr) => format!("{}", instr),
|
||||
Statement::Conversion(conv) => format!("{}", conv),
|
||||
_ => todo!(),
|
||||
};
|
||||
let mut args_formatter = StatementFormatter::new(resolver);
|
||||
stmt.visit_map(&mut args_formatter);
|
||||
args_formatter.format(&op)
|
||||
}
|
||||
|
||||
fn test_pass_assert<F, D>(
|
||||
name: &str,
|
||||
run_pass: F,
|
||||
ptx_in: &str,
|
||||
expected_ptx_out: &str,
|
||||
) -> Result<(), Box<dyn error::Error>>
|
||||
where
|
||||
F: FnOnce(ast::Module) -> D,
|
||||
D: std::fmt::Display,
|
||||
{
|
||||
let actual_ptx_out = ast::parse_module_checked(ptx_in)
|
||||
.map(|ast| {
|
||||
let result = run_pass(ast);
|
||||
result.to_string()
|
||||
})
|
||||
.unwrap_or("".to_string());
|
||||
compare_ptx(name, ptx_in, actual_ptx_out.trim(), expected_ptx_out);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compare_ptx(name: &str, ptx_in: &str, actual_ptx_out: &str, expected_ptx_out: &str) {
|
||||
if actual_ptx_out != expected_ptx_out {
|
||||
let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR");
|
||||
if let Ok(output_dir) = output_dir {
|
||||
let output_dir = Path::new(&output_dir);
|
||||
fs::create_dir_all(&output_dir).unwrap();
|
||||
let output_file = output_dir.join(format!("{}.ptx", name));
|
||||
let mut output_file = File::create(output_file).unwrap();
|
||||
output_file.write_all(ptx_in.as_bytes()).unwrap();
|
||||
output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap();
|
||||
output_file.write_all(actual_ptx_out.as_bytes()).unwrap();
|
||||
}
|
||||
let comparison = pretty_assertions::StrComparison::new(expected_ptx_out, actual_ptx_out);
|
||||
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
|
||||
}
|
||||
}
|
|
@ -3,13 +3,11 @@ use ptx_parser as ast;
|
|||
|
||||
mod spirv_run;
|
||||
|
||||
#[cfg(not(feature = "ci_build"))]
|
||||
#[macro_export]
|
||||
macro_rules! read_test_file {
|
||||
($file:expr) => {
|
||||
{
|
||||
if cfg!(feature = "ci_build") {
|
||||
include_str!($file).to_string()
|
||||
} else {
|
||||
use std::path::PathBuf;
|
||||
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
|
@ -19,7 +17,14 @@ macro_rules! read_test_file {
|
|||
path.push($file);
|
||||
std::fs::read_to_string(path).unwrap()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(feature = "ci_build")]
|
||||
#[macro_export]
|
||||
macro_rules! read_test_file {
|
||||
($file:expr) => {
|
||||
include_str!($file).to_string()
|
||||
};
|
||||
}
|
||||
pub(crate) use read_test_file;
|
||||
|
|
|
@ -303,7 +303,8 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
repr: T,
|
||||
space: { data.state_space },
|
||||
}
|
||||
}
|
||||
},
|
||||
display: write!(f, "<TODO:finish ld>")?
|
||||
},
|
||||
Lg2 {
|
||||
type: Type::Scalar(ScalarType::F32),
|
||||
|
@ -356,7 +357,8 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
arguments<T>: {
|
||||
dst: T,
|
||||
src: T
|
||||
}
|
||||
},
|
||||
display: write!(f, "mov{}", data.typ)?
|
||||
},
|
||||
Mul {
|
||||
type: { Type::from(data.type_()) },
|
||||
|
@ -457,7 +459,8 @@ ptx_parser_macros::generate_instruction_type!(
|
|||
}
|
||||
},
|
||||
Ret {
|
||||
data: RetData
|
||||
data: RetData,
|
||||
display: write!(f, "ret")?
|
||||
},
|
||||
Rsqrt {
|
||||
type: { Type::from(data.type_) },
|
||||
|
@ -926,6 +929,40 @@ pub struct Variable<ID> {
|
|||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.state_space)?;
|
||||
|
||||
if let Some(align) = self.align {
|
||||
write!(f, " .align {}", align)?;
|
||||
}
|
||||
|
||||
let (vector_size, scalar_type, array_dims) = match &self.v_type {
|
||||
Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]),
|
||||
Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]),
|
||||
Type::Array(vector_size, scalar_type, array_dims) => {
|
||||
(vector_size.map(|s| s.get()), *scalar_type, array_dims)
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(size) = vector_size {
|
||||
write!(f, " .v{}", size)?;
|
||||
}
|
||||
|
||||
write!(f, " {} {}", scalar_type, self.name)?;
|
||||
|
||||
for dim in array_dims {
|
||||
write!(f, "[{}]", dim)?;
|
||||
}
|
||||
|
||||
if self.array_init.len() > 0 {
|
||||
todo!("Need to interpret the array initializer data as the appropriate type");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PredAt<ID> {
|
||||
pub not: bool,
|
||||
pub label: ID,
|
||||
|
@ -941,6 +978,15 @@ pub enum Type {
|
|||
Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Type {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Type::Scalar(scalar_type) => write!(f, "{}", scalar_type),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
|
||||
match vector {
|
||||
|
@ -1387,6 +1433,20 @@ bitflags! {
|
|||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LinkingDirective {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut directives = vec![];
|
||||
if self.contains(LinkingDirective::EXTERN) {
|
||||
directives.push(".extern");
|
||||
} else if self.contains(LinkingDirective::VISIBLE) {
|
||||
directives.push(".visible");
|
||||
} else if self.contains(LinkingDirective::WEAK) {
|
||||
directives.push(".weak");
|
||||
}
|
||||
write!(f, "{}", directives.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Function<'a, ID, S> {
|
||||
pub func_directive: MethodDeclaration<'a, ID>,
|
||||
pub tuning: Vec<TuningDirective>,
|
||||
|
|
|
@ -1757,37 +1757,39 @@ derive_parser!(
|
|||
DotFile
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum StateSpace {
|
||||
#[display(".reg")]
|
||||
Reg,
|
||||
#[display("")]
|
||||
Generic,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum MemScope { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum ScalarType { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum SetpBoolPostOp { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum AtomSemantics { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum Mul24Control { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum Reduction { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum ShuffleMode { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum ShiftDirection { }
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
|
||||
pub enum FunnelShiftMode { }
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
|
||||
|
|
|
@ -416,7 +416,9 @@ fn emit_enum_types(
|
|||
if let Some(enum_) = existing_enums.get_mut(ident) {
|
||||
enum_.variants.extend(variants.into_iter().map(|modifier| {
|
||||
let ident = modifier.variant_capitalized();
|
||||
let m_string = format!("{}", modifier);
|
||||
let variant: syn::Variant = syn::parse_quote! {
|
||||
#[display(#m_string)]
|
||||
#ident
|
||||
};
|
||||
variant
|
||||
|
@ -1050,6 +1052,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro:
|
|||
let mut result = proc_macro2::TokenStream::new();
|
||||
input.emit_arg_types(&mut result);
|
||||
input.emit_instruction_type(&mut result);
|
||||
input.emit_instruction_display(&mut result);
|
||||
input.emit_visit(&mut result);
|
||||
input.emit_visit_mut(&mut result);
|
||||
input.emit_visit_map(&mut result);
|
||||
|
|
|
@ -35,6 +35,32 @@ impl GenerateInstructionType {
|
|||
.to_tokens(tokens);
|
||||
}
|
||||
|
||||
pub fn emit_instruction_display(&self, tokens: &mut TokenStream) {
|
||||
let type_name = &self.name;
|
||||
let type_parameters = &self.type_parameters;
|
||||
let type_arguments = self.type_parameters.iter().map(|p| p.ident.clone());
|
||||
let variants = self
|
||||
.variants
|
||||
.iter()
|
||||
.map(|v| v.emit_display(&self.name))
|
||||
.filter_map(|v| v);
|
||||
quote! {
|
||||
impl<#type_parameters> std::fmt::Display for #type_name<#(#type_arguments),*>
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
<T as Operand>::Ident: std::fmt::Display
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
#(#variants),*
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
.to_tokens(tokens);
|
||||
}
|
||||
|
||||
pub fn emit_visit(&self, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
|
||||
}
|
||||
|
@ -163,6 +189,7 @@ pub struct InstructionVariant {
|
|||
pub arguments: Option<Arguments>,
|
||||
pub visit: Option<Expr>,
|
||||
pub visit_mut: Option<Expr>,
|
||||
pub display: Option<Expr>,
|
||||
pub map: Option<Expr>,
|
||||
}
|
||||
|
||||
|
@ -214,6 +241,25 @@ impl InstructionVariant {
|
|||
}
|
||||
}
|
||||
|
||||
fn emit_display(&self, enum_: &Ident) -> Option<TokenStream> {
|
||||
let name = &self.name;
|
||||
let enum_ = enum_;
|
||||
let arguments = self.arguments.as_ref().map(|_| quote! { arguments });
|
||||
let data = &self.data.as_ref().map(|_| quote! { data,});
|
||||
|
||||
let display_op = self
|
||||
.display
|
||||
.as_ref()
|
||||
.map(|d| quote! {#d})
|
||||
.unwrap_or(quote! { write!(f, "<{}>", stringify!(#name))? });
|
||||
|
||||
Some(quote! {
|
||||
instr @ #enum_ :: #name { #data #arguments } => {
|
||||
#display_op;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
|
||||
self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
|
||||
}
|
||||
|
@ -332,6 +378,7 @@ impl Parse for InstructionVariant {
|
|||
let mut arguments = None;
|
||||
let mut visit = None;
|
||||
let mut visit_mut = None;
|
||||
let mut display = None;
|
||||
let mut map = None;
|
||||
for property in properties {
|
||||
match property {
|
||||
|
@ -341,6 +388,7 @@ impl Parse for InstructionVariant {
|
|||
VariantProperty::Arguments(a) => arguments = Some(a),
|
||||
VariantProperty::Visit(e) => visit = Some(e),
|
||||
VariantProperty::VisitMut(e) => visit_mut = Some(e),
|
||||
VariantProperty::Display(e) => display = Some(e),
|
||||
VariantProperty::Map(e) => map = Some(e),
|
||||
}
|
||||
}
|
||||
|
@ -352,6 +400,7 @@ impl Parse for InstructionVariant {
|
|||
arguments,
|
||||
visit,
|
||||
visit_mut,
|
||||
display,
|
||||
map,
|
||||
})
|
||||
}
|
||||
|
@ -364,6 +413,7 @@ enum VariantProperty {
|
|||
Arguments(Arguments),
|
||||
Visit(Expr),
|
||||
VisitMut(Expr),
|
||||
Display(Expr),
|
||||
Map(Expr),
|
||||
}
|
||||
|
||||
|
@ -419,6 +469,10 @@ impl VariantProperty {
|
|||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::VisitMut(input.parse::<Expr>()?)
|
||||
}
|
||||
"display" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Display(input.parse::<Expr>()?)
|
||||
}
|
||||
"map" => {
|
||||
input.parse::<Token![:]>()?;
|
||||
VariantProperty::Map(input.parse::<Expr>()?)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue