zoc: Specify output type

This commit is contained in:
Joëlle van Essen 2025-03-13 17:12:43 +01:00
parent c6f240e78d
commit 084339e141
No known key found for this signature in database
GPG key ID: 28D3B5CDD4B43882
3 changed files with 133 additions and 37 deletions

18
Cargo.lock generated
View file

@ -98,9 +98,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]]
name = "bpaf"
version = "0.9.17"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd343f000c4472dc4557e093f314c5210375d3e9349ad866fee0ddb7f4ed10bf"
checksum = "4de4d74c5891642753c67ab88f58d971a68dd98673b69689a8c24ce3ec78a412"
dependencies = [
"bpaf_derive",
]
@ -151,7 +151,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@ -924,11 +924,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.11"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl 2.0.11",
"thiserror-impl 2.0.12",
]
[[package]]
@ -944,9 +944,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.11"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
@ -1229,5 +1229,5 @@ dependencies = [
"hip_runtime-sys",
"ptx",
"ptx_parser",
"thiserror 1.0.64",
"thiserror 2.0.12",
]

View file

@ -7,9 +7,9 @@ edition = "2021"
[dependencies]
amd_comgr-sys = { path = "../ext/amd_comgr-sys" }
bpaf = { version = "0.9.17", features = ["derive"] }
bpaf = { version = "0.9.18", features = ["derive"] }
comgr = { path = "../comgr" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
ptx = { path = "../ptx" }
ptx_parser = { path = "../ptx_parser" }
thiserror = "1.0"
thiserror = "2.0.12"

View file

@ -1,10 +1,11 @@
use std::env;
use std::error::Error;
use std::ffi::CStr;
use std::ffi::{CStr, OsStr};
use std::fs::{self, File};
use std::io::{self, Write};
use std::mem::MaybeUninit;
use std::path::Path;
use std::str;
use std::path::{Path, PathBuf};
use std::str::{self, FromStr};
use amd_comgr_sys::amd_comgr_status_s;
use bpaf::Bpaf;
@ -14,42 +15,71 @@ use ptx_parser::PtxError;
#[derive(Debug, Clone, Bpaf)]
#[bpaf(options, version)]
pub struct Options {
#[bpaf(positional("file"))]
#[bpaf(external(output_type), optional)]
output_type: Option<OutputType>,
#[bpaf(positional("filename"))]
/// PTX file
file: String,
ptx_path: String,
}
fn main() -> Result<(), Box<dyn Error>> {
let options = options().run();
let ptx_path = options.file;
let ptx_path = Path::new(&ptx_path);
let opts = options().run();
let ptx = fs::read(ptx_path)?;
let output_type = match opts.output_type {
Some(t) => t,
None => OutputType::Elf,
};
match output_type {
OutputType::LlvmIrLinked | OutputType::Assembly => todo!(),
_ => {}
}
let ptx_path = Path::new(&opts.ptx_path).to_path_buf();
check_path(&ptx_path)?;
let output_path = get_output_path(&ptx_path, &output_type)?;
check_path(&output_path)?;
let ptx = fs::read(&ptx_path)?;
let ptx = str::from_utf8(&ptx)?;
let llvm = ptx_to_llvm(ptx)?;
let ll_path = ptx_path.with_extension("ll");
write_to_file(&llvm.llvm_ir, &ll_path)?;
if output_type == OutputType::LlvmIrPreLinked {
write_to_file(&llvm.llvm_ir, &output_path)?;
return Ok(());
}
let elf = llvm_to_elf(&llvm)?;
let elf_path = ptx_path.with_extension("elf");
write_to_file(&elf, &elf_path)?;
write_to_file(&elf, &output_path)?;
Ok(())
}
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
errors.join("\n")
}
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, Box<dyn Error>> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?;
let module = ptx::to_llvm_module(ast)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
let linked_bitcode = module.linked_bitcode().to_vec();
let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec();
Ok(LLVMArtifacts { bitcode, linked_bitcode, llvm_ir })
Ok(LLVMArtifacts {
bitcode,
linked_bitcode,
llvm_ir,
})
}
#[derive(Debug)]
struct LLVMArtifacts {
bitcode: Vec<u8>,
linked_bitcode: Vec<u8>,
llvm_ir: Vec<u8>,
}
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
errors.join("\n")
}
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
@ -63,20 +93,86 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from)
}
fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
if path.try_exists()? && !path.is_file() {
let error = CheckPathError(path.to_path_buf());
let error = Box::new(error);
return Err(error);
}
Ok(())
}
fn get_output_path(
ptx_path: &PathBuf,
output_type: &OutputType,
) -> Result<PathBuf, Box<dyn Error>> {
let current_dir = env::current_dir()?;
let output_path = current_dir.join(
ptx_path
.as_path()
.file_stem()
.unwrap_or(OsStr::new("output")),
);
let output_path = output_path.with_extension(output_type.extension());
Ok(output_path)
}
fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> {
let mut file = File::create(path)?;
file.write_all(content)?;
file.flush()?;
println!("Wrote to {}", path.to_str().unwrap());
Ok(())
}
#[derive(Debug)]
struct LLVMArtifacts {
bitcode: Vec<u8>,
linked_bitcode: Vec<u8>,
llvm_ir: Vec<u8>,
#[derive(Bpaf, Clone, Copy, Debug, PartialEq)]
enum OutputType {
/// Produce pre-linked LLVM IR
#[bpaf(long("ll"))]
LlvmIrPreLinked,
/// Produce linked LLVM IR
#[bpaf(long("linked-ll"))]
LlvmIrLinked,
/// Produce ELF binary (default)
Elf,
/// Produce assembly
#[bpaf(long("asm"))]
Assembly,
}
impl OutputType {
fn extension(self) -> String {
match self {
OutputType::LlvmIrPreLinked | OutputType::LlvmIrLinked => "ll",
OutputType::Assembly => "asm",
OutputType::Elf => "elf",
}
.into()
}
}
impl FromStr for OutputType {
type Err = ParseOutputTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"ll" => Ok(Self::LlvmIrPreLinked),
"ll_linked" => Ok(Self::LlvmIrLinked),
"elf" => Ok(Self::Elf),
"asm" => Ok(Self::Assembly),
_ => Err(ParseOutputTypeError(s.into())),
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Not a regular file: {0}")]
struct CheckPathError(PathBuf);
#[derive(Debug, thiserror::Error)]
#[error("Invalid output type: {0}")]
struct ParseOutputTypeError(String);
#[derive(Debug, thiserror::Error)]
enum ElfError {
#[error("HIP error: {0:?}")]
@ -95,4 +191,4 @@ impl From<amd_comgr_status_s> for ElfError {
fn from(value: amd_comgr_status_s) -> Self {
ElfError::AmdComgrError(value)
}
}
}