diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6b9dcfb..61b255d 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - //let typed_statements = - // convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; + let typed_statements = + convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -4307,14 +4307,14 @@ fn expand_map_variables<'a, 'b>( // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion // TODO: propagate through calls? -/* -fn convert_to_stateful_memory_access<'a>( - func_args: &mut SpirvMethodDecl, +fn convert_to_stateful_memory_access<'a, 'input>( + func_args: &Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, ) -> Result, TranslateError> { - let func_args_64bit = func_args - .input + let mut func_args = func_args.borrow_mut(); + let func_args_64bit = (*func_args) + .input_arguments .iter() .filter_map(|arg| match arg.v_type { ast::Type::Scalar(ast::ScalarType::U64) @@ -4445,15 +4445,15 @@ fn convert_to_stateful_memory_access<'a>( let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { let new_id = id_defs.register_variable( - ast::Type::Pointer(ast::ScalarType::U8), - ast::StateSpace::Global, + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, ); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), - v_type: ast::Type::Pointer(ast::ScalarType::U8), - state_space: ast::StateSpace::Global, + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, })); remapped_ids.insert(reg, new_id); } @@ -4515,8 +4515,10 @@ fn convert_to_stateful_memory_access<'a>( } _ => return Err(error_unreachable()), }; - let offset_neg = - id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64))); + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, @@ -4538,9 +4540,8 @@ fn convert_to_stateful_memory_access<'a>( } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4550,16 +4551,14 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4569,16 +4568,14 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } Statement::RepackVector(pack) => { let mut post_statements = Vec::new(); - let new_statement = pack.visit( - &mut |arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>| { + let new_statement = + pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4588,18 +4585,17 @@ fn convert_to_stateful_memory_access<'a>( arg_desc, expected_type, ) - }, - )?; + })?; result.push(new_statement); result.extend(post_statements); } _ => return Err(error_unreachable()), } } - for arg in func_args.input.iter_mut() { + for arg in (*func_args).input_arguments.iter_mut() { if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8); - arg.state_space = ast::StateSpace::Global; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.state_space = ast::StateSpace::Reg; } } Ok(result) @@ -4612,43 +4608,40 @@ fn convert_to_stateful_memory_access_postprocess( result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, - expected_type: Option<&ast::Type>, + expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), + Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => { + return Ok(*new_id) + } _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.register_intermediate(Some(old_type_clone)); + let converting_id = + id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg))); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, from_type: old_type, - to_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::StateSpace::Global, - ), - kind: ConversionKind::BitToPtr(ast::StateSpace::Global), - src_ - dst_sema: arg_desc.sema, + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::BitToPtr, })); converting_id } else { result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from_type: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::StateSpace::Global, - ), + from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + from_space: ast::StateSpace::Reg, to_type: old_type, - kind: ConversionKind::PtrToBit(ast::ScalarType::U64), - src_sema: arg_desc.sema, - dst_ + to_space: ast::StateSpace::Reg, + kind: ConversionKind::AddressOf, })); converting_id } @@ -4660,22 +4653,23 @@ fn convert_to_stateful_memory_access_postprocess( } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { - Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id), + Some(( + ast::Type::Pointer(_, ast::StateSpace::Global), + ast::StateSpace::Reg, + )) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); - let converting_id = id_defs.register_intermediate(Some(old_type)); + let converting_id = + id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg))); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, - from_type: ast::Type::Pointer( - ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - ast::StateSpace::Param, - ), + from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + from_space: ast::StateSpace::Reg, to_type: old_type_clone, - kind: ConversionKind::PtrToPtr { spirv_ptr: false }, - src_sema: arg_desc.sema, - dst_ + to_space: ast::StateSpace::Reg, + kind: ConversionKind::PtrToPtr, })); converting_id } @@ -4710,7 +4704,6 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { _ => false, } } -*/ #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister {