mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-05 15:49:25 +00:00
Use normalize_fn
for performance libraries (#449)
The goal here is to make the performance library implementations work more like zluda.
This commit is contained in:
parent
c07d7678cd
commit
98b601d15a
11 changed files with 90 additions and 71 deletions
|
@ -203,8 +203,11 @@ const MODULES: &[&str] = &[
|
|||
"stream",
|
||||
];
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
fn normalize_fn_impl(
|
||||
prefix: &str,
|
||||
default_module: Option<&str>,
|
||||
tokens: TokenStream,
|
||||
) -> TokenStream {
|
||||
let mut path = parse_macro_input!(tokens as syn::Path);
|
||||
let fn_ = path
|
||||
.segments
|
||||
|
@ -215,14 +218,44 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
|||
.ident
|
||||
.to_string();
|
||||
let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
|
||||
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
|
||||
let fn_path = join(segments, !already_has_module);
|
||||
let segments: Vec<String> = split(&fn_[prefix.len()..]); // skip "cu"
|
||||
let fn_path = join(segments, default_module.filter(|_| !already_has_module));
|
||||
quote! {
|
||||
#path #fn_path
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("cu", Some("driver"), tokens)
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("cublas", None, tokens)
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("cublasLt", None, tokens)
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("cudnn", None, tokens)
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("cusparse", None, tokens)
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream {
|
||||
normalize_fn_impl("nvml", None, tokens)
|
||||
}
|
||||
|
||||
fn split(fn_: &str) -> Vec<String> {
|
||||
let mut result = Vec::new();
|
||||
for c in fn_.chars() {
|
||||
|
@ -235,7 +268,10 @@ fn split(fn_: &str) -> Vec<String> {
|
|||
result
|
||||
}
|
||||
|
||||
fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
||||
fn join(
|
||||
fn_: Vec<String>,
|
||||
default_module: Option<&str>,
|
||||
) -> Punctuated<Ident, Token![::]> {
|
||||
fn full_form(segment: &str) -> Option<&[&str]> {
|
||||
Some(match segment {
|
||||
"ctx" => &["context"],
|
||||
|
@ -253,13 +289,9 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
|||
None => normalized.push(&*segment),
|
||||
}
|
||||
}
|
||||
if !find_module {
|
||||
return [Ident::new(&normalized.join("_"), Span::call_site())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
}
|
||||
if !MODULES.contains(&normalized[0]) {
|
||||
let mut globalized = vec!["driver"];
|
||||
if let Some(default_module) = default_module {
|
||||
if !MODULES.contains(&normalized[0]) {
|
||||
let mut globalized = vec![default_module];
|
||||
globalized.extend(normalized);
|
||||
normalized = globalized;
|
||||
}
|
||||
|
@ -269,4 +301,10 @@ fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
|
|||
.into_iter()
|
||||
.map(|s| Ident::new(s, Span::call_site()))
|
||||
.collect()
|
||||
} else {
|
||||
return [Ident::new(&normalized.join("_"), Span::call_site())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue