mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
zoc: Specify output type
This commit is contained in:
parent
c6f240e78d
commit
084339e141
3 changed files with 133 additions and 37 deletions
18
Cargo.lock
generated
18
Cargo.lock
generated
|
@ -98,9 +98,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
|
|||
|
||||
[[package]]
|
||||
name = "bpaf"
|
||||
version = "0.9.17"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd343f000c4472dc4557e093f314c5210375d3e9349ad866fee0ddb7f4ed10bf"
|
||||
checksum = "4de4d74c5891642753c67ab88f58d971a68dd98673b69689a8c24ce3ec78a412"
|
||||
dependencies = [
|
||||
"bpaf_derive",
|
||||
]
|
||||
|
@ -151,7 +151,7 @@ dependencies = [
|
|||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.11",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -924,11 +924,11 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.11"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
|
||||
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
|
||||
dependencies = [
|
||||
"thiserror-impl 2.0.11",
|
||||
"thiserror-impl 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -944,9 +944,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "2.0.11"
|
||||
version = "2.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
|
||||
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1229,5 +1229,5 @@ dependencies = [
|
|||
"hip_runtime-sys",
|
||||
"ptx",
|
||||
"ptx_parser",
|
||||
"thiserror 1.0.64",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
|
|
@ -7,9 +7,9 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
amd_comgr-sys = { path = "../ext/amd_comgr-sys" }
|
||||
bpaf = { version = "0.9.17", features = ["derive"] }
|
||||
bpaf = { version = "0.9.18", features = ["derive"] }
|
||||
comgr = { path = "../comgr" }
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
ptx = { path = "../ptx" }
|
||||
ptx_parser = { path = "../ptx_parser" }
|
||||
thiserror = "1.0"
|
||||
thiserror = "2.0.12"
|
148
zoc/src/main.rs
148
zoc/src/main.rs
|
@ -1,10 +1,11 @@
|
|||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::ffi::CStr;
|
||||
use std::ffi::{CStr, OsStr};
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, Write};
|
||||
use std::mem::MaybeUninit;
|
||||
use std::path::Path;
|
||||
use std::str;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::{self, FromStr};
|
||||
|
||||
use amd_comgr_sys::amd_comgr_status_s;
|
||||
use bpaf::Bpaf;
|
||||
|
@ -14,42 +15,71 @@ use ptx_parser::PtxError;
|
|||
#[derive(Debug, Clone, Bpaf)]
|
||||
#[bpaf(options, version)]
|
||||
pub struct Options {
|
||||
#[bpaf(positional("file"))]
|
||||
#[bpaf(external(output_type), optional)]
|
||||
output_type: Option<OutputType>,
|
||||
|
||||
#[bpaf(positional("filename"))]
|
||||
/// PTX file
|
||||
file: String,
|
||||
ptx_path: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let options = options().run();
|
||||
let ptx_path = options.file;
|
||||
let ptx_path = Path::new(&ptx_path);
|
||||
let opts = options().run();
|
||||
|
||||
let ptx = fs::read(ptx_path)?;
|
||||
let output_type = match opts.output_type {
|
||||
Some(t) => t,
|
||||
None => OutputType::Elf,
|
||||
};
|
||||
|
||||
match output_type {
|
||||
OutputType::LlvmIrLinked | OutputType::Assembly => todo!(),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let ptx_path = Path::new(&opts.ptx_path).to_path_buf();
|
||||
check_path(&ptx_path)?;
|
||||
|
||||
let output_path = get_output_path(&ptx_path, &output_type)?;
|
||||
check_path(&output_path)?;
|
||||
|
||||
let ptx = fs::read(&ptx_path)?;
|
||||
let ptx = str::from_utf8(&ptx)?;
|
||||
let llvm = ptx_to_llvm(ptx)?;
|
||||
|
||||
let ll_path = ptx_path.with_extension("ll");
|
||||
write_to_file(&llvm.llvm_ir, &ll_path)?;
|
||||
|
||||
if output_type == OutputType::LlvmIrPreLinked {
|
||||
write_to_file(&llvm.llvm_ir, &output_path)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let elf = llvm_to_elf(&llvm)?;
|
||||
let elf_path = ptx_path.with_extension("elf");
|
||||
write_to_file(&elf, &elf_path)?;
|
||||
write_to_file(&elf, &output_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
|
||||
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
|
||||
errors.join("\n")
|
||||
}
|
||||
|
||||
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, Box<dyn Error>> {
|
||||
let ast = ptx_parser::parse_module_checked(ptx).map_err(join_ptx_errors)?;
|
||||
let module = ptx::to_llvm_module(ast)?;
|
||||
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
|
||||
let linked_bitcode = module.linked_bitcode().to_vec();
|
||||
let llvm_ir = module.llvm_ir.print_module_to_string().to_bytes().to_vec();
|
||||
Ok(LLVMArtifacts { bitcode, linked_bitcode, llvm_ir })
|
||||
Ok(LLVMArtifacts {
|
||||
bitcode,
|
||||
linked_bitcode,
|
||||
llvm_ir,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LLVMArtifacts {
|
||||
bitcode: Vec<u8>,
|
||||
linked_bitcode: Vec<u8>,
|
||||
llvm_ir: Vec<u8>,
|
||||
}
|
||||
|
||||
fn join_ptx_errors(vector: Vec<PtxError>) -> String {
|
||||
let errors: Vec<String> = vector.iter().map(PtxError::to_string).collect();
|
||||
errors.join("\n")
|
||||
}
|
||||
|
||||
fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
|
||||
|
@ -63,20 +93,86 @@ fn llvm_to_elf(llvm: &LLVMArtifacts) -> Result<Vec<u8>, ElfError> {
|
|||
comgr::compile_bitcode(gcn_arch, &llvm.bitcode, &llvm.linked_bitcode).map_err(ElfError::from)
|
||||
}
|
||||
|
||||
fn check_path(path: &Path) -> Result<(), Box<dyn Error>> {
|
||||
if path.try_exists()? && !path.is_file() {
|
||||
let error = CheckPathError(path.to_path_buf());
|
||||
let error = Box::new(error);
|
||||
return Err(error);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_output_path(
|
||||
ptx_path: &PathBuf,
|
||||
output_type: &OutputType,
|
||||
) -> Result<PathBuf, Box<dyn Error>> {
|
||||
let current_dir = env::current_dir()?;
|
||||
let output_path = current_dir.join(
|
||||
ptx_path
|
||||
.as_path()
|
||||
.file_stem()
|
||||
.unwrap_or(OsStr::new("output")),
|
||||
);
|
||||
let output_path = output_path.with_extension(output_type.extension());
|
||||
Ok(output_path)
|
||||
}
|
||||
|
||||
fn write_to_file(content: &[u8], path: &Path) -> io::Result<()> {
|
||||
let mut file = File::create(path)?;
|
||||
file.write_all(content)?;
|
||||
file.flush()?;
|
||||
println!("Wrote to {}", path.to_str().unwrap());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LLVMArtifacts {
|
||||
bitcode: Vec<u8>,
|
||||
linked_bitcode: Vec<u8>,
|
||||
llvm_ir: Vec<u8>,
|
||||
#[derive(Bpaf, Clone, Copy, Debug, PartialEq)]
|
||||
enum OutputType {
|
||||
/// Produce pre-linked LLVM IR
|
||||
#[bpaf(long("ll"))]
|
||||
LlvmIrPreLinked,
|
||||
/// Produce linked LLVM IR
|
||||
#[bpaf(long("linked-ll"))]
|
||||
LlvmIrLinked,
|
||||
/// Produce ELF binary (default)
|
||||
Elf,
|
||||
/// Produce assembly
|
||||
#[bpaf(long("asm"))]
|
||||
Assembly,
|
||||
}
|
||||
|
||||
impl OutputType {
|
||||
fn extension(self) -> String {
|
||||
match self {
|
||||
OutputType::LlvmIrPreLinked | OutputType::LlvmIrLinked => "ll",
|
||||
OutputType::Assembly => "asm",
|
||||
OutputType::Elf => "elf",
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for OutputType {
|
||||
type Err = ParseOutputTypeError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"ll" => Ok(Self::LlvmIrPreLinked),
|
||||
"ll_linked" => Ok(Self::LlvmIrLinked),
|
||||
"elf" => Ok(Self::Elf),
|
||||
"asm" => Ok(Self::Assembly),
|
||||
_ => Err(ParseOutputTypeError(s.into())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Not a regular file: {0}")]
|
||||
struct CheckPathError(PathBuf);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Invalid output type: {0}")]
|
||||
struct ParseOutputTypeError(String);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum ElfError {
|
||||
#[error("HIP error: {0:?}")]
|
||||
|
@ -95,4 +191,4 @@ impl From<amd_comgr_status_s> for ElfError {
|
|||
fn from(value: amd_comgr_status_s) -> Self {
|
||||
ElfError::AmdComgrError(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue