diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs new file mode 100644 index 0000000..3060a70 --- /dev/null +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -0,0 +1,535 @@ +use super::*; +use ptx_parser as ast; +use std::{ + collections::{BTreeSet, HashSet}, + iter, + rc::Rc, +}; + +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ +// TODO: detect more patterns (mov, call via reg, call via param) +// TODO: don't convert to ptr if the register is not ultimately used for ld/st +// TODO: once insert_mem_ssa_statements is moved to later, move this pass after +// argument expansion +// TODO: propagate out of calls and into calls +pub(super) fn run<'a, 'input>( + func_args: Rc>>, + func_body: Vec, + id_defs: &mut NumericIdResolver<'a>, +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !matches!(method_decl.name, ast::MethodName::Kernel(..)) { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) + .input_arguments + .iter() + .filter_map(|arg| match arg.v_type { + ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), + _ => None, + }) + .collect::>(); + let mut stateful_markers = Vec::new(); + let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Cvta { + data: + ast::CvtaDetails { + state_space: ast::StateSpace::Global, + direction: ast::CvtaDirection::GenericToExplicit, + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, src)); + } + } + } + Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), + .. + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if func_args_64bit.contains(&src) { + multi_hash_map_append(&mut stateful_init_reg, dst, src); + } + } + } + _ => {} + } + } + if stateful_markers.len() == 0 { + drop(method_decl); + return Ok((func_args, func_body)); + } + let mut func_args_ptr = HashSet::new(); + let mut regs_ptr_current = HashSet::new(); + for (dst, src) in stateful_markers { + if let Some(func_args) = stateful_init_reg.get(&src) { + for a in func_args { + func_args_ptr.insert(*a); + regs_ptr_current.insert(src); + regs_ptr_current.insert(dst); + } + } + } + // BTreeSet here to have a stable order of iteration, + // unfortunately our tests rely on it + let mut regs_ptr_seen = BTreeSet::new(); + while regs_ptr_current.len() > 0 { + let mut regs_ptr_new = HashSet::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + _ => {} + } + } + for id in regs_ptr_current { + regs_ptr_seen.insert(id); + } + regs_ptr_current = regs_ptr_new; + } + drop(regs_ptr_current); + let mut remapped_ids = HashMap::new(); + let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); + for reg in regs_ptr_seen { + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); + result.push(Statement::Variable(ast::Variable { + align: None, + name: new_id, + array_init: Vec::new(), + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, + })); + remapped_ids.insert(reg, new_id); + } + for arg in (*method_decl).input_arguments.iter_mut() { + if !func_args_ptr.contains(&arg.name) { + continue; + } + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, + ); + let old_name = arg.name; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + remapped_ids.insert(old_name, new_id); + } + for statement in func_body { + match statement { + l @ Statement::Label(_) => result.push(l), + c @ Statement::Conditional(_) => result.push(c), + c @ Statement::Constant(..) => result.push(c), + Statement::Variable(var) => { + if !remapped_ids.contains_key(&var.name) { + result.push(Statement::Variable(var)); + } + } + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_add_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(src1) if remapped_ids.contains_key(&src1) => { + (remapped_ids.get(&src1).unwrap(), arguments.src2) + } + Some(src2) if remapped_ids.contains_key(&src2) => { + (remapped_ids.get(&src2).unwrap(), arguments.src1) + } + _ => return Err(error_unreachable()), + }; + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: offset, + })) + } + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_sub_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2), + _ => return Err(error_unreachable()), + }; + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + result.push(Statement::Instruction(ast::Instruction::Neg { + data: ast::TypeFtz { + type_: ast::ScalarType::S64, + flush_to_zero: None, + }, + arguments: ast::NegArgs { + src: offset, + dst: TypedOperand::Reg(offset_neg), + }, + })); + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: TypedOperand::Reg(offset_neg), + })) + } + inst @ Statement::Instruction(_) => { + let mut post_statements = Vec::new(); + let new_statement = inst.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + repack @ Statement::RepackVector(_) => { + let mut post_statements = Vec::new(); + let new_statement = repack.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + _ => return Err(error_unreachable()), + } + } + drop(method_decl); + Ok((func_args, result)) +} + +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { + match id_defs.get_typed(id) { + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, + _ => false, + } +} + +fn multi_hash_map_append< + K: Eq + std::hash::Hash, + V, + Collection: std::iter::Extend + std::default::Default, +>( + m: &mut HashMap, + key: K, + value: V, +) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().extend(iter::once(value)); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Default::default()).extend(iter::once(value)); + } + } +} + +fn is_add_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::AddArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + if let Some(ref src1_reg) = arg.src1.underlying_register() { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when adding two pointers + if let Some(ref src2_reg) = arg.src2.underlying_register() { + return !remapped_ids.contains_key(src2_reg); + } + } + } + if let Some(ref src2_reg) = arg.src2.underlying_register() { + remapped_ids.contains_key(src2_reg) + } else { + false + } + } + } +} + +fn is_sub_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::SubArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.underlying_register() { + Some(ref src1_reg) => { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when subtracting two pointers + arg.src2 + .underlying_register() + .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg)) + } else { + false + } + } + None => false, + } + } + } +} + +fn convert_to_stateful_memory_access_postprocess( + id_defs: &mut NumericIdResolver, + remapped_ids: &HashMap, + result: &mut Vec, + post_statements: &mut Vec, + operand: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_conversion: bool, +) -> Result { + operand.map(|operand, _| { + Ok(match remapped_ids.get(&operand) { + Some(new_id) => { + let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; + // TODO: readd if required + if let Some(..) = type_space { + if relaxed_conversion { + return Ok(*new_id); + } + } + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; + let converting_id = id_defs + .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr + }; + if is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } + } + None => operand, + }) + }) +} + +fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { + this == other + || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg +} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs index d94786e..871537d 100644 --- a/ptx/src/pass/fix_special_registers.rs +++ b/ptx/src/pass/fix_special_registers.rs @@ -1,7 +1,7 @@ use super::*; use std::collections::HashMap; -fn run<'a, 'b, 'input>( +pub(super) fn run<'a, 'b, 'input>( ptx_impl_imports: &'a mut HashMap>, typed_statements: Vec, numeric_id_defs: &'a mut NumericIdResolver<'b>, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index b3bfa72..439233a 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -5,9 +5,11 @@ use std::{ cell::RefCell, collections::{hash_map, HashMap}, ffi::CString, + marker::PhantomData, rc::Rc, }; +mod convert_to_stateful_memory_access; mod convert_to_typed; mod fix_special_registers; mod normalize_identifiers; @@ -169,12 +171,12 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?; todo!() /* - let typed_statements = - fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; - let (func_decl, typed_statements) = - convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -1035,7 +1037,7 @@ struct FunctionPointerDetails { src: SpirvWord, } -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] struct SpirvWord(spirv::Word); impl From for SpirvWord { @@ -1117,6 +1119,20 @@ impl TypedOperand { TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), }) } + + fn underlying_register(&self) -> Option { + match self { + Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r), + Self::Imm(_) => None, + } + } + + fn unwrap_reg(&self) -> Result { + match self { + TypedOperand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), + } + } } impl ast::Operand for TypedOperand { @@ -1126,3 +1142,67 @@ impl ast::Operand for TypedOperand { TypedOperand::Reg(ident) } } + +impl ast::VisitorMap + for FnVisitor +where + Fn: FnMut( + TypedOperand, + Option<(&ast::Type, ast::StateSpace)>, + bool, + bool, + ) -> Result, +{ + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self.fn_)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match (self.fn_)( + TypedOperand::Reg(args), + type_space, + is_dst, + relaxed_type_check, + )? { + TypedOperand::Reg(reg) => Ok(reg), + _ => Err(TranslateError::Unreachable), + } + } +} + +struct FnVisitor< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, +> { + fn_: Fn, + _marker: PhantomData Result>, +} + +impl< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, + > FnVisitor +{ + fn new(fn_: Fn) -> Self { + Self { + fn_, + _marker: PhantomData, + } + } +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db1063b..9b422fd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>( for statement in sorted_statements { match statement { Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Shared, .. }, ) | Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Global, .. }, @@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, @@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, @@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Float { op: ast::AtomFloatOp::Add, diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 5175b2d..59815f2 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -760,6 +760,7 @@ pub enum Type { Vector(ScalarType, u8), // .param.b32 foo[4]; Array(ScalarType, Vec), + Pointer(ScalarType, StateSpace) } impl Type {