Add test for conversion from .f16x2 to .b32 (#479)

This commit is contained in:
Violet 2025-08-25 15:33:53 -07:00 committed by GitHub
commit de319f7c00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 49 additions and 15 deletions

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.func (.reg .b32 output) default_reg_b32_reg_f16x2 (
.reg .f16x2 input
)
{
mov.b32 output, input;
ret;
}
// %%% output %%%
.func (.reg .b32 %2) %1 (
.reg .f16x2 %3
)
{
.b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.f16x2 %3;
mov.b32 %2, %4;
ret;
}

View file

@ -20,3 +20,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String {
}
test_insert_implicit_conversions!(default);
test_insert_implicit_conversions!(default_reg_b32_reg_f16x2);

View file

@ -162,7 +162,7 @@ impl<'a> ast::VisitorMap<SpirvWord, SpirvWord, ()> for StatementFormatter<'a> {
arg: SpirvWord,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, ()> {
if is_dst {
if let Some(IdentEntry { name: None, .. }) = self.resolver.ident_map.get(&arg) {
@ -205,7 +205,7 @@ fn statement_to_string(
_ => todo!(),
};
let mut args_formatter = StatementFormatter::new(resolver);
stmt.visit_map(&mut args_formatter);
stmt.visit_map(&mut args_formatter).unwrap();
args_formatter.format(&op)
}
@ -219,29 +219,40 @@ where
F: FnOnce(ast::Module) -> D,
D: std::fmt::Display,
{
let actual_ptx_out = ast::parse_module_checked(ptx_in)
let (actual_ptx_out, errs) = ast::parse_module_checked(ptx_in)
.map(|ast| {
let result = run_pass(ast);
result.to_string()
(result.to_string(), vec![])
})
.unwrap_or("".to_string());
.unwrap_or_else(|errs| ("".to_string(), errs));
for err in errs {
eprintln!("{}", err);
}
compare_ptx(name, ptx_in, actual_ptx_out.trim(), expected_ptx_out);
Ok(())
}
fn compare_ptx(name: &str, ptx_in: &str, actual_ptx_out: &str, expected_ptx_out: &str) {
if actual_ptx_out != expected_ptx_out {
let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR");
if let Ok(output_dir) = output_dir {
let output_dir = Path::new(&output_dir);
fs::create_dir_all(&output_dir).unwrap();
let output_file = output_dir.join(format!("{}.ptx", name));
let mut output_file = File::create(output_file).unwrap();
output_file.write_all(ptx_in.as_bytes()).unwrap();
output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap();
output_file.write_all(actual_ptx_out.as_bytes()).unwrap();
}
maybe_save_output(name, ptx_in, actual_ptx_out);
let comparison = pretty_assertions::StrComparison::new(expected_ptx_out, actual_ptx_out);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
}
if actual_ptx_out == "" {
maybe_save_output(name, ptx_in, actual_ptx_out);
panic!("missing expected output");
}
}
fn maybe_save_output(name: &str, ptx_in: &str, actual_ptx_out: &str) {
let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR");
if let Ok(output_dir) = output_dir {
let output_dir = Path::new(&output_dir);
fs::create_dir_all(&output_dir).unwrap();
let output_file = output_dir.join(format!("{}.ptx", name));
let mut output_file = File::create(output_file).unwrap();
output_file.write_all(ptx_in.as_bytes()).unwrap();
output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap();
output_file.write_all(actual_ptx_out.as_bytes()).unwrap();
}
}