LLVM unit tests: Adjustments after review

This commit is contained in:
Joëlle van Essen 2025-02-17 18:27:04 +01:00
parent fc8d82860f
commit 32d421a282
No known key found for this signature in database
GPG key ID: 28D3B5CDD4B43882
4 changed files with 31 additions and 37 deletions

View file

@ -27,13 +27,11 @@
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2;
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
@ -47,7 +45,7 @@ const SHARED_ADDRESS_SPACE: u32 = 3;
const CONSTANT_ADDRESS_SPACE: u32 = 4;
const PRIVATE_ADDRESS_SPACE: u32 = 5;
struct Context(LLVMContextRef);
pub struct Context(LLVMContextRef);
impl Context {
fn new() -> Self {
@ -67,7 +65,7 @@ impl Drop for Context {
}
}
struct Module(LLVMModuleRef);
pub struct Module(LLVMModuleRef);
impl Module {
fn new(ctx: &Context, name: &CStr) -> Self {
@ -94,10 +92,15 @@ impl Module {
}
}
fn write_bitcode_to_memory(&self) -> MemoryBuffer {
pub fn write_bitcode_to_memory(&self) -> MemoryBuffer {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
pub fn print_module_to_string(&self) -> Message {
let asm = unsafe { LLVMPrintModuleToString(self.get()) };
Message(unsafe { CStr::from_ptr(asm) })
}
}
impl Drop for Module {
@ -132,7 +135,7 @@ impl Drop for Builder {
}
}
struct Message(&'static CStr);
pub struct Message(&'static CStr);
impl Drop for Message {
fn drop(&mut self) {
@ -148,23 +151,14 @@ impl std::fmt::Debug for Message {
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl MemoryBuffer {
pub fn print_as_asm(&self) -> &str {
unsafe {
let context = Context::new();
let mut module = MaybeUninit::uninit();
LLVMParseBitcodeInContext2(context.0, self.0, module.as_mut_ptr());
let module = module.assume_init();
let asm = LLVMPrintModuleToString(module);
LLVMDisposeModule(module);
let asm = CStr::from_ptr(asm);
asm.to_str().unwrap().trim()
}
impl Message {
pub fn to_str(&self) -> &str {
self.0.to_str().unwrap().trim()
}
}
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
@ -186,7 +180,7 @@ impl Deref for MemoryBuffer {
pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
) -> Result<(Module, Context), TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
@ -199,7 +193,7 @@ pub(super) fn run<'input>(
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
Ok(module.write_bitcode_to_memory())
Ok((module, context))
}
struct ModuleEmitContext<'a, 'input> {

View file

@ -53,15 +53,17 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
let (llvm_ir, llvm_context) = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {
llvm_ir,
_llvm_context: llvm_context,
kernel_info: HashMap::new(),
})
}
pub struct Module {
pub llvm_ir: emit_llvm::MemoryBuffer,
pub llvm_ir: emit_llvm::Module,
_llvm_context: emit_llvm::Context,
pub kernel_info: HashMap<String, KernelInfo>,
}

View file

@ -4,12 +4,12 @@ use std::env;
use std::error;
use std::ffi::{CStr, CString};
use std::fmt::{self, Debug, Display, Formatter};
use std::fs::{create_dir_all, File};
use std::fs::{self, File};
use std::io::Write;
use std::panic::{catch_unwind, resume_unwind};
use std::mem;
use std::path::Path;
use std::{ptr, str};
use std::ptr;
use std::str;
use pretty_assertions;
macro_rules! test_ptx {
@ -247,21 +247,19 @@ fn test_llvm_assert<
) -> Result<(), Box<dyn error::Error + 'a>> {
let ast = ptx_parser::parse_module_checked(ptx_text).unwrap();
let llvm_ir = pass::to_llvm_module(ast).unwrap();
let actual_ll = llvm_ir.llvm_ir.print_as_asm();
let result = catch_unwind(||
pretty_assertions::assert_eq!(actual_ll, expected_ll));
if let Err(cause) = result {
// Write actual generated LLVM IR to directory specified by environment variable
// TEST_PTX_LLVM_FAIL_DIR if test fails
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();
let actual_ll = actual_ll.to_str();
if actual_ll != expected_ll {
let output_dir = env::var("TEST_PTX_LLVM_FAIL_DIR");
if let Ok(output_dir) = output_dir {
let output_dir = Path::new(&output_dir);
create_dir_all(&output_dir).unwrap();
fs::create_dir_all(&output_dir).unwrap();
let output_file = output_dir.join(format!("{}.ll", name));
let mut output_file = File::create(output_file).unwrap();
output_file.write_all(actual_ll.as_bytes()).unwrap();
}
resume_unwind(cause);
let comparison = pretty_assertions::StrComparison::new(actual_ll, expected_ll);
panic!("assertion failed: `(left == right)`\n\n{}", comparison);
}
Ok(())
}
@ -347,7 +345,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
let elf_module = comgr::compile_bitcode(
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
&*module.llvm_ir,
&*module.llvm_ir.write_bitcode_to_memory(),
module.linked_bitcode(),
)
.unwrap();

View file

@ -30,7 +30,7 @@ pub(crate) fn load_data(module: &mut CUmodule, image: *const std::ffi::c_void) -
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
let elf_module = comgr::compile_bitcode(
unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) },
&*llvm_module.llvm_ir,
&*llvm_module.llvm_ir.write_bitcode_to_memory(),
llvm_module.linked_bitcode(),
)
.map_err(|_| CUerror::UNKNOWN)?;