From f46b756fdc8f3890a161063668fc1e26d22f7763 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 20 Sep 2025 22:45:46 -0700 Subject: [PATCH] Fix cuCtxPopCurrent (#519) --- zluda/src/impl/context.rs | 49 ++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 8933116..e6fb35e 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -188,12 +188,25 @@ pub(crate) unsafe fn push_current_v2(ctx: CUcontext) -> CUresult { push_current(ctx) } -pub(crate) unsafe fn pop_current(ctx: &mut CUcontext) -> CUresult { - STACK.with(|stack| { - if let Some((_ctx, _)) = stack.borrow_mut().pop() { - *ctx = _ctx; - } +pub(crate) unsafe fn pop_current(result: Option<&mut CUcontext>) -> CUresult { + let old_ctx_and_new_device = STACK.with(|stack| { + let mut stack = stack.borrow_mut(); + stack + .pop() + .map(|(ctx, _)| (ctx, stack.last().map(|(_, dev)| *dev))) }); + let ctx = match old_ctx_and_new_device { + Some((old_ctx, new_device)) => { + if let Some(new_device) = new_device { + hipSetDevice(new_device)?; + } + old_ctx + } + None => return CUresult::ERROR_INVALID_CONTEXT, + }; + if let Some(out) = result { + *out = ctx; + } Ok(()) } @@ -213,7 +226,7 @@ pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult { zluda_common::drop_checked::(ctx) } -pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult { +pub(crate) unsafe fn pop_current_v2(ctx: Option<&mut CUcontext>) -> CUresult { pop_current(ctx) } @@ -241,3 +254,27 @@ pub(crate) unsafe fn get_api_version( *version = 3020; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::CudaApi; + use cuda_macros::test_cuda; + use std::mem; + + #[test_cuda] + fn empty_pop_fails(api: impl CudaApi) { + api.cuInit(0); + assert_eq!( + api.cuCtxPopCurrent_v2_unchecked(&mut unsafe { mem::zeroed() }), + CUresult::ERROR_INVALID_CONTEXT + ); + } + + #[test_cuda] + fn pop_into_null_succeeds(api: impl CudaApi) { + api.cuInit(0); + api.cuCtxCreate_v2(&mut unsafe { mem::zeroed() }, 0, 0); + api.cuCtxPopCurrent_v2(ptr::null_mut()); + } +}