zoc: Linked LLVM IR

This commit is contained in:
Joëlle van Essen 2025-03-26 19:53:26 +01:00
parent 87b30a6604
commit 06328891c4
No known key found for this signature in database
GPG key ID: 28D3B5CDD4B43882
5 changed files with 170 additions and 68 deletions

View file

@ -1,7 +1,9 @@
use amd_comgr_sys::*;
use std::{ffi::CStr, mem, ptr};
use std::ffi::CStr;
use std::mem;
use std::ptr;
struct Data(amd_comgr_data_t);
pub struct Data(amd_comgr_data_t);
impl Data {
fn new(
@ -20,7 +22,7 @@ impl Data {
self.0
}
fn copy_content(&self) -> Result<Vec<u8>, amd_comgr_status_s> {
pub fn copy_content(&self) -> Result<Vec<u8>, amd_comgr_status_s> {
let mut size = unsafe { mem::zeroed() };
unsafe { amd_comgr_get_data(self.get(), &mut size, ptr::null_mut()) }?;
let mut result: Vec<u8> = Vec::with_capacity(size);
@ -30,7 +32,7 @@ impl Data {
}
}
struct DataSet(amd_comgr_data_set_t);
pub struct DataSet(amd_comgr_data_set_t);
impl DataSet {
fn new() -> Result<Self, amd_comgr_status_s> {
@ -47,7 +49,7 @@ impl DataSet {
self.0
}
fn get_data(
pub fn get_data(
&self,
kind: amd_comgr_data_kind_t,
index: usize,
@ -108,11 +110,10 @@ impl Drop for ActionInfo {
}
}
pub fn compile_bitcode(
gcn_arch: &CStr,
pub fn link_bitcode(
main_buffer: &[u8],
ptx_impl: &[u8],
) -> Result<Vec<u8>, amd_comgr_status_s> {
) -> Result<DataSet, amd_comgr_status_s> {
use amd_comgr_sys::*;
let bitcode_data_set = DataSet::new()?;
let main_bitcode_data = Data::new(
@ -128,11 +129,21 @@ pub fn compile_bitcode(
)?;
bitcode_data_set.add(&stdlib_bitcode_data)?;
let linking_info = ActionInfo::new()?;
let linked_data_set = do_action(
do_action(
&bitcode_data_set,
&linking_info,
amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_BC_TO_BC,
)?;
)
}
pub fn compile_bitcode(
gcn_arch: &CStr,
main_buffer: &[u8],
ptx_impl: &[u8],
) -> Result<Vec<u8>, amd_comgr_status_s> {
use amd_comgr_sys::*;
let linked_data_set = link_bitcode(main_buffer, ptx_impl)?;
let compile_to_exec = ActionInfo::new()?;
compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?;

View file

@ -2,5 +2,5 @@ pub(crate) mod pass;
#[cfg(test)]
mod test;
pub use pass::to_llvm_module;
pub use pass::{TranslateError, to_llvm_module};
pub use pass::emit_llvm::bitcode_to_ir;

View file

@ -27,11 +27,13 @@
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::{i8, ptr};
use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2;
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
use llvm_zluda::{core::*, *};
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
@ -118,6 +120,24 @@ impl Drop for Module {
}
}
impl From<MemoryBuffer> for Module {
fn from(memory_buffer: MemoryBuffer) -> Self {
let context = Context::new();
let mut module: MaybeUninit<LLVMModuleRef> = MaybeUninit::uninit();
unsafe {
LLVMParseBitcodeInContext2(context.get(), memory_buffer.get(), module.as_mut_ptr());
}
let module = unsafe { module.assume_init() };
Self(module, context)
}
}
pub fn bitcode_to_ir(bitcode: Vec<u8>) -> Vec<u8> {
let memory_buffer: MemoryBuffer = bitcode.into();
let module: Module = memory_buffer.into();
module.print_module_to_string().to_bytes().to_vec()
}
struct Builder(LLVMBuilderRef);
impl Builder {
@ -170,6 +190,12 @@ impl Message {
pub struct MemoryBuffer(LLVMMemoryBufferRef);
impl MemoryBuffer {
fn get(&self) -> LLVMMemoryBufferRef {
self.0
}
}
impl Drop for MemoryBuffer {
fn drop(&mut self) {
unsafe {
@ -188,6 +214,26 @@ impl Deref for MemoryBuffer {
}
}
impl From<Vec<i8>> for MemoryBuffer {
fn from(value: Vec<i8>) -> Self {
let memory_buffer: LLVMMemoryBufferRef = unsafe {
LLVMCreateMemoryBufferWithMemoryRangeCopy(
value.as_ptr(),
value.len(),
ptr::null()
)
};
Self(memory_buffer)
}
}
impl From<Vec<u8>> for MemoryBuffer {
fn from(value: Vec<u8>) -> Self {
let value: Vec<i8> = value.iter().map(|&v| i8::from_ne_bytes([v])).collect();
value.into()
}
}
pub(super) fn run<'input>(
id_defs: GlobalStringIdentResolver2<'input>,
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,

66
zoc/src/error.rs Normal file
View file

@ -0,0 +1,66 @@
use std::io;
use std::path::PathBuf;
use std::str::Utf8Error;
use amd_comgr_sys::amd_comgr_status_s;
use hip_runtime_sys::hipErrorCode_t;
use ptx::TranslateError;
use ptx_parser::PtxError;
#[derive(Debug, thiserror::Error)]
pub enum CompilerError {
#[error("HIP error: {0:?}")]
HipError(hipErrorCode_t),
#[error("amd_comgr error: {0:?}")]
ComgrError(amd_comgr_status_s),
#[error("Not a regular file: {0}")]
CheckPathError(PathBuf),
#[error("Invalid output type: {0}")]
ParseOutputTypeError(String),
#[error("Error parsing PTX: {0}")]
PtxParserError(String),
#[error("Error translating PTX: {0:?}")]
PtxTranslateError(TranslateError),
#[error("IO error: {0:?}")]
IoError(io::Error),
#[error("Error parsing file: {0:?}")]
ParseFileError(Utf8Error),
}
impl From<hipErrorCode_t> for CompilerError {
fn from(error_code: hipErrorCode_t) -> Self {
CompilerError::HipError(error_code)
}
}
impl From<amd_comgr_status_s> for CompilerError {
fn from(error_code: amd_comgr_status_s) -> Self {
CompilerError::ComgrError(error_code)
}
}
impl From<Vec<PtxError<'_>>> for CompilerError {
fn from(causes: Vec<PtxError>) -> Self {
let errors: Vec<String> = causes.iter().map(PtxError::to_string).collect();
let msg = errors.join("\n");
CompilerError::PtxParserError(msg)
}
}
impl From<io::Error> for CompilerError {
fn from(cause: io::Error) -> Self {
CompilerError::IoError(cause)
}
}
impl From<Utf8Error> for CompilerError {
fn from(cause: Utf8Error) -> Self {
CompilerError::ParseFileError(cause)
}
}
impl From<TranslateError> for CompilerError {
fn from(cause: TranslateError) -> Self {
CompilerError::PtxTranslateError(cause)
}
}

View file

@ -1,5 +1,4 @@
use std::env;
use std::error::Error;
use std::ffi::{CStr, OsStr};
use std::fs::{self, File};
use std::io::{self, Write};
@ -7,10 +6,11 @@ use std::mem::MaybeUninit;
use std::path::{Path, PathBuf};
use std::str::{self, FromStr};
use amd_comgr_sys::amd_comgr_status_s;
use amd_comgr_sys::amd_comgr_data_kind_s;
use bpaf::Bpaf;
use hip_runtime_sys::hipErrorCode_t;
use ptx_parser::PtxError;
mod error;
use error::CompilerError;
#[derive(Debug, Clone, Bpaf)]
#[bpaf(options, version)]
@ -23,13 +23,13 @@ pub struct Options {
ptx_path: String,
}
fn main() -> Result<(), Box<dyn Error>> {
fn main() -> Result<(), CompilerError> {
let opts = options().run();
let output_type = opts.output_type.unwrap_or_default();
match output_type {
OutputType::LlvmIrLinked | OutputType::Assembly => todo!(),
OutputType::Assembly => todo!(),
_ => {}
}
@ -39,24 +39,30 @@ fn main() -> Result<(), Box<dyn Error>> {
let output_path = get_output_path(&ptx_path, &output_type)?;
check_path(&output_path)?;
let ptx = fs::read(&ptx_path)?;
let ptx = str::from_utf8(&ptx)?;
let llvm = ptx_to_llvm(ptx)?;
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?;
if output_type == OutputType::LlvmIrPreLinked {
write_to_file(&llvm.llvm_ir, &output_path)?;
write_to_file(&llvm.llvm_ir, &output_path).map_err(CompilerError::from)?;
return Ok(());
}
if output_type == OutputType::LlvmIrLinked {
let linked_llvm = link_llvm(&llvm)?;
write_to_file(&linked_llvm, &output_path).map_err(CompilerError::from)?;
return Ok(());
}
let elf = llvm_to_elf(&llvm)?;
write_to_file(&elf, &output_path)?;
write_to_file(&elf, &output_path).map_err(CompilerError::from)?;
Ok(())
}
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, Box<dyn Error>> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?;
let module = ptx::to_llvm_module(ast)?;
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from).map_err(CompilerError::from)?;
let module = ptx::to_llvm_module(ast).map_err(CompilerError::from)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
let linked_bitcode = module.linked_bitcode().to_vec();
let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec();
@ -74,12 +80,14 @@ struct LLVMArtifacts {
llvm_ir: Vec<u8>,
}
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
errors.join("\n")
fn link_llvm(llvm: &LLVMArtifacts) -> Result<Vec<u8>, CompilerError> {
let linked_bitcode = comgr::link_bitcode(&llvm.bitcode, &llvm.linked_bitcode)?;
let data = linked_bitcode.get_data(amd_comgr_data_kind_s::AMD_COMGR_DATA_KIND_BC, 0)?;
let linked_llvm = data.copy_content().map_err(CompilerError::from)?;
Ok(ptx::bitcode_to_ir(linked_llvm))
}
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, CompilerError> {
use hip_runtime_sys::*;
unsafe { hipInit(0) }?;
let mut dev_props: MaybeUninit<hipDeviceProp_tR0600> = MaybeUninit::uninit();
@ -87,13 +95,12 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
let dev_props = unsafe { dev_props.assume_init() };
let gcn_arch = unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) };
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from)
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(CompilerError::from)
}
fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
if path.try_exists()? && !path.is_file() {
let error = CheckPathError(path.to_path_buf());
let error = Box::new(error);
fn check_path(path: &Path) -> Result<(), CompilerError> {
if path.try_exists().map_err(CompilerError::from)? && !path.is_file() {
let error = CompilerError::CheckPathError(path.to_path_buf());
return Err(error);
}
Ok(())
@ -102,8 +109,8 @@ fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
fn get_output_path(
ptx_path: &PathBuf,
output_type: &OutputType,
) -> Result<PathBuf, Box<dyn Error>> {
let current_dir = env::current_dir()?;
) -> Result<PathBuf, CompilerError> {
let current_dir = env::current_dir().map_err(CompilerError::from)?;
let output_path = current_dir.join(
ptx_path
.as_path()
@ -150,7 +157,7 @@ impl OutputType {
}
impl FromStr for OutputType {
type Err = ParseOutputTypeError;
type Err = CompilerError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
@ -158,35 +165,7 @@ impl FromStr for OutputType {
"ll_linked" => Ok(Self::LlvmIrLinked),
"elf" => Ok(Self::Elf),
"asm" => Ok(Self::Assembly),
_ => Err(ParseOutputTypeError(s.into())),
_ => Err(CompilerError::ParseOutputTypeError(s.into())),
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Not a regular file: {0}")]
struct CheckPathError(PathBuf);
#[derive(Debug, thiserror::Error)]
#[error("Invalid output type: {0}")]
struct ParseOutputTypeError(String);
#[derive(Debug, thiserror::Error)]
enum ElfError {
#[error("HIP error: {0:?}")]
HipError(hipErrorCode_t),
#[error("amd_comgr error: {0:?}")]
AmdComgrError(amd_comgr_status_s),
}
impl From<hipErrorCode_t> for ElfError {
fn from(value: hipErrorCode_t) -> Self {
ElfError::HipError(value)
}
}
impl From<amd_comgr_status_s> for ElfError {
fn from(value: amd_comgr_status_s) -> Self {
ElfError::AmdComgrError(value)
}
}