mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-08-03 06:40:21 +00:00
Implement constants in translation middle-end
This commit is contained in:
parent
3d6991e0ca
commit
872d69c714
10 changed files with 840 additions and 173 deletions
|
@ -14,6 +14,7 @@ spirv_headers = "1.4"
|
|||
quick-error = "1.2"
|
||||
bit-vec = "0.6"
|
||||
paste = "0.1"
|
||||
half ="1.6"
|
||||
|
||||
[build-dependencies.lalrpop]
|
||||
version = "0.18.1"
|
||||
|
|
|
@ -243,15 +243,74 @@ pub struct MovData {
|
|||
pub typ: Type,
|
||||
}
|
||||
|
||||
pub struct MulData {}
|
||||
pub struct MulData {
|
||||
pub typ: Type,
|
||||
pub desc: MulDescriptor,
|
||||
}
|
||||
|
||||
pub enum MulDescriptor {
|
||||
Int(MulIntControl),
|
||||
Float(MulFloatDesc),
|
||||
}
|
||||
|
||||
pub enum MulIntControl {
|
||||
Low,
|
||||
High,
|
||||
Wide
|
||||
}
|
||||
|
||||
pub struct MulFloatDesc {
|
||||
pub rounding: Option<RoundingMode>,
|
||||
pub flush_to_zero: bool,
|
||||
pub saturate: bool,
|
||||
}
|
||||
|
||||
pub enum RoundingMode {
|
||||
NearestEven,
|
||||
Zero,
|
||||
NegativeInf,
|
||||
PositiveInf
|
||||
}
|
||||
|
||||
pub struct AddData {
|
||||
pub typ: ScalarType,
|
||||
}
|
||||
|
||||
pub struct SetpData {}
|
||||
pub struct SetpData {
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: bool,
|
||||
pub cmp_op: SetpCompareOp,
|
||||
}
|
||||
|
||||
pub struct SetpBoolData {}
|
||||
pub enum SetpCompareOp {
|
||||
Eq,
|
||||
NotEq,
|
||||
Less,
|
||||
LessOrEq,
|
||||
Greater,
|
||||
GreaterOrEq,
|
||||
NanEq,
|
||||
NanNotEq,
|
||||
NanLess,
|
||||
NanLessOrEq,
|
||||
NanGreater,
|
||||
NanGreaterOrEq,
|
||||
IsNotNan,
|
||||
IsNan,
|
||||
}
|
||||
|
||||
pub enum SetpBoolPostOp {
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
}
|
||||
|
||||
pub struct SetpBoolData {
|
||||
pub typ: ScalarType,
|
||||
pub flush_to_zero: bool,
|
||||
pub cmp_op: SetpCompareOp,
|
||||
pub bool_op: SetpBoolPostOp
|
||||
}
|
||||
|
||||
pub struct NotData {}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ extern crate level_zero as ze;
|
|||
extern crate level_zero_sys as l0;
|
||||
extern crate rspirv;
|
||||
extern crate spirv_headers as spirv;
|
||||
extern crate half;
|
||||
|
||||
#[cfg(test)]
|
||||
extern crate spirv_tools_sys as spirv_tools;
|
||||
|
|
|
@ -399,20 +399,56 @@ InstMul: ast::Instruction<&'input str> = {
|
|||
};
|
||||
|
||||
InstMulMode: ast::MulData = {
|
||||
MulIntControl? IntType => ast::MulData{},
|
||||
RoundingMode? ".ftz"? ".sat"? ".f32" => ast::MulData{},
|
||||
RoundingMode? ".f64" => ast::MulData{},
|
||||
".rn"? ".ftz"? ".sat"? ".f16" => ast::MulData{},
|
||||
".rn"? ".ftz"? ".sat"? ".f16x2" => ast::MulData{}
|
||||
<ctr:MulIntControl> <t:IntType> => ast::MulData{
|
||||
typ: ast::Type::Scalar(t),
|
||||
desc: ast::MulDescriptor::Int(ctr)
|
||||
},
|
||||
<r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulData{
|
||||
typ: ast::Type::Scalar(ast::ScalarType::F32),
|
||||
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
|
||||
rounding: r,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
})
|
||||
},
|
||||
<r:RoundingMode?> ".f64" => ast::MulData{
|
||||
typ: ast::Type::Scalar(ast::ScalarType::F64),
|
||||
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
|
||||
rounding: r,
|
||||
flush_to_zero: false,
|
||||
saturate: false
|
||||
})
|
||||
},
|
||||
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulData{
|
||||
typ: ast::Type::Scalar(ast::ScalarType::F16),
|
||||
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
|
||||
rounding: r.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
})
|
||||
},
|
||||
<r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulData{
|
||||
typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2),
|
||||
desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
|
||||
rounding: r.map(|_| ast::RoundingMode::NearestEven),
|
||||
flush_to_zero: ftz.is_some(),
|
||||
saturate: s.is_some()
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
MulIntControl = {
|
||||
".hi", ".lo", ".wide"
|
||||
MulIntControl: ast::MulIntControl = {
|
||||
".hi" => ast::MulIntControl::High,
|
||||
".lo" => ast::MulIntControl::Low,
|
||||
".wide" => ast::MulIntControl::Wide
|
||||
};
|
||||
|
||||
#[inline]
|
||||
RoundingMode = {
|
||||
".rn", ".rz", ".rm", ".rp"
|
||||
RoundingMode : ast::RoundingMode = {
|
||||
".rn" => ast::RoundingMode::NearestEven,
|
||||
".rz" => ast::RoundingMode::Zero,
|
||||
".rm" => ast::RoundingMode::NegativeInf,
|
||||
".rp" => ast::RoundingMode::PositiveInf,
|
||||
};
|
||||
|
||||
IntType : ast::ScalarType = {
|
||||
|
@ -449,27 +485,61 @@ InstSetp: ast::Instruction<&'input str> = {
|
|||
};
|
||||
|
||||
SetpMode: ast::SetpData = {
|
||||
SetpCmpOp ".ftz"? SetpType => ast::SetpData{}
|
||||
<cmp_op:SetpCompareOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpData{
|
||||
typ: t,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
cmp_op: cmp_op,
|
||||
}
|
||||
};
|
||||
|
||||
SetpBoolMode: ast::SetpBoolData = {
|
||||
SetpCmpOp SetpBoolOp ".ftz"? SetpType => ast::SetpBoolData{}
|
||||
<cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpBoolData{
|
||||
typ: t,
|
||||
flush_to_zero: ftz.is_some(),
|
||||
cmp_op: cmp_op,
|
||||
bool_op: bool_op,
|
||||
}
|
||||
};
|
||||
|
||||
SetpCmpOp = {
|
||||
".eq", ".ne", ".lt", ".le", ".gt", ".ge", ".lo", ".ls", ".hi", ".hs",
|
||||
".equ", ".neu", ".ltu", ".leu", ".gtu", ".geu", ".num", ".nan"
|
||||
SetpCompareOp: ast::SetpCompareOp = {
|
||||
".eq" => ast::SetpCompareOp::Eq,
|
||||
".ne" => ast::SetpCompareOp::NotEq,
|
||||
".lt" => ast::SetpCompareOp::Less,
|
||||
".le" => ast::SetpCompareOp::LessOrEq,
|
||||
".gt" => ast::SetpCompareOp::Greater,
|
||||
".ge" => ast::SetpCompareOp::GreaterOrEq,
|
||||
".lo" => ast::SetpCompareOp::Less,
|
||||
".ls" => ast::SetpCompareOp::LessOrEq,
|
||||
".hi" => ast::SetpCompareOp::Greater,
|
||||
".hs" => ast::SetpCompareOp::GreaterOrEq,
|
||||
".equ" => ast::SetpCompareOp::NanEq,
|
||||
".neu" => ast::SetpCompareOp::NanNotEq,
|
||||
".ltu" => ast::SetpCompareOp::NanLess,
|
||||
".leu" => ast::SetpCompareOp::NanLessOrEq,
|
||||
".gtu" => ast::SetpCompareOp::NanGreater,
|
||||
".geu" => ast::SetpCompareOp::NanGreaterOrEq,
|
||||
".num" => ast::SetpCompareOp::IsNotNan,
|
||||
".nan" => ast::SetpCompareOp::IsNan,
|
||||
};
|
||||
|
||||
SetpBoolOp = {
|
||||
".and", ".or", ".xor"
|
||||
SetpBoolPostOp: ast::SetpBoolPostOp = {
|
||||
".and" => ast::SetpBoolPostOp::And,
|
||||
".or" => ast::SetpBoolPostOp::Or,
|
||||
".xor" => ast::SetpBoolPostOp::Xor,
|
||||
};
|
||||
|
||||
SetpType = {
|
||||
".b16", ".b32", ".b64",
|
||||
".u16", ".u32", ".u64",
|
||||
".s16", ".s32", ".s64",
|
||||
".f32", ".f64"
|
||||
SetpType: ast::ScalarType = {
|
||||
".b16" => ast::ScalarType::B16,
|
||||
".b32" => ast::ScalarType::B32,
|
||||
".b64" => ast::ScalarType::B64,
|
||||
".u16" => ast::ScalarType::U16,
|
||||
".u32" => ast::ScalarType::U32,
|
||||
".u64" => ast::ScalarType::U64,
|
||||
".s16" => ast::ScalarType::S16,
|
||||
".s32" => ast::ScalarType::S32,
|
||||
".s64" => ast::ScalarType::S64,
|
||||
".f32" => ast::ScalarType::F32,
|
||||
".f64" => ast::ScalarType::F64,
|
||||
};
|
||||
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
|
||||
|
|
|
@ -37,7 +37,9 @@ macro_rules! test_ptx {
|
|||
}
|
||||
|
||||
test_ptx!(ld_st, [1u64], [1u64]);
|
||||
test_ptx!(mov, [1u64], [1u64]);
|
||||
//test_ptx!(mov, [1u64], [1u64]);
|
||||
//test_ptx!(mul_lo, [1u64], [2u64]);
|
||||
//test_ptx!(mul_hi, [u64::max_value()], [1u64]);
|
||||
|
||||
struct DisplayError<T: Display + Debug> {
|
||||
err: T,
|
||||
|
|
22
ptx/src/test/spirv_run/mul_hi.ptx
Normal file
22
ptx/src/test/spirv_run/mul_hi.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mul_hi(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
mul.hi.u64 temp2, temp, 2;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
26
ptx/src/test/spirv_run/mul_hi.spvtxt
Normal file
26
ptx/src/test/spirv_run/mul_hi.spvtxt
Normal file
|
@ -0,0 +1,26 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %5 "mul_hi"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 64 0
|
||||
%4 = OpTypeFunction %2 %3 %3
|
||||
%19 = OpTypePointer Generic %3
|
||||
%5 = OpFunction %2 None %4
|
||||
%6 = OpFunctionParameter %3
|
||||
%7 = OpFunctionParameter %3
|
||||
%18 = OpLabel
|
||||
%13 = OpCopyObject %3 %6
|
||||
%14 = OpCopyObject %3 %7
|
||||
%15 = OpConvertUToPtr %19 %13
|
||||
%16 = OpLoad %3 %15
|
||||
%100 = OpCopyObject %3 %16
|
||||
%17 = OpConvertUToPtr %19 %14
|
||||
OpStore %17 %100
|
||||
OpReturn
|
||||
OpFunctionEnd
|
22
ptx/src/test/spirv_run/mul_lo.ptx
Normal file
22
ptx/src/test/spirv_run/mul_lo.ptx
Normal file
|
@ -0,0 +1,22 @@
|
|||
.version 6.5
|
||||
.target sm_30
|
||||
.address_size 64
|
||||
|
||||
.visible .entry mul_lo(
|
||||
.param .u64 input,
|
||||
.param .u64 output
|
||||
)
|
||||
{
|
||||
.reg .u64 in_addr;
|
||||
.reg .u64 out_addr;
|
||||
.reg .u64 temp;
|
||||
.reg .u64 temp2;
|
||||
|
||||
ld.param.u64 in_addr, [input];
|
||||
ld.param.u64 out_addr, [output];
|
||||
|
||||
ld.u64 temp, [in_addr];
|
||||
mul.lo.u64 temp2, temp, 2;
|
||||
st.u64 [out_addr], temp2;
|
||||
ret;
|
||||
}
|
26
ptx/src/test/spirv_run/mul_lo.spvtxt
Normal file
26
ptx/src/test/spirv_run/mul_lo.spvtxt
Normal file
|
@ -0,0 +1,26 @@
|
|||
OpCapability GenericPointer
|
||||
OpCapability Linkage
|
||||
OpCapability Addresses
|
||||
OpCapability Kernel
|
||||
OpCapability Int64
|
||||
OpCapability Int8
|
||||
%1 = OpExtInstImport "OpenCL.std"
|
||||
OpMemoryModel Physical64 OpenCL
|
||||
OpEntryPoint Kernel %5 "mul_lo"
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 64 0
|
||||
%4 = OpTypeFunction %2 %3 %3
|
||||
%19 = OpTypePointer Generic %3
|
||||
%5 = OpFunction %2 None %4
|
||||
%6 = OpFunctionParameter %3
|
||||
%7 = OpFunctionParameter %3
|
||||
%18 = OpLabel
|
||||
%13 = OpCopyObject %3 %6
|
||||
%14 = OpCopyObject %3 %7
|
||||
%15 = OpConvertUToPtr %19 %13
|
||||
%16 = OpLoad %3 %15
|
||||
%100 = OpCopyObject %3 %16
|
||||
%17 = OpConvertUToPtr %19 %14
|
||||
OpStore %17 %100
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -18,7 +18,7 @@ impl From<ast::Type> for SpirvType {
|
|||
fn from(t: ast::Type) -> Self {
|
||||
match t {
|
||||
ast::Type::Scalar(t) => SpirvType::Base(t),
|
||||
ast::Type::ExtendedScalar(t) => SpirvType::Extended(t)
|
||||
ast::Type::ExtendedScalar(t) => SpirvType::Extended(t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -60,7 +60,11 @@ impl TypeWordMap {
|
|||
})
|
||||
}
|
||||
|
||||
fn get_or_add_extended(&mut self, b: &mut dr::Builder, t: ast::ExtendedScalarType) -> spirv::Word {
|
||||
fn get_or_add_extended(
|
||||
&mut self,
|
||||
b: &mut dr::Builder,
|
||||
t: ast::ExtendedScalarType,
|
||||
) -> spirv::Word {
|
||||
*self
|
||||
.complex
|
||||
.entry(SpirvType::Extended(t))
|
||||
|
@ -178,8 +182,9 @@ fn to_ssa<'a>(
|
|||
let registers = collect_var_definitions(&f_args, &f_body);
|
||||
let (normalized_ids, unique_ids) =
|
||||
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
|
||||
let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids);
|
||||
let (mut func_body, unique_ids) =
|
||||
insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]);
|
||||
insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]);
|
||||
let bbs = get_basic_blocks(&func_body);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
|
@ -195,6 +200,221 @@ fn to_ssa<'a>(
|
|||
(func_body, bbs, phis, unique_ids)
|
||||
}
|
||||
|
||||
fn normalize_statements(
|
||||
func: Vec<ast::Statement<spirv::Word>>,
|
||||
unique_ids: spirv::Word,
|
||||
) -> (Vec<Statement>, spirv::Word) {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
let mut id = unique_ids;
|
||||
let new_id = &mut || {
|
||||
let to_insert = id;
|
||||
id += 1;
|
||||
to_insert
|
||||
};
|
||||
for s in func {
|
||||
match s {
|
||||
ast::Statement::Label(id) => result.push(Statement::Label(id)),
|
||||
ast::Statement::Instruction(pred, inst) => {
|
||||
if let Some(pred) = pred {
|
||||
let mut if_true = new_id();
|
||||
let mut if_false = new_id();
|
||||
if pred.not {
|
||||
std::mem::swap(&mut if_true, &mut if_false);
|
||||
}
|
||||
let folded_bra = match &inst {
|
||||
ast::Instruction::Bra(_, arg) => Some(arg.src),
|
||||
_ => None,
|
||||
};
|
||||
let branch = BrachCondition {
|
||||
predicate: pred.label,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
result.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
result.push(Statement::Label(if_true));
|
||||
let instr = normalize_insert_instruction(&mut result, new_id, inst);
|
||||
result.push(Statement::Instruction(instr));
|
||||
}
|
||||
result.push(Statement::Label(if_false));
|
||||
} else {
|
||||
let instr = normalize_insert_instruction(&mut result, new_id, inst);
|
||||
result.push(Statement::Instruction(instr));
|
||||
}
|
||||
}
|
||||
ast::Statement::Variable(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
(result, id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn normalize_insert_instruction(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
instr: ast::Instruction<spirv::Word>,
|
||||
) -> Instruction {
|
||||
match instr {
|
||||
ast::Instruction::Ld(d, a) => {
|
||||
let arg = normalize_expand_arg2(func, new_id, &|| Some(d.typ), a);
|
||||
Instruction::Ld(d, arg)
|
||||
}
|
||||
ast::Instruction::Mov(d, a) => {
|
||||
let arg = normalize_expand_arg2mov(func, new_id, &|| d.typ.try_as_scalar(), a);
|
||||
Instruction::Mov(d, arg)
|
||||
}
|
||||
ast::Instruction::Mul(d, a) => {
|
||||
let arg = normalize_expand_arg3(func, new_id, &|| d.typ.try_as_scalar(), a);
|
||||
Instruction::Mul(d, arg)
|
||||
}
|
||||
ast::Instruction::Add(d, a) => {
|
||||
let arg = normalize_expand_arg3(func, new_id, &|| Some(d.typ), a);
|
||||
Instruction::Add(d, arg)
|
||||
}
|
||||
ast::Instruction::Setp(d, a) => {
|
||||
let arg = normalize_expand_arg4(func, new_id, &|| Some(d.typ), a);
|
||||
Instruction::Setp(d, arg)
|
||||
}
|
||||
ast::Instruction::SetpBool(d, a) => {
|
||||
let arg = normalize_expand_arg5(func, new_id, &|| Some(d.typ), a);
|
||||
Instruction::SetpBool(d, arg)
|
||||
}
|
||||
ast::Instruction::Not(d, a) => {
|
||||
let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
|
||||
Instruction::Not(d, arg)
|
||||
}
|
||||
ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
|
||||
ast::Instruction::Cvt(d, a) => {
|
||||
let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
|
||||
Instruction::Cvt(d, arg)
|
||||
}
|
||||
ast::Instruction::Shl(d, a) => {
|
||||
let arg = normalize_expand_arg3(func, new_id, &|| todo!(), a);
|
||||
Instruction::Shl(d, arg)
|
||||
}
|
||||
ast::Instruction::St(d, a) => {
|
||||
let arg = normalize_expand_arg2st(func, new_id, &|| todo!(), a);
|
||||
Instruction::St(d, arg)
|
||||
}
|
||||
ast::Instruction::Ret(d) => Instruction::Ret(d),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg2(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg2<spirv::Word>,
|
||||
) -> Arg2 {
|
||||
Arg2 {
|
||||
dst: a.dst,
|
||||
src: normalize_expand_operand(func, new_id, inst_type, a.src),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg2mov(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg2Mov<spirv::Word>,
|
||||
) -> Arg2 {
|
||||
Arg2 {
|
||||
dst: a.dst,
|
||||
src: normalize_expand_mov_operand(func, new_id, inst_type, a.src),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg2st(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg2St<spirv::Word>,
|
||||
) -> Arg2St {
|
||||
Arg2St {
|
||||
src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
|
||||
src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg3(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg3<spirv::Word>,
|
||||
) -> Arg3 {
|
||||
Arg3 {
|
||||
dst: a.dst,
|
||||
src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
|
||||
src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg4(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg4<spirv::Word>,
|
||||
) -> Arg4 {
|
||||
Arg4 {
|
||||
dst1: a.dst1,
|
||||
dst2: a.dst2,
|
||||
src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
|
||||
src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_arg5(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
a: ast::Arg5<spirv::Word>,
|
||||
) -> Arg5 {
|
||||
Arg5 {
|
||||
dst1: a.dst1,
|
||||
dst2: a.dst2,
|
||||
src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
|
||||
src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
|
||||
src3: normalize_expand_operand(func, new_id, inst_type, a.src3),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_operand(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
opr: ast::Operand<spirv::Word>,
|
||||
) -> spirv::Word {
|
||||
match opr {
|
||||
ast::Operand::Reg(r) => r,
|
||||
ast::Operand::Imm(x) => {
|
||||
if let Some(typ) = inst_type() {
|
||||
let id = new_id();
|
||||
func.push(Statement::Constant(ConstantDefinition {
|
||||
dst: id,
|
||||
typ: typ,
|
||||
value: x,
|
||||
}));
|
||||
id
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_expand_mov_operand(
|
||||
func: &mut Vec<Statement>,
|
||||
new_id: &mut impl FnMut() -> spirv::Word,
|
||||
inst_type: &impl Fn() -> Option<ast::ScalarType>,
|
||||
opr: ast::MovOperand<spirv::Word>,
|
||||
) -> spirv::Word {
|
||||
match opr {
|
||||
ast::MovOperand::Op(opr) => normalize_expand_operand(func, new_id, inst_type, opr),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_var_definitions<'a>(
|
||||
args: &[ast::Argument<'a>],
|
||||
body: &[ast::Statement<&'a str>],
|
||||
|
@ -249,17 +469,15 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
|||
for s in normalized_ids.into_iter() {
|
||||
match s {
|
||||
Statement::Instruction(inst) => match inst {
|
||||
ast::Instruction::Ld(ld, mut arg) => {
|
||||
arg.src = arg.src.map_id(&mut |arg_src| {
|
||||
insert_implicit_conversions_ld_src(
|
||||
Instruction::Ld(ld, mut arg) => {
|
||||
arg.src = insert_implicit_conversions_ld_src(
|
||||
&mut result,
|
||||
ast::Type::Scalar(ld.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
ld.state_space,
|
||||
arg_src,
|
||||
)
|
||||
});
|
||||
arg.src,
|
||||
);
|
||||
insert_with_implicit_conversion_dst(
|
||||
&mut result,
|
||||
ld.typ,
|
||||
|
@ -268,40 +486,35 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
|
|||
should_convert_relaxed_dst,
|
||||
arg,
|
||||
|arg| &mut arg.dst,
|
||||
|arg| ast::Instruction::Ld(ld, arg),
|
||||
|arg| Instruction::Ld(ld, arg),
|
||||
);
|
||||
}
|
||||
ast::Instruction::St(st, mut arg) => {
|
||||
arg.src2 = arg.src2.map_id(&mut |arg_src| {
|
||||
let arg_src_type = type_check(arg_src);
|
||||
if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) {
|
||||
insert_conversion_src(
|
||||
Instruction::St(st, mut arg) => {
|
||||
let arg_src2_type = type_check(arg.src2);
|
||||
if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
|
||||
arg.src2 = insert_conversion_src(
|
||||
&mut result,
|
||||
new_id,
|
||||
arg_src,
|
||||
arg_src_type,
|
||||
arg.src2,
|
||||
arg_src2_type,
|
||||
ast::Type::Scalar(st.typ),
|
||||
conv,
|
||||
)
|
||||
} else {
|
||||
arg_src
|
||||
);
|
||||
}
|
||||
});
|
||||
arg.src1 = arg.src1.map_id(&mut |arg_src| {
|
||||
insert_implicit_conversions_ld_src(
|
||||
arg.src1 = insert_implicit_conversions_ld_src(
|
||||
&mut result,
|
||||
ast::Type::Scalar(st.typ),
|
||||
type_check,
|
||||
new_id,
|
||||
st.state_space.to_ld_ss(),
|
||||
arg_src,
|
||||
)
|
||||
});
|
||||
result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
|
||||
arg.src1,
|
||||
);
|
||||
result.push(Statement::Instruction(Instruction::St(st, arg)));
|
||||
}
|
||||
inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
|
||||
},
|
||||
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
|
||||
Statement::Constant(_) => (),
|
||||
Statement::Converison(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -390,61 +603,52 @@ fn emit_function_body_ops(
|
|||
// If block starts with a label it has already been emitted,
|
||||
// all other labels in the block are unused
|
||||
Statement::Label(_) => (),
|
||||
Statement::Constant(_) => todo!(),
|
||||
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
|
||||
Statement::Conditional(bra) => {
|
||||
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
|
||||
}
|
||||
Statement::Instruction(inst) => match inst {
|
||||
// SPIR-V does not support marking jumps as guaranteed-converged
|
||||
ast::Instruction::Bra(_, arg) => {
|
||||
Instruction::Bra(_, arg) => {
|
||||
builder.branch(arg.src)?;
|
||||
}
|
||||
ast::Instruction::Ld(data, arg) => {
|
||||
Instruction::Ld(data, arg) => {
|
||||
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
|
||||
todo!()
|
||||
}
|
||||
let src = match arg.src {
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
let result_type = map.get_or_add_scalar(builder, data.typ);
|
||||
match data.state_space {
|
||||
ast::LdStateSpace::Generic => {
|
||||
builder.load(result_type, Some(arg.dst), src, None, [])?;
|
||||
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
|
||||
}
|
||||
ast::LdStateSpace::Param => {
|
||||
builder.copy_object(result_type, Some(arg.dst), src)?;
|
||||
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
ast::Instruction::St(data, arg) => {
|
||||
Instruction::St(data, arg) => {
|
||||
if data.qualifier != ast::LdStQualifier::Weak
|
||||
|| data.vector.is_some()
|
||||
|| data.state_space != ast::StStateSpace::Generic
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
let dst = match arg.src1 {
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
let src = match arg.src2 {
|
||||
ast::Operand::Reg(id) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.store(dst, src, None, &[])?;
|
||||
builder.store(arg.src1, arg.src2, None, &[])?;
|
||||
}
|
||||
// SPIR-V does not support ret as guaranteed-converged
|
||||
ast::Instruction::Ret(_) => builder.ret()?,
|
||||
ast::Instruction::Mov(mov, arg) => {
|
||||
Instruction::Ret(_) => builder.ret()?,
|
||||
Instruction::Mov(mov, arg) => {
|
||||
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
|
||||
let src = match arg.src {
|
||||
ast::MovOperand::Op(ast::Operand::Reg(id)) => id,
|
||||
_ => todo!(),
|
||||
};
|
||||
builder.copy_object(result_type, Some(arg.dst), src)?;
|
||||
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
|
||||
}
|
||||
Instruction::Mul(mul, arg) => match mul.desc {
|
||||
ast::MulDescriptor::Int(ref ctr) => {
|
||||
emit_mul_int(builder, map, mul.typ, ctr, arg)
|
||||
}
|
||||
ast::MulDescriptor::Float(_) => todo!(),
|
||||
},
|
||||
_ => todo!(),
|
||||
},
|
||||
}
|
||||
|
@ -453,6 +657,17 @@ fn emit_function_body_ops(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_mul_int(
|
||||
_builder: &mut dr::Builder,
|
||||
_map: &mut TypeWordMap,
|
||||
_typ: ast::Type,
|
||||
_ctr: &ast::MulIntControl,
|
||||
_arg: &Arg3,
|
||||
) {
|
||||
//let inst_type = map.get_or_add(builder, SpirvType::from(typ));
|
||||
//builder.i_mul(inst_type, Some(arg.dst), Some(arg.src1), Some(arg.src2));
|
||||
}
|
||||
|
||||
fn emit_implicit_conversion(
|
||||
builder: &mut dr::Builder,
|
||||
map: &mut TypeWordMap,
|
||||
|
@ -523,12 +738,11 @@ fn normalize_identifiers<'a>(
|
|||
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
|
||||
type_map: &mut HashMap<spirv::Word, ast::Type>,
|
||||
types: HashMap<Cow<'a, str>, ast::Type>,
|
||||
) -> (Vec<Statement>, spirv::Word) {
|
||||
let mut result = Vec::with_capacity(func.len());
|
||||
) -> (Vec<ast::Statement<spirv::Word>>, spirv::Word) {
|
||||
let mut id: u32 = constant_identifiers.len() as u32;
|
||||
let mut remapped_ids = HashMap::new();
|
||||
let mut get_or_add = |key| match key {
|
||||
Some(key) => constant_identifiers.get(key).map_or_else(
|
||||
let mut get_or_add = |key| {
|
||||
constant_identifiers.get(key).map_or_else(
|
||||
|| {
|
||||
*remapped_ids.entry(key).or_insert_with(|| {
|
||||
let to_insert = id;
|
||||
|
@ -537,16 +751,12 @@ fn normalize_identifiers<'a>(
|
|||
})
|
||||
},
|
||||
|id| *id,
|
||||
),
|
||||
None => {
|
||||
let to_insert = id;
|
||||
id += 1;
|
||||
to_insert
|
||||
}
|
||||
)
|
||||
};
|
||||
for s in func {
|
||||
Statement::from_ast(s, &mut result, &mut get_or_add);
|
||||
}
|
||||
let result = func
|
||||
.into_iter()
|
||||
.filter_map(|s| Statement::from_ast(s, &mut get_or_add))
|
||||
.collect::<Vec<_>>();
|
||||
type_map.extend(
|
||||
remapped_ids
|
||||
.into_iter()
|
||||
|
@ -594,7 +804,7 @@ fn apply_ssa_renaming(
|
|||
for s in get_bb_body(func, bbs, BBIndex(bb)) {
|
||||
s.visit_id(&mut |is_dst, id| {
|
||||
if is_dst {
|
||||
old_dst_id[bb].push(*id)
|
||||
old_dst_id[bb].push(id)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -787,8 +997,8 @@ fn gather_phi_sets(
|
|||
let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize];
|
||||
for bb in 0..cfg.len() {
|
||||
let mut var_kill = HashSet::new();
|
||||
let mut visitor = |is_dst, id: &u32| {
|
||||
if *id >= constant_ids {
|
||||
let mut visitor = |is_dst, id: spirv::Word| {
|
||||
if id >= constant_ids {
|
||||
let id = id - constant_ids;
|
||||
if is_dst {
|
||||
var_kill.insert(id);
|
||||
|
@ -807,8 +1017,9 @@ fn gather_phi_sets(
|
|||
for s in get_bb_body(func, cfg, BBIndex(bb)) {
|
||||
match s {
|
||||
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
|
||||
Statement::Conditional(brc) => visitor(false, &brc.predicate),
|
||||
Statement::Conditional(brc) => visitor(false, brc.predicate),
|
||||
Statement::Converison(conv) => conv.visit_id(&mut visitor),
|
||||
Statement::Constant(cons) => cons.visit_id(&mut visitor),
|
||||
// label redefinition is a compile-time error
|
||||
Statement::Label(_) => (),
|
||||
}
|
||||
|
@ -859,6 +1070,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
|||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
|
||||
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
|
||||
}
|
||||
Statement::Constant(_) => (),
|
||||
Statement::Converison(_) => (),
|
||||
};
|
||||
}
|
||||
|
@ -877,7 +1089,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
|
|||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||
}
|
||||
}
|
||||
Statement::Converison(_) | Statement::Label(_) => {
|
||||
Statement::Converison(_) | Statement::Constant(_) | Statement::Label(_) => {
|
||||
bb_edge.insert((StmtIndex(target.0 - 1), target));
|
||||
}
|
||||
// This is already in `unresolved_bb_edge`
|
||||
|
@ -1043,10 +1255,241 @@ impl fmt::Display for BBIndex {
|
|||
|
||||
enum Statement {
|
||||
Label(u32),
|
||||
Instruction(ast::Instruction<spirv::Word>),
|
||||
Instruction(Instruction),
|
||||
// SPIR-V compatible replacement for PTX predicates
|
||||
Conditional(BrachCondition),
|
||||
Converison(ImplicitConversion),
|
||||
Constant(ConstantDefinition),
|
||||
}
|
||||
|
||||
enum Instruction {
|
||||
Ld(ast::LdData, Arg2),
|
||||
Mov(ast::MovData, Arg2),
|
||||
Mul(ast::MulData, Arg3),
|
||||
Add(ast::AddData, Arg3),
|
||||
Setp(ast::SetpData, Arg4),
|
||||
SetpBool(ast::SetpBoolData, Arg5),
|
||||
Not(ast::NotData, Arg2),
|
||||
Bra(ast::BraData, Arg1),
|
||||
Cvt(ast::CvtData, Arg2),
|
||||
Shl(ast::ShlData, Arg3),
|
||||
St(ast::StData, Arg2St),
|
||||
Ret(ast::RetData),
|
||||
}
|
||||
|
||||
impl Instruction {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
match self {
|
||||
Instruction::Ld(_, a) => a.visit_id(f),
|
||||
Instruction::Mov(_, a) => a.visit_id(f),
|
||||
Instruction::Mul(_, a) => a.visit_id(f),
|
||||
Instruction::Add(_, a) => a.visit_id(f),
|
||||
Instruction::Setp(_, a) => a.visit_id(f),
|
||||
Instruction::SetpBool(_, a) => a.visit_id(f),
|
||||
Instruction::Not(_, a) => a.visit_id(f),
|
||||
Instruction::Cvt(_, a) => a.visit_id(f),
|
||||
Instruction::Shl(_, a) => a.visit_id(f),
|
||||
Instruction::St(_, a) => a.visit_id(f),
|
||||
Instruction::Bra(_, a) => a.visit_id(f),
|
||||
Instruction::Ret(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
match self {
|
||||
Instruction::Ld(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Mov(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Mul(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Add(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Setp(_, a) => a.visit_id_mut(f),
|
||||
Instruction::SetpBool(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Not(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Cvt(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Shl(_, a) => a.visit_id_mut(f),
|
||||
Instruction::St(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Bra(_, a) => a.visit_id_mut(f),
|
||||
Instruction::Ret(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ast::Type> {
|
||||
match self {
|
||||
Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
|
||||
Instruction::Ret(_) => None,
|
||||
Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
|
||||
Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
|
||||
Instruction::Mov(mov, _) => Some(mov.typ),
|
||||
Instruction::Mul(mul, _) => Some(mul.typ),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn jump_target(&self) -> Option<spirv::Word> {
|
||||
match self {
|
||||
Instruction::Bra(_, a) => Some(a.src),
|
||||
Instruction::Ld(_, _)
|
||||
| Instruction::Mov(_, _)
|
||||
| Instruction::Mul(_, _)
|
||||
| Instruction::Add(_, _)
|
||||
| Instruction::Setp(_, _)
|
||||
| Instruction::SetpBool(_, _)
|
||||
| Instruction::Not(_, _)
|
||||
| Instruction::Cvt(_, _)
|
||||
| Instruction::Shl(_, _)
|
||||
| Instruction::St(_, _)
|
||||
| Instruction::Ret(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_terminal(&self) -> bool {
|
||||
match self {
|
||||
Instruction::Ret(_) => true,
|
||||
Instruction::Ld(_, _)
|
||||
| Instruction::Mov(_, _)
|
||||
| Instruction::Mul(_, _)
|
||||
| Instruction::Add(_, _)
|
||||
| Instruction::Setp(_, _)
|
||||
| Instruction::SetpBool(_, _)
|
||||
| Instruction::Not(_, _)
|
||||
| Instruction::Cvt(_, _)
|
||||
| Instruction::Shl(_, _)
|
||||
| Instruction::St(_, _)
|
||||
| Instruction::Bra(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Arg1 {
|
||||
pub src: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg1 {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, self.src);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src);
|
||||
}
|
||||
}
|
||||
|
||||
struct Arg2 {
|
||||
pub dst: spirv::Word,
|
||||
pub src: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg2 {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(true, self.dst);
|
||||
f(false, self.src);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src);
|
||||
f(true, &mut self.dst);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Arg2St {
|
||||
pub src1: spirv::Word,
|
||||
pub src2: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg2St {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, self.src1);
|
||||
f(false, self.src2);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src1);
|
||||
f(false, &mut self.src2);
|
||||
}
|
||||
}
|
||||
|
||||
struct Arg3 {
|
||||
pub dst: spirv::Word,
|
||||
pub src1: spirv::Word,
|
||||
pub src2: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg3 {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(true, self.dst);
|
||||
f(false, self.src1);
|
||||
f(false, self.src2);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src1);
|
||||
f(false, &mut self.src2);
|
||||
f(true, &mut self.dst);
|
||||
}
|
||||
}
|
||||
|
||||
struct Arg4 {
|
||||
pub dst1: spirv::Word,
|
||||
pub dst2: Option<spirv::Word>,
|
||||
pub src1: spirv::Word,
|
||||
pub src2: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg4 {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(true, self.dst1);
|
||||
self.dst2.map(|dst2| f(true, dst2));
|
||||
f(false, self.src1);
|
||||
f(false, self.src2);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src1);
|
||||
f(false, &mut self.src2);
|
||||
f(true, &mut self.dst1);
|
||||
self.dst2.as_mut().map(|dst2| f(true, dst2));
|
||||
}
|
||||
}
|
||||
|
||||
struct Arg5 {
|
||||
pub dst1: spirv::Word,
|
||||
pub dst2: Option<spirv::Word>,
|
||||
pub src1: spirv::Word,
|
||||
pub src2: spirv::Word,
|
||||
pub src3: spirv::Word,
|
||||
}
|
||||
|
||||
impl Arg5 {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(true, self.dst1);
|
||||
self.dst2.map(|dst2| f(true, dst2));
|
||||
f(false, self.src1);
|
||||
f(false, self.src2);
|
||||
f(false, self.src3);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(false, &mut self.src1);
|
||||
f(false, &mut self.src2);
|
||||
f(false, &mut self.src3);
|
||||
f(true, &mut self.dst1);
|
||||
self.dst2.as_mut().map(|dst2| f(true, dst2));
|
||||
}
|
||||
}
|
||||
|
||||
struct ConstantDefinition {
|
||||
pub dst: spirv::Word,
|
||||
pub typ: ast::ScalarType,
|
||||
pub value: i128,
|
||||
}
|
||||
|
||||
impl ConstantDefinition {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(true, self.dst);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
f(true, &mut self.dst);
|
||||
}
|
||||
}
|
||||
|
||||
struct BrachCondition {
|
||||
|
@ -1056,10 +1499,10 @@ struct BrachCondition {
|
|||
}
|
||||
|
||||
impl BrachCondition {
|
||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, &self.predicate);
|
||||
f(false, &self.if_true);
|
||||
f(false, &self.if_false);
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, self.predicate);
|
||||
f(false, self.if_true);
|
||||
f(false, self.if_false);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
|
@ -1086,9 +1529,9 @@ enum ConversionKind {
|
|||
}
|
||||
|
||||
impl ImplicitConversion {
|
||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, &self.src);
|
||||
f(true, &self.dst);
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
f(false, self.src);
|
||||
f(true, self.dst);
|
||||
}
|
||||
|
||||
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
|
||||
|
@ -1098,54 +1541,27 @@ impl ImplicitConversion {
|
|||
}
|
||||
|
||||
impl Statement {
|
||||
fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
|
||||
fn from_ast<'a, F: FnMut(&'a str) -> u32>(
|
||||
s: ast::Statement<&'a str>,
|
||||
out: &mut Vec<Statement>,
|
||||
get_id: &mut F,
|
||||
) {
|
||||
) -> Option<ast::Statement<spirv::Word>> {
|
||||
match s {
|
||||
ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))),
|
||||
ast::Statement::Instruction(p, i) => {
|
||||
if let Some(pred) = p {
|
||||
let predicate = get_id(Some(pred.label));
|
||||
let mut if_true = get_id(None);
|
||||
let mut if_false = get_id(None);
|
||||
if pred.not {
|
||||
std::mem::swap(&mut if_true, &mut if_false);
|
||||
}
|
||||
let folded_bra = match &i {
|
||||
ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))),
|
||||
_ => None,
|
||||
};
|
||||
let branch = BrachCondition {
|
||||
predicate,
|
||||
if_true: folded_bra.unwrap_or(if_true),
|
||||
if_false,
|
||||
};
|
||||
out.push(Statement::Conditional(branch));
|
||||
if folded_bra.is_none() {
|
||||
out.push(Statement::Label(if_true));
|
||||
out.push(Statement::Instruction(
|
||||
i.map_id(&mut |name| get_id(Some(name))),
|
||||
));
|
||||
}
|
||||
out.push(Statement::Label(if_false));
|
||||
} else {
|
||||
out.push(Statement::Instruction(
|
||||
i.map_id(&mut |name| get_id(Some(name))),
|
||||
));
|
||||
}
|
||||
}
|
||||
ast::Statement::Variable(_) => (),
|
||||
ast::Statement::Label(name) => Some(ast::Statement::Label(get_id(name))),
|
||||
ast::Statement::Instruction(p, i) => Some(ast::Statement::Instruction(
|
||||
p.map(|p| p.map_id(get_id)),
|
||||
i.map_id(get_id),
|
||||
)),
|
||||
ast::Statement::Variable(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
|
||||
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
|
||||
match self {
|
||||
Statement::Label(id) => f(false, id),
|
||||
Statement::Label(id) => f(false, *id),
|
||||
Statement::Instruction(inst) => inst.visit_id(f),
|
||||
Statement::Conditional(bra) => bra.visit_id(f),
|
||||
Statement::Converison(conv) => conv.visit_id(f),
|
||||
Statement::Constant(cons) => cons.visit_id(f),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1157,6 +1573,16 @@ impl Statement {
|
|||
Statement::Instruction(inst) => inst.visit_id_mut(f),
|
||||
Statement::Conditional(bra) => bra.visit_id_mut(f),
|
||||
Statement::Converison(conv) => conv.visit_id_mut(f),
|
||||
Statement::Constant(cons) => cons.visit_id_mut(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ast::PredAt<T> {
|
||||
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
|
||||
ast::PredAt {
|
||||
not: self.not,
|
||||
label: f(self.label),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1220,7 +1646,8 @@ impl<T> ast::Instruction<T> {
|
|||
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
|
||||
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
|
||||
ast::Instruction::Mov(mov, _) => Some(mov.typ),
|
||||
_ => todo!()
|
||||
ast::Instruction::Mul(mul, _) => Some(mul.typ),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1476,6 +1903,15 @@ enum ScalarKind {
|
|||
Float,
|
||||
}
|
||||
|
||||
impl ast::Type {
|
||||
fn try_as_scalar(self) -> Option<ast::ScalarType> {
|
||||
match self {
|
||||
ast::Type::Scalar(s) => Some(s),
|
||||
ast::Type::ExtendedScalar(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ast::ScalarType {
|
||||
fn width(self) -> u8 {
|
||||
match self {
|
||||
|
@ -1688,7 +2124,7 @@ fn insert_with_implicit_conversion_dst<
|
|||
NewId: FnMut() -> spirv::Word,
|
||||
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
|
||||
Setter: Fn(&mut T) -> &mut spirv::Word,
|
||||
ToInstruction: FnOnce(T) -> ast::Instruction<spirv::Word>,
|
||||
ToInstruction: FnOnce(T) -> Instruction,
|
||||
>(
|
||||
func: &mut Vec<Statement>,
|
||||
instr_type: ast::ScalarType,
|
||||
|
@ -1821,7 +2257,7 @@ fn insert_implicit_bitcasts<
|
|||
func: &mut Vec<Statement>,
|
||||
type_check: &TypeCheck,
|
||||
new_id: &mut NewId,
|
||||
mut instr: ast::Instruction<spirv::Word>,
|
||||
mut instr: Instruction,
|
||||
) {
|
||||
let mut dst_coercion = None;
|
||||
if let Some(instr_type) = instr.get_type() {
|
||||
|
@ -1984,9 +2420,9 @@ mod tests {
|
|||
fn get_basic_blocks_miniloop() {
|
||||
let func = vec![
|
||||
Statement::Label(12),
|
||||
Statement::Instruction(ast::Instruction::Bra(
|
||||
Statement::Instruction(Instruction::Bra(
|
||||
ast::BraData { uniform: false },
|
||||
ast::Arg1 { src: 12 },
|
||||
Arg1 { src: 12 },
|
||||
)),
|
||||
];
|
||||
let bbs = get_basic_blocks(&func);
|
||||
|
@ -2226,9 +2662,10 @@ mod tests {
|
|||
let mut constant_ids = HashMap::new();
|
||||
collect_label_ids(&mut constant_ids, &ast);
|
||||
let registers = collect_var_definitions(&[], &ast);
|
||||
let (normalized_ids, _) =
|
||||
let (normalized_ids, unique_ids) =
|
||||
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
|
||||
let mut bbs = get_basic_blocks(&normalized_ids);
|
||||
let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids);
|
||||
let mut bbs = get_basic_blocks(&normalized_stmts);
|
||||
bbs.iter_mut().for_each(sort_pred_succ);
|
||||
assert_eq!(
|
||||
bbs,
|
||||
|
@ -2239,32 +2676,32 @@ mod tests {
|
|||
succ: vec![BBIndex(1)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(3),
|
||||
start: StmtIndex(6),
|
||||
pred: vec![BBIndex(0), BBIndex(5)],
|
||||
succ: vec![BBIndex(2), BBIndex(6)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(6),
|
||||
start: StmtIndex(10),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![BBIndex(3), BBIndex(4)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(9),
|
||||
start: StmtIndex(14),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(13),
|
||||
start: StmtIndex(19),
|
||||
pred: vec![BBIndex(2)],
|
||||
succ: vec![BBIndex(5)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(16),
|
||||
start: StmtIndex(23),
|
||||
pred: vec![BBIndex(3), BBIndex(4)],
|
||||
succ: vec![BBIndex(1)],
|
||||
},
|
||||
BasicBlock {
|
||||
start: StmtIndex(18),
|
||||
start: StmtIndex(25),
|
||||
pred: vec![BBIndex(1)],
|
||||
succ: vec![],
|
||||
},
|
||||
|
@ -2375,14 +2812,15 @@ mod tests {
|
|||
collect_label_ids(&mut constant_ids, &fn_ast);
|
||||
assert_eq!(constant_ids.len(), 4);
|
||||
let registers = collect_var_definitions(&[], &fn_ast);
|
||||
let (normalized_ids, max_id) =
|
||||
let (normalized_ids, unique_ids) =
|
||||
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
|
||||
let bbs = get_basic_blocks(&normalized_ids);
|
||||
let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids);
|
||||
let bbs = get_basic_blocks(&normalized_stmts);
|
||||
let rpostorder = to_reverse_postorder(&bbs);
|
||||
let doms = immediate_dominators(&bbs, &rpostorder);
|
||||
let dom_fronts = dominance_frontiers(&bbs, &doms);
|
||||
let phi = gather_phi_sets(
|
||||
&normalized_ids,
|
||||
&normalized_stmts,
|
||||
constant_ids.len() as u32,
|
||||
max_id,
|
||||
&bbs,
|
||||
|
@ -2490,7 +2928,7 @@ mod tests {
|
|||
for s in func {
|
||||
s.visit_id(&mut |is_dst, id| {
|
||||
if is_dst {
|
||||
assert!(seen.insert(*id));
|
||||
assert!(seen.insert(id));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -2504,7 +2942,7 @@ mod tests {
|
|||
fn get_ids(s: &Statement) -> Vec<spirv::Word> {
|
||||
let mut result = Vec::new();
|
||||
s.visit_id(&mut |_, id| {
|
||||
result.push(*id);
|
||||
result.push(id);
|
||||
});
|
||||
result
|
||||
}
|
||||
|
@ -2533,7 +2971,7 @@ mod tests {
|
|||
let mut result = None;
|
||||
s.visit_id(&mut |is_dst, id| {
|
||||
if is_dst {
|
||||
assert_eq!(result.replace(*id), None);
|
||||
assert_eq!(result.replace(id), None);
|
||||
}
|
||||
});
|
||||
result.unwrap()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue