Implement constants in translation middle-end

This commit is contained in:
Andrzej Janik 2020-07-20 00:01:03 +02:00
parent 3d6991e0ca
commit 872d69c714
10 changed files with 840 additions and 173 deletions

View file

@ -14,6 +14,7 @@ spirv_headers = "1.4"
quick-error = "1.2"
bit-vec = "0.6"
paste = "0.1"
half ="1.6"
[build-dependencies.lalrpop]
version = "0.18.1"

View file

@ -243,15 +243,74 @@ pub struct MovData {
pub typ: Type,
}
pub struct MulData {}
pub struct MulData {
pub typ: Type,
pub desc: MulDescriptor,
}
pub enum MulDescriptor {
Int(MulIntControl),
Float(MulFloatDesc),
}
pub enum MulIntControl {
Low,
High,
Wide
}
pub struct MulFloatDesc {
pub rounding: Option<RoundingMode>,
pub flush_to_zero: bool,
pub saturate: bool,
}
pub enum RoundingMode {
NearestEven,
Zero,
NegativeInf,
PositiveInf
}
pub struct AddData {
pub typ: ScalarType,
}
pub struct SetpData {}
pub struct SetpData {
pub typ: ScalarType,
pub flush_to_zero: bool,
pub cmp_op: SetpCompareOp,
}
pub struct SetpBoolData {}
pub enum SetpCompareOp {
Eq,
NotEq,
Less,
LessOrEq,
Greater,
GreaterOrEq,
NanEq,
NanNotEq,
NanLess,
NanLessOrEq,
NanGreater,
NanGreaterOrEq,
IsNotNan,
IsNan,
}
pub enum SetpBoolPostOp {
And,
Or,
Xor,
}
pub struct SetpBoolData {
pub typ: ScalarType,
pub flush_to_zero: bool,
pub cmp_op: SetpCompareOp,
pub bool_op: SetpBoolPostOp
}
pub struct NotData {}

View file

@ -12,6 +12,7 @@ extern crate level_zero as ze;
extern crate level_zero_sys as l0;
extern crate rspirv;
extern crate spirv_headers as spirv;
extern crate half;
#[cfg(test)]
extern crate spirv_tools_sys as spirv_tools;

View file

@ -399,20 +399,56 @@ InstMul: ast::Instruction<&'input str> = {
};
InstMulMode: ast::MulData = {
MulIntControl? IntType => ast::MulData{},
RoundingMode? ".ftz"? ".sat"? ".f32" => ast::MulData{},
RoundingMode? ".f64" => ast::MulData{},
".rn"? ".ftz"? ".sat"? ".f16" => ast::MulData{},
".rn"? ".ftz"? ".sat"? ".f16x2" => ast::MulData{}
<ctr:MulIntControl> <t:IntType> => ast::MulData{
typ: ast::Type::Scalar(t),
desc: ast::MulDescriptor::Int(ctr)
},
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F32),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r,
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
},
<r:RoundingMode?> ".f64" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F64),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r,
flush_to_zero: false,
saturate: false
})
},
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulData{
typ: ast::Type::Scalar(ast::ScalarType::F16),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
},
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulData{
typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2),
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
rounding: r.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: ftz.is_some(),
saturate: s.is_some()
})
}
};
MulIntControl = {
".hi", ".lo", ".wide"
MulIntControl: ast::MulIntControl = {
".hi" => ast::MulIntControl::High,
".lo" => ast::MulIntControl::Low,
".wide" => ast::MulIntControl::Wide
};
#[inline]
RoundingMode = {
".rn", ".rz", ".rm", ".rp"
RoundingMode : ast::RoundingMode = {
".rn" => ast::RoundingMode::NearestEven,
".rz" => ast::RoundingMode::Zero,
".rm" => ast::RoundingMode::NegativeInf,
".rp" => ast::RoundingMode::PositiveInf,
};
IntType : ast::ScalarType = {
@ -449,27 +485,61 @@ InstSetp: ast::Instruction<&'input str> = {
};
SetpMode: ast::SetpData = {
SetpCmpOp ".ftz"? SetpType => ast::SetpData{}
<cmp_op:SetpCompareOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpData{
typ: t,
flush_to_zero: ftz.is_some(),
cmp_op: cmp_op,
}
};
SetpBoolMode: ast::SetpBoolData = {
SetpCmpOp SetpBoolOp ".ftz"? SetpType => ast::SetpBoolData{}
<cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpBoolData{
typ: t,
flush_to_zero: ftz.is_some(),
cmp_op: cmp_op,
bool_op: bool_op,
}
};
SetpCmpOp = {
".eq", ".ne", ".lt", ".le", ".gt", ".ge", ".lo", ".ls", ".hi", ".hs",
".equ", ".neu", ".ltu", ".leu", ".gtu", ".geu", ".num", ".nan"
SetpCompareOp: ast::SetpCompareOp = {
".eq" => ast::SetpCompareOp::Eq,
".ne" => ast::SetpCompareOp::NotEq,
".lt" => ast::SetpCompareOp::Less,
".le" => ast::SetpCompareOp::LessOrEq,
".gt" => ast::SetpCompareOp::Greater,
".ge" => ast::SetpCompareOp::GreaterOrEq,
".lo" => ast::SetpCompareOp::Less,
".ls" => ast::SetpCompareOp::LessOrEq,
".hi" => ast::SetpCompareOp::Greater,
".hs" => ast::SetpCompareOp::GreaterOrEq,
".equ" => ast::SetpCompareOp::NanEq,
".neu" => ast::SetpCompareOp::NanNotEq,
".ltu" => ast::SetpCompareOp::NanLess,
".leu" => ast::SetpCompareOp::NanLessOrEq,
".gtu" => ast::SetpCompareOp::NanGreater,
".geu" => ast::SetpCompareOp::NanGreaterOrEq,
".num" => ast::SetpCompareOp::IsNotNan,
".nan" => ast::SetpCompareOp::IsNan,
};
SetpBoolOp = {
".and", ".or", ".xor"
SetpBoolPostOp: ast::SetpBoolPostOp = {
".and" => ast::SetpBoolPostOp::And,
".or" => ast::SetpBoolPostOp::Or,
".xor" => ast::SetpBoolPostOp::Xor,
};
SetpType = {
".b16", ".b32", ".b64",
".u16", ".u32", ".u64",
".s16", ".s32", ".s64",
".f32", ".f64"
SetpType: ast::ScalarType = {
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
".b64" => ast::ScalarType::B64,
".u16" => ast::ScalarType::U16,
".u32" => ast::ScalarType::U32,
".u64" => ast::ScalarType::U64,
".s16" => ast::ScalarType::S16,
".s32" => ast::ScalarType::S32,
".s64" => ast::ScalarType::S64,
".f32" => ast::ScalarType::F32,
".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not

View file

@ -37,7 +37,9 @@ macro_rules! test_ptx {
}
test_ptx!(ld_st, [1u64], [1u64]);
test_ptx!(mov, [1u64], [1u64]);
//test_ptx!(mov, [1u64], [1u64]);
//test_ptx!(mul_lo, [1u64], [2u64]);
//test_ptx!(mul_hi, [u64::max_value()], [1u64]);
struct DisplayError<T: Display + Debug> {
err: T,

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mul_hi(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
mul.hi.u64 temp2, temp, 2;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,26 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "mul_hi"
%2 = OpTypeVoid
%3 = OpTypeInt 64 0
%4 = OpTypeFunction %2 %3 %3
%19 = OpTypePointer Generic %3
%5 = OpFunction %2 None %4
%6 = OpFunctionParameter %3
%7 = OpFunctionParameter %3
%18 = OpLabel
%13 = OpCopyObject %3 %6
%14 = OpCopyObject %3 %7
%15 = OpConvertUToPtr %19 %13
%16 = OpLoad %3 %15
%100 = OpCopyObject %3 %16
%17 = OpConvertUToPtr %19 %14
OpStore %17 %100
OpReturn
OpFunctionEnd

View file

@ -0,0 +1,22 @@
.version 6.5
.target sm_30
.address_size 64
.visible .entry mul_lo(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u64 temp, [in_addr];
mul.lo.u64 temp2, temp, 2;
st.u64 [out_addr], temp2;
ret;
}

View file

@ -0,0 +1,26 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int64
OpCapability Int8
%1 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %5 "mul_lo"
%2 = OpTypeVoid
%3 = OpTypeInt 64 0
%4 = OpTypeFunction %2 %3 %3
%19 = OpTypePointer Generic %3
%5 = OpFunction %2 None %4
%6 = OpFunctionParameter %3
%7 = OpFunctionParameter %3
%18 = OpLabel
%13 = OpCopyObject %3 %6
%14 = OpCopyObject %3 %7
%15 = OpConvertUToPtr %19 %13
%16 = OpLoad %3 %15
%100 = OpCopyObject %3 %16
%17 = OpConvertUToPtr %19 %14
OpStore %17 %100
OpReturn
OpFunctionEnd

File diff suppressed because it is too large Load diff