Add support for top-level global variables, improve array support

This commit is contained in:
Andrzej Janik 2020-10-02 20:34:45 +02:00
parent 9a65dd32f5
commit 27d25865af
13 changed files with 1085 additions and 378 deletions

View file

@ -1,6 +1,6 @@
use crate::sys;
use std::{
ffi::{c_void, CStr},
ffi::{c_void, CStr, CString},
fmt::Debug,
marker::PhantomData,
mem, ptr,
@ -238,23 +238,16 @@ impl Drop for CommandQueue {
pub struct Module(sys::ze_module_handle_t);
impl Module {
pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t {
self.0
}
pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self {
Self(x)
}
pub fn new_spirv(
ctx: &mut Context,
d: &Device,
bin: &[u8],
opts: Option<&CStr>,
) -> Result<Self> {
) -> (Result<Self>, BuildLog) {
Module::new(ctx, true, d, bin, opts)
}
pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> Result<Self> {
pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
Module::new(ctx, false, d, bin, None)
}
@ -264,7 +257,7 @@ impl Module {
d: &Device,
bin: &[u8],
opts: Option<&CStr>,
) -> Result<Self> {
) -> (Result<Self>, BuildLog) {
let desc = sys::ze_module_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_MODULE_DESC,
pNext: ptr::null(),
@ -279,14 +272,14 @@ impl Module {
pConstants: ptr::null(),
};
let mut result: sys::ze_module_handle_t = ptr::null_mut();
check!(sys::zeModuleCreate(
ctx.0,
d.0,
&desc,
&mut result,
ptr::null_mut()
));
Ok(Module(result))
let mut log_handle = ptr::null_mut();
let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) };
let log = BuildLog(log_handle);
if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
(Result::Err(err), log)
} else {
(Ok(Module(result)), log)
}
}
}
@ -297,6 +290,32 @@ impl Drop for Module {
}
}
pub struct BuildLog(sys::ze_module_build_log_handle_t);
impl BuildLog {
pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t {
self.0
}
pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self {
Self(x)
}
pub fn get_cstring(&self) -> Result<CString> {
let mut size = 0;
check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) };
let mut str_vec = vec![0u8; size];
check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) };
str_vec.pop();
Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?)
}
}
impl Drop for BuildLog {
fn drop(&mut self) {
check_panic!(sys::zeModuleBuildLogDestroy(self.0));
}
}
pub trait SafeRepr {}
impl SafeRepr for u8 {}
impl SafeRepr for i8 {}

View file

@ -1,7 +1,7 @@
use ::std::os::raw::{c_uint, c_void};
use std::ptr;
use super::{context, device, stream::Stream, CUresult};
use super::{device, stream::Stream, CUresult};
pub struct Function {
pub base: l0::Kernel<'static>,

View file

@ -46,7 +46,7 @@ unsafe fn memcpy_impl(
Ok(())
}
pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> {
pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
Ok(())
}

View file

@ -1,4 +1,4 @@
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st};
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
#[cfg(test)]

View file

@ -1,13 +1,10 @@
use std::{
collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice,
sync::Mutex,
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
};
use super::{function::Function, transmute_lifetime, CUresult};
use ptx;
use super::context;
pub type Module = Mutex<ModuleData>;
pub struct ModuleData {
@ -67,14 +64,14 @@ impl ModuleData {
l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
});
match module {
Ok(Ok(module)) => Ok(Mutex::new(Self {
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
base: module,
arg_lens: all_arg_lens
.into_iter()
.map(|(k, v)| (CString::new(k).unwrap(), v))
.collect(),
})),
Ok(Err(err)) => Err(ModuleCompileError::from(err)),
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
Err(err) => Err(ModuleCompileError::from(err)),
}
}
@ -116,6 +113,6 @@ pub fn get_function(
Ok(())
}
pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> {
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
Ok(())
}

View file

@ -30,7 +30,7 @@ mod tests {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::{ffi::c_void, ptr};
use std::ptr;
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
@ -41,7 +41,7 @@ mod tests {
fn default_stream_uses_current_ctx_legacy<T: CudaDriverFns>() {
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_LEGACY);
}
fn default_stream_uses_current_ctx_ptsd<T: CudaDriverFns>() {
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_PER_THREAD);
}

View file

@ -1,6 +1,6 @@
#![allow(non_snake_case)]
use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda};
use crate::{cuda::CUstream, r#impl as notcuda};
use crate::r#impl::CUresult;
use crate::{cuda::CUuuid, r#impl::Encuda};
use ::std::{

View file

@ -1,6 +1,8 @@
use std::convert::From;
use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
use half::f16;
quick_error! {
#[derive(Debug)]
pub enum PtxError {
@ -9,11 +11,17 @@ quick_error! {
display("{}", err)
cause(err)
}
ParseFloat (err: ParseFloatError) {
from()
display("{}", err)
cause(err)
}
SyntaxError {}
NonF32Ftz {}
WrongArrayType {}
WrongVectorElement {}
MultiArrayVariable {}
ZeroDimensionArray {}
}
}
@ -53,7 +61,7 @@ macro_rules! sub_scalar_type {
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone, Copy)]
#[derive(PartialEq, Eq, Clone)]
pub enum $type_name {
$(
$variant ($($field_type),+),
@ -80,11 +88,13 @@ sub_type! {
}
}
type VecU32 = Vec<u32>;
sub_type! {
VariableLocalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
Array(SizedScalarType, u32),
Array(SizedScalarType, VecU32),
}
}
@ -95,7 +105,7 @@ sub_type! {
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
Array(SizedScalarType, u32),
Array(SizedScalarType, VecU32),
}
}
@ -169,7 +179,12 @@ impl<
pub struct Module<'a> {
pub version: (u8, u8),
pub functions: Vec<ParsedFunction<'a>>,
pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>,
}
pub enum Directive<'a, P: ArgParams> {
Variable(Variable<VariableType, P::Id>),
Method(Function<'a, &'a str, Statement<P>>),
}
pub enum MethodDecl<'a, ID> {
@ -187,7 +202,7 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone, Copy)]
#[derive(PartialEq, Eq, Clone)]
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
@ -202,11 +217,11 @@ impl From<FnArgumentType> for Type {
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
#[derive(PartialEq, Eq, Hash, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, u32),
Array(ScalarType, Vec<u32>),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@ -274,6 +289,30 @@ sub_scalar_type!(FloatType {
F64
});
impl ScalarType {
pub fn size_of(self) -> u8 {
match self {
ScalarType::U8 => 1,
ScalarType::S8 => 1,
ScalarType::B8 => 1,
ScalarType::U16 => 2,
ScalarType::S16 => 2,
ScalarType::B16 => 2,
ScalarType::F16 => 2,
ScalarType::U32 => 4,
ScalarType::S32 => 4,
ScalarType::B32 => 4,
ScalarType::F32 => 4,
ScalarType::U64 => 8,
ScalarType::S64 => 8,
ScalarType::B64 => 8,
ScalarType::F64 => 8,
ScalarType::F16x2 => 4,
ScalarType::Pred => 1,
}
}
}
impl Default for ScalarType {
fn default() -> Self {
ScalarType::B8
@ -296,13 +335,26 @@ pub struct Variable<T, ID> {
pub align: Option<u32>,
pub v_type: T,
pub name: ID,
pub array_init: Vec<u8>,
}
#[derive(Eq, PartialEq, Copy, Clone)]
#[derive(Eq, PartialEq, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
Global(VariableLocalType),
}
impl VariableType {
pub fn to_type(&self) -> (StateSpace, Type) {
match self {
VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
}
}
}
impl From<VariableType> for Type {
@ -311,6 +363,7 @@ impl From<VariableType> for Type {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
VariableType::Global(t) => t.into(),
}
}
}
@ -318,7 +371,6 @@ impl From<VariableType> for Type {
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
Sreg,
Const,
Global,
Local,
@ -538,7 +590,7 @@ pub enum LdCacheOperator {
Uncached,
}
#[derive(Copy, Clone)]
#[derive(Clone)]
pub struct MovDetails {
pub typ: Type,
pub src_is_address: bool,
@ -846,3 +898,194 @@ pub struct MinMaxFloat {
pub nan: bool,
pub typ: FloatType,
}
pub enum NumsOrArrays<'a> {
Nums(Vec<&'a str>),
Arrays(Vec<NumsOrArrays<'a>>),
}
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;
let sizeof_t = ScalarType::from(typ).size_of() as usize;
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
let mut result = vec![0; result_size];
self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?;
Ok(result)
}
fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> {
match dimensions.first_mut() {
Some(first) => {
if *first == 0 {
*first = match self {
NumsOrArrays::Nums(v) => v.len() as u32,
NumsOrArrays::Arrays(v) => v.len() as u32,
};
}
}
None => return Err(PtxError::ZeroDimensionArray),
}
for dim in dimensions {
if *dim == 0 {
return Err(PtxError::ZeroDimensionArray);
}
}
Ok(())
}
fn parse_and_copy(
&self,
t: SizedScalarType,
size_of_t: usize,
dimensions: &[u32],
result: &mut [u8],
) -> Result<(), PtxError> {
match dimensions {
[] => unreachable!(),
[dim] => match self {
NumsOrArrays::Nums(vec) => {
if vec.len() > *dim as usize {
return Err(PtxError::ZeroDimensionArray);
}
for (idx, val) in vec.iter().enumerate() {
Self::parse_and_copy_single(t, idx, val, result)?;
}
}
NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray),
},
[first_dim, rest @ ..] => match self {
NumsOrArrays::Arrays(vec) => {
if vec.len() > *first_dim as usize {
return Err(PtxError::ZeroDimensionArray);
}
let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize));
for (idx, this) in vec.iter().enumerate() {
this.parse_and_copy(
t,
size_of_t,
rest,
&mut result[(size_of_element * idx)..],
)?;
}
}
NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray),
},
}
Ok(())
}
fn parse_and_copy_single(
t: SizedScalarType,
idx: usize,
str_val: &str,
output: &mut [u8],
) -> Result<(), PtxError> {
match t {
SizedScalarType::B8 | SizedScalarType::U8 => {
Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?;
}
SizedScalarType::B16 | SizedScalarType::U16 => {
Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?;
}
SizedScalarType::B32 | SizedScalarType::U32 => {
Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?;
}
SizedScalarType::B64 | SizedScalarType::U64 => {
Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?;
}
SizedScalarType::S8 => {
Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?;
}
SizedScalarType::S16 => {
Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?;
}
SizedScalarType::S32 => {
Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?;
}
SizedScalarType::S64 => {
Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?;
}
SizedScalarType::F16 => {
Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?;
}
SizedScalarType::F16x2 => todo!(),
SizedScalarType::F32 => {
Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?;
}
SizedScalarType::F64 => {
Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?;
}
}
Ok(())
}
fn parse_and_copy_single_t<T: Copy + FromStr>(
idx: usize,
str_val: &str,
output: &mut [u8],
) -> Result<(), PtxError>
where
T::Err: Into<PtxError>,
{
let typed_output = unsafe {
std::slice::from_raw_parts_mut::<T>(
output.as_mut_ptr() as *mut _,
output.len() / mem::size_of::<T>(),
)
};
typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn array_fails_multiple_0_dmiensions() {
let inp = NumsOrArrays::Nums(Vec::new());
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
}
#[test]
fn array_fails_on_empty() {
let inp = NumsOrArrays::Nums(Vec::new());
assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
}
#[test]
fn array_auto_sizes_0_dimension() {
let inp = NumsOrArrays::Arrays(vec![
NumsOrArrays::Nums(vec!["1", "2"]),
NumsOrArrays::Nums(vec!["3", "4"]),
]);
let mut dimensions = vec![0u32, 2];
assert_eq!(
vec![1u8, 2, 3, 4],
inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
);
assert_eq!(dimensions, vec![2u32, 2]);
}
#[test]
fn array_fails_wrong_structure() {
let inp = NumsOrArrays::Arrays(vec![
NumsOrArrays::Nums(vec!["1", "2"]),
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
}
#[test]
fn array_fails_too_long_component() {
let inp = NumsOrArrays::Arrays(vec![
NumsOrArrays::Nums(vec!["1", "2", "3"]),
NumsOrArrays::Nums(vec!["4", "5"]),
]);
let mut dimensions = vec![0u32, 2];
assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
}
}

View file

@ -2,6 +2,8 @@ use crate::ast;
use crate::ast::UnwrapWithVec;
use crate::{without_none, vector_index};
use lalrpop_util::ParseError;
grammar<'a>(errors: &mut Vec<ast::PtxError>);
extern {
@ -27,6 +29,7 @@ match {
"{", "}",
"<", ">",
"|",
"=",
".acquire",
".address_size",
".align",
@ -94,7 +97,6 @@ match {
".sat",
".section",
".shared",
".sreg",
".sys",
".target",
".to",
@ -176,8 +178,8 @@ ExtendedID : &'input str = {
}
pub Module: ast::Module<'input> = {
<v:Version> Target <f:Directive*> => {
ast::Module { version: v, functions: without_none(f) }
<v:Version> Target <d:Directive*> => {
ast::Module { version: v, directives: without_none(d) }
}
};
@ -203,11 +205,12 @@ TargetSpecifier = {
"map_f64_to_f32"
};
Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = {
Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
AddressSize => None,
<f:Function> => Some(f),
<f:Function> => Some(ast::Directive::Method(f)),
File => None,
Section => None
Section => None,
<v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)),
};
AddressSize = {
@ -242,9 +245,9 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
};
KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
<v:ParamVariable> => {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
ast::Variable{ align, v_type, name }
ast::Variable{ align, v_type, name, array_init: Vec::new() }
}
}
@ -252,12 +255,12 @@ FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type);
ast::Variable{ align, v_type, name }
ast::Variable{ align, v_type, name, array_init: Vec::new() }
},
<v:ParamVariable> => {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Param(v_type);
ast::Variable{ align, v_type, name }
ast::Variable{ align, v_type, name, array_init: Vec::new() }
}
}
@ -268,7 +271,6 @@ pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>
StateSpaceSpecifier: ast::StateSpace = {
".reg" => ast::StateSpace::Reg,
".sreg" => ast::StateSpace::Sreg,
".const" => ast::StateSpace::Const,
".global" => ast::StateSpace::Global,
".local" => ast::StateSpace::Local,
@ -344,13 +346,13 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
ast::Variable {align, v_type, name}
ast::Variable {align, v_type, name, array_init: Vec::new()}
},
LocalVariable,
<v:ParamVariable> => {
let (align, v_type, name) = v;
let (align, array_init, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type);
ast::Variable {align, v_type, name}
ast::Variable {align, v_type, name, array_init}
},
};
@ -366,32 +368,60 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
}
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
ast::Variable {align, v_type, name}
},
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
ast::Variable {align, v_type, name}
".local" <def:LocalVariableDefinition> => {
let (align, array_init, v_type, name) = def;
ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }
}
}
GlobalVariable: ast::Variable<ast::VariableType, &'input str> = {
".global" <def:LocalVariableDefinition> => {
let (align, array_init, v_type, name) = def;
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
let v_type = ast::VariableParamType::Scalar(t);
(align, Vec::new(), v_type, name)
},
".param" <align:Align?> <arr:ArrayDefinition> => {
let (array_init, name, (t, dimensions)) = arr;
let v_type = ast::VariableParamType::Array(t, dimensions);
(align, array_init, v_type, name)
}
}
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
let v_type = ast::VariableParamType::Scalar(t);
(align, v_type, name)
},
".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
let v_type = ast::VariableParamType::Array(t, arr);
".param" <align:Align?> <arr:ArrayDeclaration> => {
let (name, (t, dimensions)) = arr;
let v_type = ast::VariableParamType::Array(t, dimensions);
(align, v_type, name)
}
}
LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = {
<align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableLocalType::Scalar(t);
(align, Vec::new(), v_type, name)
},
<align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableLocalType::Vector(t, v_len);
(align, Vec::new(), v_type, name)
},
<align:Align?> <arr:ArrayDefinition> => {
let (array_init, name, (t, dimensions)) = arr;
let v_type = ast::VariableLocalType::Array(t, dimensions);
(align, array_init, v_type, name)
}
}
#[inline]
SizedScalarType: ast::SizedScalarType = {
".b8" => ast::SizedScalarType::B8,
@ -431,12 +461,59 @@ ParamScalarType: ast::ParamScalarType = {
".f64" => ast::ParamScalarType::F64,
}
ArraySpecifier: u32 = {
"[" <n:Num> "]" => {
let size = n.parse::<u32>();
size.unwrap_with(errors)
ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = {
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
let mut dims = dims;
let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?;
Ok((
array_init,
name,
(typ, dims)
))
}
};
}
ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = {
<typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? {
let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?;
Ok((name, (typ, dims)))
}
}
// [0] and [] are treated the same
ArrayDimensions: Vec<u32> = {
ArrayEmptyDimension => vec![0u32],
ArrayEmptyDimension <dims:ArrayDimension+> => {
let mut dims = dims;
let mut result = vec![0u32];
result.append(&mut dims);
result
},
<dims:ArrayDimension+> => dims
}
ArrayEmptyDimension = {
"[" "]"
}
ArrayDimension: u32 = {
"[" <n:Num> "]" =>? {
str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
}
}
ArrayInitializer: ast::NumsOrArrays<'input> = {
"=" <nums:NumsOrArraysBracket> => nums
}
NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
"{" <nums:NumsOrArrays> "}" => nums
}
NumsOrArrays: ast::NumsOrArrays<'input> = {
<n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
<n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
}
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
@ -1244,3 +1321,11 @@ Comma<T>: Vec<T> = {
}
}
};
CommaNonEmpty<T>: Vec<T> = {
<v:(<T> ",")*> <e:T> => {
let mut v = v;
v.push(e);
v
}
};

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.global .s32 foobar[4] = {1};
.visible .entry global_array(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u32 temp;
mov.u64 in_addr, foobar;
ld.param.u64 out_addr, [output];
ld.global.u32 temp, [in_addr];
st.global.u32 [out_addr], temp;
ret;
}

View file

@ -0,0 +1,54 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
%22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "global_array" %1
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uint_4 = OpConstant %uint 4
%_arr_uint_uint_4 = OpTypeArray %uint %uint_4
%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4
%uint_4_0 = OpConstant %uint 4
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
%31 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0
%1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %31
%ulong = OpTypeInt 64 0
%33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%2 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%20 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_uint Function
OpStore %3 %8
OpStore %4 %9
%17 = OpConvertPtrToU %ulong %1
%10 = OpCopyObject %ulong %17
OpStore %5 %10
%12 = OpLoad %ulong %4
%11 = OpCopyObject %ulong %12
OpStore %6 %11
%14 = OpLoad %ulong %5
%18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14
%13 = OpLoad %uint %18
OpStore %7 %13
%15 = OpLoad %ulong %6
%16 = OpLoad %uint %7
%19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %15
OpStore %19 %16
OpReturn
OpFunctionEnd

View file

@ -66,14 +66,18 @@ test_ptx!(b64tof64, [111u64], [111u64]);
test_ptx!(implicit_param, [34u32], [34u32]);
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
test_ptx!(
mul_wide,
[0x01_00_00_00__01_00_00_00i64],
[0x1_00_00_00_00_00_00i64]
);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
test_ptx!(global_array, [0xDEADu32], [1u32]);
struct DisplayError<T: Debug> {
err: T,
@ -131,7 +135,15 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
let module = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None)?;
let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
let module = match module {
Ok(m) => m,
Err(err) => {
let raw_err_string = log.get_cstring()?;
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
};
let mut kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,

File diff suppressed because it is too large Load diff