mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-03 16:17:11 +00:00
Implement kernel cache (#465)
This commit is contained in:
parent
d2f92e4267
commit
28eca3d75a
20 changed files with 2407 additions and 47 deletions
1931
Cargo.lock
generated
1931
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -21,6 +21,7 @@ members = [
|
|||
"zluda_bindgen",
|
||||
"zluda_blas",
|
||||
"zluda_blaslt",
|
||||
"zluda_cache",
|
||||
"zluda_common",
|
||||
"zluda_dnn",
|
||||
"zluda_trace",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use amd_comgr_sys::*;
|
||||
use std::{ffi::CStr, mem, ptr};
|
||||
use std::{ffi::CStr, iter, mem, ptr};
|
||||
|
||||
macro_rules! call_dispatch_arg {
|
||||
(2, $arg:ident) => {
|
||||
|
@ -105,9 +105,10 @@ comgr_owned!(
|
|||
);
|
||||
|
||||
impl<'a> ActionInfo<'a> {
|
||||
fn set_isa_name(&self, isa: &CStr) -> Result<(), Error> {
|
||||
fn set_isa_name(&self, isa: &str) -> Result<(), Error> {
|
||||
let mut full_isa = "amdgcn-amd-amdhsa--".to_string().into_bytes();
|
||||
full_isa.extend(isa.to_bytes_with_nul());
|
||||
full_isa.extend(isa.as_bytes());
|
||||
full_isa.push(0);
|
||||
call_dispatch!(self.comgr => amd_comgr_action_info_set_isa_name(self, { full_isa.as_ptr().cast() }));
|
||||
Ok(())
|
||||
}
|
||||
|
@ -176,7 +177,7 @@ impl Data {
|
|||
|
||||
pub fn compile_bitcode(
|
||||
comgr: &Comgr,
|
||||
gcn_arch: &CStr,
|
||||
gcn_arch: &str,
|
||||
main_buffer: &[u8],
|
||||
attributes_buffer: &[u8],
|
||||
ptx_impl: &[u8],
|
||||
|
@ -233,6 +234,48 @@ pub fn compile_bitcode(
|
|||
executable.copy_content(comgr)
|
||||
}
|
||||
|
||||
pub fn get_clang_version(comgr: &Comgr) -> Result<String, Error> {
|
||||
let version_string_set = DataSet::new(comgr)?;
|
||||
let version_string = Data::new(
|
||||
comgr,
|
||||
DataKind::Source,
|
||||
c"version.cpp",
|
||||
b"__clang_version__",
|
||||
)?;
|
||||
version_string_set.add(&version_string)?;
|
||||
let preprocessor_info = ActionInfo::new(comgr)?;
|
||||
preprocessor_info.set_language(Language::Hip)?;
|
||||
preprocessor_info.set_options(iter::once(c"-P"))?;
|
||||
let preprocessed = comgr.do_action(
|
||||
ActionKind::SourceToPreprocessor,
|
||||
&preprocessor_info,
|
||||
&version_string_set,
|
||||
)?;
|
||||
let data = preprocessed.get_data(DataKind::Source, 0)?;
|
||||
String::from_utf8(trim_whitespace_and_quotes(data.copy_content(comgr)?)?)
|
||||
.map_err(|_| Error::UNKNOWN)
|
||||
}
|
||||
|
||||
// When running the preprocessor to expand the macro the output is surrounded by
|
||||
// quotes (because it is a string literal) and has a trailing newline.
|
||||
// This function is not strictly necessary, but it makes the output cleaner
|
||||
fn trim_whitespace_and_quotes(data: Vec<u8>) -> Result<Vec<u8>, Error> {
|
||||
fn is_not_whitespace_or_quote(b: u8) -> bool {
|
||||
!(b.is_ascii_whitespace() || b == b'"')
|
||||
}
|
||||
let prefix_length = data
|
||||
.iter()
|
||||
.copied()
|
||||
.position(is_not_whitespace_or_quote)
|
||||
.ok_or(Error::UNKNOWN)?;
|
||||
let last_letter = data
|
||||
.iter()
|
||||
.copied()
|
||||
.rposition(is_not_whitespace_or_quote)
|
||||
.ok_or(Error::UNKNOWN)?;
|
||||
Ok(data[prefix_length..=last_letter].to_vec())
|
||||
}
|
||||
|
||||
pub enum Comgr {
|
||||
V2(amd_comgr_sys::comgr2::Comgr2),
|
||||
V3(amd_comgr_sys::comgr3::Comgr3),
|
||||
|
@ -356,7 +399,8 @@ impl_into!(
|
|||
amd_comgr_action_kind_t,
|
||||
[
|
||||
LinkBcToBc => AMD_COMGR_ACTION_LINK_BC_TO_BC,
|
||||
CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE
|
||||
CompileSourceToExecutable => AMD_COMGR_ACTION_COMPILE_SOURCE_TO_EXECUTABLE,
|
||||
SourceToPreprocessor => AMD_COMGR_ACTION_SOURCE_TO_PREPROCESSOR
|
||||
]
|
||||
);
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ petgraph = "0.7.1"
|
|||
microlp = "0.2.11"
|
||||
int-enum = "1.1"
|
||||
unwrap_or = "1.0.1"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
|
|
|
@ -47,6 +47,7 @@ quick_error! {
|
|||
}
|
||||
|
||||
/// GPU attributes needed at compile time.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct Attributes {
|
||||
/// Clock frequency in kHz.
|
||||
pub clock_rate: u32,
|
||||
|
|
|
@ -643,7 +643,9 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
|
|||
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
|
||||
let elf_module = comgr::compile_bitcode(
|
||||
&comgr,
|
||||
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) },
|
||||
unsafe { CStr::from_ptr(dev_props.gcnArchName.as_ptr()) }
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
&*module.llvm_ir.write_bitcode_to_memory(),
|
||||
&*module.attributes_ir.write_bitcode_to_memory(),
|
||||
module.linked_bitcode(),
|
||||
|
|
|
@ -7,8 +7,8 @@ edition = "2021"
|
|||
[dependencies]
|
||||
bpaf = { version = "0.9.15", features = ["derive"] }
|
||||
cargo_metadata = "0.19.1"
|
||||
serde = "1.0.217"
|
||||
serde_json = "1.0.137"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.142"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
flate2 = { version = "1.1.1", features = ["zlib-rs"], default-features = false }
|
||||
|
|
|
@ -16,6 +16,7 @@ cuda_types = { path = "../cuda_types" }
|
|||
cuda_macros = { path = "../cuda_macros" }
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
dark_api = { path = "../dark_api" }
|
||||
zluda_cache = { path = "../zluda_cache" }
|
||||
lazy_static = "1.4"
|
||||
num_enum = "0.4"
|
||||
lz4-sys = "1.9"
|
||||
|
@ -23,6 +24,9 @@ tempfile = "3"
|
|||
paste = "1.0"
|
||||
rustc-hash = "1.1"
|
||||
zluda_common = { path = "../zluda_common" }
|
||||
blake3 = "1.8.2"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.142"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
winapi = { version = "0.3", features = ["heapapi", "std"] }
|
||||
|
@ -31,6 +35,9 @@ winapi = { version = "0.3", features = ["heapapi", "std"] }
|
|||
libc = "0.2"
|
||||
dtor = "0.0.7"
|
||||
|
||||
[build-dependencies]
|
||||
vergen-gix = "1.0.9"
|
||||
|
||||
[package.metadata.zluda]
|
||||
linux_symlinks = [
|
||||
"libcuda.so",
|
||||
|
|
10
zluda/build.rs
Normal file
10
zluda/build.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
use vergen_gix::{Emitter, GixBuilder};
|
||||
|
||||
fn main() {
|
||||
let git = GixBuilder::all_git().unwrap();
|
||||
Emitter::default()
|
||||
.add_instructions(&git)
|
||||
.unwrap()
|
||||
.emit()
|
||||
.unwrap();
|
||||
}
|
|
@ -97,6 +97,15 @@ impl ZludaObject for Context {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_current_device() -> Result<hipDevice_t, CUerror> {
|
||||
STACK.with(|stack| {
|
||||
stack
|
||||
.try_borrow()
|
||||
.map_err(|_| CUerror::UNKNOWN)
|
||||
.and_then(|s| s.last().ok_or(CUerror::UNKNOWN).map(|(_, dev)| *dev))
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_current_context() -> Result<CUcontext, CUerror> {
|
||||
if let Some(ctx) = STACK.with(|stack| stack.borrow().last().copied().map(|(ctx, _)| ctx)) {
|
||||
return Ok(ctx);
|
||||
|
|
|
@ -17,6 +17,8 @@ mod os;
|
|||
pub(crate) struct GlobalState {
|
||||
pub devices: Vec<Device>,
|
||||
pub comgr: Comgr,
|
||||
pub comgr_clang_version: String,
|
||||
pub cache_path: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) struct Device {
|
||||
|
@ -52,8 +54,11 @@ pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
|||
let mut device_count = 0;
|
||||
unsafe { hipGetDeviceCount(&mut device_count) }?;
|
||||
let comgr = Comgr::new().map_err(|_| CUerror::UNKNOWN)?;
|
||||
let comgr_clang_version =
|
||||
comgr::get_clang_version(&comgr).map_err(|_| CUerror::UNKNOWN)?;
|
||||
Ok(GlobalState {
|
||||
comgr,
|
||||
comgr_clang_version,
|
||||
devices: (0..device_count)
|
||||
.map(|i| {
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
|
@ -68,6 +73,7 @@ pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
|
|||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
cache_path: zluda_cache::ModuleCache::create_cache_dir_and_get_path(),
|
||||
})
|
||||
})
|
||||
.as_ref()
|
||||
|
|
|
@ -65,26 +65,95 @@ fn get_ptx(image: *const ::core::ffi::c_void) -> Result<String, CUerror> {
|
|||
pub(crate) fn load_hip_module(image: *const std::ffi::c_void) -> Result<hipModule_t, CUerror> {
|
||||
let global_state = driver::global_state()?;
|
||||
let text = get_ptx(image)?;
|
||||
let ast = ptx_parser::parse_module_checked(&text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||
let mut dev = 0;
|
||||
unsafe { hipCtxGetDevice(&mut dev) }?;
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
unsafe { hipGetDevicePropertiesR0600(&mut props, dev) }?;
|
||||
let hip_properties = get_hip_properties()?;
|
||||
let gcn_arch = get_gcn_arch(&hip_properties)?;
|
||||
let attributes = ptx::Attributes {
|
||||
clock_rate: props.clockRate as u32,
|
||||
clock_rate: hip_properties.clockRate as u32,
|
||||
};
|
||||
let mut cache_with_key = global_state.cache_path.as_ref().and_then(|p| {
|
||||
let cache = zluda_cache::ModuleCache::open(p)?;
|
||||
let key = get_cache_key(global_state, gcn_arch, &text, &attributes)?;
|
||||
Some((cache, key))
|
||||
});
|
||||
let cached_binary = load_cached_binary(&mut cache_with_key);
|
||||
let elf_module = cached_binary.ok_or(CUerror::UNKNOWN).or_else(|_| {
|
||||
compile_from_ptx_and_cache(
|
||||
&global_state.comgr,
|
||||
gcn_arch,
|
||||
attributes,
|
||||
&text,
|
||||
&mut cache_with_key,
|
||||
)
|
||||
})?;
|
||||
let mut hip_module = unsafe { mem::zeroed() };
|
||||
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
|
||||
Ok(hip_module)
|
||||
}
|
||||
|
||||
fn get_hip_properties<'a>() -> Result<hipDeviceProp_tR0600, CUerror> {
|
||||
let hip_dev = super::context::get_current_device()?;
|
||||
let mut props = unsafe { mem::zeroed() };
|
||||
unsafe { hipGetDevicePropertiesR0600(&mut props, hip_dev) }?;
|
||||
Ok(props)
|
||||
}
|
||||
|
||||
fn get_gcn_arch<'a>(props: &'a hipDeviceProp_tR0600) -> Result<&'a str, CUerror> {
|
||||
let gcn_arch = unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) };
|
||||
gcn_arch.to_str().map_err(|_| CUerror::UNKNOWN)
|
||||
}
|
||||
|
||||
fn get_cache_key<'a, 'b>(
|
||||
global_state: &'static driver::GlobalState,
|
||||
isa: &'a str,
|
||||
text: &str,
|
||||
attributes: &ptx::Attributes,
|
||||
) -> Option<zluda_cache::ModuleKey<'a>> {
|
||||
// Serialization here is deterministic. When marking a type with
|
||||
// #[derive(serde::Serialize)] the derived implementation will just write
|
||||
// fields in the order of their declaration. It's not explictly guaranteed
|
||||
// by serde, but it is the only sensible thing to do, so I feel safe
|
||||
// to rely on it
|
||||
let serialized_attributes = serde_json::to_string(attributes).ok()?;
|
||||
Some(zluda_cache::ModuleKey {
|
||||
hash: blake3::hash(text.as_bytes()).to_hex(),
|
||||
compiler_version: &*global_state.comgr_clang_version,
|
||||
zluda_version: env!("VERGEN_GIT_SHA"),
|
||||
device: isa,
|
||||
backend_key: serialized_attributes,
|
||||
last_access: zluda_cache::ModuleCache::time_now(),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_cached_binary(
|
||||
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||
) -> Option<Vec<u8>> {
|
||||
cache_with_key
|
||||
.as_mut()
|
||||
.and_then(|(c, key)| c.get_module_binary(key))
|
||||
}
|
||||
|
||||
fn compile_from_ptx_and_cache(
|
||||
comgr: &comgr::Comgr,
|
||||
gcn_arch: &str,
|
||||
attributes: ptx::Attributes,
|
||||
text: &str,
|
||||
cache_with_key: &mut Option<(zluda_cache::ModuleCache, zluda_cache::ModuleKey)>,
|
||||
) -> Result<Vec<u8>, CUerror> {
|
||||
let ast = ptx_parser::parse_module_checked(text).map_err(|_| CUerror::NO_BINARY_FOR_GPU)?;
|
||||
let llvm_module = ptx::to_llvm_module(ast, attributes).map_err(|_| CUerror::UNKNOWN)?;
|
||||
let elf_module = comgr::compile_bitcode(
|
||||
&global_state.comgr,
|
||||
unsafe { CStr::from_ptr(props.gcnArchName.as_ptr()) },
|
||||
comgr,
|
||||
gcn_arch,
|
||||
&*llvm_module.llvm_ir.write_bitcode_to_memory(),
|
||||
&*llvm_module.attributes_ir.write_bitcode_to_memory(),
|
||||
llvm_module.linked_bitcode(),
|
||||
)
|
||||
.map_err(|_| CUerror::UNKNOWN)?;
|
||||
let mut hip_module = unsafe { mem::zeroed() };
|
||||
unsafe { hipModuleLoadData(&mut hip_module, elf_module.as_ptr().cast()) }?;
|
||||
Ok(hip_module)
|
||||
if let Some((cache, key)) = cache_with_key {
|
||||
key.last_access = zluda_cache::ModuleCache::time_now();
|
||||
cache.insert_module(key, &elf_module);
|
||||
}
|
||||
Ok(elf_module)
|
||||
}
|
||||
|
||||
pub(crate) fn load_data(module: &mut CUmodule, image: &std::ffi::c_void) -> CUresult {
|
||||
|
|
14
zluda_cache/Cargo.toml
Normal file
14
zluda_cache/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "zluda_cache"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
diesel = { version = "2.2.12", features = [
|
||||
"sqlite",
|
||||
"returning_clauses_for_sqlite_3_35",
|
||||
] }
|
||||
diesel_migrations = "2.2.0"
|
||||
libsqlite3-sys = { version = "0.35", features = ["bundled"] }
|
||||
dirs = "6.0.0"
|
||||
arrayvec = "0.7.6"
|
10
zluda_cache/diesel.toml
Normal file
10
zluda_cache/diesel.toml
Normal file
|
@ -0,0 +1,10 @@
|
|||
# For documentation on how to configure this file,
|
||||
# see https://diesel.rs/guides/configuring-diesel-cli
|
||||
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
custom_type_derives = ["diesel::query_builder::QueryId", "Clone"]
|
||||
sqlite_integer_primary_key_is_bigint = true
|
||||
|
||||
[migrations_directory]
|
||||
dir = "migrations"
|
0
zluda_cache/migrations/.keep
Normal file
0
zluda_cache/migrations/.keep
Normal file
|
@ -0,0 +1,2 @@
|
|||
DROP TABLE modules;
|
||||
DROP TABLE globals;
|
|
@ -0,0 +1,41 @@
|
|||
CREATE TABLE modules (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
hash TEXT NOT NULL,
|
||||
compiler_version TEXT NOT NULL,
|
||||
zluda_version TEXT NOT NULL,
|
||||
device TEXT NOT NULL,
|
||||
backend_key TEXT NOT NULL,
|
||||
binary BLOB NOT NULL,
|
||||
last_access BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS modules_index ON modules (hash, compiler_version, zluda_version, device, backend_key);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS globals (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BIGINT NOT NULL
|
||||
) WITHOUT ROWID;
|
||||
|
||||
INSERT OR IGNORE INTO globals (key, value) VALUES ('total_size', 0);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_size_on_delete
|
||||
AFTER
|
||||
DELETE ON modules FOR EACH ROW BEGIN
|
||||
UPDATE
|
||||
globals
|
||||
SET
|
||||
value = value - length(OLD.binary)
|
||||
WHERE
|
||||
key = 'total_size';
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS update_size_on_insert
|
||||
AFTER
|
||||
INSERT ON modules FOR EACH ROW BEGIN
|
||||
UPDATE
|
||||
globals
|
||||
SET
|
||||
value = value + length(NEW.binary)
|
||||
WHERE
|
||||
key = 'total_size';
|
||||
END;
|
231
zluda_cache/src/lib.rs
Normal file
231
zluda_cache/src/lib.rs
Normal file
|
@ -0,0 +1,231 @@
|
|||
use crate::schema::modules;
|
||||
use arrayvec::ArrayString;
|
||||
use diesel::{connection::SimpleConnection, prelude::*};
|
||||
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) mod models;
|
||||
pub(crate) mod schema;
|
||||
|
||||
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
|
||||
|
||||
pub struct ModuleKey<'a> {
|
||||
pub hash: ArrayString<64>,
|
||||
pub compiler_version: &'static str,
|
||||
pub zluda_version: &'static str,
|
||||
pub device: &'a str,
|
||||
pub backend_key: String,
|
||||
pub last_access: i64,
|
||||
}
|
||||
|
||||
pub struct ModuleCache(SqliteConnection);
|
||||
|
||||
impl ModuleCache {
|
||||
pub fn create_cache_dir_and_get_path() -> Option<String> {
|
||||
let mut cache_dir = dirs::cache_dir()?;
|
||||
cache_dir.extend(["zluda", "ComputeCache"]);
|
||||
// We ensure that the cache directory exists
|
||||
std::fs::create_dir_all(&cache_dir).ok()?;
|
||||
// No need to create the file, it will be created by SQLite on first access
|
||||
cache_dir.push("zluda.db");
|
||||
Some(cache_dir.to_string_lossy().into())
|
||||
}
|
||||
|
||||
pub fn open(file_path: &str) -> Option<Self> {
|
||||
let mut conn = SqliteConnection::establish(file_path).ok()?;
|
||||
conn.batch_execute("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")
|
||||
.ok()?;
|
||||
conn.run_pending_migrations(MIGRATIONS).ok()?;
|
||||
Some(Self(conn))
|
||||
}
|
||||
|
||||
pub fn get_module_binary(&mut self, key: &ModuleKey) -> Option<Vec<u8>> {
|
||||
diesel::update(modules::dsl::modules)
|
||||
.set(modules::last_access.eq(key.last_access))
|
||||
.filter(modules::hash.eq(key.hash.as_str()))
|
||||
.filter(modules::compiler_version.eq(&key.compiler_version))
|
||||
.filter(modules::zluda_version.eq(key.zluda_version))
|
||||
.filter(modules::device.eq(key.device))
|
||||
.filter(modules::backend_key.eq(&key.backend_key))
|
||||
.returning(modules::binary)
|
||||
.get_result(&mut self.0)
|
||||
.ok()
|
||||
}
|
||||
|
||||
pub fn insert_module(&mut self, key: &ModuleKey, binary: &[u8]) {
|
||||
diesel::insert_into(modules::dsl::modules)
|
||||
.values(models::AddModule {
|
||||
hash: key.hash.as_str(),
|
||||
compiler_version: &key.compiler_version,
|
||||
zluda_version: key.zluda_version,
|
||||
device: key.device,
|
||||
backend_key: &key.backend_key,
|
||||
last_access: key.last_access,
|
||||
binary,
|
||||
})
|
||||
.execute(&mut self.0)
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn time_now() -> i64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_millis() as i64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
schema::{globals::dsl::*, modules::dsl::*},
|
||||
ModuleCache,
|
||||
};
|
||||
use arrayvec::ArrayString;
|
||||
use diesel::prelude::*;
|
||||
|
||||
#[derive(Queryable, Selectable)]
|
||||
#[diesel(table_name = crate::schema::modules)]
|
||||
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||
pub struct Module {
|
||||
pub id: i64,
|
||||
pub hash: String,
|
||||
pub binary: Vec<u8>,
|
||||
pub last_access: i64,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Selectable)]
|
||||
#[diesel(table_name = crate::schema::globals)]
|
||||
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||
pub struct Global {
|
||||
pub key: String,
|
||||
pub value: i64,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_db_returns_no_module() {
|
||||
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||
let module_binary = db.get_module_binary(&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 123,
|
||||
});
|
||||
assert!(module_binary.is_none());
|
||||
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_modules.len(), 0);
|
||||
let all_globals: Vec<Global> = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_globals[0].key, "total_size");
|
||||
assert_eq!(all_globals[0].value, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn newly_inserted_module_increments_total_size() {
|
||||
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||
db.insert_module(
|
||||
&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash1").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 123,
|
||||
},
|
||||
&[1, 2, 3, 4, 5],
|
||||
);
|
||||
db.insert_module(
|
||||
&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash2").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 124,
|
||||
},
|
||||
&[1, 2, 3],
|
||||
);
|
||||
let mut all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||
all_modules.sort_by_key(|m: &Module| m.id);
|
||||
assert_eq!(all_modules.len(), 2);
|
||||
assert_eq!(all_modules[0].hash, "test_hash1");
|
||||
assert_eq!(all_modules[0].last_access, 123);
|
||||
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||
assert_eq!(all_modules[1].hash, "test_hash2");
|
||||
assert_eq!(all_modules[1].last_access, 124);
|
||||
assert_eq!(all_modules[1].binary, &[1, 2, 3]);
|
||||
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_globals[0].key, "total_size");
|
||||
assert_eq!(all_globals[0].value, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_bumps_last_access() {
|
||||
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||
db.insert_module(
|
||||
&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 123,
|
||||
},
|
||||
&[1, 2, 3, 4, 5],
|
||||
);
|
||||
let module_binary = db
|
||||
.get_module_binary(&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 124,
|
||||
})
|
||||
.unwrap();
|
||||
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_modules.len(), 1);
|
||||
assert_eq!(all_modules[0].last_access, 124);
|
||||
assert_eq!(module_binary, &[1, 2, 3, 4, 5]);
|
||||
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_globals[0].key, "total_size");
|
||||
assert_eq!(all_globals[0].value, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn double_insert_does_not_override() {
|
||||
let mut db = ModuleCache::open(":memory:").unwrap();
|
||||
db.insert_module(
|
||||
&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 123,
|
||||
},
|
||||
&[1, 2, 3, 4, 5],
|
||||
);
|
||||
db.insert_module(
|
||||
&super::ModuleKey {
|
||||
hash: ArrayString::from("test_hash").unwrap(),
|
||||
compiler_version: "1.0.0",
|
||||
zluda_version: "1.0.0",
|
||||
device: "test_device",
|
||||
backend_key: "{}".to_string(),
|
||||
last_access: 124,
|
||||
},
|
||||
&[5, 4, 3, 2, 1],
|
||||
);
|
||||
let all_modules = modules.select(Module::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_modules.len(), 1);
|
||||
assert_eq!(all_modules[0].last_access, 123);
|
||||
assert_eq!(all_modules[0].binary, &[1, 2, 3, 4, 5]);
|
||||
let all_globals = globals.select(Global::as_select()).load(&mut db.0).unwrap();
|
||||
assert_eq!(all_globals[0].key, "total_size");
|
||||
assert_eq!(all_globals[0].value, 5);
|
||||
}
|
||||
}
|
14
zluda_cache/src/models.rs
Normal file
14
zluda_cache/src/models.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use crate::schema::modules;
|
||||
use diesel::prelude::*;
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[diesel(table_name = modules)]
|
||||
pub(crate) struct AddModule<'a> {
|
||||
pub hash: &'a str,
|
||||
pub compiler_version: &'a str,
|
||||
pub zluda_version: &'a str,
|
||||
pub device: &'a str,
|
||||
pub backend_key: &'a str,
|
||||
pub binary: &'a [u8],
|
||||
pub last_access: i64,
|
||||
}
|
23
zluda_cache/src/schema.rs
Normal file
23
zluda_cache/src/schema.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
// @generated automatically by Diesel CLI.
|
||||
|
||||
diesel::table! {
|
||||
globals (key) {
|
||||
key -> Text,
|
||||
value -> BigInt,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
modules (id) {
|
||||
id -> BigInt,
|
||||
hash -> Text,
|
||||
compiler_version -> Text,
|
||||
zluda_version -> Text,
|
||||
device -> Text,
|
||||
backend_key -> Text,
|
||||
binary -> Binary,
|
||||
last_access -> BigInt,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::allow_tables_to_appear_in_same_query!(globals, modules,);
|
Loading…
Add table
Add a link
Reference in a new issue