Avoid memset in malloc during capturing

This commit is contained in:
Andrzej Janik 2025-09-25 22:44:49 +00:00
commit a8fa8089c0

View file

@ -1,16 +1,18 @@
use std::ptr;
use crate::r#impl::{context, driver};
use cuda_types::cuda::{CUerror, CUresult, CUresultConsts};
use hip_runtime_sys::*;
use std::{mem, ptr};
use crate::r#impl::{context, driver};
pub(crate) fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult {
pub(crate) unsafe fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult {
let context = context::get_current_context()?;
unsafe { hipMalloc(ptr::from_mut(dptr).cast(), bytesize) }?;
hipMalloc(ptr::from_mut(dptr).cast(), bytesize)?;
add_allocation(dptr.0, bytesize, context)?;
let mut status = mem::zeroed();
hipStreamIsCapturing(hipStream_t(ptr::null_mut()), &mut status)?;
// TODO: parametrize for non-Geekbench
unsafe { hipMemsetD8(*dptr, 0, bytesize) }?;
if status != hipStreamCaptureStatus::hipStreamCaptureStatusNone {
hipMemsetD8(*dptr, 0, bytesize)?;
}
Ok(())
}
@ -68,6 +70,7 @@ pub(crate) unsafe fn host_alloc(
) -> CUresult {
let context = context::get_current_context()?;
hipHostMalloc(pp, bytesize, flags)?;
unsafe { hipMemsetD8(hipDeviceptr_t(*pp), 0, bytesize) }?;
add_allocation(*pp, bytesize, context)?;
Ok(())
}