diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 61b255d..4c1c0e7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4136,6 +4136,11 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + } _ => unreachable!(), } Ok(()) @@ -4610,72 +4615,63 @@ fn convert_to_stateful_memory_access_postprocess( arg_desc: ArgumentDescriptor, expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { - Ok(match remapped_ids.get(&arg_desc.op) { - Some(new_id) => { - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some((ast::Type::Pointer(_, ast::StateSpace::Global), ast::StateSpace::Reg)) => { - return Ok(*new_id) - } - _ => id_defs.get_typed(arg_desc.op)?.0, - }; - let old_type_clone = old_type.clone(); - let converting_id = - id_defs.register_intermediate(Some((old_type_clone, ast::StateSpace::Reg))); - if arg_desc.is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_type, - from_space: ast::StateSpace::Reg, - to_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::BitToPtr, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - from_space: ast::StateSpace::Reg, - to_type: old_type, - to_space: ast::StateSpace::Reg, - kind: ConversionKind::AddressOf, - })); - converting_id - } - } - None => match func_args_ptr.get(&arg_desc.op) { + Ok( + match remapped_ids + .get(&arg_desc.op) + .or_else(|| func_args_ptr.get(&arg_desc.op)) + { Some(new_id) => { - if arg_desc.is_dst { - return Err(error_unreachable()); + let (new_operand_type, new_operand_space, is_variable) = + id_defs.get_typed(*new_id)?; + if let Some((expected_type, expected_space)) = expected_type { + let implicit_conversion = arg_desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); + } } - // We skip conversion here to trigger PtrAcces in a later pass - let old_type = match expected_type { - Some(( - ast::Type::Pointer(_, ast::StateSpace::Global), - ast::StateSpace::Reg, - )) => return Ok(*new_id), - _ => id_defs.get_typed(arg_desc.op)?.0, + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; + let new_operand_type_clone = new_operand_type.clone(); + let converting_id = id_defs + .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr }; - let old_type_clone = old_type.clone(); - let converting_id = - id_defs.register_intermediate(Some((old_type, ast::StateSpace::Reg))); - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), - from_space: ast::StateSpace::Reg, - to_type: old_type_clone, - to_space: ast::StateSpace::Reg, - kind: ConversionKind::PtrToPtr, - })); - converting_id + if arg_desc.is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } } None => arg_desc.op, }, - }) + ) } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool {