Add stateless-to-stateful conversion

This commit is contained in:
Andrzej Janik 2024-08-26 18:31:06 +02:00
parent 107f1eb17f
commit 3e0a15ac84
5 changed files with 627 additions and 21 deletions

View file

@ -0,0 +1,535 @@
use super::*;
use ptx_parser as ast;
use std::{
collections::{BTreeSet, HashSet},
iter,
rc::Rc,
};
/*
Our goal here is to transform
.visible .entry foobar(.param .u64 input) {
.reg .b64 in_addr;
.reg .b64 in_addr2;
ld.param.u64 in_addr, [input];
cvta.to.global.u64 in_addr2, in_addr;
}
into:
.visible .entry foobar(.param .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
ld.param.u8[] in_addr, [input];
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.reg .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
mov.u8[] in_addr, input;
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.param ptr<u8, global> input) {
.reg ptr<u8, global> in_addr;
.reg ptr<u8, global> in_addr2;
ld.param.ptr<u8, global> in_addr, [input];
mov.ptr<u8, global> in_addr2, in_addr;
}
*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate out of calls and into calls
pub(super) fn run<'a, 'input>(
func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<
(
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
Vec<TypedStatement>,
),
TranslateError,
> {
let mut method_decl = func_args.borrow_mut();
if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
drop(method_decl);
return Ok((func_args, func_body));
}
if Rc::strong_count(&func_args) != 1 {
return Err(error_unreachable());
}
let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
_ => None,
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta {
data:
ast::CvtaDetails {
state_space: ast::StateSpace::Global,
direction: ast::CvtaDirection::GenericToExplicit,
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
stateful_markers.push((dst, src));
}
}
}
Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if func_args_64bit.contains(&src) {
multi_hash_map_append(&mut stateful_init_reg, dst, src);
}
}
}
_ => {}
}
}
if stateful_markers.len() == 0 {
drop(method_decl);
return Ok((func_args, func_body));
}
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
if let Some(func_args) = stateful_init_reg.get(&src) {
for a in func_args {
func_args_ptr.insert(*a);
regs_ptr_current.insert(src);
regs_ptr_current.insert(dst);
}
}
}
// BTreeSet here to have a stable order of iteration,
// unfortunately our tests rely on it
let mut regs_ptr_seen = BTreeSet::new();
while regs_ptr_current.len() > 0 {
let mut regs_ptr_new = HashSet::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}
Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}
_ => {}
}
}
for id in regs_ptr_current {
regs_ptr_seen.insert(id);
}
regs_ptr_current = regs_ptr_new;
}
drop(regs_ptr_current);
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
for arg in (*method_decl).input_arguments.iter_mut() {
if !func_args_ptr.contains(&arg.name) {
continue;
}
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Param,
);
let old_name = arg.name;
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.name = new_id;
remapped_ids.insert(old_name, new_id);
}
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
c @ Statement::Conditional(_) => result.push(c),
c @ Statement::Constant(..) => result.push(c),
Statement::Variable(var) => {
if !remapped_ids.contains_key(&var.name) {
result.push(Statement::Variable(var));
}
}
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_add_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(src1) if remapped_ids.contains_key(&src1) => {
(remapped_ids.get(&src1).unwrap(), arguments.src2)
}
Some(src2) if remapped_ids.contains_key(&src2) => {
(remapped_ids.get(&src2).unwrap(), arguments.src1)
}
_ => return Err(error_unreachable()),
};
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
_ => return Err(error_unreachable()),
};
let offset_neg = id_defs.register_intermediate(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
result.push(Statement::Instruction(ast::Instruction::Neg {
data: ast::TypeFtz {
type_: ast::ScalarType::S64,
flush_to_zero: None,
},
arguments: ast::NegArgs {
src: offset,
dst: TypedOperand::Reg(offset_neg),
},
}));
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
}))
}
inst @ Statement::Instruction(_) => {
let mut post_statements = Vec::new();
let new_statement = inst.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
repack @ Statement::RepackVector(_) => {
let mut post_statements = Vec::new();
let new_statement = repack.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
drop(method_decl);
Ok((func_args, result))
}
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
match id_defs.get_typed(id) {
Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
fn multi_hash_map_append<
K: Eq + std::hash::Hash,
V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(Default::default()).extend(iter::once(value));
}
}
}
fn is_add_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::AddArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
if let Some(ref src1_reg) = arg.src1.underlying_register() {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when adding two pointers
if let Some(ref src2_reg) = arg.src2.underlying_register() {
return !remapped_ids.contains_key(src2_reg);
}
}
}
if let Some(ref src2_reg) = arg.src2.underlying_register() {
remapped_ids.contains_key(src2_reg)
} else {
false
}
}
}
}
fn is_sub_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::SubArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
match arg.src1.underlying_register() {
Some(ref src1_reg) => {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when subtracting two pointers
arg.src2
.underlying_register()
.map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
} else {
false
}
}
None => false,
}
}
}
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
operand: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_conversion: bool,
) -> Result<TypedOperand, TranslateError> {
operand.map(|operand, _| {
Ok(match remapped_ids.get(&operand) {
Some(new_id) => {
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
// TODO: readd if required
if let Some(..) = type_space {
if relaxed_conversion {
return Ok(*new_id);
}
}
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
};
if is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_operand_type,
from_space: old_operand_space,
to_type: new_operand_type,
to_space: new_operand_space,
kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: new_operand_type,
from_space: new_operand_space,
to_type: old_operand_type,
to_space: old_operand_space,
kind,
}));
converting_id
}
}
None => operand,
})
})
}
fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
this == other
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}

View file

@ -1,7 +1,7 @@
use super::*;
use std::collections::HashMap;
fn run<'a, 'b, 'input>(
pub(super) fn run<'a, 'b, 'input>(
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,

View file

@ -5,9 +5,11 @@ use std::{
cell::RefCell,
collections::{hash_map, HashMap},
ffi::CString,
marker::PhantomData,
rc::Rc,
};
mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod fix_special_registers;
mod normalize_identifiers;
@ -169,12 +171,12 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
todo!()
/*
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@ -1035,7 +1037,7 @@ struct FunctionPointerDetails {
src: SpirvWord,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct SpirvWord(spirv::Word);
impl From<spirv::Word> for SpirvWord {
@ -1117,6 +1119,20 @@ impl TypedOperand {
TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
})
}
fn underlying_register(&self) -> Option<SpirvWord> {
match self {
Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r),
Self::Imm(_) => None,
}
}
fn unwrap_reg(&self) -> Result<SpirvWord, TranslateError> {
match self {
TypedOperand::Reg(reg) => Ok(*reg),
_ => Err(error_unreachable()),
}
}
}
impl ast::Operand for TypedOperand {
@ -1126,3 +1142,67 @@ impl ast::Operand for TypedOperand {
TypedOperand::Reg(ident)
}
}
impl<Fn> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for FnVisitor<TypedOperand, TypedOperand, TranslateError, Fn>
where
Fn: FnMut(
TypedOperand,
Option<(&ast::Type, ast::StateSpace)>,
bool,
bool,
) -> Result<TypedOperand, TranslateError>,
{
fn visit(
&mut self,
args: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
(self.fn_)(args, type_space, is_dst, relaxed_type_check)
}
fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match (self.fn_)(
TypedOperand::Reg(args),
type_space,
is_dst,
relaxed_type_check,
)? {
TypedOperand::Reg(reg) => Ok(reg),
_ => Err(TranslateError::Unreachable),
}
}
}
struct FnVisitor<
T,
U,
Err,
Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
> {
fn_: Fn,
_marker: PhantomData<fn(T) -> Result<U, Err>>,
}
impl<
T,
U,
Err,
Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
> FnVisitor<T, U, Err, Fn>
{
fn new(fn_: Fn) -> Self {
Self {
fn_,
_marker: PhantomData,
}
}
}

View file

@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
for statement in sorted_statements {
match statement {
Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,

View file

@ -760,6 +760,7 @@ pub enum Type {
Vector(ScalarType, u8),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
Pointer(ScalarType, StateSpace)
}
impl Type {