From 00eb553454099962d94546e83ca0f2eb5e0bae49 Mon Sep 17 00:00:00 2001 From: Violet Date: Fri, 22 Aug 2025 13:01:39 -0700 Subject: [PATCH] Add pass test mechanism for insert_implicit_conversions (#477) --- ptx/src/pass/mod.rs | 22 +- .../insert_implicit_conversions/default.ptx | 22 ++ .../test/insert_implicit_conversions/mod.rs | 22 ++ ptx/src/pass/test/mod.rs | 247 ++++++++++++++++++ ptx/src/test/mod.rs | 29 +- ptx_parser/src/ast.rs | 66 ++++- ptx_parser/src/lib.rs | 22 +- ptx_parser_macros/src/lib.rs | 3 + ptx_parser_macros_impl/src/lib.rs | 54 ++++ 9 files changed, 461 insertions(+), 26 deletions(-) create mode 100644 ptx/src/pass/test/insert_implicit_conversions/default.ptx create mode 100644 ptx/src/pass/test/insert_implicit_conversions/mod.rs create mode 100644 ptx/src/pass/test/mod.rs diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index d31c0ec..79f5e99 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -29,6 +29,9 @@ mod replace_instructions_with_functions_fp_required; mod replace_known_functions; mod resolve_function_pointers; +#[cfg(test)] +mod test; + static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; @@ -577,7 +580,18 @@ struct ImplicitConversion { kind: ConversionKind, } -#[derive(PartialEq, Clone)] +impl std::fmt::Display for ImplicitConversion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "zluda.convert_implicit{}{}{}{}{}", + self.kind, self.to_space, self.to_type, self.from_space, self.from_type + ) + } +} + +#[derive(PartialEq, Clone, strum_macros::Display)] +#[strum(serialize_all = "snake_case", prefix = ".")] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types @@ -617,6 +631,12 @@ struct FunctionPointerDetails { #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] pub struct SpirvWord(u32); +impl std::fmt::Display for SpirvWord { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "%{}", self.0) + } +} + impl From for SpirvWord { fn from(value: u32) -> Self { Self(value) diff --git a/ptx/src/pass/test/insert_implicit_conversions/default.ptx b/ptx/src/pass/test/insert_implicit_conversions/default.ptx new file mode 100644 index 0000000..6718884 --- /dev/null +++ b/ptx/src/pass/test/insert_implicit_conversions/default.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .b32 output) default ( + .reg .u32 input +) +{ + mov.b32 output, input; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .u32 %3 +) +{ + .b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.u32 %3; + mov.b32 %2, %4; + ret; +} \ No newline at end of file diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs new file mode 100644 index 0000000..f758148 --- /dev/null +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -0,0 +1,22 @@ +use crate::pass::{test::directive2_vec_to_string, *}; + +use super::test_pass; + +macro_rules! test_insert_implicit_conversions { + ($test_name:ident) => { + test_pass!(run_insert_implicit_conversions, $test_name); + }; +} + +fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { + // We run the minimal number of passes required to produce the input expected by insert_implicit_conversions + let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap(); + let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap(); + let directives = expand_operands::run(&mut flat_resolver, directives).unwrap(); + let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap(); + directive2_vec_to_string(&flat_resolver, directives) +} + +test_insert_implicit_conversions!(default); diff --git a/ptx/src/pass/test/mod.rs b/ptx/src/pass/test/mod.rs new file mode 100644 index 0000000..e54eed9 --- /dev/null +++ b/ptx/src/pass/test/mod.rs @@ -0,0 +1,247 @@ +use ptx_parser as ast; +use std::{ + env, error, + fs::{self, File}, + io::Write, + path::Path, +}; + +mod insert_implicit_conversions; + +#[macro_export] +macro_rules! test_pass { + ($pass:expr, $test_name:ident) => { + paste::item! { + #[test] + fn [<$test_name>]() -> Result<(), Box> { + use crate::test::read_test_file; + let ptx = read_test_file!(concat!(stringify!($test_name), ".ptx")); + let mut parts = ptx.split("// %%% output %%%"); + let ptx_in = parts.next().unwrap_or("").trim(); + let ptx_out = parts.next().unwrap_or("").trim(); + assert!(parts.next().is_none()); + crate::pass::test::test_pass_assert(stringify!($test_name), $pass, ptx_in, ptx_out) + } + } + }; +} +pub(crate) use test_pass; + +use crate::pass::IdentEntry; + +use super::{Directive2, Function2, GlobalStringIdentResolver2, SpirvWord, Statement}; + +fn directive2_vec_to_string( + resolver: &GlobalStringIdentResolver2, + directives: Vec, SpirvWord>>, +) -> String { + directives + .into_iter() + .map(|d| directive_to_string(resolver, d) + "\n") + .collect::>() + .join("") +} + +fn directive_to_string( + resolver: &GlobalStringIdentResolver2, + directive: Directive2, SpirvWord>, +) -> String { + match directive { + Directive2::Variable(linking_directive, variable) => { + let ld_string = if !linking_directive.is_empty() { + format!("{} ", linking_directive) + } else { + "".to_string() + }; + format!("{}{};", ld_string, variable) + } + Directive2::Method(function) => function_to_string(resolver, function), + } +} + +fn function_to_string( + resolver: &GlobalStringIdentResolver2, + function: Function2, SpirvWord>, +) -> String { + if function.import_as.is_some() + || function.tuning.len() > 0 + || function.flush_to_zero_f32 + || function.flush_to_zero_f16f64 + || function.rounding_mode_f32 != ast::RoundingMode::NearestEven + || function.rounding_mode_f16f64 != ast::RoundingMode::NearestEven + { + todo!("Figure out some way of representing these in text"); + } + + let linkage = if !function.linkage.is_empty() { + format!("{} ", function.linkage) + } else { + "".to_string() + }; + + let entry = if !function.is_kernel { + format!(".func ") + } else { + format!(".entry ") + }; + + let return_arguments = if function.return_arguments.len() > 0 { + let args = function + .return_arguments + .iter() + .map(|arg| format!("{}", arg)) + .collect::>() + .join(", "); + + format!("({}) ", args) + } else { + "".to_string() + }; + + let input_arguments = function + .input_arguments + .iter() + .map(|arg| format!("\n {}", arg)) + .collect::>() + .join(","); + + let body = if let Some(stmts) = function.body { + let stmt_strings = stmts + .into_iter() + .map(|stmt| format!(" {}\n", statement_to_string(resolver, stmt))) + .collect::>() + .join(""); + format!("\n{{\n{}}}", stmt_strings) + } else { + format!(";") + }; + + format!( + "{}{}{}{} ({}\n){}", + linkage, entry, return_arguments, function.name, input_arguments, body + ) +} + +struct StatementFormatter<'a> { + resolver: &'a GlobalStringIdentResolver2<'a>, + dst_strings: Vec, + other_args: Vec, +} + +impl<'a> StatementFormatter<'a> { + fn new(resolver: &'a GlobalStringIdentResolver2<'a>) -> Self { + Self { + resolver, + dst_strings: Vec::new(), + other_args: Vec::new(), + } + } + + fn format(&self, op: &str) -> String { + let assign_temps = if self.dst_strings.len() > 0 { + let temps = self.dst_strings.join(", "); + format!("{} = ", temps) + } else { + "".to_string() + }; + + let args = self + .other_args + .iter() + .map(|arg| format!(" {}", arg)) + .collect::>() + .join(","); + + format!("{}{}{};", assign_temps, op, args) + } +} + +impl<'a> ast::VisitorMap for StatementFormatter<'a> { + fn visit( + &mut self, + arg: SpirvWord, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + if is_dst { + if let Some(IdentEntry { name: None, .. }) = self.resolver.ident_map.get(&arg) { + let type_string = if let Some((type_, state_space)) = type_space { + format!("{}{} ", type_, state_space) + } else { + "".to_string() + }; + + self.dst_strings.push(format!("{}{}", type_string, arg)); + + return Ok(arg); + } + } + + self.other_args.push(arg); + + Ok(arg) + } + + fn visit_ident( + &mut self, + arg: ::Ident, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<::Ident, ()> { + self.visit(arg, type_space, is_dst, relaxed_type_check) + } +} + +fn statement_to_string( + resolver: &GlobalStringIdentResolver2, + stmt: Statement, SpirvWord>, +) -> String { + let op = match &stmt { + Statement::Variable(var) => format!("{}", var), + Statement::Instruction(instr) => format!("{}", instr), + Statement::Conversion(conv) => format!("{}", conv), + _ => todo!(), + }; + let mut args_formatter = StatementFormatter::new(resolver); + stmt.visit_map(&mut args_formatter); + args_formatter.format(&op) +} + +fn test_pass_assert( + name: &str, + run_pass: F, + ptx_in: &str, + expected_ptx_out: &str, +) -> Result<(), Box> +where + F: FnOnce(ast::Module) -> D, + D: std::fmt::Display, +{ + let actual_ptx_out = ast::parse_module_checked(ptx_in) + .map(|ast| { + let result = run_pass(ast); + result.to_string() + }) + .unwrap_or("".to_string()); + compare_ptx(name, ptx_in, actual_ptx_out.trim(), expected_ptx_out); + Ok(()) +} + +fn compare_ptx(name: &str, ptx_in: &str, actual_ptx_out: &str, expected_ptx_out: &str) { + if actual_ptx_out != expected_ptx_out { + let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR"); + if let Ok(output_dir) = output_dir { + let output_dir = Path::new(&output_dir); + fs::create_dir_all(&output_dir).unwrap(); + let output_file = output_dir.join(format!("{}.ptx", name)); + let mut output_file = File::create(output_file).unwrap(); + output_file.write_all(ptx_in.as_bytes()).unwrap(); + output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap(); + output_file.write_all(actual_ptx_out.as_bytes()).unwrap(); + } + let comparison = pretty_assertions::StrComparison::new(expected_ptx_out, actual_ptx_out); + panic!("assertion failed: `(left == right)`\n\n{}", comparison); + } +} diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 2de24b7..f746d63 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -3,25 +3,30 @@ use ptx_parser as ast; mod spirv_run; +#[cfg(not(feature = "ci_build"))] #[macro_export] macro_rules! read_test_file { ($file:expr) => { { - if cfg!(feature = "ci_build") { - include_str!($file).to_string() - } else { - use std::path::PathBuf; - // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). - let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.pop(); - path.push(file!()); - path.pop(); - path.push($file); - std::fs::read_to_string(path).unwrap() - } + use std::path::PathBuf; + // CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx). + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.pop(); + path.push(file!()); + path.pop(); + path.push($file); + std::fs::read_to_string(path).unwrap() } }; } + +#[cfg(feature = "ci_build")] +#[macro_export] +macro_rules! read_test_file { + ($file:expr) => { + include_str!($file).to_string() + }; +} pub(crate) use read_test_file; fn parse_and_assert(ptx_text: &str) { diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 108c2d3..f198795 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -303,7 +303,8 @@ ptx_parser_macros::generate_instruction_type!( repr: T, space: { data.state_space }, } - } + }, + display: write!(f, "")? }, Lg2 { type: Type::Scalar(ScalarType::F32), @@ -356,7 +357,8 @@ ptx_parser_macros::generate_instruction_type!( arguments: { dst: T, src: T - } + }, + display: write!(f, "mov{}", data.typ)? }, Mul { type: { Type::from(data.type_()) }, @@ -457,7 +459,8 @@ ptx_parser_macros::generate_instruction_type!( } }, Ret { - data: RetData + data: RetData, + display: write!(f, "ret")? }, Rsqrt { type: { Type::from(data.type_) }, @@ -926,6 +929,40 @@ pub struct Variable { pub array_init: Vec, } +impl std::fmt::Display for Variable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.state_space)?; + + if let Some(align) = self.align { + write!(f, " .align {}", align)?; + } + + let (vector_size, scalar_type, array_dims) = match &self.v_type { + Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]), + Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]), + Type::Array(vector_size, scalar_type, array_dims) => { + (vector_size.map(|s| s.get()), *scalar_type, array_dims) + } + }; + + if let Some(size) = vector_size { + write!(f, " .v{}", size)?; + } + + write!(f, " {} {}", scalar_type, self.name)?; + + for dim in array_dims { + write!(f, "[{}]", dim)?; + } + + if self.array_init.len() > 0 { + todo!("Need to interpret the array initializer data as the appropriate type"); + } + + Ok(()) + } +} + pub struct PredAt { pub not: bool, pub label: ID, @@ -941,6 +978,15 @@ pub enum Type { Array(Option, ScalarType, Vec), } +impl std::fmt::Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Scalar(scalar_type) => write!(f, "{}", scalar_type), + _ => todo!(), + } + } +} + impl Type { pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { match vector { @@ -1387,6 +1433,20 @@ bitflags! { } } +impl std::fmt::Display for LinkingDirective { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut directives = vec![]; + if self.contains(LinkingDirective::EXTERN) { + directives.push(".extern"); + } else if self.contains(LinkingDirective::VISIBLE) { + directives.push(".visible"); + } else if self.contains(LinkingDirective::WEAK) { + directives.push(".weak"); + } + write!(f, "{}", directives.join(" ")) + } +} + pub struct Function<'a, ID, S> { pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec, diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 56331a6..9c08f95 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1757,37 +1757,39 @@ derive_parser!( DotFile } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum StateSpace { + #[display(".reg")] Reg, + #[display("")] Generic, } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum MemScope { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum ScalarType { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum SetpBoolPostOp { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum AtomSemantics { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum Mul24Control { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum Reduction { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum ShuffleMode { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum ShiftDirection { } - #[derive(Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] pub enum FunnelShiftMode { } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs index 87e3de0..6c65af6 100644 --- a/ptx_parser_macros/src/lib.rs +++ b/ptx_parser_macros/src/lib.rs @@ -416,7 +416,9 @@ fn emit_enum_types( if let Some(enum_) = existing_enums.get_mut(ident) { enum_.variants.extend(variants.into_iter().map(|modifier| { let ident = modifier.variant_capitalized(); + let m_string = format!("{}", modifier); let variant: syn::Variant = syn::parse_quote! { + #[display(#m_string)] #ident }; variant @@ -1050,6 +1052,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro: let mut result = proc_macro2::TokenStream::new(); input.emit_arg_types(&mut result); input.emit_instruction_type(&mut result); + input.emit_instruction_display(&mut result); input.emit_visit(&mut result); input.emit_visit_mut(&mut result); input.emit_visit_map(&mut result); diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs index 34d97da..ed72f35 100644 --- a/ptx_parser_macros_impl/src/lib.rs +++ b/ptx_parser_macros_impl/src/lib.rs @@ -35,6 +35,32 @@ impl GenerateInstructionType { .to_tokens(tokens); } + pub fn emit_instruction_display(&self, tokens: &mut TokenStream) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let type_arguments = self.type_parameters.iter().map(|p| p.ident.clone()); + let variants = self + .variants + .iter() + .map(|v| v.emit_display(&self.name)) + .filter_map(|v| v); + quote! { + impl<#type_parameters> std::fmt::Display for #type_name<#(#type_arguments),*> + where + T: std::fmt::Display, + ::Ident: std::fmt::Display + { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #(#variants),* + } + Ok(()) + } + } + } + .to_tokens(tokens); + } + pub fn emit_visit(&self, tokens: &mut TokenStream) { self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) } @@ -163,6 +189,7 @@ pub struct InstructionVariant { pub arguments: Option, pub visit: Option, pub visit_mut: Option, + pub display: Option, pub map: Option, } @@ -214,6 +241,25 @@ impl InstructionVariant { } } + fn emit_display(&self, enum_: &Ident) -> Option { + let name = &self.name; + let enum_ = enum_; + let arguments = self.arguments.as_ref().map(|_| quote! { arguments }); + let data = &self.data.as_ref().map(|_| quote! { data,}); + + let display_op = self + .display + .as_ref() + .map(|d| quote! {#d}) + .unwrap_or(quote! { write!(f, "<{}>", stringify!(#name))? }); + + Some(quote! { + instr @ #enum_ :: #name { #data #arguments } => { + #display_op; + } + }) + } + fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit) } @@ -332,6 +378,7 @@ impl Parse for InstructionVariant { let mut arguments = None; let mut visit = None; let mut visit_mut = None; + let mut display = None; let mut map = None; for property in properties { match property { @@ -341,6 +388,7 @@ impl Parse for InstructionVariant { VariantProperty::Arguments(a) => arguments = Some(a), VariantProperty::Visit(e) => visit = Some(e), VariantProperty::VisitMut(e) => visit_mut = Some(e), + VariantProperty::Display(e) => display = Some(e), VariantProperty::Map(e) => map = Some(e), } } @@ -352,6 +400,7 @@ impl Parse for InstructionVariant { arguments, visit, visit_mut, + display, map, }) } @@ -364,6 +413,7 @@ enum VariantProperty { Arguments(Arguments), Visit(Expr), VisitMut(Expr), + Display(Expr), Map(Expr), } @@ -419,6 +469,10 @@ impl VariantProperty { input.parse::()?; VariantProperty::VisitMut(input.parse::()?) } + "display" => { + input.parse::()?; + VariantProperty::Display(input.parse::()?) + } "map" => { input.parse::()?; VariantProperty::Map(input.parse::()?)