diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 9131188..ce5452a 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,10 +1,14 @@ use crate::pass; use hip_runtime_sys::hipError_t; +use std::env; use std::error; use std::ffi::{CStr, CString}; -use std::fmt; -use std::fmt::{Debug, Display, Formatter}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::fs::{create_dir_all, File}; +use std::io::Write; +use std::panic::{catch_unwind, resume_unwind}; use std::mem; +use std::path::Path; use std::{ptr, str}; use pretty_assertions; @@ -33,10 +37,9 @@ macro_rules! test_ptx { paste::item! { #[test] fn [<$fn_name _llvm>]() -> Result<(), Box> { - let fn_name = stringify!($fn_name); let ptx = include_str!(concat!(stringify!($fn_name), ".ptx")); let ll = include_str!(concat!("../ll/", stringify!($fn_name), ".ll")).trim(); - test_llvm_assert(ptx, &ll) + test_llvm_assert(stringify!($fn_name), ptx, &ll) } } }; @@ -237,13 +240,28 @@ fn test_hip_assert< fn test_llvm_assert< 'a, >( + name: &str, ptx_text: &'a str, expected_ll: &str ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); let llvm_ir = pass::to_llvm_module(ast).unwrap(); let actual_ll = llvm_ir.llvm_ir.print_as_asm(); - pretty_assertions::assert_eq!(actual_ll, expected_ll); + let result = catch_unwind(|| + pretty_assertions::assert_eq!(actual_ll, expected_ll)); + if let Err(cause) = result { + // Write actual generated LLVM IR to directory specified by environment variable + // TEST_PTX_LLVM_FAIL_DIR if test fails + let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); + if let Ok(output_dir) = output_dir { + let output_dir = Path::new(&output_dir); + create_dir_all(&output_dir).unwrap(); + let output_file = output_dir.join(format!("{}.ll", name)); + let mut output_file = File::create(output_file).unwrap(); + output_file.write_all(actual_ll.as_bytes()).unwrap(); + } + resume_unwind(cause); + } Ok(()) }