mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Refactor host code to use one big lock
This commit is contained in:
parent
7c93997cc9
commit
a2e77fe961
15 changed files with 914 additions and 540 deletions
|
@ -173,6 +173,16 @@ impl Context {
|
|||
check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
|
||||
Ok(Context(result))
|
||||
}
|
||||
|
||||
pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
|
||||
check! {
|
||||
sys::zeMemFree(
|
||||
self.0,
|
||||
ptr,
|
||||
)
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Context {
|
||||
|
@ -239,7 +249,7 @@ pub struct Module(sys::ze_module_handle_t);
|
|||
|
||||
impl Module {
|
||||
// HACK ALERT
|
||||
// We use OpenCL for now to do SPIR-V linking, because Level0
|
||||
// We use OpenCL for now to do SPIR-V linking, because Level0
|
||||
// does not allow linking. Don't let presence of zeModuleDynamicLink fool
|
||||
// you, it's not currently possible to create non-compiled modules.
|
||||
// zeModuleCreate always compiles (builds and links).
|
||||
|
|
27
notcuda/build.rs
Normal file
27
notcuda/build.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
// HACK ALERT
|
||||
// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries
|
||||
// overriding path to OpenCL
|
||||
|
||||
fn main() {
|
||||
if cfg!(windows) {
|
||||
let known_sdk = [
|
||||
// E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\"
|
||||
("INTELOCLSDKROOT", "x64", "x86"),
|
||||
// E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\"
|
||||
("CUDA_PATH", "x64", "Win32"),
|
||||
// E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\"
|
||||
("AMDAPPSDKROOT", "x86_64", "x86"),
|
||||
];
|
||||
|
||||
for info in known_sdk.iter() {
|
||||
if let Ok(sdk) = std::env::var(info.0) {
|
||||
let mut path = std::path::PathBuf::from(sdk);
|
||||
path.push("lib");
|
||||
path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 });
|
||||
println!("cargo:rustc-link-search=native={}", path.display());
|
||||
}
|
||||
}
|
||||
|
||||
println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64");
|
||||
}
|
||||
}
|
|
@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int)
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult {
|
||||
r#impl::device::get(device.decuda(), ordinal)
|
||||
r#impl::device::get(device.decuda(), ordinal).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult {
|
||||
r#impl::device::get_count(count)
|
||||
r#impl::device::get_count(count).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult {
|
|||
cuDevicePrimaryCtxReset_v2(dev)
|
||||
}
|
||||
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
|
@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2(
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult {
|
||||
r#impl::context::destroy_v2(ctx.decuda())
|
||||
r#impl::context::destroy_v2(ctx.decuda()).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult {
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult {
|
||||
r#impl::context::get_device(device.decuda())
|
||||
r#impl::context::get_device(device.decuda()).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion(
|
|||
ctx: CUcontext,
|
||||
version: *mut ::std::os::raw::c_uint,
|
||||
) -> CUresult {
|
||||
r#impl::context::get_api_version(ctx.decuda(), version)
|
||||
r#impl::context::get_api_version(ctx.decuda(), version).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult {
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::context::attach(pctx.decuda(), flags).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::context::detach(ctx.decuda()).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData(
|
|||
module: *mut CUmodule,
|
||||
image: *const ::std::os::raw::c_void,
|
||||
) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::module::load_data(module.decuda(), image).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult {
|
||||
r#impl::memory::alloc_v2(dptr.decuda(), bytesize)
|
||||
r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate(
|
|||
phStream: *mut CUstream,
|
||||
Flags: ::std::os::raw::c_uint,
|
||||
) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::stream::create(phStream.decuda(), Flags).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags(
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult {
|
|||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult {
|
||||
r#impl::unimplemented()
|
||||
r#impl::stream::destroy_v2(hStream.decuda()).encuda()
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), no_mangle)]
|
||||
|
|
|
@ -1,18 +1,15 @@
|
|||
use super::CUresult;
|
||||
use super::{device, HasLivenessCookie, LiveCheck};
|
||||
use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck};
|
||||
use super::{CUresult, GlobalState};
|
||||
use crate::{cuda::CUcontext, cuda_impl};
|
||||
use l0::sys::ze_result_t;
|
||||
use std::mem::{self, ManuallyDrop};
|
||||
use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
num::NonZeroU32,
|
||||
os::raw::c_uint,
|
||||
ptr,
|
||||
sync::{atomic::AtomicU32, Mutex},
|
||||
collections::HashSet,
|
||||
mem::{self},
|
||||
};
|
||||
|
||||
thread_local! {
|
||||
pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new());
|
||||
pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new());
|
||||
}
|
||||
|
||||
pub type Context = LiveCheck<ContextData>;
|
||||
|
@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData {
|
|||
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
const COOKIE: usize = 0x0b643ffb;
|
||||
|
||||
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
|
||||
|
||||
fn try_drop(&mut self) -> Result<(), CUresult> {
|
||||
for stream in self.streams.iter() {
|
||||
let stream = unsafe { &mut **stream };
|
||||
stream.context = ptr::null_mut();
|
||||
Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
enum ContextRefCount {
|
||||
|
@ -67,26 +75,16 @@ impl ContextRefCount {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_primary(&self) -> bool {
|
||||
match self {
|
||||
ContextRefCount::Primary => true,
|
||||
ContextRefCount::NonPrimary(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextData {
|
||||
pub flags: AtomicU32,
|
||||
pub device_index: device::Index,
|
||||
// This pointer is null only for a moment when constructing primary context
|
||||
pub device: *const Mutex<device::Device>,
|
||||
// The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState
|
||||
pub mutable: Mutex<ContextDataMutable>,
|
||||
}
|
||||
|
||||
pub struct ContextDataMutable {
|
||||
pub device: *mut device::Device,
|
||||
ref_count: ContextRefCount,
|
||||
pub default_stream: StreamData,
|
||||
pub streams: HashSet<*mut StreamData>,
|
||||
// All the fields below are here to support internal CUDA driver API
|
||||
pub cuda_manager: *mut cuda_impl::rt::ContextStateManager,
|
||||
pub cuda_state: *mut cuda_impl::rt::ContextState,
|
||||
pub cuda_dtor_cb: Option<
|
||||
|
@ -100,63 +98,75 @@ pub struct ContextDataMutable {
|
|||
|
||||
impl ContextData {
|
||||
pub fn new(
|
||||
l0_ctx: &mut l0::Context,
|
||||
l0_dev: &l0::Device,
|
||||
flags: c_uint,
|
||||
is_primary: bool,
|
||||
dev_index: device::Index,
|
||||
dev: *const Mutex<device::Device>,
|
||||
) -> Self {
|
||||
ContextData {
|
||||
dev: *mut device::Device,
|
||||
) -> Result<Self, CUresult> {
|
||||
let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?;
|
||||
Ok(ContextData {
|
||||
flags: AtomicU32::new(flags),
|
||||
device_index: dev_index,
|
||||
device: dev,
|
||||
mutable: Mutex::new(ContextDataMutable {
|
||||
ref_count: ContextRefCount::new(is_primary),
|
||||
cuda_manager: ptr::null_mut(),
|
||||
cuda_state: ptr::null_mut(),
|
||||
cuda_dtor_cb: None,
|
||||
}),
|
||||
}
|
||||
ref_count: ContextRefCount::new(is_primary),
|
||||
default_stream,
|
||||
streams: HashSet::new(),
|
||||
cuda_manager: ptr::null_mut(),
|
||||
cuda_state: ptr::null_mut(),
|
||||
cuda_dtor_cb: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult {
|
||||
impl Context {
|
||||
pub fn late_init(&mut self) {
|
||||
let ctx_data = self.as_option_mut().unwrap();
|
||||
ctx_data.default_stream.context = ctx_data as *mut _;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_v2(
|
||||
pctx: *mut *mut Context,
|
||||
flags: u32,
|
||||
dev_idx: device::Index,
|
||||
) -> Result<(), CUresult> {
|
||||
if pctx == ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let dev = device::get_device_ref(dev_idx);
|
||||
let dev = match dev {
|
||||
Ok(d) => d,
|
||||
Err(e) => return e,
|
||||
};
|
||||
let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev)));
|
||||
let ctx_ref = ctx.as_mut() as *mut Context;
|
||||
let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
|
||||
let dev_ptr = dev as *mut _;
|
||||
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
|
||||
&mut dev.l0_context,
|
||||
&dev.base,
|
||||
flags,
|
||||
false,
|
||||
dev_ptr as *mut _,
|
||||
)?));
|
||||
ctx_box.late_init();
|
||||
Ok::<_, CUresult>(ctx_box)
|
||||
})??;
|
||||
let ctx_ref = ctx_box.as_mut() as *mut Context;
|
||||
unsafe { *pctx = ctx_ref };
|
||||
mem::forget(ctx);
|
||||
mem::forget(ctx_box);
|
||||
CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref));
|
||||
CUresult::CUDA_SUCCESS
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn destroy_v2(ctx: *mut Context) -> CUresult {
|
||||
pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> {
|
||||
if ctx == ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
CONTEXT_STACK.with(|stack| {
|
||||
let mut stack = stack.borrow_mut();
|
||||
let should_pop = match stack.last() {
|
||||
Some(active_ctx) => *active_ctx == (ctx as *const _),
|
||||
Some(active_ctx) => *active_ctx == (ctx as *mut _),
|
||||
None => false,
|
||||
};
|
||||
if should_pop {
|
||||
stack.pop();
|
||||
}
|
||||
});
|
||||
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) });
|
||||
if !ctx_box.try_drop() {
|
||||
CUresult::CUDA_ERROR_INVALID_CONTEXT
|
||||
} else {
|
||||
unsafe { ManuallyDrop::drop(&mut ctx_box) };
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
GlobalState::lock(|_| Context::destroy_impl(ctx))?
|
||||
}
|
||||
|
||||
pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
|
||||
|
@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
|
|||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
|
||||
pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> {
|
||||
CONTEXT_STACK.with(|stack| {
|
||||
stack
|
||||
.borrow()
|
||||
.last()
|
||||
.and_then(|c| unsafe { &**c }.as_ref())
|
||||
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
|
||||
.map(f)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> {
|
||||
if pctx == ptr::null_mut() {
|
||||
return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT);
|
||||
|
@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult {
|
||||
let _ctx = match unsafe { ctx.as_mut() } {
|
||||
None => return CUresult::CUDA_ERROR_INVALID_VALUE,
|
||||
Some(ctx) => match ctx.as_mut() {
|
||||
None => return CUresult::CUDA_ERROR_INVALID_CONTEXT,
|
||||
Some(ctx) => ctx,
|
||||
},
|
||||
};
|
||||
pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
|
||||
if ctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
GlobalState::lock(|_| {
|
||||
unsafe { &*ctx }.as_result()?;
|
||||
Ok::<_, CUresult>(())
|
||||
})??;
|
||||
//TODO: query device for properties roughly matching CUDA API version
|
||||
unsafe { *version = 1100 };
|
||||
CUresult::CUDA_SUCCESS
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_device(dev: *mut device::Index) -> CUresult {
|
||||
let dev_idx = with_current(|ctx| ctx.device_index);
|
||||
match dev_idx {
|
||||
Ok(idx) => {
|
||||
unsafe { *dev = idx }
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
Err(err) => err,
|
||||
pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> {
|
||||
let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?;
|
||||
unsafe { *dev = dev_idx };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> {
|
||||
if pctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
|
||||
let ctx = unchecked_ctx.as_result_mut()?;
|
||||
ctx.ref_count.incr()?;
|
||||
Ok::<_, CUresult>(unchecked_ctx as *mut _)
|
||||
})??;
|
||||
unsafe { *pctx = ctx };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
|
||||
if pctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
|
||||
let ctx = unchecked_ctx.as_result_mut()?;
|
||||
if ctx.ref_count.decr() {
|
||||
Context::destroy_impl(unchecked_ctx)?;
|
||||
}
|
||||
Ok::<_, CUresult>(())
|
||||
})?
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn is_context_stack_empty() -> bool {
|
||||
CONTEXT_STACK.with(|stack| stack.borrow().is_empty())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod test {
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
use std::{ffi::c_void, ptr};
|
||||
|
|
|
@ -1,24 +1,21 @@
|
|||
use super::{context, transmute_lifetime, CUresult, Error};
|
||||
use super::{context, CUresult, GlobalState};
|
||||
use crate::cuda;
|
||||
use cuda::{CUdevice_attribute, CUuuid_st};
|
||||
use std::{
|
||||
cmp, mem,
|
||||
os::raw::{c_char, c_int},
|
||||
ptr,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Mutex, MutexGuard,
|
||||
},
|
||||
sync::atomic::{AtomicU32, Ordering},
|
||||
};
|
||||
|
||||
const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]";
|
||||
static mut DEVICES: Option<Vec<Mutex<Device>>> = None;
|
||||
|
||||
#[repr(transparent)]
|
||||
#[derive(Clone, Copy)]
|
||||
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
|
||||
pub struct Index(pub c_int);
|
||||
|
||||
pub struct Device {
|
||||
pub index: Index,
|
||||
pub base: l0::Device,
|
||||
pub default_queue: l0::CommandQueue,
|
||||
pub l0_context: l0::Context,
|
||||
|
@ -33,17 +30,19 @@ unsafe impl Send for Device {}
|
|||
|
||||
impl Device {
|
||||
// Unsafe because it does not fully initalize primary_context
|
||||
unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result<Self> {
|
||||
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
|
||||
let mut ctx = l0::Context::new(drv)?;
|
||||
let queue = l0::CommandQueue::new(&mut ctx, &d)?;
|
||||
let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
|
||||
let primary_context = context::Context::new(context::ContextData::new(
|
||||
&mut ctx,
|
||||
&l0_dev,
|
||||
0,
|
||||
true,
|
||||
Index(idx as c_int),
|
||||
ptr::null(),
|
||||
));
|
||||
ptr::null_mut(),
|
||||
)?);
|
||||
Ok(Self {
|
||||
base: d,
|
||||
index: Index(idx as c_int),
|
||||
base: l0_dev,
|
||||
default_queue: queue,
|
||||
l0_context: ctx,
|
||||
primary_context: primary_context,
|
||||
|
@ -93,83 +92,53 @@ impl Device {
|
|||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn late_init(&mut self) {
|
||||
self.primary_context.as_option_mut().unwrap().device = self as *mut _;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(driver: &l0::Driver) -> l0::Result<()> {
|
||||
pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
|
||||
let ze_devices = driver.devices()?;
|
||||
let mut devices = ze_devices
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new))
|
||||
.map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
for d in devices.iter_mut() {
|
||||
d.get_mut()
|
||||
.unwrap()
|
||||
.primary_context
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.device = d;
|
||||
for dev in devices.iter_mut() {
|
||||
dev.late_init();
|
||||
dev.primary_context.late_init();
|
||||
}
|
||||
unsafe { DEVICES = Some(devices) };
|
||||
Ok(devices)
|
||||
}
|
||||
|
||||
pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
|
||||
let len = GlobalState::lock(|state| state.devices.len())?;
|
||||
unsafe { *count = len as c_int };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn devices() -> Result<&'static Vec<Mutex<Device>>, CUresult> {
|
||||
match unsafe { &DEVICES } {
|
||||
Some(devs) => Ok(devs),
|
||||
None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex<Device>, CUresult> {
|
||||
let devs = devices()?;
|
||||
if dev_idx < 0 || dev_idx >= devs.len() as c_int {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
|
||||
}
|
||||
Ok(&devs[dev_idx as usize])
|
||||
}
|
||||
|
||||
pub fn get_device(dev_idx: Index) -> Result<MutexGuard<'static, Device>, CUresult> {
|
||||
let dev = get_device_ref(dev_idx)?;
|
||||
dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
|
||||
}
|
||||
|
||||
pub fn get_count(count: *mut c_int) -> CUresult {
|
||||
let len = devices().map(|d| d.len());
|
||||
match len {
|
||||
Ok(len) => {
|
||||
unsafe { *count = len as c_int };
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(device: *mut Index, ordinal: c_int) -> CUresult {
|
||||
pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> {
|
||||
if device == ptr::null_mut() || ordinal < 0 {
|
||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let len = devices().map(|d| d.len());
|
||||
match len {
|
||||
Ok(len) if ordinal < (len as i32) => {
|
||||
unsafe { *device = Index(ordinal) };
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE,
|
||||
Err(e) => e,
|
||||
let len = GlobalState::lock(|state| state.devices.len())?;
|
||||
if ordinal < (len as i32) {
|
||||
unsafe { *device = Index(ordinal) };
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> {
|
||||
pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> {
|
||||
if name == ptr::null_mut() || len < 0 {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
// This is safe because devices are 'static
|
||||
let name_ptr = {
|
||||
let mut dev = get_device(dev)?;
|
||||
let props = dev.get_properties().map_err(Into::<CUresult>::into)?;
|
||||
props.name.as_ptr()
|
||||
};
|
||||
let name_ptr = GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr())
|
||||
})??;
|
||||
let name_len = (0..256)
|
||||
.position(|i| unsafe { *name_ptr.add(i) } == 0)
|
||||
.unwrap_or(256);
|
||||
|
@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult>
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> {
|
||||
pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> {
|
||||
if bytes == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
// This is safe because devices are 'static
|
||||
let mem_props = {
|
||||
let mut dev = get_device(dev)?;
|
||||
unsafe {
|
||||
transmute_lifetime(
|
||||
dev.get_memory_properties()
|
||||
.map_err(Into::<CUresult>::into)?,
|
||||
)
|
||||
}
|
||||
};
|
||||
let mem_props = GlobalState::lock_device(dev_idx, |dev| {
|
||||
let mem_props = dev.get_memory_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(mem_props)
|
||||
})??;
|
||||
let max_mem = mem_props
|
||||
.iter()
|
||||
.map(|p| p.totalSize)
|
||||
|
@ -228,56 +191,101 @@ impl CUdevice_attribute {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> {
|
||||
pub fn get_attribute(
|
||||
pi: *mut i32,
|
||||
attrib: CUdevice_attribute,
|
||||
dev_idx: Index,
|
||||
) -> Result<(), CUresult> {
|
||||
if pi == ptr::null_mut() {
|
||||
return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE));
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
if let Some(value) = attrib.get_static_value() {
|
||||
unsafe { *pi = value };
|
||||
return Ok(());
|
||||
}
|
||||
let mut dev = get_device(dev).map_err(Error::Cuda)?;
|
||||
let value = match attrib {
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
|
||||
dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => {
|
||||
let props = dev.get_properties().map_err(Error::L0)?;
|
||||
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(
|
||||
(props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32,
|
||||
)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => {
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_image_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(cmp::min(
|
||||
props.maxImageDims1D,
|
||||
c_int::max_value() as u32,
|
||||
) as c_int)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min(
|
||||
dev.get_image_properties()
|
||||
.map_err(Error::L0)?
|
||||
.maxImageDims1D,
|
||||
c_int::max_value() as u32,
|
||||
) as c_int,
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(cmp::max(
|
||||
i32::max_value() as u32,
|
||||
props.maxGroupCountX,
|
||||
) as i32)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(cmp::max(
|
||||
i32::max_value() as u32,
|
||||
props.maxGroupCountY,
|
||||
) as i32)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(cmp::max(
|
||||
i32::max_value() as u32,
|
||||
props.maxGroupCountZ,
|
||||
) as i32)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32,
|
||||
)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32,
|
||||
)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(
|
||||
cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32,
|
||||
)
|
||||
})??
|
||||
}
|
||||
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
|
||||
let props = dev.get_compute_properties().map_err(Error::L0)?;
|
||||
cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32
|
||||
GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_compute_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(cmp::max(
|
||||
i32::max_value() as u32,
|
||||
props.maxTotalGroupSize,
|
||||
) as i32)
|
||||
})??
|
||||
}
|
||||
_ => {
|
||||
// TODO: support more attributes for CUDA runtime
|
||||
|
@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
|
||||
let ze_uuid = {
|
||||
get_device(dev)
|
||||
.map_err(Error::Cuda)?
|
||||
.get_properties()
|
||||
.map_err(Error::L0)?
|
||||
.uuid
|
||||
};
|
||||
pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
|
||||
let ze_uuid = GlobalState::lock_device(dev_idx, |dev| {
|
||||
let props = dev.get_properties()?;
|
||||
Ok::<_, l0::sys::ze_result_t>(props.uuid)
|
||||
})??;
|
||||
unsafe {
|
||||
*uuid = CUuuid_st {
|
||||
bytes: mem::transmute(ze_uuid.id),
|
||||
|
@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn with_current_exclusive<F: FnOnce(&mut Device) -> R, R>(f: F) -> Result<R, CUresult> {
|
||||
let dev = super::context::with_current(|ctx| ctx.device);
|
||||
dev.and_then(|dev| {
|
||||
unsafe { &*dev }
|
||||
.try_lock()
|
||||
.map(|mut dev| f(&mut dev))
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_exclusive<F: FnOnce(&mut Device) -> R, R>(dev: Index, f: F) -> Result<R, CUresult> {
|
||||
let dev = get_device_ref(dev)?;
|
||||
dev.try_lock()
|
||||
.map(|mut dev| f(&mut dev))
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
|
||||
}
|
||||
|
||||
pub fn primary_ctx_get_state(
|
||||
idx: Index,
|
||||
dev_idx: Index,
|
||||
flags: *mut u32,
|
||||
active: *mut i32,
|
||||
) -> Result<(), CUresult> {
|
||||
let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| {
|
||||
let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| {
|
||||
// This is safe because primary context can't be dropped
|
||||
let ctx_ptr = &dev.primary_context as *const _;
|
||||
let ctx_ptr = &mut dev.primary_context as *mut _;
|
||||
let flags_ptr =
|
||||
(&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32;
|
||||
(ctx_ptr, flags_ptr)
|
||||
})?;
|
||||
let is_active = context::CONTEXT_STACK
|
||||
.with(|stack| stack.borrow().last().map(|x| *x))
|
||||
.map(|current| current == ctx_ptr)
|
||||
.unwrap_or(false);
|
||||
let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
|
||||
unsafe { *flags = flags_value };
|
||||
let is_active = context::CONTEXT_STACK
|
||||
.with(|stack| stack.borrow().last().map(|x| *x))
|
||||
.map(|current| current == ctx_ptr)
|
||||
.unwrap_or(false);
|
||||
let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
|
||||
Ok::<_, l0::sys::ze_result_t>((is_active, flags_value))
|
||||
})??;
|
||||
unsafe { *active = if is_active { 1 } else { 0 } };
|
||||
unsafe { *flags = flags_value };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> {
|
||||
let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?;
|
||||
pub fn primary_ctx_retain(
|
||||
pctx: *mut *mut context::Context,
|
||||
dev_idx: Index,
|
||||
) -> Result<(), CUresult> {
|
||||
let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?;
|
||||
unsafe { *pctx = ctx_ptr };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod test {
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
cuda_impl,
|
||||
};
|
||||
|
||||
use super::{context, device, module, Decuda, Encuda};
|
||||
use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
|
||||
use std::mem;
|
||||
use std::os::raw::{c_uint, c_ulong, c_ushort};
|
||||
use std::{
|
||||
|
@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
|
|||
VTableEntry { ptr: ptr::null() },
|
||||
];
|
||||
|
||||
unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
|
||||
cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
|
||||
}
|
||||
|
||||
fn cudart_interface_fn1_impl(
|
||||
pctx: *mut *mut context::Context,
|
||||
dev: device::Index,
|
||||
) -> Result<(), CUresult> {
|
||||
let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
|
||||
unsafe { *pctx = ctx_ptr };
|
||||
Ok(())
|
||||
unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
|
||||
super::unimplemented()
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin(
|
|||
ptr1: *mut c_void,
|
||||
ptr2: *mut c_void,
|
||||
) -> CUresult {
|
||||
// Not sure what those twoparameters are actually used for,
|
||||
// Not sure what those two parameters are actually used for,
|
||||
// they are somehow involved in __cudaRegisterHostVar
|
||||
if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
|
||||
return CUresult::CUDA_ERROR_NOT_SUPPORTED;
|
||||
|
@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin(
|
|||
},
|
||||
Err(_) => continue,
|
||||
};
|
||||
let module = module::ModuleData::compile_spirv(kernel_text_string);
|
||||
let module = module::SpirvModule::new(kernel_text_string);
|
||||
match module {
|
||||
Ok(module) => {
|
||||
*result = Box::into_raw(Box::new(module));
|
||||
match module::load_data_impl(result, module) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return err,
|
||||
}
|
||||
return CUresult::CUDA_SUCCESS;
|
||||
}
|
||||
Err(_) => continue,
|
||||
|
@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor(
|
|||
}
|
||||
|
||||
fn context_local_storage_ctor_impl(
|
||||
mut cu_ctx: *mut context::Context,
|
||||
cu_ctx: *mut context::Context,
|
||||
mgr: *mut cuda_impl::rt::ContextStateManager,
|
||||
ctx_state: *mut cuda_impl::rt::ContextState,
|
||||
dtor_cb: Option<
|
||||
|
@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl(
|
|||
),
|
||||
>,
|
||||
) -> Result<(), CUresult> {
|
||||
if cu_ctx == ptr::null_mut() {
|
||||
context::get_current(&mut cu_ctx)?;
|
||||
}
|
||||
if cu_ctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
unsafe { &*cu_ctx }
|
||||
.as_ref()
|
||||
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
|
||||
.and_then(|ctx| {
|
||||
ctx.mutable
|
||||
.try_lock()
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
|
||||
.map(|mut mutable| {
|
||||
mutable.cuda_manager = mgr;
|
||||
mutable.cuda_state = ctx_state;
|
||||
mutable.cuda_dtor_cb = dtor_cb;
|
||||
})
|
||||
})?;
|
||||
Ok(())
|
||||
lock_context(cu_ctx, |ctx: &mut ContextData| {
|
||||
ctx.cuda_manager = mgr;
|
||||
ctx.cuda_state = ctx_state;
|
||||
ctx.cuda_dtor_cb = dtor_cb;
|
||||
})
|
||||
}
|
||||
|
||||
// some kind of dtor
|
||||
|
@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state(
|
|||
|
||||
fn context_local_storage_get_state_impl(
|
||||
ctx_state: *mut *mut cuda_impl::rt::ContextState,
|
||||
mut cu_ctx: *mut context::Context,
|
||||
cu_ctx: *mut context::Context,
|
||||
_: *mut cuda_impl::rt::ContextStateManager,
|
||||
) -> Result<(), CUresult> {
|
||||
if cu_ctx == ptr::null_mut() {
|
||||
context::get_current(&mut cu_ctx)?;
|
||||
}
|
||||
if cu_ctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let cuda_state = unsafe { &*cu_ctx }
|
||||
.as_ref()
|
||||
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
|
||||
.and_then(|ctx| {
|
||||
ctx.mutable
|
||||
.try_lock()
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
|
||||
.map(|mutable| mutable.cuda_state)
|
||||
})?;
|
||||
let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
|
||||
if cuda_state == ptr::null_mut() {
|
||||
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
|
||||
} else {
|
||||
|
@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl(
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn lock_context<T>(
|
||||
cu_ctx: *mut context::Context,
|
||||
fn_impl: impl FnOnce(&mut ContextData) -> T,
|
||||
) -> Result<T, CUresult> {
|
||||
if cu_ctx == ptr::null_mut() {
|
||||
GlobalState::lock_current_context(fn_impl)
|
||||
} else {
|
||||
GlobalState::lock(|_| {
|
||||
let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
|
||||
Ok(fn_impl(ctx))
|
||||
})?
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,28 @@
|
|||
use ::std::os::raw::{c_uint, c_void};
|
||||
use std::ptr;
|
||||
|
||||
use super::{device, stream::Stream, CUresult};
|
||||
use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream};
|
||||
|
||||
pub struct Function {
|
||||
pub type Function = LiveCheck<FunctionData>;
|
||||
|
||||
impl HasLivenessCookie for FunctionData {
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
const COOKIE: usize = 0x5e2ab14d5840678e;
|
||||
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
const COOKIE: usize = 0x33e6a1e6;
|
||||
|
||||
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
|
||||
|
||||
fn try_drop(&mut self) -> Result<(), CUresult> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
pub base: l0::Kernel<'static>,
|
||||
pub arg_size: Vec<usize>,
|
||||
pub use_shared_mem: bool,
|
||||
}
|
||||
|
||||
pub fn launch_kernel(
|
||||
|
@ -17,36 +34,43 @@ pub fn launch_kernel(
|
|||
block_dim_y: c_uint,
|
||||
block_dim_z: c_uint,
|
||||
shared_mem_bytes: c_uint,
|
||||
strean: *mut Stream,
|
||||
hstream: *mut Stream,
|
||||
kernel_params: *mut *mut c_void,
|
||||
extra: *mut *mut c_void,
|
||||
) -> Result<(), CUresult> {
|
||||
if f == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() {
|
||||
if extra != ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
|
||||
}
|
||||
let func = unsafe { &*f };
|
||||
for (i, arg_size) in func.arg_size.iter().copied().enumerate() {
|
||||
unsafe {
|
||||
func.base
|
||||
.set_arg_raw(i as u32, arg_size, *kernel_params.add(i))?
|
||||
};
|
||||
}
|
||||
unsafe { &*f }
|
||||
.base
|
||||
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
|
||||
device::with_current_exclusive(|dev| {
|
||||
let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?;
|
||||
GlobalState::lock_stream(hstream, |stream| {
|
||||
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
|
||||
for (i, arg_size) in func.arg_size.iter().enumerate() {
|
||||
unsafe {
|
||||
func.base
|
||||
.set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))?
|
||||
};
|
||||
}
|
||||
if func.use_shared_mem {
|
||||
unsafe {
|
||||
func.base.set_arg_raw(
|
||||
func.arg_size.len() as u32,
|
||||
shared_mem_bytes as usize,
|
||||
ptr::null(),
|
||||
)?
|
||||
};
|
||||
}
|
||||
func.base
|
||||
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
cmd_list.append_launch_kernel(
|
||||
&unsafe { &*f }.base,
|
||||
&mut func.base,
|
||||
&[grid_dim_x, grid_dim_y, grid_dim_z],
|
||||
None,
|
||||
&mut [],
|
||||
)?;
|
||||
dev.default_queue.execute(cmd_list)?;
|
||||
l0::Result::Ok(())
|
||||
})??;
|
||||
Ok(())
|
||||
stream.queue.execute(cmd_list)?;
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
|
|
@ -1,57 +1,34 @@
|
|||
use super::CUresult;
|
||||
use super::{stream, CUresult, GlobalState};
|
||||
use std::ffi::c_void;
|
||||
|
||||
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
|
||||
let alloc_result = super::device::with_current_exclusive(|dev| unsafe {
|
||||
dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0)
|
||||
});
|
||||
match alloc_result {
|
||||
Ok(Ok(alloc)) => {
|
||||
unsafe { *dptr = alloc };
|
||||
CUresult::CUDA_SUCCESS
|
||||
}
|
||||
Ok(Err(e)) => e.into(),
|
||||
Err(e) => e,
|
||||
}
|
||||
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
|
||||
let ptr = GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
|
||||
})??;
|
||||
unsafe { *dptr = ptr };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn copy_v2(
|
||||
dst: *mut c_void,
|
||||
src: *const c_void,
|
||||
bytesize: usize,
|
||||
) -> Result<Result<(), l0::sys::ze_result_t>, CUresult> {
|
||||
super::device::with_current_exclusive(|dev| unsafe {
|
||||
memcpy_impl(
|
||||
&mut dev.l0_context,
|
||||
dst,
|
||||
src,
|
||||
bytesize,
|
||||
&dev.base,
|
||||
&mut dev.default_queue,
|
||||
)
|
||||
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
|
||||
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
|
||||
let mut cmd_list = stream.command_list()?;
|
||||
unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
|
||||
stream.queue.execute(cmd_list)?;
|
||||
Ok::<_, CUresult>(())
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
|
||||
GlobalState::lock_current_context(|ctx| {
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
|
||||
})
|
||||
}
|
||||
|
||||
unsafe fn memcpy_impl(
|
||||
ctx: &mut l0::Context,
|
||||
dst: *mut c_void,
|
||||
src: *const c_void,
|
||||
bytes_count: usize,
|
||||
dev: &l0::Device,
|
||||
queue: &mut l0::CommandQueue,
|
||||
) -> l0::Result<()> {
|
||||
let mut cmd_list = l0::CommandList::new(ctx, &dev)?;
|
||||
cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?;
|
||||
queue.execute(cmd_list)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
|
||||
Ok(())
|
||||
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod test {
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
use std::ptr;
|
||||
|
@ -82,4 +59,20 @@ mod tests {
|
|||
assert_ne!(mem, ptr::null_mut());
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
cuda_driver_test!(free_without_ctx);
|
||||
|
||||
fn free_without_ctx<T: CudaDriverFns>() {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut mem = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_ne!(mem, ptr::null_mut());
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
|
||||
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
|
||||
use crate::{
|
||||
cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
|
||||
r#impl::device::Device,
|
||||
};
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
mem::{self, ManuallyDrop},
|
||||
os::raw::c_int,
|
||||
ptr,
|
||||
sync::Mutex,
|
||||
sync::TryLockError,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
|
@ -7,9 +17,9 @@ pub mod test;
|
|||
pub mod context;
|
||||
pub mod device;
|
||||
pub mod export_table;
|
||||
pub mod function;
|
||||
pub mod memory;
|
||||
pub mod module;
|
||||
pub mod function;
|
||||
pub mod stream;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
|
@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult {
|
|||
CUresult::CUDA_ERROR_NOT_SUPPORTED
|
||||
}
|
||||
|
||||
pub trait HasLivenessCookie {
|
||||
pub trait HasLivenessCookie: Sized {
|
||||
const COOKIE: usize;
|
||||
const LIVENESS_FAIL: CUresult;
|
||||
|
||||
fn try_drop(&mut self) -> Result<(), CUresult>;
|
||||
}
|
||||
|
||||
// This struct is a best-effort check if wrapped value has been dropped,
|
||||
|
@ -42,19 +55,23 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
|
||||
let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
|
||||
ctx_box.try_drop()?;
|
||||
unsafe { ManuallyDrop::drop(&mut ctx_box) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
|
||||
let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
|
||||
outer_ptr as *mut Self
|
||||
}
|
||||
|
||||
pub unsafe fn as_ref_unchecked(&self) -> &T {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn as_ref(&self) -> Option<&T> {
|
||||
if self.cookie == T::COOKIE {
|
||||
Some(&self.data)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut(&mut self) -> Option<&mut T> {
|
||||
pub fn as_option_mut(&mut self) -> Option<&mut T> {
|
||||
if self.cookie == T::COOKIE {
|
||||
Some(&mut self.data)
|
||||
} else {
|
||||
|
@ -62,14 +79,31 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn as_result(&self) -> Result<&T, CUresult> {
|
||||
if self.cookie == T::COOKIE {
|
||||
Ok(&self.data)
|
||||
} else {
|
||||
Err(T::LIVENESS_FAIL)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
|
||||
if self.cookie == T::COOKIE {
|
||||
Ok(&mut self.data)
|
||||
} else {
|
||||
Err(T::LIVENESS_FAIL)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn try_drop(&mut self) -> bool {
|
||||
pub fn try_drop(&mut self) -> Result<(), CUresult> {
|
||||
if self.cookie == T::COOKIE {
|
||||
self.cookie = 0;
|
||||
self.data.try_drop()?;
|
||||
unsafe { ManuallyDrop::drop(&mut self.data) };
|
||||
return true;
|
||||
return Ok(());
|
||||
}
|
||||
false
|
||||
Err(T::LIVENESS_FAIL)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> From<TryLockError<T>> for CUresult {
|
||||
fn from(_: TryLockError<T>) -> Self {
|
||||
CUresult::CUDA_ERROR_ILLEGAL_STATE
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Encuda {
|
||||
type To: Sized;
|
||||
fn encuda(self: Self) -> Self::To;
|
||||
|
@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
|
|||
}
|
||||
}
|
||||
|
||||
pub enum Error {
|
||||
L0(l0::sys::ze_result_t),
|
||||
Cuda(CUresult),
|
||||
}
|
||||
|
||||
impl Encuda for Error {
|
||||
type To = CUresult;
|
||||
fn encuda(self: Self) -> Self::To {
|
||||
match self {
|
||||
Error::L0(e) => e.into(),
|
||||
Error::Cuda(e) => e,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None);
|
||||
}
|
||||
|
||||
struct GlobalState {
|
||||
driver: l0::Driver,
|
||||
devices: Vec<Device>,
|
||||
}
|
||||
|
||||
unsafe impl Send for GlobalState {}
|
||||
|
||||
impl GlobalState {
|
||||
fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> {
|
||||
let mut mutex = GLOBAL_STATE
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?;
|
||||
Ok(f(global_state))
|
||||
}
|
||||
|
||||
fn lock_device<T>(
|
||||
device::Index(dev_idx): device::Index,
|
||||
f: impl FnOnce(&'static mut device::Device) -> T,
|
||||
) -> Result<T, CUresult> {
|
||||
if dev_idx < 0 {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
|
||||
}
|
||||
Self::lock(|global_state| {
|
||||
if dev_idx >= global_state.devices.len() as c_int {
|
||||
Err(CUresult::CUDA_ERROR_INVALID_DEVICE)
|
||||
} else {
|
||||
Ok(f(unsafe {
|
||||
transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize])
|
||||
}))
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>(
|
||||
f: F,
|
||||
) -> Result<R, CUresult> {
|
||||
Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))?
|
||||
}
|
||||
|
||||
fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>(
|
||||
f: F,
|
||||
) -> Result<R, CUresult> {
|
||||
context::CONTEXT_STACK.with(|stack| {
|
||||
stack
|
||||
.borrow_mut()
|
||||
.last_mut()
|
||||
.ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
|
||||
.map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))?
|
||||
})
|
||||
}
|
||||
|
||||
fn lock_stream<T>(
|
||||
stream: *mut stream::Stream,
|
||||
f: impl FnOnce(&mut stream::StreamData) -> T,
|
||||
) -> Result<T, CUresult> {
|
||||
if stream == ptr::null_mut()
|
||||
|| stream == stream::CU_STREAM_LEGACY
|
||||
|| stream == stream::CU_STREAM_PER_THREAD
|
||||
{
|
||||
Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))?
|
||||
} else {
|
||||
Self::lock(|_| {
|
||||
let stream = unsafe { &mut *stream }.as_result_mut()?;
|
||||
Ok(f(stream))
|
||||
})?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
fn is_intel_gpu_driver(_: &l0::Driver) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn init() -> l0::Result<()> {
|
||||
pub fn init() -> Result<(), CUresult> {
|
||||
let mut global_state = GLOBAL_STATE
|
||||
.lock()
|
||||
.map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?;
|
||||
.map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
|
||||
if global_state.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
l0::init()?;
|
||||
let drivers = l0::Driver::get()?;
|
||||
let driver = match drivers.into_iter().find(is_intel_gpu_driver) {
|
||||
None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN),
|
||||
Some(driver) => {
|
||||
device::init(&driver)?;
|
||||
driver
|
||||
}
|
||||
let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
|
||||
None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
|
||||
Some(driver) => device::init(&driver)?,
|
||||
};
|
||||
*global_state = Some(GlobalState { driver });
|
||||
*global_state = Some(GlobalState { devices });
|
||||
drop(global_state);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
|
||||
unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
|
||||
mem::transmute(t)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,79 +1,90 @@
|
|||
use std::{
|
||||
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
|
||||
collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
|
||||
os::raw::c_char, ptr, slice,
|
||||
};
|
||||
|
||||
use super::{function::Function, transmute_lifetime, CUresult};
|
||||
use super::{
|
||||
device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie,
|
||||
LiveCheck,
|
||||
};
|
||||
use ptx;
|
||||
|
||||
pub type Module = Mutex<ModuleData>;
|
||||
pub type Module = LiveCheck<ModuleData>;
|
||||
|
||||
impl HasLivenessCookie for ModuleData {
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
const COOKIE: usize = 0xf1313bd46505f98a;
|
||||
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
const COOKIE: usize = 0xbdbe3f15;
|
||||
|
||||
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
|
||||
|
||||
fn try_drop(&mut self) -> Result<(), CUresult> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModuleData {
|
||||
base: l0::Module,
|
||||
arg_lens: HashMap<CString, Vec<usize>>,
|
||||
pub spirv: SpirvModule,
|
||||
// This should be a Vec<>, but I'm feeling lazy
|
||||
pub device_binaries: HashMap<device::Index, CompiledModule>,
|
||||
}
|
||||
|
||||
pub enum ModuleCompileError<'a> {
|
||||
Parse(
|
||||
Vec<ptx::ast::PtxError>,
|
||||
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
|
||||
),
|
||||
Compile(ptx::TranslateError),
|
||||
L0(l0::sys::ze_result_t),
|
||||
CUDA(CUresult),
|
||||
pub struct SpirvModule {
|
||||
pub binaries: Vec<u32>,
|
||||
pub kernel_info: HashMap<String, ptx::KernelInfo>,
|
||||
pub should_link_ptx_impl: Option<&'static [u8]>,
|
||||
pub build_options: CString,
|
||||
}
|
||||
|
||||
impl<'a> ModuleCompileError<'a> {
|
||||
pub fn get_build_log(&self) {}
|
||||
pub struct CompiledModule {
|
||||
pub base: l0::Module,
|
||||
pub kernels: HashMap<CString, Box<Function>>,
|
||||
}
|
||||
|
||||
impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
|
||||
fn from(err: ptx::TranslateError) -> Self {
|
||||
ModuleCompileError::Compile(err)
|
||||
impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult {
|
||||
fn from(_: ptx::ParseError<L, T, E>) -> Self {
|
||||
CUresult::CUDA_ERROR_INVALID_PTX
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
|
||||
fn from(err: l0::sys::ze_result_t) -> Self {
|
||||
ModuleCompileError::L0(err)
|
||||
impl From<ptx::TranslateError> for CUresult {
|
||||
fn from(_: ptx::TranslateError) -> Self {
|
||||
CUresult::CUDA_ERROR_INVALID_PTX
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<CUresult> for ModuleCompileError<'a> {
|
||||
fn from(err: CUresult) -> Self {
|
||||
ModuleCompileError::CUDA(err)
|
||||
impl SpirvModule {
|
||||
pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> {
|
||||
let u8_text = unsafe { CStr::from_ptr(text) };
|
||||
let ptx_text = u8_text
|
||||
.to_str()
|
||||
.map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?;
|
||||
Self::new(ptx_text)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleData {
|
||||
pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
|
||||
pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
|
||||
let ast = match ast {
|
||||
Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
|
||||
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
|
||||
Ok(ast) => ast,
|
||||
};
|
||||
let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
|
||||
let spirv_module = ptx::to_spirv_module(ast)?;
|
||||
Ok(SpirvModule {
|
||||
binaries: spirv_module.assemble(),
|
||||
kernel_info: spirv_module.kernel_info,
|
||||
should_link_ptx_impl: spirv_module.should_link_ptx_impl,
|
||||
build_options: spirv_module.build_options,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
|
||||
let byte_il = unsafe {
|
||||
slice::from_raw_parts::<u8>(
|
||||
spirv.as_ptr() as *const _,
|
||||
spirv.len() * mem::size_of::<u32>(),
|
||||
slice::from_raw_parts(
|
||||
self.binaries.as_ptr() as *const u8,
|
||||
self.binaries.len() * mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let module = super::device::with_current_exclusive(|dev| {
|
||||
l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
|
||||
});
|
||||
match module {
|
||||
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
|
||||
base: module,
|
||||
arg_lens: all_arg_lens
|
||||
.into_iter()
|
||||
.map(|(k, v)| (CString::new(k).unwrap(), v))
|
||||
.collect(),
|
||||
})),
|
||||
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
|
||||
Err(err) => Err(ModuleCompileError::from(err)),
|
||||
}
|
||||
let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?;
|
||||
Ok(l0_module)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,34 +96,75 @@ pub fn get_function(
|
|||
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let name = unsafe { CStr::from_ptr(name) };
|
||||
let (mut kernel, args_len) = unsafe { &*hmod }
|
||||
.try_lock()
|
||||
.map(|module| {
|
||||
Result::<_, CUresult>::Ok((
|
||||
l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
|
||||
module
|
||||
.arg_lens
|
||||
.get(name)
|
||||
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
|
||||
.clone(),
|
||||
))
|
||||
})
|
||||
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
|
||||
kernel.set_indirect_access(
|
||||
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
|
||||
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
|
||||
)?;
|
||||
unsafe {
|
||||
*hfunc = Box::into_raw(Box::new(Function {
|
||||
base: kernel,
|
||||
arg_size: args_len,
|
||||
}))
|
||||
};
|
||||
let name = unsafe { CStr::from_ptr(name) }.to_owned();
|
||||
let function: *mut Function = GlobalState::lock_current_context(|ctx| {
|
||||
let module = unsafe { &mut *hmod }.as_result_mut()?;
|
||||
let device = unsafe { &mut *ctx.device };
|
||||
let compiled_module = match module.device_binaries.entry(device.index) {
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let new_module = CompiledModule {
|
||||
base: module.spirv.compile(&mut device.l0_context, &device.base)?,
|
||||
kernels: HashMap::new(),
|
||||
};
|
||||
entry.insert(new_module)
|
||||
}
|
||||
};
|
||||
//let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
|
||||
let kernel = match compiled_module.kernels.entry(name) {
|
||||
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
let kernel_info = module
|
||||
.spirv
|
||||
.kernel_info
|
||||
.get(unsafe {
|
||||
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
|
||||
})
|
||||
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
|
||||
let kernel =
|
||||
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
|
||||
entry.insert(Box::new(Function::new(FunctionData {
|
||||
base: kernel,
|
||||
arg_size: kernel_info.arguments_sizes.clone(),
|
||||
use_shared_mem: kernel_info.uses_shared_mem,
|
||||
})))
|
||||
}
|
||||
};
|
||||
Ok::<_, CUresult>(kernel as *mut _)
|
||||
})??;
|
||||
unsafe { *hfunc = function };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
|
||||
pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> {
|
||||
let spirv_data = SpirvModule::new_raw(image as *const _)?;
|
||||
load_data_impl(pmod, spirv_data)
|
||||
}
|
||||
|
||||
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
|
||||
let module = GlobalState::lock_current_context(|ctx| {
|
||||
let device = unsafe { &mut *ctx.device };
|
||||
let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
|
||||
let mut device_binaries = HashMap::new();
|
||||
let compiled_module = CompiledModule {
|
||||
base: l0_module,
|
||||
kernels: HashMap::new(),
|
||||
};
|
||||
device_binaries.insert(device.index, compiled_module);
|
||||
let module_data = ModuleData {
|
||||
spirv: spirv_data,
|
||||
device_binaries,
|
||||
};
|
||||
Ok::<_, CUresult>(module_data)
|
||||
})??;
|
||||
let module_ptr = Box::into_raw(Box::new(Module::new(module)));
|
||||
unsafe { *pmod = module_ptr };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> {
|
||||
if module == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
GlobalState::lock(|_| Module::destroy_impl(module))?
|
||||
}
|
||||
|
|
|
@ -1,36 +1,114 @@
|
|||
use std::cell::RefCell;
|
||||
use super::{
|
||||
context::{Context, ContextData},
|
||||
CUresult, GlobalState,
|
||||
};
|
||||
use std::{mem, ptr};
|
||||
|
||||
use device::Device;
|
||||
use super::{HasLivenessCookie, LiveCheck};
|
||||
|
||||
use super::device;
|
||||
pub type Stream = LiveCheck<StreamData>;
|
||||
|
||||
pub struct Stream {
|
||||
dev: *mut Device,
|
||||
}
|
||||
pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _;
|
||||
pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _;
|
||||
|
||||
pub struct DefaultStream {
|
||||
streams: Vec<Option<Stream>>,
|
||||
}
|
||||
impl HasLivenessCookie for StreamData {
|
||||
#[cfg(target_pointer_width = "64")]
|
||||
const COOKIE: usize = 0x512097354de18d35;
|
||||
|
||||
impl DefaultStream {
|
||||
fn new() -> Self {
|
||||
DefaultStream {
|
||||
streams: Vec::new(),
|
||||
#[cfg(target_pointer_width = "32")]
|
||||
const COOKIE: usize = 0x77d5cc0b;
|
||||
|
||||
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
|
||||
|
||||
fn try_drop(&mut self) -> Result<(), CUresult> {
|
||||
if self.context != ptr::null_mut() {
|
||||
let context = unsafe { &mut *self.context };
|
||||
if !context.streams.remove(&(self as *mut _)) {
|
||||
return Err(CUresult::CUDA_ERROR_UNKNOWN);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new());
|
||||
pub struct StreamData {
|
||||
pub context: *mut ContextData,
|
||||
pub queue: l0::CommandQueue,
|
||||
}
|
||||
|
||||
impl StreamData {
|
||||
pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
|
||||
Ok(StreamData {
|
||||
context: ptr::null_mut(),
|
||||
queue: l0::CommandQueue::new(ctx, dev)?,
|
||||
})
|
||||
}
|
||||
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
|
||||
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
|
||||
let l0_dev = &unsafe { &*ctx.device }.base;
|
||||
Ok(StreamData {
|
||||
context: ctx as *mut _,
|
||||
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
|
||||
let ctx = unsafe { &mut *self.context };
|
||||
let dev = unsafe { &mut *ctx.device };
|
||||
l0::CommandList::new(&mut dev.l0_context, &dev.base)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StreamData {
|
||||
fn drop(&mut self) {
|
||||
if self.context == ptr::null_mut() {
|
||||
return;
|
||||
}
|
||||
unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) };
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> {
|
||||
if pctx == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?;
|
||||
if ctx_ptr == ptr::null_mut() {
|
||||
return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED);
|
||||
}
|
||||
unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> {
|
||||
let stream_ptr = GlobalState::lock_current_context(|ctx| {
|
||||
let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?));
|
||||
let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _;
|
||||
if !ctx.streams.insert(stream_ptr) {
|
||||
return Err(CUresult::CUDA_ERROR_UNKNOWN);
|
||||
}
|
||||
mem::forget(stream_box);
|
||||
Ok::<_, CUresult>(stream_ptr)
|
||||
})??;
|
||||
unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) };
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> {
|
||||
if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD
|
||||
{
|
||||
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
|
||||
}
|
||||
GlobalState::lock(|_| Stream::destroy_impl(pstream))?
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod test {
|
||||
use crate::cuda::CUstream;
|
||||
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
use std::ptr;
|
||||
use std::{ptr, thread};
|
||||
|
||||
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
|
||||
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
|
||||
|
@ -65,5 +143,100 @@ mod tests {
|
|||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(ctx2, stream_ctx2);
|
||||
// Cleanup
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS);
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
cuda_driver_test!(stream_context_destroyed);
|
||||
|
||||
fn stream_context_destroyed<T: CudaDriverFns>() {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream = ptr::null_mut();
|
||||
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream_ctx1 = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuStreamGetCtx(stream, &mut stream_ctx1),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(stream_ctx1, ctx);
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
let mut stream_ctx2 = ptr::null_mut();
|
||||
// When a context gets destroyed, its streams are also destroyed
|
||||
let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2);
|
||||
assert!(
|
||||
cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE
|
||||
|| cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT
|
||||
|| cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED
|
||||
);
|
||||
assert_eq!(
|
||||
T::cuStreamDestroy_v2(stream),
|
||||
CUresult::CUDA_ERROR_INVALID_HANDLE
|
||||
);
|
||||
// Check if creating another context is possible
|
||||
let mut ctx2 = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
// Cleanup
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
cuda_driver_test!(stream_moves_context_to_another_thread);
|
||||
|
||||
fn stream_moves_context_to_another_thread<T: CudaDriverFns>() {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream = ptr::null_mut();
|
||||
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream_ctx1 = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuStreamGetCtx(stream, &mut stream_ctx1),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
assert_eq!(stream_ctx1, ctx);
|
||||
let stream_ptr = stream as usize;
|
||||
let stream_ctx_on_thread = thread::spawn(move || {
|
||||
let mut stream_ctx2 = ptr::null_mut();
|
||||
assert_eq!(
|
||||
T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
stream_ctx2 as usize
|
||||
})
|
||||
.join()
|
||||
.unwrap();
|
||||
assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _);
|
||||
// Cleanup
|
||||
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
cuda_driver_test!(can_destroy_stream);
|
||||
|
||||
fn can_destroy_stream<T: CudaDriverFns>() {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
let mut stream = ptr::null_mut();
|
||||
assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
|
||||
assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
|
||||
// Cleanup
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
|
||||
cuda_driver_test!(cant_destroy_default_stream);
|
||||
|
||||
fn cant_destroy_default_stream<T: CudaDriverFns>() {
|
||||
assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
|
||||
let mut ctx = ptr::null_mut();
|
||||
assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
|
||||
assert_ne!(
|
||||
T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _),
|
||||
CUresult::CUDA_SUCCESS
|
||||
);
|
||||
// Cleanup
|
||||
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::{cuda::CUstream, r#impl as notcuda};
|
||||
use crate::r#impl::CUresult;
|
||||
use crate::{cuda::CUuuid, r#impl::Encuda};
|
||||
use crate::cuda as notcuda;
|
||||
use crate::cuda::CUstream;
|
||||
use crate::cuda::CUuuid;
|
||||
use crate::{
|
||||
cuda::{CUdevice, CUdeviceptr},
|
||||
r#impl::CUresult,
|
||||
};
|
||||
use ::std::{
|
||||
ffi::c_void,
|
||||
os::raw::{c_int, c_uint},
|
||||
|
@ -37,48 +41,63 @@ pub trait CudaDriverFns {
|
|||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
|
||||
fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
|
||||
}
|
||||
|
||||
pub struct NotCuda();
|
||||
|
||||
impl CudaDriverFns for NotCuda {
|
||||
fn cuInit(_flags: c_uint) -> CUresult {
|
||||
crate::cuda::cuInit(_flags as _)
|
||||
notcuda::cuInit(_flags as _)
|
||||
}
|
||||
|
||||
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
|
||||
notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda()
|
||||
notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
|
||||
}
|
||||
|
||||
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
|
||||
notcuda::context::destroy_v2(ctx as *mut _)
|
||||
notcuda::cuCtxDestroy_v2(ctx as *mut _)
|
||||
}
|
||||
|
||||
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
|
||||
notcuda::context::pop_current_v2(pctx as *mut _)
|
||||
notcuda::cuCtxPopCurrent_v2(pctx as *mut _)
|
||||
}
|
||||
|
||||
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
|
||||
notcuda::context::get_api_version(ctx as *mut _, version)
|
||||
notcuda::cuCtxGetApiVersion(ctx as *mut _, version)
|
||||
}
|
||||
|
||||
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
|
||||
notcuda::context::get_current(pctx as *mut _).encuda()
|
||||
notcuda::cuCtxGetCurrent(pctx as *mut _)
|
||||
}
|
||||
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
|
||||
notcuda::memory::alloc_v2(dptr as *mut _, bytesize)
|
||||
notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize)
|
||||
}
|
||||
|
||||
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
|
||||
notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda()
|
||||
notcuda::cuDeviceGetUuid(uuid, CUdevice(dev))
|
||||
}
|
||||
|
||||
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
|
||||
notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
|
||||
notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
|
||||
}
|
||||
|
||||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
crate::cuda::cuStreamGetCtx(hStream, pctx as _)
|
||||
notcuda::cuStreamGetCtx(hStream, pctx as _)
|
||||
}
|
||||
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
|
||||
notcuda::cuStreamCreate(stream, flags)
|
||||
}
|
||||
|
||||
fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
|
||||
notcuda::cuMemFree_v2(CUdeviceptr(dptr as _))
|
||||
}
|
||||
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
|
||||
notcuda::cuStreamDestroy_v2(stream)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda {
|
|||
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
|
||||
}
|
||||
|
||||
fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
|
||||
unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser;
|
|||
pub use lalrpop_util::lexer::Token;
|
||||
pub use lalrpop_util::ParseError;
|
||||
pub use rspirv::dr::Error as SpirvError;
|
||||
pub use translate::TranslateError as TranslateError;
|
||||
pub use translate::to_spirv;
|
||||
pub use translate::to_spirv_module;
|
||||
pub use translate::KernelInfo;
|
||||
pub use translate::TranslateError;
|
||||
|
||||
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
|
||||
x.into_iter().filter_map(|x| x).collect()
|
||||
|
|
|
@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) {
|
|||
fn compile_and_assert(s: &str) -> Result<(), TranslateError> {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
|
||||
crate::to_spirv(ast)?;
|
||||
crate::to_spirv_module(ast)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::ast;
|
||||
use half::f16;
|
||||
use rspirv::{binary::Disassemble, dr};
|
||||
use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
|
||||
use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
|
||||
use std::{
|
||||
collections::{hash_map, HashMap, HashSet},
|
||||
convert::TryInto,
|
||||
|
@ -450,6 +450,11 @@ pub struct Module {
|
|||
pub should_link_ptx_impl: Option<&'static [u8]>,
|
||||
pub build_options: CString,
|
||||
}
|
||||
impl Module {
|
||||
pub fn assemble(&self) -> Vec<u32> {
|
||||
self.spirv.assemble()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KernelInfo {
|
||||
pub arguments_sizes: Vec<usize>,
|
||||
|
@ -1046,8 +1051,12 @@ fn emit_function_header<'a>(
|
|||
kernel_info: &mut HashMap<String, KernelInfo>,
|
||||
) -> Result<(), TranslateError> {
|
||||
if let MethodName::Kernel(name) = func_decl.name {
|
||||
let args_lens = func_decl
|
||||
.input
|
||||
let input_args = if !func_decl.uses_shared_mem {
|
||||
func_decl.input.as_slice()
|
||||
} else {
|
||||
&func_decl.input[0..func_decl.input.len() - 1]
|
||||
};
|
||||
let args_lens = input_args
|
||||
.iter()
|
||||
.map(|param| param.v_type.size_of())
|
||||
.collect();
|
||||
|
@ -1135,21 +1144,6 @@ fn emit_function_header<'a>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_spirv<'a>(
|
||||
ast: ast::Module<'a>,
|
||||
) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
|
||||
let module = to_spirv_module(ast)?;
|
||||
Ok((
|
||||
module.should_link_ptx_impl,
|
||||
module.spirv.assemble(),
|
||||
module
|
||||
.kernel_info
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.arguments_sizes))
|
||||
.collect(),
|
||||
))
|
||||
}
|
||||
|
||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||
builder.capability(spirv::Capability::GenericPointer);
|
||||
builder.capability(spirv::Capability::Linkage);
|
||||
|
|
Loading…
Add table
Reference in a new issue