Add Blender 4.2 support (#184)

Redo primary context and fix various long-standing bugs around this API
This commit is contained in:
Andrzej Janik 2024-03-28 17:12:10 +01:00 committed by GitHub
commit 7d4147c8b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 288 additions and 160 deletions

View file

@ -69,6 +69,7 @@ cuda_function_declarations!(
cuCtxGetDevice, cuCtxGetDevice,
cuCtxGetLimit, cuCtxGetLimit,
cuCtxSetLimit, cuCtxSetLimit,
cuCtxSetFlags,
cuCtxGetStreamPriorityRange, cuCtxGetStreamPriorityRange,
cuCtxSynchronize, cuCtxSynchronize,
cuCtxSetCacheConfig, cuCtxSetCacheConfig,
@ -485,6 +486,10 @@ mod definitions {
context::set_limit(limit, value) context::set_limit(limit, value)
} }
pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
context::set_flags(flags)
}
pub(crate) unsafe fn cuCtxGetStreamPriorityRange( pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int, leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int, greatestPriority: *mut ::std::os::raw::c_int,

View file

@ -7,7 +7,7 @@ use cuda_types::*;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use std::ptr; use std::ptr;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::AtomicU32;
use std::sync::Mutex; use std::sync::Mutex;
use std::{cell::RefCell, ffi::c_void}; use std::{cell::RefCell, ffi::c_void};
@ -28,57 +28,104 @@ impl ZludaObject for ContextData {
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> { fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> {
let mutable = self self.with_inner_mut(|mutable| {
.mutable fold_cuda_errors(
.get_mut() mutable
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; .streams
fold_cuda_errors(mutable.streams.iter().copied().map(|s| { .iter()
unsafe { LiveCheck::drop_box_with_result(s, true)? }; .copied()
Ok(()) .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
})) )
})?
} }
} }
pub(crate) struct ContextData { pub(crate) struct ContextData {
pub(crate) flags: AtomicU32,
is_primary: bool,
pub(crate) ref_count: AtomicU32,
pub(crate) device: hipDevice_t, pub(crate) device: hipDevice_t,
pub(crate) mutable: Mutex<ContextDataMutable>, pub(crate) variant: ContextVariant,
}
pub(crate) enum ContextVariant {
NonPrimary(NonPrimaryContextData),
Primary(Mutex<PrimaryContextData>),
}
pub(crate) struct PrimaryContextData {
pub(crate) ref_count: u32,
pub(crate) flags: u32,
pub(crate) mutable: ContextInnerMutable,
}
pub(crate) struct NonPrimaryContextData {
flags: AtomicU32,
mutable: Mutex<ContextInnerMutable>,
} }
impl ContextData { impl ContextData {
pub(crate) fn new( pub(crate) fn new_non_primary(flags: u32, device: hipDevice_t) -> Self {
flags: u32, Self {
device: hipDevice_t,
is_primary: bool,
initial_refcount: u32,
) -> Result<Self, CUresult> {
Ok(ContextData {
flags: AtomicU32::new(flags),
device, device,
ref_count: AtomicU32::new(initial_refcount), variant: ContextVariant::NonPrimary(NonPrimaryContextData {
is_primary, flags: AtomicU32::new(flags),
mutable: Mutex::new(ContextDataMutable::new()), mutable: Mutex::new(ContextInnerMutable::new()),
}),
}
}
pub(crate) fn new_primary(device: hipDevice_t) -> Self {
Self {
device,
variant: ContextVariant::Primary(Mutex::new(PrimaryContextData {
ref_count: 0,
flags: 0,
mutable: ContextInnerMutable::new(),
})),
}
}
pub(crate) fn with_inner_mut<T>(
&self,
fn_: impl FnOnce(&mut ContextInnerMutable) -> T,
) -> Result<T, CUresult> {
Ok(match self.variant {
ContextVariant::Primary(ref mutex_over_primary_ctx_data) => {
let mut primary_ctx_data = mutex_over_primary_ctx_data
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
fn_(&mut primary_ctx_data.mutable)
}
ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => {
let mut ctx_data_mutable =
mutable.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
fn_(&mut ctx_data_mutable)
}
}) })
} }
} }
pub(crate) struct ContextDataMutable { pub(crate) struct ContextInnerMutable {
pub(crate) streams: FxHashSet<*mut stream::Stream>, pub(crate) streams: FxHashSet<*mut stream::Stream>,
pub(crate) modules: FxHashSet<*mut module::Module>, pub(crate) modules: FxHashSet<*mut module::Module>,
// Field below is here to support CUDA Driver Dark API // Field below is here to support CUDA Driver Dark API
pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>, pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>,
} }
impl ContextDataMutable { impl ContextInnerMutable {
fn new() -> Self { pub(crate) fn new() -> Self {
ContextDataMutable { ContextInnerMutable {
streams: FxHashSet::default(), streams: FxHashSet::default(),
modules: FxHashSet::default(), modules: FxHashSet::default(),
local_storage: FxHashMap::default(), local_storage: FxHashMap::default(),
} }
} }
pub(crate) fn drop_with_result(&mut self) -> Result<(), CUresult> {
fold_cuda_errors(
self.streams
.iter()
.copied()
.map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
)
}
} }
pub(crate) struct LocalStorageValue { pub(crate) struct LocalStorageValue {
@ -94,7 +141,7 @@ pub(crate) unsafe fn create(
if pctx == ptr::null_mut() { if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let context_box = Box::new(LiveCheck::new(ContextData::new(flags, dev, false, 1)?)); let context_box = Box::new(LiveCheck::new(ContextData::new_non_primary(flags, dev)));
let context_ptr = Box::into_raw(context_box); let context_ptr = Box::into_raw(context_box);
*pctx = context_ptr; *pctx = context_ptr;
push_context_stack(context_ptr) push_context_stack(context_ptr)
@ -105,7 +152,7 @@ pub(crate) unsafe fn destroy(ctx: *mut Context) -> Result<(), CUresult> {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let ctx_ref = LiveCheck::as_result(ctx)?; let ctx_ref = LiveCheck::as_result(ctx)?;
if ctx_ref.is_primary { if let ContextVariant::Primary { .. } = ctx_ref.variant {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
} }
CONTEXT_STACK.with(|stack| { CONTEXT_STACK.with(|stack| {
@ -175,14 +222,25 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
Ok(()) Ok(())
} }
pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
with_current(|ctx| match ctx.variant {
ContextVariant::NonPrimary(ref context) => {
context
.flags
.store(flags, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
// This looks stupid, but this is an actual CUDA behavior,
// see primary_context.rs test
ContextVariant::Primary(_) => Ok(()),
})?
}
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() { if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
} }
let ctx = LiveCheck::as_result(ctx)?; //let ctx = LiveCheck::as_result(ctx)?;
if ctx.ref_count.load(Ordering::Acquire) == 0 {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
//TODO: query device for properties roughly matching CUDA API version //TODO: query device for properties roughly matching CUDA API version
*version = 3020; *version = 3020;
Ok(()) Ok(())

View file

@ -121,11 +121,16 @@ impl CudaDarkApi for CudaDarkApiZluda {
value: *mut c_void, value: *mut c_void,
dtor_callback: Option<extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void)>, dtor_callback: Option<extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void)>,
) -> CUresult { ) -> CUresult {
unsafe fn context_local_storage_insert_impl(
cu_ctx: cuda_types::CUcontext,
key: *mut c_void,
value: *mut c_void,
dtor_callback: Option<
extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void),
>,
) -> Result<(), CUresult> {
with_context_or_current(cu_ctx, |ctx| { with_context_or_current(cu_ctx, |ctx| {
let mut ctx_mutable = ctx ctx.with_inner_mut(|ctx_mutable| {
.mutable
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
ctx_mutable.local_storage.insert( ctx_mutable.local_storage.insert(
key, key,
LocalStorageValue { LocalStorageValue {
@ -133,8 +138,10 @@ impl CudaDarkApi for CudaDarkApiZluda {
_dtor_callback: dtor_callback, _dtor_callback: dtor_callback,
}, },
); );
Ok(())
}) })
})?
}
context_local_storage_insert_impl(cu_ctx, key, value, dtor_callback).into_cuda()
} }
// TODO // TODO
@ -143,29 +150,30 @@ impl CudaDarkApi for CudaDarkApiZluda {
} }
unsafe extern "system" fn context_local_storage_get( unsafe extern "system" fn context_local_storage_get(
result: *mut *mut c_void, cu_result: *mut *mut c_void,
cu_ctx: cuda_types::CUcontext, cu_ctx: cuda_types::CUcontext,
key: *mut c_void, key: *mut c_void,
) -> CUresult { ) -> CUresult {
let mut cu_result = None; unsafe fn context_local_storage_get_impl(
let query_cu_result = with_context_or_current(cu_ctx, |ctx| { cu_ctx: cuda_types::CUcontext,
let ctx_mutable = ctx key: *mut c_void,
.mutable ) -> Result<*mut c_void, CUresult> {
.lock() with_context_or_current(cu_ctx, |ctx| {
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; ctx.with_inner_mut(|ctx_mutable| {
cu_result = ctx_mutable.local_storage.get(&key).map(|v| v.value); ctx_mutable
Ok(()) .local_storage
}); .get(&key)
if query_cu_result != CUresult::CUDA_SUCCESS { .map(|v| v.value)
query_cu_result .ok_or(CUresult::CUDA_ERROR_INVALID_VALUE)
} else { })?
match cu_result { })?
Some(value) => { }
*result = value; match context_local_storage_get_impl(cu_ctx, key) {
Ok(result) => {
*cu_result = result;
CUresult::CUDA_SUCCESS CUresult::CUDA_SUCCESS
} }
None => CUresult::CUDA_ERROR_INVALID_VALUE, Err(err) => err,
}
} }
} }
@ -386,14 +394,14 @@ impl CudaDarkApi for CudaDarkApiZluda {
} }
} }
unsafe fn with_context_or_current( unsafe fn with_context_or_current<T>(
ctx: CUcontext, ctx: CUcontext,
f: impl FnOnce(&context::ContextData) -> Result<(), CUresult>, fn_: impl FnOnce(&context::ContextData) -> T,
) -> CUresult { ) -> Result<T, CUresult> {
if ctx == ptr::null_mut() { if ctx == ptr::null_mut() {
context::with_current(|c| f(c)).into_cuda() context::with_current(|c| fn_(c))
} else { } else {
let ctx = FromCuda::from_cuda(ctx); let ctx = FromCuda::from_cuda(ctx);
LiveCheck::as_result(ctx).map(f).into_cuda() Ok(fn_(LiveCheck::as_result(ctx)?))
} }
} }

View file

@ -1,6 +1,8 @@
use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData};
use super::{ use super::{
context, LiveCheck, GLOBAL_STATE, context, LiveCheck, GLOBAL_STATE
}; };
use crate::r#impl::context::ContextData;
use crate::{r#impl::IntoCuda, hip_call_cuda}; use crate::{r#impl::IntoCuda, hip_call_cuda};
use crate::hip_call; use crate::hip_call;
use cuda_types::{CUdevice_attribute, CUdevprop, CUuuid_st, CUresult}; use cuda_types::{CUdevice_attribute, CUdevprop, CUuuid_st, CUresult};
@ -10,11 +12,7 @@ use paste::paste;
use std::{ use std::{
mem, mem,
os::raw::{c_char, c_uint}, os::raw::{c_char, c_uint},
ptr, ptr,ffi::CString,
sync::{
atomic::AtomicU32,
Mutex,
}, ops::AddAssign, ffi::CString,
}; };
const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0"; const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0";
@ -28,9 +26,7 @@ pub const COMPUTE_CAPABILITY_MINOR: u32 = 8;
pub(crate) struct Device { pub(crate) struct Device {
pub(crate) compilation_mode: CompilationMode, pub(crate) compilation_mode: CompilationMode,
pub(crate) comgr_isa: CString, pub(crate) comgr_isa: CString,
// Primary context is lazy-initialized, the mutex is here to secure retain primary_context: context::Context,
// from multiple threads
primary_context: Mutex<Option<context::Context>>,
} }
impl Device { impl Device {
@ -48,7 +44,7 @@ impl Device {
Ok(Self { Ok(Self {
compilation_mode, compilation_mode,
comgr_isa, comgr_isa,
primary_context: Mutex::new(None), primary_context: LiveCheck::new(ContextData::new_primary(index as i32)),
}) })
} }
} }
@ -516,38 +512,29 @@ unsafe fn primary_ctx_get_or_retain(
if pctx == ptr::null_mut() { if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let ctx = primary_ctx(hip_dev, |ctx| { let ctx = primary_ctx(hip_dev, |ctx, raw_ctx| {
let ctx = match ctx { if increment_refcount || ctx.ref_count == 0 {
Some(ref mut ctx) => ctx, ctx.ref_count += 1;
None => {
ctx.insert(LiveCheck::new(context::ContextData::new(0, hip_dev, true, 0)?))
},
};
if increment_refcount {
ctx.as_mut_unchecked().ref_count.get_mut().add_assign(1);
} }
Ok(ctx as *mut _) Ok(raw_ctx.cast_mut())
})??; })??;
*pctx = ctx; *pctx = ctx;
Ok(()) Ok(())
} }
pub(crate) unsafe fn primary_ctx_release(hip_dev: hipDevice_t) -> Result<(), CUresult> { pub(crate) unsafe fn primary_ctx_release(hip_dev: hipDevice_t) -> Result<(), CUresult> {
primary_ctx(hip_dev, move |maybe_ctx| { primary_ctx(hip_dev, |ctx, _| {
if let Some(ctx) = maybe_ctx { if ctx.ref_count == 0 {
let ctx_data = ctx.as_mut_unchecked(); return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
let ref_count = ctx_data.ref_count.get_mut();
*ref_count -= 1;
if *ref_count == 0 {
//TODO: fix
//ctx.try_drop(false)
Ok(())
} else {
Ok(())
} }
} else { ctx.ref_count -= 1;
Err(CUresult::CUDA_ERROR_INVALID_CONTEXT) if ctx.ref_count == 0 {
// Even if we encounter errors we can't really surface them
ctx.mutable.drop_with_result().ok();
ctx.mutable = ContextInnerMutable::new();
ctx.flags = 0;
} }
Ok(())
})? })?
} }
@ -566,53 +553,43 @@ pub(crate) unsafe fn primary_ctx_set_flags(
hip_dev: hipDevice_t, hip_dev: hipDevice_t,
flags: ::std::os::raw::c_uint, flags: ::std::os::raw::c_uint,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
primary_ctx(hip_dev, move |maybe_ctx| { primary_ctx(hip_dev, |ctx, _| {
if let Some(ctx) = maybe_ctx { ctx.flags = flags;
let ctx = ctx.as_mut_unchecked(); // TODO: actually use flags
ctx.flags = AtomicU32::new(flags);
Ok(()) Ok(())
} else {
Err(CUresult::CUDA_ERROR_INVALID_CONTEXT)
}
})? })?
} }
pub(crate) unsafe fn primary_ctx_get_state( pub(crate) unsafe fn primary_ctx_get_state(
hip_dev: hipDevice_t, hip_dev: hipDevice_t,
flags_ptr: *mut ::std::os::raw::c_uint, flags_ptr: *mut u32,
active_ptr: *mut ::std::os::raw::c_int, active_ptr: *mut i32,
) -> Result<(), CUresult> { ) -> Result<(), CUresult> {
if flags_ptr == ptr::null_mut() || active_ptr == ptr::null_mut() { if flags_ptr == ptr::null_mut() || active_ptr == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE); return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
} }
let maybe_flags = primary_ctx(hip_dev, move |maybe_ctx| { let (flags, active) = primary_ctx(hip_dev, |ctx, _| {
if let Some(ctx) = maybe_ctx { (ctx.flags, (ctx.ref_count > 0) as i32)
let ctx = ctx.as_mut_unchecked();
Some(*ctx.flags.get_mut())
} else {
None
}
})?; })?;
if let Some(flags) = maybe_flags {
*flags_ptr = flags; *flags_ptr = flags;
*active_ptr = 1; *active_ptr = active;
} else {
*flags_ptr = 0;
*active_ptr = 0;
}
Ok(()) Ok(())
} }
pub(crate) unsafe fn primary_ctx<T>( pub(crate) unsafe fn primary_ctx<T>(
dev: hipDevice_t, dev: hipDevice_t,
f: impl FnOnce(&mut Option<context::Context>) -> T, fn_: impl FnOnce(&mut PrimaryContextData, *const LiveCheck<ContextData>) -> T,
) -> Result<T, CUresult> { ) -> Result<T, CUresult> {
let device = GLOBAL_STATE.get()?.device(dev)?; let device = GLOBAL_STATE.get()?.device(dev)?;
let mut maybe_primary_context = device let raw_ptr = &device.primary_context as *const _;
.primary_context let context = device.primary_context.as_ref_unchecked();
.lock() match context.variant {
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; ContextVariant::Primary(ref mutex_over_primary_ctx) => {
Ok(f(&mut maybe_primary_context)) let mut primary_ctx = mutex_over_primary_ctx.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
Ok(fn_(&mut primary_ctx, raw_ptr))
},
ContextVariant::NonPrimary(..) => Err(CUresult::CUDA_ERROR_UNKNOWN)
}
} }
pub(crate) unsafe fn get_name(name: *mut i8, len: i32, device: i32) -> hipError_t { pub(crate) unsafe fn get_name(name: *mut i8, len: i32, device: i32) -> hipError_t {

View file

@ -148,6 +148,10 @@ impl<T: ZludaObject> LiveCheck<T> {
outer_ptr as *mut Self outer_ptr as *mut Self
} }
pub unsafe fn as_ref_unchecked(&self) -> & T {
&self.data
}
pub unsafe fn as_mut_unchecked(&mut self) -> &mut T { pub unsafe fn as_mut_unchecked(&mut self) -> &mut T {
&mut self.data &mut self.data
} }

View file

@ -31,13 +31,11 @@ impl ZludaObject for ModuleData {
let deregistration_err = if !by_owner { let deregistration_err = if !by_owner {
if let Some(ctx) = self.owner { if let Some(ctx) = self.owner {
let ctx = unsafe { LiveCheck::as_result(ctx.as_ptr())? }; let ctx = unsafe { LiveCheck::as_result(ctx.as_ptr())? };
let mut ctx_mutable = ctx ctx.with_inner_mut(|ctx_mutable| {
.mutable
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
ctx_mutable ctx_mutable
.modules .modules
.remove(&unsafe { LiveCheck::from_raw(self) }); .remove(&unsafe { LiveCheck::from_raw(self) });
})?;
} }
Ok(()) Ok(())
} else { } else {
@ -104,11 +102,9 @@ pub(crate) unsafe fn load_impl(
isa, isa,
input, input,
)?); )?);
let mut ctx_mutable = ctx ctx.with_inner_mut(|ctx_mutable| {
.mutable
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
ctx_mutable.modules.insert(module); ctx_mutable.modules.insert(module);
})?;
*output = module; *output = module;
Ok(()) Ok(())
})? })?

View file

@ -21,13 +21,11 @@ impl ZludaObject for StreamData {
if !by_owner { if !by_owner {
let ctx = unsafe { LiveCheck::as_result(self.ctx)? }; let ctx = unsafe { LiveCheck::as_result(self.ctx)? };
{ {
let mut ctx_mutable = ctx ctx.with_inner_mut(|ctx_mutable| {
.mutable
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
ctx_mutable ctx_mutable
.streams .streams
.remove(&unsafe { LiveCheck::from_raw(&mut *self) }); .remove(&unsafe { LiveCheck::from_raw(&mut *self) });
})?;
} }
} }
hip_call_cuda!(hipStreamDestroy(self.base)); hip_call_cuda!(hipStreamDestroy(self.base));
@ -59,11 +57,9 @@ pub(crate) unsafe fn create_with_priority(
ctx: ptr::null_mut(), ctx: ptr::null_mut(),
}))); })));
let ctx = context::with_current(|ctx| { let ctx = context::with_current(|ctx| {
let mut ctx_mutable = ctx ctx.with_inner_mut(|ctx_mutable| {
.mutable
.lock()
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
ctx_mutable.streams.insert(stream); ctx_mutable.streams.insert(stream);
})?;
Ok(LiveCheck::from_raw(ctx as *const _ as _)) Ok(LiveCheck::from_raw(ctx as *const _ as _))
})??; })??;
(*stream).as_mut_unchecked().ctx = ctx; (*stream).as_mut_unchecked().ctx = ctx;

View file

@ -0,0 +1,84 @@
use crate::common::CudaDriverFns;
use cuda_types::*;
use std::{mem, ptr};
mod common;
cuda_driver_test!(primary_context);
unsafe fn primary_context<T: CudaDriverFns>(cuda: T) {
assert_eq!(cuda.cuInit(0), CUresult::CUDA_SUCCESS);
let mut flags = 0;
let mut active = 0;
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((0, 0), (flags, active));
assert_eq!(
cuda.cuDevicePrimaryCtxSetFlags_v2(CUdevice_v1(0), 1),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((1, 0), (flags, active));
let mut primary_ctx = ptr::null_mut();
assert_eq!(
cuda.cuDevicePrimaryCtxRetain(&mut primary_ctx, CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuCtxPushCurrent_v2(primary_ctx),
CUresult::CUDA_SUCCESS
);
assert_eq!(cuda.cuCtxSetFlags(2), CUresult::CUDA_SUCCESS);
assert_eq!(
cuda.cuCtxSetCurrent(ptr::null_mut()),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((1, 1), (flags, active));
assert_ne!(primary_ctx, ptr::null_mut());
let mut active_ctx = ptr::null_mut();
assert_eq!(
cuda.cuCtxGetCurrent(&mut active_ctx),
CUresult::CUDA_SUCCESS
);
assert_eq!(active_ctx, ptr::null_mut());
assert_ne!(primary_ctx, active_ctx);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_eq!((1, 1), (flags, active));
let mut buffer = mem::zeroed();
assert_eq!(
cuda.cuCtxPushCurrent_v2(primary_ctx),
CUresult::CUDA_SUCCESS
);
assert_eq!(cuda.cuMemAlloc_v2(&mut buffer, 4), CUresult::CUDA_SUCCESS);
assert_eq!(
cuda.cuDevicePrimaryCtxRelease_v2(CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
CUresult::CUDA_SUCCESS
);
assert_ne!(
cuda.cuDevicePrimaryCtxRelease_v2(CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
assert_eq!((0, 0), (flags, active));
// Already freed on context destruction
// TODO: reenable when we start tracking allocations inside context
//assert_ne!(cuda.cuMemFree_v2(buffer), CUresult::CUDA_SUCCESS);
assert_eq!(
cuda.cuDevicePrimaryCtxReset_v2(CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
}