Make cudart_interface_fn2 behave more in line with CUDA behavior

This commit is contained in:
Andrzej Janik 2025-09-10 01:13:20 +00:00
commit afb184e10e
2 changed files with 35 additions and 1 deletions

Binary file not shown.

View file

@ -166,7 +166,9 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
None => return CUresult::ERROR_INVALID_VALUE, None => return CUresult::ERROR_INVALID_VALUE,
}; };
device::primary_context_retain(pctx, hip_dev) let (_, cu_ctx) = device::get_primary_context(hip_dev)?;
*pctx = cu_ctx;
Ok(())
} }
unsafe extern "system" fn get_module_from_cubin_ext1( unsafe extern "system" fn get_module_from_cubin_ext1(
@ -527,6 +529,8 @@ pub(crate) unsafe fn launch_kernel_ex(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::i32;
use crate::r#impl::driver::AllocationInfo; use crate::r#impl::driver::AllocationInfo;
use crate::tests::CudaApi; use crate::tests::CudaApi;
use cuda_macros::test_cuda; use cuda_macros::test_cuda;
@ -571,4 +575,34 @@ mod tests {
} }
assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None); assert_eq!(alloc_info.get_offset_and_info(0x2000 + 8), None);
} }
#[test_cuda]
fn primary_context_is_inactive_on_init(api: impl CudaApi) {
api.cuInit(0);
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
}
#[test_cuda]
unsafe fn cudart_interface_fn2_creates_inactive_primary_ctx(api: impl CudaApi) {
api.cuInit(0);
let mut table_ptr = std::ptr::null();
api.cuGetExportTable(&mut table_ptr, &dark_api::cuda::CudartInterface::GUID);
let cuda_rt_iface = dark_api::cuda::CudartInterface::new(table_ptr);
let mut dark_ctx = std::mem::zeroed();
cuda_rt_iface
.cudart_interface_fn2(&mut dark_ctx, 0)
.unwrap();
let mut flags = u32::MAX;
let mut active = i32::MAX;
api.cuDevicePrimaryCtxGetState(0, &mut flags, &mut active);
assert_eq!(flags, 0);
assert_eq!(active, 0);
let mut primary_ctx = std::mem::zeroed();
api.cuDevicePrimaryCtxRetain(&mut primary_ctx, 0);
assert_eq!(dark_ctx.0, primary_ctx.0);
}
} }