From 88b01c809eb6bca5519b3ee5acf5f5026c6aeff0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 3 Sep 2025 21:23:01 +0200 Subject: [PATCH] Add small compiler fixes and a fake ptxas binary (#491) --- Cargo.lock | 11 +++ Cargo.toml | 1 + compiler/src/main.rs | 21 +++--- ptx_parser/src/lib.rs | 167 ++++++++++++++++++++++++++++++++++-------- ptxas/Cargo.toml | 18 +++++ ptxas/src/main.rs | 65 ++++++++++++++++ 6 files changed, 242 insertions(+), 41 deletions(-) create mode 100644 ptxas/Cargo.toml create mode 100644 ptxas/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 56b79c4..0a07aad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2649,6 +2649,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "ptxas" +version = "0.0.0" +dependencies = [ + "bpaf", + "comgr", + "hip_runtime-sys", + "ptx", + "ptx_parser", +] + [[package]] name = "quick-error" version = "1.2.3" diff --git a/Cargo.toml b/Cargo.toml index 9415c8f..d6fa904 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "ptx_parser", "ptx_parser_macros", "ptx_parser_macros_impl", + "ptxas", "xtask", "zluda", "zluda_bindgen", diff --git a/compiler/src/main.rs b/compiler/src/main.rs index aa5007d..fb8feb0 100644 --- a/compiler/src/main.rs +++ b/compiler/src/main.rs @@ -10,6 +10,7 @@ use bpaf::Bpaf; mod error; use error::CompilerError; +use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600, hipInit}; const DEFAULT_ARCH: &'static str = "gfx1100"; @@ -58,9 +59,14 @@ fn main_core() -> Result<(), CompilerError> { let arch: String = match opts.arch { Some(s) => s, - None => get_gpu_arch() - .map(String::from) - .unwrap_or(DEFAULT_ARCH.to_owned()), + None => { + unsafe { hipInit(0) }?; + let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; + unsafe { hipGetDevicePropertiesR0600(&mut dev_props, 0) }?; + get_gpu_arch(&mut dev_props) + .map(String::from) + .unwrap_or(DEFAULT_ARCH.to_owned()) + } }; let ptx = fs::read(&ptx_path).map_err(CompilerError::from)?; @@ -78,8 +84,8 @@ fn main_core() -> Result<(), CompilerError> { &comgr, &arch, &llvm.bitcode, - &llvm.attributes_bitcode, &llvm.linked_bitcode, + &llvm.attributes_bitcode, Some(&comgr_hook), ) .map_err(CompilerError::from)?; @@ -116,11 +122,8 @@ struct LLVMArtifacts { llvm_ir: Vec, } -fn get_gpu_arch() -> Result<&'static str, CompilerError> { - use hip_runtime_sys::*; - unsafe { hipInit(0) }?; - let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; - unsafe { hipGetDevicePropertiesR0600(&mut dev_props, 0) }?; +fn get_gpu_arch<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> Result<&'a str, CompilerError> { + unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }?; let gcn_arch_name = &dev_props.gcnArchName; let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; let gcn_arch_name = gcn_arch_name.to_str(); diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index fb924e7..6e4e167 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -182,26 +182,29 @@ fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input } fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { - any.verify_map(|(t, _)| { - Some(match t { - Token::Hex(s) => { - if s.ends_with('U') { - (&s[2..s.len() - 1], 16, true) - } else { - (&s[2..], 16, false) + trace( + "num", + any.verify_map(|(t, _)| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } } - } - Token::Decimal(s) => { - let radix = if s.starts_with('0') { 8 } else { 10 }; - if s.ends_with('U') { - (&s[..s.len() - 1], radix, true) - } else { - (s, radix, false) + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } } - } - _ => return None, - }) - }) + _ => return None, + }) + }), + ) .parse_next(stream) } @@ -290,13 +293,16 @@ fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { } fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { - take_error(num.map(|x| { - let (text, radix, _) = x; - match u32::from_str_radix(text, radix) { - Ok(x) => Ok(x), - Err(err) => Err((0, PtxError::from(err))), - } - })) + trace( + "u32", + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })), + ) .parse_next(stream) } @@ -547,7 +553,9 @@ fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { } fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { - alt((ident, dot_ident)).void().parse_next(stream) + trace("section_label", alt((ident, dot_ident))) + .void() + .parse_next(stream) } fn function<'a, 'input>( @@ -654,13 +662,13 @@ fn kernel_arguments<'a, 'input>( fn kernel_input<'a, 'input>( stream: &mut PtxParser<'a, 'input>, ) -> PResult> { - preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) + preceded(Token::DotParam, method_parameter(StateSpace::Param, true)).parse_next(stream) } fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { dispatch! { any; - (Token::DotParam, _) => method_parameter(StateSpace::Param), - (Token::DotReg, _) => method_parameter(StateSpace::Reg), + (Token::DotParam, _) => method_parameter(StateSpace::Param, false), + (Token::DotReg, _) => method_parameter(StateSpace::Reg, false), _ => fail } .parse_next(stream) @@ -820,11 +828,30 @@ fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { fn method_parameter<'a, 'input: 'a>( state_space: StateSpace, + kernel_decl_rules: bool, ) -> impl Parser, Variable<&'input str>, ContextError> { + fn nvptx_kernel_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<(Option, Option, ScalarType, &'input str)> { + trace( + "nvptx_kernel_declaration", + ( + vector_prefix, + scalar_type, + opt((Token::DotPtr, Token::DotGlobal)), + opt(align.verify(|x| x.count_ones() == 1)), + ident, + ), + ) + .map(|(vector, type_, _, align, name)| (align, vector, type_, name)) + .parse_next(stream) + } trace( "method_parameter", move |stream: &mut PtxParser<'a, 'input>| { - let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + if kernel_decl_rules {} + let (align, vector, type_, name) = + alt((variable_declaration, nvptx_kernel_declaration)).parse_next(stream)?; let array_dimensions = if state_space != StateSpace::Reg { opt(array_dimensions).parse_next(stream)? } else { @@ -1751,10 +1778,12 @@ derive_parser!( DotTarget, #[token(".address_size")] DotAddressSize, - #[token(".action")] + #[token(".section")] DotSection, #[token(".file")] - DotFile + DotFile, + #[token(".ptr")] + DotPtr } #[derive(Copy, Clone, Display, PartialEq, Eq, Hash)] @@ -3775,6 +3804,7 @@ derive_parser!( mod tests { use crate::first_optional; use crate::parse_module_checked; + use crate::section; use crate::PtxError; use super::target; @@ -3997,4 +4027,77 @@ mod tests { PtxError::UnrecognizedDirective(".global .bad_type foo;") )); } + + #[test] + fn dwarf_line() { + let text = " + .section .debug_abbrev + { +.b8 1 // Abbreviation Code +.b8 17 // DW_TAG_compile_unit +.b8 1 // DW_CHILDREN_yes +.b8 37 // DW_AT_producer +.b8 8 // DW_FORM_string +.b8 19 // DW_AT_language +.b8 5 // DW_FORM_data2 +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 16 // DW_AT_stmt_list +.b8 6 // DW_FORM_data4 +.b8 27 // DW_AT_comp_dir +.b8 8 // DW_FORM_string +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 2 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 0 // DW_CHILDREN_no +.b8 3 // DW_AT_name +.b8 8 // DW_FORM_string +.b8 32 // DW_AT_inline +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 3 // Abbreviation Code +.b8 46 // DW_TAG_subprogram +.b8 1 // DW_CHILDREN_yes +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 4 // Abbreviation Code +.b8 29 // DW_TAG_inlined_subroutine +.b8 0 // DW_CHILDREN_no +.b8 49 // DW_AT_abstract_origin +.b8 19 // DW_FORM_ref4 +.b8 17 // DW_AT_low_pc +.b8 1 // DW_FORM_addr +.b8 18 // DW_AT_high_pc +.b8 1 // DW_FORM_addr +.b8 88 // DW_AT_call_file +.b8 11 // DW_FORM_data1 +.b8 89 // DW_AT_call_line +.b8 11 // DW_FORM_data1 +.b8 87 // DW_AT_call_column +.b8 11 // DW_FORM_data1 +.b8 0 // EOM(1) +.b8 0 // EOM(2) +.b8 0 // EOM(3) + } +"; + let tokens = Token::lexer(text) + .map(|t| t.map(|t| (t, Span::default()))) + .collect::, _>>() + .unwrap(); + let mut errors = Vec::new(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(text, &mut errors), + }; + assert!(section.parse(stream).is_ok()); + assert_eq!(errors.len(), 0); + } } diff --git a/ptxas/Cargo.toml b/ptxas/Cargo.toml new file mode 100644 index 0000000..b99ac96 --- /dev/null +++ b/ptxas/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "ptxas" +version = "0.0.0" +authors = ["Andrzej Janik "] +edition = "2021" + +[[bin]] +name = "ptxas" +path = "src/main.rs" + +[dependencies] +comgr = { path = "../comgr" } +ptx = { path = "../ptx" } +ptx_parser = { path = "../ptx_parser" } +hip_runtime-sys = { path = "../ext/hip_runtime-sys" } +bpaf = { version = "0.9.19", features = ["derive"] } + +[package.metadata.zluda] diff --git a/ptxas/src/main.rs b/ptxas/src/main.rs new file mode 100644 index 0000000..d71ca6b --- /dev/null +++ b/ptxas/src/main.rs @@ -0,0 +1,65 @@ +use bpaf::{any, doc::Style, Bpaf, Parser}; +use hip_runtime_sys::{hipDeviceProp_tR0600, hipGetDevicePropertiesR0600}; +use std::{ffi::CStr, mem}; + +#[derive(Debug, Clone, Bpaf)] +#[allow(dead_code)] +#[bpaf(options, version("Cuda compilation tools, release 12.8, V12.8.0"))] +pub struct Options { + #[bpaf(short, long)] + output: String, + warn_on_spills: bool, + #[bpaf(short, long)] + verbose: bool, + #[bpaf(external)] + gpu_name: String, + #[bpaf(long, short('O'), fallback(3))] + opt_level: usize, + #[bpaf(positional)] + input: String, +} + +// #[bpaf(long, long("gpu_name"), fallback_with(default_arch))] +fn gpu_name() -> impl Parser { + any("", move |s: String| { + Some(s.strip_prefix("-arch=")?.to_owned()) + }) + .metavar(&[("-arch=", Style::Literal), ("ARG", Style::Metavar)]) + .anywhere() + .fallback_with(|| Ok::("sm_52".to_string())) +} + +fn main() { + let options = options().run(); + let comgr = comgr::Comgr::new().unwrap(); + unsafe { hip_runtime_sys::hipInit(0) }.unwrap(); + let mut dev_props: hipDeviceProp_tR0600 = unsafe { mem::zeroed() }; + let (gpu_arch, clock_rate) = get_gpu_arch_and_clock_rate(&mut dev_props); + let input = std::fs::read_to_string(options.input).unwrap(); + let ast = ptx_parser::parse_module_checked(&input).unwrap(); + let llvm = ptx::to_llvm_module( + ast, + ptx::Attributes { + clock_rate: clock_rate as u32, + }, + ) + .unwrap(); + let elf_binary = comgr::compile_bitcode( + &comgr, + gpu_arch, + &*llvm.llvm_ir.write_bitcode_to_memory(), + &*llvm.linked_bitcode(), + &*llvm.attributes_ir.write_bitcode_to_memory(), + None, + ) + .unwrap(); + std::fs::write(options.output, elf_binary).unwrap(); +} + +fn get_gpu_arch_and_clock_rate<'a>(dev_props: &'a mut hipDeviceProp_tR0600) -> (&'a str, i32) { + unsafe { hipGetDevicePropertiesR0600(dev_props, 0) }.unwrap(); + let gcn_arch_name = &dev_props.gcnArchName; + let gcn_arch_name = unsafe { CStr::from_ptr(gcn_arch_name.as_ptr()) }; + let gcn_arch_name = gcn_arch_name.to_str(); + (gcn_arch_name.unwrap(), dev_props.clockRate) +}