Start macroizing host code

This commit is contained in:
Andrzej Janik 2024-11-19 16:40:28 +00:00
parent 79362757ed
commit 6c2a8576c2
7 changed files with 109 additions and 226 deletions

View file

@ -1,6 +1,7 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{quote, ToTokens};
use rustc_hash::FxHashMap;
use std::iter;
@ -148,3 +149,49 @@ impl VisitMut for FixFnSignatures {
s.inputs.pop_punct();
}
}
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
let fn_ = path
.segments
.pop()
.unwrap()
.into_tuple()
.0
.ident
.to_string();
let known_modules = [
"context", "device", "function", "link", "memory", "module", "pointer",
];
let segments: Vec<String> = split(&fn_[2..]);
let fn_path = join(segments, &known_modules);
quote! {
#path #fn_path
}
.into()
}
fn split(fn_: &str) -> Vec<String> {
let mut result = Vec::new();
for c in fn_.chars() {
if c.is_ascii_uppercase() {
result.push(c.to_ascii_lowercase().to_string());
} else {
result.last_mut().unwrap().push(c);
}
}
result
}
fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
let (prefix, suffix) = fn_.split_at(1);
if known_modules.contains(&&*prefix[0]) {
[&prefix[0], &suffix.join("_")]
.into_iter()
.map(|seg| Ident::new(seg, Span::call_site()))
.collect()
} else {
iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect()
}
}

View file

@ -6,3 +6,4 @@ edition = "2018"
[dependencies]
cuda_base = { path = "../cuda_base" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }

View file

@ -8083,3 +8083,8 @@ pub type CUresult = ::core::result::Result<(), CUerror>;
const _: fn() = || {
let _ = std::mem::transmute::<CUresult, u32>;
};
impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
Self(error.0)
}
}

View file

@ -11,6 +11,7 @@ crate-type = ["cdylib"]
[dependencies]
ptx = { path = "../ptx" }
cuda_types = { path = "../cuda_types" }
cuda_base = { path = "../cuda_base" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
lazy_static = "1.4"
num_enum = "0.4"

View file

@ -1,230 +1,26 @@
use hip_runtime_sys::hipError_t;
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
use std::{
ffi::c_void,
mem::{self, ManuallyDrop},
os::raw::c_int,
ptr,
sync::Mutex,
sync::TryLockError,
};
#[cfg(test)]
#[macro_use]
pub mod test;
pub mod device;
pub mod export_table;
pub mod function;
#[cfg_attr(windows, path = "os_win.rs")]
#[cfg_attr(not(windows), path = "os_unix.rs")]
pub(crate) mod os;
pub(crate) mod module;
pub(crate) mod context;
pub(crate) mod memory;
pub(crate) mod link;
pub(crate) mod pointer;
use cuda_types::*;
use hip_runtime_sys::*;
#[cfg(debug_assertions)]
pub fn unimplemented() -> CUresult {
pub(crate) fn unimplemented() -> CUresult {
unimplemented!()
}
#[cfg(not(debug_assertions))]
pub fn unimplemented() -> CUresult {
CUresult::CUDA_ERROR_NOT_SUPPORTED
pub(crate) fn unimplemented() -> CUresult {
CUresult::ERROR_NOT_SUPPORTED
}
#[macro_export]
macro_rules! hip_call {
($expr:expr) => {
#[allow(unused_unsafe)]
{
let err = unsafe { $expr };
if err != hip_runtime_sys::hipError_t::hipSuccess {
return Result::Err(err);
}
}
};
pub(crate) trait FromCuda<T>: Sized {
fn from_cuda(t: T) -> Result<Self, CUerror>;
}
pub trait HasLivenessCookie: Sized {
const COOKIE: usize;
const LIVENESS_FAIL: CUresult;
fn try_drop(&mut self) -> Result<(), CUresult>;
}
// This struct is a best-effort check if wrapped value has been dropped,
// while it's inherently safe, its use coming from FFI is very unsafe
#[repr(C)]
pub struct LiveCheck<T: HasLivenessCookie> {
cookie: usize,
data: ManuallyDrop<T>,
}
impl<T: HasLivenessCookie> LiveCheck<T> {
pub fn new(data: T) -> Self {
LiveCheck {
cookie: T::COOKIE,
data: ManuallyDrop::new(data),
}
}
fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
ctx_box.try_drop()?;
unsafe { ManuallyDrop::drop(&mut ctx_box) };
Ok(())
}
unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
outer_ptr as *mut Self
}
pub unsafe fn as_ref_unchecked(&self) -> &T {
&self.data
}
pub fn as_option_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE {
Some(&mut self.data)
} else {
None
}
}
pub fn as_result(&self) -> Result<&T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
if self.cookie == T::COOKIE {
Ok(&mut self.data)
} else {
Err(T::LIVENESS_FAIL)
}
}
#[must_use]
pub fn try_drop(&mut self) -> Result<(), CUresult> {
if self.cookie == T::COOKIE {
self.cookie = 0;
self.data.try_drop()?;
unsafe { ManuallyDrop::drop(&mut self.data) };
return Ok(());
}
Err(T::LIVENESS_FAIL)
impl FromCuda<u32> for u32 {
fn from_cuda(x: u32) -> Result<Self, CUerror> {
Ok(x)
}
}
impl<T: HasLivenessCookie> Drop for LiveCheck<T> {
fn drop(&mut self) {
self.cookie = 0;
}
}
pub trait CudaRepr: Sized {
type Impl: Sized;
}
impl<T: CudaRepr> CudaRepr for *mut T {
type Impl = *mut T::Impl;
}
pub trait Decuda<To> {
fn decuda(self: Self) -> To;
}
impl<T: CudaRepr> Decuda<*mut T::Impl> for *mut T {
fn decuda(self: Self) -> *mut T::Impl {
self as *mut _
}
}
impl<T> From<TryLockError<T>> for CUresult {
fn from(_: TryLockError<T>) -> Self {
CUresult::CUDA_ERROR_ILLEGAL_STATE
}
}
impl From<ocl_core::Error> for CUresult {
fn from(result: ocl_core::Error) -> Self {
match result {
_ => CUresult::CUDA_ERROR_UNKNOWN,
}
}
}
impl From<hip_runtime_sys::hipError_t> for CUresult {
fn from(result: hip_runtime_sys::hipError_t) -> Self {
match result {
hip_runtime_sys::hipError_t::hipErrorRuntimeMemory
| hip_runtime_sys::hipError_t::hipErrorRuntimeOther => CUresult::CUDA_ERROR_UNKNOWN,
hip_runtime_sys::hipError_t(e) => CUresult(e),
}
}
}
pub trait Encuda {
type To: Sized;
fn encuda(self: Self) -> Self::To;
}
impl Encuda for CUresult {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
self
}
}
impl Encuda for () {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
CUresult::CUDA_SUCCESS
}
}
impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1, T2> {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
match self {
Ok(e) => e.encuda(),
Err(e) => e.encuda(),
}
}
}
impl Encuda for hipError_t {
type To = CUresult;
fn encuda(self: Self) -> Self::To {
self.into()
}
}
unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
mem::transmute(t)
}
unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t)
}
pub fn driver_get_version() -> c_int {
i32::max_value()
}
impl<'a> CudaRepr for CUdeviceptr {
type Impl = *mut c_void;
}
impl Decuda<*mut c_void> for CUdeviceptr {
fn decuda(self) -> *mut c_void {
self.0 as *mut _
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
unsafe { hipInit(flags) }
}

View file

@ -1,11 +1,37 @@
extern crate lazy_static;
#[cfg(test)]
extern crate cuda_driver_sys;
#[cfg(test)]
extern crate paste;
extern crate ptx;
#[allow(warnings)]
pub mod cuda;
mod cuda_impl;
pub(crate) mod r#impl;
macro_rules! unimplemented {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
crate::r#impl::unimplemented()
}
)*
};
}
macro_rules! implemented {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
cuda_base::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda($arg_id)?),*)?;
Ok(())
}
)*
};
}
use cuda_base::cuda_function_declarations;
cuda_function_declarations!(
unimplemented,
implemented <= [
cuInit
]
);

View file

@ -183,6 +183,7 @@ impl ConvertIntoRustResult {
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct #new_error_type(pub ::core::num::NonZeroU32);
pub trait #type_trait {
#(#result_variants)*
}
@ -192,6 +193,12 @@ impl ConvertIntoRustResult {
const _: fn() = || {
let _ = std::mem::transmute::<#type_, u32>;
};
impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type {
fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
Self(error.0)
}
}
};
items.extend(extra_items);
}