mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-19 16:04:44 +00:00
Generate cudnn bindgen
This commit is contained in:
parent
7cdab7abc2
commit
734db223d6
11 changed files with 361 additions and 45 deletions
|
@ -1,2 +1,4 @@
|
|||
pub mod cuda;
|
||||
pub mod nvml;
|
||||
pub mod nvml;
|
||||
pub mod cublas;
|
||||
pub mod cublaslt;
|
|
@ -5,3 +5,4 @@
|
|||
#include <cudaEGL.h>
|
||||
#include <vdpau/vdpau.h>
|
||||
#include <cudaVDPAU.h>
|
||||
#include <library_types.h>
|
||||
|
|
1
zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_adv_infer.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_adv_infer_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_adv_train.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_adv_train_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_backend.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_backend.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_backend_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_cnn_infer.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_cnn_infer_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_cnn_train.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_cnn_train_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_ops_infer.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_ops_infer_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_ops_train.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_ops_train_v8.h>
|
1
zluda_bindgen/build/cudnn_v8/cudnn_version.h
Normal file
1
zluda_bindgen/build/cudnn_v8/cudnn_version.h
Normal file
|
@ -0,0 +1 @@
|
|||
#include <cudnn_version_v8.h>
|
|
@ -1,11 +1,13 @@
|
|||
use proc_macro2::Span;
|
||||
use quote::{format_ident, quote, ToTokens};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr};
|
||||
use std::{
|
||||
borrow::Cow, collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr,
|
||||
};
|
||||
use syn::{
|
||||
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed, FnArg,
|
||||
ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path,
|
||||
PathArguments, Signature, Type, TypePath, UseTree, PathSegment
|
||||
parse, parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FieldsUnnamed,
|
||||
FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr,
|
||||
Path, PathArguments, PathSegment, Signature, Type, TypePath, UseTree,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
|
@ -14,23 +16,324 @@ fn main() {
|
|||
&crate_root,
|
||||
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
|
||||
);
|
||||
generate_ml(&crate_root);
|
||||
generate_cuda(&crate_root);
|
||||
generate_ml(&crate_root);
|
||||
generate_cublas(&crate_root);
|
||||
generate_cublaslt(&crate_root);
|
||||
generate_cudnn(&crate_root);
|
||||
}
|
||||
|
||||
fn generate_cudnn(crate_root: &PathBuf) {
|
||||
let cudnn9 = new_builder()
|
||||
.header("/usr/include/x86_64-linux-gnu/cudnn_v9.h")
|
||||
.allowlist_type("^cudnn.*")
|
||||
.allowlist_function("^cudnn.*")
|
||||
.allowlist_var("^CUDNN_.*")
|
||||
.must_use_type("cudnnStatus_t")
|
||||
.allowlist_recursively(false)
|
||||
.clang_args(["-I/usr/local/cuda/include"])
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let module: syn::File = syn::parse_str(&cudnn9).unwrap();
|
||||
generate_functions(
|
||||
&crate_root,
|
||||
"cudnn9",
|
||||
&["..", "cuda_base", "src", "cudnn9.rs"],
|
||||
&module,
|
||||
);
|
||||
let cudnn9_types = generate_types_library_impl(&module);
|
||||
let mut current_dir = PathBuf::from(file!());
|
||||
current_dir.pop();
|
||||
let cudnn8 = new_builder()
|
||||
.header("/usr/include/x86_64-linux-gnu/cudnn_v8.h")
|
||||
.allowlist_type("^cudnn.*")
|
||||
.allowlist_function("^cudnn.*")
|
||||
.allowlist_var("^CUDNN_.*")
|
||||
.must_use_type("cudnnStatus_t")
|
||||
.allowlist_recursively(false)
|
||||
.clang_args([
|
||||
"-I/usr/local/cuda/include",
|
||||
&format!("-I{}/../build/cudnn_v8", current_dir.display()),
|
||||
])
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let module: syn::File = syn::parse_str(&cudnn8).unwrap();
|
||||
generate_functions(
|
||||
&crate_root,
|
||||
"cudnn8",
|
||||
&["..", "cuda_base", "src", "cudnn8.rs"],
|
||||
&module,
|
||||
);
|
||||
let cudnn8_types = generate_types_library_impl(&module);
|
||||
merge_types(
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cudnn.rs"],
|
||||
cudnn9_types,
|
||||
&["..", "cuda_types", "src", "cudnn9.rs"],
|
||||
cudnn8_types,
|
||||
&["..", "cuda_types", "src", "cudnn8.rs"],
|
||||
);
|
||||
}
|
||||
|
||||
// This code splits types (and constants) into one of:
|
||||
// - cudnn8-specific
|
||||
// - cudnn9-specific
|
||||
// - cudnn shared
|
||||
// With the rules being:
|
||||
// - constants go to the specific files
|
||||
// - if there's conflict between types they go to specific files
|
||||
// - if the cudnn9 type is purely additive over cudnn8 then it goes into the
|
||||
// shared (and is re-exported)
|
||||
fn merge_types(
|
||||
output: &PathBuf,
|
||||
cudnn_path: &[&str],
|
||||
cudnn9_types: syn::File,
|
||||
cudnn9_path: &[&str],
|
||||
cudnn8_types: syn::File,
|
||||
cudnn8_path: &[&str],
|
||||
) {
|
||||
let cudnn_enums = merge_enums(&cudnn9_types, &cudnn8_types);
|
||||
let conflicting_types = get_conflicting_structs(&cudnn9_types, &cudnn8_types, cudnn_enums);
|
||||
let common_items = cudnn9_types.items.iter().filter_map(|item| match item {
|
||||
Item::Impl(ref impl_) => match conflicting_types.get(type_to_ident(&*impl_.self_ty)) {
|
||||
Some(CudnnEnumMergeResult::Conflict) => None,
|
||||
Some(CudnnEnumMergeResult::Same) => {
|
||||
let item: Item = parse_quote! {
|
||||
#impl_
|
||||
};
|
||||
Some(item)
|
||||
}
|
||||
Some(CudnnEnumMergeResult::Cudnn9) => None,
|
||||
None => None,
|
||||
},
|
||||
Item::Struct(ref struct_) => match conflicting_types.get(&struct_.ident) {
|
||||
Some(CudnnEnumMergeResult::Conflict) => None,
|
||||
Some(CudnnEnumMergeResult::Same) => {
|
||||
let item: Item = parse_quote! {
|
||||
#struct_
|
||||
};
|
||||
Some(item)
|
||||
}
|
||||
Some(CudnnEnumMergeResult::Cudnn9) => None,
|
||||
None => None,
|
||||
},
|
||||
Item::Enum(ref enum_) => match conflicting_types.get(&enum_.ident) {
|
||||
Some(CudnnEnumMergeResult::Conflict) => None,
|
||||
Some(CudnnEnumMergeResult::Same) => {
|
||||
let item: Item = parse_quote! {
|
||||
#enum_
|
||||
};
|
||||
Some(item)
|
||||
}
|
||||
Some(CudnnEnumMergeResult::Cudnn9) => None,
|
||||
None => None,
|
||||
},
|
||||
Item::ForeignMod(ItemForeignMod { .. }) => None,
|
||||
_ => None,
|
||||
//_ => unimplemented!(),
|
||||
});
|
||||
|
||||
let file: syn::File = parse_quote! {
|
||||
#(#common_items)*
|
||||
};
|
||||
{
|
||||
let mut output = output.clone();
|
||||
output.extend(cudnn_path);
|
||||
let text = prettyplease::unparse(&file);
|
||||
write_rust_to_file(output, &text)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_conflicting_structs<'a>(
|
||||
cudnn9_types: &'a syn::File,
|
||||
cudnn8_types: &'a syn::File,
|
||||
mut enums: FxHashMap<&'a Ident, CudnnEnumMergeResult>,
|
||||
) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> {
|
||||
let structs9 = get_structs(cudnn9_types);
|
||||
let structs8 = get_structs(cudnn8_types);
|
||||
for (struct_name8, struct8) in structs8 {
|
||||
if enums.contains_key(struct_name8) {
|
||||
continue;
|
||||
}
|
||||
match structs9.get(struct_name8) {
|
||||
Some(struct9) => {
|
||||
if struct8 != *struct9 {
|
||||
panic!("{}", struct_name8.to_string());
|
||||
}
|
||||
let has_conflicting_field = struct8.iter().any(|field| {
|
||||
let type_ = type_to_ident(&field.ty);
|
||||
enums.get(type_) == Some(&CudnnEnumMergeResult::Conflict)
|
||||
});
|
||||
let value = if has_conflicting_field {
|
||||
CudnnEnumMergeResult::Conflict
|
||||
} else {
|
||||
CudnnEnumMergeResult::Same
|
||||
};
|
||||
assert!(enums.insert(struct_name8, value).is_none());
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
enums
|
||||
}
|
||||
|
||||
fn type_to_ident<'a>(ty: &'a syn::Type) -> &'a syn::Ident {
|
||||
match ty {
|
||||
Type::Path(path) => &path.path.segments[0].ident,
|
||||
Type::Array(array) => type_to_ident(&array.elem),
|
||||
_ => unimplemented!("{}", ty.to_token_stream().to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_enums<'a>(
|
||||
cudnn9_types: &'a syn::File,
|
||||
cudnn8_types: &'a syn::File,
|
||||
) -> FxHashMap<&'a Ident, CudnnEnumMergeResult> {
|
||||
let result = {
|
||||
let enums8 = get_enums(cudnn8_types);
|
||||
let enums9 = get_enums(cudnn9_types);
|
||||
enums8
|
||||
.iter()
|
||||
.map(|(enum8_ident, enum8_vars)| {
|
||||
let merge_result = match enums9.get(enum8_ident) {
|
||||
Some(enum9_vars) => {
|
||||
let e8_has_extra = enum8_vars.difference(&enum9_vars).any(|_| true);
|
||||
let e9_has_extra = enum9_vars.difference(&enum8_vars).any(|_| true);
|
||||
match (e8_has_extra, e9_has_extra) {
|
||||
(false, false) => CudnnEnumMergeResult::Same,
|
||||
(false, true) => CudnnEnumMergeResult::Cudnn9,
|
||||
(true, true) => CudnnEnumMergeResult::Conflict,
|
||||
(true, false) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
unimplemented!()
|
||||
}
|
||||
};
|
||||
(*enum8_ident, merge_result)
|
||||
})
|
||||
.collect::<FxHashMap<_, _>>()
|
||||
};
|
||||
result
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
enum CudnnEnumMergeResult {
|
||||
// Conflicting definitions
|
||||
Conflict,
|
||||
// Identical definitions
|
||||
Same,
|
||||
// Enum present in both, but cudnn9 definition is a strict superset
|
||||
Cudnn9,
|
||||
}
|
||||
|
||||
fn get_enums<'a>(
|
||||
cudnn_module: &'a syn::File,
|
||||
) -> FxHashMap<&'a Ident, FxHashSet<&'a syn::ImplItemConst>> {
|
||||
let mut enums = FxHashMap::default();
|
||||
for item in cudnn_module.items.iter() {
|
||||
match item {
|
||||
Item::Impl(ref impl_) => match &*impl_.self_ty {
|
||||
Type::Path(path) => {
|
||||
let constant = match impl_.items[0] {
|
||||
syn::ImplItem::Const(ref impl_item_const) => impl_item_const,
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
enums
|
||||
.entry(&path.path.segments[0].ident)
|
||||
.or_insert(FxHashSet::default())
|
||||
.insert(constant);
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
enums
|
||||
}
|
||||
|
||||
fn get_structs<'a>(cudnn_module: &'a syn::File) -> FxHashMap<&'a Ident, Cow<'a, syn::Fields>> {
|
||||
let mut structs = FxHashMap::default();
|
||||
for item in cudnn_module.items.iter() {
|
||||
match item {
|
||||
Item::Struct(ref struct_) => {
|
||||
assert!(structs
|
||||
.insert(&struct_.ident, Cow::Borrowed(&struct_.fields))
|
||||
.is_none());
|
||||
}
|
||||
Item::Union(ref union_) => {
|
||||
assert!(structs
|
||||
.insert(
|
||||
&union_.ident,
|
||||
Cow::Owned(syn::Fields::Named(union_.fields.clone()))
|
||||
)
|
||||
.is_none());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
structs
|
||||
}
|
||||
|
||||
fn generate_cublas(crate_root: &PathBuf) {
|
||||
let cublas_header = new_builder()
|
||||
.header("/usr/local/cuda/include/cublas_v2.h")
|
||||
.allowlist_type("^cublas.*")
|
||||
.allowlist_function("^cublas.*")
|
||||
.allowlist_var("^CUBLAS_.*")
|
||||
.must_use_type("cublasStatus_t")
|
||||
.allowlist_recursively(false)
|
||||
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let module: syn::File = syn::parse_str(&cublas_header).unwrap();
|
||||
generate_functions(
|
||||
&crate_root,
|
||||
"cublas",
|
||||
&["..", "cuda_base", "src", "cublas.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_types_library(
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cublas.rs"],
|
||||
&module,
|
||||
)
|
||||
}
|
||||
|
||||
fn generate_cublaslt(crate_root: &PathBuf) {
|
||||
let cublas_header = new_builder()
|
||||
.header("/usr/local/cuda/include/cublasLt.h")
|
||||
.allowlist_type("^cublas.*")
|
||||
.allowlist_function("^cublasLt.*")
|
||||
.allowlist_var("^CUBLASLT_.*")
|
||||
.must_use_type("cublasStatus_t")
|
||||
.allowlist_recursively(false)
|
||||
.clang_args(["-I/usr/local/cuda/include", "-x", "c++"])
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let module: syn::File = syn::parse_str(&cublas_header).unwrap();
|
||||
generate_functions(
|
||||
&crate_root,
|
||||
"cublaslt",
|
||||
&["..", "cuda_base", "src", "cublaslt.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_types_library(
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "cublaslt.rs"],
|
||||
&module,
|
||||
)
|
||||
}
|
||||
|
||||
fn generate_cuda(crate_root: &PathBuf) {
|
||||
let cuda_header = bindgen::Builder::default()
|
||||
.use_core()
|
||||
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||
.layout_tests(false)
|
||||
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||
is_bitfield: false,
|
||||
is_global: false,
|
||||
})
|
||||
.derive_hash(true)
|
||||
.derive_eq(true)
|
||||
let cuda_header = new_builder()
|
||||
.header_contents("cuda_wrapper.h", include_str!("../build/cuda_wrapper.h"))
|
||||
.allowlist_type("^CU.*")
|
||||
.allowlist_type("^cuda.*")
|
||||
.allowlist_function("^cu.*")
|
||||
.allowlist_var("^CU.*")
|
||||
.must_use_type("cudaError_enum")
|
||||
|
@ -67,22 +370,14 @@ fn generate_cuda(crate_root: &PathBuf) {
|
|||
}
|
||||
|
||||
fn generate_ml(crate_root: &PathBuf) {
|
||||
let ml_header = bindgen::Builder::default()
|
||||
.use_core()
|
||||
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||
.layout_tests(false)
|
||||
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||
is_bitfield: false,
|
||||
is_global: false,
|
||||
})
|
||||
.derive_hash(true)
|
||||
.derive_eq(true)
|
||||
let ml_header = new_builder()
|
||||
.header("/usr/local/cuda/include/nvml.h")
|
||||
.allowlist_type("^nvml.*")
|
||||
.allowlist_function("^nvml.*")
|
||||
.allowlist_var("^NVML.*")
|
||||
.must_use_type("nvmlReturn_t")
|
||||
.constified_enum("nvmlReturn_enum")
|
||||
.clang_args(["-I/usr/local/cuda/include"])
|
||||
.generate()
|
||||
.unwrap()
|
||||
.to_string();
|
||||
|
@ -112,37 +407,34 @@ fn generate_ml(crate_root: &PathBuf) {
|
|||
&["..", "cuda_base", "src", "nvml.rs"],
|
||||
&module,
|
||||
);
|
||||
generate_types(
|
||||
generate_types_library(
|
||||
&crate_root,
|
||||
&["..", "cuda_types", "src", "nvml.rs"],
|
||||
&module,
|
||||
);
|
||||
}
|
||||
|
||||
fn generate_types(crate_root: &PathBuf, path: &[&str], module: &syn::File) {
|
||||
fn generate_types_library(crate_root: &PathBuf, path: &[&str], module: &syn::File) {
|
||||
let module = generate_types_library_impl(module);
|
||||
let mut output = crate_root.clone();
|
||||
output.extend(path);
|
||||
let text =
|
||||
prettyplease::unparse(&module).replace("self::cudaDataType", "super::cuda::cudaDataType");
|
||||
write_rust_to_file(output, &text)
|
||||
}
|
||||
|
||||
fn generate_types_library_impl(module: &syn::File) -> syn::File {
|
||||
let non_fn = module.items.iter().filter_map(|item| match item {
|
||||
Item::ForeignMod(_) => None,
|
||||
_ => Some(item),
|
||||
});
|
||||
let module: syn::File = parse_quote! {
|
||||
parse_quote! {
|
||||
#(#non_fn)*
|
||||
};
|
||||
let mut output = crate_root.clone();
|
||||
output.extend(path);
|
||||
write_rust_to_file(output, &prettyplease::unparse(&module))
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
|
||||
let hiprt_header = bindgen::Builder::default()
|
||||
.use_core()
|
||||
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||
.layout_tests(false)
|
||||
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||
is_bitfield: false,
|
||||
is_global: false,
|
||||
})
|
||||
.derive_hash(true)
|
||||
.derive_eq(true)
|
||||
let hiprt_header = new_builder()
|
||||
.header("/opt/rocm/include/hip/hip_runtime_api.h")
|
||||
.allowlist_type("^hip.*")
|
||||
.allowlist_function("^hip.*")
|
||||
|
@ -426,7 +718,7 @@ struct ExplicitReturnType;
|
|||
impl VisitMut for ExplicitReturnType {
|
||||
fn visit_return_type_mut(&mut self, i: &mut syn::ReturnType) {
|
||||
if let syn::ReturnType::Default = i {
|
||||
*i = parse_quote! { -> {} };
|
||||
*i = parse_quote! { -> () };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -563,7 +855,7 @@ fn cuda_derive_display_trait_for_item<'a>(
|
|||
state: &mut DeriveDisplayState<'a>,
|
||||
item: &'a Item,
|
||||
) -> Option<syn::Item> {
|
||||
let path_prefix = & state.types_crate;
|
||||
let path_prefix = &state.types_crate;
|
||||
let path_prefix_iter = iter::repeat(&path_prefix);
|
||||
let mut prepend_path = PrependCudaPath {
|
||||
module: Ident::new("cuda", Span::call_site()),
|
||||
|
@ -798,3 +1090,16 @@ fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new_builder() -> bindgen::Builder {
|
||||
bindgen::Builder::default()
|
||||
.use_core()
|
||||
.rust_target(bindgen::RustTarget::Stable_1_77)
|
||||
.layout_tests(false)
|
||||
.default_enum_style(bindgen::EnumVariation::NewType {
|
||||
is_bitfield: false,
|
||||
is_global: false,
|
||||
})
|
||||
.derive_hash(true)
|
||||
.derive_eq(true)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue