Minor cleanup

This commit is contained in:
Andrzej Janik 2020-07-27 00:33:57 +02:00
parent ec7ab8e5c4
commit 04820fba2f

View file

@ -1,9 +1,7 @@
use crate::ast;
use bit_vec::BitVec;
use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::{borrow::Cow, fmt, iter, mem};
use std::collections::{HashMap, HashSet};
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@ -218,7 +216,7 @@ fn normalize_labels(
fn normalize_predicates(
func: Vec<ast::Statement<spirv::Word>>,
id_def: &mut NumericIdResolver,
) -> Vec<Statement<NormalizedArgs>> {
) -> Vec<NormalizedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
@ -258,9 +256,9 @@ fn normalize_predicates(
}
fn insert_mem_ssa_statements(
func: Vec<Statement<NormalizedArgs>>,
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<Statement<NormalizedArgs>> {
) -> Vec<NormalizedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
@ -318,7 +316,7 @@ fn insert_mem_ssa_statements(
}
fn expand_arguments(
func: Vec<Statement<NormalizedArgs>>,
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(func.len());
@ -608,36 +606,6 @@ fn emit_function_args(
}
}
fn collect_arg_ids<'a>(
result: &mut HashMap<&'a str, spirv::Word>,
type_check: &mut HashMap<spirv::Word, ast::Type>,
args: &'a [ast::Argument<'a>],
) {
let mut id = result.len() as u32;
for arg in args {
result.insert(arg.name, id);
type_check.insert(id, ast::Type::Scalar(arg.a_type));
id += 1;
}
}
fn collect_label_ids<'a>(
result: &mut HashMap<&'a str, spirv::Word>,
fn_body: &[ast::Statement<&'a str>],
) {
let mut id = result.len() as u32;
for s in fn_body {
match s {
ast::Statement::Label(name) => {
result.insert(name, id);
id += 1;
}
ast::Statement::Instruction(_, _) => (),
ast::Statement::Variable(_) => (),
}
}
}
fn emit_function_body_ops(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -749,7 +717,7 @@ fn emit_function_body_ops(
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
}
Statement::StoreVar(arg, typ) => {
Statement::StoreVar(arg, _) => {
builder.store(arg.src1, arg.src2, None, [])?;
}
}
@ -893,7 +861,7 @@ fn expand_map_ids<'a>(
}
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))),
i.map_id1(&mut |id| id_defs.get_id(id)),
i.map_id(&mut |id| id_defs.get_id(id)),
)),
ast::Statement::Variable(var) => match var.count {
Some(count) => {
@ -945,7 +913,6 @@ impl<'a> StringIdResolver<'a> {
self.variables[id]
}
#[must_use]
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
let numeric_id = self.current_id;
self.variables.insert(Cow::Borrowed(id), numeric_id);
@ -1023,10 +990,6 @@ impl<A: Args> Statement<A> {
Statement::Constant(cons) => cons.visit_id_mut(f),
}
}
fn get_type(&self) -> Option<ast::Type> {
todo!()
}
}
trait Args {
@ -1125,23 +1088,6 @@ impl<A: Args> Instruction<A> {
_ => todo!(),
}
}
fn is_terminal(&self) -> bool {
match self {
Instruction::Ret(_) => true,
Instruction::Ld(_, _)
| Instruction::Mov(_, _)
| Instruction::Mul(_, _)
| Instruction::Add(_, _)
| Instruction::Setp(_, _)
| Instruction::SetpBool(_, _)
| Instruction::Not(_, _)
| Instruction::Cvt(_, _)
| Instruction::Shl(_, _)
| Instruction::St(_, _)
| Instruction::Bra(_, _) => false,
}
}
}
impl Instruction<NormalizedArgs> {
@ -1164,23 +1110,6 @@ impl Instruction<NormalizedArgs> {
}
impl Instruction<ExpandedArgs> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
match self {
Instruction::Ld(_, a) => a.visit_id(f),
Instruction::Mov(_, a) => a.visit_id(f),
Instruction::Mul(_, a) => a.visit_id(f),
Instruction::Add(_, a) => a.visit_id(f),
Instruction::Setp(_, a) => a.visit_id(f),
Instruction::SetpBool(_, a) => a.visit_id(f),
Instruction::Not(_, a) => a.visit_id(f),
Instruction::Cvt(_, a) => a.visit_id(f),
Instruction::Shl(_, a) => a.visit_id(f),
Instruction::St(_, a) => a.visit_id(f),
Instruction::Bra(_, a) => a.visit_id(f),
Instruction::Ret(_) => (),
}
}
fn jump_target(&self) -> Option<spirv::Word> {
match self {
Instruction::Bra(_, a) => Some(a.src),
@ -1390,7 +1319,7 @@ impl<T> ast::PredAt<T> {
}
impl<T> ast::Instruction<T> {
fn map_id1<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@ -1406,107 +1335,6 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
}
}
fn map_id<F: FnMut(T) -> spirv::Word>(self, f: &mut F) -> Instruction<NormalizedArgs> {
match self {
ast::Instruction::Ld(d, a) => Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => Instruction::Mov(d, a.map_id(f)),
ast::Instruction::Mul(d, a) => Instruction::Mul(d, a.map_id(f)),
ast::Instruction::Add(d, a) => Instruction::Add(d, a.map_id(f)),
ast::Instruction::Setp(d, a) => Instruction::Setp(d, a.map_id(f)),
ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a.map_id(f)),
ast::Instruction::Not(d, a) => Instruction::Not(d, a.map_id(f)),
ast::Instruction::Bra(d, a) => Instruction::Bra(d, a.map_id(f)),
ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a.map_id(f)),
ast::Instruction::Shl(d, a) => Instruction::Shl(d, a.map_id(f)),
ast::Instruction::St(d, a) => Instruction::St(d, a.map_id(f)),
ast::Instruction::Ret(d) => Instruction::Ret(d),
}
}
}
impl ast::Instruction<spirv::Word> {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
match self {
ast::Instruction::Ld(_, a) => Arg::visit_id(a, f),
ast::Instruction::Mov(_, a) => a.visit_id(f),
ast::Instruction::Mul(_, a) => a.visit_id(f),
ast::Instruction::Add(_, a) => a.visit_id(f),
ast::Instruction::Setp(_, a) => a.visit_id(f),
ast::Instruction::SetpBool(_, a) => a.visit_id(f),
ast::Instruction::Not(_, a) => a.visit_id(f),
ast::Instruction::Cvt(_, a) => a.visit_id(f),
ast::Instruction::Shl(_, a) => a.visit_id(f),
ast::Instruction::St(_, a) => a.visit_id(f),
ast::Instruction::Bra(_, a) => a.visit_id(f),
ast::Instruction::Ret(_) => (),
}
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
ast::Instruction::Ld(_, a) => a.visit_id_mut(f),
ast::Instruction::Mov(_, a) => a.visit_id_mut(f),
ast::Instruction::Mul(_, a) => a.visit_id_mut(f),
ast::Instruction::Add(_, a) => a.visit_id_mut(f),
ast::Instruction::Setp(_, a) => a.visit_id_mut(f),
ast::Instruction::SetpBool(_, a) => a.visit_id_mut(f),
ast::Instruction::Not(_, a) => a.visit_id_mut(f),
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
ast::Instruction::St(_, a) => a.visit_id_mut(f),
ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
ast::Instruction::Ret(_) => (),
}
}
fn get_type(&self) -> Option<ast::Type> {
match self {
ast::Instruction::Add(add, _) => Some(add.get_type()),
ast::Instruction::Ret(_) => None,
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
ast::Instruction::Mov(mov, _) => Some(mov.typ),
ast::Instruction::Mul(mul, _) => Some(mul.get_type()),
_ => todo!(),
}
}
}
impl<T: Copy> ast::Instruction<T> {
fn jump_target(&self) -> Option<T> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_) => None,
}
}
fn is_terminal(&self) -> bool {
match self {
ast::Instruction::Ret(_) => true,
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
| ast::Instruction::SetpBool(_, _)
| ast::Instruction::Not(_, _)
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Bra(_, _) => false,
}
}
}
impl<T> ast::Arg1<T> {
@ -2146,7 +1974,6 @@ fn insert_implicit_bitcasts(
mod tests {
use super::*;
use crate::ast;
use crate::ptx;
static SCALAR_TYPES: [ast::ScalarType; 15] = [
ast::ScalarType::B8,