Add back erroneously removed functionality

This commit is contained in:
Andrzej Janik 2020-11-12 21:08:28 +01:00
parent a2e77fe961
commit a6765baa3a
5 changed files with 26 additions and 8 deletions

View file

@ -2281,7 +2281,7 @@ pub extern "C" fn cuDevicePrimaryCtxRelease(dev: CUdevice) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxRelease_v2(dev: CUdevice) -> CUresult {
r#impl::unimplemented()
r#impl::device::primary_ctx_release_v2(dev.decuda())
}
#[cfg_attr(not(test), no_mangle)]

View file

@ -345,6 +345,11 @@ pub fn primary_ctx_retain(
Ok(())
}
// TODO: allow for retain/reset/release of primary context
pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult {
CUresult::CUDA_SUCCESS
}
#[cfg(test)]
mod test {
use super::super::test::CudaDriverFns;

View file

@ -4,7 +4,7 @@ use crate::{
cuda_impl,
};
use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use super::{context, context::ContextData, device, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@ -110,8 +110,17 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() },
];
unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
super::unimplemented()
unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
}
fn cudart_interface_fn1_impl(
pctx: *mut *mut context::Context,
dev: device::Index,
) -> Result<(), CUresult> {
let ctx_ptr = GlobalState::lock_device(dev, |d| &mut d.primary_context as *mut _)?;
unsafe { *pctx = ctx_ptr };
Ok(())
}
/*

View file

@ -110,7 +110,6 @@ pub fn get_function(
entry.insert(new_module)
}
};
//let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
@ -121,8 +120,13 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
let kernel =
let mut kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
)?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),

View file

@ -1,6 +1,6 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
use rspirv::dr;
use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
@ -6662,7 +6662,7 @@ impl ast::ScalarType {
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float,
ast::ScalarType::F16x2 => ScalarKind::Float2,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}