mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
zoc: Linked LLVM IR
This commit is contained in:
parent
87b30a6604
commit
06328891c4
5 changed files with 170 additions and 68 deletions
|
@ -1,7 +1,9 @@
|
|||
use amd_comgr_sys::*;
|
||||
use std::{ffi::CStr, mem, ptr};
|
||||
use std::ffi::CStr;
|
||||
use std::mem;
|
||||
use std::ptr;
|
||||
|
||||
struct Data(amd_comgr_data_t);
|
||||
pub struct Data(amd_comgr_data_t);
|
||||
|
||||
impl Data {
|
||||
fn new(
|
||||
|
@ -20,7 +22,7 @@ impl Data {
|
|||
self.0
|
||||
}
|
||||
|
||||
fn copy_content(&self) -> Result<Vec<u8>, amd_comgr_status_s> {
|
||||
pub fn copy_content(&self) -> Result<Vec<u8>, amd_comgr_status_s> {
|
||||
let mut size = unsafe { mem::zeroed() };
|
||||
unsafe { amd_comgr_get_data(self.get(), &mut size, ptr::null_mut()) }?;
|
||||
let mut result: Vec<u8> = Vec::with_capacity(size);
|
||||
|
@ -30,7 +32,7 @@ impl Data {
|
|||
}
|
||||
}
|
||||
|
||||
struct DataSet(amd_comgr_data_set_t);
|
||||
pub struct DataSet(amd_comgr_data_set_t);
|
||||
|
||||
impl DataSet {
|
||||
fn new() -> Result<Self, amd_comgr_status_s> {
|
||||
|
@ -47,7 +49,7 @@ impl DataSet {
|
|||
self.0
|
||||
}
|
||||
|
||||
fn get_data(
|
||||
pub fn get_data(
|
||||
&self,
|
||||
kind: amd_comgr_data_kind_t,
|
||||
index: usize,
|
||||
|
@ -108,11 +110,10 @@ impl Drop for ActionInfo {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn compile_bitcode(
|
||||
gcn_arch: &CStr,
|
||||
pub fn link_bitcode(
|
||||
main_buffer: &[u8],
|
||||
ptx_impl: &[u8],
|
||||
) -> Result<Vec<u8>, amd_comgr_status_s> {
|
||||
) -> Result<DataSet, amd_comgr_status_s> {
|
||||
use amd_comgr_sys::*;
|
||||
let bitcode_data_set = DataSet::new()?;
|
||||
let main_bitcode_data = Data::new(
|
||||
|
@ -128,11 +129,21 @@ pub fn compile_bitcode(
|
|||
)?;
|
||||
bitcode_data_set.add(&stdlib_bitcode_data)?;
|
||||
let linking_info = ActionInfo::new()?;
|
||||
let linked_data_set = do_action(
|
||||
do_action(
|
||||
&bitcode_data_set,
|
||||
&linking_info,
|
||||
amd_comgr_action_kind_t::AMD_COMGR_ACTION_LINK_BC_TO_BC,
|
||||
)?;
|
||||
)
|
||||
}
|
||||
|
||||
pub fn compile_bitcode(
|
||||
gcn_arch: &CStr,
|
||||
main_buffer: &[u8],
|
||||
ptx_impl: &[u8],
|
||||
) -> Result<Vec<u8>, amd_comgr_status_s> {
|
||||
use amd_comgr_sys::*;
|
||||
|
||||
let linked_data_set = link_bitcode(main_buffer, ptx_impl)?;
|
||||
let compile_to_exec = ActionInfo::new()?;
|
||||
compile_to_exec.set_isa_name(gcn_arch)?;
|
||||
compile_to_exec.set_language(amd_comgr_language_t::AMD_COMGR_LANGUAGE_LLVM_IR)?;
|
||||
|
|
|
@ -2,5 +2,5 @@ pub(crate) mod pass;
|
|||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub use pass::to_llvm_module;
|
||||
|
||||
pub use pass::{TranslateError, to_llvm_module};
|
||||
pub use pass::emit_llvm::bitcode_to_ir;
|
|
@ -27,11 +27,13 @@
|
|||
use std::array::TryFromSliceError;
|
||||
use std::convert::TryInto;
|
||||
use std::ffi::{CStr, NulError};
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ops::Deref;
|
||||
use std::{i8, ptr};
|
||||
use std::ptr;
|
||||
|
||||
use super::*;
|
||||
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
|
||||
use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2;
|
||||
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
|
||||
use llvm_zluda::{core::*, *};
|
||||
use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
|
||||
|
@ -118,6 +120,24 @@ impl Drop for Module {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<MemoryBuffer> for Module {
|
||||
fn from(memory_buffer: MemoryBuffer) -> Self {
|
||||
let context = Context::new();
|
||||
let mut module: MaybeUninit<LLVMModuleRef> = MaybeUninit::uninit();
|
||||
unsafe {
|
||||
LLVMParseBitcodeInContext2(context.get(), memory_buffer.get(), module.as_mut_ptr());
|
||||
}
|
||||
let module = unsafe { module.assume_init() };
|
||||
Self(module, context)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bitcode_to_ir(bitcode: Vec<u8>) -> Vec<u8> {
|
||||
let memory_buffer: MemoryBuffer = bitcode.into();
|
||||
let module: Module = memory_buffer.into();
|
||||
module.print_module_to_string().to_bytes().to_vec()
|
||||
}
|
||||
|
||||
struct Builder(LLVMBuilderRef);
|
||||
|
||||
impl Builder {
|
||||
|
@ -170,6 +190,12 @@ impl Message {
|
|||
|
||||
pub struct MemoryBuffer(LLVMMemoryBufferRef);
|
||||
|
||||
impl MemoryBuffer {
|
||||
fn get(&self) -> LLVMMemoryBufferRef {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MemoryBuffer {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
|
@ -188,6 +214,26 @@ impl Deref for MemoryBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Vec<i8>> for MemoryBuffer {
|
||||
fn from(value: Vec<i8>) -> Self {
|
||||
let memory_buffer: LLVMMemoryBufferRef = unsafe {
|
||||
LLVMCreateMemoryBufferWithMemoryRangeCopy(
|
||||
value.as_ptr(),
|
||||
value.len(),
|
||||
ptr::null()
|
||||
)
|
||||
};
|
||||
Self(memory_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for MemoryBuffer {
|
||||
fn from(value: Vec<u8>) -> Self {
|
||||
let value: Vec<i8> = value.iter().map(|&v| i8::from_ne_bytes([v])).collect();
|
||||
value.into()
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn run<'input>(
|
||||
id_defs: GlobalStringIdentResolver2<'input>,
|
||||
directives: Vec<Directive2<ast::Instruction<SpirvWord>, SpirvWord>>,
|
||||
|
|
66
zoc/src/error.rs
Normal file
66
zoc/src/error.rs
Normal file
|
@ -0,0 +1,66 @@
|
|||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::str::Utf8Error;
|
||||
|
||||
use amd_comgr_sys::amd_comgr_status_s;
|
||||
use hip_runtime_sys::hipErrorCode_t;
|
||||
use ptx::TranslateError;
|
||||
use ptx_parser::PtxError;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CompilerError {
|
||||
#[error("HIP error: {0:?}")]
|
||||
HipError(hipErrorCode_t),
|
||||
#[error("amd_comgr error: {0:?}")]
|
||||
ComgrError(amd_comgr_status_s),
|
||||
#[error("Not a regular file: {0}")]
|
||||
CheckPathError(PathBuf),
|
||||
#[error("Invalid output type: {0}")]
|
||||
ParseOutputTypeError(String),
|
||||
#[error("Error parsing PTX: {0}")]
|
||||
PtxParserError(String),
|
||||
#[error("Error translating PTX: {0:?}")]
|
||||
PtxTranslateError(TranslateError),
|
||||
#[error("IO error: {0:?}")]
|
||||
IoError(io::Error),
|
||||
#[error("Error parsing file: {0:?}")]
|
||||
ParseFileError(Utf8Error),
|
||||
}
|
||||
|
||||
impl From<hipErrorCode_t> for CompilerError {
|
||||
fn from(error_code: hipErrorCode_t) -> Self {
|
||||
CompilerError::HipError(error_code)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<amd_comgr_status_s> for CompilerError {
|
||||
fn from(error_code: amd_comgr_status_s) -> Self {
|
||||
CompilerError::ComgrError(error_code)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<PtxError<'_>>> for CompilerError {
|
||||
fn from(causes: Vec<PtxError>) -> Self {
|
||||
let errors: Vec<String> = causes.iter().map(PtxError::to_string).collect();
|
||||
let msg = errors.join("\n");
|
||||
CompilerError::PtxParserError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for CompilerError {
|
||||
fn from(cause: io::Error) -> Self {
|
||||
CompilerError::IoError(cause)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Utf8Error> for CompilerError {
|
||||
fn from(cause: Utf8Error) -> Self {
|
||||
CompilerError::ParseFileError(cause)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TranslateError> for CompilerError {
|
||||
fn from(cause: TranslateError) -> Self {
|
||||
CompilerError::PtxTranslateError(cause)
|
||||
}
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::ffi::{CStr, OsStr};
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Write};
|
||||
|
@ -7,10 +6,11 @@ use std::mem::MaybeUninit;
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::str::{self, FromStr};
|
||||
|
||||
use amd_comgr_sys::amd_comgr_status_s;
|
||||
use amd_comgr_sys::amd_comgr_data_kind_s;
|
||||
use bpaf::Bpaf;
|
||||
use hip_runtime_sys::hipErrorCode_t;
|
||||
use ptx_parser::PtxError;
|
||||
|
||||
mod error;
|
||||
use error::CompilerError;
|
||||
|
||||
#[derive(Debug, Clone, Bpaf)]
|
||||
#[bpaf(options, version)]
|
||||
|
@ -23,13 +23,13 @@ pub struct Options {
|
|||
ptx_path: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
fn main() -> Result<(), CompilerError> {
|
||||
let opts = options().run();
|
||||
|
||||
let output_type = opts.output_type.unwrap_or_default();
|
||||
|
||||
match output_type {
|
||||
OutputType::LlvmIrLinked | OutputType::Assembly => todo!(),
|
||||
OutputType::Assembly => todo!(),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
|
@ -39,24 +39,30 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||
let output_path = get_output_path(&ptx_path, &output_type)?;
|
||||
check_path(&output_path)?;
|
||||
|
||||
let ptx = fs::read(&ptx_path)?;
|
||||
let ptx = str::from_utf8(&ptx)?;
|
||||
let llvm = ptx_to_llvm(ptx)?;
|
||||
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
|
||||
let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?;
|
||||
let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?;
|
||||
|
||||
if output_type == OutputType::LlvmIrPreLinked {
|
||||
write_to_file(&llvm.llvm_ir, &output_path)?;
|
||||
write_to_file(&llvm.llvm_ir, &output_path).map_err(CompilerError::from)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if output_type == OutputType::LlvmIrLinked {
|
||||
let linked_llvm = link_llvm(&llvm)?;
|
||||
write_to_file(&linked_llvm, &output_path).map_err(CompilerError::from)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let elf = llvm_to_elf(&llvm)?;
|
||||
write_to_file(&elf, &output_path)?;
|
||||
write_to_file(&elf, &output_path).map_err(CompilerError::from)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, Box<dyn Error>> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?;
|
||||
let module = ptx::to_llvm_module(ast)?;
|
||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from).map_err(CompilerError::from)?;
|
||||
let module = ptx::to_llvm_module(ast).map_err(CompilerError::from)?;
|
||||
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
|
||||
let linked_bitcode = module.linked_bitcode().to_vec();
|
||||
let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec();
|
||||
|
@ -74,12 +80,14 @@ struct LLVMArtifacts {
|
|||
llvm_ir: Vec<u8>,
|
||||
}
|
||||
|
||||
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
|
||||
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
|
||||
errors.join("\n")
|
||||
fn link_llvm(llvm: &LLVMArtifacts) -> Result<Vec<u8>, CompilerError> {
|
||||
let linked_bitcode = comgr::link_bitcode(&llvm.bitcode, &llvm.linked_bitcode)?;
|
||||
let data = linked_bitcode.get_data(amd_comgr_data_kind_s::AMD_COMGR_DATA_KIND_BC, 0)?;
|
||||
let linked_llvm = data.copy_content().map_err(CompilerError::from)?;
|
||||
Ok(ptx::bitcode_to_ir(linked_llvm))
|
||||
}
|
||||
|
||||
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
|
||||
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, CompilerError> {
|
||||
use hip_runtime_sys::*;
|
||||
unsafe { hipInit(0) }?;
|
||||
let mut dev_props: MaybeUninit<hipDeviceProp_tR0600> = MaybeUninit::uninit();
|
||||
|
@ -87,13 +95,12 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
|
|||
let dev_props = unsafe { dev_props.assume_init() };
|
||||
let gcn_arch = unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) };
|
||||
|
||||
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from)
|
||||
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(CompilerError::from)
|
||||
}
|
||||
|
||||
fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
|
||||
if path.try_exists()? && !path.is_file() {
|
||||
let error = CheckPathError(path.to_path_buf());
|
||||
let error = Box::new(error);
|
||||
fn check_path(path: &Path) -> Result<(), CompilerError> {
|
||||
if path.try_exists().map_err(CompilerError::from)? && !path.is_file() {
|
||||
let error = CompilerError::CheckPathError(path.to_path_buf());
|
||||
return Err(error);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -102,8 +109,8 @@ fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
|
|||
fn get_output_path(
|
||||
ptx_path: &PathBuf,
|
||||
output_type: &OutputType,
|
||||
) -> Result<PathBuf, Box<dyn Error>> {
|
||||
let current_dir = env::current_dir()?;
|
||||
) -> Result<PathBuf, CompilerError> {
|
||||
let current_dir = env::current_dir().map_err(CompilerError::from)?;
|
||||
let output_path = current_dir.join(
|
||||
ptx_path
|
||||
.as_path()
|
||||
|
@ -150,7 +157,7 @@ impl OutputType {
|
|||
}
|
||||
|
||||
impl FromStr for OutputType {
|
||||
type Err = ParseOutputTypeError;
|
||||
type Err = CompilerError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
|
@ -158,35 +165,7 @@ impl FromStr for OutputType {
|
|||
"ll_linked" => Ok(Self::LlvmIrLinked),
|
||||
"elf" => Ok(Self::Elf),
|
||||
"asm" => Ok(Self::Assembly),
|
||||
_ => Err(ParseOutputTypeError(s.into())),
|
||||
_ => Err(CompilerError::ParseOutputTypeError(s.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Not a regular file: {0}")]
|
||||
struct CheckPathError(PathBuf);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Invalid output type: {0}")]
|
||||
struct ParseOutputTypeError(String);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ElfError {
|
||||
#[error("HIP error: {0:?}")]
|
||||
HipError(hipErrorCode_t),
|
||||
#[error("amd_comgr error: {0:?}")]
|
||||
AmdComgrError(amd_comgr_status_s),
|
||||
}
|
||||
|
||||
impl From<hipErrorCode_t> for ElfError {
|
||||
fn from(value: hipErrorCode_t) -> Self {
|
||||
ElfError::HipError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<amd_comgr_status_s> for ElfError {
|
||||
fn from(value: amd_comgr_status_s) -> Self {
|
||||
ElfError::AmdComgrError(value)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue