LLVM unit tests: Adjustments after review

This commit is contained in:
Joëlle van Essen 2025-02-17 18:27:04 +01:00
commit 32d421a282
No known key found for this signature in database
GPG key ID: 28D3B5CDD4B43882
4 changed files with 31 additions and 37 deletions

View file

@ -27,13 +27,11 @@
use std::array::TryFromSliceError; use std::array::TryFromSliceError;
use std::convert::TryInto; use std::convert::TryInto;
use std::ffi::{CStr, NulError}; use std::ffi::{CStr, NulError};
use std::mem::MaybeUninit;
use std::ops::Deref; use std::ops::Deref;
use std::{i8, ptr}; use std::{i8, ptr};
use super::*; use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2;
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, *}; use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
@ -47,7 +45,7 @@ const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4; const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5; const PRIVATE_ADDRESS_SPACE: u32 = 5;
struct Context(LLVMContextRef); pub struct Context(LLVMContextRef);
impl Context { impl Context {
fn new() -> Self { fn new() -> Self {
@ -67,7 +65,7 @@ impl Drop for Context {
} }
} }
struct Module(LLVMModuleRef); pub struct Module(LLVMModuleRef);
impl Module { impl Module {
fn new(ctx: &Context, name: &CStr) -> Self { fn new(ctx: &Context, name: &CStr) -> Self {
@ -94,10 +92,15 @@ impl Module {
} }
} }
fn write_bitcode_to_memory(&self) -> MemoryBuffer { pub fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer) MemoryBuffer(memory_buffer)
} }
pub fn print_module_to_string(&self) -> Message {
let asm = unsafe { LLVMPrintModuleToString(self.get()) };
Message(unsafe { CStr::from_ptr(asm) })
}
} }
impl Drop for Module { impl Drop for Module {
@ -132,7 +135,7 @@ impl Drop for Builder {
} }
} }
struct Message(&'static CStr); pub struct Message(&'static CStr);
impl Drop for Message { impl Drop for Message {
fn drop(&mut self) { fn drop(&mut self) {
@ -148,23 +151,14 @@ impl std::fmt::Debug for Message {
} }
} }
pub struct MemoryBuffer(LLVMMemoryBufferRef); impl Message {
pub fn to_str(&self) -> &str {
impl MemoryBuffer { self.0.to_str().unwrap().trim()
pub fn print_as_asm(&self) -> &str {
unsafe {
let context = Context::new();
let mut module = MaybeUninit::uninit();
LLVMParseBitcodeInContext2(context.0, self.0, module.as_mut_ptr());
let module = module.assume_init();
let asm = LLVMPrintModuleToString(module);
LLVMDisposeModule(module);
let asm = CStr::from_ptr(asm);
asm.to_str().unwrap().trim()
}
} }
} }
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer { impl Drop for MemoryBuffer {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
@ -186,7 +180,7 @@ impl Deref for MemoryBuffer {
pub(super) fn run<'input>( pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>, id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> { ) -> Result<(Module, Context), TranslateError> {
let context = Context::new(); let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED); let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
@ -199,7 +193,7 @@ pub(super) fn run<'input>(
if let Err(err) = module.verify() { if let Err(err) = module.verify() {
panic!("{:?}", err); panic!("{:?}", err);
} }
Ok(module.write_bitcode_to_memory()) Ok((module, context))
} }
struct ModuleEmitContext<'a, 'input> { struct ModuleEmitContext<'a, 'input> {

View file

@ -53,15 +53,17 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?; let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?; let (llvm_ir, llvm_context) = emit_llvm::run(flat_resolver, directives)?;
Ok(Module { Ok(Module {
llvm_ir, llvm_ir,
_llvm_context: llvm_context,
kernel_info: HashMap::new(), kernel_info: HashMap::new(),
}) })
} }
pub struct Module { pub struct Module {
pub llvm_ir: emit_llvm::MemoryBuffer, pub llvm_ir: emit_llvm::Module,
_llvm_context: emit_llvm::Context,
pub kernel_info: HashMap<String, KernelInfo>, pub kernel_info: HashMap<String, KernelInfo>,
} }

View file

@ -4,12 +4,12 @@ use std::env;
use std::error; use std::error;
use std::ffi::{CStr, CString}; use std::ffi::{CStr, CString};
use std::fmt::{self, Debug, Display, Formatter}; use std::fmt::{self, Debug, Display, Formatter};
use std::fs::{create_dir_all, File}; use std::fs::{self, File};
use std::io::Write; use std::io::Write;
use std::panic::{catch_unwind, resume_unwind};
use std::mem; use std::mem;
use std::path::Path; use std::path::Path;
use std::{ptr, str}; use std::ptr;
use std::str;
use pretty_assertions; use pretty_assertions;
macro_rules! test_ptx { macro_rules! test_ptx {
@ -247,21 +247,19 @@ fn test_llvm_assert<
) -> Result<(), Box<dyn error::Error + 'a>> { ) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap(); let llvm_ir = pass::to_llvm_module(ast).unwrap();
let actual_ll = llvm_ir.llvm_ir.print_as_asm(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
let result = catch_unwind(|| let actual_ll = actual_ll.to_str();
pretty_assertions::assert_eq!(actual_ll, expected_ll)); if actual_ll != expected_ll {
if let Err(cause) = result {
// Write actual generated LLVM IR to directory specified by environment variable
// TEST_PTX_LLVM_FAIL_DIR if test fails
let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR"); let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR");
if let Ok(output_dir) = output_dir { if let Ok(output_dir) = output_dir {
let output_dir = Path::new(&output_dir); let output_dir = Path::new(&output_dir);
create_dir_all(&output_dir).unwrap(); fs::create_dir_all(&output_dir).unwrap();
let output_file = output_dir.join(format!("{}.ll", name)); let output_file = output_dir.join(format!("{}.ll", name));
let mut output_file = File::create(output_file).unwrap(); let mut output_file = File::create(output_file).unwrap();
output_file.write_all(actual_ll.as_bytes()).unwrap(); output_file.write_all(actual_ll.as_bytes()).unwrap();
} }
resume_unwind(cause); let comparison = pretty_assertions::StrComparison::new(actual_ll, expected_ll);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
} }
Ok(()) Ok(())
} }
@ -347,7 +345,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap(); unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }, unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
&*module.llvm_ir, &*module.llvm_ir.write_bitcode_to_memory(),
module.linked_bitcode(), module.linked_bitcode(),
) )
.unwrap(); .unwrap();

View file

@ -30,7 +30,7 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?; unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) }, unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) },
&*llvm_module.llvm_ir, &*llvm_module.llvm_ir.write_bitcode_to_memory(),
llvm_module.linked_bitcode(), llvm_module.linked_bitcode(),
) )
.map_err(|_| CUerror::UNKNOWN)?; .map_err(|_| CUerror::UNKNOWN)?;