diff --git a/Cargo.toml b/Cargo.toml index 5106d33..67ee0a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "level_zero-sys", + "level_zero", "notcuda", "notcuda_inject", "notcuda_redirect", diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 7f94c1b..4a61f4e 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -5,22 +5,26 @@ extern crate quick_error; extern crate bit_vec; #[cfg(test)] -extern crate level_zero_sys as l0; -#[cfg(test)] extern crate level_zero as ze; +#[cfg(test)] +extern crate level_zero_sys as l0; extern crate rspirv; extern crate spirv_headers as spirv; -lalrpop_mod!(ptx); +lalrpop_mod!( + #[allow(dead_code)] + #[allow(unused_imports)] + ptx +); +pub mod ast; #[cfg(test)] mod test; mod translate; -pub mod ast; -pub use ast::Module as Module; -pub use translate::to_spirv as to_spirv; +pub use ast::Module; +pub use translate::to_spirv; pub(crate) fn without_none(x: Vec>) -> Vec { x.into_iter().filter_map(|x| x).collect() -} \ No newline at end of file +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index cc2b890..63d7f7b 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt; -use rspirv::binary::{Assemble, Disassemble}; +use rspirv::binary::Assemble; #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvType { @@ -130,7 +130,7 @@ fn emit_function<'a>( let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); - let (mut phis, unique_ids) = ssa_legalize( + let (_, unique_ids) = ssa_legalize( &mut normalized_ids, contant_ids.len() as u32, unique_ids, @@ -245,9 +245,13 @@ fn emit_function_body_ops( // TODO: make the cast optional let ptr_result_type = map.get_or_add( builder, - SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup), + SpirvType::Pointer( + data.typ, + spirv::StorageClass::CrossWorkgroup, + ), ); - let bitcast = builder.convert_u_to_ptr(ptr_result_type, None, src)?; + let bitcast = + builder.convert_u_to_ptr(ptr_result_type, None, src)?; builder.load( result_type, Some(arg.dst + id_offset), @@ -360,7 +364,11 @@ fn apply_ssa_renaming( let mut old_dst_id = vec![Vec::new(); bbs.len()]; for bb in 0..bbs.len() { for s in get_bb_body(func, bbs, BBIndex(bb)) { - s.for_dst_id(&mut |id| old_dst_id[bb].push(id)); + s.visit_id(&mut |is_dst, id| { + if is_dst { + old_dst_id[bb].push(*id) + } + }); } } let mut new_phi = old_phi @@ -872,16 +880,6 @@ impl Statement { } } - fn for_dst_id(&self, f: &mut F) { - match self { - Statement::Label(id) => f(*id), - Statement::Instruction(inst) => { - inst.for_dst_id(f); - } - Statement::Conditional(_) => (), - } - } - fn visit_id(&self, f: &mut F) { match self { Statement::Label(id) => f(false, id), @@ -901,27 +899,6 @@ impl Statement { } } -impl ast::PredAt { - fn map_id U>(self, f: &mut F) -> ast::PredAt { - ast::PredAt { - not: self.not, - label: f(self.label), - } - } - - fn visit_id(&self, f: &mut F) { - f(false, &self.label) - } - - fn visit_id_mut(&mut self, f: &mut F) { - f(false, &mut self.label) - } -} - -impl ast::PredAt { - fn for_dst_id(&self, _: &mut F) {} -} - impl ast::Instruction { fn map_id U>(self, f: &mut F) -> ast::Instruction { match self { @@ -1015,23 +992,6 @@ impl ast::Instruction { | ast::Instruction::Bra(_, _) => false, } } - - fn for_dst_id(&self, f: &mut F) { - match self { - ast::Instruction::Ld(_, a) => a.for_dst_id(f), - ast::Instruction::Mov(_, a) => a.for_dst_id(f), - ast::Instruction::Mul(_, a) => a.for_dst_id(f), - ast::Instruction::Add(_, a) => a.for_dst_id(f), - ast::Instruction::Setp(_, a) => a.for_dst_id(f), - ast::Instruction::SetpBool(_, a) => a.for_dst_id(f), - ast::Instruction::Not(_, a) => a.for_dst_id(f), - ast::Instruction::Cvt(_, a) => a.for_dst_id(f), - ast::Instruction::Shl(_, a) => a.for_dst_id(f), - ast::Instruction::St(_, _) => (), - ast::Instruction::Bra(_, _) => (), - ast::Instruction::Ret(_) => (), - } - } } impl ast::Arg1 { @@ -1067,12 +1027,6 @@ impl ast::Arg2 { } } -impl ast::Arg2 { - fn for_dst_id(&self, f: &mut F) { - f(self.dst); - } -} - impl ast::Arg2Mov { fn map_id U>(self, f: &mut F) -> ast::Arg2Mov { ast::Arg2Mov { @@ -1092,12 +1046,6 @@ impl ast::Arg2Mov { } } -impl ast::Arg2Mov { - fn for_dst_id(&self, f: &mut F) { - f(self.dst); - } -} - impl ast::Arg3 { fn map_id U>(self, f: &mut F) -> ast::Arg3 { ast::Arg3 { @@ -1120,12 +1068,6 @@ impl ast::Arg3 { } } -impl ast::Arg3 { - fn for_dst_id(&self, f: &mut F) { - f(self.dst); - } -} - impl ast::Arg4 { fn map_id U>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { @@ -1151,13 +1093,6 @@ impl ast::Arg4 { } } -impl ast::Arg4 { - fn for_dst_id(&self, f: &mut F) { - f(self.dst1); - self.dst2.map(|t| f(t)); - } -} - impl ast::Arg5 { fn map_id U>(self, f: &mut F) -> ast::Arg5 { ast::Arg5 { @@ -1186,13 +1121,6 @@ impl ast::Arg5 { } } -impl ast::Arg5 { - fn for_dst_id(&self, f: &mut F) { - f(self.dst1); - self.dst2.map(|t| f(t)); - } -} - impl ast::Operand { fn map_id U>(self, f: &mut F) -> ast::Operand { match self { @@ -1879,8 +1807,10 @@ mod tests { fn assert_dst_unique(func: &[Statement], phis: &[Vec]) { let mut seen = HashSet::new(); for s in func { - s.for_dst_id(&mut |id| { - assert!(seen.insert(id)); + s.visit_id(&mut |is_dst, id| { + if is_dst { + assert!(seen.insert(*id)); + } }); } for phi_set in phis {