mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add support for top-level global variables, improve array support
This commit is contained in:
parent
9a65dd32f5
commit
27d25865af
13 changed files with 1085 additions and 378 deletions
|
@ -1,6 +1,6 @@
|
|||
use crate::sys;
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
ffi::{c_void, CStr, CString},
|
||||
fmt::Debug,
|
||||
marker::PhantomData,
|
||||
mem, ptr,
|
||||
|
@ -238,23 +238,16 @@ impl Drop for CommandQueue {
|
|||
pub struct Module(sys::ze_module_handle_t);
|
||||
|
||||
impl Module {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t {
|
||||
self.0
|
||||
}
|
||||
pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self {
|
||||
Self(x)
|
||||
}
|
||||
|
||||
pub fn new_spirv(
|
||||
ctx: &mut Context,
|
||||
d: &Device,
|
||||
bin: &[u8],
|
||||
opts: Option<&CStr>,
|
||||
) -> Result<Self> {
|
||||
) -> (Result<Self>, BuildLog) {
|
||||
Module::new(ctx, true, d, bin, opts)
|
||||
}
|
||||
|
||||
pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> Result<Self> {
|
||||
pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
|
||||
Module::new(ctx, false, d, bin, None)
|
||||
}
|
||||
|
||||
|
@ -264,7 +257,7 @@ impl Module {
|
|||
d: &Device,
|
||||
bin: &[u8],
|
||||
opts: Option<&CStr>,
|
||||
) -> Result<Self> {
|
||||
) -> (Result<Self>, BuildLog) {
|
||||
let desc = sys::ze_module_desc_t {
|
||||
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_MODULE_DESC,
|
||||
pNext: ptr::null(),
|
||||
|
@ -279,14 +272,14 @@ impl Module {
|
|||
pConstants: ptr::null(),
|
||||
};
|
||||
let mut result: sys::ze_module_handle_t = ptr::null_mut();
|
||||
check!(sys::zeModuleCreate(
|
||||
ctx.0,
|
||||
d.0,
|
||||
&desc,
|
||||
&mut result,
|
||||
ptr::null_mut()
|
||||
));
|
||||
Ok(Module(result))
|
||||
let mut log_handle = ptr::null_mut();
|
||||
let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) };
|
||||
let log = BuildLog(log_handle);
|
||||
if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
|
||||
(Result::Err(err), log)
|
||||
} else {
|
||||
(Ok(Module(result)), log)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,6 +290,32 @@ impl Drop for Module {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct BuildLog(sys::ze_module_build_log_handle_t);
|
||||
|
||||
impl BuildLog {
|
||||
pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t {
|
||||
self.0
|
||||
}
|
||||
pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self {
|
||||
Self(x)
|
||||
}
|
||||
|
||||
pub fn get_cstring(&self) -> Result<CString> {
|
||||
let mut size = 0;
|
||||
check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) };
|
||||
let mut str_vec = vec![0u8; size];
|
||||
check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) };
|
||||
str_vec.pop();
|
||||
Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for BuildLog {
|
||||
fn drop(&mut self) {
|
||||
check_panic!(sys::zeModuleBuildLogDestroy(self.0));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SafeRepr {}
|
||||
impl SafeRepr for u8 {}
|
||||
impl SafeRepr for i8 {}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use ::std::os::raw::{c_uint, c_void};
|
||||
use std::ptr;
|
||||
|
||||
use super::{context, device, stream::Stream, CUresult};
|
||||
use super::{device, stream::Stream, CUresult};
|
||||
|
||||
pub struct Function {
|
||||
pub base: l0::Kernel<'static>,
|
||||
|
|
|
@ -46,7 +46,7 @@ unsafe fn memcpy_impl(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> {
|
||||
pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st};
|
||||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
|
||||
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
use std::{
|
||||
collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice,
|
||||
sync::Mutex,
|
||||
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
|
||||
};
|
||||
|
||||
use super::{function::Function, transmute_lifetime, CUresult};
|
||||
use ptx;
|
||||
|
||||
use super::context;
|
||||
|
||||
pub type Module = Mutex<ModuleData>;
|
||||
|
||||
pub struct ModuleData {
|
||||
|
@ -67,14 +64,14 @@ impl ModuleData {
|
|||
l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
|
||||
});
|
||||
match module {
|
||||
Ok(Ok(module)) => Ok(Mutex::new(Self {
|
||||
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
|
||||
base: module,
|
||||
arg_lens: all_arg_lens
|
||||
.into_iter()
|
||||
.map(|(k, v)| (CString::new(k).unwrap(), v))
|
||||
.collect(),
|
||||
})),
|
||||
Ok(Err(err)) => Err(ModuleCompileError::from(err)),
|
||||
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
|
||||
Err(err) => Err(ModuleCompileError::from(err)),
|
||||
}
|
||||
}
|
||||
|
@ -116,6 +113,6 @@ pub fn get_function(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> {
|
||||
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ mod tests {
|
|||
|
||||
use super::super::test::CudaDriverFns;
|
||||
use super::super::CUresult;
|
||||
use std::{ffi::c_void, ptr};
|
||||
use std::ptr;
|
||||
|
||||
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
|
||||
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
|
||||
|
@ -41,7 +41,7 @@ mod tests {
|
|||
fn default_stream_uses_current_ctx_legacy<T: CudaDriverFns>() {
|
||||
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_LEGACY);
|
||||
}
|
||||
|
||||
|
||||
fn default_stream_uses_current_ctx_ptsd<T: CudaDriverFns>() {
|
||||
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_PER_THREAD);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda};
|
||||
use crate::{cuda::CUstream, r#impl as notcuda};
|
||||
use crate::r#impl::CUresult;
|
||||
use crate::{cuda::CUuuid, r#impl::Encuda};
|
||||
use ::std::{
|
||||
|
|
265
ptx/src/ast.rs
265
ptx/src/ast.rs
|
@ -1,6 +1,8 @@
|
|||
use std::convert::From;
|
||||
use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
|
||||
use std::{marker::PhantomData, num::ParseIntError};
|
||||
|
||||
use half::f16;
|
||||
|
||||
quick_error! {
|
||||
#[derive(Debug)]
|
||||
pub enum PtxError {
|
||||
|
@ -9,11 +11,17 @@ quick_error! {
|
|||
display("{}", err)
|
||||
cause(err)
|
||||
}
|
||||
ParseFloat (err: ParseFloatError) {
|
||||
from()
|
||||
display("{}", err)
|
||||
cause(err)
|
||||
}
|
||||
SyntaxError {}
|
||||
NonF32Ftz {}
|
||||
WrongArrayType {}
|
||||
WrongVectorElement {}
|
||||
MultiArrayVariable {}
|
||||
ZeroDimensionArray {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +61,7 @@ macro_rules! sub_scalar_type {
|
|||
|
||||
macro_rules! sub_type {
|
||||
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum $type_name {
|
||||
$(
|
||||
$variant ($($field_type),+),
|
||||
|
@ -80,11 +88,13 @@ sub_type! {
|
|||
}
|
||||
}
|
||||
|
||||
type VecU32 = Vec<u32>;
|
||||
|
||||
sub_type! {
|
||||
VariableLocalType {
|
||||
Scalar(SizedScalarType),
|
||||
Vector(SizedScalarType, u8),
|
||||
Array(SizedScalarType, u32),
|
||||
Array(SizedScalarType, VecU32),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,7 +105,7 @@ sub_type! {
|
|||
sub_type! {
|
||||
VariableParamType {
|
||||
Scalar(ParamScalarType),
|
||||
Array(SizedScalarType, u32),
|
||||
Array(SizedScalarType, VecU32),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,7 +179,12 @@ impl<
|
|||
|
||||
pub struct Module<'a> {
|
||||
pub version: (u8, u8),
|
||||
pub functions: Vec<ParsedFunction<'a>>,
|
||||
pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>,
|
||||
}
|
||||
|
||||
pub enum Directive<'a, P: ArgParams> {
|
||||
Variable(Variable<VariableType, P::Id>),
|
||||
Method(Function<'a, &'a str, Statement<P>>),
|
||||
}
|
||||
|
||||
pub enum MethodDecl<'a, ID> {
|
||||
|
@ -187,7 +202,7 @@ pub struct Function<'a, ID, S> {
|
|||
|
||||
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum FnArgumentType {
|
||||
Reg(VariableRegType),
|
||||
Param(VariableParamType),
|
||||
|
@ -202,11 +217,11 @@ impl From<FnArgumentType> for Type {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
pub enum Type {
|
||||
Scalar(ScalarType),
|
||||
Vector(ScalarType, u8),
|
||||
Array(ScalarType, u32),
|
||||
Array(ScalarType, Vec<u32>),
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
|
@ -274,6 +289,30 @@ sub_scalar_type!(FloatType {
|
|||
F64
|
||||
});
|
||||
|
||||
impl ScalarType {
|
||||
pub fn size_of(self) -> u8 {
|
||||
match self {
|
||||
ScalarType::U8 => 1,
|
||||
ScalarType::S8 => 1,
|
||||
ScalarType::B8 => 1,
|
||||
ScalarType::U16 => 2,
|
||||
ScalarType::S16 => 2,
|
||||
ScalarType::B16 => 2,
|
||||
ScalarType::F16 => 2,
|
||||
ScalarType::U32 => 4,
|
||||
ScalarType::S32 => 4,
|
||||
ScalarType::B32 => 4,
|
||||
ScalarType::F32 => 4,
|
||||
ScalarType::U64 => 8,
|
||||
ScalarType::S64 => 8,
|
||||
ScalarType::B64 => 8,
|
||||
ScalarType::F64 => 8,
|
||||
ScalarType::F16x2 => 4,
|
||||
ScalarType::Pred => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ScalarType {
|
||||
fn default() -> Self {
|
||||
ScalarType::B8
|
||||
|
@ -296,13 +335,26 @@ pub struct Variable<T, ID> {
|
|||
pub align: Option<u32>,
|
||||
pub v_type: T,
|
||||
pub name: ID,
|
||||
pub array_init: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||
#[derive(Eq, PartialEq, Clone)]
|
||||
pub enum VariableType {
|
||||
Reg(VariableRegType),
|
||||
Local(VariableLocalType),
|
||||
Param(VariableParamType),
|
||||
Global(VariableLocalType),
|
||||
}
|
||||
|
||||
impl VariableType {
|
||||
pub fn to_type(&self) -> (StateSpace, Type) {
|
||||
match self {
|
||||
VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
|
||||
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
|
||||
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
|
||||
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<VariableType> for Type {
|
||||
|
@ -311,6 +363,7 @@ impl From<VariableType> for Type {
|
|||
VariableType::Reg(t) => t.into(),
|
||||
VariableType::Local(t) => t.into(),
|
||||
VariableType::Param(t) => t.into(),
|
||||
VariableType::Global(t) => t.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -318,7 +371,6 @@ impl From<VariableType> for Type {
|
|||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum StateSpace {
|
||||
Reg,
|
||||
Sreg,
|
||||
Const,
|
||||
Global,
|
||||
Local,
|
||||
|
@ -538,7 +590,7 @@ pub enum LdCacheOperator {
|
|||
Uncached,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct MovDetails {
|
||||
pub typ: Type,
|
||||
pub src_is_address: bool,
|
||||
|
@ -846,3 +898,194 @@ pub struct MinMaxFloat {
|
|||
pub nan: bool,
|
||||
pub typ: FloatType,
|
||||
}
|
||||
|
||||
pub enum NumsOrArrays<'a> {
|
||||
Nums(Vec<&'a str>),
|
||||
Arrays(Vec<NumsOrArrays<'a>>),
|
||||
}
|
||||
|
||||
impl<'a> NumsOrArrays<'a> {
|
||||
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
|
||||
self.normalize_dimensions(dimensions)?;
|
||||
let sizeof_t = ScalarType::from(typ).size_of() as usize;
|
||||
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
|
||||
let mut result = vec![0; result_size];
|
||||
self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> {
|
||||
match dimensions.first_mut() {
|
||||
Some(first) => {
|
||||
if *first == 0 {
|
||||
*first = match self {
|
||||
NumsOrArrays::Nums(v) => v.len() as u32,
|
||||
NumsOrArrays::Arrays(v) => v.len() as u32,
|
||||
};
|
||||
}
|
||||
}
|
||||
None => return Err(PtxError::ZeroDimensionArray),
|
||||
}
|
||||
for dim in dimensions {
|
||||
if *dim == 0 {
|
||||
return Err(PtxError::ZeroDimensionArray);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_and_copy(
|
||||
&self,
|
||||
t: SizedScalarType,
|
||||
size_of_t: usize,
|
||||
dimensions: &[u32],
|
||||
result: &mut [u8],
|
||||
) -> Result<(), PtxError> {
|
||||
match dimensions {
|
||||
[] => unreachable!(),
|
||||
[dim] => match self {
|
||||
NumsOrArrays::Nums(vec) => {
|
||||
if vec.len() > *dim as usize {
|
||||
return Err(PtxError::ZeroDimensionArray);
|
||||
}
|
||||
for (idx, val) in vec.iter().enumerate() {
|
||||
Self::parse_and_copy_single(t, idx, val, result)?;
|
||||
}
|
||||
}
|
||||
NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray),
|
||||
},
|
||||
[first_dim, rest @ ..] => match self {
|
||||
NumsOrArrays::Arrays(vec) => {
|
||||
if vec.len() > *first_dim as usize {
|
||||
return Err(PtxError::ZeroDimensionArray);
|
||||
}
|
||||
let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize));
|
||||
for (idx, this) in vec.iter().enumerate() {
|
||||
this.parse_and_copy(
|
||||
t,
|
||||
size_of_t,
|
||||
rest,
|
||||
&mut result[(size_of_element * idx)..],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray),
|
||||
},
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_and_copy_single(
|
||||
t: SizedScalarType,
|
||||
idx: usize,
|
||||
str_val: &str,
|
||||
output: &mut [u8],
|
||||
) -> Result<(), PtxError> {
|
||||
match t {
|
||||
SizedScalarType::B8 | SizedScalarType::U8 => {
|
||||
Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::B16 | SizedScalarType::U16 => {
|
||||
Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::B32 | SizedScalarType::U32 => {
|
||||
Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::B64 | SizedScalarType::U64 => {
|
||||
Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::S8 => {
|
||||
Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::S16 => {
|
||||
Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::S32 => {
|
||||
Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::S64 => {
|
||||
Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::F16 => {
|
||||
Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::F16x2 => todo!(),
|
||||
SizedScalarType::F32 => {
|
||||
Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?;
|
||||
}
|
||||
SizedScalarType::F64 => {
|
||||
Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_and_copy_single_t<T: Copy + FromStr>(
|
||||
idx: usize,
|
||||
str_val: &str,
|
||||
output: &mut [u8],
|
||||
) -> Result<(), PtxError>
|
||||
where
|
||||
T::Err: Into<PtxError>,
|
||||
{
|
||||
let typed_output = unsafe {
|
||||
std::slice::from_raw_parts_mut::<T>(
|
||||
output.as_mut_ptr() as *mut _,
|
||||
output.len() / mem::size_of::<T>(),
|
||||
)
|
||||
};
|
||||
typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn array_fails_multiple_0_dmiensions() {
|
||||
let inp = NumsOrArrays::Nums(Vec::new());
|
||||
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_fails_on_empty() {
|
||||
let inp = NumsOrArrays::Nums(Vec::new());
|
||||
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_auto_sizes_0_dimension() {
|
||||
let inp = NumsOrArrays::Arrays(vec![
|
||||
NumsOrArrays::Nums(vec!["1", "2"]),
|
||||
NumsOrArrays::Nums(vec!["3", "4"]),
|
||||
]);
|
||||
let mut dimensions = vec![0u32, 2];
|
||||
assert_eq!(
|
||||
vec![1u8, 2, 3, 4],
|
||||
inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
|
||||
);
|
||||
assert_eq!(dimensions, vec![2u32, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_fails_wrong_structure() {
|
||||
let inp = NumsOrArrays::Arrays(vec![
|
||||
NumsOrArrays::Nums(vec!["1", "2"]),
|
||||
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]),
|
||||
]);
|
||||
let mut dimensions = vec![0u32, 2];
|
||||
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_fails_too_long_component() {
|
||||
let inp = NumsOrArrays::Arrays(vec![
|
||||
NumsOrArrays::Nums(vec!["1", "2", "3"]),
|
||||
NumsOrArrays::Nums(vec!["4", "5"]),
|
||||
]);
|
||||
let mut dimensions = vec![0u32, 2];
|
||||
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ use crate::ast;
|
|||
use crate::ast::UnwrapWithVec;
|
||||
use crate::{without_none, vector_index};
|
||||
|
||||
use lalrpop_util::ParseError;
|
||||
|
||||
grammar<'a>(errors: &mut Vec<ast::PtxError>);
|
||||
|
||||
extern {
|
||||
|
@ -27,6 +29,7 @@ match {
|
|||
"{", "}",
|
||||
"<", ">",
|
||||
"|",
|
||||
"=",
|
||||
".acquire",
|
||||
".address_size",
|
||||
".align",
|
||||
|
@ -94,7 +97,6 @@ match {
|
|||
".sat",
|
||||
".section",
|
||||
".shared",
|
||||
".sreg",
|
||||
".sys",
|
||||
".target",
|
||||
".to",
|
||||
|
@ -176,8 +178,8 @@ ExtendedID : &'input str = {
|
|||
}
|
||||
|
||||
pub Module: ast::Module<'input> = {
|
||||
<v:Version> Target <f:Directive*> => {
|
||||
ast::Module { version: v, functions: without_none(f) }
|
||||
<v:Version> Target <d:Directive*> => {
|
||||
ast::Module { version: v, directives: without_none(d) }
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -203,11 +205,12 @@ TargetSpecifier = {
|
|||
"map_f64_to_f32"
|
||||
};
|
||||
|
||||
Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = {
|
||||
Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
|
||||
AddressSize => None,
|
||||
<f:Function> => Some(f),
|
||||
<f:Function> => Some(ast::Directive::Method(f)),
|
||||
File => None,
|
||||
Section => None
|
||||
Section => None,
|
||||
<v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)),
|
||||
};
|
||||
|
||||
AddressSize = {
|
||||
|
@ -242,9 +245,9 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
|
|||
};
|
||||
|
||||
KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
|
||||
<v:ParamVariable> => {
|
||||
<v:ParamDeclaration> => {
|
||||
let (align, v_type, name) = v;
|
||||
ast::Variable{ align, v_type, name }
|
||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -252,12 +255,12 @@ FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
|
|||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Reg(v_type);
|
||||
ast::Variable{ align, v_type, name }
|
||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
||||
},
|
||||
<v:ParamVariable> => {
|
||||
<v:ParamDeclaration> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::FnArgumentType::Param(v_type);
|
||||
ast::Variable{ align, v_type, name }
|
||||
ast::Variable{ align, v_type, name, array_init: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,7 +271,6 @@ pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>
|
|||
|
||||
StateSpaceSpecifier: ast::StateSpace = {
|
||||
".reg" => ast::StateSpace::Reg,
|
||||
".sreg" => ast::StateSpace::Sreg,
|
||||
".const" => ast::StateSpace::Const,
|
||||
".global" => ast::StateSpace::Global,
|
||||
".local" => ast::StateSpace::Local,
|
||||
|
@ -344,13 +346,13 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
|
|||
<v:RegVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Reg(v_type);
|
||||
ast::Variable {align, v_type, name}
|
||||
ast::Variable {align, v_type, name, array_init: Vec::new()}
|
||||
},
|
||||
LocalVariable,
|
||||
<v:ParamVariable> => {
|
||||
let (align, v_type, name) = v;
|
||||
let (align, array_init, v_type, name) = v;
|
||||
let v_type = ast::VariableType::Param(v_type);
|
||||
ast::Variable {align, v_type, name}
|
||||
ast::Variable {align, v_type, name, array_init}
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -366,32 +368,60 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
|
|||
}
|
||||
|
||||
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
|
||||
ast::Variable {align, v_type, name}
|
||||
},
|
||||
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
|
||||
ast::Variable {align, v_type, name}
|
||||
".local" <def:LocalVariableDefinition> => {
|
||||
let (align, array_init, v_type, name) = def;
|
||||
ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }
|
||||
}
|
||||
}
|
||||
|
||||
GlobalVariable: ast::Variable<ast::VariableType, &'input str> = {
|
||||
".global" <def:LocalVariableDefinition> => {
|
||||
let (align, array_init, v_type, name) = def;
|
||||
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
|
||||
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableParamType::Scalar(t);
|
||||
(align, Vec::new(), v_type, name)
|
||||
},
|
||||
".param" <align:Align?> <arr:ArrayDefinition> => {
|
||||
let (array_init, name, (t, dimensions)) = arr;
|
||||
let v_type = ast::VariableParamType::Array(t, dimensions);
|
||||
(align, array_init, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
|
||||
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableParamType::Scalar(t);
|
||||
(align, v_type, name)
|
||||
},
|
||||
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
|
||||
let v_type = ast::VariableParamType::Array(t, arr);
|
||||
".param" <align:Align?> <arr:ArrayDeclaration> => {
|
||||
let (name, (t, dimensions)) = arr;
|
||||
let v_type = ast::VariableParamType::Array(t, dimensions);
|
||||
(align, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = {
|
||||
<align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableLocalType::Scalar(t);
|
||||
(align, Vec::new(), v_type, name)
|
||||
},
|
||||
<align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
|
||||
let v_type = ast::VariableLocalType::Vector(t, v_len);
|
||||
(align, Vec::new(), v_type, name)
|
||||
},
|
||||
<align:Align?> <arr:ArrayDefinition> => {
|
||||
let (array_init, name, (t, dimensions)) = arr;
|
||||
let v_type = ast::VariableLocalType::Array(t, dimensions);
|
||||
(align, array_init, v_type, name)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
SizedScalarType: ast::SizedScalarType = {
|
||||
".b8" => ast::SizedScalarType::B8,
|
||||
|
@ -431,12 +461,59 @@ ParamScalarType: ast::ParamScalarType = {
|
|||
".f64" => ast::ParamScalarType::F64,
|
||||
}
|
||||
|
||||
ArraySpecifier: u32 = {
|
||||
"[" <n:Num> "]" => {
|
||||
let size = n.parse::<u32>();
|
||||
size.unwrap_with(errors)
|
||||
ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = {
|
||||
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
|
||||
let mut dims = dims;
|
||||
let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?;
|
||||
Ok((
|
||||
array_init,
|
||||
name,
|
||||
(typ, dims)
|
||||
))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = {
|
||||
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? {
|
||||
let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?;
|
||||
Ok((name, (typ, dims)))
|
||||
}
|
||||
}
|
||||
|
||||
// [0] and [] are treated the same
|
||||
ArrayDimensions: Vec<u32> = {
|
||||
ArrayEmptyDimension => vec![0u32],
|
||||
ArrayEmptyDimension <dims:ArrayDimension+> => {
|
||||
let mut dims = dims;
|
||||
let mut result = vec![0u32];
|
||||
result.append(&mut dims);
|
||||
result
|
||||
},
|
||||
<dims:ArrayDimension+> => dims
|
||||
}
|
||||
|
||||
ArrayEmptyDimension = {
|
||||
"[" "]"
|
||||
}
|
||||
|
||||
ArrayDimension: u32 = {
|
||||
"[" <n:Num> "]" =>? {
|
||||
str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
|
||||
}
|
||||
}
|
||||
|
||||
ArrayInitializer: ast::NumsOrArrays<'input> = {
|
||||
"=" <nums:NumsOrArraysBracket> => nums
|
||||
}
|
||||
|
||||
NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
|
||||
"{" <nums:NumsOrArrays> "}" => nums
|
||||
}
|
||||
|
||||
NumsOrArrays: ast::NumsOrArrays<'input> = {
|
||||
<n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
|
||||
<n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
|
||||
}
|
||||
|
||||
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||
InstLd,
|
||||
|
@ -1244,3 +1321,11 @@ Comma<T>: Vec<T> = {
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
CommaNonEmpty<T>: Vec<T> = {
|
||||
<v:(<T> ",")*> <e:T> => {
|
||||
let mut v = v;
|
||||
v.push(e);
|
||||
v
|
||||
}
|
||||
};
|
||||
|
|
22
ptx/src/test/spirv_run/global_array.ptx
Normal file
22
ptx/src/test/spirv_run/global_array.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.global .s32 foobar[4] = {1};
|
||||
|
||||
.visible .entry global_array(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u32 temp;
|
||||
|
||||
mov.u64 in_addr, foobar;
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.global.u32 temp, [in_addr];
|
||||
st.global.u32 [out_addr], temp;
|
||||
ret;
|
||||
}
|
54
ptx/src/test/spirv_run/global_array.spvtxt
Normal file
54
ptx/src/test/spirv_run/global_array.spvtxt
Normal file
|
@ -0,0 +1,54 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int8
|
||||
OpCapability Int16
|
||||
OpCapability Int64
|
||||
OpCapability Float16
|
||||
OpCapability Float64
|
||||
%22 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %2 "global_array" %1
|
||||
%void = OpTypeVoid
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_4 = OpConstant %uint 4
|
||||
%_arr_uint_uint_4 = OpTypeArray %uint %uint_4
|
||||
%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4
|
||||
%uint_4_0 = OpConstant %uint 4
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%31 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0
|
||||
%1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %31
|
||||
%ulong = OpTypeInt 64 0
|
||||
%33 = OpTypeFunction %void %ulong %ulong
|
||||
%_ptr_Function_ulong = OpTypePointer Function %ulong
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
|
||||
%2 = OpFunction %void None %33
|
||||
%8 = OpFunctionParameter %ulong
|
||||
%9 = OpFunctionParameter %ulong
|
||||
%20 = OpLabel
|
||||
%3 = OpVariable %_ptr_Function_ulong Function
|
||||
%4 = OpVariable %_ptr_Function_ulong Function
|
||||
%5 = OpVariable %_ptr_Function_ulong Function
|
||||
%6 = OpVariable %_ptr_Function_ulong Function
|
||||
%7 = OpVariable %_ptr_Function_uint Function
|
||||
OpStore %3 %8
|
||||
OpStore %4 %9
|
||||
%17 = OpConvertPtrToU %ulong %1
|
||||
%10 = OpCopyObject %ulong %17
|
||||
OpStore %5 %10
|
||||
%12 = OpLoad %ulong %4
|
||||
%11 = OpCopyObject %ulong %12
|
||||
OpStore %6 %11
|
||||
%14 = OpLoad %ulong %5
|
||||
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14
|
||||
%13 = OpLoad %uint %18
|
||||
OpStore %7 %13
|
||||
%15 = OpLoad %ulong %6
|
||||
%16 = OpLoad %uint %7
|
||||
%19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %15
|
||||
OpStore %19 %16
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -66,14 +66,18 @@ test_ptx!(b64tof64, [111u64], [111u64]);
|
|||
test_ptx!(implicit_param, [34u32], [34u32]);
|
||||
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
|
||||
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
|
||||
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
|
||||
test_ptx!(
|
||||
mul_wide,
|
||||
[0x01_00_00_00__01_00_00_00i64],
|
||||
[0x1_00_00_00_00_00_00i64]
|
||||
);
|
||||
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
|
||||
test_ptx!(shr, [-2i32], [-1i32]);
|
||||
test_ptx!(or, [1u64, 2u64], [3u64]);
|
||||
test_ptx!(sub, [2u64], [1u64]);
|
||||
test_ptx!(min, [555i32, 444i32], [444i32]);
|
||||
test_ptx!(max, [555i32, 444i32], [555i32]);
|
||||
|
||||
test_ptx!(global_array, [0xDEADu32], [1u32]);
|
||||
|
||||
struct DisplayError<T: Debug> {
|
||||
err: T,
|
||||
|
@ -131,7 +135,15 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
|
|||
let mut devices = drv.devices()?;
|
||||
let dev = devices.drain(0..1).next().unwrap();
|
||||
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
|
||||
let module = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None)?;
|
||||
let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
|
||||
let module = match module {
|
||||
Ok(m) => m,
|
||||
Err(err) => {
|
||||
let raw_err_string = log.get_cstring()?;
|
||||
let err_string = raw_err_string.to_string_lossy();
|
||||
panic!("{:?}\n{}", err, err_string);
|
||||
}
|
||||
};
|
||||
let mut kernel = ze::Kernel::new_resident(&module, name)?;
|
||||
kernel.set_indirect_access(
|
||||
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue