Generate cudnn bindgen

This commit is contained in:
Andrzej Janik 2025-04-19 10:32:48 +00:00
parent 7cdab7abc2
commit 734db223d6
11 changed files with 361 additions and 45 deletions

View file

@ -1,2 +1,4 @@
pub mod cuda;
pub mod nvml;
pub mod nvml;
pub mod cublas;
pub mod cublaslt;

View file

@ -5,3 +5,4 @@
#include <cudaEGL.h>
#include <vdpau/vdpau.h>
#include <cudaVDPAU.h>
#include <library_types.h>

View file

@ -0,0 +1 @@
#include <cudnn_adv_infer_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_adv_train_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_backend_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_cnn_infer_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_cnn_train_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_ops_infer_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_ops_train_v8.h>

View file

@ -0,0 +1 @@
#include <cudnn_version_v8.h>

View file

@ -1,11 +1,13 @@
use proc_macro2::Span;
use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr};
use std::{
borrow::Cow, collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr,
};
use syn::{
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg,
ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path,
PathArguments, Signature, Type, TypePath, UseTree, PathSegment
parse, parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed,
FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr,
Path, PathArguments, PathSegment, Signature, Type, TypePath, UseTree,
};
fn main() {
@ -14,23 +16,324 @@ fn main() {
&crate_root,
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
);
generate_ml(&crate_root);
generate_cuda(&crate_root);
generate_ml(&crate_root);
generate_cublas(&crate_root);
generate_cublaslt(&crate_root);
generate_cudnn(&crate_root);
}
fn generate_cudnn(crate_root: &PathBuf) {
let cudnn9 = new_builder()
.header("/usr/include/x86_64-linux-gnu/cudnn_v9.h")
.allowlist_type("^cudnn.*")
.allowlist_function("^cudnn.*")
.allowlist_var("^CUDNN_.*")
.must_use_type("cudnnStatus_t")
.allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include"])
.generate()
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cudnn9).unwrap();
generate_functions(
&crate_root,
"cudnn9",
&["..", "cuda_base", "src", "cudnn9.rs"],
&module,
);
let cudnn9_types = generate_types_library_impl(&module);
let mut current_dir = PathBuf::from(file!());
current_dir.pop();
let cudnn8 = new_builder()
.header("/usr/include/x86_64-linux-gnu/cudnn_v8.h")
.allowlist_type("^cudnn.*")
.allowlist_function("^cudnn.*")
.allowlist_var("^CUDNN_.*")
.must_use_type("cudnnStatus_t")
.allowlist_recursively(false)
.clang_args([
"-I/usr/local/cuda/include",
&format!("-I{}/../build/cudnn_v8", current_dir.display()),
])
.generate()
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cudnn8).unwrap();
generate_functions(
&crate_root,
"cudnn8",
&["..", "cuda_base", "src", "cudnn8.rs"],
&module,
);
let cudnn8_types = generate_types_library_impl(&module);
merge_types(
&crate_root,
&["..", "cuda_types", "src", "cudnn.rs"],
cudnn9_types,
&["..", "cuda_types", "src", "cudnn9.rs"],
cudnn8_types,
&["..", "cuda_types", "src", "cudnn8.rs"],
);
}
// This code splits types (and constants) into one of:
// - cudnn8-specific
// - cudnn9-specific
// - cudnn shared
// With the rules being:
// - constants go to the specific files
// - if there's conflict between types they go to specific files
// - if the cudnn9 type is purely additive over cudnn8 then it goes into the
// shared (and is re-exported)
fn merge_types(
output: &PathBuf,
cudnn_path: &[&str],
cudnn9_types: syn::File,
cudnn9_path: &[&str],
cudnn8_types: syn::File,
cudnn8_path: &[&str],
) {
let cudnn_enums = merge_enums(&cudnn9_types, &cudnn8_types);
let conflicting_types = get_conflicting_structs(&cudnn9_types, &cudnn8_types, cudnn_enums);
let common_items = cudnn9_types.items.iter().filter_map(|item| match item {
Item::Impl(ref impl_) => match conflicting_types.get(type_to_ident(&*impl_.self_ty)) {
Some(CudnnEnumMergeResult::Conflict) => None,
Some(CudnnEnumMergeResult::Same) => {
let item: Item = parse_quote! {
#impl_
};
Some(item)
}
Some(CudnnEnumMergeResult::Cudnn9) => None,
None => None,
},
Item::Struct(ref struct_) => match conflicting_types.get(&struct_.ident) {
Some(CudnnEnumMergeResult::Conflict) => None,
Some(CudnnEnumMergeResult::Same) => {
let item: Item = parse_quote! {
#struct_
};
Some(item)
}
Some(CudnnEnumMergeResult::Cudnn9) => None,
None => None,
},
Item::Enum(ref enum_) => match conflicting_types.get(&enum_.ident) {
Some(CudnnEnumMergeResult::Conflict) => None,
Some(CudnnEnumMergeResult::Same) => {
let item: Item = parse_quote! {
#enum_
};
Some(item)
}
Some(CudnnEnumMergeResult::Cudnn9) => None,
None => None,
},
Item::ForeignMod(ItemForeignMod { .. }) => None,
_ => None,
//_ => unimplemented!(),
});
let file: syn::File = parse_quote! {
#(#common_items)*
};
{
let mut output = output.clone();
output.extend(cudnn_path);
let text = prettyplease::unparse(&file);
write_rust_to_file(output, &text)
}
}
fn get_conflicting_structs<'a>(
cudnn9_types: &'a syn::File,
cudnn8_types: &'a syn::File,
mut enums: FxHashMap<&'a Ident, CudnnEnumMergeResult>,
) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> {
let structs9 = get_structs(cudnn9_types);
let structs8 = get_structs(cudnn8_types);
for (struct_name8, struct8) in structs8 {
if enums.contains_key(struct_name8) {
continue;
}
match structs9.get(struct_name8) {
Some(struct9) => {
if struct8 != *struct9 {
panic!("{}", struct_name8.to_string());
}
let has_conflicting_field = struct8.iter().any(|field| {
let type_ = type_to_ident(&field.ty);
enums.get(type_) == Some(&CudnnEnumMergeResult::Conflict)
});
let value = if has_conflicting_field {
CudnnEnumMergeResult::Conflict
} else {
CudnnEnumMergeResult::Same
};
assert!(enums.insert(struct_name8, value).is_none());
}
None => {}
}
}
enums
}
fn type_to_ident<'a>(ty: &'a syn::Type) -> &'a syn::Ident {
match ty {
Type::Path(path) => &path.path.segments[0].ident,
Type::Array(array) => type_to_ident(&array.elem),
_ => unimplemented!("{}", ty.to_token_stream().to_string()),
}
}
fn merge_enums<'a>(
cudnn9_types: &'a syn::File,
cudnn8_types: &'a syn::File,
) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> {
let result = {
let enums8 = get_enums(cudnn8_types);
let enums9 = get_enums(cudnn9_types);
enums8
.iter()
.map(|(enum8_ident, enum8_vars)| {
let merge_result = match enums9.get(enum8_ident) {
Some(enum9_vars) => {
let e8_has_extra = enum8_vars.difference(&enum9_vars).any(|_| true);
let e9_has_extra = enum9_vars.difference(&enum8_vars).any(|_| true);
match (e8_has_extra, e9_has_extra) {
(false, false) => CudnnEnumMergeResult::Same,
(false, true) => CudnnEnumMergeResult::Cudnn9,
(true, true) => CudnnEnumMergeResult::Conflict,
(true, false) => unimplemented!(),
}
}
None => {
unimplemented!()
}
};
(*enum8_ident, merge_result)
})
.collect::<FxHashMap<_, _>>()
};
result
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum CudnnEnumMergeResult {
// Conflicting definitions
Conflict,
// Identical definitions
Same,
// Enum present in both, but cudnn9 definition is a strict superset
Cudnn9,
}
fn get_enums<'a>(
cudnn_module: &'a syn::File,
) -> FxHashMap<&'a Ident, FxHashSet<&'a syn::ImplItemConst>> {
let mut enums = FxHashMap::default();
for item in cudnn_module.items.iter() {
match item {
Item::Impl(ref impl_) => match &*impl_.self_ty {
Type::Path(path) => {
let constant = match impl_.items[0] {
syn::ImplItem::Const(ref impl_item_const) => impl_item_const,
_ => unimplemented!(),
};
enums
.entry(&path.path.segments[0].ident)
.or_insert(FxHashSet::default())
.insert(constant);
}
_ => unimplemented!(),
},
_ => {}
}
}
enums
}
fn get_structs<'a>(cudnn_module: &'a syn::File) -> FxHashMap<&'a Ident, Cow<'a, syn::Fields>> {
let mut structs = FxHashMap::default();
for item in cudnn_module.items.iter() {
match item {
Item::Struct(ref struct_) => {
assert!(structs
.insert(&struct_.ident, Cow::Borrowed(&struct_.fields))
.is_none());
}
Item::Union(ref union_) => {
assert!(structs
.insert(
&union_.ident,
Cow::Owned(syn::Fields::Named(union_.fields.clone()))
)
.is_none());
}
_ => {}
}
}
structs
}
fn generate_cublas(crate_root: &PathBuf) {
let cublas_header = new_builder()
.header("/usr/local/cuda/include/cublas_v2.h")
.allowlist_type("^cublas.*")
.allowlist_function("^cublas.*")
.allowlist_var("^CUBLAS_.*")
.must_use_type("cublasStatus_t")
.allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
.generate()
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cublas_header).unwrap();
generate_functions(
&crate_root,
"cublas",
&["..", "cuda_base", "src", "cublas.rs"],
&module,
);
generate_types_library(
&crate_root,
&["..", "cuda_types", "src", "cublas.rs"],
&module,
)
}
fn generate_cublaslt(crate_root: &PathBuf) {
let cublas_header = new_builder()
.header("/usr/local/cuda/include/cublasLt.h")
.allowlist_type("^cublas.*")
.allowlist_function("^cublasLt.*")
.allowlist_var("^CUBLASLT_.*")
.must_use_type("cublasStatus_t")
.allowlist_recursively(false)
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
.generate()
.unwrap()
.to_string();
let module: syn::File = syn::parse_str(&cublas_header).unwrap();
generate_functions(
&crate_root,
"cublaslt",
&["..", "cuda_base", "src", "cublaslt.rs"],
&module,
);
generate_types_library(
&crate_root,
&["..", "cuda_types", "src", "cublaslt.rs"],
&module,
)
}
fn generate_cuda(crate_root: &PathBuf) {
let cuda_header = bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
.layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
.derive_hash(true)
.derive_eq(true)
let cuda_header = new_builder()
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
.allowlist_type("^CU.*")
.allowlist_type("^cuda.*")
.allowlist_function("^cu.*")
.allowlist_var("^CU.*")
.must_use_type("cudaError_enum")
@ -67,22 +370,14 @@ fn generate_cuda(crate_root: &PathBuf) {
}
fn generate_ml(crate_root: &PathBuf) {
let ml_header = bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
.layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
.derive_hash(true)
.derive_eq(true)
let ml_header = new_builder()
.header("/usr/local/cuda/include/nvml.h")
.allowlist_type("^nvml.*")
.allowlist_function("^nvml.*")
.allowlist_var("^NVML.*")
.must_use_type("nvmlReturn_t")
.constified_enum("nvmlReturn_enum")
.clang_args(["-I/usr/local/cuda/include"])
.generate()
.unwrap()
.to_string();
@ -112,37 +407,34 @@ fn generate_ml(crate_root: &PathBuf) {
&["..", "cuda_base", "src", "nvml.rs"],
&module,
);
generate_types(
generate_types_library(
&crate_root,
&["..", "cuda_types", "src", "nvml.rs"],
&module,
);
}
fn generate_types(crate_root: &PathBuf, path: &[&str], module: &syn::File) {
fn generate_types_library(crate_root: &PathBuf, path: &[&str], module: &syn::File) {
let module = generate_types_library_impl(module);
let mut output = crate_root.clone();
output.extend(path);
let text =
prettyplease::unparse(&module).replace("self::cudaDataType", "super::cuda::cudaDataType");
write_rust_to_file(output, &text)
}
fn generate_types_library_impl(module: &syn::File) -> syn::File {
let non_fn = module.items.iter().filter_map(|item| match item {
Item::ForeignMod(_) => None,
_ => Some(item),
});
let module: syn::File = parse_quote! {
parse_quote! {
#(#non_fn)*
};
let mut output = crate_root.clone();
output.extend(path);
write_rust_to_file(output, &prettyplease::unparse(&module))
}
}
fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
let hiprt_header = bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
.layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
.derive_hash(true)
.derive_eq(true)
let hiprt_header = new_builder()
.header("/opt/rocm/include/hip/hip_runtime_api.h")
.allowlist_type("^hip.*")
.allowlist_function("^hip.*")
@ -426,7 +718,7 @@ struct ExplicitReturnType;
impl VisitMut for ExplicitReturnType {
fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) {
if let syn::ReturnType::Default = i {
*i = parse_quote! { -> {} };
*i = parse_quote! { -> () };
}
}
}
@ -563,7 +855,7 @@ fn cuda_derive_display_trait_for_item<'a>(
state: &mut DeriveDisplayState<'a>,
item: &'a Item,
) -> Option<syn::Item> {
let path_prefix = & state.types_crate;
let path_prefix = &state.types_crate;
let path_prefix_iter = iter::repeat(&path_prefix);
let mut prepend_path = PrependCudaPath {
module: Ident::new("cuda", Span::call_site()),
@ -798,3 +1090,16 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
}
}
}
fn new_builder() -> bindgen::Builder {
bindgen::Builder::default()
.use_core()
.rust_target(bindgen::RustTarget::Stable_1_77)
.layout_tests(false)
.default_enum_style(bindgen::EnumVariation::NewType {
is_bitfield: false,
is_global: false,
})
.derive_hash(true)
.derive_eq(true)
}