Use normalize_fn for performance libraries (#449)
Some checks are pending
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run

The goal here is to make the performance library implementations work more like zluda.
This commit is contained in:
Violet 2025-07-30 14:02:01 -07:00 committed by GitHub
commit 98b601d15a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 90 additions and 71 deletions

View file

@ -203,8 +203,11 @@ const MODULES: &[&str] = &[
"stream",
];
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
fn normalize_fn_impl(
prefix: &str,
default_module: Option<&str>,
tokens: TokenStream,
) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
let fn_ = path
.segments
@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident
.to_string();
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
let fn_path = join(segments, !already_has_module);
let segments: Vec<String> = split(&fn_[prefix.len()..]); // skip "cu"
let fn_path = join(segments, default_module.filter(|_| !already_has_module));
quote! {
#path #fn_path
}
.into()
}
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cu", Some("driver"), tokens)
}
#[proc_macro]
pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublas", None, tokens)
}
#[proc_macro]
pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cublasLt", None, tokens)
}
#[proc_macro]
pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cudnn", None, tokens)
}
#[proc_macro]
pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("cusparse", None, tokens)
}
#[proc_macro]
pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream {
normalize_fn_impl("nvml", None, tokens)
}
fn split(fn_: &str) -> Vec<String> {
let mut result = Vec::new();
for c in fn_.chars() {
@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec<String> {
result
}
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
fn join(
fn_: Vec<String>,
default_module: Option<&str>,
) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment {
"ctx" => &["context"],
@ -253,13 +289,9 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment),
}
}
if !find_module {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
if !MODULES.contains(&normalized[0]) {
let mut globalized = vec!["driver"];
if let Some(default_module) = default_module {
if !MODULES.contains(&normalized[0]) {
let mut globalized = vec![default_module];
globalized.extend(normalized);
normalized = globalized;
}
@ -269,4 +301,10 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
.into_iter()
.map(|s| Ident::new(s, Span::call_site()))
.collect()
} else {
return [Ident::new(&normalized.join("_"), Span::call_site())]
.into_iter()
.collect();
}
}