mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-04 15:19:49 +00:00
Parse Linux vectorAdd debug PTX kernel
This commit is contained in:
parent
2e4cadc2ab
commit
87cc72494e
10 changed files with 223 additions and 24 deletions
|
@ -4,10 +4,13 @@ use crate::{
|
||||||
cuda_impl,
|
cuda_impl,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{context, device, Decuda, Encuda};
|
use super::{context, device, module, Decuda, Encuda};
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use std::os::raw::{c_uint, c_ulong, c_ushort};
|
use std::os::raw::{c_uint, c_ulong, c_ushort};
|
||||||
use std::{ffi::c_void, ptr, slice};
|
use std::{
|
||||||
|
ffi::{c_void, CStr, CString},
|
||||||
|
ptr, slice,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn get(table: *mut *const std::os::raw::c_void, id: *const CUuuid) -> CUresult {
|
pub fn get(table: *mut *const std::os::raw::c_void, id: *const CUuuid) -> CUresult {
|
||||||
if table == ptr::null_mut() || id == ptr::null_mut() {
|
if table == ptr::null_mut() || id == ptr::null_mut() {
|
||||||
|
@ -204,6 +207,12 @@ unsafe extern "C" fn get_module_from_cubin(
|
||||||
{
|
{
|
||||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||||
}
|
}
|
||||||
|
let result = result.decuda();
|
||||||
|
let mut dev_count = 0;
|
||||||
|
let cu_result = device::get_count(&mut dev_count);
|
||||||
|
if cu_result != CUresult::CUDA_SUCCESS {
|
||||||
|
return cu_result;
|
||||||
|
}
|
||||||
let fatbin_header = (*fatbinc_wrapper).data;
|
let fatbin_header = (*fatbinc_wrapper).data;
|
||||||
if (*fatbin_header).magic != FATBIN_MAGIC || (*fatbin_header).version != FATBIN_VERSION {
|
if (*fatbin_header).magic != FATBIN_MAGIC || (*fatbin_header).version != FATBIN_VERSION {
|
||||||
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
return CUresult::CUDA_ERROR_INVALID_VALUE;
|
||||||
|
@ -219,7 +228,21 @@ unsafe extern "C" fn get_module_from_cubin(
|
||||||
);
|
);
|
||||||
let kernel_text =
|
let kernel_text =
|
||||||
lz4::block::decompress(slice, Some((*file).uncompressed_payload as i32)).unwrap();
|
lz4::block::decompress(slice, Some((*file).uncompressed_payload as i32)).unwrap();
|
||||||
return CUresult::CUDA_SUCCESS;
|
let kernel_text_string = match CStr::from_bytes_with_nul(&kernel_text) {
|
||||||
|
Ok(c_str) => match c_str.to_str() {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => continue,
|
||||||
|
},
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let module = module::Module::compile(kernel_text_string, dev_count as usize);
|
||||||
|
match module {
|
||||||
|
Ok(module) => {
|
||||||
|
*result = Box::into_raw(Box::new(module));
|
||||||
|
return CUresult::CUDA_SUCCESS;
|
||||||
|
}
|
||||||
|
Err(_) => continue,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CUresult::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE
|
CUresult::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUresult};
|
use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUresult, CUmodule};
|
||||||
use std::{ffi::c_void, mem::ManuallyDrop, os::raw::c_int, sync::Mutex};
|
use std::{ffi::c_void, mem::ManuallyDrop, os::raw::c_int, sync::Mutex};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -8,6 +8,7 @@ pub mod context;
|
||||||
pub mod device;
|
pub mod device;
|
||||||
pub mod export_table;
|
pub mod export_table;
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
|
pub mod module;
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
pub fn unimplemented() -> CUresult {
|
pub fn unimplemented() -> CUresult {
|
||||||
|
@ -232,3 +233,8 @@ impl Decuda<*mut c_void> for CUdeviceptr {
|
||||||
self.0 as *mut _
|
self.0 as *mut _
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> CudaRepr for CUmodule {
|
||||||
|
type Impl = *mut module::Module;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
44
notcuda/src/impl/module.rs
Normal file
44
notcuda/src/impl/module.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
use ptx;
|
||||||
|
|
||||||
|
pub struct Module {
|
||||||
|
spirv_code: Vec<u32>,
|
||||||
|
compiled_code: Vec<Option<Vec<u8>>>, // size as big as the number of devices
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum ModuleCompileError<'a> {
|
||||||
|
Parse(
|
||||||
|
Vec<ptx::ast::PtxError>,
|
||||||
|
Option<ptx::ParseError<usize, ptx::Token<'a>, &'a str>>,
|
||||||
|
),
|
||||||
|
Compile(ptx::SpirvError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ModuleCompileError<'a> {
|
||||||
|
pub fn get_build_log(&self) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> From<ptx::SpirvError> for ModuleCompileError<'a> {
|
||||||
|
fn from(err: ptx::SpirvError) -> Self {
|
||||||
|
ModuleCompileError::Compile(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn compile(ptx_text: &str, devices: usize) -> Result<Self, ModuleCompileError> {
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
|
||||||
|
let ast = match ast {
|
||||||
|
Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
|
||||||
|
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
|
||||||
|
Ok(ast) => ast,
|
||||||
|
};
|
||||||
|
let spirv = ptx::to_spirv(ast)?;
|
||||||
|
Ok(Self {
|
||||||
|
spirv_code: spirv,
|
||||||
|
compiled_code: vec![None; devices],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ extern crate lz4;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate paste;
|
extern crate paste;
|
||||||
|
extern crate ptx;
|
||||||
|
|
||||||
#[allow(warnings)]
|
#[allow(warnings)]
|
||||||
mod cuda;
|
mod cuda;
|
||||||
|
|
|
@ -7,7 +7,7 @@ edition = "2018"
|
||||||
[lib]
|
[lib]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
lalrpop-util = "0.18.1"
|
lalrpop-util = "0.19"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
rspirv = "0.6"
|
rspirv = "0.6"
|
||||||
spirv_headers = "1.4"
|
spirv_headers = "1.4"
|
||||||
|
@ -16,7 +16,7 @@ bit-vec = "0.6"
|
||||||
half ="1.6"
|
half ="1.6"
|
||||||
|
|
||||||
[build-dependencies.lalrpop]
|
[build-dependencies.lalrpop]
|
||||||
version = "0.18.1"
|
version = "0.19"
|
||||||
features = ["lexer"]
|
features = ["lexer"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -11,6 +11,7 @@ quick_error! {
|
||||||
}
|
}
|
||||||
SyntaxError {}
|
SyntaxError {}
|
||||||
NonF32Ftz {}
|
NonF32Ftz {}
|
||||||
|
WrongArrayType {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,11 +51,16 @@ pub struct Module<'a> {
|
||||||
pub functions: Vec<Function<'a>>,
|
pub functions: Vec<Function<'a>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum FunctionReturn<'a> {
|
||||||
|
Func(Vec<Argument<'a>>),
|
||||||
|
Kernel,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Function<'a> {
|
pub struct Function<'a> {
|
||||||
pub kernel: bool,
|
pub func_directive: FunctionReturn<'a>,
|
||||||
pub name: &'a str,
|
pub name: &'a str,
|
||||||
pub args: Vec<Argument<'a>>,
|
pub args: Vec<Argument<'a>>,
|
||||||
pub body: Vec<Statement<ParsedArgParams<'a>>>,
|
pub body: Option<Vec<Statement<ParsedArgParams<'a>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
@ -68,6 +74,7 @@ pub struct Argument<'a> {
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
Scalar(ScalarType),
|
Scalar(ScalarType),
|
||||||
ExtendedScalar(ExtendedScalarType),
|
ExtendedScalar(ExtendedScalarType),
|
||||||
|
Array(ScalarType, u32),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<FloatType> for Type {
|
impl From<FloatType> for Type {
|
||||||
|
@ -173,10 +180,12 @@ pub enum Statement<P: ArgParams> {
|
||||||
Label(P::ID),
|
Label(P::ID),
|
||||||
Variable(Variable<P>),
|
Variable(Variable<P>),
|
||||||
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
|
Instruction(Option<PredAt<P::ID>>, Instruction<P>),
|
||||||
|
Block(Vec<Statement<P>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Variable<P: ArgParams> {
|
pub struct Variable<P: ArgParams> {
|
||||||
pub space: StateSpace,
|
pub space: StateSpace,
|
||||||
|
pub align: Option<u32>,
|
||||||
pub v_type: Type,
|
pub v_type: Type,
|
||||||
pub name: P::ID,
|
pub name: P::ID,
|
||||||
pub count: Option<u32>,
|
pub count: Option<u32>,
|
||||||
|
@ -190,6 +199,7 @@ pub enum StateSpace {
|
||||||
Global,
|
Global,
|
||||||
Local,
|
Local,
|
||||||
Shared,
|
Shared,
|
||||||
|
Param,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PredAt<ID> {
|
pub struct PredAt<ID> {
|
||||||
|
@ -211,6 +221,23 @@ pub enum Instruction<P: ArgParams> {
|
||||||
Shl(ShlType, Arg3<P>),
|
Shl(ShlType, Arg3<P>),
|
||||||
St(StData, Arg2St<P>),
|
St(StData, Arg2St<P>),
|
||||||
Ret(RetData),
|
Ret(RetData),
|
||||||
|
Call(CallData, ArgCall<P>),
|
||||||
|
Abs(AbsDetails, Arg2<P>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CallData {
|
||||||
|
pub uniform: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AbsDetails {
|
||||||
|
pub flush_to_zero: bool,
|
||||||
|
pub typ: ScalarType
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ArgCall<P: ArgParams> {
|
||||||
|
pub ret_params: Vec<P::ID>,
|
||||||
|
pub func: P::ID,
|
||||||
|
pub param_list: Vec<P::ID>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ArgParams {
|
pub trait ArgParams {
|
||||||
|
|
|
@ -27,8 +27,11 @@ pub mod ast;
|
||||||
mod test;
|
mod test;
|
||||||
mod translate;
|
mod translate;
|
||||||
|
|
||||||
pub use ast::Module;
|
pub use lalrpop_util::ParseError as ParseError;
|
||||||
pub use translate::to_spirv;
|
pub use lalrpop_util::lexer::Token as Token;
|
||||||
|
pub use crate::ptx::ModuleParser as ModuleParser;
|
||||||
|
pub use translate::to_spirv as to_spirv;
|
||||||
|
pub use rspirv::dr::Error as SpirvError;
|
||||||
|
|
||||||
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
|
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
|
||||||
x.into_iter().filter_map(|x| x).collect()
|
x.into_iter().filter_map(|x| x).collect()
|
||||||
|
|
|
@ -24,6 +24,7 @@ match {
|
||||||
"|",
|
"|",
|
||||||
".acquire",
|
".acquire",
|
||||||
".address_size",
|
".address_size",
|
||||||
|
".align",
|
||||||
".and",
|
".and",
|
||||||
".b16",
|
".b16",
|
||||||
".b32",
|
".b32",
|
||||||
|
@ -108,8 +109,10 @@ match {
|
||||||
".xor",
|
".xor",
|
||||||
} else {
|
} else {
|
||||||
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
|
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
|
||||||
|
"abs",
|
||||||
"add",
|
"add",
|
||||||
"bra",
|
"bra",
|
||||||
|
"call",
|
||||||
"cvt",
|
"cvt",
|
||||||
"cvta",
|
"cvta",
|
||||||
"debug",
|
"debug",
|
||||||
|
@ -135,8 +138,10 @@ match {
|
||||||
}
|
}
|
||||||
|
|
||||||
ExtendedID : &'input str = {
|
ExtendedID : &'input str = {
|
||||||
|
"abs",
|
||||||
"add",
|
"add",
|
||||||
"bra",
|
"bra",
|
||||||
|
"call",
|
||||||
"cvt",
|
"cvt",
|
||||||
"cvta",
|
"cvta",
|
||||||
"debug",
|
"debug",
|
||||||
|
@ -197,9 +202,9 @@ AddressSize = {
|
||||||
|
|
||||||
Function: ast::Function<'input> = {
|
Function: ast::Function<'input> = {
|
||||||
LinkingDirective*
|
LinkingDirective*
|
||||||
<kernel:IsKernel>
|
<func_directive:FunctionReturn>
|
||||||
<name:ExtendedID>
|
<name:ExtendedID>
|
||||||
"(" <args:Comma<FunctionInput>> ")"
|
<args:Arguments>
|
||||||
<body:FunctionBody> => ast::Function{<>}
|
<body:FunctionBody> => ast::Function{<>}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -209,11 +214,15 @@ LinkingDirective = {
|
||||||
".weak"
|
".weak"
|
||||||
};
|
};
|
||||||
|
|
||||||
IsKernel: bool = {
|
FunctionReturn: ast::FunctionReturn<'input> = {
|
||||||
".entry" => true,
|
".entry" => ast::FunctionReturn::Kernel,
|
||||||
".func" => false
|
".func" <args:Arguments?> => ast::FunctionReturn::Func(args.unwrap_or_else(|| Vec::new()))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Arguments: Vec<ast::Argument<'input>> = {
|
||||||
|
"(" <args:Comma<FunctionInput>> ")" => args
|
||||||
|
}
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
|
||||||
FunctionInput: ast::Argument<'input> = {
|
FunctionInput: ast::Argument<'input> = {
|
||||||
".param" <_type:ScalarType> <name:ExtendedID> => {
|
".param" <_type:ScalarType> <name:ExtendedID> => {
|
||||||
|
@ -226,8 +235,9 @@ FunctionInput: ast::Argument<'input> = {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) FunctionBody: Vec<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>> = {
|
||||||
"{" <s:Statement*> "}" => { without_none(s) }
|
"{" <s:Statement*> "}" => { Some(without_none(s)) },
|
||||||
|
";" => { None }
|
||||||
};
|
};
|
||||||
|
|
||||||
StateSpaceSpecifier: ast::StateSpace = {
|
StateSpaceSpecifier: ast::StateSpace = {
|
||||||
|
@ -236,7 +246,8 @@ StateSpaceSpecifier: ast::StateSpace = {
|
||||||
".const" => ast::StateSpace::Const,
|
".const" => ast::StateSpace::Const,
|
||||||
".global" => ast::StateSpace::Global,
|
".global" => ast::StateSpace::Global,
|
||||||
".local" => ast::StateSpace::Local,
|
".local" => ast::StateSpace::Local,
|
||||||
".shared" => ast::StateSpace::Shared
|
".shared" => ast::StateSpace::Shared,
|
||||||
|
".param" => ast::StateSpace::Param, // used to prepare function call
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -276,7 +287,8 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
|
||||||
<l:Label> => Some(ast::Statement::Label(l)),
|
<l:Label> => Some(ast::Statement::Label(l)),
|
||||||
DebugDirective => None,
|
DebugDirective => None,
|
||||||
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
|
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
|
||||||
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i))
|
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
|
||||||
|
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
|
||||||
};
|
};
|
||||||
|
|
||||||
DebugDirective: () = {
|
DebugDirective: () = {
|
||||||
|
@ -292,10 +304,32 @@ Label: &'input str = {
|
||||||
<id:ExtendedID> ":" => id
|
<id:ExtendedID> ":" => id
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Align: u32 = {
|
||||||
|
".align" <a:Num> => {
|
||||||
|
let align = a.parse::<u32>();
|
||||||
|
align.unwrap_with(errors)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
|
Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
|
||||||
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
|
<s:StateSpaceSpecifier> <a:Align?> <t:Type> <v:VariableName> <arr: ArraySpecifier?> => {
|
||||||
let (name, count) = v;
|
let (name, count) = v;
|
||||||
ast::Variable { space: s, v_type: t, name: name, count: count }
|
let t = match (t, arr) {
|
||||||
|
(ast::Type::Scalar(st), Some(arr_size)) => ast::Type::Array(st, arr_size),
|
||||||
|
(t, Some(_)) => {
|
||||||
|
errors.push(ast::PtxError::WrongArrayType);
|
||||||
|
t
|
||||||
|
},
|
||||||
|
(t, None) => t,
|
||||||
|
};
|
||||||
|
ast::Variable { space: s, align: a, v_type: t, name: name, count: count }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ArraySpecifier: u32 = {
|
||||||
|
"[" <n:Num> "]" => {
|
||||||
|
let size = n.parse::<u32>();
|
||||||
|
size.unwrap_with(errors)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -326,6 +360,8 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
InstSt,
|
InstSt,
|
||||||
InstRet,
|
InstRet,
|
||||||
InstCvta,
|
InstCvta,
|
||||||
|
InstCall,
|
||||||
|
InstAbs,
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||||
|
@ -819,6 +855,36 @@ CvtaSize: ast::CvtaSize = {
|
||||||
".u64" => ast::CvtaSize::U64,
|
".u64" => ast::CvtaSize::U64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
|
||||||
|
InstCall: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
|
"call" <u:".uni"?> <a:ArgCall> => ast::Instruction::Call(ast::CallData { uniform: u.is_some() }, a)
|
||||||
|
};
|
||||||
|
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
|
||||||
|
InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
|
||||||
|
"abs" <t:SignedIntType> <a:Arg2> => {
|
||||||
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: t }, a)
|
||||||
|
},
|
||||||
|
"abs" <f:".ftz"?> ".f32" <a:Arg2> => {
|
||||||
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F32 }, a)
|
||||||
|
},
|
||||||
|
"abs" ".f64" <a:Arg2> => {
|
||||||
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: false, typ: ast::ScalarType::F64 }, a)
|
||||||
|
},
|
||||||
|
"abs" <f:".ftz"?> ".f16" <a:Arg2> => {
|
||||||
|
ast::Instruction::Abs(ast::AbsDetails { flush_to_zero: f.is_some(), typ: ast::ScalarType::F16 }, a)
|
||||||
|
},
|
||||||
|
"abs" <f:".ftz"?> ".f16x2" <a:Arg2> => {
|
||||||
|
todo!()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
SignedIntType: ast::ScalarType = {
|
||||||
|
".s16" => ast::ScalarType::S16,
|
||||||
|
".s32" => ast::ScalarType::S32,
|
||||||
|
".s64" => ast::ScalarType::S64,
|
||||||
|
};
|
||||||
|
|
||||||
Operand: ast::Operand<&'input str> = {
|
Operand: ast::Operand<&'input str> = {
|
||||||
<r:ExtendedID> => ast::Operand::Reg(r),
|
<r:ExtendedID> => ast::Operand::Reg(r),
|
||||||
<r:ExtendedID> "+" <o:Num> => {
|
<r:ExtendedID> "+" <o:Num> => {
|
||||||
|
@ -873,6 +939,12 @@ Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
|
||||||
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
|
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
ArgCall: ast::ArgCall<ast::ParsedArgParams<'input>> = {
|
||||||
|
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{<>},
|
||||||
|
<func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{ret_params: Vec::new(), func, param_list},
|
||||||
|
<func:ExtendedID> => ast::ArgCall{ret_params: Vec::new(), func, param_list: Vec::new()},
|
||||||
|
};
|
||||||
|
|
||||||
OptionalDst: &'input str = {
|
OptionalDst: &'input str = {
|
||||||
"|" <dst2:ExtendedID> => dst2
|
"|" <dst2:ExtendedID> => dst2
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,3 +25,11 @@ fn operands_ptx() {
|
||||||
let vector_add = include_str!("operands.ptx");
|
let vector_add = include_str!("operands.ptx");
|
||||||
parse_and_assert(vector_add);
|
parse_and_assert(vector_add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn _Z9vectorAddPKfS0_Pfi_ptx() {
|
||||||
|
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
|
||||||
|
parse_and_assert(vector_add);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ impl SpirvType {
|
||||||
let key = match t {
|
let key = match t {
|
||||||
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
|
ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
|
||||||
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
|
ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
|
||||||
|
ast::Type::Array(_, _) => todo!(),
|
||||||
};
|
};
|
||||||
SpirvType::Pointer(key, sc)
|
SpirvType::Pointer(key, sc)
|
||||||
}
|
}
|
||||||
|
@ -26,6 +27,7 @@ impl From<ast::Type> for SpirvType {
|
||||||
match t {
|
match t {
|
||||||
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
|
||||||
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
|
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
|
||||||
|
ast::Type::Array(_, _) => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,10 +197,13 @@ fn emit_function<'a>(
|
||||||
let func_type = get_function_type(builder, map, &f.args);
|
let func_type = get_function_type(builder, map, &f.args);
|
||||||
let func_id =
|
let func_id =
|
||||||
builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
|
builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
|
||||||
if f.kernel {
|
match f.func_directive {
|
||||||
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
|
ast::FunctionReturn::Kernel => {
|
||||||
|
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[])
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
}
|
}
|
||||||
let (mut func_body, unique_ids) = to_ssa(&f.args, f.body);
|
let (mut func_body, unique_ids) = to_ssa(&f.args, f.body.unwrap_or_else(|| todo!()));
|
||||||
let id_offset = builder.reserve_ids(unique_ids);
|
let id_offset = builder.reserve_ids(unique_ids);
|
||||||
emit_function_args(builder, id_offset, map, &f.args);
|
emit_function_args(builder, id_offset, map, &f.args);
|
||||||
func_body = apply_id_offset(func_body, id_offset);
|
func_body = apply_id_offset(func_body, id_offset);
|
||||||
|
@ -266,6 +271,7 @@ fn normalize_predicates(
|
||||||
let mut result = Vec::with_capacity(func.len());
|
let mut result = Vec::with_capacity(func.len());
|
||||||
for s in func {
|
for s in func {
|
||||||
match s {
|
match s {
|
||||||
|
ast::Statement::Block(_) => todo!(),
|
||||||
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
||||||
ast::Statement::Instruction(pred, inst) => {
|
ast::Statement::Instruction(pred, inst) => {
|
||||||
if let Some(pred) = pred {
|
if let Some(pred) = pred {
|
||||||
|
@ -652,6 +658,8 @@ fn emit_function_body_ops(
|
||||||
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
||||||
}
|
}
|
||||||
Statement::Instruction(inst) => match inst {
|
Statement::Instruction(inst) => match inst {
|
||||||
|
ast::Instruction::Abs(_, _) => todo!(),
|
||||||
|
ast::Instruction::Call(_,_) => todo!(),
|
||||||
// SPIR-V does not support marking jumps as guaranteed-converged
|
// SPIR-V does not support marking jumps as guaranteed-converged
|
||||||
ast::Instruction::Bra(_, arg) => {
|
ast::Instruction::Bra(_, arg) => {
|
||||||
builder.branch(arg.src)?;
|
builder.branch(arg.src)?;
|
||||||
|
@ -1076,6 +1084,7 @@ fn expand_map_variables<'a>(
|
||||||
s: ast::Statement<ast::ParsedArgParams<'a>>,
|
s: ast::Statement<ast::ParsedArgParams<'a>>,
|
||||||
) {
|
) {
|
||||||
match s {
|
match s {
|
||||||
|
ast::Statement::Block(_) => todo!(),
|
||||||
ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))),
|
ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))),
|
||||||
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
|
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
|
||||||
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
|
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
|
||||||
|
@ -1086,6 +1095,7 @@ fn expand_map_variables<'a>(
|
||||||
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
|
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
|
||||||
result.push(ast::Statement::Variable(ast::Variable {
|
result.push(ast::Statement::Variable(ast::Variable {
|
||||||
space: var.space,
|
space: var.space,
|
||||||
|
align: var.align,
|
||||||
v_type: var.v_type,
|
v_type: var.v_type,
|
||||||
name: new_id,
|
name: new_id,
|
||||||
count: None,
|
count: None,
|
||||||
|
@ -1096,6 +1106,7 @@ fn expand_map_variables<'a>(
|
||||||
let new_id = id_defs.add_def(var.name, Some(var.v_type));
|
let new_id = id_defs.add_def(var.name, Some(var.v_type));
|
||||||
result.push(ast::Statement::Variable(ast::Variable {
|
result.push(ast::Statement::Variable(ast::Variable {
|
||||||
space: var.space,
|
space: var.space,
|
||||||
|
align: var.align,
|
||||||
v_type: var.v_type,
|
v_type: var.v_type,
|
||||||
name: new_id,
|
name: new_id,
|
||||||
count: None,
|
count: None,
|
||||||
|
@ -1307,6 +1318,8 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
|
||||||
visitor: &mut V,
|
visitor: &mut V,
|
||||||
) -> ast::Instruction<U> {
|
) -> ast::Instruction<U> {
|
||||||
match self {
|
match self {
|
||||||
|
ast::Instruction::Abs(_, _) => todo!(),
|
||||||
|
ast::Instruction::Call(_, _) => todo!(),
|
||||||
ast::Instruction::Ld(d, a) => {
|
ast::Instruction::Ld(d, a) => {
|
||||||
let inst_type = d.typ;
|
let inst_type = d.typ;
|
||||||
ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type))))
|
ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type))))
|
||||||
|
@ -1432,6 +1445,8 @@ impl ast::Instruction<ExpandedArgParams> {
|
||||||
|
|
||||||
fn jump_target(&self) -> Option<spirv::Word> {
|
fn jump_target(&self) -> Option<spirv::Word> {
|
||||||
match self {
|
match self {
|
||||||
|
ast::Instruction::Abs(_, _) => todo!(),
|
||||||
|
ast::Instruction::Call(_, _) => todo!(),
|
||||||
ast::Instruction::Bra(_, a) => Some(a.src),
|
ast::Instruction::Bra(_, a) => Some(a.src),
|
||||||
ast::Instruction::Ld(_, _)
|
ast::Instruction::Ld(_, _)
|
||||||
| ast::Instruction::Mov(_, _)
|
| ast::Instruction::Mov(_, _)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue