diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 4c1c0e7..511d763 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; - let typed_statements = - convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, @@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>( // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion -// TODO: propagate through calls? +// TODO: propagate out of calls and into calls fn convert_to_stateful_memory_access<'a, 'input>( - func_args: &Rc>>, + func_args: Rc>>, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, -) -> Result, TranslateError> { - let mut func_args = func_args.borrow_mut(); - let func_args_64bit = (*func_args) +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !method_decl.name.is_kernel() { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) .input_arguments .iter() .filter_map(|arg| match arg.v_type { @@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>( })); remapped_ids.insert(reg, new_id); } + for arg in (*method_decl).input_arguments.iter_mut() { + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); + let old_name = arg.name; + if func_args_ptr.contains(&arg.name) { + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + } + remapped_ids.insert(old_name, new_id); + } for statement in func_body { match statement { l @ Statement::Label(_) => result.push(l), @@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>( convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, - &func_args_ptr, &mut result, &mut post_statements, arg_desc, @@ -4597,81 +4619,69 @@ fn convert_to_stateful_memory_access<'a, 'input>( _ => return Err(error_unreachable()), } } - for arg in (*func_args).input_arguments.iter_mut() { - if func_args_ptr.contains(&arg.name) { - arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); - arg.state_space = ast::StateSpace::Reg; - } - } - Ok(result) + drop(method_decl); + Ok((func_args, result)) } fn convert_to_stateful_memory_access_postprocess( id_defs: &mut NumericIdResolver, remapped_ids: &HashMap, - func_args_ptr: &HashSet, result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, expected_type: Option<(&ast::Type, ast::StateSpace)>, ) -> Result { - Ok( - match remapped_ids - .get(&arg_desc.op) - .or_else(|| func_args_ptr.get(&arg_desc.op)) - { - Some(new_id) => { - let (new_operand_type, new_operand_space, is_variable) = - id_defs.get_typed(*new_id)?; - if let Some((expected_type, expected_space)) = expected_type { - let implicit_conversion = arg_desc - .non_default_implicit_conversion - .unwrap_or(default_implicit_conversion); - if implicit_conversion( - (new_operand_space, &new_operand_type), - (expected_space, expected_type), - ) - .is_ok() - { - return Ok(*new_id); - } - } - let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; - let new_operand_type_clone = new_operand_type.clone(); - let converting_id = id_defs - .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { - ConversionKind::Default - } else { - ConversionKind::PtrToPtr - }; - if arg_desc.is_dst { - post_statements.push(Statement::Conversion(ImplicitConversion { - src: converting_id, - dst: *new_id, - from_type: old_operand_type, - from_space: old_operand_space, - to_type: new_operand_type, - to_space: new_operand_space, - kind, - })); - converting_id - } else { - result.push(Statement::Conversion(ImplicitConversion { - src: *new_id, - dst: converting_id, - from_type: new_operand_type, - from_space: new_operand_space, - to_type: old_operand_type, - to_space: old_operand_space, - kind, - })); - converting_id + Ok(match remapped_ids.get(&arg_desc.op) { + Some(new_id) => { + let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?; + if let Some((expected_type, expected_space)) = expected_type { + let implicit_conversion = arg_desc + .non_default_implicit_conversion + .unwrap_or(default_implicit_conversion); + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); } } - None => arg_desc.op, - }, - ) + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?; + let new_operand_type_clone = new_operand_type.clone(); + let converting_id = + id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr + }; + if arg_desc.is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } + } + None => arg_desc.op, + }) } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs index f168930..ffd1498 100644 --- a/zluda_dump/src/lib.rs +++ b/zluda_dump/src/lib.rs @@ -219,15 +219,18 @@ unsafe fn to_str(image: *const T) -> Option<&'static str> { fn directive_to_kernel(dir: &ast::Directive) -> Option<(String, Vec)> { match dir { - ast::Directive::Method(ast::Function { - func_directive: - ast::MethodDeclaration { - name: ast::MethodName::Kernel(name), - input_arguments, - .. - }, - .. - }) => { + ast::Directive::Method( + _, + ast::Function { + func_directive: + ast::MethodDeclaration { + name: ast::MethodName::Kernel(name), + input_arguments, + .. + }, + .. + }, + ) => { let arg_sizes = input_arguments .iter() .map(|arg| ast::Type::from(arg.v_type.clone()).size_of())