From ed295c4083da00dd8cec2e7aa5861fb060dcb13d Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 31 Jul 2020 02:03:59 +0200 Subject: [PATCH] Add support for some most common setp variants and fix a bug with branch conditions --- ptx/src/ast.rs | 1 + ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 157 +++++++++++++++++++++++++++++++--- 3 files changed, 149 insertions(+), 10 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 9fab216..bbc5815 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -355,6 +355,7 @@ pub struct SetpData { pub cmp_op: SetpCompareOp, } +#[derive(PartialEq, Eq, Copy, Clone)] pub enum SetpCompareOp { Eq, NotEq, diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 3abcae7..e0b7d74 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -44,6 +44,7 @@ test_ptx!(mov, [1u64], [1u64]); test_ptx!(mul_lo, [1u64], [2u64]); test_ptx!(mul_hi, [u64::max_value()], [1u64]); test_ptx!(add, [1u64], [2u64]); +test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c399f0d..12b9aae 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -224,20 +224,20 @@ fn normalize_predicates( ast::Statement::Label(id) => result.push(Statement::Label(id)), ast::Statement::Instruction(pred, inst) => { if let Some(pred) = pred { - let mut if_true = id_def.new_id(None); - let mut if_false = id_def.new_id(None); - if pred.not { - std::mem::swap(&mut if_true, &mut if_false); - } + let if_true = id_def.new_id(None); + let if_false = id_def.new_id(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, }; - let branch = BrachCondition { + let mut branch = BrachCondition { predicate: pred.label, if_true: folded_bra.unwrap_or(if_true), if_false, }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } result.push(Statement::Conditional(branch)); if folded_bra.is_none() { result.push(Statement::Label(if_true)); @@ -306,9 +306,21 @@ fn insert_mem_ssa_statements( result.append(&mut post_statements); } }, - s @ Statement::Variable(_, _, _) - | s @ Statement::Label(_) - | s @ Statement::Conditional(_) => result.push(s), + Statement::Conditional(mut bra) => { + let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar( + ast::ExtendedScalarType::Pred, + ))); + result.push(Statement::LoadVar( + Arg2 { + dst: generated_id, + src: bra.predicate, + }, + ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred), + )); + bra.predicate = generated_id; + result.push(Statement::Conditional(bra)); + } + s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s), Statement::LoadVar(_, _) | Statement::StoreVar(_, _) | Statement::Conversion(_) @@ -378,7 +390,39 @@ impl<'a> ArgumentMapVisitor for FlattenA todo!() } } - _ => todo!(), + ast::Operand::RegOffset(reg, offset) => { + if let Some(typ) = t { + let scalar_t = if let ast::Type::Scalar(scalar) = typ { + scalar + } else { + todo!() + }; + let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t))); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: scalar_t, + value: offset as i128, + })); + let result_id = self.id_def.new_id(t); + let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!()); + self.func.push(Statement::Instruction( + ast::Instruction::::Add( + ast::AddDetails::Int(ast::AddIntDesc { + typ: int_type, + saturate: false, + }), + ast::Arg3 { + dst: result_id, + src1: reg, + src2: id_constant_stmt, + }, + ), + )); + result_id + } else { + todo!() + } + } } } @@ -601,6 +645,12 @@ fn emit_function_body_ops( } ast::AddDetails::Float(_) => todo!(), }, + ast::Instruction::Setp(setp, arg) => { + if arg.dst2.is_some() { + todo!() + } + emit_setp(builder, map, setp, arg)?; + } _ => todo!(), }, Statement::LoadVar(arg, typ) => { @@ -615,6 +665,81 @@ fn emit_function_body_ops( Ok(()) } +fn emit_setp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + setp: &ast::SetpData, + arg: &ast::Arg4, +) -> Result<(), dr::Error> { + if setp.flush_to_zero { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::Extended(ast::ExtendedScalarType::Pred)); + let result_id = Some(arg.dst1); + let operand_1 = arg.src1; + let operand_2 = arg.src2; + match (setp.cmp_op, setp.typ.kind()) { + (ast::SetpCompareOp::Eq, ScalarKind::Signed) + | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) + | (ast::SetpCompareOp::Eq, ScalarKind::Byte) => { + builder.i_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Eq, ScalarKind::Float) => { + builder.f_ord_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::NotEq, ScalarKind::Signed) + | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) + | (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => { + builder.i_not_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { + builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Less, ScalarKind::Unsigned) + | (ast::SetpCompareOp::Less, ScalarKind::Byte) => { + builder.u_less_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Less, ScalarKind::Signed) => { + builder.s_less_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Less, ScalarKind::Float) => { + builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) + | (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => { + builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { + builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { + builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) + | (ast::SetpCompareOp::Greater, ScalarKind::Byte) => { + builder.u_greater_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { + builder.s_greater_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::Greater, ScalarKind::Float) => { + builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) + | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => { + builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { + builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { + builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + _ => todo!(), + }?; + Ok(()) +} + fn emit_mul_int( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -1397,6 +1522,18 @@ impl ast::IntType { ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false, } } + + fn try_new(t: ast::ScalarType) -> Option { + match t { + ast::ScalarType::U16 => Some(ast::IntType::U16), + ast::ScalarType::U32 => Some(ast::IntType::U32), + ast::ScalarType::U64 => Some(ast::IntType::U64), + ast::ScalarType::S16 => Some(ast::IntType::S16), + ast::ScalarType::S32 => Some(ast::IntType::S32), + ast::ScalarType::S64 => Some(ast::IntType::S64), + _ => None, + } + } } fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {