Have zoc time compiler passes

This commit is contained in:
Andrzej Janik 2025-09-10 18:21:48 +00:00
commit 5e8a930be6
6 changed files with 42 additions and 16 deletions

View file

@ -6,6 +6,7 @@ use std::io::{self, Write};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::ExitCode; use std::process::ExitCode;
use std::str; use std::str;
use std::time::Instant;
use std::{env, mem}; use std::{env, mem};
mod error; mod error;
@ -57,19 +58,12 @@ fn main_core() -> Result<(), CompilerError> {
let arch: String = match opts.arch { let arch: String = match opts.arch {
Some(s) => s, Some(s) => s,
None => { None => (|| {
(|| { let runtime = hip::Runtime::load()?;
let runtime = hip::Runtime::load()?; runtime.init()?;
runtime.init()?; get_gpu_arch(&runtime)
get_gpu_arch(&runtime) })()
})() .unwrap_or_else(|_| DEFAULT_ARCH.to_owned()),
.unwrap_or_else(|_| DEFAULT_ARCH.to_owned())
/*
get_gpu_arch(&mut dev_props)
.map(String::from)
.unwrap_or(DEFAULT_ARCH.to_owned())
*/
}
}; };
let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?;
@ -83,6 +77,7 @@ fn main_core() -> Result<(), CompilerError> {
write_to_file(bytes, &output_path).unwrap(); write_to_file(bytes, &output_path).unwrap();
}; };
let mut start = Instant::now();
comgr::compile_bitcode( comgr::compile_bitcode(
&comgr, &comgr,
&arch, &arch,
@ -92,17 +87,22 @@ fn main_core() -> Result<(), CompilerError> {
Some(&comgr_hook), Some(&comgr_hook),
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
report_pass_time("compile_bitcode", &mut start);
Ok(()) Ok(())
} }
fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> { fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?; let ast = ptx_parser::parse_module_checked(ptx).map_err(CompilerError::from)?;
let mut start = Instant::now();
let module = ptx::to_llvm_module( let module = ptx::to_llvm_module(
ast, ast,
ptx::Attributes { ptx::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|pass| {
report_pass_time(pass, &mut start);
},
) )
.map_err(CompilerError::from)?; .map_err(CompilerError::from)?;
let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec(); let bitcode = module.llvm_ir.write_bitcode_to_memory().to_vec();
@ -117,6 +117,12 @@ fn ptx_to_llvm(ptx: &str) -> Result<LLVMArtifacts, CompilerError> {
}) })
} }
fn report_pass_time(pass: &str, start: &mut Instant) {
let duration = start.elapsed();
println!("Pass {:?} took {:?}", pass, duration);
*start = Instant::now();
}
#[derive(Debug)] #[derive(Debug)]
struct LLVMArtifacts { struct LLVMArtifacts {
bitcode: Vec<u8>, bitcode: Vec<u8>,

View file

@ -60,31 +60,48 @@ pub struct Attributes {
pub fn to_llvm_module<'input>( pub fn to_llvm_module<'input>(
ast: ast::Module<'input>, ast: ast::Module<'input>,
attributes: Attributes, attributes: Attributes,
mut on_pass_end: impl FnMut(&str),
) -> Result<Module, TranslateError> { ) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?; let sreg_map = SpecialRegistersMap::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
on_pass_end("normalize_identifiers2");
let directives = replace_known_functions::run(&mut flat_resolver, directives); let directives = replace_known_functions::run(&mut flat_resolver, directives);
on_pass_end("replace_known_functions");
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_predicates2");
let directives = resolve_function_pointers::run(directives)?; let directives = resolve_function_pointers::run(directives)?;
on_pass_end("resolve_function_pointers");
let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?; let directives = fix_special_registers::run(&mut flat_resolver, &sreg_map, directives)?;
on_pass_end("fix_special_registers");
let directives = expand_operands::run(&mut flat_resolver, directives)?; let directives = expand_operands::run(&mut flat_resolver, directives)?;
on_pass_end("expand_operands");
let directives = insert_post_saturation::run(&mut flat_resolver, directives)?; let directives = insert_post_saturation::run(&mut flat_resolver, directives)?;
on_pass_end("insert_post_saturation");
let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
on_pass_end("deparamize_functions");
let directives = let directives =
replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?; replace_instructions_with_functions_fp_required::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions_fp_required");
let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?; let directives = normalize_basic_blocks::run(&mut flat_resolver, directives)?;
on_pass_end("normalize_basic_blocks");
let directives = remove_unreachable_basic_blocks::run(directives)?; let directives = remove_unreachable_basic_blocks::run(directives)?;
on_pass_end("remove_unreachable_basic_blocks");
let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?; let directives = instruction_mode_to_global_mode::run(&mut flat_resolver, directives)?;
on_pass_end("instruction_mode_to_global_mode");
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
on_pass_end("insert_explicit_load_store");
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
on_pass_end("insert_implicit_conversions2");
let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?; let directives = replace_instructions_with_functions::run(&mut flat_resolver, directives)?;
on_pass_end("replace_instructions_with_functions");
let directives = hoist_globals::run(directives)?; let directives = hoist_globals::run(directives)?;
on_pass_end("hoist_globals");
let context = llvm::Context::new(); let context = llvm::Context::new();
let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?; let llvm_ir = llvm::emit::run(&context, flat_resolver, directives)?;
let attributes_ir = llvm::attributes::run(&context, attributes)?; let attributes_ir = llvm::attributes::run(&context, attributes)?;
on_pass_end("emit_llvm");
Ok(Module { Ok(Module {
llvm_ir, llvm_ir,
attributes_ir, attributes_ir,

View file

@ -38,7 +38,7 @@ fn compile_and_assert(ptx_text: &str) -> Result<(), TranslateError> {
let attributes = pass::Attributes { let attributes = pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}; };
crate::to_llvm_module(ast, attributes)?; crate::to_llvm_module(ast, attributes, |_| {})?;
Ok(()) Ok(())
} }

View file

@ -522,6 +522,7 @@ fn test_hip_assert<
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let name = CString::new(name)?; let name = CString::new(name)?;
@ -542,6 +543,7 @@ fn test_llvm_assert(
pass::Attributes { pass::Attributes {
clock_rate: 2124000, clock_rate: 2124000,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let actual_ll = llvm_ir.llvm_ir.print_module_to_string(); let actual_ll = llvm_ir.llvm_ir.print_module_to_string();

View file

@ -42,6 +42,7 @@ fn main() {
ptx::Attributes { ptx::Attributes {
clock_rate: clock_rate as u32, clock_rate: clock_rate as u32,
}, },
|_| {},
) )
.unwrap(); .unwrap();
let elf_binary = comgr::compile_bitcode( let elf_binary = comgr::compile_bitcode(

View file

@ -138,7 +138,7 @@ fn compile_from_ptx_and_cache(
} else { } else {
ptx_parser::parse_module_unchecked(text) ptx_parser::parse_module_unchecked(text)
}; };
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?; let llvm_module = ptx::to_llvm_module(ast, attributes, |_| {}).map_err(|_| CUerror::UNKNOWN)?;
let elf_module = comgr::compile_bitcode( let elf_module = comgr::compile_bitcode(
comgr, comgr,
gcn_arch, gcn_arch,