Fix cuCtxPopCurrent (#519)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

This commit is contained in:
Andrzej Janik 2025-09-20 22:45:46 -07:00 committed by GitHub
commit f46b756fdc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -188,12 +188,25 @@ pub(crate) unsafe fn push_current_v2(ctx: CUcontext) -> CUresult {
push_current(ctx) push_current(ctx)
} }
pub(crate) unsafe fn pop_current(ctx: &mut CUcontext) -> CUresult { pub(crate) unsafe fn pop_current(result: Option<&mut CUcontext>) -> CUresult {
STACK.with(|stack| { let old_ctx_and_new_device = STACK.with(|stack| {
if let Some((_ctx, _)) = stack.borrow_mut().pop() { let mut stack = stack.borrow_mut();
*ctx = _ctx; stack
} .pop()
.map(|(ctx, _)| (ctx, stack.last().map(|(_, dev)| *dev)))
}); });
let ctx = match old_ctx_and_new_device {
Some((old_ctx, new_device)) => {
if let Some(new_device) = new_device {
hipSetDevice(new_device)?;
}
old_ctx
}
None => return CUresult::ERROR_INVALID_CONTEXT,
};
if let Some(out) = result {
*out = ctx;
}
Ok(()) Ok(())
} }
@ -213,7 +226,7 @@ pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult {
zluda_common::drop_checked::<Context>(ctx) zluda_common::drop_checked::<Context>(ctx)
} }
pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult { pub(crate) unsafe fn pop_current_v2(ctx: Option<&mut CUcontext>) -> CUresult {
pop_current(ctx) pop_current(ctx)
} }
@ -241,3 +254,27 @@ pub(crate) unsafe fn get_api_version(
*version = 3020; *version = 3020;
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use std::mem;
#[test_cuda]
fn empty_pop_fails(api: impl CudaApi) {
api.cuInit(0);
assert_eq!(
api.cuCtxPopCurrent_v2_unchecked(&mut unsafe { mem::zeroed() }),
CUresult::ERROR_INVALID_CONTEXT
);
}
#[test_cuda]
fn pop_into_null_succeeds(api: impl CudaApi) {
api.cuInit(0);
api.cuCtxCreate_v2(&mut unsafe { mem::zeroed() }, 0, 0);
api.cuCtxPopCurrent_v2(ptr::null_mut());
}
}