mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-10-01 21:59:38 +00:00
Make Library objects actually lazy
This commit is contained in:
parent
6b90b5acba
commit
e8bcb1ae33
8 changed files with 210 additions and 110 deletions
|
@ -71,6 +71,7 @@ fn parse_fatbin_header<T: Sized>(ptr: &*const T) -> Result<&FatbinHeader, ParseE
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Fatbin<'a> {
|
pub struct Fatbin<'a> {
|
||||||
pub wrapper: &'a FatbincWrapper,
|
pub wrapper: &'a FatbincWrapper,
|
||||||
}
|
}
|
||||||
|
@ -83,7 +84,7 @@ impl<'a> Fatbin<'a> {
|
||||||
Ok(Fatbin { wrapper })
|
Ok(Fatbin { wrapper })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_submodules(&self) -> Result<FatbinIter<'a>, FatbinError> {
|
pub fn get_submodules(self) -> Result<FatbinIter<'a>, FatbinError> {
|
||||||
match self.wrapper.version {
|
match self.wrapper.version {
|
||||||
FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator {
|
FatbincWrapper::VERSION_V2 => Ok(FatbinIter::V2(FatbinSubmoduleIterator {
|
||||||
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
|
fatbins: self.wrapper.filename_or_fatbins as *const *const std::ffi::c_void,
|
||||||
|
@ -105,6 +106,10 @@ impl<'a> Fatbin<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for Fatbin<'static> {}
|
||||||
|
unsafe impl Sync for Fatbin<'static> {}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct FatbinSubmodule<'a> {
|
pub struct FatbinSubmodule<'a> {
|
||||||
pub header: &'a FatbinHeader, // TODO: maybe make private
|
pub header: &'a FatbinHeader, // TODO: maybe make private
|
||||||
}
|
}
|
||||||
|
@ -159,6 +164,7 @@ impl<'a> FatbinSubmoduleIterator<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct FatbinFile<'a> {
|
pub struct FatbinFile<'a> {
|
||||||
pub header: &'a FatbinFileHeader,
|
pub header: &'a FatbinFileHeader,
|
||||||
}
|
}
|
||||||
|
@ -176,14 +182,14 @@ impl<'a> FatbinFile<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_payload(&'a self) -> &'a [u8] {
|
pub unsafe fn get_payload(self) -> &'a [u8] {
|
||||||
let start = std::ptr::from_ref(self.header)
|
let start = std::ptr::from_ref(self.header)
|
||||||
.cast::<u8>()
|
.cast::<u8>()
|
||||||
.add(self.header.header_size as usize);
|
.add(self.header.header_size as usize);
|
||||||
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
std::slice::from_raw_parts(start, self.header.payload_size as usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_or_decompress_content(&'a self) -> Result<Cow<'a, [u8]>, FatbinError> {
|
pub unsafe fn get_or_decompress_content(self) -> Result<Cow<'a, [u8]>, FatbinError> {
|
||||||
let mut payload = if self
|
let mut payload = if self
|
||||||
.header
|
.header
|
||||||
.flags
|
.flags
|
||||||
|
@ -256,7 +262,7 @@ impl<'a> FatbinFileIterator<'a> {
|
||||||
|
|
||||||
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
|
const MAX_MODULE_DECOMPRESSION_BOUND: usize = 64 * 1024 * 1024;
|
||||||
|
|
||||||
pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
pub unsafe fn decompress_lz4(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||||
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
|
let decompressed_size = usize::max(1024, file.header.uncompressed_payload as usize);
|
||||||
let mut decompressed_vec = vec![0u8; decompressed_size];
|
let mut decompressed_vec = vec![0u8; decompressed_size];
|
||||||
loop {
|
loop {
|
||||||
|
@ -281,7 +287,7 @@ pub unsafe fn decompress_lz4(file: &FatbinFile) -> Result<Vec<u8>, FatbinError>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn decompress_zstd(file: &FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
pub unsafe fn decompress_zstd(file: FatbinFile) -> Result<Vec<u8>, FatbinError> {
|
||||||
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
let mut result = Vec::with_capacity(file.header.uncompressed_payload as usize);
|
||||||
let payload = file.get_payload();
|
let payload = file.get_payload();
|
||||||
match zstd_safe::decompress(&mut result, payload) {
|
match zstd_safe::decompress(&mut result, payload) {
|
||||||
|
|
|
@ -154,7 +154,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_module: *mut cuda_types::cuda::CUmodule,
|
_module: *mut cuda_types::cuda::CUmodule,
|
||||||
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
_fatbinc_wrapper: *const cuda_types::dark_api::FatbincWrapper,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn cudart_interface_fn2(
|
unsafe extern "system" fn cudart_interface_fn2(
|
||||||
|
@ -176,11 +176,11 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_arg4: *mut std::ffi::c_void,
|
_arg4: *mut std::ffi::c_void,
|
||||||
_arg5: u32,
|
_arg5: u32,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
unsafe extern "system" fn cudart_interface_fn7(_arg1: usize) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_module_from_cubin_ext2(
|
unsafe extern "system" fn get_module_from_cubin_ext2(
|
||||||
|
@ -190,7 +190,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_arg4: *mut std::ffi::c_void,
|
_arg4: *mut std::ffi::c_void,
|
||||||
_arg5: u32,
|
_arg5: u32,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn get_unknown_buffer1(
|
unsafe extern "system" fn get_unknown_buffer1(
|
||||||
|
@ -276,7 +276,7 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_flags: ::std::os::raw::c_uint,
|
_flags: ::std::os::raw::c_uint,
|
||||||
_dev: cuda_types::cuda::CUdevice,
|
_dev: cuda_types::cuda::CUdevice,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn heap_alloc(
|
unsafe extern "system" fn heap_alloc(
|
||||||
|
@ -284,14 +284,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_arg2: usize,
|
_arg2: usize,
|
||||||
_arg3: usize,
|
_arg3: usize,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn heap_free(
|
unsafe extern "system" fn heap_free(
|
||||||
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
_heap_alloc_record_ptr: *const std::ffi::c_void,
|
||||||
_arg2: *mut usize,
|
_arg2: *mut usize,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn device_get_attribute_ext(
|
unsafe extern "system" fn device_get_attribute_ext(
|
||||||
|
@ -300,14 +300,14 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi {
|
||||||
_unknown: std::ffi::c_int,
|
_unknown: std::ffi::c_int,
|
||||||
_result: *mut [usize; 2],
|
_result: *mut [usize; 2],
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn device_get_something(
|
unsafe extern "system" fn device_get_something(
|
||||||
_result: *mut std::ffi::c_uchar,
|
_result: *mut std::ffi::c_uchar,
|
||||||
_dev: cuda_types::cuda::CUdevice,
|
_dev: cuda_types::cuda::CUdevice,
|
||||||
) -> cuda_types::cuda::CUresult {
|
) -> cuda_types::cuda::CUresult {
|
||||||
r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "system" fn integrity_check(
|
unsafe extern "system" fn integrity_check(
|
||||||
|
|
|
@ -1,10 +1,38 @@
|
||||||
use super::module;
|
use crate::r#impl::{context, driver, module};
|
||||||
use cuda_types::cuda::*;
|
use cuda_types::cuda::*;
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
use zluda_common::ZludaObject;
|
use std::{ffi::c_void, sync::OnceLock};
|
||||||
|
use zluda_common::{CodeLibraryRef, ZludaObject};
|
||||||
|
|
||||||
pub(crate) struct Library {
|
pub(crate) struct Library {
|
||||||
base: hipModule_t,
|
data: LibraryData,
|
||||||
|
modules: Vec<OnceLock<Result<hipModule_t, CUerror>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Library {
|
||||||
|
pub(crate) fn get_module_for_device(&self, device: usize) -> Result<hipModule_t, CUerror> {
|
||||||
|
let module_lock = self.modules.get(device).ok_or(CUerror::INVALID_DEVICE)?;
|
||||||
|
*module_lock.get_or_init(|| match self.data {
|
||||||
|
LibraryData::Lazy(lib) => module::load_hip_module(lib),
|
||||||
|
LibraryData::Eager(()) => Err(CUerror::NOT_SUPPORTED),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum LibraryData {
|
||||||
|
Lazy(CodeLibraryRef<'static>),
|
||||||
|
Eager(()),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LibraryData {
|
||||||
|
unsafe fn new(ptr: *mut c_void, static_lifetime: bool) -> Result<Self, CUerror> {
|
||||||
|
if static_lifetime {
|
||||||
|
let lib = CodeLibraryRef::try_load(ptr).map_err(|_| CUerror::INVALID_VALUE)?;
|
||||||
|
Ok(LibraryData::Lazy(lib))
|
||||||
|
} else {
|
||||||
|
Err(CUerror::NOT_SUPPORTED)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ZludaObject for Library {
|
impl ZludaObject for Library {
|
||||||
|
@ -14,34 +42,89 @@ impl ZludaObject for Library {
|
||||||
type CudaHandle = CUlibrary;
|
type CudaHandle = CUlibrary;
|
||||||
|
|
||||||
fn drop_checked(&mut self) -> CUresult {
|
fn drop_checked(&mut self) -> CUresult {
|
||||||
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice.
|
// TODO: implement unloading
|
||||||
unsafe { hipModuleUnload(self.base) }?;
|
// TODO: we will want to test that we handle `cuModuleUnload` on a module that came from a library correctly, without calling `hipModuleUnload` twice
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This implementation simply loads the code as a HIP module for now. The various JIT and library options are ignored.
|
pub(crate) unsafe fn load_data(
|
||||||
pub(crate) fn load_data(
|
result: &mut CUlibrary,
|
||||||
library: &mut CUlibrary,
|
|
||||||
code: *const ::core::ffi::c_void,
|
code: *const ::core::ffi::c_void,
|
||||||
_jit_options: Option<&mut CUjit_option>,
|
_jit_options: Option<&mut CUjit_option>,
|
||||||
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
|
_jit_options_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||||
_num_jit_options: ::core::ffi::c_uint,
|
_num_jit_options: ::core::ffi::c_uint,
|
||||||
_library_options: Option<&mut CUlibraryOption>,
|
library_options: Option<&mut CUlibraryOption>,
|
||||||
_library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||||
_num_library_options: ::core::ffi::c_uint,
|
num_library_options: ::core::ffi::c_uint,
|
||||||
) -> CUresult {
|
) -> CUresult {
|
||||||
let hip_module = module::load_hip_module(code)?;
|
let global_state = driver::global_state()?;
|
||||||
*library = Library { base: hip_module }.wrap();
|
let options =
|
||||||
|
LibraryOptions::load(library_options, library_option_values, num_library_options)?;
|
||||||
|
let library = Library {
|
||||||
|
data: LibraryData::new(code as *mut c_void, options.preserve_binary)?,
|
||||||
|
modules: vec![OnceLock::new(); global_state.devices.len()],
|
||||||
|
};
|
||||||
|
*result = library.wrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LibraryOptions {
|
||||||
|
preserve_binary: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LibraryOptions {
|
||||||
|
unsafe fn load(
|
||||||
|
library_options: Option<&mut CUlibraryOption>,
|
||||||
|
library_option_values: Option<&mut *mut ::core::ffi::c_void>,
|
||||||
|
num_library_options: ::core::ffi::c_uint,
|
||||||
|
) -> Result<Self, CUerror> {
|
||||||
|
if num_library_options == 0 {
|
||||||
|
return Ok(LibraryOptions {
|
||||||
|
preserve_binary: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (library_options, library_option_values) =
|
||||||
|
match (library_options, library_option_values) {
|
||||||
|
(Some(library_options), Some(library_option_values)) => {
|
||||||
|
let library_options =
|
||||||
|
std::slice::from_raw_parts(library_options, num_library_options as usize);
|
||||||
|
let library_option_values = std::slice::from_raw_parts(
|
||||||
|
library_option_values,
|
||||||
|
num_library_options as usize,
|
||||||
|
);
|
||||||
|
(library_options, library_option_values)
|
||||||
|
}
|
||||||
|
_ => return Err(CUerror::INVALID_VALUE),
|
||||||
|
};
|
||||||
|
let mut preserve_binary = false;
|
||||||
|
for (option, value) in library_options
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.zip(library_option_values.iter())
|
||||||
|
{
|
||||||
|
match option {
|
||||||
|
CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED => {
|
||||||
|
preserve_binary = *(value.cast::<bool>());
|
||||||
|
}
|
||||||
|
_ => return Err(CUerror::NOT_SUPPORTED),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(LibraryOptions { preserve_binary })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult {
|
pub(crate) unsafe fn unload(library: CUlibrary) -> CUresult {
|
||||||
zluda_common::drop_checked::<Library>(library)
|
zluda_common::drop_checked::<Library>(library)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult {
|
pub(crate) unsafe fn get_module(out: &mut CUmodule, library: &Library) -> CUresult {
|
||||||
*out = module::Module { base: library.base }.wrap();
|
let device = context::get_current_device()?;
|
||||||
|
// TODO: lifetime is very wrong here
|
||||||
|
let library = module::Module {
|
||||||
|
base: library.get_module_for_device(device as usize)?,
|
||||||
|
};
|
||||||
|
*out = library.wrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,8 +132,11 @@ pub(crate) unsafe fn get_kernel(
|
||||||
kernel: *mut hipFunction_t,
|
kernel: *mut hipFunction_t,
|
||||||
library: &Library,
|
library: &Library,
|
||||||
name: *const ::core::ffi::c_char,
|
name: *const ::core::ffi::c_char,
|
||||||
) -> hipError_t {
|
) -> CUresult {
|
||||||
hipModuleGetFunction(kernel, library.base, name)
|
let device = context::get_current_device()?;
|
||||||
|
let module = library.get_module_for_device(device as usize)?;
|
||||||
|
hipModuleGetFunction(kernel, module, name)?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn get_global(
|
pub(crate) unsafe fn get_global(
|
||||||
|
@ -58,16 +144,22 @@ pub(crate) unsafe fn get_global(
|
||||||
bytes: *mut usize,
|
bytes: *mut usize,
|
||||||
library: &Library,
|
library: &Library,
|
||||||
name: *const ::core::ffi::c_char,
|
name: *const ::core::ffi::c_char,
|
||||||
) -> hipError_t {
|
) -> CUresult {
|
||||||
hipModuleGetGlobal(dptr, bytes, library.base, name)
|
let device = context::get_current_device()?;
|
||||||
|
let module = library.get_module_for_device(device as usize)?;
|
||||||
|
hipModuleGetGlobal(dptr, bytes, module, name)?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::tests::CudaApi;
|
use crate::tests::CudaApi;
|
||||||
use cuda_macros::test_cuda;
|
use cuda_macros::test_cuda;
|
||||||
use cuda_types::cuda::{CUresult, CUresultConsts};
|
use cuda_types::cuda::{CUlibraryOption, CUresult, CUresultConsts};
|
||||||
use std::{ffi::CStr, mem, ptr};
|
use std::{
|
||||||
|
ffi::{c_void, CStr},
|
||||||
|
mem, ptr,
|
||||||
|
};
|
||||||
|
|
||||||
#[test_cuda]
|
#[test_cuda]
|
||||||
unsafe fn library_loads_without_context(api: impl CudaApi) {
|
unsafe fn library_loads_without_context(api: impl CudaApi) {
|
||||||
|
@ -98,9 +190,13 @@ mod tests {
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
ptr::null_mut(),
|
ptr::null_mut(),
|
||||||
0,
|
0,
|
||||||
ptr::null_mut(),
|
[CUlibraryOption::CU_LIBRARY_BINARY_IS_PRESERVED].as_mut_ptr(),
|
||||||
ptr::null_mut(),
|
[(&true as *const bool) as *mut c_void].as_mut_ptr(),
|
||||||
0,
|
1,
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
CUresult::ERROR_INVALID_CONTEXT,
|
||||||
|
api.cuLibraryGetModule_unchecked(&mut mem::zeroed(), library)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,13 +15,13 @@ pub(super) mod pointer;
|
||||||
pub(super) mod stream;
|
pub(super) mod stream;
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
pub(crate) fn unimplemented() -> CUresult {
|
pub(crate) fn unimplemented() -> CUerror {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
pub(crate) fn unimplemented() -> CUresult {
|
pub(crate) fn unimplemented() -> CUerror {
|
||||||
CUresult::ERROR_NOT_SUPPORTED
|
CUerror::NOT_SUPPORTED
|
||||||
}
|
}
|
||||||
|
|
||||||
from_cuda_object!(module::Module, context::Context, library::Library);
|
from_cuda_object!(module::Module, context::Context, library::Library);
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
use super::driver;
|
use super::driver;
|
||||||
use cuda_types::{
|
use cuda_types::{cuda::*, dark_api::FatbinFileHeader};
|
||||||
cuda::*,
|
|
||||||
dark_api::{FatbinFileHeader, FatbincWrapper},
|
|
||||||
};
|
|
||||||
use dark_api::fatbin::Fatbin;
|
|
||||||
use hip_runtime_sys::*;
|
use hip_runtime_sys::*;
|
||||||
use std::{ffi::CStr, mem};
|
use std::{borrow::Cow, ffi::CStr, mem};
|
||||||
use zluda_common::ZludaObject;
|
use zluda_common::{CodeLibraryRef, CodeModuleRef, ZludaObject};
|
||||||
|
|
||||||
pub(crate) struct Module {
|
pub(crate) struct Module {
|
||||||
pub(crate) base: hipModule_t,
|
pub(crate) base: hipModule_t,
|
||||||
|
@ -24,47 +20,45 @@ impl ZludaObject for Module {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_ptx_from_wrapped_fatbin(image: *const ::core::ffi::c_void) -> Result<Vec<u8>, CUerror> {
|
// get_ptx takes an `image` that can be anything we support and returns a
|
||||||
let fatbin = Fatbin::new(&image).map_err(|_| CUerror::UNKNOWN)?;
|
// String containing a ptx extracted from `image`.
|
||||||
let mut submodules = fatbin.get_submodules().map_err(|_| CUerror::UNKNOWN)?;
|
fn get_ptx<'a>(image: CodeLibraryRef<'a>) -> Result<Cow<'a, str>, CUerror> {
|
||||||
|
let mut ptx_modules = Vec::new();
|
||||||
while let Some(current) = submodules.next().map_err(|_| CUerror::UNKNOWN)? {
|
unsafe {
|
||||||
let mut files = current.get_files();
|
CodeLibraryRef::iterate_modules(image, |_, module| match module {
|
||||||
while let Some(file) = unsafe { files.next().map_err(|_| CUerror::UNKNOWN)? } {
|
Ok(CodeModuleRef::Text(ptx)) => {
|
||||||
if file.header.kind == FatbinFileHeader::HEADER_KIND_PTX {
|
ptx_modules.push(Cow::<'a, _>::Borrowed(ptx));
|
||||||
let decompressed = unsafe { file.decompress() }.map_err(|_| CUerror::UNKNOWN)?;
|
}
|
||||||
return Ok(decompressed);
|
Ok(CodeModuleRef::<'a>::File(file)) => {
|
||||||
|
if file.header.kind != FatbinFileHeader::HEADER_KIND_PTX {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Ok(text) = file.get_or_decompress_content() {
|
||||||
|
if let Some(text) = cow_bytes_to_str(text) {
|
||||||
|
ptx_modules.push(text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
Err(CUerror::NO_BINARY_FOR_GPU)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
/// get_ptx takes an `image` that can be either a fatbin or a NULL-terminated ptx, and returns a String containing a ptx extracted from `image`.
|
|
||||||
fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
|
|
||||||
if image.is_null() {
|
|
||||||
return Err(CUerror::INVALID_VALUE);
|
|
||||||
}
|
|
||||||
|
|
||||||
let ptx = if unsafe { *(image as *const [u8; 4]) } == FatbincWrapper::MAGIC {
|
|
||||||
let ptx_bytes = get_ptx_from_wrapped_fatbin(image)?;
|
|
||||||
std::str::from_utf8(&ptx_bytes)
|
|
||||||
.map_err(|_| CUerror::UNKNOWN)?
|
|
||||||
.to_owned()
|
|
||||||
} else {
|
|
||||||
unsafe { CStr::from_ptr(image.cast()) }
|
|
||||||
.to_str()
|
|
||||||
.map_err(|_| CUerror::INVALID_VALUE)?
|
|
||||||
.to_owned()
|
|
||||||
};
|
};
|
||||||
|
// TODO: instead of getting first PTX module, try and get the best match
|
||||||
Ok(ptx)
|
ptx_modules
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or(CUerror::NO_BINARY_FOR_GPU)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
|
fn cow_bytes_to_str<'a>(data: Cow<'a, [u8]>) -> Option<Cow<'a, str>> {
|
||||||
|
match data {
|
||||||
|
Cow::Borrowed(bytes) => std::str::from_utf8(bytes).map(Cow::Borrowed).ok(),
|
||||||
|
Cow::Owned(bytes) => String::from_utf8(bytes).map(Cow::Owned).ok(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn load_hip_module(library: CodeLibraryRef) -> Result<hipModule_t, CUerror> {
|
||||||
let global_state = driver::global_state()?;
|
let global_state = driver::global_state()?;
|
||||||
let text = get_ptx(image)?;
|
let text = get_ptx(library)?;
|
||||||
let hip_properties = get_hip_properties()?;
|
let hip_properties = get_hip_properties()?;
|
||||||
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
||||||
let attributes = ptx::Attributes {
|
let attributes = ptx::Attributes {
|
||||||
|
@ -162,7 +156,9 @@ fn compile_from_ptx_and_cache(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
||||||
let hip_module = load_hip_module(image)?;
|
let library =
|
||||||
|
unsafe { CodeLibraryRef::try_load(image) }.map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||||
|
let hip_module = load_hip_module(library)?;
|
||||||
*module = Module { base: hip_module }.wrap();
|
*module = Module { base: hip_module }.wrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -189,6 +185,6 @@ pub(crate) fn get_global_v2(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_loading_mode(mode: &mut cuda_types::cuda::CUmoduleLoadingMode) -> CUresult {
|
pub(crate) fn get_loading_mode(mode: &mut cuda_types::cuda::CUmoduleLoadingMode) -> CUresult {
|
||||||
*mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_EAGER_LOADING;
|
*mode = cuda_types::cuda::CUmoduleLoadingMode::CU_MODULE_LAZY_LOADING;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ macro_rules! unimplemented {
|
||||||
#[allow(improper_ctypes)]
|
#[allow(improper_ctypes)]
|
||||||
#[allow(improper_ctypes_definitions)]
|
#[allow(improper_ctypes_definitions)]
|
||||||
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
|
||||||
crate::r#impl::unimplemented()
|
Err(r#impl::unimplemented())
|
||||||
}
|
}
|
||||||
)*
|
)*
|
||||||
};
|
};
|
||||||
|
|
|
@ -473,15 +473,16 @@ impl<'a> CodeModule<'a> {
|
||||||
// to load it directly from the pointer
|
// to load it directly from the pointer
|
||||||
// * The consumer is zluda_trace, it wants to compute the length of
|
// * The consumer is zluda_trace, it wants to compute the length of
|
||||||
// the ELF and save it to a file
|
// the ELF and save it to a file
|
||||||
pub enum CodeLibaryRef<'a> {
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum CodeLibraryRef<'a> {
|
||||||
FatbincWrapper(Fatbin<'a>),
|
FatbincWrapper(Fatbin<'a>),
|
||||||
FatbinHeader(FatbinSubmodule<'a>),
|
FatbinHeader(FatbinSubmodule<'a>),
|
||||||
Text(&'a str),
|
Text(&'a str),
|
||||||
Elf(*const c_void),
|
Elf(&'a c_void),
|
||||||
Archive(*const c_void),
|
Archive(&'a c_void),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> CodeLibaryRef<'a> {
|
impl<'a> CodeLibraryRef<'a> {
|
||||||
const ELFMAG: [u8; 4] = *b"\x7FELF";
|
const ELFMAG: [u8; 4] = *b"\x7FELF";
|
||||||
const AR_MAGIC: [u8; 8] = *b"!<arch>\x0A";
|
const AR_MAGIC: [u8; 8] = *b"!<arch>\x0A";
|
||||||
|
|
||||||
|
@ -493,20 +494,20 @@ impl<'a> CodeLibaryRef<'a> {
|
||||||
FatbinHeader::MAGIC => Self::FatbinHeader(FatbinSubmodule {
|
FatbinHeader::MAGIC => Self::FatbinHeader(FatbinSubmodule {
|
||||||
header: &*(ptr.cast()),
|
header: &*(ptr.cast()),
|
||||||
}),
|
}),
|
||||||
Self::ELFMAG => Self::Elf(ptr),
|
Self::ELFMAG => Self::Elf(&*ptr),
|
||||||
_ => match *ptr.cast::<[u8; 8]>() {
|
_ => match *ptr.cast::<[u8; 8]>() {
|
||||||
Self::AR_MAGIC => Self::Archive(ptr),
|
Self::AR_MAGIC => Self::Archive(&*ptr),
|
||||||
_ => CStr::from_ptr(ptr.cast()).to_str().map(Self::Text)?,
|
_ => CStr::from_ptr(ptr.cast()).to_str().map(Self::Text)?,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn iterate_modules(
|
pub unsafe fn iterate_modules(
|
||||||
&self,
|
self,
|
||||||
mut fn_: impl FnMut(Option<(usize, Option<usize>)>, Result<CodeModule, FatbinError>),
|
mut fn_: impl FnMut(Option<(usize, Option<usize>)>, Result<CodeModuleRef<'a>, FatbinError>),
|
||||||
) {
|
) {
|
||||||
match self {
|
match self {
|
||||||
CodeLibaryRef::FatbincWrapper(fatbin) => {
|
CodeLibraryRef::FatbincWrapper(fatbin) => {
|
||||||
let module_iter = fatbin.get_submodules();
|
let module_iter = fatbin.get_submodules();
|
||||||
match module_iter {
|
match module_iter {
|
||||||
Ok(mut iter) => {
|
Ok(mut iter) => {
|
||||||
|
@ -525,7 +526,7 @@ impl<'a> CodeLibaryRef<'a> {
|
||||||
};
|
};
|
||||||
fn_(Some(index), module)
|
fn_(Some(index), module)
|
||||||
},
|
},
|
||||||
&submodule,
|
submodule,
|
||||||
),
|
),
|
||||||
Err(err) => fn_(
|
Err(err) => fn_(
|
||||||
module_index.map(|module_index| (module_index, None)),
|
module_index.map(|module_index| (module_index, None)),
|
||||||
|
@ -538,34 +539,35 @@ impl<'a> CodeLibaryRef<'a> {
|
||||||
Err(err) => fn_(None, Err(err)),
|
Err(err) => fn_(None, Err(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CodeLibaryRef::FatbinHeader(submodule) => iterate_modules_fatbin_header(
|
CodeLibraryRef::FatbinHeader(submodule) => iterate_modules_fatbin_header(
|
||||||
|index, module| fn_(Some((index, None)), module),
|
|index, module| fn_(Some((index, None)), module),
|
||||||
submodule,
|
submodule,
|
||||||
),
|
),
|
||||||
CodeLibaryRef::Text(text) => fn_(None, Ok(CodeModule::Text(*text))),
|
CodeLibraryRef::Text(text) => fn_(None, Ok(CodeModuleRef::Text(text))),
|
||||||
CodeLibaryRef::Elf(elf) => fn_(None, Ok(CodeModule::Elf(*elf))),
|
CodeLibraryRef::Elf(elf) => fn_(None, Ok(CodeModuleRef::Elf(elf))),
|
||||||
CodeLibaryRef::Archive(ar) => fn_(None, Ok(CodeModule::Archive(*ar))),
|
CodeLibraryRef::Archive(ar) => fn_(None, Ok(CodeModuleRef::Archive(ar))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn iterate_modules_fatbin_header(
|
unsafe fn iterate_modules_fatbin_header<'x>(
|
||||||
mut fn_: impl FnMut(usize, Result<CodeModule, FatbinError>),
|
mut fn_: impl FnMut(usize, Result<CodeModuleRef<'x>, FatbinError>),
|
||||||
submodule: &FatbinSubmodule<'_>,
|
submodule: FatbinSubmodule<'x>,
|
||||||
) {
|
) {
|
||||||
let mut iter = submodule.get_files();
|
let mut iter = submodule.get_files();
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
while let Some(file) = iter.next() {
|
while let Some(file) = iter.next() {
|
||||||
fn_(
|
fn_(
|
||||||
index,
|
index,
|
||||||
file.map(CodeModule::File)
|
file.map(CodeModuleRef::File)
|
||||||
.map_err(FatbinError::ParseFailure),
|
.map_err(FatbinError::ParseFailure),
|
||||||
);
|
);
|
||||||
index += 1;
|
index += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum CodeModule<'a> {
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum CodeModuleRef<'a> {
|
||||||
File(FatbinFile<'a>),
|
File(FatbinFile<'a>),
|
||||||
Text(&'a str),
|
Text(&'a str),
|
||||||
Elf(*const c_void),
|
Elf(*const c_void),
|
||||||
|
|
|
@ -82,14 +82,14 @@ impl StateTracker {
|
||||||
self.saved_modules.insert(cu_module);
|
self.saved_modules.insert(cu_module);
|
||||||
self.library_counter += 1;
|
self.library_counter += 1;
|
||||||
let code_ref = fn_logger.try_return(|| {
|
let code_ref = fn_logger.try_return(|| {
|
||||||
unsafe { zluda_common::CodeLibaryRef::try_load(raw_image) }
|
unsafe { zluda_common::CodeLibraryRef::try_load(raw_image) }
|
||||||
.map_err(ErrorEntry::NonUtf8ModuleText)
|
.map_err(ErrorEntry::NonUtf8ModuleText)
|
||||||
});
|
});
|
||||||
let code_ref = unwrap_some_or!(code_ref, return);
|
let code_ref = unwrap_some_or!(code_ref, return);
|
||||||
unsafe {
|
unsafe {
|
||||||
code_ref.iterate_modules(|index, module| match module {
|
code_ref.iterate_modules(|index, module| match module {
|
||||||
Err(err) => fn_logger.log(ErrorEntry::from(err)),
|
Err(err) => fn_logger.log(ErrorEntry::from(err)),
|
||||||
Ok(zluda_common::CodeModule::Elf(elf)) => match get_elf_size(elf) {
|
Ok(zluda_common::CodeModuleRef::Elf(elf)) => match get_elf_size(elf) {
|
||||||
Some(len) => {
|
Some(len) => {
|
||||||
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
|
let elf_image = std::slice::from_raw_parts(elf.cast::<u8>(), len);
|
||||||
self.record_new_submodule(index, elf_image, fn_logger, "elf");
|
self.record_new_submodule(index, elf_image, fn_logger, "elf");
|
||||||
|
@ -100,21 +100,21 @@ impl StateTracker {
|
||||||
kind: "ELF",
|
kind: "ELF",
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
Ok(zluda_common::CodeModule::Archive(archive)) => {
|
Ok(zluda_common::CodeModuleRef::Archive(archive)) => {
|
||||||
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
fn_logger.log(log::ErrorEntry::UnsupportedModule {
|
||||||
module: cu_module,
|
module: cu_module,
|
||||||
raw_image: archive,
|
raw_image: archive,
|
||||||
kind: "archive",
|
kind: "archive",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Ok(zluda_common::CodeModule::File(file)) => {
|
Ok(zluda_common::CodeModuleRef::File(file)) => {
|
||||||
if let Some(buffer) = fn_logger
|
if let Some(buffer) = fn_logger
|
||||||
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
|
.try_(|_| file.get_or_decompress_content().map_err(ErrorEntry::from))
|
||||||
{
|
{
|
||||||
self.record_new_submodule(index, &*buffer, fn_logger, file.kind());
|
self.record_new_submodule(index, &*buffer, fn_logger, file.kind());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(zluda_common::CodeModule::Text(ptx)) => {
|
Ok(zluda_common::CodeModuleRef::Text(ptx)) => {
|
||||||
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx");
|
self.record_new_submodule(index, ptx.as_bytes(), fn_logger, "ptx");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue