diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index fb15d61..8fbc82a 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -19,4 +19,5 @@ version = "0.18.1" features = ["lexer"] [dev-dependencies] +level_zero-sys = { path = "../level_zero-sys" } ocl = { version = "0.19", features = ["opencl_version_1_1", "opencl_version_1_2", "opencl_version_2_1"] } diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 2984c89..022fa97 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -6,6 +6,8 @@ extern crate quick_error; extern crate bit_vec; #[cfg(test)] extern crate ocl; +#[cfg(test)] +extern crate level_zero_sys as l0; extern crate rspirv; extern crate spirv_headers as spirv; diff --git a/ptx/src/test/ops/mod.rs b/ptx/src/test/ops/mod.rs index 1ea60b8..2537cf9 100644 --- a/ptx/src/test/ops/mod.rs +++ b/ptx/src/test/ops/mod.rs @@ -64,6 +64,139 @@ fn run_spirv>( input: &[T], output: &mut [T], ) -> ocl::Result> { + let (drv, device, queue) = unsafe { l0_init() }; + let (ocl_plat, ocl_dev) = get_ocl_platform_device(); + let ocl_ctx = Context::builder() + .platform(ocl_plat) + .devices(ocl_dev) + .build()?; + let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap(); + let src = CString::new( + " + __kernel void ld_st(ulong a, ulong b) + { + __global ulong* a_copy = (__global ulong*)a; + __global ulong* b_copy = (__global ulong*)b; + *b_copy = *a_copy; + }", + )?; + let prog = Program::with_source(&ocl_ctx, &[src], None, &empty_cstr)?; + let binaries_wrapped = prog.info(ocl::core::ProgramInfo::Binaries)?; + let binaries = if let ocl::core::ProgramInfoResult::Binaries(bins) = binaries_wrapped { + bins + } else { + panic!() + }; + let module = l0_create_module(device, &binaries[0]); + let kernel_desc = l0::ze_kernel_desc_t { + version: l0::ze_kernel_desc_version_t::ZE_KERNEL_DESC_VERSION_CURRENT, + flags: l0::ze_kernel_flag_t::ZE_KERNEL_FLAG_NONE, + pKernelName: "ld_st".as_ptr() as *const _, + }; + let mut kernel: l0::ze_kernel_handle_t = ptr::null_mut(); + let mut err = unsafe { l0::zeKernelCreate(module, &kernel_desc, &mut kernel) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let inp_b = l0_allocate_buffer(drv, device, &input); + let out_b = l0_allocate_buffer(drv, device, &output); + println!("inp_b: {:?}", inp_b); + println!("out_b: {:?}", out_b); + let mut cmd_list = l0_create_cmd_list(device); + println!("input: {:?}", input); + err = unsafe { + l0::zeCommandListAppendMemoryCopy( + cmd_list, + inp_b, + input.as_ptr() as *const _, + input.len() * mem::size_of::(), + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let pattern = 0u8; + err = unsafe { + l0::zeCommandListAppendMemoryFill( + cmd_list, + out_b, + &pattern as *const u8 as *const _, + 1, + input.len() * mem::size_of::(), + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { l0::zeKernelSetGroupSize(kernel, 1, 1, 1) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let wg_size = l0::ze_group_count_t { + groupCountX: 1, + groupCountY: 1, + groupCountZ: 1, + }; + err = unsafe { + l0::zeKernelSetArgumentValue( + kernel, + 0, + mem::size_of::<*mut c_void>(), + &inp_b as *const *mut _ as *const _, + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { + l0::zeKernelSetArgumentValue( + kernel, + 1, + mem::size_of::<*mut c_void>(), + &out_b as *const *mut _ as *const _, + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { + l0::zeCommandListAppendBarrier( + cmd_list, + ptr::null_mut(), + 0, + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { + l0::zeCommandListAppendLaunchKernel( + cmd_list, + kernel, + &wg_size, + ptr::null_mut(), + 0, + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { + l0::zeCommandListAppendBarrier( + cmd_list, + ptr::null_mut(), + 0, + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let mut result: Vec = vec![0u8.into(); output.len()]; + err = unsafe { + l0::zeCommandListAppendMemoryCopy( + cmd_list, + result.as_mut_ptr() as *mut _, + out_b, + result.len() * mem::size_of::(), + ptr::null_mut(), + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { l0::zeCommandListClose(cmd_list) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = + unsafe { l0::zeCommandQueueExecuteCommandLists(queue, 1, &mut cmd_list, ptr::null_mut()) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + err = unsafe { l0::zeCommandQueueSynchronize(queue, u32::max_value()) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + /* let (plat, dev) = get_ocl_platform_device(); let ctx = Context::builder().platform(plat).devices(dev).build()?; let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap(); @@ -160,6 +293,7 @@ fn run_spirv>( ); assert_eq!(err_code, 0); queue.finish()?; + */ Ok(result) } @@ -278,3 +412,82 @@ fn get_cl_set_kernel_arg_mem_pointer_intel( }?; Ok(unsafe { std::mem::transmute(ptr) }) } + +unsafe fn l0_init() -> ( + l0::ze_driver_handle_t, + l0::ze_device_handle_t, + l0::ze_command_queue_handle_t, +) { + let mut err = l0::ze_result_t::ZE_RESULT_SUCCESS; + err = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_NONE); + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let mut len = 1; + let mut driver: l0::ze_driver_handle_t = ptr::null_mut(); + err = l0::zeDriverGet(&mut len, &mut driver); + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let mut device: l0::ze_device_handle_t = ptr::null_mut(); + err = l0::zeDeviceGet(driver, &mut len, &mut device); + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + let que_desc = l0::ze_command_queue_desc_t { + version: l0::ze_command_queue_desc_version_t::ZE_COMMAND_QUEUE_DESC_VERSION_CURRENT, + flags: l0::ze_command_queue_flag_t::ZE_COMMAND_QUEUE_FLAG_NONE, + mode: l0::ze_command_queue_mode_t::ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, + priority: l0::ze_command_queue_priority_t::ZE_COMMAND_QUEUE_PRIORITY_NORMAL, + ordinal: 0, + }; + let mut queue: l0::ze_command_queue_handle_t = ptr::null_mut(); + err = l0::zeCommandQueueCreate(device, &que_desc, &mut queue); + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + (driver, device, queue) +} + +fn l0_create_module(dev: l0::ze_device_handle_t, bin: &[u8]) -> l0::ze_module_handle_t { + let desc = l0::ze_module_desc_t { + version: l0::ze_module_desc_version_t::ZE_MODULE_DESC_VERSION_CURRENT, + format: l0::ze_module_format_t::ZE_MODULE_FORMAT_NATIVE, + inputSize: bin.len(), + pInputModule: bin.as_ptr(), + pBuildFlags: ptr::null(), + pConstants: ptr::null(), + }; + let mut result: l0::ze_module_handle_t = ptr::null_mut(); + let err = unsafe { l0::zeModuleCreate(dev, &desc, &mut result, ptr::null_mut()) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + result +} + +fn l0_allocate_buffer( + drv: l0::ze_driver_handle_t, + dev: l0::ze_device_handle_t, + based: &[T], +) -> *mut c_void { + let desc = l0::_ze_device_mem_alloc_desc_t { + version: l0::ze_device_mem_alloc_desc_version_t::ZE_DEVICE_MEM_ALLOC_DESC_VERSION_CURRENT, + flags: l0::_ze_device_mem_alloc_flag_t::ZE_DEVICE_MEM_ALLOC_FLAG_DEFAULT, + ordinal: 0, + }; + let mut result = ptr::null_mut(); + let err = unsafe { + l0::zeDriverAllocDeviceMem( + drv, + &desc, + based.len() * mem::size_of::(), + mem::align_of::(), + dev, + &mut result, + ) + }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + result +} + +fn l0_create_cmd_list(dev: l0::ze_device_handle_t) -> l0::ze_command_list_handle_t { + let desc = l0::_ze_command_list_desc_t { + version: l0::ze_command_list_desc_version_t::ZE_COMMAND_LIST_DESC_VERSION_CURRENT, + flags: l0::ze_command_list_flag_t::ZE_COMMAND_LIST_FLAG_EXPLICIT_ONLY, + }; + let mut result: l0::ze_command_list_handle_t = ptr::null_mut(); + let err = unsafe { l0::zeCommandListCreate(dev, &desc, &mut result) }; + assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS); + result +}