Convert PTX predicates to a distinct conditional jump statement

This commit is contained in:
Andrzej Janik 2020-05-03 23:53:56 +02:00
parent a69c12a387
commit 3b433456a1
2 changed files with 178 additions and 77 deletions

View file

@ -6,4 +6,8 @@ members = [
"notcuda_inject",
"notcuda_redirect",
"ptx",
]
]
[patch.crates-io]
rspirv = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' }
spirv_headers = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' }

View file

@ -113,7 +113,7 @@ fn emit_function<'a>(
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
ssa_legalize(
let phis = ssa_legalize(
&mut normalized_ids,
contant_ids.len() as u32,
unique_ids,
@ -121,12 +121,35 @@ fn emit_function<'a>(
&doms,
&dom_fronts,
);
emit_function_body_ops(builder);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?;
builder.end_function()?;
builder.ret()?;
builder.end_function()?;
Ok(func_id)
}
fn emit_function_args(
builder: &mut dr::Builder,
id_offset: spirv::Word,
map: &mut TypeWordMap,
args: &[ast::Argument],
) {
let mut id = id_offset;
for arg in args {
let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
Some(id),
Vec::new(),
);
builder.function.as_mut().unwrap().parameters.push(inst);
id += 1;
}
}
fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
let mut id = result.len() as u32;
for arg in args {
@ -152,13 +175,35 @@ fn collect_label_ids<'a>(
}
}
fn emit_function_body_ops(builder: &mut dr::Builder) {
todo!()
fn emit_function_body_ops(
builder: &mut dr::Builder,
id_offset: spirv::Word,
func: &[Statement],
cfg: &[BasicBlock],
) -> Result<(), dr::Error> {
for bb_idx in 0..cfg.len() {
let body = get_bb_body(func, cfg, BBIndex(bb_idx));
if body.len() == 0 {
continue;
}
let header_id = if let Statement::Label(id) = body[0] {
Some(id_offset + id)
} else {
None
};
builder.begin_block(header_id)?;
for s in body {
/*
match s {
Statement::Instruction(pred, inst) => (),
Statement::Label(_) => (),
}
*/
}
}
Ok(())
}
// This functions converts string identifiers to numeric identifiers in a normalized form, where
// - identifiers in the range [0..constant_identifiers.len()) are arguments and labels
// - identifiers in the range [constant_identifiers.len()..result.1) are variables
// TODO: support scopes
fn normalize_identifiers<'a>(
func: Vec<ast::Statement<&'a str>>,
@ -167,8 +212,8 @@ fn normalize_identifiers<'a>(
let mut result = Vec::with_capacity(func.len());
let mut id: u32 = constant_identifiers.len() as u32;
let mut remapped_ids = HashMap::new();
let mut get_or_add = |key| {
constant_identifiers.get(key).map_or_else(
let mut get_or_add = |key| match key {
Some(key) => constant_identifiers.get(key).map_or_else(
|| {
*remapped_ids.entry(key).or_insert_with(|| {
let to_insert = id;
@ -177,12 +222,15 @@ fn normalize_identifiers<'a>(
})
},
|id| *id,
)
),
None => {
let to_insert = id;
id += 1;
to_insert
}
};
for s in func {
if let Some(s) = Statement::from_ast(s, &mut get_or_add) {
result.push(s);
}
Statement::from_ast(s, &mut result, &mut get_or_add);
}
(result, id)
}
@ -195,13 +243,7 @@ fn ssa_legalize(
doms: &[BBIndex],
dom_fronts: &[HashSet<BBIndex>],
) -> Vec<Vec<PhiDef>> {
let phis = gather_phi_sets(
&func,
constant_ids,
unique_ids,
&bbs,
dom_fronts,
);
let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts);
apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
}
@ -431,12 +473,12 @@ fn gather_phi_sets(
}
}
};
// We try to avoid adding labels to the global-visbility set.
// We are not 100% precise (we add jump targets in bra), but it shouldn't be a problem
for s in get_bb_body(func, cfg, BBIndex(bb)) {
match s {
Statement::Instruction(pred, inst) => {
pred.as_ref().map(|p| p.visit_id(&mut visitor));
inst.visit_id(&mut visitor);
}
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
Statement::Conditional(brc) => visitor(false, &brc.predicate),
// label redefinition is a compile-time error
Statement::Label(_) => (),
}
@ -464,20 +506,16 @@ fn gather_phi_sets(
fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
// edge signify pred/succ relationship between bbs
let mut bb_edge = HashSet::new();
let mut unresolved_bb_edge = Vec::new();
// bb start means that a bb is starting at this statement, but there's no predecessor
let mut bb_start = Vec::new();
let mut labels = HashMap::new();
for (idx, s) in fun.iter().enumerate() {
match s {
Statement::Instruction(pred, i) => {
Statement::Instruction(i) => {
if let Some(id) = i.jump_target() {
unresolved_bb_edge.push((StmtIndex(idx), id));
if idx + 1 < fun.len() {
if pred.is_some() {
bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
}
bb_start.push(StmtIndex(idx + 1));
}
} else if i.is_terminal() && idx + 1 < fun.len() {
@ -487,23 +525,32 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
Statement::Label(id) => {
labels.insert(id, StmtIndex(idx));
}
Statement::Conditional(bra) => {
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
}
};
}
let mut bb_edge = HashSet::new();
// Resolve every <jump into label> into <jump into statement index>
// TODO: handle jumps into nowhere
for (idx, id) in unresolved_bb_edge {
let target = labels[&id];
bb_edge.insert((idx, target));
bb_start.push(target);
// now check if the preceding statement forms an edge
// now check if there is an edge target-1 -> target
if target != StmtIndex(0) {
match &fun[target.0 - 1] {
Statement::Instruction(pred, i) => {
if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
Statement::Instruction(i) => {
if !(i.jump_target().is_some() || i.is_terminal()) {
bb_edge.insert((StmtIndex(target.0 - 1), target));
}
}
Statement::Label(_) => (),
Statement::Label(_) => {
bb_edge.insert((StmtIndex(target.0 - 1), target));
}
// This is already in `unresolved_bb_edge`
Statement::Conditional(_) => (),
}
}
}
@ -665,43 +712,88 @@ impl fmt::Display for BBIndex {
enum Statement {
Label(u32),
Instruction(
Option<ast::PredAt<spirv::Word>>,
ast::Instruction<spirv::Word>,
),
Instruction(ast::Instruction<spirv::Word>),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
}
struct BrachCondition {
predicate: spirv::Word,
if_true: spirv::Word,
if_false: spirv::Word,
}
impl BrachCondition {
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
f(false, &self.predicate);
f(false, &self.if_true);
f(false, &self.if_false);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.predicate);
f(false, &mut self.if_true);
f(false, &mut self.if_false);
}
}
impl Statement {
fn from_ast<'a, F: FnMut(&'a str) -> u32>(
fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
s: ast::Statement<&'a str>,
f: &mut F,
) -> Option<Self> {
out: &mut Vec<Statement>,
get_id: &mut F,
) {
match s {
ast::Statement::Label(name) => Some(Statement::Label(f(name))),
ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))),
ast::Statement::Instruction(p, i) => {
Some(Statement::Instruction(p.map(|p| p.map_id(f)), i.map_id(f)))
if let Some(pred) = p {
let predicate = get_id(Some(pred.label));
let mut if_true = get_id(None);
let mut if_false = get_id(None);
if pred.not {
std::mem::swap(&mut if_true, &mut if_false);
}
let folded_bra = match &i {
ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))),
_ => None,
};
let branch = BrachCondition {
predicate,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
out.push(Statement::Conditional(branch));
if folded_bra.is_none() {
out.push(Statement::Label(if_true));
out.push(Statement::Instruction(
i.map_id(&mut |name| get_id(Some(name))),
));
}
out.push(Statement::Label(if_false));
} else {
out.push(Statement::Instruction(
i.map_id(&mut |name| get_id(Some(name))),
));
}
}
ast::Statement::Variable(_) => None,
ast::Statement::Variable(_) => (),
}
}
fn for_dst_id<F: FnMut(spirv::Word)>(&self, f: &mut F) {
match self {
Statement::Label(id) => f(*id),
Statement::Instruction(pred, inst) => {
pred.as_ref().map(|p| p.for_dst_id(f));
Statement::Instruction(inst) => {
inst.for_dst_id(f);
}
Statement::Conditional(_) => (),
}
}
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
match self {
Statement::Label(id) => f(true, id),
Statement::Instruction(pred, inst) => {
pred.as_ref().map(|p| p.visit_id(f));
inst.visit_id(f);
}
Statement::Label(id) => f(false, id),
Statement::Instruction(inst) => inst.visit_id(f),
Statement::Conditional(bra) => bra.visit_id(f),
}
}
@ -709,11 +801,9 @@ impl Statement {
// otherwise SSA renaming will yield weird results
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Label(id) => f(true, id),
Statement::Instruction(pred, inst) => {
pred.as_mut().map(|p| p.visit_id_mut(f));
inst.visit_id_mut(f);
}
Statement::Label(id) => f(false, id),
Statement::Instruction(inst) => inst.visit_id_mut(f),
Statement::Conditional(bra) => bra.visit_id_mut(f),
}
}
}
@ -1182,10 +1272,10 @@ mod tests {
fn get_basic_blocks_miniloop() {
let func = vec![
Statement::Label(12),
Statement::Instruction(
None,
ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { src: 12 }),
),
Statement::Instruction(ast::Instruction::Bra(
ast::BraData {},
ast::Arg1 { src: 12 },
)),
];
let bbs = get_basic_blocks(&func);
assert_eq!(
@ -1390,11 +1480,11 @@ mod tests {
mov.u32 k, 0;
block_2:
setp.ge.u32 p, k, 100;
@p bra block_4;
block_3:
@p bra block_4; // conditional p block_4 if_false1
// if_false1:
setp.ge.u32 q, j, 20;
@q bra block_6;
block_5:
@q bra block_6; // conditional q block_6 if_false2
// if_false2:
mov.u32 j, i;
add.u32 k, k, 1;
bra block_7;
@ -1563,7 +1653,7 @@ mod tests {
assert_eq!(errors.len(), 0);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 6);
assert_eq!(constant_ids.len(), 4);
let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
@ -1580,11 +1670,11 @@ mod tests {
phi,
vec![
HashSet::new(),
to_hashset(vec![7, 8]),
to_hashset(vec![5, 6]),
HashSet::new(),
HashSet::new(),
HashSet::new(),
to_hashset(vec![7, 8]),
to_hashset(vec![5, 6]),
HashSet::new()
]
);
@ -1634,12 +1724,12 @@ mod tests {
let k5 = get_dst(&func[15]);
let p1 = get_dst(&func[4]);
let q1 = get_dst(&func[7]);
let block_2 = get_dst(&func[3]);
let block_3 = get_dst(&func[6]);
let block_5 = get_dst(&func[9]);
let block_6 = get_dst(&func[13]);
let block_7 = get_dst(&func[16]);
let block_4 = get_dst(&func[18]);
let block_2 = get_label(&func[3]);
let if_false1 = get_label(&func[6]);
let if_false2 = get_label(&func[9]);
let block_6 = get_label(&func[13]);
let block_7 = get_label(&func[16]);
let block_4 = get_label(&func[18]);
{
assert_eq!(get_ids(&func[0]), vec![i1]);
@ -1652,13 +1742,13 @@ mod tests {
);
assert_eq!(get_ids(&func[3]), vec![block_2]);
assert_eq!(get_ids(&func[4]), vec![p1, k2]);
assert_eq!(get_ids(&func[5]), vec![p1, block_4]);
assert_eq!(get_ids(&func[5]), vec![p1, block_4, if_false1]);
assert_eq!(get_ids(&func[6]), vec![block_3]);
assert_eq!(get_ids(&func[6]), vec![if_false1]);
assert_eq!(get_ids(&func[7]), vec![q1, j2]);
assert_eq!(get_ids(&func[8]), vec![q1, block_6]);
assert_eq!(get_ids(&func[8]), vec![q1, block_6, if_false2]);
assert_eq!(get_ids(&func[9]), vec![block_5]);
assert_eq!(get_ids(&func[9]), vec![if_false2]);
assert_eq!(get_ids(&func[10]), vec![j3, i1]);
assert_eq!(get_ids(&func[11]), vec![k3, k2]);
assert_eq!(get_ids(&func[12]), vec![block_7]);
@ -1739,6 +1829,13 @@ mod tests {
result.unwrap()
}
fn get_label(s: &Statement) -> spirv::Word {
match s {
Statement::Label(id) => *id,
_ => panic!(),
}
}
fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word {
for phi_set in phi {
if phi_set.src.contains(&src) {