extern crate proc_macro; use proc_macro::TokenStream; use proc_macro2::Span; use quote::{format_ident, quote, ToTokens}; use rustc_hash::FxHashMap; use std::iter; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::visit_mut::VisitMut; use syn::{ bracketed, parse_macro_input, File, ForeignItem, ForeignItemFn, Ident, Item, Path, Signature, Token }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; const NVML_RS: &'static str = include_str! {"nvml.rs"}; const CUBLAS_RS: &'static str = include_str! {"cublas.rs"}; const CUBLASLT_RS: &'static str = include_str! {"cublaslt.rs"}; const CUBLASLT_INTERNAL_RS: &'static str = include_str! {"cublaslt_internal.rs"}; const CUFFT_RS: &'static str = include_str! {"cufft.rs"}; const CUSPARSE_RS: &'static str = include_str! {"cusparse.rs"}; const CUDNN9_RS: &'static str = include_str! {"cudnn9.rs"}; // This macro accepts following arguments: // * `normal_macro`: ident for a normal macro // * zero or more: // * `override_macro`: ident for an override macro // * `override_fns`: list of override functions // Then macro goes through every function in rust.rs, and for every fn `foo`: // * if `foo` is contained in `override_fns` then pass it into `override_macro` // * if `foo` is not contained in `override_fns` pass it to `normal_macro` // Both `override_macro` and `normal_macro` expect semicolon-separated list: // macro_foo!( // "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult; // "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult // ) // Additionally, it does a fixup of CUDA types so they get prefixed with `type_path` #[proc_macro] pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUDA_RS) } #[proc_macro] pub fn cublas_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUBLAS_RS) } #[proc_macro] pub fn cublaslt_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUBLASLT_RS) } #[proc_macro] pub fn cublaslt_internal_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUBLASLT_INTERNAL_RS) } #[proc_macro] pub fn cufft_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUFFT_RS) } #[proc_macro] pub fn cusparse_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUSPARSE_RS) } #[proc_macro] pub fn cudnn9_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, CUDNN9_RS) } fn function_declarations(tokens: TokenStream, module: &str) -> TokenStream { let input = parse_macro_input!(tokens as FnDeclInput); let mut cuda_module = syn::parse_str::(module).unwrap(); let mut choose_macro = ChooseMacro::new(input); syn::visit_mut::visit_file_mut(&mut FixFnSignatures, &mut cuda_module); for item in cuda_module.items { let extern_ = if let Item::ForeignMod(extern_) = item { extern_ } else { unreachable!() }; let abi = extern_.abi.name; for mut item in extern_.items { if let ForeignItem::Fn(ForeignItemFn { sig: Signature { ref ident, .. }, ref mut attrs, .. }) = item { *attrs = Vec::new(); choose_macro.add(ident, quote! { #abi #item }); } else { unreachable!() } } } let mut result = proc_macro2::TokenStream::new(); for (path, items) in iter::once(choose_macro.default).chain(choose_macro.override_sets.into_iter()) { if items.is_empty() { continue; } quote! { #path ! { #(#items)* } } .to_tokens(&mut result); } result.into() } #[proc_macro] pub fn nvml_function_declarations(tokens: TokenStream) -> TokenStream { function_declarations(tokens, NVML_RS) } struct FnDeclInput { normal_macro: Path, overrides: Punctuated, } impl Parse for FnDeclInput { fn parse(input: ParseStream) -> syn::Result { let normal_macro = input.parse::()?; let overrides = if input.is_empty() { Punctuated::new() } else { input.parse::()?; input.parse_terminated(OverrideMacro::parse, Token![,])? }; Ok(Self { normal_macro, overrides, }) } } struct OverrideMacro { macro_: Path, functions: Punctuated, } impl Parse for OverrideMacro { fn parse(input: ParseStream) -> syn::Result { let macro_ = input.parse::()?; input.parse::()?; let functions_content; bracketed!(functions_content in input); let functions = functions_content.parse_terminated(Ident::parse, Token![,])?; Ok(Self { macro_, functions }) } } struct ChooseMacro { default: (Path, Vec), override_lookup: FxHashMap, override_sets: FxHashMap>, } impl ChooseMacro { fn new(input: FnDeclInput) -> Self { let mut override_lookup = FxHashMap::default(); let mut override_sets = FxHashMap::default(); for OverrideMacro { macro_, functions } in input.overrides { for ident in functions { override_lookup.insert(ident, macro_.clone()); override_sets.insert(macro_.clone(), Vec::new()); } } Self { default: (input.normal_macro, Vec::new()), override_lookup, override_sets, } } fn add(&mut self, ident: &Ident, tokens: proc_macro2::TokenStream) { match self.override_lookup.get(ident) { Some(override_macro) => { self.override_sets .get_mut(override_macro) .unwrap() .push(tokens); } None => self.default.1.push(tokens), } } } // For some reason prettyplease will append trailing comma *only* // if there are two or more arguments struct FixFnSignatures; impl VisitMut for FixFnSignatures { fn visit_signature_mut(&mut self, s: &mut syn::Signature) { s.inputs.pop_punct(); } } const MODULES: &[&str] = &[ "context", "device", "driver", "event", "function", "graph", "kernel", "library", "link", "memory", "module", "pointer", "stream" ]; fn normalize_fn_impl( prefix: &str, default_module: Option<&str>, tokens: TokenStream, ) -> TokenStream { let mut path = parse_macro_input!(tokens as syn::Path); let fn_ = path .segments .pop() .unwrap() .into_tuple() .0 .ident .to_string(); let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string()); let segments: Vec = split(&fn_[prefix.len()..]); // skip "cu" let fn_path = join(segments, default_module.filter(|_| !already_has_module)); quote! { #path #fn_path } .into() } #[proc_macro] pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("cu", Some("driver"), tokens) } #[proc_macro] pub fn cublas_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("cublas", None, tokens) } #[proc_macro] pub fn cublaslt_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("cublasLt", None, tokens) } #[proc_macro] pub fn cudnn_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("cudnn", None, tokens) } #[proc_macro] pub fn cusparse_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("cusparse", None, tokens) } #[proc_macro] pub fn nvml_normalize_fn(tokens: TokenStream) -> TokenStream { normalize_fn_impl("nvml", None, tokens) } fn split(fn_: &str) -> Vec { let mut result = Vec::new(); for c in fn_.chars() { if c.is_ascii_uppercase() { result.push(c.to_ascii_lowercase().to_string()); } else { result.last_mut().unwrap().push(c); } } result } fn join( fn_: Vec, default_module: Option<&str>, ) -> Punctuated { fn full_form(segment: &str) -> Option<&[&str]> { Some(match segment { "ctx" => &["context"], "func" => &["function"], "mem" => &["memory"], "memcpy" => &["memory", "copy"], "memset" => &["memory", "set"], _ => return None, }) } let mut normalized: Vec<&str> = Vec::new(); for segment in fn_.iter() { match full_form(segment) { Some(segments) => normalized.extend(segments.into_iter()), None => normalized.push(&*segment), } } if let Some(default_module) = default_module { if !MODULES.contains(&normalized[0]) { let mut globalized = vec![default_module]; globalized.extend(normalized); normalized = globalized; } let (module, path) = normalized.split_first().unwrap(); let path = path.join("_"); [module, &&*path] .into_iter() .map(|s| Ident::new(s, Span::call_site())) .collect() } else { return [Ident::new(&normalized.join("_"), Span::call_site())] .into_iter() .collect(); } } #[proc_macro_attribute] pub fn test_cuda(_attr: TokenStream, item: TokenStream) -> TokenStream { let fn_ = parse_macro_input!(item as syn::ItemFn); let cuda_fn = format_ident!("{}{}", fn_.sig.ident, "_nvidia"); let zluda_fn = format_ident!("{}{}", fn_.sig.ident, "_zluda"); let fn_name = fn_.sig.ident.clone(); quote! { #[test] fn #cuda_fn() { unsafe { #fn_name(::new()) } } #[test] fn #zluda_fn() { unsafe { #fn_name(::new()) } } #fn_ }.into() }