Add pass test mechanism for insert_implicit_conversions (#477)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

This commit is contained in:
Violet 2025-08-22 13:01:39 -07:00 committed by GitHub
commit 00eb553454
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 461 additions and 26 deletions

View file

@ -29,6 +29,9 @@ mod replace_instructions_with_functions_fp_required;
mod replace_known_functions;
mod resolve_function_pointers;
#[cfg(test)]
mod test;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
@ -577,7 +580,18 @@ struct ImplicitConversion {
kind: ConversionKind,
}
#[derive(PartialEq, Clone)]
impl std::fmt::Display for ImplicitConversion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"zluda.convert_implicit{}{}{}{}{}",
self.kind, self.to_space, self.to_type, self.from_space, self.from_type
)
}
}
#[derive(PartialEq, Clone, strum_macros::Display)]
#[strum(serialize_all = "snake_case", prefix = ".")]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
@ -617,6 +631,12 @@ struct FunctionPointerDetails {
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
pub struct SpirvWord(u32);
impl std::fmt::Display for SpirvWord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "%{}", self.0)
}
}
impl From<u32> for SpirvWord {
fn from(value: u32) -> Self {
Self(value)

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.func (.reg .b32 output) default (
.reg .u32 input
)
{
mov.b32 output, input;
ret;
}
// %%% output %%%
.func (.reg .b32 %2) %1 (
.reg .u32 %3
)
{
.b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.u32 %3;
mov.b32 %2, %4;
ret;
}

View file

@ -0,0 +1,22 @@
use crate::pass::{test::directive2_vec_to_string, *};
use super::test_pass;
macro_rules! test_insert_implicit_conversions {
($test_name:ident) => {
test_pass!(run_insert_implicit_conversions, $test_name);
};
}
fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
// We run the minimal number of passes required to produce the input expected by insert_implicit_conversions
let mut flat_resolver = GlobalStringIdentResolver2::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let directives = normalize_identifiers2::run(&mut scoped_resolver, ptx.directives).unwrap();
let directives = normalize_predicates2::run(&mut flat_resolver, directives).unwrap();
let directives = expand_operands::run(&mut flat_resolver, directives).unwrap();
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives).unwrap();
directive2_vec_to_string(&flat_resolver, directives)
}
test_insert_implicit_conversions!(default);

247
ptx/src/pass/test/mod.rs Normal file
View file

@ -0,0 +1,247 @@
use ptx_parser as ast;
use std::{
env, error,
fs::{self, File},
io::Write,
path::Path,
};
mod insert_implicit_conversions;
#[macro_export]
macro_rules! test_pass {
($pass:expr, $test_name:ident) => {
paste::item! {
#[test]
fn [<$test_name>]() -> Result<(), Box<dyn std::error::Error>> {
use crate::test::read_test_file;
let ptx = read_test_file!(concat!(stringify!($test_name), ".ptx"));
let mut parts = ptx.split("// %%% output %%%");
let ptx_in = parts.next().unwrap_or("").trim();
let ptx_out = parts.next().unwrap_or("").trim();
assert!(parts.next().is_none());
crate::pass::test::test_pass_assert(stringify!($test_name), $pass, ptx_in, ptx_out)
}
}
};
}
pub(crate) use test_pass;
use crate::pass::IdentEntry;
use super::{Directive2, Function2, GlobalStringIdentResolver2, SpirvWord, Statement};
fn directive2_vec_to_string(
resolver: &GlobalStringIdentResolver2,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
) -> String {
directives
.into_iter()
.map(|d| directive_to_string(resolver, d) + "\n")
.collect::<Vec<_>>()
.join("")
}
fn directive_to_string(
resolver: &GlobalStringIdentResolver2,
directive: Directive2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> String {
match directive {
Directive2::Variable(linking_directive, variable) => {
let ld_string = if !linking_directive.is_empty() {
format!("{} ", linking_directive)
} else {
"".to_string()
};
format!("{}{};", ld_string, variable)
}
Directive2::Method(function) => function_to_string(resolver, function),
}
}
fn function_to_string(
resolver: &GlobalStringIdentResolver2,
function: Function2<ast::Instruction<SpirvWord>, SpirvWord>,
) -> String {
if function.import_as.is_some()
|| function.tuning.len() > 0
|| function.flush_to_zero_f32
|| function.flush_to_zero_f16f64
|| function.rounding_mode_f32 != ast::RoundingMode::NearestEven
|| function.rounding_mode_f16f64 != ast::RoundingMode::NearestEven
{
todo!("Figure out some way of representing these in text");
}
let linkage = if !function.linkage.is_empty() {
format!("{} ", function.linkage)
} else {
"".to_string()
};
let entry = if !function.is_kernel {
format!(".func ")
} else {
format!(".entry ")
};
let return_arguments = if function.return_arguments.len() > 0 {
let args = function
.return_arguments
.iter()
.map(|arg| format!("{}", arg))
.collect::<Vec<_>>()
.join(", ");
format!("({}) ", args)
} else {
"".to_string()
};
let input_arguments = function
.input_arguments
.iter()
.map(|arg| format!("\n {}", arg))
.collect::<Vec<_>>()
.join(",");
let body = if let Some(stmts) = function.body {
let stmt_strings = stmts
.into_iter()
.map(|stmt| format!(" {}\n", statement_to_string(resolver, stmt)))
.collect::<Vec<_>>()
.join("");
format!("\n{{\n{}}}", stmt_strings)
} else {
format!(";")
};
format!(
"{}{}{}{} ({}\n){}",
linkage, entry, return_arguments, function.name, input_arguments, body
)
}
struct StatementFormatter<'a> {
resolver: &'a GlobalStringIdentResolver2<'a>,
dst_strings: Vec<String>,
other_args: Vec<SpirvWord>,
}
impl<'a> StatementFormatter<'a> {
fn new(resolver: &'a GlobalStringIdentResolver2<'a>) -> Self {
Self {
resolver,
dst_strings: Vec::new(),
other_args: Vec::new(),
}
}
fn format(&self, op: &str) -> String {
let assign_temps = if self.dst_strings.len() > 0 {
let temps = self.dst_strings.join(", ");
format!("{} = ", temps)
} else {
"".to_string()
};
let args = self
.other_args
.iter()
.map(|arg| format!(" {}", arg))
.collect::<Vec<_>>()
.join(",");
format!("{}{}{};", assign_temps, op, args)
}
}
impl<'a> ast::VisitorMap<SpirvWord, SpirvWord, ()> for StatementFormatter<'a> {
fn visit(
&mut self,
arg: SpirvWord,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, ()> {
if is_dst {
if let Some(IdentEntry { name: None, .. }) = self.resolver.ident_map.get(&arg) {
let type_string = if let Some((type_, state_space)) = type_space {
format!("{}{} ", type_, state_space)
} else {
"".to_string()
};
self.dst_strings.push(format!("{}{}", type_string, arg));
return Ok(arg);
}
}
self.other_args.push(arg);
Ok(arg)
}
fn visit_ident(
&mut self,
arg: <SpirvWord as ptx_parser::Operand>::Ident,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, ()> {
self.visit(arg, type_space, is_dst, relaxed_type_check)
}
}
fn statement_to_string(
resolver: &GlobalStringIdentResolver2,
stmt: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> String {
let op = match &stmt {
Statement::Variable(var) => format!("{}", var),
Statement::Instruction(instr) => format!("{}", instr),
Statement::Conversion(conv) => format!("{}", conv),
_ => todo!(),
};
let mut args_formatter = StatementFormatter::new(resolver);
stmt.visit_map(&mut args_formatter);
args_formatter.format(&op)
}
fn test_pass_assert<F, D>(
name: &str,
run_pass: F,
ptx_in: &str,
expected_ptx_out: &str,
) -> Result<(), Box<dyn error::Error>>
where
F: FnOnce(ast::Module) -> D,
D: std::fmt::Display,
{
let actual_ptx_out = ast::parse_module_checked(ptx_in)
.map(|ast| {
let result = run_pass(ast);
result.to_string()
})
.unwrap_or("".to_string());
compare_ptx(name, ptx_in, actual_ptx_out.trim(), expected_ptx_out);
Ok(())
}
fn compare_ptx(name: &str, ptx_in: &str, actual_ptx_out: &str, expected_ptx_out: &str) {
if actual_ptx_out != expected_ptx_out {
let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR");
if let Ok(output_dir) = output_dir {
let output_dir = Path::new(&output_dir);
fs::create_dir_all(&output_dir).unwrap();
let output_file = output_dir.join(format!("{}.ptx", name));
let mut output_file = File::create(output_file).unwrap();
output_file.write_all(ptx_in.as_bytes()).unwrap();
output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap();
output_file.write_all(actual_ptx_out.as_bytes()).unwrap();
}
let comparison = pretty_assertions::StrComparison::new(expected_ptx_out, actual_ptx_out);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
}
}

View file

@ -3,13 +3,11 @@ use ptx_parser as ast;
mod spirv_run;
#[cfg(not(feature = "ci_build"))]
#[macro_export]
macro_rules! read_test_file {
($file:expr) => {
{
if cfg!(feature = "ci_build") {
include_str!($file).to_string()
} else {
use std::path::PathBuf;
// CARGO_MANIFEST_DIR is the crate directory (ptx), but file! is relative to the workspace root (and therefore also includes ptx).
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
@ -19,7 +17,14 @@ macro_rules! read_test_file {
path.push($file);
std::fs::read_to_string(path).unwrap()
}
};
}
#[cfg(feature = "ci_build")]
#[macro_export]
macro_rules! read_test_file {
($file:expr) => {
include_str!($file).to_string()
};
}
pub(crate) use read_test_file;

View file

@ -303,7 +303,8 @@ ptx_parser_macros::generate_instruction_type!(
repr: T,
space: { data.state_space },
}
}
},
display: write!(f, "<TODO:finish ld>")?
},
Lg2 {
type: Type::Scalar(ScalarType::F32),
@ -356,7 +357,8 @@ ptx_parser_macros::generate_instruction_type!(
arguments<T>: {
dst: T,
src: T
}
},
display: write!(f, "mov{}", data.typ)?
},
Mul {
type: { Type::from(data.type_()) },
@ -457,7 +459,8 @@ ptx_parser_macros::generate_instruction_type!(
}
},
Ret {
data: RetData
data: RetData,
display: write!(f, "ret")?
},
Rsqrt {
type: { Type::from(data.type_) },
@ -926,6 +929,40 @@ pub struct Variable<ID> {
pub array_init: Vec<u8>,
}
impl<ID: std::fmt::Display> std::fmt::Display for Variable<ID> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.state_space)?;
if let Some(align) = self.align {
write!(f, " .align {}", align)?;
}
let (vector_size, scalar_type, array_dims) = match &self.v_type {
Type::Scalar(scalar_type) => (None, *scalar_type, &vec![]),
Type::Vector(size, scalar_type) => (Some(*size), *scalar_type, &vec![]),
Type::Array(vector_size, scalar_type, array_dims) => {
(vector_size.map(|s| s.get()), *scalar_type, array_dims)
}
};
if let Some(size) = vector_size {
write!(f, " .v{}", size)?;
}
write!(f, " {} {}", scalar_type, self.name)?;
for dim in array_dims {
write!(f, "[{}]", dim)?;
}
if self.array_init.len() > 0 {
todo!("Need to interpret the array initializer data as the appropriate type");
}
Ok(())
}
}
pub struct PredAt<ID> {
pub not: bool,
pub label: ID,
@ -941,6 +978,15 @@ pub enum Type {
Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
}
impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Scalar(scalar_type) => write!(f, "{}", scalar_type),
_ => todo!(),
}
}
}
impl Type {
pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
match vector {
@ -1387,6 +1433,20 @@ bitflags! {
}
}
impl std::fmt::Display for LinkingDirective {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut directives = vec![];
if self.contains(LinkingDirective::EXTERN) {
directives.push(".extern");
} else if self.contains(LinkingDirective::VISIBLE) {
directives.push(".visible");
} else if self.contains(LinkingDirective::WEAK) {
directives.push(".weak");
}
write!(f, "{}", directives.join(" "))
}
}
pub struct Function<'a, ID, S> {
pub func_directive: MethodDeclaration<'a, ID>,
pub tuning: Vec<TuningDirective>,

View file

@ -1757,37 +1757,39 @@ derive_parser!(
DotFile
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum StateSpace {
#[display(".reg")]
Reg,
#[display("")]
Generic,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum MemScope { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum ScalarType { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum SetpBoolPostOp { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum AtomSemantics { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum Mul24Control { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum Reduction { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum ShuffleMode { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum ShiftDirection { }
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Display, PartialEq, Eq, Hash)]
pub enum FunnelShiftMode { }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov

View file

@ -416,7 +416,9 @@ fn emit_enum_types(
if let Some(enum_) = existing_enums.get_mut(ident) {
enum_.variants.extend(variants.into_iter().map(|modifier| {
let ident = modifier.variant_capitalized();
let m_string = format!("{}", modifier);
let variant: syn::Variant = syn::parse_quote! {
#[display(#m_string)]
#ident
};
variant
@ -1050,6 +1052,7 @@ pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro:
let mut result = proc_macro2::TokenStream::new();
input.emit_arg_types(&mut result);
input.emit_instruction_type(&mut result);
input.emit_instruction_display(&mut result);
input.emit_visit(&mut result);
input.emit_visit_mut(&mut result);
input.emit_visit_map(&mut result);

View file

@ -35,6 +35,32 @@ impl GenerateInstructionType {
.to_tokens(tokens);
}
pub fn emit_instruction_display(&self, tokens: &mut TokenStream) {
let type_name = &self.name;
let type_parameters = &self.type_parameters;
let type_arguments = self.type_parameters.iter().map(|p| p.ident.clone());
let variants = self
.variants
.iter()
.map(|v| v.emit_display(&self.name))
.filter_map(|v| v);
quote! {
impl<#type_parameters> std::fmt::Display for #type_name<#(#type_arguments),*>
where
T: std::fmt::Display,
<T as Operand>::Ident: std::fmt::Display
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#(#variants),*
}
Ok(())
}
}
}
.to_tokens(tokens);
}
pub fn emit_visit(&self, tokens: &mut TokenStream) {
self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
}
@ -163,6 +189,7 @@ pub struct InstructionVariant {
pub arguments: Option<Arguments>,
pub visit: Option<Expr>,
pub visit_mut: Option<Expr>,
pub display: Option<Expr>,
pub map: Option<Expr>,
}
@ -214,6 +241,25 @@ impl InstructionVariant {
}
}
fn emit_display(&self, enum_: &Ident) -> Option<TokenStream> {
let name = &self.name;
let enum_ = enum_;
let arguments = self.arguments.as_ref().map(|_| quote! { arguments });
let data = &self.data.as_ref().map(|_| quote! { data,});
let display_op = self
.display
.as_ref()
.map(|d| quote! {#d})
.unwrap_or(quote! { write!(f, "<{}>", stringify!(#name))? });
Some(quote! {
instr @ #enum_ :: #name { #data #arguments } => {
#display_op;
}
})
}
fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
}
@ -332,6 +378,7 @@ impl Parse for InstructionVariant {
let mut arguments = None;
let mut visit = None;
let mut visit_mut = None;
let mut display = None;
let mut map = None;
for property in properties {
match property {
@ -341,6 +388,7 @@ impl Parse for InstructionVariant {
VariantProperty::Arguments(a) => arguments = Some(a),
VariantProperty::Visit(e) => visit = Some(e),
VariantProperty::VisitMut(e) => visit_mut = Some(e),
VariantProperty::Display(e) => display = Some(e),
VariantProperty::Map(e) => map = Some(e),
}
}
@ -352,6 +400,7 @@ impl Parse for InstructionVariant {
arguments,
visit,
visit_mut,
display,
map,
})
}
@ -364,6 +413,7 @@ enum VariantProperty {
Arguments(Arguments),
Visit(Expr),
VisitMut(Expr),
Display(Expr),
Map(Expr),
}
@ -419,6 +469,10 @@ impl VariantProperty {
input.parse::<Token![:]>()?;
VariantProperty::VisitMut(input.parse::<Expr>()?)
}
"display" => {
input.parse::<Token![:]>()?;
VariantProperty::Display(input.parse::<Expr>()?)
}
"map" => {
input.parse::<Token![:]>()?;
VariantProperty::Map(input.parse::<Expr>()?)