Add support for some most common setp variants and fix a bug with branch conditions

This commit is contained in:
Andrzej Janik 2020-07-31 02:03:59 +02:00
parent 9ed3dc54f2
commit ed295c4083
3 changed files with 149 additions and 10 deletions

View file

@ -355,6 +355,7 @@ pub struct SetpData {
pub cmp_op: SetpCompareOp,
}
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum SetpCompareOp {
Eq,
NotEq,

View file

@ -44,6 +44,7 @@ test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
test_ptx!(add, [1u64], [2u64]);
test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
struct DisplayError<T: Display + Debug> {
err: T,

View file

@ -224,20 +224,20 @@ fn normalize_predicates(
ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred {
let mut if_true = id_def.new_id(None);
let mut if_false = id_def.new_id(None);
if pred.not {
std::mem::swap(&mut if_true, &mut if_false);
}
let if_true = id_def.new_id(None);
let if_false = id_def.new_id(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
};
let branch = BrachCondition {
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
@ -306,9 +306,21 @@ fn insert_mem_ssa_statements(
result.append(&mut post_statements);
}
},
s @ Statement::Variable(_, _, _)
| s @ Statement::Label(_)
| s @ Statement::Conditional(_) => result.push(s),
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
ast::ExtendedScalarType::Pred,
)));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: bra.predicate,
},
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
));
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
@ -378,7 +390,39 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
todo!()
}
}
_ => todo!(),
ast::Operand::RegOffset(reg, offset) => {
if let Some(typ) = t {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
value: offset as i128,
}));
let result_id = self.id_def.new_id(t);
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Add(
ast::AddDetails::Int(ast::AddIntDesc {
typ: int_type,
saturate: false,
}),
ast::Arg3 {
dst: result_id,
src1: reg,
src2: id_constant_stmt,
},
),
));
result_id
} else {
todo!()
}
}
}
}
@ -601,6 +645,12 @@ fn emit_function_body_ops(
}
ast::AddDetails::Float(_) => todo!(),
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
todo!()
}
emit_setp(builder, map, setp, arg)?;
}
_ => todo!(),
},
Statement::LoadVar(arg, typ) => {
@ -615,6 +665,81 @@ fn emit_function_body_ops(
Ok(())
}
fn emit_setp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
setp: &ast::SetpData,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if setp.flush_to_zero {
todo!()
}
let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred));
let result_id = Some(arg.dst1);
let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
| (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
_ => todo!(),
}?;
Ok(())
}
fn emit_mul_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1397,6 +1522,18 @@ impl ast::IntType {
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
}
}
fn try_new(t: ast::ScalarType) -> Option<Self> {
match t {
ast::ScalarType::U16 => Some(ast::IntType::U16),
ast::ScalarType::U32 => Some(ast::IntType::U32),
ast::ScalarType::U64 => Some(ast::IntType::U64),
ast::ScalarType::S16 => Some(ast::IntType::S16),
ast::ScalarType::S32 => Some(ast::IntType::S32),
ast::ScalarType::S64 => Some(ast::IntType::S64),
_ => None,
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {