From bd3d440dba9a913e2214de89a151f9c2c34984fe Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 1 Oct 2020 20:28:57 +0200 Subject: [PATCH] Implement or --- ptx/src/ast.rs | 8 ++++++++ ptx/src/ptx.lalrpop | 17 ++++++++++++++++- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 16 ++++++++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index b509dfe..8c64ebf 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -345,6 +345,7 @@ pub enum Instruction { Call(CallInst

), Abs(AbsDetails, Arg2

), Mad(MulDetails, Arg4

), + Or(OrType, Arg3

), } #[derive(Copy, Clone)] @@ -802,3 +803,10 @@ pub enum StCacheOperator { pub struct RetData { pub uniform: bool, } + +sub_scalar_type!(OrType { + Pred, + B16, + B32, + B64, +}); diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index debdae7..d2d5be8 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -127,6 +127,7 @@ match { "mov", "mul", "not", + "or", "ret", "setp", "shl", @@ -155,6 +156,7 @@ ExtendedID : &'input str = { "mov", "mul", "not", + "or", "ret", "setp", "shl", @@ -445,7 +447,8 @@ Instruction: ast::Instruction> = { InstCvta, InstCall, InstAbs, - InstMad + InstMad, + InstOr }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1048,6 +1051,18 @@ SignedIntType: ast::ScalarType = { ".s64" => ast::ScalarType::S64, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or +InstOr: ast::Instruction> = { + "or" => ast::Instruction::Or(d, a), +}; + +OrType: ast::OrType = { + ".pred" => ast::OrType::Pred, + ".b16" => ast::OrType::B16, + ".b32" => ast::OrType::B32, + ".b64" => ast::OrType::B64, +} + Operand: ast::Operand<&'input str> = { => ast::Operand::Reg(r), "+" => { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 6f516fd..99785a6 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -69,6 +69,7 @@ test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]); test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]); test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]); test_ptx!(shr, [-2i32], [-1i32]); +test_ptx!(or, [1u64, 2u64], [3u64]); struct DisplayError { diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index fe6a7dc..fb1b843 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -592,6 +592,9 @@ fn convert_to_typed_statements( ast::Instruction::Shr(d, a) => { result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast()))) } + ast::Instruction::Or(d, a) => { + result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) + } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), @@ -1583,6 +1586,14 @@ fn emit_function_body_ops( } ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?, }, + ast::Instruction::Or(t, a) => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); + if *t == ast::OrType::Pred { + builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; + } else { + builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; + } + } }, Statement::LoadVar(arg, typ) => { let type_id = map.get_or_add(builder, SpirvType::from(*typ)); @@ -2905,6 +2916,10 @@ impl ast::Instruction { let is_wide = d.is_wide(); ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?) } + ast::Instruction::Or(t, a) => ast::Instruction::Or( + t, + a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?, + ), }) } } @@ -3113,6 +3128,7 @@ impl ast::Instruction { | ast::Instruction::Ret(_) | ast::Instruction::Abs(_, _) | ast::Instruction::Call(_) + | ast::Instruction::Or(_, _) | ast::Instruction::Mad(_, _) => None, } }