Switch to pregenerated format functions

This commit is contained in:
Andrzej Janik 2024-11-18 01:54:21 +00:00
parent fa94cdfc16
commit aa6b61b414
6 changed files with 25372 additions and 116 deletions

View file

@ -1,3 +1,5 @@
// Generated automatically by zluda_bindgen
// DO NOT EDIT MANUALLY
#![allow(warnings)]
pub const CUDA_VERSION: u32 = 12040;
pub const CU_IPC_HANDLE_SIZE: u32 = 64;
@ -47,7 +49,7 @@ pub type cuuint32_t = u32;
pub type cuuint64_t = u64;
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct CUdeviceptr_v2(pub ::core::ffi::c_ulonglong);
pub struct CUdeviceptr_v2(pub *mut ::core::ffi::c_void);
pub type CUdeviceptr = CUdeviceptr_v2;
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@ -187,7 +189,7 @@ pub type CUasyncCallbackHandle = *mut CUasyncCallbackEntry_st;
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct CUuuid_st {
pub bytes: [::core::ffi::c_char; 16usize],
pub bytes: [::core::ffi::c_uchar; 16usize],
}
pub type CUuuid = CUuuid_st;
/** Fabric handle - An opaque handle representing a memory allocation
@ -7868,7 +7870,7 @@ impl CUerror {
});
}
#[repr(transparent)]
#[derive(Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct CUerror(pub ::core::num::NonZeroU32);
pub trait CUresultConsts {
const SUCCESS: CUresult = CUresult::Ok(());

View file

@ -9,3 +9,4 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89"
quote = "1.0"
prettyplease = "0.2.25"
rustc-hash = "1.1.0"

View file

@ -1,8 +1,11 @@
use proc_macro2::Span;
use quote::{format_ident, quote};
use std::{path::PathBuf, str::FromStr};
use quote::{format_ident, quote, ToTokens};
use rustc_hash::{FxHashMap, FxHashSet};
use std::{collections::hash_map, fs::File, io::Write, iter, path::PathBuf, str::FromStr};
use syn::{
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Item, ItemUse, LitStr, UseTree,
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Abi, Fields, FnArg, ForeignItem,
ForeignItemFn, Ident, Item, ItemConst, ItemForeignMod, ItemUse, LitStr, Path, PathArguments,
Signature, Type, UseTree,
};
fn main() {
@ -28,18 +31,18 @@ fn main() {
.generate()
.unwrap()
.to_string();
generate_types(
crate_root,
&["..", "cuda_types", "src", "lib.rs"],
cuda_header,
);
let module: syn::File = syn::parse_str(&cuda_header).unwrap();
generate_types(&crate_root, &["..", "cuda_types", "src", "lib.rs"], &module);
generate_display(
&crate_root,
&["..", "zluda_dump", "src", "format_generated.rs"],
"cuda_types",
&module,
)
}
fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) {
let mut module: syn::File = syn::parse_str(&cuda_header).unwrap();
module.attrs.push(parse_quote! {
#![allow(warnings)]
});
fn generate_types(output: &PathBuf, path: &[&str], module: &syn::File) {
let mut module = module.clone();
let mut converter = ConvertIntoRustResult {
type_: "CUresult",
underlying_type: "cudaError_enum",
@ -55,15 +58,38 @@ fn generate_types(mut output: PathBuf, path: &[&str], cuda_header: String) {
Item::ForeignMod(_) => None,
Item::Const(const_) => converter.get_const(const_).map(Item::Const),
Item::Use(use_) => converter.get_use(use_).map(Item::Use),
Item::Struct(mut struct_) => {
let ident_string = struct_.ident.to_string();
match &*ident_string {
"CUdeviceptr_v2" => {
struct_.fields = Fields::Unnamed(parse_quote! {
(pub *mut ::core::ffi::c_void)
});
}
"CUuuid_st" => {
struct_.fields = Fields::Named(parse_quote! {
{pub bytes: [::core::ffi::c_uchar; 16usize]}
});
}
_ => {}
}
Some(Item::Struct(struct_))
}
item => Some(item),
})
.collect::<Vec<_>>();
converter.flush(&mut module.items);
syn::visit_mut::visit_file_mut(&mut FixAbi, &mut module);
for segment in path {
output.push(segment);
}
std::fs::write(output, prettyplease::unparse(&module)).unwrap();
let mut output = output.clone();
output.extend(path);
write_rust_to_file(output, &prettyplease::unparse(&module))
}
fn write_rust_to_file(path: impl AsRef<std::path::Path>, content: &str) {
let mut file = File::create(path).unwrap();
file.write("// Generated automatically by zluda_bindgen\n// DO NOT EDIT MANUALLY\n#![allow(warnings)]\n".as_bytes())
.unwrap();
file.write(content.as_bytes()).unwrap();
}
struct ConvertIntoRustResult {
@ -154,3 +180,386 @@ impl VisitMut for FixAbi {
}
}
}
fn generate_display(
output: &PathBuf,
path: &[&str],
types_crate: &'static str,
module: &syn::File,
) {
let ignore_types = [
"CUarrayMapInfo_st",
"CUDA_RESOURCE_DESC_st",
"CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st",
"CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st",
"CUexecAffinityParam_st",
"CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st",
"CUstreamBatchMemOpParams_union_CUstreamMemOpWriteValueParams_st",
"CUuuid_st",
"HGPUNV",
"EGLint",
"EGLSyncKHR",
"EGLImageKHR",
"EGLStreamKHR",
"CUasyncNotificationInfo_st",
"CUgraphNodeParams_st",
"CUeglFrame_st",
"CUdevResource_st",
"CUlaunchAttribute_st",
"CUlaunchConfig_st",
];
let ignore_functions = [
"cuGLGetDevices",
"cuGLGetDevices_v2",
"cuStreamSetAttribute",
"cuStreamSetAttribute_ptsz",
"cuStreamGetAttribute",
"cuStreamGetAttribute_ptsz",
"cuGraphKernelNodeGetAttribute",
"cuGraphKernelNodeSetAttribute",
];
let count_selectors = [
("cuCtxCreate_v3", 1, 2),
("cuMemMapArrayAsync", 0, 1),
("cuMemMapArrayAsync_ptsz", 0, 1),
("cuStreamBatchMemOp", 2, 1),
("cuStreamBatchMemOp_ptsz", 2, 1),
("cuStreamBatchMemOp_v2", 2, 1),
];
let mut derive_state = DeriveDisplayState::new(
&ignore_types,
types_crate,
&ignore_functions,
&count_selectors,
);
let mut items = module
.items
.iter()
.filter_map(|i| cuda_derive_display_trait_for_item(&mut derive_state, i))
.collect::<Vec<_>>();
items.push(curesult_display_trait(&derive_state));
let mut output = output.clone();
output.extend(path);
write_rust_to_file(
output,
&prettyplease::unparse(&syn::File {
shebang: None,
attrs: Vec::new(),
items,
}),
);
}
struct DeriveDisplayState<'a> {
types_crate: &'static str,
ignore_types: FxHashSet<Ident>,
ignore_fns: FxHashSet<Ident>,
enums: FxHashMap<&'a Ident, Vec<&'a Ident>>,
array_arguments: FxHashMap<(Ident, usize), usize>,
result_variants: Vec<&'a ItemConst>,
}
impl<'a> DeriveDisplayState<'a> {
fn new(
ignore_types: &[&'static str],
types_crate: &'static str,
ignore_fns: &[&'static str],
count_selectors: &[(&'static str, usize, usize)],
) -> Self {
DeriveDisplayState {
types_crate,
ignore_types: ignore_types
.into_iter()
.map(|x| Ident::new(x, Span::call_site()))
.collect(),
ignore_fns: ignore_fns
.into_iter()
.map(|x| Ident::new(x, Span::call_site()))
.collect(),
array_arguments: count_selectors
.into_iter()
.map(|(name, val, count)| ((Ident::new(name, Span::call_site()), *val), *count))
.collect(),
enums: Default::default(),
result_variants: Vec::new(),
}
}
fn record_enum_variant(&mut self, enum_: &'a Ident, variant: &'a Ident) {
match self.enums.entry(enum_) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().push(variant);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(vec![variant]);
}
}
}
}
fn cuda_derive_display_trait_for_item<'a>(
state: &mut DeriveDisplayState<'a>,
item: &'a Item,
) -> Option<syn::Item> {
let path_prefix = Path::from(Ident::new(state.types_crate, Span::call_site()));
let path_prefix_iter = iter::repeat(&path_prefix);
match item {
Item::Const(const_) => {
if const_.ty.to_token_stream().to_string() == "cudaError_enum" {
state.result_variants.push(const_);
}
None
}
Item::ForeignMod(ItemForeignMod { items, .. }) => match items.last().unwrap() {
ForeignItem::Fn(ForeignItemFn {
sig: Signature { ident, inputs, .. },
..
}) => {
if state.ignore_fns.contains(ident) {
return None;
}
let inputs = inputs
.iter()
.map(|fn_arg| match fn_arg {
FnArg::Typed(ref pat_type) => {
let mut pat_type = pat_type.clone();
pat_type.ty = prepend_cuda_path_to_type(&path_prefix, pat_type.ty);
FnArg::Typed(pat_type)
}
_ => unreachable!(),
})
.collect::<Vec<_>>();
let inputs_iter = inputs.iter();
let original_fn_name = ident.to_string();
let mut write_argument = inputs.iter().enumerate().map(|(index, fn_arg)| {
let name = fn_arg_name(fn_arg);
if let Some(length_index) = state.array_arguments.get(&(ident.clone(), index)) {
let length = fn_arg_name(&inputs[*length_index]);
quote! {
writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?;
writer.write_all(b"[")?;
for i in 0..#length {
if i != 0 {
writer.write_all(b", ")?;
}
crate::format::CudaDisplay::write(unsafe { &*#name.add(i as usize) }, #original_fn_name, arg_idx, writer)?;
}
writer.write_all(b"]")?;
}
} else {
quote! {
writer.write_all(concat!(stringify!(#name), ": ").as_bytes())?;
crate::format::CudaDisplay::write(&#name, #original_fn_name, arg_idx, writer)?;
}
}
});
let fn_name = format_ident!("write_{}", ident);
Some(match write_argument.next() {
Some(first_write_argument) => parse_quote! {
pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized), #(#inputs_iter,)*) -> std::io::Result<()> {
let mut arg_idx = 0usize;
writer.write_all(b"(")?;
#first_write_argument
#(
arg_idx += 1;
writer.write_all(b", ")?;
#write_argument
)*
writer.write_all(b")")
}
},
None => parse_quote! {
pub fn #fn_name(writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
writer.write_all(b"()")
}
},
})
}
_ => unreachable!(),
},
Item::Impl(ref item_impl) => {
let enum_ = match &*item_impl.self_ty {
Type::Path(ref path) => &path.path.segments.last().unwrap().ident,
_ => unreachable!(),
};
let variant_ = match item_impl.items.last().unwrap() {
syn::ImplItem::Const(item_const) => &item_const.ident,
_ => unreachable!(),
};
state.record_enum_variant(enum_, variant_);
None
}
Item::Struct(item_struct) => {
if state.ignore_types.contains(&item_struct.ident) {
return None;
}
if state.enums.contains_key(&item_struct.ident) {
let enum_ = &item_struct.ident;
let enum_iter = iter::repeat(&item_struct.ident);
let variants = state.enums.get(&item_struct.ident).unwrap().iter();
Some(parse_quote! {
impl crate::format::CudaDisplay for #path_prefix :: #enum_ {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
match self {
#(& #path_prefix_iter :: #enum_iter :: #variants => writer.write_all(stringify!(#variants).as_bytes()),)*
_ => write!(writer, "{}", self.0)
}
}
}
})
} else {
let struct_ = &item_struct.ident;
let (first_field, rest_of_fields) = match item_struct.fields {
Fields::Named(ref fields) => {
let mut all_idents = fields.named.iter().filter_map(|f| {
let f_ident = f.ident.as_ref().unwrap();
let name = f_ident.to_string();
if name.starts_with("reserved") || name == "_unused" {
None
} else {
Some(f_ident)
}
});
let first = match all_idents.next() {
Some(f) => f,
None => return None,
};
(first, all_idents)
}
_ => return None,
};
Some(parse_quote! {
impl crate::format::CudaDisplay for #path_prefix :: #struct_ {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
writer.write_all(concat!("{ ", stringify!(#first_field), ": ").as_bytes())?;
crate::format::CudaDisplay::write(&self.#first_field, "", 0, writer)?;
#(
writer.write_all(concat!(", ", stringify!(#rest_of_fields), ": ").as_bytes())?;
crate::format::CudaDisplay::write(&self.#rest_of_fields, "", 0, writer)?;
)*
writer.write_all(b" }")
}
}
})
}
}
Item::Type(item_type) => {
if state.ignore_types.contains(&item_type.ident) {
return None;
};
match &*item_type.ty {
Type::Ptr(_) => {
let type_ = &item_type.ident;
Some(parse_quote! {
impl crate::format::CudaDisplay for #path_prefix :: #type_ {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
write!(writer, "{:p}", *self)
}
}
})
}
Type::Path(type_path) => {
if type_path.path.leading_colon.is_some() {
let option_seg = type_path.path.segments.last().unwrap();
if option_seg.ident == "Option" {
match &option_seg.arguments {
PathArguments::AngleBracketed(generic) => match generic.args[0] {
syn::GenericArgument::Type(Type::BareFn(_)) => {
let type_ = &item_type.ident;
return Some(parse_quote! {
impl crate::format::CudaDisplay for #path_prefix :: #type_ {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
write!(writer, "{:p}", unsafe { std::mem::transmute::<#path_prefix :: #type_, *mut ::std::ffi::c_void>(*self) })
}
}
});
}
_ => unreachable!(),
},
_ => unreachable!(),
}
}
}
None
}
_ => unreachable!(),
}
}
Item::Union(_) => None,
Item::Use(_) => None,
_ => unreachable!(),
}
}
fn fn_arg_name(fn_arg: &FnArg) -> &Box<syn::Pat> {
let name = if let FnArg::Typed(t) = fn_arg {
&t.pat
} else {
unreachable!()
};
name
}
fn prepend_cuda_path_to_type(base_path: &Path, type_: Box<Type>) -> Box<Type> {
match *type_ {
Type::Path(mut type_path) => {
type_path.path = prepend_cuda_path_to_path(base_path, type_path.path);
Box::new(Type::Path(type_path))
}
Type::Ptr(mut type_ptr) => {
type_ptr.elem = prepend_cuda_path_to_type(base_path, type_ptr.elem);
Box::new(Type::Ptr(type_ptr))
}
_ => unreachable!(),
}
}
fn prepend_cuda_path_to_path(base_path: &Path, path: Path) -> Path {
if path.leading_colon.is_some() {
return path;
}
if path.segments.len() == 1 {
let ident = path.segments[0].ident.to_string();
if ident.starts_with("CU")
|| ident.starts_with("cu")
|| ident.starts_with("GL")
|| ident.starts_with("EGL")
|| ident.starts_with("Vdp")
|| ident == "HGPUNV"
{
let mut base_path = base_path.clone();
base_path.segments.extend(path.segments);
return base_path;
}
}
path
}
fn curesult_display_trait(derive_state: &DeriveDisplayState) -> syn::Item {
let errors = derive_state.result_variants.iter().filter_map(|const_| {
let prefix = "cudaError_enum_";
let text = &const_.ident.to_string()[prefix.len()..];
if text == "CUDA_SUCCESS" {
return None;
}
let expr = &const_.expr;
Some(quote! {
#expr => writer.write_all(#text.as_bytes()),
})
});
parse_quote! {
impl crate::format::CudaDisplay for cuda_types::CUresult {
fn write(&self, _fn_name: &'static str, _index: usize, writer: &mut (impl std::io::Write + ?Sized)) -> std::io::Result<()> {
match self {
Ok(()) => writer.write_all(b"CUDA_SUCCESS"),
Err(err) => {
match err.0.get() {
#(#errors)*
err => write!(writer, "{}", err)
}
}
}
}
}
}
}

View file

@ -1,4 +1,4 @@
use cuda_base::cuda_derive_display_trait;
use cuda_types::{CUGLDeviceList, CUdevice};
use std::{
ffi::{c_void, CStr},
fmt::LowerHex,
@ -596,34 +596,26 @@ impl<T: CudaDisplay, const N: usize> CudaDisplay for [T; N] {
}
}
#[allow(non_snake_case)]
pub fn write_cuStreamBatchMemOp(
writer: &mut (impl std::io::Write + ?Sized),
stream: cuda_types::CUstream,
count: ::std::os::raw::c_uint,
paramArray: *mut cuda_types::CUstreamBatchMemOpParams,
flags: ::std::os::raw::c_uint,
) -> std::io::Result<()> {
writer.write_all(b"(stream: ")?;
CudaDisplay::write(&stream, "cuStreamBatchMemOp", 0, writer)?;
writer.write_all(b", ")?;
writer.write_all(b"count: ")?;
CudaDisplay::write(&count, "cuStreamBatchMemOp", 1, writer)?;
writer.write_all(b", paramArray: [")?;
for i in 0..count {
if i != 0 {
writer.write_all(b", ")?;
}
CudaDisplay::write(
&unsafe { paramArray.add(i as usize) },
"cuStreamBatchMemOp",
2,
writer,
)?;
impl CudaDisplay for cuda_types::CUarrayMapInfo_st {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
_writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
todo!()
}
}
impl CudaDisplay for cuda_types::CUexecAffinityParam_st {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
_writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
todo!()
}
writer.write_all(b"], flags: ")?;
CudaDisplay::write(&flags, "cuStreamBatchMemOp", 3, writer)?;
writer.write_all(b") ")
}
#[allow(non_snake_case)]
@ -762,81 +754,27 @@ pub fn write_cuStreamSetAttribute_ptsz(
}
#[allow(non_snake_case)]
pub fn write_cuCtxCreate_v3(
pub fn write_cuGLGetDevices(
_writer: &mut (impl std::io::Write + ?Sized),
_pctx: *mut cuda_types::CUcontext,
_paramsArray: *mut cuda_types::CUexecAffinityParam,
_numParams: ::std::os::raw::c_int,
_flags: ::std::os::raw::c_uint,
_dev: cuda_types::CUdevice,
_pCudaDeviceCount: *mut ::std::os::raw::c_uint,
_pCudaDevices: *mut CUdevice,
_cudaDeviceCount: ::std::os::raw::c_uint,
_deviceList: CUGLDeviceList,
) -> std::io::Result<()> {
todo!()
}
#[allow(non_snake_case)]
pub fn write_cuCtxGetExecAffinity(
pub fn write_cuGLGetDevices_v2(
_writer: &mut (impl std::io::Write + ?Sized),
_pExecAffinity: *mut cuda_types::CUexecAffinityParam,
_type_: cuda_types::CUexecAffinityType,
_pCudaDeviceCount: *mut ::std::os::raw::c_uint,
_pCudaDevices: *mut CUdevice,
_cudaDeviceCount: ::std::os::raw::c_uint,
_deviceList: CUGLDeviceList,
) -> std::io::Result<()> {
todo!()
}
#[allow(non_snake_case)]
pub fn write_cuMemMapArrayAsync(
_writer: &mut (impl std::io::Write + ?Sized),
_mapInfoList: *mut cuda_types::CUarrayMapInfo,
_count: ::std::os::raw::c_uint,
_hStream: cuda_types::CUstream,
) -> std::io::Result<()> {
todo!()
}
#[allow(non_snake_case)]
pub fn write_cuMemMapArrayAsync_ptsz(
writer: &mut (impl std::io::Write + ?Sized),
mapInfoList: *mut cuda_types::CUarrayMapInfo,
count: ::std::os::raw::c_uint,
hStream: cuda_types::CUstream,
) -> std::io::Result<()> {
write_cuMemMapArrayAsync(writer, mapInfoList, count, hStream)
}
cuda_derive_display_trait!(
cuda_types,
CudaDisplay,
[
CUarrayMapInfo_st,
CUDA_RESOURCE_DESC_st,
CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st,
CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st,
CUexecAffinityParam_st,
CUstreamBatchMemOpParams_union_CUstreamMemOpWaitValueParams_st,
CUstreamBatchMemOpParams_union_CUstreamMemOpWriteValueParams_st,
CUuuid_st,
HGPUNV,
EGLint,
EGLSyncKHR,
EGLImageKHR,
EGLStreamKHR,
CUasyncNotificationInfo_st,
CUgraphNodeParams_st,
CUeglFrame_st,
CUdevResource_st,
CUlaunchAttribute_st,
CUlaunchConfig_st
],
[
cuCtxCreate_v3,
cuCtxGetExecAffinity,
cuGraphKernelNodeGetAttribute,
cuGraphKernelNodeSetAttribute,
cuMemMapArrayAsync,
cuMemMapArrayAsync_ptsz,
cuStreamBatchMemOp,
cuStreamGetAttribute,
cuStreamGetAttribute_ptsz,
cuStreamSetAttribute,
cuStreamSetAttribute_ptsz
]
);
#[path = "format_generated.rs"]
mod format_generated;
pub(crate) use format_generated::*;

File diff suppressed because it is too large Load diff

View file

@ -299,7 +299,7 @@ where
// alternatively we could return a CUDA error, but I think it's fine to
// crash. This is a diagnostic utility, if the lock was poisoned we can't
// extract any useful trace or logging anyway
let mut global_state = &mut *global_state_mutex.lock().unwrap();
let global_state = &mut *global_state_mutex.lock().unwrap();
let (mut logger, delayed_state) = match global_state.delayed_state {
LateInit::Success(ref mut delayed_state) => (
global_state.log_factory.get_logger(func, arguments_writer),