diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 70395ed..7e9a1b8 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -1,16 +1,18 @@ -use std::ptr; - +use crate::r#impl::{context, driver}; use cuda_types::cuda::{CUerror, CUresult, CUresultConsts}; use hip_runtime_sys::*; +use std::{mem, ptr}; -use crate::r#impl::{context, driver}; - -pub(crate) fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult { +pub(crate) unsafe fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult { let context = context::get_current_context()?; - unsafe { hipMalloc(ptr::from_mut(dptr).cast(), bytesize) }?; + hipMalloc(ptr::from_mut(dptr).cast(), bytesize)?; add_allocation(dptr.0, bytesize, context)?; + let mut status = mem::zeroed(); + hipStreamIsCapturing(hipStream_t(ptr::null_mut()), &mut status)?; // TODO: parametrize for non-Geekbench - unsafe { hipMemsetD8(*dptr, 0, bytesize) }?; + if status != hipStreamCaptureStatus::hipStreamCaptureStatusNone { + hipMemsetD8(*dptr, 0, bytesize)?; + } Ok(()) } @@ -68,6 +70,7 @@ pub(crate) unsafe fn host_alloc( ) -> CUresult { let context = context::get_current_context()?; hipHostMalloc(pp, bytesize, flags)?; + unsafe { hipMemsetD8(hipDeviceptr_t(*pp), 0, bytesize) }?; add_allocation(*pp, bytesize, context)?; Ok(()) }