diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index b2831a0..8dd612d 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,9 +1,7 @@ use crate::ast; -use bit_vec::BitVec; use rspirv::dr; -use std::cell::RefCell; -use std::collections::{BTreeMap, HashMap, HashSet}; -use std::{borrow::Cow, fmt, iter, mem}; +use std::collections::{HashMap, HashSet}; +use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; @@ -218,7 +216,7 @@ fn normalize_labels( fn normalize_predicates( func: Vec>, id_def: &mut NumericIdResolver, -) -> Vec> { +) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { match s { @@ -258,9 +256,9 @@ fn normalize_predicates( } fn insert_mem_ssa_statements( - func: Vec>, + func: Vec, id_def: &mut NumericIdResolver, -) -> Vec> { +) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { match s { @@ -318,7 +316,7 @@ fn insert_mem_ssa_statements( } fn expand_arguments( - func: Vec>, + func: Vec, id_def: &mut NumericIdResolver, ) -> Vec { let mut result = Vec::with_capacity(func.len()); @@ -608,36 +606,6 @@ fn emit_function_args( } } -fn collect_arg_ids<'a>( - result: &mut HashMap<&'a str, spirv::Word>, - type_check: &mut HashMap, - args: &'a [ast::Argument<'a>], -) { - let mut id = result.len() as u32; - for arg in args { - result.insert(arg.name, id); - type_check.insert(id, ast::Type::Scalar(arg.a_type)); - id += 1; - } -} - -fn collect_label_ids<'a>( - result: &mut HashMap<&'a str, spirv::Word>, - fn_body: &[ast::Statement<&'a str>], -) { - let mut id = result.len() as u32; - for s in fn_body { - match s { - ast::Statement::Label(name) => { - result.insert(name, id); - id += 1; - } - ast::Statement::Instruction(_, _) => (), - ast::Statement::Variable(_) => (), - } - } -} - fn emit_function_body_ops( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -749,7 +717,7 @@ fn emit_function_body_ops( let type_id = map.get_or_add(builder, SpirvType::from(*typ)); builder.load(type_id, Some(arg.dst), arg.src, None, [])?; } - Statement::StoreVar(arg, typ) => { + Statement::StoreVar(arg, _) => { builder.store(arg.src1, arg.src2, None, [])?; } } @@ -893,7 +861,7 @@ fn expand_map_ids<'a>( } ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))), - i.map_id1(&mut |id| id_defs.get_id(id)), + i.map_id(&mut |id| id_defs.get_id(id)), )), ast::Statement::Variable(var) => match var.count { Some(count) => { @@ -945,7 +913,6 @@ impl<'a> StringIdResolver<'a> { self.variables[id] } - #[must_use] fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { let numeric_id = self.current_id; self.variables.insert(Cow::Borrowed(id), numeric_id); @@ -1023,10 +990,6 @@ impl Statement { Statement::Constant(cons) => cons.visit_id_mut(f), } } - - fn get_type(&self) -> Option { - todo!() - } } trait Args { @@ -1125,23 +1088,6 @@ impl Instruction { _ => todo!(), } } - - fn is_terminal(&self) -> bool { - match self { - Instruction::Ret(_) => true, - Instruction::Ld(_, _) - | Instruction::Mov(_, _) - | Instruction::Mul(_, _) - | Instruction::Add(_, _) - | Instruction::Setp(_, _) - | Instruction::SetpBool(_, _) - | Instruction::Not(_, _) - | Instruction::Cvt(_, _) - | Instruction::Shl(_, _) - | Instruction::St(_, _) - | Instruction::Bra(_, _) => false, - } - } } impl Instruction { @@ -1164,23 +1110,6 @@ impl Instruction { } impl Instruction { - fn visit_id(&self, f: &mut F) { - match self { - Instruction::Ld(_, a) => a.visit_id(f), - Instruction::Mov(_, a) => a.visit_id(f), - Instruction::Mul(_, a) => a.visit_id(f), - Instruction::Add(_, a) => a.visit_id(f), - Instruction::Setp(_, a) => a.visit_id(f), - Instruction::SetpBool(_, a) => a.visit_id(f), - Instruction::Not(_, a) => a.visit_id(f), - Instruction::Cvt(_, a) => a.visit_id(f), - Instruction::Shl(_, a) => a.visit_id(f), - Instruction::St(_, a) => a.visit_id(f), - Instruction::Bra(_, a) => a.visit_id(f), - Instruction::Ret(_) => (), - } - } - fn jump_target(&self) -> Option { match self { Instruction::Bra(_, a) => Some(a.src), @@ -1390,7 +1319,7 @@ impl ast::PredAt { } impl ast::Instruction { - fn map_id1 U>(self, f: &mut F) -> ast::Instruction { + fn map_id U>(self, f: &mut F) -> ast::Instruction { match self { ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)), ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)), @@ -1406,107 +1335,6 @@ impl ast::Instruction { ast::Instruction::Ret(d) => ast::Instruction::Ret(d), } } - - fn map_id spirv::Word>(self, f: &mut F) -> Instruction { - match self { - ast::Instruction::Ld(d, a) => Instruction::Ld(d, a.map_id(f)), - ast::Instruction::Mov(d, a) => Instruction::Mov(d, a.map_id(f)), - ast::Instruction::Mul(d, a) => Instruction::Mul(d, a.map_id(f)), - ast::Instruction::Add(d, a) => Instruction::Add(d, a.map_id(f)), - ast::Instruction::Setp(d, a) => Instruction::Setp(d, a.map_id(f)), - ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a.map_id(f)), - ast::Instruction::Not(d, a) => Instruction::Not(d, a.map_id(f)), - ast::Instruction::Bra(d, a) => Instruction::Bra(d, a.map_id(f)), - ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a.map_id(f)), - ast::Instruction::Shl(d, a) => Instruction::Shl(d, a.map_id(f)), - ast::Instruction::St(d, a) => Instruction::St(d, a.map_id(f)), - ast::Instruction::Ret(d) => Instruction::Ret(d), - } - } -} - -impl ast::Instruction { - fn visit_id(&self, f: &mut F) { - match self { - ast::Instruction::Ld(_, a) => Arg::visit_id(a, f), - ast::Instruction::Mov(_, a) => a.visit_id(f), - ast::Instruction::Mul(_, a) => a.visit_id(f), - ast::Instruction::Add(_, a) => a.visit_id(f), - ast::Instruction::Setp(_, a) => a.visit_id(f), - ast::Instruction::SetpBool(_, a) => a.visit_id(f), - ast::Instruction::Not(_, a) => a.visit_id(f), - ast::Instruction::Cvt(_, a) => a.visit_id(f), - ast::Instruction::Shl(_, a) => a.visit_id(f), - ast::Instruction::St(_, a) => a.visit_id(f), - ast::Instruction::Bra(_, a) => a.visit_id(f), - ast::Instruction::Ret(_) => (), - } - } - - fn visit_id_mut(&mut self, f: &mut F) { - match self { - ast::Instruction::Ld(_, a) => a.visit_id_mut(f), - ast::Instruction::Mov(_, a) => a.visit_id_mut(f), - ast::Instruction::Mul(_, a) => a.visit_id_mut(f), - ast::Instruction::Add(_, a) => a.visit_id_mut(f), - ast::Instruction::Setp(_, a) => a.visit_id_mut(f), - ast::Instruction::SetpBool(_, a) => a.visit_id_mut(f), - ast::Instruction::Not(_, a) => a.visit_id_mut(f), - ast::Instruction::Cvt(_, a) => a.visit_id_mut(f), - ast::Instruction::Shl(_, a) => a.visit_id_mut(f), - ast::Instruction::St(_, a) => a.visit_id_mut(f), - ast::Instruction::Bra(_, a) => a.visit_id_mut(f), - ast::Instruction::Ret(_) => (), - } - } - - fn get_type(&self) -> Option { - match self { - ast::Instruction::Add(add, _) => Some(add.get_type()), - ast::Instruction::Ret(_) => None, - ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), - ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), - ast::Instruction::Mov(mov, _) => Some(mov.typ), - ast::Instruction::Mul(mul, _) => Some(mul.get_type()), - _ => todo!(), - } - } -} - -impl ast::Instruction { - fn jump_target(&self) -> Option { - match self { - ast::Instruction::Bra(_, a) => Some(a.src), - ast::Instruction::Ld(_, _) - | ast::Instruction::Mov(_, _) - | ast::Instruction::Mul(_, _) - | ast::Instruction::Add(_, _) - | ast::Instruction::Setp(_, _) - | ast::Instruction::SetpBool(_, _) - | ast::Instruction::Not(_, _) - | ast::Instruction::Cvt(_, _) - | ast::Instruction::Shl(_, _) - | ast::Instruction::St(_, _) - | ast::Instruction::Ret(_) => None, - } - } - - fn is_terminal(&self) -> bool { - match self { - ast::Instruction::Ret(_) => true, - ast::Instruction::Ld(_, _) - | ast::Instruction::Mov(_, _) - | ast::Instruction::Mul(_, _) - | ast::Instruction::Add(_, _) - | ast::Instruction::Setp(_, _) - | ast::Instruction::SetpBool(_, _) - | ast::Instruction::Not(_, _) - | ast::Instruction::Cvt(_, _) - | ast::Instruction::Shl(_, _) - | ast::Instruction::St(_, _) - | ast::Instruction::Bra(_, _) => false, - } - } } impl ast::Arg1 { @@ -2146,7 +1974,6 @@ fn insert_implicit_bitcasts( mod tests { use super::*; use crate::ast; - use crate::ptx; static SCALAR_TYPES: [ast::ScalarType; 15] = [ ast::ScalarType::B8,