Pass correct arguments

This commit is contained in:
Andrzej Janik 2025-09-20 01:54:40 +00:00
commit 47703e6507

View file

@ -66,6 +66,7 @@ fn main() {
}); });
buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes()); buffer_param_slice.copy_from_slice(&(dev_ptr.0 as usize).to_ne_bytes());
} }
buffer
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut module = unsafe { mem::zeroed() }; let mut module = unsafe { mem::zeroed() };
@ -78,6 +79,10 @@ fn main() {
libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast()) libcuda.cuModuleGetFunction(&mut function, module, manifest.kernel_name.as_ptr().cast())
} }
.unwrap(); .unwrap();
let mut cuda_args = args
.iter_mut()
.map(|arg| arg.as_mut_ptr().cast::<std::ffi::c_void>())
.collect::<Vec<_>>();
unsafe { unsafe {
libcuda.cuLaunchKernel( libcuda.cuLaunchKernel(
function, function,
@ -89,10 +94,10 @@ fn main() {
manifest.config.block_dim.2, manifest.config.block_dim.2,
manifest.config.shared_mem_bytes, manifest.config.shared_mem_bytes,
CUstream(std::ptr::null_mut()), CUstream(std::ptr::null_mut()),
args.as_mut_ptr().cast(), cuda_args.as_mut_ptr().cast(),
std::ptr::null_mut(), std::ptr::null_mut(),
) )
} }
.unwrap(); .unwrap();
todo!(); unsafe { libcuda.cuCtxSynchronize() }.unwrap();
} }