From 89bc40618bc2b410a03c71f5684b1c8004c65157 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 28 Jan 2022 16:44:46 +0100 Subject: [PATCH] Implement static typing for dynamically-loaded CUDA DLLs --- cuda_base/src/lib.rs | 69 +++++++++-------- zluda_dump/src/format.rs | 1 - zluda_dump/src/lib.rs | 135 +++++++++++++++++++++------------ zluda_dump/src/os_unix.rs | 2 +- zluda_dump/src/os_win.rs | 2 +- zluda_dump/src/side_by_side.rs | 77 +++++++++++++++++++ 6 files changed, 205 insertions(+), 81 deletions(-) create mode 100644 zluda_dump/src/side_by_side.rs diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 8b804d1..3f6f779 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -9,12 +9,11 @@ use quote::{format_ident, quote, ToTokens}; use rustc_hash::{FxHashMap, FxHashSet}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; -use syn::token::Brace; use syn::visit_mut::VisitMut; use syn::{ bracketed, parse_macro_input, Abi, Fields, File, FnArg, ForeignItem, ForeignItemFn, Ident, - Item, ItemForeignMod, ItemMacro, LitStr, Macro, MacroDelimiter, PatType, Path, PathArguments, - PathSegment, ReturnType, Signature, Token, Type, TypeArray, TypePath, TypePtr, + Item, ItemForeignMod, LitStr, PatType, Path, PathArguments, PathSegment, ReturnType, Signature, + Token, Type, TypeArray, TypePath, TypePtr, }; const CUDA_RS: &'static str = include_str! {"cuda.rs"}; @@ -109,8 +108,11 @@ impl VisitMut for FixAbi { // Then macro goes through every function in rust.rs, and for every fn `foo`: // * if `foo` is contained in `override_fns` then pass it into `override_macro` // * if `foo` is not contained in `override_fns` pass it to `normal_macro` -// Both `override_macro` and `normal_macro` expect this format: -// macro_foo!("system" fn cuCtxDetach(ctx: CUcontext) -> CUresult) +// Both `override_macro` and `normal_macro` expect semicolon-separated list: +// macro_foo!( +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult; +// "system" fn cuCtxDetach(ctx: CUcontext) -> CUresult +// ) // Additionally, it does a fixup of CUDA types so they get prefixed with `type_path` #[proc_macro] pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { @@ -121,7 +123,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { .iter() .map(ToString::to_string) .collect::>(); - cuda_module + let (normal_macro_args, override_macro_args): (Vec<_>, Vec<_>) = cuda_module .items .into_iter() .filter_map(|item| match item { @@ -136,12 +138,7 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { }, .. }) => { - let path = if override_fns.contains(&ident.to_string()) { - &input.override_macro - } else { - &input.normal_macro - } - .clone(); + let use_normal_macro = !override_fns.contains(&ident.to_string()); let inputs = inputs .into_iter() .map(|fn_arg| match fn_arg { @@ -158,30 +155,42 @@ pub fn cuda_function_declarations(tokens: TokenStream) -> TokenStream { ReturnType::Default => unreachable!(), }; let type_path = input.type_path.clone(); - let tokens = quote! { - "system" fn #ident(#inputs) -> #type_path :: #output - }; - Some(Item::Macro(ItemMacro { - attrs: Vec::new(), - ident: None, - mac: Macro { - path, - bang_token: Token![!](Span::call_site()), - delimiter: MacroDelimiter::Brace(Brace { - span: Span::call_site(), - }), - tokens, + Some(( + quote! { + "system" fn #ident(#inputs) -> #type_path :: #output }, - semi_token: None, - })) + use_normal_macro, + )) } _ => unreachable!(), }, _ => None, }) - .map(Item::into_token_stream) - .collect::() - .into() + .partition(|(_, use_normal_macro)| *use_normal_macro); + let mut result = proc_macro2::TokenStream::new(); + if !normal_macro_args.is_empty() { + let punctuated_normal_macro_args = to_punctuated::(normal_macro_args); + let macro_ = &input.normal_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_normal_macro_args); + })); + } + if !override_macro_args.is_empty() { + let punctuated_override_macro_args = to_punctuated::(override_macro_args); + let macro_ = &input.override_macro; + result.extend(iter::once(quote! { + #macro_ ! (#punctuated_override_macro_args); + })); + } + result.into() +} + +fn to_punctuated( + elms: Vec<(proc_macro2::TokenStream, bool)>, +) -> proc_macro2::TokenStream { + let mut collection = Punctuated::::new(); + collection.extend(elms.into_iter().map(|(token_stream, _)| token_stream)); + collection.into_token_stream() } fn prepend_cuda_path_to_type(base_path: &Path, type_: Box) -> Box { diff --git a/zluda_dump/src/format.rs b/zluda_dump/src/format.rs index 8080fbc..380e52d 100644 --- a/zluda_dump/src/format.rs +++ b/zluda_dump/src/format.rs @@ -1,4 +1,3 @@ -extern crate cuda_types; use std::{ ffi::{c_void, CStr}, fmt::LowerHex, diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index d79c391..04fc36e 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -2,6 +2,7 @@ use cuda_types::{ CUdevice, CUdevice_attribute, CUfunction, CUjit_option, CUmodule, CUresult, CUuuid, }; use paste::paste; +use side_by_side::CudaDynamicFns; use std::io; use std::{ collections::HashMap, env, error::Error, ffi::c_void, fs, path::PathBuf, ptr::NonNull, rc::Rc, @@ -10,47 +11,50 @@ use std::{ #[macro_use] extern crate lazy_static; +extern crate cuda_types; macro_rules! extern_redirect { - ($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => { - #[no_mangle] - pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - let original_fn = |fn_ptr| { - let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) }; - typed_fn($( $arg_id ),*) - }; - let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { - (paste! { format :: [] }) ( - writer - $(,$arg_id)* - ) - }); - crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args) - } + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => { + $( + #[no_mangle] + pub extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| { + dynamic_fns.$fn_name($( $arg_id ),*) + }; + let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { + (paste! { format :: [] }) ( + writer + $(,$arg_id)* + ) + }); + crate::handle_cuda_function_call(stringify!($fn_name), original_fn, get_formatted_args) + } + )* }; } macro_rules! extern_redirect_with_post { - ($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path) => { - #[no_mangle] - pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { - let original_fn = |fn_ptr| { - let typed_fn = unsafe { std::mem::transmute::<_, extern "system" fn( $( $arg_id : $arg_type),* ) -> $ret_type>(fn_ptr) }; - typed_fn($( $arg_id ),*) - }; - let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { - (paste! { format :: [] }) ( - writer - $(,$arg_id)* + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => { + $( + #[no_mangle] + pub extern "system" fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + let original_fn = |dynamic_fns: &mut crate::side_by_side::CudaDynamicFns| { + dynamic_fns.$fn_name($( $arg_id ),*) + }; + let get_formatted_args = Box::new(move |writer: &mut dyn std::io::Write| { + (paste! { format :: [] }) ( + writer + $(,$arg_id)* + ) + }); + crate::handle_cuda_function_call_with_probes( + stringify!($fn_name), + || (), original_fn, + get_formatted_args, + move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result ) ) - }); - crate::handle_cuda_function_call_with_probes( - stringify!($fn_name), - || (), original_fn, - get_formatted_args, - move |logger, state, _, cuda_result| paste! { [<$fn_name _Post>] } ( $( $arg_id ),* , logger, state, cuda_result ) - ) - } + } + )* }; } @@ -77,6 +81,7 @@ mod log; #[cfg_attr(windows, path = "os_win.rs")] #[cfg_attr(not(windows), path = "os_unix.rs")] mod os; +mod side_by_side; mod trace; lazy_static! { @@ -127,7 +132,8 @@ impl LateInit { struct GlobalDelayedState { settings: Settings, - libcuda_handle: NonNull, + libcuda: CudaDynamicFns, + side_by_side_lib: Option, cuda_state: trace::StateTracker, } @@ -139,9 +145,8 @@ impl GlobalDelayedState { ) -> (LateInit, log::FunctionLogger<'a>) { let (mut fn_logger, settings) = factory.get_first_logger_and_init_settings(func, arguments_writer); - let maybe_libcuda_handle = unsafe { os::load_cuda_library(&settings.libcuda_path) }; - let libcuda_handle = match NonNull::new(maybe_libcuda_handle) { - Some(h) => h, + let libcuda = match unsafe { CudaDynamicFns::load_library(&settings.libcuda_path) } { + Some(libcuda) => libcuda, None => { fn_logger.log(log::LogEntry::ErrorBox( format!("Invalid CUDA library at path {}", &settings.libcuda_path).into(), @@ -149,11 +154,30 @@ impl GlobalDelayedState { return (LateInit::Error, fn_logger); } }; + let side_by_side_lib = settings + .side_by_side_path + .as_ref() + .and_then(|side_by_side_path| { + match unsafe { CudaDynamicFns::load_library(&*side_by_side_path) } { + Some(fns) => Some(fns), + None => { + fn_logger.log(log::LogEntry::ErrorBox( + format!( + "Invalid side-by-side CUDA library at path {}", + &side_by_side_path + ) + .into(), + )); + None + } + } + }); let cuda_state = trace::StateTracker::new(&settings); let delayed_state = GlobalDelayedState { settings, - libcuda_handle, + libcuda, cuda_state, + side_by_side_lib, }; (LateInit::Success(delayed_state), fn_logger) } @@ -163,6 +187,7 @@ struct Settings { dump_dir: Option, libcuda_path: String, override_cc_major: Option, + side_by_side_path: Option, } impl Settings { @@ -179,7 +204,7 @@ impl Settings { None } }; - let libcuda_path = match env::var("ZLUDA_DUMP_LIBCUDA_FILE") { + let libcuda_path = match env::var("ZLUDA_CUDA_LIB") { Err(env::VarError::NotPresent) => os::LIBCUDA_DEFAULT_PATH.to_owned(), Err(e) => { logger.log(log::LogEntry::ErrorBox(Box::new(e) as _)); @@ -201,10 +226,19 @@ impl Settings { Ok(cc) => Some(cc), }, }; + let side_by_side_path = match env::var("ZLUDA_SIDE_BY_SIDE_LIB") { + Err(env::VarError::NotPresent) => None, + Err(e) => { + logger.log(log::LogEntry::ErrorBox(Box::new(e) as _)); + None + } + Ok(env_string) => Some(env_string), + }; Settings { dump_dir, libcuda_path, override_cc_major, + side_by_side_path, } } @@ -241,7 +275,7 @@ pub struct ModuleDump { fn handle_cuda_function_call( func: &'static str, - original_cuda_fn: impl FnOnce(NonNull) -> CUresult, + original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option, arguments_writer: Box std::io::Result<()>>, ) -> CUresult { handle_cuda_function_call_with_probes( @@ -256,7 +290,7 @@ fn handle_cuda_function_call( fn handle_cuda_function_call_with_probes( func: &'static str, pre_probe: impl FnOnce() -> T, - original_cuda_fn: impl FnOnce(NonNull) -> CUresult, + original_cuda_fn: impl FnOnce(&mut CudaDynamicFns) -> Option, arguments_writer: Box std::io::Result<()>>, post_probe: PostFn, ) -> CUresult @@ -283,13 +317,18 @@ where (logger, global_state.delayed_state.as_mut().unwrap()) } }; - let name = std::ffi::CString::new(func).unwrap(); - let fn_ptr = - unsafe { os::get_proc_address(delayed_state.libcuda_handle.as_ptr(), name.as_c_str()) }; - let fn_ptr = NonNull::new(fn_ptr).unwrap(); let pre_result = pre_probe(); - let cu_result = original_cuda_fn(fn_ptr); - logger.result = Some(cu_result); + let maybe_cu_result = original_cuda_fn(&mut delayed_state.libcuda); + let cu_result = match maybe_cu_result { + Some(result) => result, + None => { + logger.log(log::LogEntry::ErrorBox( + format!("No function {} in the underlying CUDA library", func).into(), + )); + CUresult::CUDA_ERROR_UNKNOWN + } + }; + logger.result = maybe_cu_result; post_probe( &mut logger, &mut delayed_state.cuda_state, diff --git a/zluda_dump/src/os_unix.rs b/zluda_dump/src/os_unix.rs index 3b37e74..e1e516b 100644 --- a/zluda_dump/src/os_unix.rs +++ b/zluda_dump/src/os_unix.rs @@ -4,7 +4,7 @@ use std::mem; pub(crate) const LIBCUDA_DEFAULT_PATH: &'static str = b"/usr/lib/x86_64-linux-gnu/libcuda.so.1\0"; -pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void { +pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { let libcuda_path = CString::new(libcuda_path).unwrap(); libc::dlopen( libcuda_path.as_ptr() as *const _, diff --git a/zluda_dump/src/os_win.rs b/zluda_dump/src/os_win.rs index c138cc0..ef3da44 100644 --- a/zluda_dump/src/os_win.rs +++ b/zluda_dump/src/os_win.rs @@ -73,7 +73,7 @@ impl PlatformLibrary { } } -pub unsafe fn load_cuda_library(libcuda_path: &str) -> *mut c_void { +pub unsafe fn load_library(libcuda_path: &str) -> *mut c_void { let libcuda_path_uf16 = libcuda_path .encode_utf16() .chain(std::iter::once(0)) diff --git a/zluda_dump/src/side_by_side.rs b/zluda_dump/src/side_by_side.rs new file mode 100644 index 0000000..33954b8 --- /dev/null +++ b/zluda_dump/src/side_by_side.rs @@ -0,0 +1,77 @@ +use cuda_base::cuda_function_declarations; +use std::ffi::CStr; +use std::mem; +use std::ptr; +use std::ptr::NonNull; +use std::{marker::PhantomData, os::raw::c_void}; + +use crate::os; + +struct DynamicFn { + pointer: usize, + _marker: PhantomData, +} + +impl Default for DynamicFn { + fn default() -> Self { + DynamicFn { + pointer: 0, + _marker: PhantomData, + } + } +} + +impl DynamicFn { + unsafe fn get(&mut self, lib: *mut c_void, name: &[u8]) -> Option { + match self.pointer { + 0 => { + let addr = os::get_proc_address(lib, CStr::from_bytes_with_nul_unchecked(name)); + if addr == ptr::null_mut() { + self.pointer = 1; + return None; + } else { + self.pointer = addr as _; + } + } + 1 => return None, + _ => {} + } + Some(mem::transmute_copy(&self.pointer)) + } +} + +pub(crate) struct CudaDynamicFns { + lib_handle: NonNull<::std::ffi::c_void>, + fn_table: CudaFnTable, +} + +impl CudaDynamicFns { + pub(crate) unsafe fn load_library(path: &str) -> Option { + let lib_handle = NonNull::new(os::load_library(path)); + lib_handle.map(|lib_handle| CudaDynamicFns { + lib_handle, + fn_table: CudaFnTable::default(), + }) + } +} + +macro_rules! emit_cuda_fn_table { + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path);*) => { + #[derive(Default)] + struct CudaFnTable { + $($fn_name: DynamicFn $ret_type>),* + } + + impl CudaDynamicFns { + $( + #[allow(dead_code)] + pub(crate) fn $fn_name(&mut self, $($arg_id : $arg_type),*) -> Option<$ret_type> { + let func = unsafe { self.fn_table.$fn_name.get(self.lib_handle.as_ptr(), concat!(stringify!($fn_name), "\0").as_bytes()) }; + func.map(|f| f($($arg_id),*) ) + } + )* + } + }; +} + +cuda_function_declarations!(cuda_types, emit_cuda_fn_table, emit_cuda_fn_table, []);