Clean up, add cuCtxSetFlags

This commit is contained in:
Andrzej Janik 2024-03-28 16:53:52 +01:00
parent d41fbd50ff
commit 067c923408
4 changed files with 41 additions and 6 deletions

View file

@ -69,6 +69,7 @@ cuda_function_declarations!(
cuCtxGetDevice,
cuCtxGetLimit,
cuCtxSetLimit,
cuCtxSetFlags,
cuCtxGetStreamPriorityRange,
cuCtxSynchronize,
cuCtxSetCacheConfig,
@ -485,6 +486,10 @@ mod definitions {
context::set_limit(limit, value)
}
pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
context::set_flags(flags)
}
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int,

View file

@ -222,6 +222,20 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
Ok(())
}
pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
with_current(|ctx| match ctx.variant {
ContextVariant::NonPrimary(ref context) => {
context
.flags
.store(flags, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
// This looks stupid, but this is an actual CUDA behavior,
// see primary_context.rs test
ContextVariant::Primary(_) => Ok(()),
})?
}
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);

View file

@ -1,6 +1,6 @@
use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData};
use super::{
context, LiveCheck, ZludaObject, GLOBAL_STATE
context, LiveCheck, GLOBAL_STATE
};
use crate::r#impl::context::ContextData;
use crate::{r#impl::IntoCuda, hip_call_cuda};
@ -12,11 +12,7 @@ use paste::paste;
use std::{
mem,
os::raw::{c_char, c_uint},
ptr,
sync::{
atomic::AtomicU32,
Mutex,
}, ops::AddAssign, ffi::CString,
ptr,ffi::CString,
};
const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0";

View file

@ -18,17 +18,37 @@ unsafe fn primary_context<T: CudaDriverFns>(cuda: T) {
cuda.cuDevicePrimaryCtxSetFlags_v2(CUdevice_v1(0), 1),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((1, 0), (flags, active));
let mut primary_ctx = ptr::null_mut();
assert_eq!(
cuda.cuDevicePrimaryCtxRetain(&mut primary_ctx, CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuCtxPushCurrent_v2(primary_ctx),
CUresult::CUDA_SUCCESS
);
assert_eq!(cuda.cuCtxSetFlags(2), CUresult::CUDA_SUCCESS);
assert_eq!(
cuda.cuCtxSetCurrent(ptr::null_mut()),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((1, 1), (flags, active));
assert_ne!(primary_ctx, ptr::null_mut());
let mut active_ctx = ptr::null_mut();
assert_eq!(
cuda.cuCtxGetCurrent(&mut active_ctx),
CUresult::CUDA_SUCCESS
);
assert_eq!(active_ctx, ptr::null_mut());
assert_ne!(primary_ctx, active_ctx);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),