Add cuCtx*

This commit is contained in:
Andrzej Janik 2024-11-20 02:11:22 +00:00
parent 94e8e13425
commit 122676bb13
6 changed files with 37 additions and 31 deletions

View file

@ -1 +0,0 @@
bindgen build/wrapper.h -o src/cuda.rs --no-partialeq "CUDA_HOST_NODE_PARAMS_st" --with-derive-eq --allowlist-type="^CU.*" --allowlist-function="^cu.*" --allowlist-var="^CU.*" --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug --new-type-alias "^CUdevice_v\d+$|^CUdeviceptr_v\d+$" --must-use-type "cudaError_enum" --constified-enum "cudaError_enum" -- -I/usr/local/cuda/include

View file

@ -162,7 +162,13 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.ident
.to_string();
let known_modules = [
"context", "device", "function", "link", "memory", "module", "pointer",
("ctx", "context"),
("device", "device"),
("function", "function"),
("link", "link"),
("memory", "memory"),
("module", "module"),
("pointer", "pointer"),
];
let segments: Vec<String> = split(&fn_[2..]);
let fn_path = join(segments, &known_modules);
@ -184,10 +190,13 @@ fn split(fn_: &str) -> Vec<String> {
result
}
fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
fn join(fn_: Vec<String>, known_modules: &[(&str, &str)]) -> Punctuated<Ident, Token![::]> {
let (prefix, suffix) = fn_.split_at(1);
if known_modules.contains(&&*prefix[0]) {
[&prefix[0], &suffix.join("_")]
if let Some((_, mod_name)) = known_modules
.iter()
.find(|(mod_prefix, _)| mod_prefix == &prefix[0])
{
[*mod_name, &suffix.join("_")]
.into_iter()
.map(|seg| Ident::new(seg, Span::call_site()))
.collect()

View file

@ -1,3 +0,0 @@
bindgen /usr/local/cuda/include/cuda.h -o cuda.rs --whitelist-function="^cu.*" --size_t-is-usize --default-enum-style=newtype --no-layout-tests --no-doc-comments --no-derive-debug --new-type-alias "^CUdevice$|^CUdeviceptr$"
sed -i -e 's/extern "C" {//g' -e 's/-> CUresult;/-> CUresult { impl_::unsupported()/g' -e 's/pub fn /#[no_mangle] pub extern "system" fn /g' cuda.rs
rustfmt cuda.rs

View file

@ -1,24 +1,9 @@
use std::ptr;
use hip_runtime_sys::*;
use crate::cuda::CUlimit;
use crate::cuda::CUresult;
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: CUlimit) -> CUresult {
if pvalue == ptr::null_mut() {
return CUresult::CUDA_ERROR_INVALID_VALUE;
}
if limit == CUlimit::CU_LIMIT_STACK_SIZE {
*pvalue = 512; // GTX 1060 reports 1024
CUresult::CUDA_SUCCESS
} else {
CUresult::CUDA_ERROR_NOT_SUPPORTED
}
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
unsafe { hipDeviceGetLimit(pvalue, limit) }
}
pub(crate) fn set_limit(limit: CUlimit, value: usize) -> CUresult {
if limit == CUlimit::CU_LIMIT_STACK_SIZE {
CUresult::CUDA_SUCCESS
} else {
CUresult::CUDA_ERROR_NOT_SUPPORTED
}
pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
unsafe { hipDeviceSetLimit(limit, value) }
}

View file

@ -1,6 +1,7 @@
use cuda_types::*;
use hip_runtime_sys::*;
pub(super) mod context;
pub(super) mod device;
#[cfg(debug_assertions)]
@ -17,7 +18,7 @@ pub(crate) trait FromCuda<'a, T>: Sized {
fn from_cuda(t: &'a T) -> Result<Self, CUerror>;
}
macro_rules! from_cuda_noop {
macro_rules! from_cuda_nop {
($($type_:ty),*) => {
$(
impl<'a> FromCuda<'a, $type_> for $type_ {
@ -65,18 +66,31 @@ macro_rules! from_cuda_transmute {
};
}
from_cuda_noop!(
from_cuda_nop!(
*mut i8,
*mut usize,
i32,
u32,
cuda_types::CUdevprop, CUdevice_attribute
usize,
cuda_types::CUdevprop,
CUdevice_attribute
);
from_cuda_transmute!(
CUdevice => hipDevice_t,
CUuuid => hipUUID
);
impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
Ok(match *limit {
CUlimit::CU_LIMIT_STACK_SIZE => hipLimit_t::hipLimitStackSize,
CUlimit::CU_LIMIT_PRINTF_FIFO_SIZE => hipLimit_t::hipLimitPrintfFifoSize,
CUlimit::CU_LIMIT_MALLOC_HEAP_SIZE => hipLimit_t::hipLimitMallocHeapSize,
_ => return Err(CUerror::NOT_SUPPORTED),
})
}
}
pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
unsafe { hipInit(flags) }
}

View file

@ -32,6 +32,8 @@ use cuda_base::cuda_function_declarations;
cuda_function_declarations!(
unimplemented,
implemented <= [
cuCtxGetLimit,
cuCtxSetLimit,
cuDeviceComputeCapability,
cuDeviceGet,
cuDeviceGetAttribute,