diff --git a/Cargo.lock b/Cargo.lock index c140896..3c809f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bpaf" -version = "0.9.17" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd343f000c4472dc4557e093f314c5210375d3e9349ad866fee0ddb7f4ed10bf" +checksum = "4de4d74c5891642753c67ab88f58d971a68dd98673b69689a8c24ce3ec78a412" dependencies = [ "bpaf_derive", ] @@ -151,7 +151,7 @@ dependencies = [ "semver", "serde", "serde_json", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -924,11 +924,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -944,9 +944,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -1229,5 +1229,5 @@ dependencies = [ "hip_runtime-sys", "ptx", "ptx_parser", - "thiserror 1.0.64", + "thiserror 2.0.12", ] diff --git a/zoc/Cargo.toml b/zoc/Cargo.toml index 917d498..146252b 100644 --- a/zoc/Cargo.toml +++ b/zoc/Cargo.toml @@ -7,9 +7,9 @@ edition = "2021" [dependencies] amd_comgr-sys = { path = "../ext/amd_comgr-sys" } -bpaf = { version = "0.9.17", features = ["derive"] } +bpaf = { version = "0.9.18", features = ["derive"] } comgr = { path = "../comgr" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" } ptx = { path = "../ptx" } ptx_parser = { path = "../ptx_parser" } -thiserror = "1.0" \ No newline at end of file +thiserror = "2.0.12" \ No newline at end of file diff --git a/zoc/src/main.rs b/zoc/src/main.rs index 6d7b339..a07aadd 100644 --- a/zoc/src/main.rs +++ b/zoc/src/main.rs @@ -1,10 +1,11 @@ +use std::env; use std::error::Error; -use std::ffi::CStr; +use std::ffi::{CStr, OsStr}; use std::fs::{self, File}; use std::io::{self, Write}; use std::mem::MaybeUninit; -use std::path::Path; -use std::str; +use std::path::{Path, PathBuf}; +use std::str::{self, FromStr}; use amd_comgr_sys::amd_comgr_status_s; use bpaf::Bpaf; @@ -14,42 +15,71 @@ use ptx_parser::PtxError; #[derive(Debug, Clone, Bpaf)] #[bpaf(options, version)] pub struct Options { - #[bpaf(positional("file"))] + #[bpaf(external(output_type), optional)] + output_type: Option, + + #[bpaf(positional("filename"))] /// PTX file - file: String, + ptx_path: String, } fn main() -> Result<(), Box> { - let options = options().run(); - let ptx_path = options.file; - let ptx_path = Path::new(&ptx_path); + let opts = options().run(); - let ptx = fs::read(ptx_path)?; + let output_type = match opts.output_type { + Some(t) => t, + None => OutputType::Elf, + }; + + match output_type { + OutputType::LlvmIrLinked | OutputType::Assembly => todo!(), + _ => {} + } + + let ptx_path = Path::new(&opts.ptx_path).to_path_buf(); + check_path(&ptx_path)?; + + let output_path = get_output_path(&ptx_path, &output_type)?; + check_path(&output_path)?; + + let ptx = fs::read(&ptx_path)?; let ptx = str::from_utf8(&ptx)?; let llvm = ptx_to_llvm(ptx)?; - - let ll_path = ptx_path.with_extension("ll"); - write_to_file(&llvm.llvm_ir, &ll_path)?; + + if output_type == OutputType::LlvmIrPreLinked { + write_to_file(&llvm.llvm_ir, &output_path)?; + return Ok(()); + } let elf = llvm_to_elf(&llvm)?; - let elf_path = ptx_path.with_extension("elf"); - write_to_file(&elf, &elf_path)?; + write_to_file(&elf, &output_path)?; Ok(()) } -fn join_ptx_errors(vector: Vec) -> String { - let errors: Vec = vector.iter().map(PtxError::to_string).collect(); - errors.join("\n") -} - fn ptx_to_llvm(ptx: &str) -> Result> { let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?; let module = ptx::to_llvm_module(ast)?; let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let linked_bitcode = module.linked_bitcode().to_vec(); let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec(); - Ok(LLVMArtifacts { bitcode, linked_bitcode, llvm_ir }) + Ok(LLVMArtifacts { + bitcode, + linked_bitcode, + llvm_ir, + }) +} + +#[derive(Debug)] +struct LLVMArtifacts { + bitcode: Vec, + linked_bitcode: Vec, + llvm_ir: Vec, +} + +fn join_ptx_errors(vector: Vec) -> String { + let errors: Vec = vector.iter().map(PtxError::to_string).collect(); + errors.join("\n") } fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result, ElfError> { @@ -63,20 +93,86 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result, ElfError> { comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from) } +fn check_path(path: &Path) -> Result<(), Box> { + if path.try_exists()? && !path.is_file() { + let error = CheckPathError(path.to_path_buf()); + let error = Box::new(error); + return Err(error); + } + Ok(()) +} + +fn get_output_path( + ptx_path: &PathBuf, + output_type: &OutputType, +) -> Result> { + let current_dir = env::current_dir()?; + let output_path = current_dir.join( + ptx_path + .as_path() + .file_stem() + .unwrap_or(OsStr::new("output")), + ); + let output_path = output_path.with_extension(output_type.extension()); + Ok(output_path) +} + fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> { let mut file = File::create(path)?; file.write_all(content)?; file.flush()?; + println!("Wrote to {}", path.to_str().unwrap()); Ok(()) } -#[derive(Debug)] -struct LLVMArtifacts { - bitcode: Vec, - linked_bitcode: Vec, - llvm_ir: Vec, +#[derive(Bpaf, Clone, Copy, Debug, PartialEq)] +enum OutputType { + /// Produce pre-linked LLVM IR + #[bpaf(long("ll"))] + LlvmIrPreLinked, + /// Produce linked LLVM IR + #[bpaf(long("linked-ll"))] + LlvmIrLinked, + /// Produce ELF binary (default) + Elf, + /// Produce assembly + #[bpaf(long("asm"))] + Assembly, } +impl OutputType { + fn extension(self) -> String { + match self { + OutputType::LlvmIrPreLinked | OutputType::LlvmIrLinked => "ll", + OutputType::Assembly => "asm", + OutputType::Elf => "elf", + } + .into() + } +} + +impl FromStr for OutputType { + type Err = ParseOutputTypeError; + + fn from_str(s: &str) -> Result { + match s { + "ll" => Ok(Self::LlvmIrPreLinked), + "ll_linked" => Ok(Self::LlvmIrLinked), + "elf" => Ok(Self::Elf), + "asm" => Ok(Self::Assembly), + _ => Err(ParseOutputTypeError(s.into())), + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Not a regular file: {0}")] +struct CheckPathError(PathBuf); + +#[derive(Debug, thiserror::Error)] +#[error("Invalid output type: {0}")] +struct ParseOutputTypeError(String); + #[derive(Debug, thiserror::Error)] enum ElfError { #[error("HIP error: {0:?}")] @@ -95,4 +191,4 @@ impl From for ElfError { fn from(value: amd_comgr_status_s) -> Self { ElfError::AmdComgrError(value) } -} \ No newline at end of file +}