mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 00:14:45 +00:00
Convert PTX predicates to a distinct conditional jump statement
This commit is contained in:
parent
a69c12a387
commit
3b433456a1
2 changed files with 178 additions and 77 deletions
|
@ -6,4 +6,8 @@ members = [
|
|||
"notcuda_inject",
|
||||
"notcuda_redirect",
|
||||
"ptx",
|
||||
]
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
rspirv = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' }
|
||||
spirv_headers = { git = 'https://github.com/vosen/rspirv', branch = 'notcuda' }
|
|
@ -113,7 +113,7 @@ fn emit_function<'a>(
|
|||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
ssa_legalize(
|
||||
let phis = ssa_legalize(
|
||||
&mut normalized_ids,
|
||||
contant_ids.len() as u32,
|
||||
unique_ids,
|
||||
|
@ -121,12 +121,35 @@ fn emit_function<'a>(
|
|||
&doms,
|
||||
&dom_fronts,
|
||||
);
|
||||
emit_function_body_ops(builder);
|
||||
let id_offset = builder.reserve_ids(unique_ids);
|
||||
emit_function_args(builder, id_offset, map, &f.args);
|
||||
emit_function_body_ops(builder, id_offset, &normalized_ids, &bbs)?;
|
||||
builder.end_function()?;
|
||||
builder.ret()?;
|
||||
builder.end_function()?;
|
||||
Ok(func_id)
|
||||
}
|
||||
|
||||
fn emit_function_args(
|
||||
builder: &mut dr::Builder,
|
||||
id_offset: spirv::Word,
|
||||
map: &mut TypeWordMap,
|
||||
args: &[ast::Argument],
|
||||
) {
|
||||
let mut id = id_offset;
|
||||
for arg in args {
|
||||
let result_type = map.get_or_add(builder, SpirvType::Base(arg.a_type));
|
||||
let inst = dr::Instruction::new(
|
||||
spirv::Op::FunctionParameter,
|
||||
Some(result_type),
|
||||
Some(id),
|
||||
Vec::new(),
|
||||
);
|
||||
builder.function.as_mut().unwrap().parameters.push(inst);
|
||||
id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
|
||||
let mut id = result.len() as u32;
|
||||
for arg in args {
|
||||
|
@ -152,13 +175,35 @@ fn collect_label_ids<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn emit_function_body_ops(builder: &mut dr::Builder) {
|
||||
todo!()
|
||||
fn emit_function_body_ops(
|
||||
builder: &mut dr::Builder,
|
||||
id_offset: spirv::Word,
|
||||
func: &[Statement],
|
||||
cfg: &[BasicBlock],
|
||||
) -> Result<(), dr::Error> {
|
||||
for bb_idx in 0..cfg.len() {
|
||||
let body = get_bb_body(func, cfg, BBIndex(bb_idx));
|
||||
if body.len() == 0 {
|
||||
continue;
|
||||
}
|
||||
let header_id = if let Statement::Label(id) = body[0] {
|
||||
Some(id_offset + id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
builder.begin_block(header_id)?;
|
||||
for s in body {
|
||||
/*
|
||||
match s {
|
||||
Statement::Instruction(pred, inst) => (),
|
||||
Statement::Label(_) => (),
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This functions converts string identifiers to numeric identifiers in a normalized form, where
|
||||
// - identifiers in the range [0..constant_identifiers.len()) are arguments and labels
|
||||
// - identifiers in the range [constant_identifiers.len()..result.1) are variables
|
||||
// TODO: support scopes
|
||||
fn normalize_identifiers<'a>(
|
||||
func: Vec<ast::Statement<&'a str>>,
|
||||
|
@ -167,8 +212,8 @@ fn normalize_identifiers<'a>(
|
|||
let mut result = Vec::with_capacity(func.len());
|
||||
let mut id: u32 = constant_identifiers.len() as u32;
|
||||
let mut remapped_ids = HashMap::new();
|
||||
let mut get_or_add = |key| {
|
||||
constant_identifiers.get(key).map_or_else(
|
||||
let mut get_or_add = |key| match key {
|
||||
Some(key) => constant_identifiers.get(key).map_or_else(
|
||||
|| {
|
||||
*remapped_ids.entry(key).or_insert_with(|| {
|
||||
let to_insert = id;
|
||||
|
@ -177,12 +222,15 @@ fn normalize_identifiers<'a>(
|
|||
})
|
||||
},
|
||||
|id| *id,
|
||||
)
|
||||
),
|
||||
None => {
|
||||
let to_insert = id;
|
||||
id += 1;
|
||||
to_insert
|
||||
}
|
||||
};
|
||||
for s in func {
|
||||
if let Some(s) = Statement::from_ast(s, &mut get_or_add) {
|
||||
result.push(s);
|
||||
}
|
||||
Statement::from_ast(s, &mut result, &mut get_or_add);
|
||||
}
|
||||
(result, id)
|
||||
}
|
||||
|
@ -195,13 +243,7 @@ fn ssa_legalize(
|
|||
doms: &[BBIndex],
|
||||
dom_fronts: &[HashSet<BBIndex>],
|
||||
) -> Vec<Vec<PhiDef>> {
|
||||
let phis = gather_phi_sets(
|
||||
&func,
|
||||
constant_ids,
|
||||
unique_ids,
|
||||
&bbs,
|
||||
dom_fronts,
|
||||
);
|
||||
let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts);
|
||||
apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
|
||||
}
|
||||
|
||||
|
@ -431,12 +473,12 @@ fn gather_phi_sets(
|
|||
}
|
||||
}
|
||||
};
|
||||
// We try to avoid adding labels to the global-visbility set.
|
||||
// We are not 100% precise (we add jump targets in bra), but it shouldn't be a problem
|
||||
for s in get_bb_body(func, cfg, BBIndex(bb)) {
|
||||
match s {
|
||||
Statement::Instruction(pred, inst) => {
|
||||
pred.as_ref().map(|p| p.visit_id(&mut visitor));
|
||||
inst.visit_id(&mut visitor);
|
||||
}
|
||||
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
|
||||
Statement::Conditional(brc) => visitor(false, &brc.predicate),
|
||||
// label redefinition is a compile-time error
|
||||
Statement::Label(_) => (),
|
||||
}
|
||||
|
@ -464,20 +506,16 @@ fn gather_phi_sets(
|
|||
|
||||
fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
||||
// edge signify pred/succ relationship between bbs
|
||||
let mut bb_edge = HashSet::new();
|
||||
let mut unresolved_bb_edge = Vec::new();
|
||||
// bb start means that a bb is starting at this statement, but there's no predecessor
|
||||
let mut bb_start = Vec::new();
|
||||
let mut labels = HashMap::new();
|
||||
for (idx, s) in fun.iter().enumerate() {
|
||||
match s {
|
||||
Statement::Instruction(pred, i) => {
|
||||
Statement::Instruction(i) => {
|
||||
if let Some(id) = i.jump_target() {
|
||||
unresolved_bb_edge.push((StmtIndex(idx), id));
|
||||
if idx + 1 < fun.len() {
|
||||
if pred.is_some() {
|
||||
bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
|
||||
}
|
||||
bb_start.push(StmtIndex(idx + 1));
|
||||
}
|
||||
} else if i.is_terminal() && idx + 1 < fun.len() {
|
||||
|
@ -487,23 +525,32 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
|||
Statement::Label(id) => {
|
||||
labels.insert(id, StmtIndex(idx));
|
||||
}
|
||||
Statement::Conditional(bra) => {
|
||||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
|
||||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
|
||||
}
|
||||
};
|
||||
}
|
||||
let mut bb_edge = HashSet::new();
|
||||
// Resolve every <jump into label> into <jump into statement index>
|
||||
// TODO: handle jumps into nowhere
|
||||
for (idx, id) in unresolved_bb_edge {
|
||||
let target = labels[&id];
|
||||
bb_edge.insert((idx, target));
|
||||
bb_start.push(target);
|
||||
// now check if the preceding statement forms an edge
|
||||
// now check if there is an edge target-1 -> target
|
||||
if target != StmtIndex(0) {
|
||||
match &fun[target.0 - 1] {
|
||||
Statement::Instruction(pred, i) => {
|
||||
if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
|
||||
Statement::Instruction(i) => {
|
||||
if !(i.jump_target().is_some() || i.is_terminal()) {
|
||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||
}
|
||||
}
|
||||
Statement::Label(_) => (),
|
||||
Statement::Label(_) => {
|
||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||
}
|
||||
// This is already in `unresolved_bb_edge`
|
||||
Statement::Conditional(_) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -665,43 +712,88 @@ impl fmt::Display for BBIndex {
|
|||
|
||||
enum Statement {
|
||||
Label(u32),
|
||||
Instruction(
|
||||
Option<ast::PredAt<spirv::Word>>,
|
||||
ast::Instruction<spirv::Word>,
|
||||
),
|
||||
Instruction(ast::Instruction<spirv::Word>),
|
||||
// SPIR-V compatible replacement for PTX predicates
|
||||
Conditional(BrachCondition),
|
||||
}
|
||||
|
||||
struct BrachCondition {
|
||||
predicate: spirv::Word,
|
||||
if_true: spirv::Word,
|
||||
if_false: spirv::Word,
|
||||
}
|
||||
impl BrachCondition {
|
||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, &self.predicate);
|
||||
f(false, &self.if_true);
|
||||
f(false, &self.if_false);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.predicate);
|
||||
f(false, &mut self.if_true);
|
||||
f(false, &mut self.if_false);
|
||||
}
|
||||
}
|
||||
|
||||
impl Statement {
|
||||
fn from_ast<'a, F: FnMut(&'a str) -> u32>(
|
||||
fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
|
||||
s: ast::Statement<&'a str>,
|
||||
f: &mut F,
|
||||
) -> Option<Self> {
|
||||
out: &mut Vec<Statement>,
|
||||
get_id: &mut F,
|
||||
) {
|
||||
match s {
|
||||
ast::Statement::Label(name) => Some(Statement::Label(f(name))),
|
||||
ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))),
|
||||
ast::Statement::Instruction(p, i) => {
|
||||
Some(Statement::Instruction(p.map(|p| p.map_id(f)), i.map_id(f)))
|
||||
if let Some(pred) = p {
|
||||
let predicate = get_id(Some(pred.label));
|
||||
let mut if_true = get_id(None);
|
||||
let mut if_false = get_id(None);
|
||||
if pred.not {
|
||||
std::mem::swap(&mut if_true, &mut if_false);
|
||||
}
|
||||
let folded_bra = match &i {
|
||||
ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))),
|
||||
_ => None,
|
||||
};
|
||||
let branch = BrachCondition {
|
||||
predicate,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
out.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
out.push(Statement::Label(if_true));
|
||||
out.push(Statement::Instruction(
|
||||
i.map_id(&mut |name| get_id(Some(name))),
|
||||
));
|
||||
}
|
||||
out.push(Statement::Label(if_false));
|
||||
} else {
|
||||
out.push(Statement::Instruction(
|
||||
i.map_id(&mut |name| get_id(Some(name))),
|
||||
));
|
||||
}
|
||||
}
|
||||
ast::Statement::Variable(_) => None,
|
||||
ast::Statement::Variable(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn for_dst_id<F: FnMut(spirv::Word)>(&self, f: &mut F) {
|
||||
match self {
|
||||
Statement::Label(id) => f(*id),
|
||||
Statement::Instruction(pred, inst) => {
|
||||
pred.as_ref().map(|p| p.for_dst_id(f));
|
||||
Statement::Instruction(inst) => {
|
||||
inst.for_dst_id(f);
|
||||
}
|
||||
Statement::Conditional(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||
match self {
|
||||
Statement::Label(id) => f(true, id),
|
||||
Statement::Instruction(pred, inst) => {
|
||||
pred.as_ref().map(|p| p.visit_id(f));
|
||||
inst.visit_id(f);
|
||||
}
|
||||
Statement::Label(id) => f(false, id),
|
||||
Statement::Instruction(inst) => inst.visit_id(f),
|
||||
Statement::Conditional(bra) => bra.visit_id(f),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -709,11 +801,9 @@ impl Statement {
|
|||
// otherwise SSA renaming will yield weird results
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
match self {
|
||||
Statement::Label(id) => f(true, id),
|
||||
Statement::Instruction(pred, inst) => {
|
||||
pred.as_mut().map(|p| p.visit_id_mut(f));
|
||||
inst.visit_id_mut(f);
|
||||
}
|
||||
Statement::Label(id) => f(false, id),
|
||||
Statement::Instruction(inst) => inst.visit_id_mut(f),
|
||||
Statement::Conditional(bra) => bra.visit_id_mut(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1182,10 +1272,10 @@ mod tests {
|
|||
fn get_basic_blocks_miniloop() {
|
||||
let func = vec![
|
||||
Statement::Label(12),
|
||||
Statement::Instruction(
|
||||
None,
|
||||
ast::Instruction::Bra(ast::BraData {}, ast::Arg1 { src: 12 }),
|
||||
),
|
||||
Statement::Instruction(ast::Instruction::Bra(
|
||||
ast::BraData {},
|
||||
ast::Arg1 { src: 12 },
|
||||
)),
|
||||
];
|
||||
let bbs = get_basic_blocks(&func);
|
||||
assert_eq!(
|
||||
|
@ -1390,11 +1480,11 @@ mod tests {
|
|||
mov.u32 k, 0;
|
||||
block_2:
|
||||
setp.ge.u32 p, k, 100;
|
||||
@p bra block_4;
|
||||
block_3:
|
||||
@p bra block_4; // conditional p block_4 if_false1
|
||||
// if_false1:
|
||||
setp.ge.u32 q, j, 20;
|
||||
@q bra block_6;
|
||||
block_5:
|
||||
@q bra block_6; // conditional q block_6 if_false2
|
||||
// if_false2:
|
||||
mov.u32 j, i;
|
||||
add.u32 k, k, 1;
|
||||
bra block_7;
|
||||
|
@ -1563,7 +1653,7 @@ mod tests {
|
|||
assert_eq!(errors.len(), 0);
|
||||
let mut constant_ids = HashMap::new();
|
||||
collect_label_ids(&mut constant_ids, &fn_ast);
|
||||
assert_eq!(constant_ids.len(), 6);
|
||||
assert_eq!(constant_ids.len(), 4);
|
||||
let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
|
||||
let bbs = get_basic_blocks(&normalized_ids);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
|
@ -1580,11 +1670,11 @@ mod tests {
|
|||
phi,
|
||||
vec![
|
||||
HashSet::new(),
|
||||
to_hashset(vec![7, 8]),
|
||||
to_hashset(vec![5, 6]),
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
to_hashset(vec![7, 8]),
|
||||
to_hashset(vec![5, 6]),
|
||||
HashSet::new()
|
||||
]
|
||||
);
|
||||
|
@ -1634,12 +1724,12 @@ mod tests {
|
|||
let k5 = get_dst(&func[15]);
|
||||
let p1 = get_dst(&func[4]);
|
||||
let q1 = get_dst(&func[7]);
|
||||
let block_2 = get_dst(&func[3]);
|
||||
let block_3 = get_dst(&func[6]);
|
||||
let block_5 = get_dst(&func[9]);
|
||||
let block_6 = get_dst(&func[13]);
|
||||
let block_7 = get_dst(&func[16]);
|
||||
let block_4 = get_dst(&func[18]);
|
||||
let block_2 = get_label(&func[3]);
|
||||
let if_false1 = get_label(&func[6]);
|
||||
let if_false2 = get_label(&func[9]);
|
||||
let block_6 = get_label(&func[13]);
|
||||
let block_7 = get_label(&func[16]);
|
||||
let block_4 = get_label(&func[18]);
|
||||
|
||||
{
|
||||
assert_eq!(get_ids(&func[0]), vec![i1]);
|
||||
|
@ -1652,13 +1742,13 @@ mod tests {
|
|||
);
|
||||
assert_eq!(get_ids(&func[3]), vec![block_2]);
|
||||
assert_eq!(get_ids(&func[4]), vec![p1, k2]);
|
||||
assert_eq!(get_ids(&func[5]), vec![p1, block_4]);
|
||||
assert_eq!(get_ids(&func[5]), vec![p1, block_4, if_false1]);
|
||||
|
||||
assert_eq!(get_ids(&func[6]), vec![block_3]);
|
||||
assert_eq!(get_ids(&func[6]), vec![if_false1]);
|
||||
assert_eq!(get_ids(&func[7]), vec![q1, j2]);
|
||||
assert_eq!(get_ids(&func[8]), vec![q1, block_6]);
|
||||
assert_eq!(get_ids(&func[8]), vec![q1, block_6, if_false2]);
|
||||
|
||||
assert_eq!(get_ids(&func[9]), vec![block_5]);
|
||||
assert_eq!(get_ids(&func[9]), vec![if_false2]);
|
||||
assert_eq!(get_ids(&func[10]), vec![j3, i1]);
|
||||
assert_eq!(get_ids(&func[11]), vec![k3, k2]);
|
||||
assert_eq!(get_ids(&func[12]), vec![block_7]);
|
||||
|
@ -1739,6 +1829,13 @@ mod tests {
|
|||
result.unwrap()
|
||||
}
|
||||
|
||||
fn get_label(s: &Statement) -> spirv::Word {
|
||||
match s {
|
||||
Statement::Label(id) => *id,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word {
|
||||
for phi_set in phi {
|
||||
if phi_set.src.contains(&src) {
|
||||
|
|
Loading…
Add table
Reference in a new issue