diff --git a/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_f16x2.ptx b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_f16x2.ptx new file mode 100644 index 0000000..e467734 --- /dev/null +++ b/ptx/src/pass/test/insert_implicit_conversions/default_reg_b32_reg_f16x2.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .b32 output) default_reg_b32_reg_f16x2 ( + .reg .f16x2 input +) +{ + mov.b32 output, input; + ret; +} + +// %%% output %%% + +.func (.reg .b32 %2) %1 ( + .reg .f16x2 %3 +) +{ + .b32.reg %4 = zluda.convert_implicit.default.reg.b32.reg.f16x2 %3; + mov.b32 %2, %4; + ret; +} diff --git a/ptx/src/pass/test/insert_implicit_conversions/mod.rs b/ptx/src/pass/test/insert_implicit_conversions/mod.rs index f758148..1fb7a54 100644 --- a/ptx/src/pass/test/insert_implicit_conversions/mod.rs +++ b/ptx/src/pass/test/insert_implicit_conversions/mod.rs @@ -20,3 +20,4 @@ fn run_insert_implicit_conversions(ptx: ptx_parser::Module) -> String { } test_insert_implicit_conversions!(default); +test_insert_implicit_conversions!(default_reg_b32_reg_f16x2); diff --git a/ptx/src/pass/test/mod.rs b/ptx/src/pass/test/mod.rs index e54eed9..3a9ef1f 100644 --- a/ptx/src/pass/test/mod.rs +++ b/ptx/src/pass/test/mod.rs @@ -162,7 +162,7 @@ impl<'a> ast::VisitorMap for StatementFormatter<'a> { arg: SpirvWord, type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, is_dst: bool, - relaxed_type_check: bool, + _relaxed_type_check: bool, ) -> Result { if is_dst { if let Some(IdentEntry { name: None, .. }) = self.resolver.ident_map.get(&arg) { @@ -205,7 +205,7 @@ fn statement_to_string( _ => todo!(), }; let mut args_formatter = StatementFormatter::new(resolver); - stmt.visit_map(&mut args_formatter); + stmt.visit_map(&mut args_formatter).unwrap(); args_formatter.format(&op) } @@ -219,29 +219,40 @@ where F: FnOnce(ast::Module) -> D, D: std::fmt::Display, { - let actual_ptx_out = ast::parse_module_checked(ptx_in) + let (actual_ptx_out, errs) = ast::parse_module_checked(ptx_in) .map(|ast| { let result = run_pass(ast); - result.to_string() + (result.to_string(), vec![]) }) - .unwrap_or("".to_string()); + .unwrap_or_else(|errs| ("".to_string(), errs)); + for err in errs { + eprintln!("{}", err); + } compare_ptx(name, ptx_in, actual_ptx_out.trim(), expected_ptx_out); Ok(()) } fn compare_ptx(name: &str, ptx_in: &str, actual_ptx_out: &str, expected_ptx_out: &str) { if actual_ptx_out != expected_ptx_out { - let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR"); - if let Ok(output_dir) = output_dir { - let output_dir = Path::new(&output_dir); - fs::create_dir_all(&output_dir).unwrap(); - let output_file = output_dir.join(format!("{}.ptx", name)); - let mut output_file = File::create(output_file).unwrap(); - output_file.write_all(ptx_in.as_bytes()).unwrap(); - output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap(); - output_file.write_all(actual_ptx_out.as_bytes()).unwrap(); - } + maybe_save_output(name, ptx_in, actual_ptx_out); let comparison = pretty_assertions::StrComparison::new(expected_ptx_out, actual_ptx_out); panic!("assertion failed: `(left == right)`\n\n{}", comparison); } + if actual_ptx_out == "" { + maybe_save_output(name, ptx_in, actual_ptx_out); + panic!("missing expected output"); + } +} + +fn maybe_save_output(name: &str, ptx_in: &str, actual_ptx_out: &str) { + let output_dir = env::var("TEST_PTX_PASS_FAIL_DIR"); + if let Ok(output_dir) = output_dir { + let output_dir = Path::new(&output_dir); + fs::create_dir_all(&output_dir).unwrap(); + let output_file = output_dir.join(format!("{}.ptx", name)); + let mut output_file = File::create(output_file).unwrap(); + output_file.write_all(ptx_in.as_bytes()).unwrap(); + output_file.write_all(b"\n\n// %%% output %%%\n\n").unwrap(); + output_file.write_all(actual_ptx_out.as_bytes()).unwrap(); + } }