diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 96e0815..b816481 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -27,13 +27,11 @@ use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; -use std::mem::MaybeUninit; use std::ops::Deref; use std::{i8, ptr}; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; -use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; @@ -47,7 +45,7 @@ const SHARED_ADDRESS_SPACE: u32 = 3; const CONSTANT_ADDRESS_SPACE: u32 = 4; const PRIVATE_ADDRESS_SPACE: u32 = 5; -struct Context(LLVMContextRef); +pub struct Context(LLVMContextRef); impl Context { fn new() -> Self { @@ -67,7 +65,7 @@ impl Drop for Context { } } -struct Module(LLVMModuleRef); +pub struct Module(LLVMModuleRef); impl Module { fn new(ctx: &Context, name: &CStr) -> Self { @@ -94,10 +92,15 @@ impl Module { } } - fn write_bitcode_to_memory(&self) -> MemoryBuffer { + pub fn write_bitcode_to_memory(&self) -> MemoryBuffer { let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; MemoryBuffer(memory_buffer) } + + pub fn print_module_to_string(&self) -> Message { + let asm = unsafe { LLVMPrintModuleToString(self.get()) }; + Message(unsafe { CStr::from_ptr(asm) }) + } } impl Drop for Module { @@ -132,7 +135,7 @@ impl Drop for Builder { } } -struct Message(&'static CStr); +pub struct Message(&'static CStr); impl Drop for Message { fn drop(&mut self) { @@ -148,23 +151,14 @@ impl std::fmt::Debug for Message { } } -pub struct MemoryBuffer(LLVMMemoryBufferRef); - -impl MemoryBuffer { - pub fn print_as_asm(&self) -> &str { - unsafe { - let context = Context::new(); - let mut module = MaybeUninit::uninit(); - LLVMParseBitcodeInContext2(context.0, self.0, module.as_mut_ptr()); - let module = module.assume_init(); - let asm = LLVMPrintModuleToString(module); - LLVMDisposeModule(module); - let asm = CStr::from_ptr(asm); - asm.to_str().unwrap().trim() - } +impl Message { + pub fn to_str(&self) -> &str { + self.0.to_str().unwrap().trim() } } +pub struct MemoryBuffer(LLVMMemoryBufferRef); + impl Drop for MemoryBuffer { fn drop(&mut self) { unsafe { @@ -186,7 +180,7 @@ impl Deref for MemoryBuffer { pub(super) fn run<'input>( id_defs: GlobalStringIdentResolver2<'input>, directives: Vec, SpirvWord>>, -) -> Result { +) -> Result<(Module, Context), TranslateError> { let context = Context::new(); let module = Module::new(&context, LLVM_UNNAMED); let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); @@ -199,7 +193,7 @@ pub(super) fn run<'input>( if let Err(err) = module.verify() { panic!("{:?}", err); } - Ok(module.write_bitcode_to_memory()) + Ok((module, context)) } struct ModuleEmitContext<'a, 'input> { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index c32cc39..40df188 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -53,15 +53,17 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result, } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 165835f..a57f71e 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -4,12 +4,12 @@ use std::env; use std::error; use std::ffi::{CStr, CString}; use std::fmt::{self, Debug, Display, Formatter}; -use std::fs::{create_dir_all, File}; +use std::fs::{self, File}; use std::io::Write; -use std::panic::{catch_unwind, resume_unwind}; use std::mem; use std::path::Path; -use std::{ptr, str}; +use std::ptr; +use std::str; use pretty_assertions; macro_rules! test_ptx { @@ -247,21 +247,19 @@ fn test_llvm_assert< ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); let llvm_ir = pass::to_llvm_module(ast).unwrap(); - let actual_ll = llvm_ir.llvm_ir.print_as_asm(); - let result = catch_unwind(|| - pretty_assertions::assert_eq!(actual_ll, expected_ll)); - if let Err(cause) = result { - // Write actual generated LLVM IR to directory specified by environment variable - // TEST_PTX_LLVM_FAIL_DIR if test fails + let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); + let actual_ll = actual_ll.to_str(); + if actual_ll != expected_ll { let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); if let Ok(output_dir) = output_dir { let output_dir = Path::new(&output_dir); - create_dir_all(&output_dir).unwrap(); + fs::create_dir_all(&output_dir).unwrap(); let output_file = output_dir.join(format!("{}.ll", name)); let mut output_file = File::create(output_file).unwrap(); output_file.write_all(actual_ll.as_bytes()).unwrap(); } - resume_unwind(cause); + let comparison = pretty_assertions::StrComparison::new(actual_ll, expected_ll); + panic!("assertion failed: `(left == right)`\n\n{}", comparison); } Ok(()) } @@ -347,7 +345,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, - &*module.llvm_ir, + &*module.llvm_ir.write_bitcode_to_memory(), module.linked_bitcode(), ) .unwrap(); diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index b469a89..a881e16 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -30,7 +30,7 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) - unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?; let elf_module = comgr::compile_bitcode( unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) }, - &*llvm_module.llvm_ir, + &*llvm_module.llvm_ir.write_bitcode_to_memory(), llvm_module.linked_bitcode(), ) .map_err(|_| CUerror::UNKNOWN)?;