Add support for integer addition

This commit is contained in:
Andrzej Janik 2020-07-26 03:09:05 +02:00
parent b068a89c38
commit ec7ab8e5c4
6 changed files with 270 additions and 85 deletions

View file

@ -68,6 +68,17 @@ pub enum Type {
ExtendedScalar(ExtendedScalarType),
}
impl From<FloatType> for Type {
fn from(t: FloatType) -> Self {
match t {
FloatType::F16 => Type::Scalar(ScalarType::F16),
FloatType::F16x2 => Type::ExtendedScalar(ExtendedScalarType::F16x2),
FloatType::F32 => Type::Scalar(ScalarType::F32),
FloatType::F64 => Type::Scalar(ScalarType::F64),
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ScalarType {
B8,
@ -87,6 +98,37 @@ pub enum ScalarType {
F64,
}
impl From<IntType> for ScalarType {
fn from(t: IntType) -> Self {
match t {
IntType::S16 => ScalarType::S16,
IntType::S32 => ScalarType::S32,
IntType::S64 => ScalarType::S64,
IntType::U16 => ScalarType::U16,
IntType::U32 => ScalarType::U32,
IntType::U64 => ScalarType::U64,
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum IntType {
U16,
U32,
U64,
S16,
S32,
S64,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum FloatType {
F16,
F16x2,
F32,
F64,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub enum ExtendedScalarType {
F16x2,
@ -130,8 +172,8 @@ pub struct PredAt<ID> {
pub enum Instruction<ID> {
Ld(LdData, Arg2<ID>),
Mov(MovData, Arg2Mov<ID>),
Mul(MulData, Arg3<ID>),
Add(AddData, Arg3<ID>),
Mul(MulDetails, Arg3<ID>),
Add(AddDetails, Arg3<ID>),
Setp(SetpData, Arg4<ID>),
SetpBool(SetpBoolData, Arg5<ID>),
Not(NotData, Arg2<ID>),
@ -244,23 +286,24 @@ pub struct MovData {
pub typ: Type,
}
pub struct MulData {
pub typ: Type,
pub desc: MulDescriptor,
pub enum MulDetails {
Int(MulIntDesc),
Float(MulFloatDesc),
}
pub enum MulDescriptor {
Int(MulIntControl),
Float(MulFloatDesc),
pub struct MulIntDesc {
pub typ: IntType,
pub control: MulIntControl,
}
pub enum MulIntControl {
Low,
High,
Wide
Wide,
}
pub struct MulFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
@ -270,11 +313,24 @@ pub enum RoundingMode {
NearestEven,
Zero,
NegativeInf,
PositiveInf
PositiveInf,
}
pub struct AddData {
pub typ: ScalarType,
pub enum AddDetails {
Int(AddIntDesc),
Float(AddFloatDesc),
}
pub struct AddIntDesc {
pub typ: IntType,
pub saturate: bool,
}
pub struct AddFloatDesc {
pub typ: FloatType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
pub struct SetpData {
@ -310,7 +366,7 @@ pub struct SetpBoolData {
pub typ: ScalarType,
pub flush_to_zero: bool,
pub cmp_op: SetpCompareOp,
pub bool_op: SetpBoolPostOp
pub bool_op: SetpBoolPostOp,
}
pub struct NotData {}

View file

@ -398,43 +398,35 @@ InstMul: ast::Instruction<&'input str> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
};
InstMulMode: ast::MulData = {
<ctr:MulIntControl> <t:IntType> => ast::MulData{
typ: ast::Type::Scalar(t),
desc: ast::MulDescriptor::Int(ctr)
},
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F32),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r,
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
},
<r:RoundingMode?> ".f64" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F64),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r,
flush_to_zero: false,
saturate: false
})
},
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F16),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
},
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulData{
typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
}
InstMulMode: ast::MulDetails = {
<ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc {
typ: t,
control: ctr
}),
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F32,
rounding: r,
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}),
<r:RoundingMode?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F64,
rounding: r,
flush_to_zero: false,
saturate: false
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
}),
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc {
typ: ast::FloatType::F16x2,
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
};
MulIntControl: ast::MulIntControl = {
@ -451,13 +443,13 @@ RoundingMode : ast::RoundingMode = {
".rp" => ast::RoundingMode::PositiveInf,
};
IntType : ast::ScalarType = {
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
IntType : ast::IntType = {
".u16" => ast::IntType::U16,
".u32" => ast::IntType::U32,
".u64" => ast::IntType::U64,
".s16" => ast::IntType::S16,
".s32" => ast::IntType::S32,
".s64" => ast::IntType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
@ -467,12 +459,33 @@ InstAdd: ast::Instruction<&'input str> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
};
InstAddMode: ast::AddData = {
<t:IntType> => ast::AddData{ typ: t },
".sat" ".s32" => ast::AddData{ typ: ast::ScalarType::S32 },
RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{ typ: ast::ScalarType::F32 },
RoundingMode? ".f64" => ast::AddData{ typ: ast::ScalarType::F64 },
".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{ typ: ast::ScalarType::F16 },
InstAddMode: ast::AddDetails = {
<t:IntType> => ast::AddDetails::Int(ast::AddIntDesc {
typ: t,
saturate: false,
}),
".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc {
typ: ast::IntType::S32,
saturate: true,
}),
<rn:RoundingMode?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F32,
rounding: rn,
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
<rn:RoundingMode?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F64,
rounding: rn,
flush_to_zero: false,
saturate: false,
}),
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc {
typ: ast::FloatType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: sat.is_some(),
}),
".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
};

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry add(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
add.u64 temp2, temp, 1;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,38 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "add"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%4 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_1 = OpConstant %ulong 1
%5 = OpFunction %void None %4
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%21 = OpLabel
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
%10 = OpVariable %_ptr_Function_ulong Function
%11 = OpVariable %_ptr_Function_ulong Function
OpStore %8 %6
OpStore %9 %7
%12 = OpLoad %ulong %8
%19 = OpConvertUToPtr %_ptr_Generic_ulong %12
%13 = OpLoad %ulong %19
OpStore %10 %13
%14 = OpLoad %ulong %10
%15 = OpIAdd %ulong %14 %ulong_1
OpStore %11 %15
%16 = OpLoad %ulong %9
%17 = OpLoad %ulong %11
%20 = OpConvertUToPtr %_ptr_Generic_ulong %16
OpStore %20 %17
OpReturn
OpFunctionEnd

View file

@ -43,6 +43,7 @@ test_ptx!(ld_st, [1u64], [1u64]);
test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
test_ptx!(add, [1u64], [2u64]);
struct DisplayError<T: Display + Debug> {
err: T,
@ -233,6 +234,9 @@ fn is_instr_equal(
instr2: &Instruction,
map: &mut HashMap<Word, Word>,
) -> bool {
if instr1.class.opcode != instr2.class.opcode {
return false;
}
if !is_option_equal(&instr1.result_type, &instr2.result_type, map, is_word_equal) {
return false;
}

View file

@ -355,11 +355,11 @@ fn normalize_insert_instruction(
Instruction::Mov(d, arg)
}
Instruction::Mul(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.typ.try_as_scalar(), a);
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Mul(d, arg)
}
Instruction::Add(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| Some(d.typ), a);
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
Instruction::Add(d, arg)
}
Instruction::Setp(d, a) => {
@ -731,11 +731,17 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
Instruction::Mul(mul, arg) => match mul.desc {
ast::MulDescriptor::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, mul.typ, ctr, arg)?;
Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
}
ast::MulDescriptor::Float(_) => todo!(),
ast::MulDetails::Float(_) => todo!(),
},
Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => {
emit_add_int(builder, map, desc, arg)?;
}
ast::AddDetails::Float(_) => todo!(),
},
_ => todo!(),
},
@ -755,26 +761,24 @@ fn emit_mul_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
typ: ast::Type,
ctr: &ast::MulIntControl,
desc: &ast::MulIntDesc,
arg: &Arg3,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(typ));
match ctr {
let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into()));
match desc.control {
ast::MulIntControl::Low => {
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
let ocl_mul_hi = match typ.try_as_scalar().unwrap().kind() {
ScalarKind::Signed => spirv::CLOp::s_mul_hi,
ScalarKind::Unsigned => spirv::CLOp::u_mul_hi,
ScalarKind::Float => unreachable!(),
ScalarKind::Byte => unreachable!(),
let ocl_mul_hi = if desc.typ.is_signed() {
spirv::CLOp::s_mul_hi
} else {
spirv::CLOp::u_mul_hi
};
builder.ext_inst(
inst_type,
Some(arg.dst),
1,
opencl,
ocl_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
@ -784,6 +788,17 @@ fn emit_mul_int(
Ok(())
}
fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ctr: &ast::AddIntDesc,
arg: &Arg3,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into()));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@ -1059,8 +1074,8 @@ type ExpandedStatement = Statement<ExpandedArgs>;
enum Instruction<A: Args> {
Ld(ast::LdData, A::Arg2),
Mov(ast::MovData, A::Arg2Mov),
Mul(ast::MulData, A::Arg3),
Add(ast::AddData, A::Arg3),
Mul(ast::MulDetails, A::Arg3),
Add(ast::AddDetails, A::Arg3),
Setp(ast::SetpData, A::Arg4),
SetpBool(ast::SetpBoolData, A::Arg5),
Not(ast::NotData, A::Arg2),
@ -1091,12 +1106,22 @@ impl<A: Args> Instruction<A> {
fn get_type(&self) -> Option<ast::Type> {
match self {
Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
Instruction::Add(add, _) => match add {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => Some((*typ).into()),
},
Instruction::Ret(_) => None,
Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
Instruction::Mov(mov, _) => Some(mov.typ),
Instruction::Mul(mul, _) => Some(mul.typ),
Instruction::Mul(mul, _) => match mul {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => {
Some(ast::Type::Scalar((*typ).into()))
}
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => Some((*typ).into()),
},
_ => todo!(),
}
}
@ -1437,12 +1462,12 @@ impl ast::Instruction<spirv::Word> {
fn get_type(&self) -> Option<ast::Type> {
match self {
ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
ast::Instruction::Add(add, _) => Some(add.get_type()),
ast::Instruction::Ret(_) => None,
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
ast::Instruction::Mov(mov, _) => Some(mov.typ),
ast::Instruction::Mul(mul, _) => Some(mul.typ),
ast::Instruction::Mul(mul, _) => Some(mul.get_type()),
_ => todo!(),
}
}
@ -1800,6 +1825,33 @@ impl ast::ScalarType {
}
}
impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => (*typ).into(),
}
}
}
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
match self {
ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => (*typ).into(),
}
}
}
impl ast::IntType {
fn is_signed(self) -> bool {
match self {
ast::IntType::S16 | ast::IntType::S32 | ast::IntType::S64 => true,
ast::IntType::U16 | ast::IntType::U32 | ast::IntType::U64 => false,
}
}
}
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {