mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Clean up, add cuCtxSetFlags
This commit is contained in:
parent
d41fbd50ff
commit
067c923408
4 changed files with 41 additions and 6 deletions
|
@ -69,6 +69,7 @@ cuda_function_declarations!(
|
|||
cuCtxGetDevice,
|
||||
cuCtxGetLimit,
|
||||
cuCtxSetLimit,
|
||||
cuCtxSetFlags,
|
||||
cuCtxGetStreamPriorityRange,
|
||||
cuCtxSynchronize,
|
||||
cuCtxSetCacheConfig,
|
||||
|
@ -485,6 +486,10 @@ mod definitions {
|
|||
context::set_limit(limit, value)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
|
||||
context::set_flags(flags)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
|
||||
leastPriority: *mut ::std::os::raw::c_int,
|
||||
greatestPriority: *mut ::std::os::raw::c_int,
|
||||
|
|
|
@ -222,6 +222,20 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
|
||||
with_current(|ctx| match ctx.variant {
|
||||
ContextVariant::NonPrimary(ref context) => {
|
||||
context
|
||||
.flags
|
||||
.store(flags, std::sync::atomic::Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
// This looks stupid, but this is an actual CUDA behavior,
|
||||
// see primary_context.rs test
|
||||
ContextVariant::Primary(_) => Ok(()),
|
||||
})?
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
|
||||
if ctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData};
|
||||
use super::{
|
||||
context, LiveCheck, ZludaObject, GLOBAL_STATE
|
||||
context, LiveCheck, GLOBAL_STATE
|
||||
};
|
||||
use crate::r#impl::context::ContextData;
|
||||
use crate::{r#impl::IntoCuda, hip_call_cuda};
|
||||
|
@ -12,11 +12,7 @@ use paste::paste;
|
|||
use std::{
|
||||
mem,
|
||||
os::raw::{c_char, c_uint},
|
||||
ptr,
|
||||
sync::{
|
||||
atomic::AtomicU32,
|
||||
Mutex,
|
||||
}, ops::AddAssign, ffi::CString,
|
||||
ptr,ffi::CString,
|
||||
};
|
||||
|
||||
const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0";
|
||||
|
|
|
@ -18,17 +18,37 @@ unsafe fn primary_context<T: CudaDriverFns>(cuda: T) {
|
|||
cuda.cuDevicePrimaryCtxSetFlags_v2(CUdevice_v1(0), 1),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(
|
||||
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!((1, 0), (flags, active));
|
||||
let mut primary_ctx = ptr::null_mut();
|
||||
assert_eq!(
|
||||
cuda.cuDevicePrimaryCtxRetain(&mut primary_ctx, CUdevice_v1(0)),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(
|
||||
cuda.cuCtxPushCurrent_v2(primary_ctx),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(cuda.cuCtxSetFlags(2), CUresult::CUDA_SUCCESS);
|
||||
assert_eq!(
|
||||
cuda.cuCtxSetCurrent(ptr::null_mut()),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(
|
||||
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!((1, 1), (flags, active));
|
||||
assert_ne!(primary_ctx, ptr::null_mut());
|
||||
let mut active_ctx = ptr::null_mut();
|
||||
assert_eq!(
|
||||
cuda.cuCtxGetCurrent(&mut active_ctx),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(active_ctx, ptr::null_mut());
|
||||
assert_ne!(primary_ctx, active_ctx);
|
||||
assert_eq!(
|
||||
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
|
||||
|
|
Loading…
Add table
Reference in a new issue