diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index d3c4b73..e7be3b7 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -14,6 +14,7 @@ spirv_headers = "1.4" quick-error = "1.2" bit-vec = "0.6" paste = "0.1" +half ="1.6" [build-dependencies.lalrpop] version = "0.18.1" diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index c7cb7f7..0efc37c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -243,15 +243,74 @@ pub struct MovData { pub typ: Type, } -pub struct MulData {} +pub struct MulData { + pub typ: Type, + pub desc: MulDescriptor, +} + +pub enum MulDescriptor { + Int(MulIntControl), + Float(MulFloatDesc), +} + +pub enum MulIntControl { + Low, + High, + Wide +} + +pub struct MulFloatDesc { + pub rounding: Option, + pub flush_to_zero: bool, + pub saturate: bool, +} + +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf +} pub struct AddData { pub typ: ScalarType, } -pub struct SetpData {} +pub struct SetpData { + pub typ: ScalarType, + pub flush_to_zero: bool, + pub cmp_op: SetpCompareOp, +} -pub struct SetpBoolData {} +pub enum SetpCompareOp { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, + NanEq, + NanNotEq, + NanLess, + NanLessOrEq, + NanGreater, + NanGreaterOrEq, + IsNotNan, + IsNan, +} + +pub enum SetpBoolPostOp { + And, + Or, + Xor, +} + +pub struct SetpBoolData { + pub typ: ScalarType, + pub flush_to_zero: bool, + pub cmp_op: SetpCompareOp, + pub bool_op: SetpBoolPostOp +} pub struct NotData {} diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 5402326..15302ff 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -12,6 +12,7 @@ extern crate level_zero as ze; extern crate level_zero_sys as l0; extern crate rspirv; extern crate spirv_headers as spirv; +extern crate half; #[cfg(test)] extern crate spirv_tools_sys as spirv_tools; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 64d7725..b44702d 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -399,20 +399,56 @@ InstMul: ast::Instruction<&'input str> = { }; InstMulMode: ast::MulData = { - MulIntControl? IntType => ast::MulData{}, - RoundingMode? ".ftz"? ".sat"? ".f32" => ast::MulData{}, - RoundingMode? ".f64" => ast::MulData{}, - ".rn"? ".ftz"? ".sat"? ".f16" => ast::MulData{}, - ".rn"? ".ftz"? ".sat"? ".f16x2" => ast::MulData{} + => ast::MulData{ + typ: ast::Type::Scalar(t), + desc: ast::MulDescriptor::Int(ctr) + }, + ".f32" => ast::MulData{ + typ: ast::Type::Scalar(ast::ScalarType::F32), + desc: ast::MulDescriptor::Float(ast::MulFloatDesc { + rounding: r, + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }) + }, + ".f64" => ast::MulData{ + typ: ast::Type::Scalar(ast::ScalarType::F64), + desc: ast::MulDescriptor::Float(ast::MulFloatDesc { + rounding: r, + flush_to_zero: false, + saturate: false + }) + }, + ".f16" => ast::MulData{ + typ: ast::Type::Scalar(ast::ScalarType::F16), + desc: ast::MulDescriptor::Float(ast::MulFloatDesc { + rounding: r.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }) + }, + ".f16x2" => ast::MulData{ + typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2), + desc: ast::MulDescriptor::Float(ast::MulFloatDesc { + rounding: r.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: s.is_some() + }) + } }; -MulIntControl = { - ".hi", ".lo", ".wide" +MulIntControl: ast::MulIntControl = { + ".hi" => ast::MulIntControl::High, + ".lo" => ast::MulIntControl::Low, + ".wide" => ast::MulIntControl::Wide }; #[inline] -RoundingMode = { - ".rn", ".rz", ".rm", ".rp" +RoundingMode : ast::RoundingMode = { + ".rn" => ast::RoundingMode::NearestEven, + ".rz" => ast::RoundingMode::Zero, + ".rm" => ast::RoundingMode::NegativeInf, + ".rp" => ast::RoundingMode::PositiveInf, }; IntType : ast::ScalarType = { @@ -449,27 +485,61 @@ InstSetp: ast::Instruction<&'input str> = { }; SetpMode: ast::SetpData = { - SetpCmpOp ".ftz"? SetpType => ast::SetpData{} + => ast::SetpData{ + typ: t, + flush_to_zero: ftz.is_some(), + cmp_op: cmp_op, + } }; SetpBoolMode: ast::SetpBoolData = { - SetpCmpOp SetpBoolOp ".ftz"? SetpType => ast::SetpBoolData{} + => ast::SetpBoolData{ + typ: t, + flush_to_zero: ftz.is_some(), + cmp_op: cmp_op, + bool_op: bool_op, + } }; -SetpCmpOp = { - ".eq", ".ne", ".lt", ".le", ".gt", ".ge", ".lo", ".ls", ".hi", ".hs", - ".equ", ".neu", ".ltu", ".leu", ".gtu", ".geu", ".num", ".nan" +SetpCompareOp: ast::SetpCompareOp = { + ".eq" => ast::SetpCompareOp::Eq, + ".ne" => ast::SetpCompareOp::NotEq, + ".lt" => ast::SetpCompareOp::Less, + ".le" => ast::SetpCompareOp::LessOrEq, + ".gt" => ast::SetpCompareOp::Greater, + ".ge" => ast::SetpCompareOp::GreaterOrEq, + ".lo" => ast::SetpCompareOp::Less, + ".ls" => ast::SetpCompareOp::LessOrEq, + ".hi" => ast::SetpCompareOp::Greater, + ".hs" => ast::SetpCompareOp::GreaterOrEq, + ".equ" => ast::SetpCompareOp::NanEq, + ".neu" => ast::SetpCompareOp::NanNotEq, + ".ltu" => ast::SetpCompareOp::NanLess, + ".leu" => ast::SetpCompareOp::NanLessOrEq, + ".gtu" => ast::SetpCompareOp::NanGreater, + ".geu" => ast::SetpCompareOp::NanGreaterOrEq, + ".num" => ast::SetpCompareOp::IsNotNan, + ".nan" => ast::SetpCompareOp::IsNan, }; -SetpBoolOp = { - ".and", ".or", ".xor" +SetpBoolPostOp: ast::SetpBoolPostOp = { + ".and" => ast::SetpBoolPostOp::And, + ".or" => ast::SetpBoolPostOp::Or, + ".xor" => ast::SetpBoolPostOp::Xor, }; -SetpType = { - ".b16", ".b32", ".b64", - ".u16", ".u32", ".u64", - ".s16", ".s32", ".s64", - ".f32", ".f64" +SetpType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index b573f2c..b374324 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -37,7 +37,9 @@ macro_rules! test_ptx { } test_ptx!(ld_st, [1u64], [1u64]); -test_ptx!(mov, [1u64], [1u64]); +//test_ptx!(mov, [1u64], [1u64]); +//test_ptx!(mul_lo, [1u64], [2u64]); +//test_ptx!(mul_hi, [u64::max_value()], [1u64]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/mul_hi.ptx b/ptx/src/test/spirv_run/mul_hi.ptx new file mode 100644 index 0000000..1dc1572 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_hi.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul_hi( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + mul.hi.u64 temp2, temp, 2; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/mul_hi.spvtxt b/ptx/src/test/spirv_run/mul_hi.spvtxt new file mode 100644 index 0000000..db8943f --- /dev/null +++ b/ptx/src/test/spirv_run/mul_hi.spvtxt @@ -0,0 +1,26 @@ +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpCapability Int8 +%1 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %5 "mul_hi" +%2 = OpTypeVoid +%3 = OpTypeInt 64 0 +%4 = OpTypeFunction %2 %3 %3 +%19 = OpTypePointer Generic %3 +%5 = OpFunction %2 None %4 +%6 = OpFunctionParameter %3 +%7 = OpFunctionParameter %3 +%18 = OpLabel +%13 = OpCopyObject %3 %6 +%14 = OpCopyObject %3 %7 +%15 = OpConvertUToPtr %19 %13 +%16 = OpLoad %3 %15 +%100 = OpCopyObject %3 %16 +%17 = OpConvertUToPtr %19 %14 +OpStore %17 %100 +OpReturn +OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_lo.ptx b/ptx/src/test/spirv_run/mul_lo.ptx new file mode 100644 index 0000000..cae3b57 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_lo.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mul_lo( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + mul.lo.u64 temp2, temp, 2; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/mul_lo.spvtxt b/ptx/src/test/spirv_run/mul_lo.spvtxt new file mode 100644 index 0000000..66e7bc1 --- /dev/null +++ b/ptx/src/test/spirv_run/mul_lo.spvtxt @@ -0,0 +1,26 @@ +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int64 +OpCapability Int8 +%1 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %5 "mul_lo" +%2 = OpTypeVoid +%3 = OpTypeInt 64 0 +%4 = OpTypeFunction %2 %3 %3 +%19 = OpTypePointer Generic %3 +%5 = OpFunction %2 None %4 +%6 = OpFunctionParameter %3 +%7 = OpFunctionParameter %3 +%18 = OpLabel +%13 = OpCopyObject %3 %6 +%14 = OpCopyObject %3 %7 +%15 = OpConvertUToPtr %19 %13 +%16 = OpLoad %3 %15 +%100 = OpCopyObject %3 %16 +%17 = OpConvertUToPtr %19 %14 +OpStore %17 %100 +OpReturn +OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ee28bb7..6620666 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -18,7 +18,7 @@ impl From for SpirvType { fn from(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t), - ast::Type::ExtendedScalar(t) => SpirvType::Extended(t) + ast::Type::ExtendedScalar(t) => SpirvType::Extended(t), } } } @@ -60,7 +60,11 @@ impl TypeWordMap { }) } - fn get_or_add_extended(&mut self, b: &mut dr::Builder, t: ast::ExtendedScalarType) -> spirv::Word { + fn get_or_add_extended( + &mut self, + b: &mut dr::Builder, + t: ast::ExtendedScalarType, + ) -> spirv::Word { *self .complex .entry(SpirvType::Extended(t)) @@ -178,8 +182,9 @@ fn to_ssa<'a>( let registers = collect_var_definitions(&f_args, &f_body); let (normalized_ids, unique_ids) = normalize_identifiers(f_body, &contant_ids, &mut type_check, registers); + let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids); let (mut func_body, unique_ids) = - insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]); + insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]); let bbs = get_basic_blocks(&func_body); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); @@ -195,6 +200,221 @@ fn to_ssa<'a>( (func_body, bbs, phis, unique_ids) } +fn normalize_statements( + func: Vec>, + unique_ids: spirv::Word, +) -> (Vec, spirv::Word) { + let mut result = Vec::with_capacity(func.len()); + let mut id = unique_ids; + let new_id = &mut || { + let to_insert = id; + id += 1; + to_insert + }; + for s in func { + match s { + ast::Statement::Label(id) => result.push(Statement::Label(id)), + ast::Statement::Instruction(pred, inst) => { + if let Some(pred) = pred { + let mut if_true = new_id(); + let mut if_false = new_id(); + if pred.not { + std::mem::swap(&mut if_true, &mut if_false); + } + let folded_bra = match &inst { + ast::Instruction::Bra(_, arg) => Some(arg.src), + _ => None, + }; + let branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + let instr = normalize_insert_instruction(&mut result, new_id, inst); + result.push(Statement::Instruction(instr)); + } + result.push(Statement::Label(if_false)); + } else { + let instr = normalize_insert_instruction(&mut result, new_id, inst); + result.push(Statement::Instruction(instr)); + } + } + ast::Statement::Variable(_) => unreachable!(), + } + } + (result, id) +} + +#[must_use] +fn normalize_insert_instruction( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + instr: ast::Instruction, +) -> Instruction { + match instr { + ast::Instruction::Ld(d, a) => { + let arg = normalize_expand_arg2(func, new_id, &|| Some(d.typ), a); + Instruction::Ld(d, arg) + } + ast::Instruction::Mov(d, a) => { + let arg = normalize_expand_arg2mov(func, new_id, &|| d.typ.try_as_scalar(), a); + Instruction::Mov(d, arg) + } + ast::Instruction::Mul(d, a) => { + let arg = normalize_expand_arg3(func, new_id, &|| d.typ.try_as_scalar(), a); + Instruction::Mul(d, arg) + } + ast::Instruction::Add(d, a) => { + let arg = normalize_expand_arg3(func, new_id, &|| Some(d.typ), a); + Instruction::Add(d, arg) + } + ast::Instruction::Setp(d, a) => { + let arg = normalize_expand_arg4(func, new_id, &|| Some(d.typ), a); + Instruction::Setp(d, arg) + } + ast::Instruction::SetpBool(d, a) => { + let arg = normalize_expand_arg5(func, new_id, &|| Some(d.typ), a); + Instruction::SetpBool(d, arg) + } + ast::Instruction::Not(d, a) => { + let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a); + Instruction::Not(d, arg) + } + ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }), + ast::Instruction::Cvt(d, a) => { + let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a); + Instruction::Cvt(d, arg) + } + ast::Instruction::Shl(d, a) => { + let arg = normalize_expand_arg3(func, new_id, &|| todo!(), a); + Instruction::Shl(d, arg) + } + ast::Instruction::St(d, a) => { + let arg = normalize_expand_arg2st(func, new_id, &|| todo!(), a); + Instruction::St(d, arg) + } + ast::Instruction::Ret(d) => Instruction::Ret(d), + } +} + +fn normalize_expand_arg2( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg2, +) -> Arg2 { + Arg2 { + dst: a.dst, + src: normalize_expand_operand(func, new_id, inst_type, a.src), + } +} + +fn normalize_expand_arg2mov( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg2Mov, +) -> Arg2 { + Arg2 { + dst: a.dst, + src: normalize_expand_mov_operand(func, new_id, inst_type, a.src), + } +} + +fn normalize_expand_arg2st( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg2St, +) -> Arg2St { + Arg2St { + src1: normalize_expand_operand(func, new_id, inst_type, a.src1), + src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + } +} + +fn normalize_expand_arg3( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg3, +) -> Arg3 { + Arg3 { + dst: a.dst, + src1: normalize_expand_operand(func, new_id, inst_type, a.src1), + src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + } +} + +fn normalize_expand_arg4( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg4, +) -> Arg4 { + Arg4 { + dst1: a.dst1, + dst2: a.dst2, + src1: normalize_expand_operand(func, new_id, inst_type, a.src1), + src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + } +} + +fn normalize_expand_arg5( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + a: ast::Arg5, +) -> Arg5 { + Arg5 { + dst1: a.dst1, + dst2: a.dst2, + src1: normalize_expand_operand(func, new_id, inst_type, a.src1), + src2: normalize_expand_operand(func, new_id, inst_type, a.src2), + src3: normalize_expand_operand(func, new_id, inst_type, a.src3), + } +} + +fn normalize_expand_operand( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + opr: ast::Operand, +) -> spirv::Word { + match opr { + ast::Operand::Reg(r) => r, + ast::Operand::Imm(x) => { + if let Some(typ) = inst_type() { + let id = new_id(); + func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: typ, + value: x, + })); + id + } else { + todo!() + } + } + _ => todo!(), + } +} + +fn normalize_expand_mov_operand( + func: &mut Vec, + new_id: &mut impl FnMut() -> spirv::Word, + inst_type: &impl Fn() -> Option, + opr: ast::MovOperand, +) -> spirv::Word { + match opr { + ast::MovOperand::Op(opr) => normalize_expand_operand(func, new_id, inst_type, opr), + _ => todo!(), + } +} + fn collect_var_definitions<'a>( args: &[ast::Argument<'a>], body: &[ast::Statement<&'a str>], @@ -249,17 +469,15 @@ fn insert_implicit_conversions ast::Type>( for s in normalized_ids.into_iter() { match s { Statement::Instruction(inst) => match inst { - ast::Instruction::Ld(ld, mut arg) => { - arg.src = arg.src.map_id(&mut |arg_src| { - insert_implicit_conversions_ld_src( - &mut result, - ast::Type::Scalar(ld.typ), - type_check, - new_id, - ld.state_space, - arg_src, - ) - }); + Instruction::Ld(ld, mut arg) => { + arg.src = insert_implicit_conversions_ld_src( + &mut result, + ast::Type::Scalar(ld.typ), + type_check, + new_id, + ld.state_space, + arg.src, + ); insert_with_implicit_conversion_dst( &mut result, ld.typ, @@ -268,40 +486,35 @@ fn insert_implicit_conversions ast::Type>( should_convert_relaxed_dst, arg, |arg| &mut arg.dst, - |arg| ast::Instruction::Ld(ld, arg), + |arg| Instruction::Ld(ld, arg), ); } - ast::Instruction::St(st, mut arg) => { - arg.src2 = arg.src2.map_id(&mut |arg_src| { - let arg_src_type = type_check(arg_src); - if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) { - insert_conversion_src( - &mut result, - new_id, - arg_src, - arg_src_type, - ast::Type::Scalar(st.typ), - conv, - ) - } else { - arg_src - } - }); - arg.src1 = arg.src1.map_id(&mut |arg_src| { - insert_implicit_conversions_ld_src( + Instruction::St(st, mut arg) => { + let arg_src2_type = type_check(arg.src2); + if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) { + arg.src2 = insert_conversion_src( &mut result, - ast::Type::Scalar(st.typ), - type_check, new_id, - st.state_space.to_ld_ss(), - arg_src, - ) - }); - result.push(Statement::Instruction(ast::Instruction::St(st, arg))); + arg.src2, + arg_src2_type, + ast::Type::Scalar(st.typ), + conv, + ); + } + arg.src1 = insert_implicit_conversions_ld_src( + &mut result, + ast::Type::Scalar(st.typ), + type_check, + new_id, + st.state_space.to_ld_ss(), + arg.src1, + ); + result.push(Statement::Instruction(Instruction::St(st, arg))); } inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst), }, s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), + Statement::Constant(_) => (), Statement::Converison(_) => unreachable!(), } } @@ -390,61 +603,52 @@ fn emit_function_body_ops( // If block starts with a label it has already been emitted, // all other labels in the block are unused Statement::Label(_) => (), + Statement::Constant(_) => todo!(), Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } Statement::Instruction(inst) => match inst { // SPIR-V does not support marking jumps as guaranteed-converged - ast::Instruction::Bra(_, arg) => { + Instruction::Bra(_, arg) => { builder.branch(arg.src)?; } - ast::Instruction::Ld(data, arg) => { + Instruction::Ld(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() { todo!() } - let src = match arg.src { - ast::Operand::Reg(id) => id, - _ => todo!(), - }; let result_type = map.get_or_add_scalar(builder, data.typ); match data.state_space { ast::LdStateSpace::Generic => { - builder.load(result_type, Some(arg.dst), src, None, [])?; + builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param => { - builder.copy_object(result_type, Some(arg.dst), src)?; + builder.copy_object(result_type, Some(arg.dst), arg.src)?; } _ => todo!(), } } - ast::Instruction::St(data, arg) => { + Instruction::St(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() || data.state_space != ast::StStateSpace::Generic { todo!() } - let dst = match arg.src1 { - ast::Operand::Reg(id) => id, - _ => todo!(), - }; - let src = match arg.src2 { - ast::Operand::Reg(id) => id, - _ => todo!(), - }; - builder.store(dst, src, None, &[])?; + builder.store(arg.src1, arg.src2, None, &[])?; } // SPIR-V does not support ret as guaranteed-converged - ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(mov, arg) => { + Instruction::Ret(_) => builder.ret()?, + Instruction::Mov(mov, arg) => { let result_type = map.get_or_add(builder, SpirvType::from(mov.typ)); - let src = match arg.src { - ast::MovOperand::Op(ast::Operand::Reg(id)) => id, - _ => todo!(), - }; - builder.copy_object(result_type, Some(arg.dst), src)?; + builder.copy_object(result_type, Some(arg.dst), arg.src)?; } + Instruction::Mul(mul, arg) => match mul.desc { + ast::MulDescriptor::Int(ref ctr) => { + emit_mul_int(builder, map, mul.typ, ctr, arg) + } + ast::MulDescriptor::Float(_) => todo!(), + }, _ => todo!(), }, } @@ -453,6 +657,17 @@ fn emit_function_body_ops( Ok(()) } +fn emit_mul_int( + _builder: &mut dr::Builder, + _map: &mut TypeWordMap, + _typ: ast::Type, + _ctr: &ast::MulIntControl, + _arg: &Arg3, +) { + //let inst_type = map.get_or_add(builder, SpirvType::from(typ)); + //builder.i_mul(inst_type, Some(arg.dst), Some(arg.src1), Some(arg.src2)); +} + fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -523,12 +738,11 @@ fn normalize_identifiers<'a>( constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined type_map: &mut HashMap, types: HashMap, ast::Type>, -) -> (Vec, spirv::Word) { - let mut result = Vec::with_capacity(func.len()); +) -> (Vec>, spirv::Word) { let mut id: u32 = constant_identifiers.len() as u32; let mut remapped_ids = HashMap::new(); - let mut get_or_add = |key| match key { - Some(key) => constant_identifiers.get(key).map_or_else( + let mut get_or_add = |key| { + constant_identifiers.get(key).map_or_else( || { *remapped_ids.entry(key).or_insert_with(|| { let to_insert = id; @@ -537,16 +751,12 @@ fn normalize_identifiers<'a>( }) }, |id| *id, - ), - None => { - let to_insert = id; - id += 1; - to_insert - } + ) }; - for s in func { - Statement::from_ast(s, &mut result, &mut get_or_add); - } + let result = func + .into_iter() + .filter_map(|s| Statement::from_ast(s, &mut get_or_add)) + .collect::>(); type_map.extend( remapped_ids .into_iter() @@ -594,7 +804,7 @@ fn apply_ssa_renaming( for s in get_bb_body(func, bbs, BBIndex(bb)) { s.visit_id(&mut |is_dst, id| { if is_dst { - old_dst_id[bb].push(*id) + old_dst_id[bb].push(id) } }); } @@ -787,8 +997,8 @@ fn gather_phi_sets( let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize]; for bb in 0..cfg.len() { let mut var_kill = HashSet::new(); - let mut visitor = |is_dst, id: &u32| { - if *id >= constant_ids { + let mut visitor = |is_dst, id: spirv::Word| { + if id >= constant_ids { let id = id - constant_ids; if is_dst { var_kill.insert(id); @@ -807,8 +1017,9 @@ fn gather_phi_sets( for s in get_bb_body(func, cfg, BBIndex(bb)) { match s { Statement::Instruction(inst) => inst.visit_id(&mut visitor), - Statement::Conditional(brc) => visitor(false, &brc.predicate), + Statement::Conditional(brc) => visitor(false, brc.predicate), Statement::Converison(conv) => conv.visit_id(&mut visitor), + Statement::Constant(cons) => cons.visit_id(&mut visitor), // label redefinition is a compile-time error Statement::Label(_) => (), } @@ -859,6 +1070,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec { unresolved_bb_edge.push((StmtIndex(idx), bra.if_false)); unresolved_bb_edge.push((StmtIndex(idx), bra.if_true)); } + Statement::Constant(_) => (), Statement::Converison(_) => (), }; } @@ -877,7 +1089,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec { bb_edge.insert((StmtIndex(target.0 - 1), target)); } } - Statement::Converison(_) | Statement::Label(_) => { + Statement::Converison(_) | Statement::Constant(_) | Statement::Label(_) => { bb_edge.insert((StmtIndex(target.0 - 1), target)); } // This is already in `unresolved_bb_edge` @@ -1043,10 +1255,241 @@ impl fmt::Display for BBIndex { enum Statement { Label(u32), - Instruction(ast::Instruction), + Instruction(Instruction), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Converison(ImplicitConversion), + Constant(ConstantDefinition), +} + +enum Instruction { + Ld(ast::LdData, Arg2), + Mov(ast::MovData, Arg2), + Mul(ast::MulData, Arg3), + Add(ast::AddData, Arg3), + Setp(ast::SetpData, Arg4), + SetpBool(ast::SetpBoolData, Arg5), + Not(ast::NotData, Arg2), + Bra(ast::BraData, Arg1), + Cvt(ast::CvtData, Arg2), + Shl(ast::ShlData, Arg3), + St(ast::StData, Arg2St), + Ret(ast::RetData), +} + +impl Instruction { + fn visit_id(&self, f: &mut F) { + match self { + Instruction::Ld(_, a) => a.visit_id(f), + Instruction::Mov(_, a) => a.visit_id(f), + Instruction::Mul(_, a) => a.visit_id(f), + Instruction::Add(_, a) => a.visit_id(f), + Instruction::Setp(_, a) => a.visit_id(f), + Instruction::SetpBool(_, a) => a.visit_id(f), + Instruction::Not(_, a) => a.visit_id(f), + Instruction::Cvt(_, a) => a.visit_id(f), + Instruction::Shl(_, a) => a.visit_id(f), + Instruction::St(_, a) => a.visit_id(f), + Instruction::Bra(_, a) => a.visit_id(f), + Instruction::Ret(_) => (), + } + } + + fn visit_id_mut(&mut self, f: &mut F) { + match self { + Instruction::Ld(_, a) => a.visit_id_mut(f), + Instruction::Mov(_, a) => a.visit_id_mut(f), + Instruction::Mul(_, a) => a.visit_id_mut(f), + Instruction::Add(_, a) => a.visit_id_mut(f), + Instruction::Setp(_, a) => a.visit_id_mut(f), + Instruction::SetpBool(_, a) => a.visit_id_mut(f), + Instruction::Not(_, a) => a.visit_id_mut(f), + Instruction::Cvt(_, a) => a.visit_id_mut(f), + Instruction::Shl(_, a) => a.visit_id_mut(f), + Instruction::St(_, a) => a.visit_id_mut(f), + Instruction::Bra(_, a) => a.visit_id_mut(f), + Instruction::Ret(_) => (), + } + } + + fn get_type(&self) -> Option { + match self { + Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)), + Instruction::Ret(_) => None, + Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), + Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), + Instruction::Mov(mov, _) => Some(mov.typ), + Instruction::Mul(mul, _) => Some(mul.typ), + _ => todo!(), + } + } + + fn jump_target(&self) -> Option { + match self { + Instruction::Bra(_, a) => Some(a.src), + Instruction::Ld(_, _) + | Instruction::Mov(_, _) + | Instruction::Mul(_, _) + | Instruction::Add(_, _) + | Instruction::Setp(_, _) + | Instruction::SetpBool(_, _) + | Instruction::Not(_, _) + | Instruction::Cvt(_, _) + | Instruction::Shl(_, _) + | Instruction::St(_, _) + | Instruction::Ret(_) => None, + } + } + + fn is_terminal(&self) -> bool { + match self { + Instruction::Ret(_) => true, + Instruction::Ld(_, _) + | Instruction::Mov(_, _) + | Instruction::Mul(_, _) + | Instruction::Add(_, _) + | Instruction::Setp(_, _) + | Instruction::SetpBool(_, _) + | Instruction::Not(_, _) + | Instruction::Cvt(_, _) + | Instruction::Shl(_, _) + | Instruction::St(_, _) + | Instruction::Bra(_, _) => false, + } + } +} + +struct Arg1 { + pub src: spirv::Word, +} + +impl Arg1 { + fn visit_id(&self, f: &mut F) { + f(false, self.src); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src); + } +} + +struct Arg2 { + pub dst: spirv::Word, + pub src: spirv::Word, +} + +impl Arg2 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); + f(false, self.src); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src); + f(true, &mut self.dst); + } +} + +pub struct Arg2St { + pub src1: spirv::Word, + pub src2: spirv::Word, +} + +impl Arg2St { + fn visit_id(&self, f: &mut F) { + f(false, self.src1); + f(false, self.src2); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src1); + f(false, &mut self.src2); + } +} + +struct Arg3 { + pub dst: spirv::Word, + pub src1: spirv::Word, + pub src2: spirv::Word, +} + +impl Arg3 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); + f(false, self.src1); + f(false, self.src2); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src1); + f(false, &mut self.src2); + f(true, &mut self.dst); + } +} + +struct Arg4 { + pub dst1: spirv::Word, + pub dst2: Option, + pub src1: spirv::Word, + pub src2: spirv::Word, +} + +impl Arg4 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst1); + self.dst2.map(|dst2| f(true, dst2)); + f(false, self.src1); + f(false, self.src2); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src1); + f(false, &mut self.src2); + f(true, &mut self.dst1); + self.dst2.as_mut().map(|dst2| f(true, dst2)); + } +} + +struct Arg5 { + pub dst1: spirv::Word, + pub dst2: Option, + pub src1: spirv::Word, + pub src2: spirv::Word, + pub src3: spirv::Word, +} + +impl Arg5 { + fn visit_id(&self, f: &mut F) { + f(true, self.dst1); + self.dst2.map(|dst2| f(true, dst2)); + f(false, self.src1); + f(false, self.src2); + f(false, self.src3); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src1); + f(false, &mut self.src2); + f(false, &mut self.src3); + f(true, &mut self.dst1); + self.dst2.as_mut().map(|dst2| f(true, dst2)); + } +} + +struct ConstantDefinition { + pub dst: spirv::Word, + pub typ: ast::ScalarType, + pub value: i128, +} + +impl ConstantDefinition { + fn visit_id(&self, f: &mut F) { + f(true, self.dst); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(true, &mut self.dst); + } } struct BrachCondition { @@ -1056,10 +1499,10 @@ struct BrachCondition { } impl BrachCondition { - fn visit_id(&self, f: &mut F) { - f(false, &self.predicate); - f(false, &self.if_true); - f(false, &self.if_false); + fn visit_id(&self, f: &mut F) { + f(false, self.predicate); + f(false, self.if_true); + f(false, self.if_false); } fn visit_id_mut(&mut self, f: &mut F) { @@ -1086,9 +1529,9 @@ enum ConversionKind { } impl ImplicitConversion { - fn visit_id(&self, f: &mut F) { - f(false, &self.src); - f(true, &self.dst); + fn visit_id(&self, f: &mut F) { + f(false, self.src); + f(true, self.dst); } fn visit_id_mut(&mut self, f: &mut F) { @@ -1098,54 +1541,27 @@ impl ImplicitConversion { } impl Statement { - fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>( + fn from_ast<'a, F: FnMut(&'a str) -> u32>( s: ast::Statement<&'a str>, - out: &mut Vec, get_id: &mut F, - ) { + ) -> Option> { match s { - ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))), - ast::Statement::Instruction(p, i) => { - if let Some(pred) = p { - let predicate = get_id(Some(pred.label)); - let mut if_true = get_id(None); - let mut if_false = get_id(None); - if pred.not { - std::mem::swap(&mut if_true, &mut if_false); - } - let folded_bra = match &i { - ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))), - _ => None, - }; - let branch = BrachCondition { - predicate, - if_true: folded_bra.unwrap_or(if_true), - if_false, - }; - out.push(Statement::Conditional(branch)); - if folded_bra.is_none() { - out.push(Statement::Label(if_true)); - out.push(Statement::Instruction( - i.map_id(&mut |name| get_id(Some(name))), - )); - } - out.push(Statement::Label(if_false)); - } else { - out.push(Statement::Instruction( - i.map_id(&mut |name| get_id(Some(name))), - )); - } - } - ast::Statement::Variable(_) => (), + ast::Statement::Label(name) => Some(ast::Statement::Label(get_id(name))), + ast::Statement::Instruction(p, i) => Some(ast::Statement::Instruction( + p.map(|p| p.map_id(get_id)), + i.map_id(get_id), + )), + ast::Statement::Variable(_) => None, } } - fn visit_id(&self, f: &mut F) { + fn visit_id(&self, f: &mut F) { match self { - Statement::Label(id) => f(false, id), + Statement::Label(id) => f(false, *id), Statement::Instruction(inst) => inst.visit_id(f), Statement::Conditional(bra) => bra.visit_id(f), Statement::Converison(conv) => conv.visit_id(f), + Statement::Constant(cons) => cons.visit_id(f), } } @@ -1157,6 +1573,16 @@ impl Statement { Statement::Instruction(inst) => inst.visit_id_mut(f), Statement::Conditional(bra) => bra.visit_id_mut(f), Statement::Converison(conv) => conv.visit_id_mut(f), + Statement::Constant(cons) => cons.visit_id_mut(f), + } + } +} + +impl ast::PredAt { + fn map_id U>(self, f: &mut F) -> ast::PredAt { + ast::PredAt { + not: self.not, + label: f(self.label), } } } @@ -1220,7 +1646,8 @@ impl ast::Instruction { ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)), ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)), ast::Instruction::Mov(mov, _) => Some(mov.typ), - _ => todo!() + ast::Instruction::Mul(mul, _) => Some(mul.typ), + _ => todo!(), } } } @@ -1476,6 +1903,15 @@ enum ScalarKind { Float, } +impl ast::Type { + fn try_as_scalar(self) -> Option { + match self { + ast::Type::Scalar(s) => Some(s), + ast::Type::ExtendedScalar(_) => None, + } + } +} + impl ast::ScalarType { fn width(self) -> u8 { match self { @@ -1688,7 +2124,7 @@ fn insert_with_implicit_conversion_dst< NewId: FnMut() -> spirv::Word, ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, Setter: Fn(&mut T) -> &mut spirv::Word, - ToInstruction: FnOnce(T) -> ast::Instruction, + ToInstruction: FnOnce(T) -> Instruction, >( func: &mut Vec, instr_type: ast::ScalarType, @@ -1821,7 +2257,7 @@ fn insert_implicit_bitcasts< func: &mut Vec, type_check: &TypeCheck, new_id: &mut NewId, - mut instr: ast::Instruction, + mut instr: Instruction, ) { let mut dst_coercion = None; if let Some(instr_type) = instr.get_type() { @@ -1984,9 +2420,9 @@ mod tests { fn get_basic_blocks_miniloop() { let func = vec![ Statement::Label(12), - Statement::Instruction(ast::Instruction::Bra( + Statement::Instruction(Instruction::Bra( ast::BraData { uniform: false }, - ast::Arg1 { src: 12 }, + Arg1 { src: 12 }, )), ]; let bbs = get_basic_blocks(&func); @@ -2226,9 +2662,10 @@ mod tests { let mut constant_ids = HashMap::new(); collect_label_ids(&mut constant_ids, &ast); let registers = collect_var_definitions(&[], &ast); - let (normalized_ids, _) = + let (normalized_ids, unique_ids) = normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers); - let mut bbs = get_basic_blocks(&normalized_ids); + let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids); + let mut bbs = get_basic_blocks(&normalized_stmts); bbs.iter_mut().for_each(sort_pred_succ); assert_eq!( bbs, @@ -2239,32 +2676,32 @@ mod tests { succ: vec![BBIndex(1)], }, BasicBlock { - start: StmtIndex(3), + start: StmtIndex(6), pred: vec![BBIndex(0), BBIndex(5)], succ: vec![BBIndex(2), BBIndex(6)], }, BasicBlock { - start: StmtIndex(6), + start: StmtIndex(10), pred: vec![BBIndex(1)], succ: vec![BBIndex(3), BBIndex(4)], }, BasicBlock { - start: StmtIndex(9), + start: StmtIndex(14), pred: vec![BBIndex(2)], succ: vec![BBIndex(5)], }, BasicBlock { - start: StmtIndex(13), + start: StmtIndex(19), pred: vec![BBIndex(2)], succ: vec![BBIndex(5)], }, BasicBlock { - start: StmtIndex(16), + start: StmtIndex(23), pred: vec![BBIndex(3), BBIndex(4)], succ: vec![BBIndex(1)], }, BasicBlock { - start: StmtIndex(18), + start: StmtIndex(25), pred: vec![BBIndex(1)], succ: vec![], }, @@ -2375,14 +2812,15 @@ mod tests { collect_label_ids(&mut constant_ids, &fn_ast); assert_eq!(constant_ids.len(), 4); let registers = collect_var_definitions(&[], &fn_ast); - let (normalized_ids, max_id) = + let (normalized_ids, unique_ids) = normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers); - let bbs = get_basic_blocks(&normalized_ids); + let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids); + let bbs = get_basic_blocks(&normalized_stmts); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); let phi = gather_phi_sets( - &normalized_ids, + &normalized_stmts, constant_ids.len() as u32, max_id, &bbs, @@ -2490,7 +2928,7 @@ mod tests { for s in func { s.visit_id(&mut |is_dst, id| { if is_dst { - assert!(seen.insert(*id)); + assert!(seen.insert(id)); } }); } @@ -2504,7 +2942,7 @@ mod tests { fn get_ids(s: &Statement) -> Vec { let mut result = Vec::new(); s.visit_id(&mut |_, id| { - result.push(*id); + result.push(id); }); result } @@ -2533,7 +2971,7 @@ mod tests { let mut result = None; s.visit_id(&mut |is_dst, id| { if is_dst { - assert_eq!(result.replace(*id), None); + assert_eq!(result.replace(id), None); } }); result.unwrap()