mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-07 18:19:07 +00:00
Implement kernel cache (#465)
This commit is contained in:
parent
d2f92e4267
commit
28eca3d75a
20 changed files with 2407 additions and 47 deletions
1931
Cargo.lock
generated
1931
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -21,6 +21,7 @@ members = [
|
||||||
"zluda_bindgen",
|
"zluda_bindgen",
|
||||||
"zluda_blas",
|
"zluda_blas",
|
||||||
"zluda_blaslt",
|
"zluda_blaslt",
|
||||||
|
"zluda_cache",
|
||||||
"zluda_common",
|
"zluda_common",
|
||||||
"zluda_dnn",
|
"zluda_dnn",
|
||||||
"zluda_trace",
|
"zluda_trace",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use amd_comgr_sys::*;
|
use amd_comgr_sys::*;
|
||||||
use std::{ffi::CStr, mem, ptr};
|
use std::{ffi::CStr, iter, mem, ptr};
|
||||||
|
|
||||||
macro_rules! call_dispatch_arg {
|
macro_rules! call_dispatch_arg {
|
||||||
(2, $arg:ident) => {
|
(2, $arg:ident) => {
|
||||||
|
@ -105,9 +105,10 @@ comgr_owned!(
|
||||||
);
|
);
|
||||||
|
|
||||||
impl<'a> ActionInfo<'a> {
|
impl<'a> ActionInfo<'a> {
|
||||||
fn set_isa_name(&self, isa: &CStr) -> Result<(), Error> {
|
fn set_isa_name(&self, isa: &str) -> Result<(), Error> {
|
||||||
let mut full_isa = "amdgcn-amd-amdhsa--".to_string().into_bytes();
|
let mut full_isa = "amdgcn-amd-amdhsa--".to_string().into_bytes();
|
||||||
full_isa.extend(isa.to_bytes_with_nul());
|
full_isa.extend(isa.as_bytes());
|
||||||
|
full_isa.push(0);
|
||||||
call_dispatch!(self.comgr => amd_comgr_action_info_set_isa_name(self, { full_isa.as_ptr().cast() }));
|
call_dispatch!(self.comgr => amd_comgr_action_info_set_isa_name(self, { full_isa.as_ptr().cast() }));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ impl Data {
|
||||||
|
|
||||||
pub fn compile_bitcode(
|
pub fn compile_bitcode(
|
||||||
comgr: &Comgr,
|
comgr: &Comgr,
|
||||||
gcn_arch: &CStr,
|
gcn_arch: &str,
|
||||||
main_buffer: &[u8],
|
main_buffer: &[u8],
|
||||||
attributes_buffer: &[u8],
|
attributes_buffer: &[u8],
|
||||||
ptx_impl: &[u8],
|
ptx_impl: &[u8],
|
||||||
|
@ -233,6 +234,48 @@ pub fn compile_bitcode(
|
||||||
executable.copy_content(comgr)
|
executable.copy_content(comgr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_clang_version(comgr: &Comgr) -> Result<String, Error> {
|
||||||
|
let version_string_set = DataSet::new(comgr)?;
|
||||||
|
let version_string = Data::new(
|
||||||
|
comgr,
|
||||||
|
DataKind::Source,
|
||||||
|
c"version.cpp",
|
||||||
|
b"__clang_version__",
|
||||||
|
)?;
|
||||||
|
version_string_set.add(&version_string)?;
|
||||||
|
let preprocessor_info = ActionInfo::new(comgr)?;
|
||||||
|
preprocessor_info.set_language(Language::Hip)?;
|
||||||
|
preprocessor_info.set_options(iter::once(c"-P"))?;
|
||||||
|
let preprocessed = comgr.do_action(
|
||||||
|
ActionKind::SourceToPreprocessor,
|
||||||
|
&preprocessor_info,
|
||||||
|
&version_string_set,
|
||||||
|
)?;
|
||||||
|
let data = preprocessed.get_data(DataKind::Source, 0)?;
|
||||||
|
String::from_utf8(trim_whitespace_and_quotes(data.copy_content(comgr)?)?)
|
||||||
|
.map_err(|_| Error::UNKNOWN)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When running the preprocessor to expand the macro the output is surrounded by
|
||||||
|
// quotes (because it is a string literal) and has a trailing newline.
|
||||||
|
// This function is not strictly necessary, but it makes the output cleaner
|
||||||
|
fn trim_whitespace_and_quotes(data: Vec<u8>) -> Result<Vec<u8>, Error> {
|
||||||
|
fn is_not_whitespace_or_quote(b: u8) -> bool {
|
||||||
|
!(b.is_ascii_whitespace() || b == b'"')
|
||||||
|
}
|
||||||
|
let prefix_length = data
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.position(is_not_whitespace_or_quote)
|
||||||
|
.ok_or(Error::UNKNOWN)?;
|
||||||
|
let last_letter = data
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.rposition(is_not_whitespace_or_quote)
|
||||||
|
.ok_or(Error::UNKNOWN)?;
|
||||||
|
Ok(data[prefix_length..=last_letter].to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
pub enum Comgr {
|
pub enum Comgr {
|
||||||
V2(amd_comgr_sys::comgr2::Comgr2),
|
V2(amd_comgr_sys::comgr2::Comgr2),
|
||||||
V3(amd_comgr_sys::comgr3::Comgr3),
|
V3(amd_comgr_sys::comgr3::Comgr3),
|
||||||
|
@ -356,7 +399,8 @@ impl_into!(
|
||||||
amd_comgr_action_kind_t,
|
amd_comgr_action_kind_t,
|
||||||
[
|
[
|
||||||
LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC,
|
LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC,
|
||||||
CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE
|
CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE,
|
||||||
|
SourceToPreprocessor => AMD_COMGR_ACTION_SOURCE_TO_PREPROCESSOR
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ petgraph = "0.7.1"
|
||||||
microlp = "0.2.11"
|
microlp = "0.2.11"
|
||||||
int-enum = "1.1"
|
int-enum = "1.1"
|
||||||
unwrap_or = "1.0.1"
|
unwrap_or = "1.0.1"
|
||||||
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||||
|
|
|
@ -47,6 +47,7 @@ quick_error! {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GPU attributes needed at compile time.
|
/// GPU attributes needed at compile time.
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
pub struct Attributes {
|
pub struct Attributes {
|
||||||
/// Clock frequency in kHz.
|
/// Clock frequency in kHz.
|
||||||
pub clock_rate: u32,
|
pub clock_rate: u32,
|
||||||
|
|
|
@ -643,7 +643,9 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
|
||||||
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
|
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
|
||||||
let elf_module = comgr::compile_bitcode(
|
let elf_module = comgr::compile_bitcode(
|
||||||
&comgr,
|
&comgr,
|
||||||
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
|
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
&*module.llvm_ir.write_bitcode_to_memory(),
|
&*module.llvm_ir.write_bitcode_to_memory(),
|
||||||
&*module.attributes_ir.write_bitcode_to_memory(),
|
&*module.attributes_ir.write_bitcode_to_memory(),
|
||||||
module.linked_bitcode(),
|
module.linked_bitcode(),
|
||||||
|
|
|
@ -7,8 +7,8 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bpaf = { version = "0.9.15", features = ["derive"] }
|
bpaf = { version = "0.9.15", features = ["derive"] }
|
||||||
cargo_metadata = "0.19.1"
|
cargo_metadata = "0.19.1"
|
||||||
serde = "1.0.217"
|
serde = "1.0.219"
|
||||||
serde_json = "1.0.137"
|
serde_json = "1.0.142"
|
||||||
|
|
||||||
[target.'cfg(unix)'.dependencies]
|
[target.'cfg(unix)'.dependencies]
|
||||||
flate2 = { version = "1.1.1", features = ["zlib-rs"], default-features = false }
|
flate2 = { version = "1.1.1", features = ["zlib-rs"], default-features = false }
|
||||||
|
|
|
@ -16,6 +16,7 @@ cuda_types = { path = "../cuda_types" }
|
||||||
cuda_macros = { path = "../cuda_macros" }
|
cuda_macros = { path = "../cuda_macros" }
|
||||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||||
dark_api = { path = "../dark_api" }
|
dark_api = { path = "../dark_api" }
|
||||||
|
zluda_cache = { path = "../zluda_cache" }
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4"
|
||||||
num_enum = "0.4"
|
num_enum = "0.4"
|
||||||
lz4-sys = "1.9"
|
lz4-sys = "1.9"
|
||||||
|
@ -23,6 +24,9 @@ tempfile = "3"
|
||||||
paste = "1.0"
|
paste = "1.0"
|
||||||
rustc-hash = "1.1"
|
rustc-hash = "1.1"
|
||||||
zluda_common = { path = "../zluda_common" }
|
zluda_common = { path = "../zluda_common" }
|
||||||
|
blake3 = "1.8.2"
|
||||||
|
serde = "1.0.219"
|
||||||
|
serde_json = "1.0.142"
|
||||||
|
|
||||||
[target.'cfg(windows)'.dependencies]
|
[target.'cfg(windows)'.dependencies]
|
||||||
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||||
|
@ -31,6 +35,9 @@ winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||||
libc = "0.2"
|
libc = "0.2"
|
||||||
dtor = "0.0.7"
|
dtor = "0.0.7"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
vergen-gix = "1.0.9"
|
||||||
|
|
||||||
[package.metadata.zluda]
|
[package.metadata.zluda]
|
||||||
linux_symlinks = [
|
linux_symlinks = [
|
||||||
"libcuda.so",
|
"libcuda.so",
|
||||||
|
|
10
zluda/build.rs
Normal file
10
zluda/build.rs
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
use vergen_gix::{Emitter, GixBuilder};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let git = GixBuilder::all_git().unwrap();
|
||||||
|
Emitter::default()
|
||||||
|
.add_instructions(&git)
|
||||||
|
.unwrap()
|
||||||
|
.emit()
|
||||||
|
.unwrap();
|
||||||
|
}
|
|
@ -97,6 +97,15 @@ impl ZludaObject for Context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_current_device() -> Result<hipDevice_t, CUerror> {
|
||||||
|
STACK.with(|stack| {
|
||||||
|
stack
|
||||||
|
.try_borrow()
|
||||||
|
.map_err(|_| CUerror::UNKNOWN)
|
||||||
|
.and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn get_current_context() -> Result<CUcontext, CUerror> {
|
pub(crate) fn get_current_context() -> Result<CUcontext, CUerror> {
|
||||||
if let Some(ctx) = STACK.with(|stack| stack.borrow().last().copied().map(|(ctx, _)| ctx)) {
|
if let Some(ctx) = STACK.with(|stack| stack.borrow().last().copied().map(|(ctx, _)| ctx)) {
|
||||||
return Ok(ctx);
|
return Ok(ctx);
|
||||||
|
|
|
@ -17,6 +17,8 @@ mod os;
|
||||||
pub(crate) struct GlobalState {
|
pub(crate) struct GlobalState {
|
||||||
pub devices: Vec<Device>,
|
pub devices: Vec<Device>,
|
||||||
pub comgr: Comgr,
|
pub comgr: Comgr,
|
||||||
|
pub comgr_clang_version: String,
|
||||||
|
pub cache_path: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Device {
|
pub(crate) struct Device {
|
||||||
|
@ -52,8 +54,11 @@ pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||||
let mut device_count = 0;
|
let mut device_count = 0;
|
||||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||||
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
||||||
|
let comgr_clang_version =
|
||||||
|
comgr::get_clang_version(&comgr).map_err(|_| CUerror::UNKNOWN)?;
|
||||||
Ok(GlobalState {
|
Ok(GlobalState {
|
||||||
comgr,
|
comgr,
|
||||||
|
comgr_clang_version,
|
||||||
devices: (0..device_count)
|
devices: (0..device_count)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let mut props = unsafe { mem::zeroed() };
|
let mut props = unsafe { mem::zeroed() };
|
||||||
|
@ -68,6 +73,7 @@ pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?,
|
.collect::<Result<Vec<_>, _>>()?,
|
||||||
|
cache_path: zluda_cache::ModuleCache::create_cache_dir_and_get_path(),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
|
@ -65,26 +65,95 @@ fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
|
||||||
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
|
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
|
||||||
let global_state = driver::global_state()?;
|
let global_state = driver::global_state()?;
|
||||||
let text = get_ptx(image)?;
|
let text = get_ptx(image)?;
|
||||||
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
let hip_properties = get_hip_properties()?;
|
||||||
let mut dev = 0;
|
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
||||||
unsafe { hipCtxGetDevice(&mut dev) }?;
|
|
||||||
let mut props = unsafe { mem::zeroed() };
|
|
||||||
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
|
|
||||||
let attributes = ptx::Attributes {
|
let attributes = ptx::Attributes {
|
||||||
clock_rate: props.clockRate as u32,
|
clock_rate: hip_properties.clockRate as u32,
|
||||||
};
|
};
|
||||||
|
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
|
||||||
|
let cache = zluda_cache::ModuleCache::open(p)?;
|
||||||
|
let key = get_cache_key(global_state, gcn_arch, &text, &attributes)?;
|
||||||
|
Some((cache, key))
|
||||||
|
});
|
||||||
|
let cached_binary = load_cached_binary(&mut cache_with_key);
|
||||||
|
let elf_module = cached_binary.ok_or(CUerror::UNKNOWN).or_else(|_| {
|
||||||
|
compile_from_ptx_and_cache(
|
||||||
|
&global_state.comgr,
|
||||||
|
gcn_arch,
|
||||||
|
attributes,
|
||||||
|
&text,
|
||||||
|
&mut cache_with_key,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let mut hip_module = unsafe { mem::zeroed() };
|
||||||
|
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
|
||||||
|
Ok(hip_module)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
|
||||||
|
let hip_dev = super::context::get_current_device()?;
|
||||||
|
let mut props = unsafe { mem::zeroed() };
|
||||||
|
unsafe { hipGetDevicePropertiesR0600(&mut props, hip_dev) }?;
|
||||||
|
Ok(props)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_gcn_arch<'a>(props: &'a hipDeviceProp_tR0600) -> Result<&'a str, CUerror> {
|
||||||
|
let gcn_arch = unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) };
|
||||||
|
gcn_arch.to_str().map_err(|_| CUerror::UNKNOWN)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_cache_key<'a, 'b>(
|
||||||
|
global_state: &'static driver::GlobalState,
|
||||||
|
isa: &'a str,
|
||||||
|
text: &str,
|
||||||
|
attributes: &ptx::Attributes,
|
||||||
|
) -> Option<zluda_cache::ModuleKey<'a>> {
|
||||||
|
// Serialization here is deterministic. When marking a type with
|
||||||
|
// #[derive(serde::Serialize)] the derived implementation will just write
|
||||||
|
// fields in the order of their declaration. It's not explictly guaranteed
|
||||||
|
// by serde, but it is the only sensible thing to do, so I feel safe
|
||||||
|
// to rely on it
|
||||||
|
let serialized_attributes = serde_json::to_string(attributes).ok()?;
|
||||||
|
Some(zluda_cache::ModuleKey {
|
||||||
|
hash: blake3::hash(text.as_bytes()).to_hex(),
|
||||||
|
compiler_version: &*global_state.comgr_clang_version,
|
||||||
|
zluda_version: env!("VERGEN_GIT_SHA"),
|
||||||
|
device: isa,
|
||||||
|
backend_key: serialized_attributes,
|
||||||
|
last_access: zluda_cache::ModuleCache::time_now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_cached_binary(
|
||||||
|
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||||
|
) -> Option<Vec<u8>> {
|
||||||
|
cache_with_key
|
||||||
|
.as_mut()
|
||||||
|
.and_then(|(c, key)| c.get_module_binary(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_from_ptx_and_cache(
|
||||||
|
comgr: &comgr::Comgr,
|
||||||
|
gcn_arch: &str,
|
||||||
|
attributes: ptx::Attributes,
|
||||||
|
text: &str,
|
||||||
|
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||||
|
) -> Result<Vec<u8>, CUerror> {
|
||||||
|
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||||
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
|
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
|
||||||
let elf_module = comgr::compile_bitcode(
|
let elf_module = comgr::compile_bitcode(
|
||||||
&global_state.comgr,
|
comgr,
|
||||||
unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) },
|
gcn_arch,
|
||||||
&*llvm_module.llvm_ir.write_bitcode_to_memory(),
|
&*llvm_module.llvm_ir.write_bitcode_to_memory(),
|
||||||
&*llvm_module.attributes_ir.write_bitcode_to_memory(),
|
&*llvm_module.attributes_ir.write_bitcode_to_memory(),
|
||||||
llvm_module.linked_bitcode(),
|
llvm_module.linked_bitcode(),
|
||||||
)
|
)
|
||||||
.map_err(|_| CUerror::UNKNOWN)?;
|
.map_err(|_| CUerror::UNKNOWN)?;
|
||||||
let mut hip_module = unsafe { mem::zeroed() };
|
if let Some((cache, key)) = cache_with_key {
|
||||||
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
|
key.last_access = zluda_cache::ModuleCache::time_now();
|
||||||
Ok(hip_module)
|
cache.insert_module(key, &elf_module);
|
||||||
|
}
|
||||||
|
Ok(elf_module)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
||||||
|
|
14
zluda_cache/Cargo.toml
Normal file
14
zluda_cache/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "zluda_cache"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
diesel = { version = "2.2.12", features = [
|
||||||
|
"sqlite",
|
||||||
|
"returning_clauses_for_sqlite_3_35",
|
||||||
|
] }
|
||||||
|
diesel_migrations = "2.2.0"
|
||||||
|
libsqlite3-sys = { version = "0.35", features = ["bundled"] }
|
||||||
|
dirs = "6.0.0"
|
||||||
|
arrayvec = "0.7.6"
|
10
zluda_cache/diesel.toml
Normal file
10
zluda_cache/diesel.toml
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# For documentation on how to configure this file,
|
||||||
|
# see https://diesel.rs/guides/configuring-diesel-cli
|
||||||
|
|
||||||
|
[print_schema]
|
||||||
|
file = "src/schema.rs"
|
||||||
|
custom_type_derives = ["diesel::query_builder::QueryId", "Clone"]
|
||||||
|
sqlite_integer_primary_key_is_bigint = true
|
||||||
|
|
||||||
|
[migrations_directory]
|
||||||
|
dir = "migrations"
|
0
zluda_cache/migrations/.keep
Normal file
0
zluda_cache/migrations/.keep
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
DROP TABLE modules;
|
||||||
|
DROP TABLE globals;
|
|
@ -0,0 +1,41 @@
|
||||||
|
CREATE TABLE modules (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
hash TEXT NOT NULL,
|
||||||
|
compiler_version TEXT NOT NULL,
|
||||||
|
zluda_version TEXT NOT NULL,
|
||||||
|
device TEXT NOT NULL,
|
||||||
|
backend_key TEXT NOT NULL,
|
||||||
|
binary BLOB NOT NULL,
|
||||||
|
last_access BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS modules_index ON modules (hash, compiler_version, zluda_version, device, backend_key);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS globals (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value BIGINT NOT NULL
|
||||||
|
) WITHOUT ROWID;
|
||||||
|
|
||||||
|
INSERT OR IGNORE INTO globals (key, value) VALUES ('total_size', 0);
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS update_size_on_delete
|
||||||
|
AFTER
|
||||||
|
DELETE ON modules FOR EACH ROW BEGIN
|
||||||
|
UPDATE
|
||||||
|
globals
|
||||||
|
SET
|
||||||
|
value = value - length(OLD.binary)
|
||||||
|
WHERE
|
||||||
|
key = 'total_size';
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS update_size_on_insert
|
||||||
|
AFTER
|
||||||
|
INSERT ON modules FOR EACH ROW BEGIN
|
||||||
|
UPDATE
|
||||||
|
globals
|
||||||
|
SET
|
||||||
|
value = value + length(NEW.binary)
|
||||||
|
WHERE
|
||||||
|
key = 'total_size';
|
||||||
|
END;
|
231
zluda_cache/src/lib.rs
Normal file
231
zluda_cache/src/lib.rs
Normal file
|
@ -0,0 +1,231 @@
|
||||||
|
use crate::schema::modules;
|
||||||
|
use arrayvec::ArrayString;
|
||||||
|
use diesel::{connection::SimpleConnection, prelude::*};
|
||||||
|
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
pub(crate) mod models;
|
||||||
|
pub(crate) mod schema;
|
||||||
|
|
||||||
|
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
|
||||||
|
|
||||||
|
pub struct ModuleKey<'a> {
|
||||||
|
pub hash: ArrayString<64>,
|
||||||
|
pub compiler_version: &'static str,
|
||||||
|
pub zluda_version: &'static str,
|
||||||
|
pub device: &'a str,
|
||||||
|
pub backend_key: String,
|
||||||
|
pub last_access: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModuleCache(SqliteConnection);
|
||||||
|
|
||||||
|
impl ModuleCache {
|
||||||
|
pub fn create_cache_dir_and_get_path() -> Option<String> {
|
||||||
|
let mut cache_dir = dirs::cache_dir()?;
|
||||||
|
cache_dir.extend(["zluda", "ComputeCache"]);
|
||||||
|
// We ensure that the cache directory exists
|
||||||
|
std::fs::create_dir_all(&cache_dir).ok()?;
|
||||||
|
// No need to create the file, it will be created by SQLite on first access
|
||||||
|
cache_dir.push("zluda.db");
|
||||||
|
Some(cache_dir.to_string_lossy().into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn open(file_path: &str) -> Option<Self> {
|
||||||
|
let mut conn = SqliteConnection::establish(file_path).ok()?;
|
||||||
|
conn.batch_execute("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")
|
||||||
|
.ok()?;
|
||||||
|
conn.run_pending_migrations(MIGRATIONS).ok()?;
|
||||||
|
Some(Self(conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_module_binary(&mut self, key: &ModuleKey) -> Option<Vec<u8>> {
|
||||||
|
diesel::update(modules::dsl::modules)
|
||||||
|
.set(modules::last_access.eq(key.last_access))
|
||||||
|
.filter(modules::hash.eq(key.hash.as_str()))
|
||||||
|
.filter(modules::compiler_version.eq(&key.compiler_version))
|
||||||
|
.filter(modules::zluda_version.eq(key.zluda_version))
|
||||||
|
.filter(modules::device.eq(key.device))
|
||||||
|
.filter(modules::backend_key.eq(&key.backend_key))
|
||||||
|
.returning(modules::binary)
|
||||||
|
.get_result(&mut self.0)
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert_module(&mut self, key: &ModuleKey, binary: &[u8]) {
|
||||||
|
diesel::insert_into(modules::dsl::modules)
|
||||||
|
.values(models::AddModule {
|
||||||
|
hash: key.hash.as_str(),
|
||||||
|
compiler_version: &key.compiler_version,
|
||||||
|
zluda_version: key.zluda_version,
|
||||||
|
device: key.device,
|
||||||
|
backend_key: &key.backend_key,
|
||||||
|
last_access: key.last_access,
|
||||||
|
binary,
|
||||||
|
})
|
||||||
|
.execute(&mut self.0)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn time_now() -> i64 {
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or(Duration::ZERO)
|
||||||
|
.as_millis() as i64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::{
|
||||||
|
schema::{globals::dsl::*, modules::dsl::*},
|
||||||
|
ModuleCache,
|
||||||
|
};
|
||||||
|
use arrayvec::ArrayString;
|
||||||
|
use diesel::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Queryable, Selectable)]
|
||||||
|
#[diesel(table_name = crate::schema::modules)]
|
||||||
|
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||||
|
pub struct Module {
|
||||||
|
pub id: i64,
|
||||||
|
pub hash: String,
|
||||||
|
pub binary: Vec<u8>,
|
||||||
|
pub last_access: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Queryable, Selectable)]
|
||||||
|
#[diesel(table_name = crate::schema::globals)]
|
||||||
|
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||||
|
pub struct Global {
|
||||||
|
pub key: String,
|
||||||
|
pub value: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_db_returns_no_module() {
|
||||||
|
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||||
|
let module_binary = db.get_module_binary(&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 123,
|
||||||
|
});
|
||||||
|
assert!(module_binary.is_none());
|
||||||
|
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_modules.len(), 0);
|
||||||
|
let all_globals: Vec<Global> = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_globals[0].key, "total_size");
|
||||||
|
assert_eq!(all_globals[0].value, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn newly_inserted_module_increments_total_size() {
|
||||||
|
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||||
|
db.insert_module(
|
||||||
|
&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash1").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 123,
|
||||||
|
},
|
||||||
|
&[1, 2, 3, 4, 5],
|
||||||
|
);
|
||||||
|
db.insert_module(
|
||||||
|
&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash2").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 124,
|
||||||
|
},
|
||||||
|
&[1, 2, 3],
|
||||||
|
);
|
||||||
|
let mut all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||||
|
all_modules.sort_by_key(|m: &Module| m.id);
|
||||||
|
assert_eq!(all_modules.len(), 2);
|
||||||
|
assert_eq!(all_modules[0].hash, "test_hash1");
|
||||||
|
assert_eq!(all_modules[0].last_access, 123);
|
||||||
|
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||||
|
assert_eq!(all_modules[1].hash, "test_hash2");
|
||||||
|
assert_eq!(all_modules[1].last_access, 124);
|
||||||
|
assert_eq!(all_modules[1].binary, &[1, 2, 3]);
|
||||||
|
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_globals[0].key, "total_size");
|
||||||
|
assert_eq!(all_globals[0].value, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_bumps_last_access() {
|
||||||
|
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||||
|
db.insert_module(
|
||||||
|
&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 123,
|
||||||
|
},
|
||||||
|
&[1, 2, 3, 4, 5],
|
||||||
|
);
|
||||||
|
let module_binary = db
|
||||||
|
.get_module_binary(&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 124,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_modules.len(), 1);
|
||||||
|
assert_eq!(all_modules[0].last_access, 124);
|
||||||
|
assert_eq!(module_binary, &[1, 2, 3, 4, 5]);
|
||||||
|
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||||
|
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_globals[0].key, "total_size");
|
||||||
|
assert_eq!(all_globals[0].value, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn double_insert_does_not_override() {
|
||||||
|
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||||
|
db.insert_module(
|
||||||
|
&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 123,
|
||||||
|
},
|
||||||
|
&[1, 2, 3, 4, 5],
|
||||||
|
);
|
||||||
|
db.insert_module(
|
||||||
|
&super::ModuleKey {
|
||||||
|
hash: ArrayString::from("test_hash").unwrap(),
|
||||||
|
compiler_version: "1.0.0",
|
||||||
|
zluda_version: "1.0.0",
|
||||||
|
device: "test_device",
|
||||||
|
backend_key: "{}".to_string(),
|
||||||
|
last_access: 124,
|
||||||
|
},
|
||||||
|
&[5, 4, 3, 2, 1],
|
||||||
|
);
|
||||||
|
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_modules.len(), 1);
|
||||||
|
assert_eq!(all_modules[0].last_access, 123);
|
||||||
|
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||||
|
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||||
|
assert_eq!(all_globals[0].key, "total_size");
|
||||||
|
assert_eq!(all_globals[0].value, 5);
|
||||||
|
}
|
||||||
|
}
|
14
zluda_cache/src/models.rs
Normal file
14
zluda_cache/src/models.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
use crate::schema::modules;
|
||||||
|
use diesel::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Insertable)]
|
||||||
|
#[diesel(table_name = modules)]
|
||||||
|
pub(crate) struct AddModule<'a> {
|
||||||
|
pub hash: &'a str,
|
||||||
|
pub compiler_version: &'a str,
|
||||||
|
pub zluda_version: &'a str,
|
||||||
|
pub device: &'a str,
|
||||||
|
pub backend_key: &'a str,
|
||||||
|
pub binary: &'a [u8],
|
||||||
|
pub last_access: i64,
|
||||||
|
}
|
23
zluda_cache/src/schema.rs
Normal file
23
zluda_cache/src/schema.rs
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
// @generated automatically by Diesel CLI.
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
globals (key) {
|
||||||
|
key -> Text,
|
||||||
|
value -> BigInt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
modules (id) {
|
||||||
|
id -> BigInt,
|
||||||
|
hash -> Text,
|
||||||
|
compiler_version -> Text,
|
||||||
|
zluda_version -> Text,
|
||||||
|
device -> Text,
|
||||||
|
backend_key -> Text,
|
||||||
|
binary -> Binary,
|
||||||
|
last_access -> BigInt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
diesel::allow_tables_to_appear_in_same_query!(globals, modules,);
|
Loading…
Add table
Add a link
Reference in a new issue