diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index ac407ef..e1a872b 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -1,7 +1,9 @@ use amd_comgr_sys::*; -use std::{ffi::CStr, mem, ptr}; +use std::ffi::CStr; +use std::mem; +use std::ptr; -struct Data(amd_comgr_data_t); +pub struct Data(amd_comgr_data_t); impl Data { fn new( @@ -20,7 +22,7 @@ impl Data { self.0 } - fn copy_content(&self) -> Result, amd_comgr_status_s> { + pub fn copy_content(&self) -> Result, amd_comgr_status_s> { let mut size = unsafe { mem::zeroed() }; unsafe { amd_comgr_get_data(self.get(), &mut size, ptr::null_mut()) }?; let mut result: Vec = Vec::with_capacity(size); @@ -30,7 +32,7 @@ impl Data { } } -struct DataSet(amd_comgr_data_set_t); +pub struct DataSet(amd_comgr_data_set_t); impl DataSet { fn new() -> Result { @@ -47,7 +49,7 @@ impl DataSet { self.0 } - fn get_data( + pub fn get_data( &self, kind: amd_comgr_data_kind_t, index: usize, @@ -108,11 +110,10 @@ impl Drop for ActionInfo { } } -pub fn compile_bitcode( - gcn_arch: &CStr, +pub fn link_bitcode( main_buffer: &[u8], ptx_impl: &[u8], -) -> Result, amd_comgr_status_s> { +) -> Result { use amd_comgr_sys::*; let bitcode_data_set = DataSet::new()?; let main_bitcode_data = Data::new( @@ -128,11 +129,21 @@ pub fn compile_bitcode( )?; bitcode_data_set.add(&stdlib_bitcode_data)?; let linking_info = ActionInfo::new()?; - let linked_data_set = do_action( + do_action( &bitcode_data_set, &linking_info, amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_BC_TO_BC, - )?; + ) +} + +pub fn compile_bitcode( + gcn_arch: &CStr, + main_buffer: &[u8], + ptx_impl: &[u8], +) -> Result, amd_comgr_status_s> { + use amd_comgr_sys::*; + + let linked_data_set = link_bitcode(main_buffer, ptx_impl)?; let compile_to_exec = ActionInfo::new()?; compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?; diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index da972f6..aea554e 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -2,5 +2,5 @@ pub(crate) mod pass; #[cfg(test)] mod test; -pub use pass::to_llvm_module; - +pub use pass::{TranslateError, to_llvm_module}; +pub use pass::emit_llvm::bitcode_to_ir; \ No newline at end of file diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 072903b..2b6b68b 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -27,11 +27,13 @@ use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; +use std::mem::MaybeUninit; use std::ops::Deref; -use std::{i8, ptr}; +use std::ptr; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; +use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; @@ -118,6 +120,24 @@ impl Drop for Module { } } +impl From for Module { + fn from(memory_buffer: MemoryBuffer) -> Self { + let context = Context::new(); + let mut module: MaybeUninit = MaybeUninit::uninit(); + unsafe { + LLVMParseBitcodeInContext2(context.get(), memory_buffer.get(), module.as_mut_ptr()); + } + let module = unsafe { module.assume_init() }; + Self(module, context) + } +} + +pub fn bitcode_to_ir(bitcode: Vec) -> Vec { + let memory_buffer: MemoryBuffer = bitcode.into(); + let module: Module = memory_buffer.into(); + module.print_module_to_string().to_bytes().to_vec() +} + struct Builder(LLVMBuilderRef); impl Builder { @@ -170,6 +190,12 @@ impl Message { pub struct MemoryBuffer(LLVMMemoryBufferRef); +impl MemoryBuffer { + fn get(&self) -> LLVMMemoryBufferRef { + self.0 + } +} + impl Drop for MemoryBuffer { fn drop(&mut self) { unsafe { @@ -188,6 +214,26 @@ impl Deref for MemoryBuffer { } } +impl From> for MemoryBuffer { + fn from(value: Vec) -> Self { + let memory_buffer: LLVMMemoryBufferRef = unsafe { + LLVMCreateMemoryBufferWithMemoryRangeCopy( + value.as_ptr(), + value.len(), + ptr::null() + ) + }; + Self(memory_buffer) + } +} + +impl From> for MemoryBuffer { + fn from(value: Vec) -> Self { + let value: Vec = value.iter().map(|&v| i8::from_ne_bytes([v])).collect(); + value.into() + } +} + pub(super) fn run<'input>( id_defs: GlobalStringIdentResolver2<'input>, directives: Vec, SpirvWord>>, diff --git a/zoc/src/error.rs b/zoc/src/error.rs new file mode 100644 index 0000000..9c268c3 --- /dev/null +++ b/zoc/src/error.rs @@ -0,0 +1,66 @@ +use std::io; +use std::path::PathBuf; +use std::str::Utf8Error; + +use amd_comgr_sys::amd_comgr_status_s; +use hip_runtime_sys::hipErrorCode_t; +use ptx::TranslateError; +use ptx_parser::PtxError; + +#[derive(Debug, thiserror::Error)] +pub enum CompilerError { + #[error("HIP error: {0:?}")] + HipError(hipErrorCode_t), + #[error("amd_comgr error: {0:?}")] + ComgrError(amd_comgr_status_s), + #[error("Not a regular file: {0}")] + CheckPathError(PathBuf), + #[error("Invalid output type: {0}")] + ParseOutputTypeError(String), + #[error("Error parsing PTX: {0}")] + PtxParserError(String), + #[error("Error translating PTX: {0:?}")] + PtxTranslateError(TranslateError), + #[error("IO error: {0:?}")] + IoError(io::Error), + #[error("Error parsing file: {0:?}")] + ParseFileError(Utf8Error), +} + +impl From for CompilerError { + fn from(error_code: hipErrorCode_t) -> Self { + CompilerError::HipError(error_code) + } +} + +impl From for CompilerError { + fn from(error_code: amd_comgr_status_s) -> Self { + CompilerError::ComgrError(error_code) + } +} + +impl From>> for CompilerError { + fn from(causes: Vec) -> Self { + let errors: Vec = causes.iter().map(PtxError::to_string).collect(); + let msg = errors.join("\n"); + CompilerError::PtxParserError(msg) + } +} + +impl From for CompilerError { + fn from(cause: io::Error) -> Self { + CompilerError::IoError(cause) + } +} + +impl From for CompilerError { + fn from(cause: Utf8Error) -> Self { + CompilerError::ParseFileError(cause) + } +} + +impl From for CompilerError { + fn from(cause: TranslateError) -> Self { + CompilerError::PtxTranslateError(cause) + } +} \ No newline at end of file diff --git a/zoc/src/main.rs b/zoc/src/main.rs index 0f7b91e..b34bad1 100644 --- a/zoc/src/main.rs +++ b/zoc/src/main.rs @@ -1,5 +1,4 @@ use std::env; -use std::error::Error; use std::ffi::{CStr, OsStr}; use std::fs::{self, File}; use std::io::{self, Write}; @@ -7,10 +6,11 @@ use std::mem::MaybeUninit; use std::path::{Path, PathBuf}; use std::str::{self, FromStr}; -use amd_comgr_sys::amd_comgr_status_s; +use amd_comgr_sys::amd_comgr_data_kind_s; use bpaf::Bpaf; -use hip_runtime_sys::hipErrorCode_t; -use ptx_parser::PtxError; + +mod error; +use error::CompilerError; #[derive(Debug, Clone, Bpaf)] #[bpaf(options, version)] @@ -23,13 +23,13 @@ pub struct Options { ptx_path: String, } -fn main() -> Result<(), Box> { +fn main() -> Result<(), CompilerError> { let opts = options().run(); let output_type = opts.output_type.unwrap_or_default(); match output_type { - OutputType::LlvmIrLinked | OutputType::Assembly => todo!(), + OutputType::Assembly => todo!(), _ => {} } @@ -39,24 +39,30 @@ fn main() -> Result<(), Box> { let output_path = get_output_path(&ptx_path, &output_type)?; check_path(&output_path)?; - let ptx = fs::read(&ptx_path)?; - let ptx = str::from_utf8(&ptx)?; - let llvm = ptx_to_llvm(ptx)?; + let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; + let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?; + let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?; if output_type == OutputType::LlvmIrPreLinked { - write_to_file(&llvm.llvm_ir, &output_path)?; + write_to_file(&llvm.llvm_ir, &output_path).map_err(CompilerError::from)?; + return Ok(()); + } + + if output_type == OutputType::LlvmIrLinked { + let linked_llvm = link_llvm(&llvm)?; + write_to_file(&linked_llvm, &output_path).map_err(CompilerError::from)?; return Ok(()); } let elf = llvm_to_elf(&llvm)?; - write_to_file(&elf, &output_path)?; + write_to_file(&elf, &output_path).map_err(CompilerError::from)?; Ok(()) } -fn ptx_to_llvm(ptx: &str) -> Result> { - let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?; - let module = ptx::to_llvm_module(ast)?; +fn ptx_to_llvm(ptx: &str) -> Result { + let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from).map_err(CompilerError::from)?; + let module = ptx::to_llvm_module(ast).map_err(CompilerError::from)?; let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let linked_bitcode = module.linked_bitcode().to_vec(); let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec(); @@ -74,12 +80,14 @@ struct LLVMArtifacts { llvm_ir: Vec, } -fn join_ptx_errors(vector: Vec) -> String { - let errors: Vec = vector.iter().map(PtxError::to_string).collect(); - errors.join("\n") +fn link_llvm(llvm: &LLVMArtifacts) -> Result, CompilerError> { + let linked_bitcode = comgr::link_bitcode(&llvm.bitcode, &llvm.linked_bitcode)?; + let data = linked_bitcode.get_data(amd_comgr_data_kind_s::AMD_COMGR_DATA_KIND_BC, 0)?; + let linked_llvm = data.copy_content().map_err(CompilerError::from)?; + Ok(ptx::bitcode_to_ir(linked_llvm)) } -fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result, ElfError> { +fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result, CompilerError> { use hip_runtime_sys::*; unsafe { hipInit(0) }?; let mut dev_props: MaybeUninit = MaybeUninit::uninit(); @@ -87,13 +95,12 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result, ElfError> { let dev_props = unsafe { dev_props.assume_init() }; let gcn_arch = unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }; - comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from) + comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(CompilerError::from) } -fn check_path(path: &Path) -> Result<(), Box> { - if path.try_exists()? && !path.is_file() { - let error = CheckPathError(path.to_path_buf()); - let error = Box::new(error); +fn check_path(path: &Path) -> Result<(), CompilerError> { + if path.try_exists().map_err(CompilerError::from)? && !path.is_file() { + let error = CompilerError::CheckPathError(path.to_path_buf()); return Err(error); } Ok(()) @@ -102,8 +109,8 @@ fn check_path(path: &Path) -> Result<(), Box> { fn get_output_path( ptx_path: &PathBuf, output_type: &OutputType, -) -> Result> { - let current_dir = env::current_dir()?; +) -> Result { + let current_dir = env::current_dir().map_err(CompilerError::from)?; let output_path = current_dir.join( ptx_path .as_path() @@ -150,7 +157,7 @@ impl OutputType { } impl FromStr for OutputType { - type Err = ParseOutputTypeError; + type Err = CompilerError; fn from_str(s: &str) -> Result { match s { @@ -158,35 +165,7 @@ impl FromStr for OutputType { "ll_linked" => Ok(Self::LlvmIrLinked), "elf" => Ok(Self::Elf), "asm" => Ok(Self::Assembly), - _ => Err(ParseOutputTypeError(s.into())), + _ => Err(CompilerError::ParseOutputTypeError(s.into())), } } } - -#[derive(Debug, thiserror::Error)] -#[error("Not a regular file: {0}")] -struct CheckPathError(PathBuf); - -#[derive(Debug, thiserror::Error)] -#[error("Invalid output type: {0}")] -struct ParseOutputTypeError(String); - -#[derive(Debug, thiserror::Error)] -enum ElfError { - #[error("HIP error: {0:?}")] - HipError(hipErrorCode_t), - #[error("amd_comgr error: {0:?}")] - AmdComgrError(amd_comgr_status_s), -} - -impl From for ElfError { - fn from(value: hipErrorCode_t) -> Self { - ElfError::HipError(value) - } -} - -impl From for ElfError { - fn from(value: amd_comgr_status_s) -> Self { - ElfError::AmdComgrError(value) - } -}