diff --git a/Cargo.lock b/Cargo.lock index 131b41b..7480124 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -186,18 +186,18 @@ dependencies = [ [[package]] name = "bpaf" -version = "0.9.15" +version = "0.9.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50fd5174866dc2fa2ddc96e8fb800852d37f064f32a45c7b7c2f8fa2c64c77fa" +checksum = "4848ed5727d39a7573551c205bcb1ccd88c8cad4ed2c80f62e2316f208196b8d" dependencies = [ "bpaf_derive", ] [[package]] name = "bpaf_derive" -version = "0.5.13" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf95d9c7e6aba67f8fc07761091e93254677f4db9e27197adecebc7039a58722" +checksum = "fefb4feeec9a091705938922f26081aad77c64cd2e76cd1c4a9ece8e42e1618a" dependencies = [ "proc-macro2", "quote", @@ -256,7 +256,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -355,6 +355,21 @@ version = "0.0.0" dependencies = [ "amd_comgr-sys", "libloading", + "ptx", + "thiserror 2.0.12", +] + +[[package]] +name = "compiler" +version = "0.0.0" +dependencies = [ + "amd_comgr-sys", + "bpaf", + "comgr", + "hip_runtime-sys", + "ptx", + "ptx_parser", + "thiserror 2.0.12", ] [[package]] @@ -960,7 +975,7 @@ dependencies = [ "parking_lot", "signal-hook", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -973,7 +988,7 @@ dependencies = [ "gix-date", "gix-utils 0.2.0", "itoa", - "thiserror 2.0.11", + "thiserror 2.0.12", "winnow 0.7.10", ] @@ -990,7 +1005,7 @@ dependencies = [ "gix-trace", "kstring", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", "unicode-bom", ] @@ -1000,7 +1015,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1db9765c69502650da68f0804e3dc2b5f8ccc6a2d104ca6c85bc40700d37540" dependencies = [ - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1009,7 +1024,7 @@ version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b1f1d8764958699dc764e3f727cef280ff4d1bd92c107bbf8acd85b30c1bd6f" dependencies = [ - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1035,7 +1050,7 @@ dependencies = [ "gix-chunk", "gix-hash 0.17.0", "memmap2 0.9.7", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1054,7 +1069,7 @@ dependencies = [ "memchr", "once_cell", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", "unicode-bom", "winnow 0.7.10", ] @@ -1069,7 +1084,7 @@ dependencies = [ "bstr", "gix-path", "libc", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1081,7 +1096,7 @@ dependencies = [ "bstr", "itoa", "jiff", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1105,7 +1120,7 @@ dependencies = [ "gix-traverse", "gix-worktree", "imara-diff", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1125,7 +1140,7 @@ dependencies = [ "gix-trace", "gix-utils 0.2.0", "gix-worktree", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1141,7 +1156,7 @@ dependencies = [ "gix-path", "gix-ref", "gix-sec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1158,7 +1173,7 @@ dependencies = [ "libc", "once_cell", "prodash", - "thiserror 2.0.11", + "thiserror 2.0.12", "walkdir", ] @@ -1192,7 +1207,7 @@ dependencies = [ "gix-trace", "gix-utils 0.2.0", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1206,7 +1221,7 @@ dependencies = [ "gix-features 0.41.1", "gix-path", "gix-utils 0.2.0", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1220,7 +1235,7 @@ dependencies = [ "gix-features 0.42.1", "gix-path", "gix-utils 0.3.0", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1244,7 +1259,7 @@ dependencies = [ "faster-hex 0.9.0", "gix-features 0.41.1", "sha1-checked", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1256,7 +1271,7 @@ dependencies = [ "faster-hex 0.10.0", "gix-features 0.42.1", "sha1-checked", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1308,7 +1323,7 @@ dependencies = [ "memmap2 0.9.7", "rustix 0.38.37", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1319,7 +1334,7 @@ checksum = "570f8b034659f256366dc90f1a24924902f20acccd6a15be96d44d1269e7a796" dependencies = [ "gix-tempfile", "gix-utils 0.3.0", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1339,7 +1354,7 @@ dependencies = [ "gix-validate 0.9.4", "itoa", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", "winnow 0.7.10", ] @@ -1361,7 +1376,7 @@ dependencies = [ "gix-quote", "parking_lot", "tempfile", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1379,7 +1394,7 @@ dependencies = [ "gix-path", "memmap2 0.9.7", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1391,7 +1406,7 @@ dependencies = [ "bstr", "faster-hex 0.9.0", "gix-trace", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1403,7 +1418,7 @@ dependencies = [ "bstr", "faster-hex 0.9.0", "gix-trace", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1417,7 +1432,7 @@ dependencies = [ "gix-validate 0.10.0", "home", "once_cell", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1432,7 +1447,7 @@ dependencies = [ "gix-config-value", "gix-glob", "gix-path", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1450,7 +1465,7 @@ dependencies = [ "gix-transport", "gix-utils 0.2.0", "maybe-async", - "thiserror 2.0.11", + "thiserror 2.0.12", "winnow 0.7.10", ] @@ -1462,7 +1477,7 @@ checksum = "1b005c550bf84de3b24aa5e540a23e6146a1c01c7d30470e35d75a12f827f969" dependencies = [ "bstr", "gix-utils 0.2.0", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1482,7 +1497,7 @@ dependencies = [ "gix-utils 0.2.0", "gix-validate 0.9.4", "memmap2 0.9.7", - "thiserror 2.0.11", + "thiserror 2.0.12", "winnow 0.7.10", ] @@ -1497,7 +1512,7 @@ dependencies = [ "gix-revision", "gix-validate 0.9.4", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1515,7 +1530,7 @@ dependencies = [ "gix-object", "gix-revwalk", "gix-trace", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1530,7 +1545,7 @@ dependencies = [ "gix-hashtable", "gix-object", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1554,7 +1569,7 @@ dependencies = [ "bstr", "gix-hash 0.17.0", "gix-lock", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1577,7 +1592,7 @@ dependencies = [ "gix-pathspec", "gix-worktree", "portable-atomic", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1592,7 +1607,7 @@ dependencies = [ "gix-pathspec", "gix-refspec", "gix-url", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1630,7 +1645,7 @@ dependencies = [ "gix-quote", "gix-sec", "gix-url", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1647,7 +1662,7 @@ dependencies = [ "gix-object", "gix-revwalk", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1660,7 +1675,7 @@ dependencies = [ "gix-features 0.41.1", "gix-path", "percent-encoding", - "thiserror 2.0.11", + "thiserror 2.0.12", "url", ] @@ -1692,7 +1707,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34b5f1253109da6c79ed7cf6e1e38437080bb6d704c76af14c93e2f255234084" dependencies = [ "bstr", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -1702,7 +1717,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77b9e00cacde5b51388d28ed746c493b18a6add1f19b5e01d686b3b9ece66d4d" dependencies = [ "bstr", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -2588,8 +2603,8 @@ dependencies = [ "quick-error", "rustc-hash 2.0.0", "serde", - "strum", - "strum_macros", + "strum 0.26.3", + "strum_macros 0.26.4", "tempfile", "thiserror 1.0.64", "unwrap_or", @@ -2604,6 +2619,7 @@ dependencies = [ "logos", "ptx_parser_macros", "rustc-hash 2.0.0", + "strum 0.27.1", "thiserror 1.0.64", "winnow 0.6.20", ] @@ -2675,7 +2691,7 @@ checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ "getrandom 0.2.16", "libredox", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -2964,6 +2980,15 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +[[package]] +name = "strum" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" +dependencies = [ + "strum_macros 0.27.1", +] + [[package]] name = "strum_macros" version = "0.26.4" @@ -2977,6 +3002,19 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "strum_macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.89", +] + [[package]] name = "syn" version = "1.0.109" @@ -3051,11 +3089,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -3071,9 +3109,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2ec35b2..f680001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,9 +37,10 @@ members = [ "zluda_preload", "zluda_redirect", "zluda_sparse", + "compiler", ] -default-members = ["zluda", "zluda_ml", "zluda_inject", "zluda_redirect"] +default-members = ["zluda", "zluda_ml", "zluda_inject", "zluda_redirect", "compiler"] [profile.release-lto] inherits = "release" @@ -47,4 +48,4 @@ codegen-units = 1 lto = true [profile.dev.package.xtask] -opt-level = 2 +opt-level = 2 \ No newline at end of file diff --git a/comgr/Cargo.toml b/comgr/Cargo.toml index 171002e..7c432cd 100644 --- a/comgr/Cargo.toml +++ b/comgr/Cargo.toml @@ -9,3 +9,5 @@ edition = "2021" [dependencies] amd_comgr-sys = { path = "../ext/amd_comgr-sys" } libloading = "0.8" +ptx = { path = "../ptx" } +thiserror = "2.0.12" diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 08f3928..f5e4e30 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -181,6 +181,7 @@ pub fn compile_bitcode( main_buffer: &[u8], attributes_buffer: &[u8], ptx_impl: &[u8], + compiler_hook: Option<&dyn Fn(&Vec, String)>, ) -> Result, Error> { let bitcode_data_set = DataSet::new(comgr)?; let main_bitcode_data = Data::new(comgr, DataKind::Bc, c"zluda.bc", main_buffer)?; @@ -193,6 +194,14 @@ pub fn compile_bitcode( let linking_info = ActionInfo::new(comgr)?; let linked_data_set = comgr.do_action(ActionKind::LinkBcToBc, &linking_info, &bitcode_data_set)?; + if let Some(hook) = compiler_hook { + // Run compiler hook on human-readable LLVM IR + let data = linked_data_set.get_data(DataKind::Bc, 0)?; + let data = data.copy_content(comgr)?; + let data = ptx::bitcode_to_ir(data); + hook(&data, String::from("linked.ll")); + } + let compile_to_exec = ActionInfo::new(comgr)?; compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_language(Language::LlvmIr)?; @@ -231,7 +240,27 @@ pub fn compile_bitcode( &linked_data_set, )?; let executable = exec_data_set.get_data(DataKind::Executable, 0)?; - executable.copy_content(comgr) + let executable = executable.copy_content(comgr); + if let Some(hook) = compiler_hook { + // Run compiler hook for executable + hook( + executable.as_ref().unwrap_or(&Vec::new()), + String::from("elf"), + ); + + // Disassemble executable and run compiler hook + let action_info = ActionInfo::new(comgr)?; + action_info.set_isa_name(gcn_arch)?; + let disassembly = comgr.do_action( + ActionKind::DisassembleExecutableToSource, + &action_info, + &exec_data_set, + )?; + let disassembly = disassembly.get_data(DataKind::Source, 0)?; + let disassembly = disassembly.copy_content(comgr); + hook(&disassembly.unwrap_or(Vec::new()), String::from("asm")) + } + executable } pub fn get_clang_version(comgr: &Comgr) -> Result { @@ -334,7 +363,8 @@ impl Comgr { } } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] +#[error("Comgr error: {0:?}")] pub struct Error(pub ::std::num::NonZeroU32); impl Error { @@ -400,6 +430,7 @@ impl_into!( [ LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC, CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE, + DisassembleExecutableToSource => AMD_COMGR_ACTION_DISASSEMBLE_EXECUTABLE_TO_SOURCE, SourceToPreprocessor => AMD_COMGR_ACTION_SOURCE_TO_PREPROCESSOR ] ); diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml new file mode 100644 index 0000000..16dca14 --- /dev/null +++ b/compiler/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "compiler" +description = "ZLUDA offline compiler" +version = "0.0.0" +authors = ["Joƫlle van Essen "] +edition = "2021" + +[[bin]] +name = "zoc" +path = "src/main.rs" + +[dependencies] +amd_comgr-sys = { path = "../ext/amd_comgr-sys" } +bpaf = { version = "0.9.19", features = ["derive"] } +comgr = { path = "../comgr" } +hip_runtime-sys = { path = "../ext/hip_runtime-sys" } +ptx = { path = "../ptx" } +ptx_parser = { path = "../ptx_parser" } +thiserror = "2.0.12" + +[package.metadata.zluda] +debug_only = true \ No newline at end of file diff --git a/compiler/src/error.rs b/compiler/src/error.rs new file mode 100644 index 0000000..f5bfe11 --- /dev/null +++ b/compiler/src/error.rs @@ -0,0 +1,62 @@ +use std::ffi::FromBytesUntilNulError; +use std::io; +use std::str::Utf8Error; + +use hip_runtime_sys::hipErrorCode_t; +use ptx::TranslateError; +use ptx_parser::PtxError; + +#[derive(Debug, thiserror::Error)] +pub enum CompilerError { + #[error("HIP error code: {0:?}")] + HipError(hipErrorCode_t), + #[error(transparent)] + ComgrError(#[from] comgr::Error), + #[error(transparent)] + IoError(#[from] io::Error), + #[error(transparent)] + Utf8Error(#[from] Utf8Error), + #[error(transparent)] + FromBytesUntilNulError(#[from] FromBytesUntilNulError), + #[error("{message}")] + GenericError { + #[source] + cause: Option>, + message: String, + }, +} + +impl From for CompilerError { + fn from(error_code: hipErrorCode_t) -> Self { + CompilerError::HipError(error_code) + } +} + +impl From>> for CompilerError { + fn from(causes: Vec) -> Self { + let errors: Vec = causes + .iter() + .map(|e| { + let msg = match e { + PtxError::UnrecognizedStatement(value) + | PtxError::UnrecognizedDirective(value) => value.to_string(), + other => other.to_string(), + }; + format!("PtxError::{}: {}", e.as_ref(), msg) + }) + .collect(); + let message = errors.join("\n"); + CompilerError::GenericError { + cause: None, + message, + } + } +} + +impl From for CompilerError { + fn from(cause: TranslateError) -> Self { + let message = format!("PTX TranslateError::{}", cause.as_ref()); + let cause = Some(Box::new(cause) as Box); + CompilerError::GenericError { cause, message } + } +} diff --git a/compiler/src/main.rs b/compiler/src/main.rs new file mode 100644 index 0000000..aa5007d --- /dev/null +++ b/compiler/src/main.rs @@ -0,0 +1,136 @@ +use std::ffi::CStr; +use std::fs::{self, File}; +use std::io::{self, Write}; +use std::path::{Path, PathBuf}; +use std::process::ExitCode; +use std::str; +use std::{env, mem}; + +use bpaf::Bpaf; + +mod error; +use error::CompilerError; + +const DEFAULT_ARCH: &'static str = "gfx1100"; + +#[derive(Debug, Clone, Bpaf)] +#[bpaf(options, version)] +pub struct Options { + #[bpaf(argument("output-dir"))] + /// Output directory + output_dir: Option, + + #[bpaf(long("arch"))] + /// Target architecture + arch: Option, + + #[bpaf(positional("filename"))] + /// PTX file + ptx_path: String, +} + +fn main() -> ExitCode { + if let Err(e) = main_core() { + eprintln!("Error: {}", e); + return ExitCode::FAILURE; + } + ExitCode::SUCCESS +} + +fn main_core() -> Result<(), CompilerError> { + let opts = options().run(); + let comgr = comgr::Comgr::new()?; + + let ptx_path = Path::new(&opts.ptx_path).to_path_buf(); + let filename_base = ptx_path + .file_name() + .map(|osstr| osstr.to_str().unwrap_or("output")) + .unwrap_or("output"); + + let mut output_path = match opts.output_dir { + Some(value) => value, + None => match ptx_path.parent() { + Some(dir) => dir.to_path_buf(), + None => env::current_dir()?, + }, + }; + output_path.push(filename_base); + + let arch: String = match opts.arch { + Some(s) => s, + None => get_gpu_arch() + .map(String::from) + .unwrap_or(DEFAULT_ARCH.to_owned()), + }; + + let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; + let ptx = str::from_utf8(&ptx).map_err(CompilerError::from)?; + let llvm = ptx_to_llvm(ptx).map_err(CompilerError::from)?; + + write_to_file(&llvm.llvm_ir, output_path.with_extension("ll").as_path())?; + + let comgr_hook = |bytes: &Vec, extension: String| { + let output_path = output_path.with_extension(extension); + write_to_file(bytes, &output_path).unwrap(); + }; + + comgr::compile_bitcode( + &comgr, + &arch, + &llvm.bitcode, + &llvm.attributes_bitcode, + &llvm.linked_bitcode, + Some(&comgr_hook), + ) + .map_err(CompilerError::from)?; + + Ok(()) +} + +fn ptx_to_llvm(ptx: &str) -> Result { + let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; + let module = ptx::to_llvm_module( + ast, + ptx::Attributes { + clock_rate: 2124000, + }, + ) + .map_err(CompilerError::from)?; + let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); + let linked_bitcode = module.linked_bitcode().to_vec(); + let attributes_bitcode = module.attributes_ir.write_bitcode_to_memory().to_vec(); + let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec(); + Ok(LLVMArtifacts { + bitcode, + linked_bitcode, + attributes_bitcode, + llvm_ir, + }) +} + +#[derive(Debug)] +struct LLVMArtifacts { + bitcode: Vec, + linked_bitcode: Vec, + attributes_bitcode: Vec, + llvm_ir: Vec, +} + +fn get_gpu_arch() -> Result<&'static str, CompilerError> { + use hip_runtime_sys::*; + unsafe { hipInit(0) }?; + let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; + unsafe { hipGetDevicePropertiesR0600(&mut dev_props, 0) }?; + let gcn_arch_name = &dev_props.gcnArchName; + let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; + let gcn_arch_name = gcn_arch_name.to_str(); + gcn_arch_name.map_err(CompilerError::from) +} + +fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> { + let mut file = File::create(path)?; + file.write_all(content)?; + file.flush()?; + println!("Wrote to {}", path.to_str().unwrap()); + Ok(()) +} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 9edc7aa..8a49881 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -2,5 +2,7 @@ pub(crate) mod pass; #[cfg(test)] mod test; +pub use pass::llvm::bitcode_to_ir; pub use pass::to_llvm_module; pub use pass::Attributes; +pub use pass::TranslateError; diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index a40e38a..3513e88 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -2,11 +2,13 @@ pub(super) mod attributes; pub(super) mod emit; use std::ffi::CStr; +use std::mem; use std::ops::Deref; use std::ptr; use crate::pass::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; +use llvm_zluda::bit_reader::LLVMParseBitcodeInContext2; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::core::*; use llvm_zluda::prelude::*; @@ -86,6 +88,20 @@ impl Drop for Module { } } +pub fn bitcode_to_ir(bitcode: Vec) -> Vec { + let bitcode: Vec = bitcode.iter().map(|&v| i8::from_ne_bytes([v])).collect(); + let memory_buffer: LLVMMemoryBufferRef = unsafe { + LLVMCreateMemoryBufferWithMemoryRangeCopy(bitcode.as_ptr(), bitcode.len(), ptr::null()) + }; + let context = Context::new(); + let mut module: LLVMModuleRef = unsafe { mem::zeroed() }; + unsafe { + LLVMParseBitcodeInContext2(context.get(), memory_buffer, &mut module); + } + let module = Module(module); + module.print_module_to_string().to_bytes().to_vec() +} + pub struct Message(&'static CStr); impl Drop for Message { @@ -103,6 +119,10 @@ impl std::fmt::Debug for Message { } impl Message { + pub fn to_bytes(&self) -> &[u8] { + self.0.to_bytes() + } + pub fn to_str(&self) -> &str { self.0.to_str().unwrap().trim() } diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 9139c49..eeb2c7f 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -19,7 +19,7 @@ mod insert_explicit_load_store; mod insert_implicit_conversions2; mod insert_post_saturation; mod instruction_mode_to_global_mode; -mod llvm; +pub mod llvm; mod normalize_basic_blocks; mod normalize_identifiers2; mod normalize_predicates2; @@ -32,7 +32,7 @@ static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl. const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_"; quick_error! { - #[derive(Debug)] + #[derive(Debug, strum_macros::AsRefStr)] pub enum TranslateError { UnknownSymbol(symbol: String) { display("Unknown symbol: \"{}\"", symbol) diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 04f37ae..4eb4c99 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -649,6 +649,7 @@ fn run_hip + Copy + Debug, Output: From + Copy + Debug + Def &*module.llvm_ir.write_bitcode_to_memory(), &*module.attributes_ir.write_bitcode_to_memory(), module.linked_bitcode(), + None, ) .unwrap(); let mut module = unsafe { mem::zeroed() }; diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml index 9032de5..3b96ac0 100644 --- a/ptx_parser/Cargo.toml +++ b/ptx_parser/Cargo.toml @@ -7,11 +7,12 @@ edition = "2021" [lib] [dependencies] +bitflags = "1.2" +derive_more = { version = "1", features = ["display"] } logos = "0.14" +ptx_parser_macros = { path = "../ptx_parser_macros" } +rustc-hash = "2.0.0" +strum = { version = "0.27.1", features = ["derive"] } +thiserror = "1.0" winnow = { version = "0.6.18" } #winnow = { version = "0.6.18", features = ["debug"] } -ptx_parser_macros = { path = "../ptx_parser_macros" } -thiserror = "1.0" -bitflags = "1.2" -rustc-hash = "2.0.0" -derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index 1908031..d788604 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1344,7 +1344,7 @@ impl ast::ParsedOperand { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, strum::AsRefStr)] pub enum PtxError<'input> { #[error("{source}")] ParseInt { diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index f7d6763..0c2c5b2 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -147,6 +147,7 @@ fn compile_from_ptx_and_cache( &*llvm_module.llvm_ir.write_bitcode_to_memory(), &*llvm_module.attributes_ir.write_bitcode_to_memory(), llvm_module.linked_bitcode(), + None, ) .map_err(|_| CUerror::UNKNOWN)?; if let Some((cache, key)) = cache_with_key {