mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Add a simple (and failing) PTX end-to-end test
This commit is contained in:
parent
0c0f0e5a6b
commit
d0aa5ba564
9 changed files with 493 additions and 89 deletions
|
@ -17,3 +17,6 @@ bit-vec = "0.6"
|
|||
[build-dependencies.lalrpop]
|
||||
version = "0.18.1"
|
||||
features = ["lexer"]
|
||||
|
||||
[dev-dependencies]
|
||||
ocl = { version = "0.19", features = ["opencl_version_1_1", "opencl_version_1_2", "opencl_version_2_1"] }
|
||||
|
|
|
@ -189,19 +189,19 @@ pub enum MovOperand<ID> {
|
|||
|
||||
pub enum VectorPrefix {
|
||||
V2,
|
||||
V4
|
||||
V4,
|
||||
}
|
||||
|
||||
pub struct LdData {
|
||||
pub qualifier: LdQualifier,
|
||||
pub qualifier: LdStQualifier,
|
||||
pub state_space: LdStateSpace,
|
||||
pub caching: LdCacheOperator,
|
||||
pub vector: Option<VectorPrefix>,
|
||||
pub typ: ScalarType
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum LdQualifier {
|
||||
pub enum LdStQualifier {
|
||||
Weak,
|
||||
Volatile,
|
||||
Relaxed(LdScope),
|
||||
|
@ -212,7 +212,7 @@ pub enum LdQualifier {
|
|||
pub enum LdScope {
|
||||
Cta,
|
||||
Gpu,
|
||||
Sys
|
||||
Sys,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
|
@ -225,14 +225,13 @@ pub enum LdStateSpace {
|
|||
Shared,
|
||||
}
|
||||
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum LdCacheOperator {
|
||||
Cached,
|
||||
L2Only,
|
||||
Streaming,
|
||||
LastUse,
|
||||
Uncached
|
||||
Uncached,
|
||||
}
|
||||
|
||||
pub struct MovData {}
|
||||
|
@ -248,13 +247,38 @@ pub struct SetpBoolData {}
|
|||
pub struct NotData {}
|
||||
|
||||
pub struct BraData {
|
||||
pub uniform: bool
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
||||
pub struct CvtData {}
|
||||
|
||||
pub struct ShlData {}
|
||||
|
||||
pub struct StData {}
|
||||
pub struct StData {
|
||||
pub qualifier: LdStQualifier,
|
||||
pub state_space: StStateSpace,
|
||||
pub caching: StCacheOperator,
|
||||
pub vector: Option<VectorPrefix>,
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
pub struct RetData {}
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum StStateSpace {
|
||||
Generic,
|
||||
Global,
|
||||
Local,
|
||||
Param,
|
||||
Shared,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum StCacheOperator {
|
||||
Writeback,
|
||||
L2Only,
|
||||
Streaming,
|
||||
Writethrough,
|
||||
}
|
||||
|
||||
pub struct RetData {
|
||||
pub uniform: bool,
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ extern crate lalrpop_util;
|
|||
extern crate quick_error;
|
||||
|
||||
extern crate bit_vec;
|
||||
#[cfg(test)]
|
||||
extern crate ocl;
|
||||
extern crate rspirv;
|
||||
extern crate spirv_headers as spirv;
|
||||
|
||||
|
|
|
@ -188,10 +188,10 @@ Instruction: ast::Instruction<&'input str> = {
|
|||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
|
||||
InstLd: ast::Instruction<&'input str> = {
|
||||
"ld" <q:LdQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
|
||||
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ID> "," "[" <src:Operand> "]" => {
|
||||
ast::Instruction::Ld(
|
||||
ast::LdData {
|
||||
qualifier: q.unwrap_or(ast::LdQualifier::Weak),
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
|
||||
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
|
||||
vector: v,
|
||||
|
@ -202,11 +202,11 @@ InstLd: ast::Instruction<&'input str> = {
|
|||
}
|
||||
};
|
||||
|
||||
LdQualifier: ast::LdQualifier = {
|
||||
".weak" => ast::LdQualifier::Weak,
|
||||
".volatile" => ast::LdQualifier::Volatile,
|
||||
".relaxed" <s:LdScope> => ast::LdQualifier::Relaxed(s),
|
||||
".acquire" <s:LdScope> => ast::LdQualifier::Acquire(s),
|
||||
LdStQualifier: ast::LdStQualifier = {
|
||||
".weak" => ast::LdStQualifier::Weak,
|
||||
".volatile" => ast::LdStQualifier::Volatile,
|
||||
".relaxed" <s:LdScope> => ast::LdStQualifier::Relaxed(s),
|
||||
".acquire" <s:LdScope> => ast::LdStQualifier::Acquire(s),
|
||||
};
|
||||
|
||||
LdScope: ast::LdScope = {
|
||||
|
@ -379,29 +379,39 @@ ShlType = {
|
|||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
|
||||
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
|
||||
InstSt: ast::Instruction<&'input str> = {
|
||||
"st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" <dst:ID> "]" "," <src:Operand> => {
|
||||
ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src})
|
||||
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <dst:ID> "]" "," <src:Operand> => {
|
||||
ast::Instruction::St(
|
||||
ast::StData {
|
||||
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
|
||||
state_space: ss.unwrap_or(ast::StStateSpace::Generic),
|
||||
caching: cop.unwrap_or(ast::StCacheOperator::Writeback),
|
||||
vector: v,
|
||||
typ: t
|
||||
},
|
||||
ast::Arg2{dst:dst, src:src}
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
StStateSpace = {
|
||||
".global",
|
||||
".local",
|
||||
".param",
|
||||
".shared",
|
||||
StStateSpace: ast::StStateSpace = {
|
||||
".global" => ast::StStateSpace::Global,
|
||||
".local" => ast::StStateSpace::Local,
|
||||
".param" => ast::StStateSpace::Param,
|
||||
".shared" => ast::StStateSpace::Shared,
|
||||
};
|
||||
|
||||
StCacheOperator = {
|
||||
".wb",
|
||||
".cg",
|
||||
".cs",
|
||||
".wt",
|
||||
StCacheOperator: ast::StCacheOperator = {
|
||||
".wb" => ast::StCacheOperator::Writeback,
|
||||
".cg" => ast::StCacheOperator::L2Only,
|
||||
".cs" => ast::StCacheOperator::Streaming,
|
||||
".wt" => ast::StCacheOperator::Writethrough,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
|
||||
InstRet: ast::Instruction<&'input str> = {
|
||||
"ret" ".uni"? => ast::Instruction::Ret(ast::RetData{})
|
||||
"ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
|
||||
};
|
||||
|
||||
Operand: ast::Operand<&'input str> = {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use super::ptx;
|
||||
|
||||
mod ops;
|
||||
|
||||
fn parse_and_assert(s: &str) {
|
||||
let mut errors = Vec::new();
|
||||
ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
|
||||
|
|
20
ptx/src/test/ops/ld_st/ld_st.ptx
Normal file
20
ptx/src/test/ops/ld_st/ld_st.ptx
Normal file
|
@ -0,0 +1,20 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry ld_st(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
st.u64 [out_addr], temp;
|
||||
ret;
|
||||
}
|
1
ptx/src/test/ops/ld_st/mod.rs
Normal file
1
ptx/src/test/ops/ld_st/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
test_ptx!(ld_st, [1u64], [1u64]);
|
280
ptx/src/test/ops/mod.rs
Normal file
280
ptx/src/test/ops/mod.rs
Normal file
|
@ -0,0 +1,280 @@
|
|||
use crate::ptx;
|
||||
use crate::translate;
|
||||
use ocl::{Buffer, Context, Device, Kernel, OclPrm, Platform, Program, Queue};
|
||||
use std::error;
|
||||
use std::ffi::{c_void, CString};
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
use std::mem;
|
||||
use std::slice;
|
||||
use std::{ptr, str};
|
||||
|
||||
macro_rules! test_ptx {
|
||||
($fn_name:ident, $input:expr, $output:expr) => {
|
||||
#[test]
|
||||
fn $fn_name() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
|
||||
let input = $input;
|
||||
let mut output = $output;
|
||||
crate::test::ops::test_ptx_assert(stringify!($fn_name), ptx, &input, &mut output)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod ld_st;
|
||||
|
||||
const CL_DEVICE_IL_VERSION: u32 = 0x105B;
|
||||
|
||||
struct DisplayError<T: Display + Debug> {
|
||||
err: T,
|
||||
}
|
||||
|
||||
impl<T: Display + Debug> Display for DisplayError<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Display::fmt(&self.err, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display + Debug> Debug for DisplayError<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Debug::fmt(&self.err, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display + Debug> error::Error for DisplayError<T> {}
|
||||
|
||||
fn test_ptx_assert<'a, T: OclPrm + From<u8>>(
|
||||
name: &str,
|
||||
ptx_text: &'a str,
|
||||
input: &[T],
|
||||
output: &mut [T],
|
||||
) -> Result<(), Box<dyn error::Error + 'a>> {
|
||||
let mut errors = Vec::new();
|
||||
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
|
||||
assert!(errors.len() == 0);
|
||||
let spirv = translate::to_spirv(ast)?;
|
||||
let result = run_spirv(name, &spirv, input, output).map_err(|err| DisplayError { err })?;
|
||||
assert_eq!(&output, &&*result);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_spirv<T: OclPrm + From<u8>>(
|
||||
name: &str,
|
||||
spirv: &[u32],
|
||||
input: &[T],
|
||||
output: &mut [T],
|
||||
) -> ocl::Result<Vec<T>> {
|
||||
let (plat, dev) = get_ocl_platform_device();
|
||||
let ctx = Context::builder().platform(plat).devices(dev).build()?;
|
||||
let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap();
|
||||
let byte_il = unsafe {
|
||||
slice::from_raw_parts::<u8>(
|
||||
spirv.as_ptr() as *const _,
|
||||
spirv.len() * mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let src = CString::new(
|
||||
"
|
||||
__kernel void ld_st(ulong a, ulong b)
|
||||
{
|
||||
__global ulong* a_copy = (__global ulong*)a;
|
||||
__global ulong* b_copy = (__global ulong*)b;
|
||||
*b_copy = *a_copy;
|
||||
}",
|
||||
)
|
||||
.unwrap();
|
||||
//let prog = Program::with_il(byte_il, Some(&[dev]), &empty_cstr, &ctx)?;
|
||||
let prog = Program::with_source(&ctx, &[src], Some(&[dev]), &empty_cstr)?;
|
||||
let queue = Queue::new(&ctx, dev, None)?;
|
||||
let cl_device_mem_alloc_intel = get_cl_device_mem_alloc_intel(&plat)?;
|
||||
let cl_enqueue_memcpy_intel = get_cl_enqueue_memcpy_intel(&plat)?;
|
||||
let cl_enqueue_memset_intel = get_cl_enqueue_memset_intel(&plat)?;
|
||||
let cl_set_kernel_arg_mem_pointer_intel = get_cl_set_kernel_arg_mem_pointer_intel(&plat)?;
|
||||
let mut err_code = 0;
|
||||
let inp_b = cl_device_mem_alloc_intel(
|
||||
ctx.as_ptr(),
|
||||
dev.as_raw(),
|
||||
ptr::null_mut(),
|
||||
input.len() * mem::size_of::<T>(),
|
||||
mem::align_of::<T>() as u32,
|
||||
&mut err_code,
|
||||
);
|
||||
assert_eq!(err_code, 0);
|
||||
let out_b = cl_device_mem_alloc_intel(
|
||||
ctx.as_ptr(),
|
||||
dev.as_raw(),
|
||||
ptr::null_mut(),
|
||||
output.len() * mem::size_of::<T>(),
|
||||
mem::align_of::<T>() as u32,
|
||||
&mut err_code,
|
||||
);
|
||||
assert_eq!(err_code, 0);
|
||||
err_code = cl_enqueue_memcpy_intel(
|
||||
queue.as_ptr(),
|
||||
1,
|
||||
inp_b as *mut _,
|
||||
input.as_ptr() as *const _,
|
||||
input.len() * mem::size_of::<T>(),
|
||||
0,
|
||||
ptr::null(),
|
||||
ptr::null_mut(),
|
||||
);
|
||||
assert_eq!(err_code, 0);
|
||||
err_code = cl_enqueue_memset_intel(
|
||||
queue.as_ptr(),
|
||||
out_b as *mut _,
|
||||
0,
|
||||
input.len() * mem::size_of::<T>(),
|
||||
0,
|
||||
ptr::null(),
|
||||
ptr::null_mut(),
|
||||
);
|
||||
assert_eq!(err_code, 0);
|
||||
let kernel = ocl::core::create_kernel(prog.as_core(), name)?;
|
||||
err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 0, inp_b);
|
||||
assert_eq!(err_code, 0);
|
||||
err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 1, out_b);
|
||||
assert_eq!(err_code, 0);
|
||||
unsafe {
|
||||
ocl::core::enqueue_kernel::<(), ()>(
|
||||
queue.as_core(),
|
||||
&kernel,
|
||||
1,
|
||||
None,
|
||||
&[1, 0, 0],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}?;
|
||||
let mut result: Vec<T> = vec![0u8.into(); output.len()];
|
||||
err_code = cl_enqueue_memcpy_intel(
|
||||
queue.as_ptr(),
|
||||
1,
|
||||
result.as_mut_ptr() as *mut _,
|
||||
inp_b,
|
||||
result.len() * mem::size_of::<T>(),
|
||||
0,
|
||||
ptr::null(),
|
||||
ptr::null_mut(),
|
||||
);
|
||||
assert_eq!(err_code, 0);
|
||||
queue.finish()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_ocl_platform_device() -> (Platform, Device) {
|
||||
for p in Platform::list() {
|
||||
if p.extensions()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|ext| *ext == "cl_intel_unified_shared_memory_preview")
|
||||
.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for d in Device::list_all(p).unwrap() {
|
||||
let typ = d.info(ocl::enums::DeviceInfo::Type).unwrap();
|
||||
if let ocl::enums::DeviceInfoResult::Type(typ) = typ {
|
||||
if typ.cpu() == ocl::flags::DeviceType::CPU {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Ok(version) = d.info_raw(CL_DEVICE_IL_VERSION) {
|
||||
let name = str::from_utf8(&version).unwrap();
|
||||
if name.starts_with("SPIR-V") {
|
||||
return (p, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
panic!("No OpenCL device with SPIR-V and USM support found")
|
||||
}
|
||||
|
||||
fn get_cl_device_mem_alloc_intel(
|
||||
p: &Platform,
|
||||
) -> ocl::core::Result<
|
||||
extern "C" fn(
|
||||
ocl::core::ffi::cl_context,
|
||||
ocl::core::ffi::cl_device_id,
|
||||
*const ocl::core::ffi::cl_bitfield,
|
||||
ocl::core::ffi::size_t,
|
||||
ocl::core::ffi::cl_uint,
|
||||
*mut ocl::core::ffi::cl_int,
|
||||
) -> *const c_void,
|
||||
> {
|
||||
let ptr = unsafe {
|
||||
ocl::core::get_extension_function_address_for_platform(
|
||||
p.as_core(),
|
||||
"clDeviceMemAllocINTEL",
|
||||
None,
|
||||
)
|
||||
}?;
|
||||
Ok(unsafe { std::mem::transmute(ptr) })
|
||||
}
|
||||
|
||||
fn get_cl_enqueue_memcpy_intel(
|
||||
p: &Platform,
|
||||
) -> ocl::core::Result<
|
||||
extern "C" fn(
|
||||
ocl::core::ffi::cl_command_queue,
|
||||
ocl::core::ffi::cl_bool,
|
||||
*mut c_void,
|
||||
*const c_void,
|
||||
ocl::core::ffi::size_t,
|
||||
ocl::core::ffi::cl_uint,
|
||||
*const ocl::core::ffi::cl_event,
|
||||
*mut ocl::core::ffi::cl_event,
|
||||
) -> ocl::core::ffi::cl_int,
|
||||
> {
|
||||
let ptr = unsafe {
|
||||
ocl::core::get_extension_function_address_for_platform(
|
||||
p.as_core(),
|
||||
"clEnqueueMemcpyINTEL",
|
||||
None,
|
||||
)
|
||||
}?;
|
||||
Ok(unsafe { std::mem::transmute(ptr) })
|
||||
}
|
||||
|
||||
fn get_cl_enqueue_memset_intel(
|
||||
p: &Platform,
|
||||
) -> ocl::core::Result<
|
||||
extern "C" fn(
|
||||
ocl::core::ffi::cl_command_queue,
|
||||
*mut c_void,
|
||||
ocl::core::ffi::cl_int,
|
||||
ocl::core::ffi::size_t,
|
||||
ocl::core::ffi::cl_uint,
|
||||
*const ocl::core::ffi::cl_event,
|
||||
*mut ocl::core::ffi::cl_event,
|
||||
) -> ocl::core::ffi::cl_int,
|
||||
> {
|
||||
let ptr = unsafe {
|
||||
ocl::core::get_extension_function_address_for_platform(
|
||||
p.as_core(),
|
||||
"clEnqueueMemsetINTEL",
|
||||
None,
|
||||
)
|
||||
}?;
|
||||
Ok(unsafe { std::mem::transmute(ptr) })
|
||||
}
|
||||
|
||||
fn get_cl_set_kernel_arg_mem_pointer_intel(
|
||||
p: &Platform,
|
||||
) -> ocl::core::Result<
|
||||
extern "C" fn(
|
||||
ocl::core::ffi::cl_kernel,
|
||||
ocl::core::ffi::cl_uint,
|
||||
*const c_void,
|
||||
) -> ocl::core::ffi::cl_int,
|
||||
> {
|
||||
let ptr = unsafe {
|
||||
ocl::core::get_extension_function_address_for_platform(
|
||||
p.as_core(),
|
||||
"clSetKernelArgMemPointerINTEL",
|
||||
None,
|
||||
)
|
||||
}?;
|
||||
Ok(unsafe { std::mem::transmute(ptr) })
|
||||
}
|
|
@ -5,6 +5,8 @@ use std::cell::RefCell;
|
|||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::fmt;
|
||||
|
||||
use rspirv::binary::{Assemble, Disassemble};
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
||||
enum SpirvType {
|
||||
Base(ast::ScalarType),
|
||||
|
@ -13,7 +15,6 @@ enum SpirvType {
|
|||
|
||||
struct TypeWordMap {
|
||||
void: spirv::Word,
|
||||
fn_void: spirv::Word,
|
||||
complex: HashMap<SpirvType, spirv::Word>,
|
||||
}
|
||||
|
||||
|
@ -22,7 +23,6 @@ impl TypeWordMap {
|
|||
let void = b.type_void();
|
||||
TypeWordMap {
|
||||
void: void,
|
||||
fn_void: b.type_function(void, vec![]),
|
||||
complex: HashMap::<SpirvType, spirv::Word>::new(),
|
||||
}
|
||||
}
|
||||
|
@ -30,32 +30,24 @@ impl TypeWordMap {
|
|||
fn void(&self) -> spirv::Word {
|
||||
self.void
|
||||
}
|
||||
fn fn_void(&self) -> spirv::Word {
|
||||
self.fn_void
|
||||
}
|
||||
|
||||
fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
|
||||
*self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t {
|
||||
ast::ScalarType::B8 | ast::ScalarType::U8 => {
|
||||
b.type_int(8, 0)
|
||||
}
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 => {
|
||||
b.type_int(16, 0)
|
||||
}
|
||||
ast::ScalarType::B32 | ast::ScalarType::U32 => {
|
||||
b.type_int(32, 0)
|
||||
}
|
||||
ast::ScalarType::B64 | ast::ScalarType::U64 => {
|
||||
b.type_int(64, 0)
|
||||
}
|
||||
ast::ScalarType::S8 => b.type_int(8, 1),
|
||||
ast::ScalarType::S16 => b.type_int(16, 1),
|
||||
ast::ScalarType::S32 => b.type_int(32, 1),
|
||||
ast::ScalarType::S64 => b.type_int(64, 1),
|
||||
ast::ScalarType::F16 => b.type_float(16),
|
||||
ast::ScalarType::F32 => b.type_float(32),
|
||||
ast::ScalarType::F64 => b.type_float(64),
|
||||
})
|
||||
*self
|
||||
.complex
|
||||
.entry(SpirvType::Base(t))
|
||||
.or_insert_with(|| match t {
|
||||
ast::ScalarType::B8 | ast::ScalarType::U8 => b.type_int(8, 0),
|
||||
ast::ScalarType::B16 | ast::ScalarType::U16 => b.type_int(16, 0),
|
||||
ast::ScalarType::B32 | ast::ScalarType::U32 => b.type_int(32, 0),
|
||||
ast::ScalarType::B64 | ast::ScalarType::U64 => b.type_int(64, 0),
|
||||
ast::ScalarType::S8 => b.type_int(8, 1),
|
||||
ast::ScalarType::S16 => b.type_int(16, 1),
|
||||
ast::ScalarType::S32 => b.type_int(32, 1),
|
||||
ast::ScalarType::S64 => b.type_int(64, 1),
|
||||
ast::ScalarType::F16 => b.type_float(16),
|
||||
ast::ScalarType::F32 => b.type_float(32),
|
||||
ast::ScalarType::F64 => b.type_float(64),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
|
||||
|
@ -63,15 +55,25 @@ impl TypeWordMap {
|
|||
SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar),
|
||||
SpirvType::Pointer(scalar, storage) => {
|
||||
let base = self.get_or_add_scalar(b, scalar);
|
||||
*self.complex.entry(t).or_insert_with(|| {
|
||||
b.type_pointer(None, storage, base)
|
||||
})
|
||||
*self
|
||||
.complex
|
||||
.entry(t)
|
||||
.or_insert_with(|| b.type_pointer(None, storage, base))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_or_add_fn<Args: Iterator<Item = SpirvType>>(
|
||||
&mut self,
|
||||
b: &mut dr::Builder,
|
||||
args: Args,
|
||||
) -> spirv::Word {
|
||||
let params = args.map(|a| self.get_or_add(b, a)).collect::<Vec<_>>();
|
||||
b.type_function(self.void(), params)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
|
||||
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
|
||||
let mut builder = dr::Builder::new();
|
||||
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
|
||||
builder.set_version(1, 0);
|
||||
|
@ -83,10 +85,12 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
|
|||
for f in ast.functions {
|
||||
emit_function(&mut builder, &mut map, f)?;
|
||||
}
|
||||
Ok(vec![])
|
||||
let module = builder.module();
|
||||
Ok(module.assemble())
|
||||
}
|
||||
|
||||
fn emit_capabilities(builder: &mut dr::Builder) {
|
||||
builder.capability(spirv::Capability::GenericPointer);
|
||||
builder.capability(spirv::Capability::Linkage);
|
||||
builder.capability(spirv::Capability::Addresses);
|
||||
builder.capability(spirv::Capability::Kernel);
|
||||
|
@ -112,12 +116,12 @@ fn emit_function<'a>(
|
|||
map: &mut TypeWordMap,
|
||||
f: ast::Function<'a>,
|
||||
) -> Result<spirv::Word, rspirv::dr::Error> {
|
||||
let func_id = builder.begin_function(
|
||||
map.void(),
|
||||
None,
|
||||
spirv::FunctionControl::NONE,
|
||||
map.fn_void(),
|
||||
)?;
|
||||
let func_type = get_function_type(builder, map, &f.args);
|
||||
let func_id =
|
||||
builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
|
||||
if f.kernel {
|
||||
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
|
||||
}
|
||||
let mut contant_ids = HashMap::new();
|
||||
collect_arg_ids(&mut contant_ids, &f.args);
|
||||
collect_label_ids(&mut contant_ids, &f.body);
|
||||
|
@ -126,7 +130,7 @@ fn emit_function<'a>(
|
|||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
let phis = ssa_legalize(
|
||||
let (mut phis, unique_ids) = ssa_legalize(
|
||||
&mut normalized_ids,
|
||||
contant_ids.len() as u32,
|
||||
unique_ids,
|
||||
|
@ -138,11 +142,17 @@ fn emit_function<'a>(
|
|||
emit_function_args(builder, id_offset, map, &f.args);
|
||||
emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?;
|
||||
builder.end_function()?;
|
||||
builder.ret()?;
|
||||
builder.end_function()?;
|
||||
Ok(func_id)
|
||||
}
|
||||
|
||||
fn get_function_type(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
args: &[ast::Argument],
|
||||
) -> spirv::Word {
|
||||
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::Base(arg.a_type)))
|
||||
}
|
||||
|
||||
fn emit_function_args(
|
||||
builder: &mut dr::Builder,
|
||||
id_offset: spirv::Word,
|
||||
|
@ -151,7 +161,7 @@ fn emit_function_args(
|
|||
) {
|
||||
let mut id = id_offset;
|
||||
for arg in args {
|
||||
let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
|
||||
let result_type = map.get_or_add_scalar(builder, arg.a_type);
|
||||
let inst = dr::Instruction::new(
|
||||
spirv::Op::FunctionParameter,
|
||||
Some(result_type),
|
||||
|
@ -195,6 +205,8 @@ fn emit_function_body_ops(
|
|||
func: &[Statement],
|
||||
cfg: &[BasicBlock],
|
||||
) -> Result<(), dr::Error> {
|
||||
// TODO: entry basic block can't be target of jumps,
|
||||
// we need to emit additional BB for this purpose
|
||||
for bb_idx in 0..cfg.len() {
|
||||
let body = get_bb_body(func, cfg, BBIndex(bb_idx));
|
||||
if body.len() == 0 {
|
||||
|
@ -215,24 +227,63 @@ fn emit_function_body_ops(
|
|||
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
||||
}
|
||||
Statement::Instruction(inst) => match inst {
|
||||
// Sadly, SPIR-V does not support marking jumps as guaranteed-converged
|
||||
// SPIR-V does not support marking jumps as guaranteed-converged
|
||||
ast::Instruction::Bra(_, arg) => {
|
||||
builder.branch(arg.src)?;
|
||||
builder.branch(arg.src + id_offset)?;
|
||||
}
|
||||
ast::Instruction::Ld(data, arg) => {
|
||||
if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() {
|
||||
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
|
||||
todo!()
|
||||
}
|
||||
let storage_class = match data.state_space {
|
||||
ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
|
||||
ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup,
|
||||
let src = match arg.src {
|
||||
ast::Operand::Reg(id) => id + id_offset,
|
||||
_ => todo!(),
|
||||
};
|
||||
let result_type = map.get_or_add(builder, SpirvType::Base(data.typ));
|
||||
let pointer_type =
|
||||
map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class));
|
||||
builder.load(result_type, None, pointer_type, None, [])?;
|
||||
let result_type = map.get_or_add_scalar(builder, data.typ);
|
||||
match data.state_space {
|
||||
ast::LdStateSpace::Generic => {
|
||||
// TODO: make the cast optional
|
||||
let ptr_result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
|
||||
);
|
||||
let bitcast = builder.convert_u_to_ptr(ptr_result_type, None, src - 5)?;
|
||||
builder.load(
|
||||
result_type,
|
||||
Some(arg.dst + id_offset),
|
||||
bitcast,
|
||||
None,
|
||||
[],
|
||||
)?;
|
||||
}
|
||||
ast::LdStateSpace::Param => {
|
||||
//builder.copy_object(result_type, Some(arg.dst + id_offset), src)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
ast::Instruction::St(data, arg) => {
|
||||
if data.qualifier != ast::LdStQualifier::Weak
|
||||
|| data.vector.is_some()
|
||||
|| data.state_space != ast::StStateSpace::Generic
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
let src = match arg.src {
|
||||
ast::Operand::Reg(id) => id + id_offset,
|
||||
_ => todo!(),
|
||||
};
|
||||
// TODO make cast optional
|
||||
let ptr_result_type = map.get_or_add(
|
||||
builder,
|
||||
SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
|
||||
);
|
||||
let bitcast =
|
||||
builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset - 5)?;
|
||||
builder.store(bitcast, src, None, &[])?;
|
||||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
ast::Instruction::Ret(_) => builder.ret()?,
|
||||
_ => todo!(),
|
||||
},
|
||||
}
|
||||
|
@ -279,7 +330,7 @@ fn ssa_legalize(
|
|||
bbs: &[BasicBlock],
|
||||
doms: &[BBIndex],
|
||||
dom_fronts: &[HashSet<BBIndex>],
|
||||
) -> Vec<Vec<PhiDef>> {
|
||||
) -> (Vec<Vec<PhiDef>>, spirv::Word) {
|
||||
let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts);
|
||||
apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
|
||||
}
|
||||
|
@ -301,7 +352,7 @@ fn apply_ssa_renaming(
|
|||
constant_ids: spirv::Word,
|
||||
all_ids: spirv::Word,
|
||||
old_phi: &[HashSet<spirv::Word>],
|
||||
) -> Vec<Vec<PhiDef>> {
|
||||
) -> (Vec<Vec<PhiDef>>, spirv::Word) {
|
||||
let mut dom_tree = vec![Vec::new(); bbs.len()];
|
||||
for (bb, idom) in doms.iter().enumerate().skip(1) {
|
||||
dom_tree[idom.0].push(BBIndex(bb));
|
||||
|
@ -345,7 +396,7 @@ fn apply_ssa_renaming(
|
|||
break;
|
||||
}
|
||||
}
|
||||
new_phi
|
||||
let phi = new_phi
|
||||
.into_iter()
|
||||
.map(|map| {
|
||||
map.into_iter()
|
||||
|
@ -355,7 +406,8 @@ fn apply_ssa_renaming(
|
|||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.collect::<Vec<_>>();
|
||||
(phi, ssa_state.next_id())
|
||||
}
|
||||
|
||||
// before ssa-renaming every phi is x <- phi(x,x,x,x)
|
||||
|
@ -479,6 +531,10 @@ impl<'a> SSARewriteState {
|
|||
self.stack[(x - self.constant_ids) as usize].pop();
|
||||
}
|
||||
}
|
||||
|
||||
fn next_id(&self) -> spirv::Word {
|
||||
self.next
|
||||
}
|
||||
}
|
||||
|
||||
// "Engineering a Compiler" - Figure 9.9
|
||||
|
@ -895,7 +951,10 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Not(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Cvt(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Shl(_, a) => a.visit_id(f),
|
||||
ast::Instruction::St(_, a) => a.visit_id(f),
|
||||
ast::Instruction::St(_, a) => {
|
||||
f(false, &a.dst);
|
||||
a.src.visit_id(f);
|
||||
}
|
||||
ast::Instruction::Bra(_, a) => a.visit_id(f),
|
||||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
|
@ -912,7 +971,10 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Not(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::St(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::St(_, a) => {
|
||||
f(false, &mut a.dst);
|
||||
a.src.visit_id_mut(f);
|
||||
}
|
||||
ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
|
||||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
|
@ -965,7 +1027,7 @@ impl<T: Copy> ast::Instruction<T> {
|
|||
ast::Instruction::Not(_, a) => a.for_dst_id(f),
|
||||
ast::Instruction::Cvt(_, a) => a.for_dst_id(f),
|
||||
ast::Instruction::Shl(_, a) => a.for_dst_id(f),
|
||||
ast::Instruction::St(_, a) => a.for_dst_id(f),
|
||||
ast::Instruction::St(_, _) => (),
|
||||
ast::Instruction::Bra(_, _) => (),
|
||||
ast::Instruction::Ret(_) => (),
|
||||
}
|
||||
|
@ -1736,7 +1798,7 @@ mod tests {
|
|||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
let mut ssa_phis = ssa_legalize(
|
||||
let (mut ssa_phis, _) = ssa_legalize(
|
||||
&mut func,
|
||||
constant_ids.len() as u32,
|
||||
unique_ids,
|
||||
|
|
Loading…
Add table
Reference in a new issue