diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index 4b333d6..2ffbc25 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -1,5 +1,5 @@ use super::LiveCheck; -use crate::r#impl::context; +use crate::r#impl::{context, device}; use comgr::Comgr; use cuda_types::cuda::*; use hip_runtime_sys::*; @@ -84,7 +84,9 @@ struct Fn2Buffer { impl Fn2Buffer { const fn new() -> Self { - Fn2Buffer { buffer: std::cell::UnsafeCell::new([0; FN2_BUFFER_SIZE]) } + Fn2Buffer { + buffer: std::cell::UnsafeCell::new([0; FN2_BUFFER_SIZE]), + } } } @@ -104,9 +106,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { unsafe extern "system" fn cudart_interface_fn2( pctx: *mut cuda_types::cuda::CUcontext, - dev: cuda_types::cuda::CUdevice, + hip_dev: hipDevice_t, ) -> cuda_types::cuda::CUresult { - todo!() + let pctx = match pctx.as_mut() { + Some(p) => p, + None => return CUresult::ERROR_INVALID_VALUE, + }; + + device::primary_context_retain(pctx, hip_dev) } unsafe extern "system" fn get_module_from_cubin_ext1(