Have zoc time compiler passes

This commit is contained in:
Andrzej Janik 2025-09-10 18:21:48 +00:00
commit 5e8a930be6
6 changed files with 42 additions and 16 deletions

View file

@ -6,6 +6,7 @@ use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use std::str;
use std::time::Instant;
use std::{env, mem};
mod error;
@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> {
let arch: String = match opts.arch {
Some(s) => s,
None => {
(|| {
let runtime = hip::Runtime::load()?;
runtime.init()?;
get_gpu_arch(&runtime)
})()
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned())
/*
get_gpu_arch(&mut dev_props)
.map(String::from)
.unwrap_or(DEFAULT_ARCH.to_owned())
*/
}
None => (|| {
let runtime = hip::Runtime::load()?;
runtime.init()?;
get_gpu_arch(&runtime)
})()
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned()),
};
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> {
write_to_file(bytes, &output_path).unwrap();
};
let mut start = Instant::now();
comgr::compile_bitcode(
&comgr,
&arch,
@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> {
Some(&comgr_hook),
)
.map_err(CompilerError::from)?;
report_pass_time("compile_bitcode", &mut start);
Ok(())
}
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
let mut start = Instant::now();
let module = ptx::to_llvm_module(
ast,
ptx::Attributes {
clock_rate: 2124000,
},
|pass| {
report_pass_time(pass, &mut start);
},
)
.map_err(CompilerError::from)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
})
}
fn report_pass_time(pass: &str, start: &mut Instant) {
let duration = start.elapsed();
println!("Pass {:?} took {:?}", pass, duration);
*start = Instant::now();
}
#[derive(Debug)]
struct LLVMArtifacts {
bitcode: Vec<u8>,

View file

@ -60,31 +60,48 @@ pub struct Attributes {
pub fn to_llvm_module<'input>(
ast: ast::Module<'input>,
attributes: Attributes,
mut on_pass_end: impl FnMut(&str),
) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
on_pass_end("normalize_identifiers2");
let directives = replace_known_functions::run(&mut flat_resolver, directives);
on_pass_end("replace_known_functions");
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_predicates2");
let directives = resolve_function_pointers::run(directives)?;
on_pass_end("resolve_function_pointers");
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
on_pass_end("fix_special_registers");
let directives = expand_operands::run(&mut flat_resolver, directives)?;
on_pass_end("expand_operands");
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
on_pass_end("insert_post_saturation");
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
on_pass_end("deparamize_functions");
let directives =
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions_fp_required");
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_basic_blocks");
let directives = remove_unreachable_basic_blocks::run(directives)?;
on_pass_end("remove_unreachable_basic_blocks");
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
on_pass_end("instruction_mode_to_global_mode");
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
on_pass_end("insert_explicit_load_store");
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
on_pass_end("insert_implicit_conversions2");
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions");
let directives = hoist_globals::run(directives)?;
on_pass_end("hoist_globals");
let context = llvm::Context::new();
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
let attributes_ir = llvm::attributes::run(&context, attributes)?;
on_pass_end("emit_llvm");
Ok(Module {
llvm_ir,
attributes_ir,

View file

@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let attributes = pass::Attributes {
clock_rate: 2124000,
};
crate::to_llvm_module(ast, attributes)?;
crate::to_llvm_module(ast, attributes, |_| {})?;
Ok(())
}

View file

@ -522,6 +522,7 @@ fn test_hip_assert<
pass::Attributes {
clock_rate: 2124000,
},
|_| {},
)
.unwrap();
let name = CString::new(name)?;
@ -542,6 +543,7 @@ fn test_llvm_assert(
pass::Attributes {
clock_rate: 2124000,
},
|_| {},
)
.unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string();

View file

@ -42,6 +42,7 @@ fn main() {
ptx::Attributes {
clock_rate: clock_rate as u32,
},
|_| {},
)
.unwrap();
let elf_binary = comgr::compile_bitcode(

View file

@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache(
} else {
ptx_parser::parse_module_unchecked(text)
};
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode(
comgr,
gcn_arch,