mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-04-20 08:24:44 +00:00
Add support for some most common setp variants and fix a bug with branch conditions
This commit is contained in:
parent
9ed3dc54f2
commit
ed295c4083
3 changed files with 149 additions and 10 deletions
|
@ -355,6 +355,7 @@ pub struct SetpData {
|
|||
pub cmp_op: SetpCompareOp,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Copy, Clone)]
|
||||
pub enum SetpCompareOp {
|
||||
Eq,
|
||||
NotEq,
|
||||
|
|
|
@ -44,6 +44,7 @@ test_ptx!(mov, [1u64], [1u64]);
|
|||
test_ptx!(mul_lo, [1u64], [2u64]);
|
||||
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
|
||||
test_ptx!(add, [1u64], [2u64]);
|
||||
test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
|
||||
|
||||
struct DisplayError<T: Display + Debug> {
|
||||
err: T,
|
||||
|
|
|
@ -224,20 +224,20 @@ fn normalize_predicates(
|
|||
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
ast::Statement::Instruction(pred, inst) => {
|
||||
if let Some(pred) = pred {
|
||||
let mut if_true = id_def.new_id(None);
|
||||
let mut if_false = id_def.new_id(None);
|
||||
if pred.not {
|
||||
std::mem::swap(&mut if_true, &mut if_false);
|
||||
}
|
||||
let if_true = id_def.new_id(None);
|
||||
let if_false = id_def.new_id(None);
|
||||
let folded_bra = match &inst {
|
||||
ast::Instruction::Bra(_, arg) => Some(arg.src),
|
||||
_ => None,
|
||||
};
|
||||
let branch = BrachCondition {
|
||||
let mut branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
if pred.not {
|
||||
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
|
||||
}
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
|
@ -306,9 +306,21 @@ fn insert_mem_ssa_statements(
|
|||
result.append(&mut post_statements);
|
||||
}
|
||||
},
|
||||
s @ Statement::Variable(_, _, _)
|
||||
| s @ Statement::Label(_)
|
||||
| s @ Statement::Conditional(_) => result.push(s),
|
||||
Statement::Conditional(mut bra) => {
|
||||
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
|
||||
ast::ExtendedScalarType::Pred,
|
||||
)));
|
||||
result.push(Statement::LoadVar(
|
||||
Arg2 {
|
||||
dst: generated_id,
|
||||
src: bra.predicate,
|
||||
},
|
||||
ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
|
||||
));
|
||||
bra.predicate = generated_id;
|
||||
result.push(Statement::Conditional(bra));
|
||||
}
|
||||
s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::LoadVar(_, _)
|
||||
| Statement::StoreVar(_, _)
|
||||
| Statement::Conversion(_)
|
||||
|
@ -378,7 +390,39 @@ impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenA
|
|||
todo!()
|
||||
}
|
||||
}
|
||||
_ => todo!(),
|
||||
ast::Operand::RegOffset(reg, offset) => {
|
||||
if let Some(typ) = t {
|
||||
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
|
||||
scalar
|
||||
} else {
|
||||
todo!()
|
||||
};
|
||||
let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
|
||||
self.func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id_constant_stmt,
|
||||
typ: scalar_t,
|
||||
value: offset as i128,
|
||||
}));
|
||||
let result_id = self.id_def.new_id(t);
|
||||
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
|
||||
self.func.push(Statement::Instruction(
|
||||
ast::Instruction::<ExpandedArgParams>::Add(
|
||||
ast::AddDetails::Int(ast::AddIntDesc {
|
||||
typ: int_type,
|
||||
saturate: false,
|
||||
}),
|
||||
ast::Arg3 {
|
||||
dst: result_id,
|
||||
src1: reg,
|
||||
src2: id_constant_stmt,
|
||||
},
|
||||
),
|
||||
));
|
||||
result_id
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -601,6 +645,12 @@ fn emit_function_body_ops(
|
|||
}
|
||||
ast::AddDetails::Float(_) => todo!(),
|
||||
},
|
||||
ast::Instruction::Setp(setp, arg) => {
|
||||
if arg.dst2.is_some() {
|
||||
todo!()
|
||||
}
|
||||
emit_setp(builder, map, setp, arg)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
},
|
||||
Statement::LoadVar(arg, typ) => {
|
||||
|
@ -615,6 +665,81 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_setp(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
setp: &ast::SetpData,
|
||||
arg: &ast::Arg4<ExpandedArgParams>,
|
||||
) -> Result<(), dr::Error> {
|
||||
if setp.flush_to_zero {
|
||||
todo!()
|
||||
}
|
||||
let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred));
|
||||
let result_id = Some(arg.dst1);
|
||||
let operand_1 = arg.src1;
|
||||
let operand_2 = arg.src2;
|
||||
match (setp.cmp_op, setp.typ.kind()) {
|
||||
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
|
||||
builder.i_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
|
||||
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
|
||||
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
|
||||
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
|
||||
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
|
||||
builder.u_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
|
||||
builder.s_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Less, ScalarKind::Float) => {
|
||||
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
|
||||
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
|
||||
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
|
||||
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
|
||||
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
|
||||
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::Greater, ScalarKind::Float) => {
|
||||
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
|
||||
| (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
|
||||
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
|
||||
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
|
||||
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
|
||||
}
|
||||
_ => todo!(),
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul_int(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -1397,6 +1522,18 @@ impl ast::IntType {
|
|||
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_new(t: ast::ScalarType) -> Option<Self> {
|
||||
match t {
|
||||
ast::ScalarType::U16 => Some(ast::IntType::U16),
|
||||
ast::ScalarType::U32 => Some(ast::IntType::U32),
|
||||
ast::ScalarType::U64 => Some(ast::IntType::U64),
|
||||
ast::ScalarType::S16 => Some(ast::IntType::S16),
|
||||
ast::ScalarType::S32 => Some(ast::IntType::S32),
|
||||
ast::ScalarType::S64 => Some(ast::IntType::S64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
|
||||
|
|
Loading…
Add table
Reference in a new issue