diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index adbbf2a..fb15d61 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -17,3 +17,6 @@ bit-vec = "0.6" [build-dependencies.lalrpop] version = "0.18.1" features = ["lexer"] + +[dev-dependencies] +ocl = { version = "0.19", features = ["opencl_version_1_1", "opencl_version_1_2", "opencl_version_2_1"] } diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 190c21a..f685b7d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -189,19 +189,19 @@ pub enum MovOperand { pub enum VectorPrefix { V2, - V4 + V4, } pub struct LdData { - pub qualifier: LdQualifier, + pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, pub vector: Option, - pub typ: ScalarType + pub typ: ScalarType, } #[derive(PartialEq, Eq)] -pub enum LdQualifier { +pub enum LdStQualifier { Weak, Volatile, Relaxed(LdScope), @@ -212,7 +212,7 @@ pub enum LdQualifier { pub enum LdScope { Cta, Gpu, - Sys + Sys, } #[derive(PartialEq, Eq)] @@ -225,14 +225,13 @@ pub enum LdStateSpace { Shared, } - #[derive(PartialEq, Eq)] pub enum LdCacheOperator { Cached, L2Only, Streaming, LastUse, - Uncached + Uncached, } pub struct MovData {} @@ -248,13 +247,38 @@ pub struct SetpBoolData {} pub struct NotData {} pub struct BraData { - pub uniform: bool + pub uniform: bool, } pub struct CvtData {} pub struct ShlData {} -pub struct StData {} +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StStateSpace, + pub caching: StCacheOperator, + pub vector: Option, + pub typ: ScalarType, +} -pub struct RetData {} +#[derive(PartialEq, Eq)] +pub enum StStateSpace { + Generic, + Global, + Local, + Param, + Shared, +} + +#[derive(PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +pub struct RetData { + pub uniform: bool, +} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index f8bb7fd..2984c89 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -4,6 +4,8 @@ extern crate lalrpop_util; extern crate quick_error; extern crate bit_vec; +#[cfg(test)] +extern crate ocl; extern crate rspirv; extern crate spirv_headers as spirv; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ded2386..999d511 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -188,10 +188,10 @@ Instruction: ast::Instruction<&'input str> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<&'input str> = { - "ld" "," "[" "]" => { + "ld" "," "[" "]" => { ast::Instruction::Ld( ast::LdData { - qualifier: q.unwrap_or(ast::LdQualifier::Weak), + qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::LdStateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), vector: v, @@ -202,11 +202,11 @@ InstLd: ast::Instruction<&'input str> = { } }; -LdQualifier: ast::LdQualifier = { - ".weak" => ast::LdQualifier::Weak, - ".volatile" => ast::LdQualifier::Volatile, - ".relaxed" => ast::LdQualifier::Relaxed(s), - ".acquire" => ast::LdQualifier::Acquire(s), +LdStQualifier: ast::LdStQualifier = { + ".weak" => ast::LdStQualifier::Weak, + ".volatile" => ast::LdStQualifier::Volatile, + ".relaxed" => ast::LdStQualifier::Relaxed(s), + ".acquire" => ast::LdStQualifier::Acquire(s), }; LdScope: ast::LdScope = { @@ -379,29 +379,39 @@ ShlType = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st +// Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<&'input str> = { - "st" LdQualifier? StStateSpace? StCacheOperator? VectorPrefix? MemoryType "[" "]" "," => { - ast::Instruction::St(ast::StData{}, ast::Arg2{dst:dst, src:src}) + "st" "[" "]" "," => { + ast::Instruction::St( + ast::StData { + qualifier: q.unwrap_or(ast::LdStQualifier::Weak), + state_space: ss.unwrap_or(ast::StStateSpace::Generic), + caching: cop.unwrap_or(ast::StCacheOperator::Writeback), + vector: v, + typ: t + }, + ast::Arg2{dst:dst, src:src} + ) } }; -StStateSpace = { - ".global", - ".local", - ".param", - ".shared", +StStateSpace: ast::StStateSpace = { + ".global" => ast::StStateSpace::Global, + ".local" => ast::StStateSpace::Local, + ".param" => ast::StStateSpace::Param, + ".shared" => ast::StStateSpace::Shared, }; -StCacheOperator = { - ".wb", - ".cg", - ".cs", - ".wt", +StCacheOperator: ast::StCacheOperator = { + ".wb" => ast::StCacheOperator::Writeback, + ".cg" => ast::StCacheOperator::L2Only, + ".cs" => ast::StCacheOperator::Streaming, + ".wt" => ast::StCacheOperator::Writethrough, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret InstRet: ast::Instruction<&'input str> = { - "ret" ".uni"? => ast::Instruction::Ret(ast::RetData{}) + "ret" => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() }) }; Operand: ast::Operand<&'input str> = { diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 15876ad..c421a8b 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -1,5 +1,7 @@ use super::ptx; +mod ops; + fn parse_and_assert(s: &str) { let mut errors = Vec::new(); ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); diff --git a/ptx/src/test/ops/ld_st/ld_st.ptx b/ptx/src/test/ops/ld_st/ld_st.ptx new file mode 100644 index 0000000..469a219 --- /dev/null +++ b/ptx/src/test/ops/ld_st/ld_st.ptx @@ -0,0 +1,20 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry ld_st( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + st.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/ops/ld_st/mod.rs b/ptx/src/test/ops/ld_st/mod.rs new file mode 100644 index 0000000..ab89fd4 --- /dev/null +++ b/ptx/src/test/ops/ld_st/mod.rs @@ -0,0 +1 @@ +test_ptx!(ld_st, [1u64], [1u64]); \ No newline at end of file diff --git a/ptx/src/test/ops/mod.rs b/ptx/src/test/ops/mod.rs new file mode 100644 index 0000000..1ea60b8 --- /dev/null +++ b/ptx/src/test/ops/mod.rs @@ -0,0 +1,280 @@ +use crate::ptx; +use crate::translate; +use ocl::{Buffer, Context, Device, Kernel, OclPrm, Platform, Program, Queue}; +use std::error; +use std::ffi::{c_void, CString}; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::mem; +use std::slice; +use std::{ptr, str}; + +macro_rules! test_ptx { + ($fn_name:ident, $input:expr, $output:expr) => { + #[test] + fn $fn_name() -> Result<(), Box> { + let ptx = include_str!(concat!(stringify!($fn_name), ".ptx")); + let input = $input; + let mut output = $output; + crate::test::ops::test_ptx_assert(stringify!($fn_name), ptx, &input, &mut output) + } + }; +} + +mod ld_st; + +const CL_DEVICE_IL_VERSION: u32 = 0x105B; + +struct DisplayError { + err: T, +} + +impl Display for DisplayError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&self.err, f) + } +} + +impl Debug for DisplayError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Debug::fmt(&self.err, f) + } +} + +impl error::Error for DisplayError {} + +fn test_ptx_assert<'a, T: OclPrm + From>( + name: &str, + ptx_text: &'a str, + input: &[T], + output: &mut [T], +) -> Result<(), Box> { + let mut errors = Vec::new(); + let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; + assert!(errors.len() == 0); + let spirv = translate::to_spirv(ast)?; + let result = run_spirv(name, &spirv, input, output).map_err(|err| DisplayError { err })?; + assert_eq!(&output, &&*result); + Ok(()) +} + +fn run_spirv>( + name: &str, + spirv: &[u32], + input: &[T], + output: &mut [T], +) -> ocl::Result> { + let (plat, dev) = get_ocl_platform_device(); + let ctx = Context::builder().platform(plat).devices(dev).build()?; + let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap(); + let byte_il = unsafe { + slice::from_raw_parts::( + spirv.as_ptr() as *const _, + spirv.len() * mem::size_of::(), + ) + }; + let src = CString::new( + " + __kernel void ld_st(ulong a, ulong b) + { + __global ulong* a_copy = (__global ulong*)a; + __global ulong* b_copy = (__global ulong*)b; + *b_copy = *a_copy; + }", + ) + .unwrap(); + //let prog = Program::with_il(byte_il, Some(&[dev]), &empty_cstr, &ctx)?; + let prog = Program::with_source(&ctx, &[src], Some(&[dev]), &empty_cstr)?; + let queue = Queue::new(&ctx, dev, None)?; + let cl_device_mem_alloc_intel = get_cl_device_mem_alloc_intel(&plat)?; + let cl_enqueue_memcpy_intel = get_cl_enqueue_memcpy_intel(&plat)?; + let cl_enqueue_memset_intel = get_cl_enqueue_memset_intel(&plat)?; + let cl_set_kernel_arg_mem_pointer_intel = get_cl_set_kernel_arg_mem_pointer_intel(&plat)?; + let mut err_code = 0; + let inp_b = cl_device_mem_alloc_intel( + ctx.as_ptr(), + dev.as_raw(), + ptr::null_mut(), + input.len() * mem::size_of::(), + mem::align_of::() as u32, + &mut err_code, + ); + assert_eq!(err_code, 0); + let out_b = cl_device_mem_alloc_intel( + ctx.as_ptr(), + dev.as_raw(), + ptr::null_mut(), + output.len() * mem::size_of::(), + mem::align_of::() as u32, + &mut err_code, + ); + assert_eq!(err_code, 0); + err_code = cl_enqueue_memcpy_intel( + queue.as_ptr(), + 1, + inp_b as *mut _, + input.as_ptr() as *const _, + input.len() * mem::size_of::(), + 0, + ptr::null(), + ptr::null_mut(), + ); + assert_eq!(err_code, 0); + err_code = cl_enqueue_memset_intel( + queue.as_ptr(), + out_b as *mut _, + 0, + input.len() * mem::size_of::(), + 0, + ptr::null(), + ptr::null_mut(), + ); + assert_eq!(err_code, 0); + let kernel = ocl::core::create_kernel(prog.as_core(), name)?; + err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 0, inp_b); + assert_eq!(err_code, 0); + err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 1, out_b); + assert_eq!(err_code, 0); + unsafe { + ocl::core::enqueue_kernel::<(), ()>( + queue.as_core(), + &kernel, + 1, + None, + &[1, 0, 0], + None, + None, + None, + ) + }?; + let mut result: Vec = vec![0u8.into(); output.len()]; + err_code = cl_enqueue_memcpy_intel( + queue.as_ptr(), + 1, + result.as_mut_ptr() as *mut _, + inp_b, + result.len() * mem::size_of::(), + 0, + ptr::null(), + ptr::null_mut(), + ); + assert_eq!(err_code, 0); + queue.finish()?; + Ok(result) +} + +fn get_ocl_platform_device() -> (Platform, Device) { + for p in Platform::list() { + if p.extensions() + .unwrap() + .iter() + .find(|ext| *ext == "cl_intel_unified_shared_memory_preview") + .is_none() + { + continue; + } + for d in Device::list_all(p).unwrap() { + let typ = d.info(ocl::enums::DeviceInfo::Type).unwrap(); + if let ocl::enums::DeviceInfoResult::Type(typ) = typ { + if typ.cpu() == ocl::flags::DeviceType::CPU { + continue; + } + } + if let Ok(version) = d.info_raw(CL_DEVICE_IL_VERSION) { + let name = str::from_utf8(&version).unwrap(); + if name.starts_with("SPIR-V") { + return (p, d); + } + } + } + } + panic!("No OpenCL device with SPIR-V and USM support found") +} + +fn get_cl_device_mem_alloc_intel( + p: &Platform, +) -> ocl::core::Result< + extern "C" fn( + ocl::core::ffi::cl_context, + ocl::core::ffi::cl_device_id, + *const ocl::core::ffi::cl_bitfield, + ocl::core::ffi::size_t, + ocl::core::ffi::cl_uint, + *mut ocl::core::ffi::cl_int, + ) -> *const c_void, +> { + let ptr = unsafe { + ocl::core::get_extension_function_address_for_platform( + p.as_core(), + "clDeviceMemAllocINTEL", + None, + ) + }?; + Ok(unsafe { std::mem::transmute(ptr) }) +} + +fn get_cl_enqueue_memcpy_intel( + p: &Platform, +) -> ocl::core::Result< + extern "C" fn( + ocl::core::ffi::cl_command_queue, + ocl::core::ffi::cl_bool, + *mut c_void, + *const c_void, + ocl::core::ffi::size_t, + ocl::core::ffi::cl_uint, + *const ocl::core::ffi::cl_event, + *mut ocl::core::ffi::cl_event, + ) -> ocl::core::ffi::cl_int, +> { + let ptr = unsafe { + ocl::core::get_extension_function_address_for_platform( + p.as_core(), + "clEnqueueMemcpyINTEL", + None, + ) + }?; + Ok(unsafe { std::mem::transmute(ptr) }) +} + +fn get_cl_enqueue_memset_intel( + p: &Platform, +) -> ocl::core::Result< + extern "C" fn( + ocl::core::ffi::cl_command_queue, + *mut c_void, + ocl::core::ffi::cl_int, + ocl::core::ffi::size_t, + ocl::core::ffi::cl_uint, + *const ocl::core::ffi::cl_event, + *mut ocl::core::ffi::cl_event, + ) -> ocl::core::ffi::cl_int, +> { + let ptr = unsafe { + ocl::core::get_extension_function_address_for_platform( + p.as_core(), + "clEnqueueMemsetINTEL", + None, + ) + }?; + Ok(unsafe { std::mem::transmute(ptr) }) +} + +fn get_cl_set_kernel_arg_mem_pointer_intel( + p: &Platform, +) -> ocl::core::Result< + extern "C" fn( + ocl::core::ffi::cl_kernel, + ocl::core::ffi::cl_uint, + *const c_void, + ) -> ocl::core::ffi::cl_int, +> { + let ptr = unsafe { + ocl::core::get_extension_function_address_for_platform( + p.as_core(), + "clSetKernelArgMemPointerINTEL", + None, + ) + }?; + Ok(unsafe { std::mem::transmute(ptr) }) +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f5c5107..90bd87c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,6 +5,8 @@ use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt; +use rspirv::binary::{Assemble, Disassemble}; + #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { Base(ast::ScalarType), @@ -13,7 +15,6 @@ enum SpirvType { struct TypeWordMap { void: spirv::Word, - fn_void: spirv::Word, complex: HashMap, } @@ -22,7 +23,6 @@ impl TypeWordMap { let void = b.type_void(); TypeWordMap { void: void, - fn_void: b.type_function(void, vec![]), complex: HashMap::::new(), } } @@ -30,32 +30,24 @@ impl TypeWordMap { fn void(&self) -> spirv::Word { self.void } - fn fn_void(&self) -> spirv::Word { - self.fn_void - } fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { - *self.complex.entry(SpirvType::Base(t)).or_insert_with(|| match t { - ast::ScalarType::B8 | ast::ScalarType::U8 => { - b.type_int(8, 0) - } - ast::ScalarType::B16 | ast::ScalarType::U16 => { - b.type_int(16, 0) - } - ast::ScalarType::B32 | ast::ScalarType::U32 => { - b.type_int(32, 0) - } - ast::ScalarType::B64 | ast::ScalarType::U64 => { - b.type_int(64, 0) - } - ast::ScalarType::S8 => b.type_int(8, 1), - ast::ScalarType::S16 => b.type_int(16, 1), - ast::ScalarType::S32 => b.type_int(32, 1), - ast::ScalarType::S64 => b.type_int(64, 1), - ast::ScalarType::F16 => b.type_float(16), - ast::ScalarType::F32 => b.type_float(32), - ast::ScalarType::F64 => b.type_float(64), - }) + *self + .complex + .entry(SpirvType::Base(t)) + .or_insert_with(|| match t { + ast::ScalarType::B8 | ast::ScalarType::U8 => b.type_int(8, 0), + ast::ScalarType::B16 | ast::ScalarType::U16 => b.type_int(16, 0), + ast::ScalarType::B32 | ast::ScalarType::U32 => b.type_int(32, 0), + ast::ScalarType::B64 | ast::ScalarType::U64 => b.type_int(64, 0), + ast::ScalarType::S8 => b.type_int(8, 1), + ast::ScalarType::S16 => b.type_int(16, 1), + ast::ScalarType::S32 => b.type_int(32, 1), + ast::ScalarType::S64 => b.type_int(64, 1), + ast::ScalarType::F16 => b.type_float(16), + ast::ScalarType::F32 => b.type_float(32), + ast::ScalarType::F64 => b.type_float(64), + }) } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { @@ -63,15 +55,25 @@ impl TypeWordMap { SpirvType::Base(scalar) => self.get_or_add_scalar(b, scalar), SpirvType::Pointer(scalar, storage) => { let base = self.get_or_add_scalar(b, scalar); - *self.complex.entry(t).or_insert_with(|| { - b.type_pointer(None, storage, base) - }) + *self + .complex + .entry(t) + .or_insert_with(|| b.type_pointer(None, storage, base)) } } } + + fn get_or_add_fn>( + &mut self, + b: &mut dr::Builder, + args: Args, + ) -> spirv::Word { + let params = args.map(|a| self.get_or_add(b, a)).collect::>(); + b.type_function(self.void(), params) + } } -pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { +pub fn to_spirv(ast: ast::Module) -> Result, dr::Error> { let mut builder = dr::Builder::new(); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 0); @@ -83,10 +85,12 @@ pub fn to_spirv(ast: ast::Module) -> Result, rspirv::dr::Error> { for f in ast.functions { emit_function(&mut builder, &mut map, f)?; } - Ok(vec![]) + let module = builder.module(); + Ok(module.assemble()) } fn emit_capabilities(builder: &mut dr::Builder) { + builder.capability(spirv::Capability::GenericPointer); builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Kernel); @@ -112,12 +116,12 @@ fn emit_function<'a>( map: &mut TypeWordMap, f: ast::Function<'a>, ) -> Result { - let func_id = builder.begin_function( - map.void(), - None, - spirv::FunctionControl::NONE, - map.fn_void(), - )?; + let func_type = get_function_type(builder, map, &f.args); + let func_id = + builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?; + if f.kernel { + builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]); + } let mut contant_ids = HashMap::new(); collect_arg_ids(&mut contant_ids, &f.args); collect_label_ids(&mut contant_ids, &f.body); @@ -126,7 +130,7 @@ fn emit_function<'a>( let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - let phis = ssa_legalize( + let (mut phis, unique_ids) = ssa_legalize( &mut normalized_ids, contant_ids.len() as u32, unique_ids, @@ -138,11 +142,17 @@ fn emit_function<'a>( emit_function_args(builder, id_offset, map, &f.args); emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?; builder.end_function()?; - builder.ret()?; - builder.end_function()?; Ok(func_id) } +fn get_function_type( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + args: &[ast::Argument], +) -> spirv::Word { + map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::Base(arg.a_type))) +} + fn emit_function_args( builder: &mut dr::Builder, id_offset: spirv::Word, @@ -151,7 +161,7 @@ fn emit_function_args( ) { let mut id = id_offset; for arg in args { - let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type)); + let result_type = map.get_or_add_scalar(builder, arg.a_type); let inst = dr::Instruction::new( spirv::Op::FunctionParameter, Some(result_type), @@ -195,6 +205,8 @@ fn emit_function_body_ops( func: &[Statement], cfg: &[BasicBlock], ) -> Result<(), dr::Error> { + // TODO: entry basic block can't be target of jumps, + // we need to emit additional BB for this purpose for bb_idx in 0..cfg.len() { let body = get_bb_body(func, cfg, BBIndex(bb_idx)); if body.len() == 0 { @@ -215,24 +227,63 @@ fn emit_function_body_ops( builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } Statement::Instruction(inst) => match inst { - // Sadly, SPIR-V does not support marking jumps as guaranteed-converged + // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra(_, arg) => { - builder.branch(arg.src)?; + builder.branch(arg.src + id_offset)?; } ast::Instruction::Ld(data, arg) => { - if data.qualifier != ast::LdQualifier::Weak || data.vector.is_some() { + if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { todo!() } - let storage_class = match data.state_space { - ast::LdStateSpace::Generic => spirv::StorageClass::Generic, - ast::LdStateSpace::Param => spirv::StorageClass::CrossWorkgroup, + let src = match arg.src { + ast::Operand::Reg(id) => id + id_offset, _ => todo!(), }; - let result_type = map.get_or_add(builder, SpirvType::Base(data.typ)); - let pointer_type = - map.get_or_add(builder, SpirvType::Pointer(data.typ, storage_class)); - builder.load(result_type, None, pointer_type, None, [])?; + let result_type = map.get_or_add_scalar(builder, data.typ); + match data.state_space { + ast::LdStateSpace::Generic => { + // TODO: make the cast optional + let ptr_result_type = map.get_or_add( + builder, + SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup), + ); + let bitcast = builder.convert_u_to_ptr(ptr_result_type, None, src - 5)?; + builder.load( + result_type, + Some(arg.dst + id_offset), + bitcast, + None, + [], + )?; + } + ast::LdStateSpace::Param => { + //builder.copy_object(result_type, Some(arg.dst + id_offset), src)?; + } + _ => todo!(), + } } + ast::Instruction::St(data, arg) => { + if data.qualifier != ast::LdStQualifier::Weak + || data.vector.is_some() + || data.state_space != ast::StStateSpace::Generic + { + todo!() + } + let src = match arg.src { + ast::Operand::Reg(id) => id + id_offset, + _ => todo!(), + }; + // TODO make cast optional + let ptr_result_type = map.get_or_add( + builder, + SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup), + ); + let bitcast = + builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset - 5)?; + builder.store(bitcast, src, None, &[])?; + } + // SPIR-V does not support ret as guaranteed-converged + ast::Instruction::Ret(_) => builder.ret()?, _ => todo!(), }, } @@ -279,7 +330,7 @@ fn ssa_legalize( bbs: &[BasicBlock], doms: &[BBIndex], dom_fronts: &[HashSet], -) -> Vec> { +) -> (Vec>, spirv::Word) { let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts); apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis) } @@ -301,7 +352,7 @@ fn apply_ssa_renaming( constant_ids: spirv::Word, all_ids: spirv::Word, old_phi: &[HashSet], -) -> Vec> { +) -> (Vec>, spirv::Word) { let mut dom_tree = vec![Vec::new(); bbs.len()]; for (bb, idom) in doms.iter().enumerate().skip(1) { dom_tree[idom.0].push(BBIndex(bb)); @@ -345,7 +396,7 @@ fn apply_ssa_renaming( break; } } - new_phi + let phi = new_phi .into_iter() .map(|map| { map.into_iter() @@ -355,7 +406,8 @@ fn apply_ssa_renaming( }) .collect::>() }) - .collect::>() + .collect::>(); + (phi, ssa_state.next_id()) } // before ssa-renaming every phi is x <- phi(x,x,x,x) @@ -479,6 +531,10 @@ impl<'a> SSARewriteState { self.stack[(x - self.constant_ids) as usize].pop(); } } + + fn next_id(&self) -> spirv::Word { + self.next + } } // "Engineering a Compiler" - Figure 9.9 @@ -895,7 +951,10 @@ impl ast::Instruction { ast::Instruction::Not(_, a) => a.visit_id(f), ast::Instruction::Cvt(_, a) => a.visit_id(f), ast::Instruction::Shl(_, a) => a.visit_id(f), - ast::Instruction::St(_, a) => a.visit_id(f), + ast::Instruction::St(_, a) => { + f(false, &a.dst); + a.src.visit_id(f); + } ast::Instruction::Bra(_, a) => a.visit_id(f), ast::Instruction::Ret(_) => (), } @@ -912,7 +971,10 @@ impl ast::Instruction { ast::Instruction::Not(_, a) => a.visit_id_mut(f), ast::Instruction::Cvt(_, a) => a.visit_id_mut(f), ast::Instruction::Shl(_, a) => a.visit_id_mut(f), - ast::Instruction::St(_, a) => a.visit_id_mut(f), + ast::Instruction::St(_, a) => { + f(false, &mut a.dst); + a.src.visit_id_mut(f); + } ast::Instruction::Bra(_, a) => a.visit_id_mut(f), ast::Instruction::Ret(_) => (), } @@ -965,7 +1027,7 @@ impl ast::Instruction { ast::Instruction::Not(_, a) => a.for_dst_id(f), ast::Instruction::Cvt(_, a) => a.for_dst_id(f), ast::Instruction::Shl(_, a) => a.for_dst_id(f), - ast::Instruction::St(_, a) => a.for_dst_id(f), + ast::Instruction::St(_, _) => (), ast::Instruction::Bra(_, _) => (), ast::Instruction::Ret(_) => (), } @@ -1736,7 +1798,7 @@ mod tests { let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - let mut ssa_phis = ssa_legalize( + let (mut ssa_phis, _) = ssa_legalize( &mut func, constant_ids.len() as u32, unique_ids,