From 734db223d6b77bcdc3388418738e25895d5c7e36 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 19 Apr 2025 10:32:48 +0000 Subject: [PATCH] Generate cudnn bindgen --- cuda_types/src/lib.rs | 4 +- zluda_bindgen/build/cuda_wrapper.h | 1 + .../build/cudnn_v8/cudnn_adv_infer.h | 1 + .../build/cudnn_v8/cudnn_adv_train.h | 1 + zluda_bindgen/build/cudnn_v8/cudnn_backend.h | 1 + .../build/cudnn_v8/cudnn_cnn_infer.h | 1 + .../build/cudnn_v8/cudnn_cnn_train.h | 1 + .../build/cudnn_v8/cudnn_ops_infer.h | 1 + .../build/cudnn_v8/cudnn_ops_train.h | 1 + zluda_bindgen/build/cudnn_v8/cudnn_version.h | 1 + zluda_bindgen/src/main.rs | 393 ++++++++++++++++-- 11 files changed, 361 insertions(+), 45 deletions(-) create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_backend.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h create mode 100644 zluda_bindgen/build/cudnn_v8/cudnn_version.h diff --git a/cuda_types/src/lib.rs b/cuda_types/src/lib.rs index cd8ce24..e9a4305 100644 --- a/cuda_types/src/lib.rs +++ b/cuda_types/src/lib.rs @@ -1,2 +1,4 @@ pub mod cuda; -pub mod nvml; \ No newline at end of file +pub mod nvml; +pub mod cublas; +pub mod cublaslt; \ No newline at end of file diff --git a/zluda_bindgen/build/cuda_wrapper.h b/zluda_bindgen/build/cuda_wrapper.h index a550256..d10f32e 100644 --- a/zluda_bindgen/build/cuda_wrapper.h +++ b/zluda_bindgen/build/cuda_wrapper.h @@ -5,3 +5,4 @@ #include #include #include +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h b/zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h new file mode 100644 index 0000000..fbd527b --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h b/zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h new file mode 100644 index 0000000..15c97e7 --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_backend.h b/zluda_bindgen/build/cudnn_v8/cudnn_backend.h new file mode 100644 index 0000000..8919805 --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_backend.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h b/zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h new file mode 100644 index 0000000..4933e9d --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h b/zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h new file mode 100644 index 0000000..9921348 --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h b/zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h new file mode 100644 index 0000000..e3aa4d2 --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h b/zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h new file mode 100644 index 0000000..b7a6aad --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/build/cudnn_v8/cudnn_version.h b/zluda_bindgen/build/cudnn_v8/cudnn_version.h new file mode 100644 index 0000000..8887e0a --- /dev/null +++ b/zluda_bindgen/build/cudnn_v8/cudnn_version.h @@ -0,0 +1 @@ +#include diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index bfa9d49..e2a6f9b 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -1,11 +1,13 @@ use proc_macro2::Span; use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; -use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr}; +use std::{ + borrow::Cow, collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr, +}; use syn::{ - parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg, - ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, - PathArguments, Signature, Type, TypePath, UseTree, PathSegment + parse, parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, + FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, + Path, PathArguments, PathSegment, Signature, Type, TypePath, UseTree, }; fn main() { @@ -14,23 +16,324 @@ fn main() { &crate_root, &["..", "ext", "hip_runtime-sys", "src", "lib.rs"], ); - generate_ml(&crate_root); generate_cuda(&crate_root); + generate_ml(&crate_root); + generate_cublas(&crate_root); + generate_cublaslt(&crate_root); + generate_cudnn(&crate_root); +} + +fn generate_cudnn(crate_root: &PathBuf) { + let cudnn9 = new_builder() + .header("/usr/include/x86_64-linux-gnu/cudnn_v9.h") + .allowlist_type("^cudnn.*") + .allowlist_function("^cudnn.*") + .allowlist_var("^CUDNN_.*") + .must_use_type("cudnnStatus_t") + .allowlist_recursively(false) + .clang_args(["-I/usr/local/cuda/include"]) + .generate() + .unwrap() + .to_string(); + let module: syn::File = syn::parse_str(&cudnn9).unwrap(); + generate_functions( + &crate_root, + "cudnn9", + &["..", "cuda_base", "src", "cudnn9.rs"], + &module, + ); + let cudnn9_types = generate_types_library_impl(&module); + let mut current_dir = PathBuf::from(file!()); + current_dir.pop(); + let cudnn8 = new_builder() + .header("/usr/include/x86_64-linux-gnu/cudnn_v8.h") + .allowlist_type("^cudnn.*") + .allowlist_function("^cudnn.*") + .allowlist_var("^CUDNN_.*") + .must_use_type("cudnnStatus_t") + .allowlist_recursively(false) + .clang_args([ + "-I/usr/local/cuda/include", + &format!("-I{}/../build/cudnn_v8", current_dir.display()), + ]) + .generate() + .unwrap() + .to_string(); + let module: syn::File = syn::parse_str(&cudnn8).unwrap(); + generate_functions( + &crate_root, + "cudnn8", + &["..", "cuda_base", "src", "cudnn8.rs"], + &module, + ); + let cudnn8_types = generate_types_library_impl(&module); + merge_types( + &crate_root, + &["..", "cuda_types", "src", "cudnn.rs"], + cudnn9_types, + &["..", "cuda_types", "src", "cudnn9.rs"], + cudnn8_types, + &["..", "cuda_types", "src", "cudnn8.rs"], + ); +} + +// This code splits types (and constants) into one of: +// - cudnn8-specific +// - cudnn9-specific +// - cudnn shared +// With the rules being: +// - constants go to the specific files +// - if there's conflict between types they go to specific files +// - if the cudnn9 type is purely additive over cudnn8 then it goes into the +// shared (and is re-exported) +fn merge_types( + output: &PathBuf, + cudnn_path: &[&str], + cudnn9_types: syn::File, + cudnn9_path: &[&str], + cudnn8_types: syn::File, + cudnn8_path: &[&str], +) { + let cudnn_enums = merge_enums(&cudnn9_types, &cudnn8_types); + let conflicting_types = get_conflicting_structs(&cudnn9_types, &cudnn8_types, cudnn_enums); + let common_items = cudnn9_types.items.iter().filter_map(|item| match item { + Item::Impl(ref impl_) => match conflicting_types.get(type_to_ident(&*impl_.self_ty)) { + Some(CudnnEnumMergeResult::Conflict) => None, + Some(CudnnEnumMergeResult::Same) => { + let item: Item = parse_quote! { + #impl_ + }; + Some(item) + } + Some(CudnnEnumMergeResult::Cudnn9) => None, + None => None, + }, + Item::Struct(ref struct_) => match conflicting_types.get(&struct_.ident) { + Some(CudnnEnumMergeResult::Conflict) => None, + Some(CudnnEnumMergeResult::Same) => { + let item: Item = parse_quote! { + #struct_ + }; + Some(item) + } + Some(CudnnEnumMergeResult::Cudnn9) => None, + None => None, + }, + Item::Enum(ref enum_) => match conflicting_types.get(&enum_.ident) { + Some(CudnnEnumMergeResult::Conflict) => None, + Some(CudnnEnumMergeResult::Same) => { + let item: Item = parse_quote! { + #enum_ + }; + Some(item) + } + Some(CudnnEnumMergeResult::Cudnn9) => None, + None => None, + }, + Item::ForeignMod(ItemForeignMod { .. }) => None, + _ => None, + //_ => unimplemented!(), + }); + + let file: syn::File = parse_quote! { + #(#common_items)* + }; + { + let mut output = output.clone(); + output.extend(cudnn_path); + let text = prettyplease::unparse(&file); + write_rust_to_file(output, &text) + } +} + +fn get_conflicting_structs<'a>( + cudnn9_types: &'a syn::File, + cudnn8_types: &'a syn::File, + mut enums: FxHashMap<&'a Ident, CudnnEnumMergeResult>, +) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> { + let structs9 = get_structs(cudnn9_types); + let structs8 = get_structs(cudnn8_types); + for (struct_name8, struct8) in structs8 { + if enums.contains_key(struct_name8) { + continue; + } + match structs9.get(struct_name8) { + Some(struct9) => { + if struct8 != *struct9 { + panic!("{}", struct_name8.to_string()); + } + let has_conflicting_field = struct8.iter().any(|field| { + let type_ = type_to_ident(&field.ty); + enums.get(type_) == Some(&CudnnEnumMergeResult::Conflict) + }); + let value = if has_conflicting_field { + CudnnEnumMergeResult::Conflict + } else { + CudnnEnumMergeResult::Same + }; + assert!(enums.insert(struct_name8, value).is_none()); + } + None => {} + } + } + enums +} + +fn type_to_ident<'a>(ty: &'a syn::Type) -> &'a syn::Ident { + match ty { + Type::Path(path) => &path.path.segments[0].ident, + Type::Array(array) => type_to_ident(&array.elem), + _ => unimplemented!("{}", ty.to_token_stream().to_string()), + } +} + +fn merge_enums<'a>( + cudnn9_types: &'a syn::File, + cudnn8_types: &'a syn::File, +) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> { + let result = { + let enums8 = get_enums(cudnn8_types); + let enums9 = get_enums(cudnn9_types); + enums8 + .iter() + .map(|(enum8_ident, enum8_vars)| { + let merge_result = match enums9.get(enum8_ident) { + Some(enum9_vars) => { + let e8_has_extra = enum8_vars.difference(&enum9_vars).any(|_| true); + let e9_has_extra = enum9_vars.difference(&enum8_vars).any(|_| true); + match (e8_has_extra, e9_has_extra) { + (false, false) => CudnnEnumMergeResult::Same, + (false, true) => CudnnEnumMergeResult::Cudnn9, + (true, true) => CudnnEnumMergeResult::Conflict, + (true, false) => unimplemented!(), + } + } + None => { + unimplemented!() + } + }; + (*enum8_ident, merge_result) + }) + .collect::>() + }; + result +} + +#[derive(Copy, Clone, PartialEq, Eq)] +enum CudnnEnumMergeResult { + // Conflicting definitions + Conflict, + // Identical definitions + Same, + // Enum present in both, but cudnn9 definition is a strict superset + Cudnn9, +} + +fn get_enums<'a>( + cudnn_module: &'a syn::File, +) -> FxHashMap<&'a Ident, FxHashSet<&'a syn::ImplItemConst>> { + let mut enums = FxHashMap::default(); + for item in cudnn_module.items.iter() { + match item { + Item::Impl(ref impl_) => match &*impl_.self_ty { + Type::Path(path) => { + let constant = match impl_.items[0] { + syn::ImplItem::Const(ref impl_item_const) => impl_item_const, + _ => unimplemented!(), + }; + enums + .entry(&path.path.segments[0].ident) + .or_insert(FxHashSet::default()) + .insert(constant); + } + _ => unimplemented!(), + }, + _ => {} + } + } + enums +} + +fn get_structs<'a>(cudnn_module: &'a syn::File) -> FxHashMap<&'a Ident, Cow<'a, syn::Fields>> { + let mut structs = FxHashMap::default(); + for item in cudnn_module.items.iter() { + match item { + Item::Struct(ref struct_) => { + assert!(structs + .insert(&struct_.ident, Cow::Borrowed(&struct_.fields)) + .is_none()); + } + Item::Union(ref union_) => { + assert!(structs + .insert( + &union_.ident, + Cow::Owned(syn::Fields::Named(union_.fields.clone())) + ) + .is_none()); + } + _ => {} + } + } + structs +} + +fn generate_cublas(crate_root: &PathBuf) { + let cublas_header = new_builder() + .header("/usr/local/cuda/include/cublas_v2.h") + .allowlist_type("^cublas.*") + .allowlist_function("^cublas.*") + .allowlist_var("^CUBLAS_.*") + .must_use_type("cublasStatus_t") + .allowlist_recursively(false) + .clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) + .generate() + .unwrap() + .to_string(); + let module: syn::File = syn::parse_str(&cublas_header).unwrap(); + generate_functions( + &crate_root, + "cublas", + &["..", "cuda_base", "src", "cublas.rs"], + &module, + ); + generate_types_library( + &crate_root, + &["..", "cuda_types", "src", "cublas.rs"], + &module, + ) +} + +fn generate_cublaslt(crate_root: &PathBuf) { + let cublas_header = new_builder() + .header("/usr/local/cuda/include/cublasLt.h") + .allowlist_type("^cublas.*") + .allowlist_function("^cublasLt.*") + .allowlist_var("^CUBLASLT_.*") + .must_use_type("cublasStatus_t") + .allowlist_recursively(false) + .clang_args(["-I/usr/local/cuda/include", "-x", "c++"]) + .generate() + .unwrap() + .to_string(); + let module: syn::File = syn::parse_str(&cublas_header).unwrap(); + generate_functions( + &crate_root, + "cublaslt", + &["..", "cuda_base", "src", "cublaslt.rs"], + &module, + ); + generate_types_library( + &crate_root, + &["..", "cuda_types", "src", "cublaslt.rs"], + &module, + ) } fn generate_cuda(crate_root: &PathBuf) { - let cuda_header = bindgen::Builder::default() - .use_core() - .rust_target(bindgen::RustTarget::Stable_1_77) - .layout_tests(false) - .default_enum_style(bindgen::EnumVariation::NewType { - is_bitfield: false, - is_global: false, - }) - .derive_hash(true) - .derive_eq(true) + let cuda_header = new_builder() .header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h")) .allowlist_type("^CU.*") + .allowlist_type("^cuda.*") .allowlist_function("^cu.*") .allowlist_var("^CU.*") .must_use_type("cudaError_enum") @@ -67,22 +370,14 @@ fn generate_cuda(crate_root: &PathBuf) { } fn generate_ml(crate_root: &PathBuf) { - let ml_header = bindgen::Builder::default() - .use_core() - .rust_target(bindgen::RustTarget::Stable_1_77) - .layout_tests(false) - .default_enum_style(bindgen::EnumVariation::NewType { - is_bitfield: false, - is_global: false, - }) - .derive_hash(true) - .derive_eq(true) + let ml_header = new_builder() .header("/usr/local/cuda/include/nvml.h") .allowlist_type("^nvml.*") .allowlist_function("^nvml.*") .allowlist_var("^NVML.*") .must_use_type("nvmlReturn_t") .constified_enum("nvmlReturn_enum") + .clang_args(["-I/usr/local/cuda/include"]) .generate() .unwrap() .to_string(); @@ -112,37 +407,34 @@ fn generate_ml(crate_root: &PathBuf) { &["..", "cuda_base", "src", "nvml.rs"], &module, ); - generate_types( + generate_types_library( &crate_root, &["..", "cuda_types", "src", "nvml.rs"], &module, ); } -fn generate_types(crate_root: &PathBuf, path: &[&str], module: &syn::File) { +fn generate_types_library(crate_root: &PathBuf, path: &[&str], module: &syn::File) { + let module = generate_types_library_impl(module); + let mut output = crate_root.clone(); + output.extend(path); + let text = + prettyplease::unparse(&module).replace("self::cudaDataType", "super::cuda::cudaDataType"); + write_rust_to_file(output, &text) +} + +fn generate_types_library_impl(module: &syn::File) -> syn::File { let non_fn = module.items.iter().filter_map(|item| match item { Item::ForeignMod(_) => None, _ => Some(item), }); - let module: syn::File = parse_quote! { + parse_quote! { #(#non_fn)* - }; - let mut output = crate_root.clone(); - output.extend(path); - write_rust_to_file(output, &prettyplease::unparse(&module)) + } } fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { - let hiprt_header = bindgen::Builder::default() - .use_core() - .rust_target(bindgen::RustTarget::Stable_1_77) - .layout_tests(false) - .default_enum_style(bindgen::EnumVariation::NewType { - is_bitfield: false, - is_global: false, - }) - .derive_hash(true) - .derive_eq(true) + let hiprt_header = new_builder() .header("/opt/rocm/include/hip/hip_runtime_api.h") .allowlist_type("^hip.*") .allowlist_function("^hip.*") @@ -426,7 +718,7 @@ struct ExplicitReturnType; impl VisitMut for ExplicitReturnType { fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) { if let syn::ReturnType::Default = i { - *i = parse_quote! { -> {} }; + *i = parse_quote! { -> () }; } } } @@ -563,7 +855,7 @@ fn cuda_derive_display_trait_for_item<'a>( state: &mut DeriveDisplayState<'a>, item: &'a Item, ) -> Option { - let path_prefix = & state.types_crate; + let path_prefix = &state.types_crate; let path_prefix_iter = iter::repeat(&path_prefix); let mut prepend_path = PrependCudaPath { module: Ident::new("cuda", Span::call_site()), @@ -798,3 +1090,16 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item { } } } + +fn new_builder() -> bindgen::Builder { + bindgen::Builder::default() + .use_core() + .rust_target(bindgen::RustTarget::Stable_1_77) + .layout_tests(false) + .default_enum_style(bindgen::EnumVariation::NewType { + is_bitfield: false, + is_global: false, + }) + .derive_hash(true) + .derive_eq(true) +}