From df035592ede497581999c9918e757c3b98162f1c Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 11 Sep 2025 01:12:47 +0000 Subject: [PATCH] Make .param addressable --- ptx/src/pass/insert_explicit_load_store.rs | 40 ++++++++++++++++++++ ptx/src/pass/insert_implicit_conversions2.rs | 9 +++-- ptx/src/pass/llvm/mod.rs | 5 ++- ptx/src/test/ll/param_is_addressable.ll | 34 +++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 ptx/src/test/ll/param_is_addressable.ll diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 32597c5..2805dfa 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -114,6 +114,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction)); result.extend(visitor.post.drain(..).map(Statement::Instruction)); } + Statement::Instruction(ast::Instruction::Mov { data, arguments }) => { + let instruction = visitor.visit_mov(data, arguments); + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } Statement::PtrAccess(ptr_access) => { let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?); let statement = statement.visit_map(visitor)?; @@ -293,6 +300,39 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }) } + fn visit_mov( + &mut self, + data: ptx_parser::MovDetails, + mut arguments: ptx_parser::MovArgs, + ) -> ast::Instruction { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + name, + new_space, + old_space, + } => { + let generic_var = self + .resolver + .register_unnamed(Some((data.typ.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Cvta { + data: ast::CvtaDetails { + state_space: *new_space, + direction: ast::CvtaDirection::ExplicitToGeneric, + }, + arguments: ast::CvtaArgs { + dst: generic_var, + src: *name, + }, + }); + arguments.src = generic_var; + } + } + } + ast::Instruction::Mov { data, arguments } + } + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { let old_space = match var.state_space { space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space, diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs index bd1675c..b1e473b 100644 --- a/ptx/src/pass/insert_implicit_conversions2.rs +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -152,11 +152,11 @@ fn is_addressable(this: ast::StateSpace) -> bool { | ast::StateSpace::Generic | ast::StateSpace::Global | ast::StateSpace::Local - | ast::StateSpace::Shared => true, + | ast::StateSpace::Shared + | ast::StateSpace::ParamEntry => true, ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta - | ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => todo!(), } } @@ -180,7 +180,8 @@ fn default_implicit_conversion_space( | ast::StateSpace::Generic | ast::StateSpace::Const | ast::StateSpace::Local - | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + | ast::StateSpace::Shared + | ast::StateSpace::Param => Ok(Some(ConversionKind::BitToPtr)), _ => Err(error_mismatched_type()), }, ast::Type::Scalar(ast::ScalarType::B32) @@ -220,7 +221,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool { ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Local - | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCta | ast::StateSpace::SharedCluster | ast::StateSpace::Shared => true, ast::StateSpace::Reg diff --git a/ptx/src/pass/llvm/mod.rs b/ptx/src/pass/llvm/mod.rs index 5e5705c..cd1814d 100644 --- a/ptx/src/pass/llvm/mod.rs +++ b/ptx/src/pass/llvm/mod.rs @@ -181,7 +181,10 @@ fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Param => Err(error_todo()), + // This is dodgy, we try our best to convert all .param into either + // .param::entry or .local, but we can't always succeed. + // In those cases we convert .param into generic address space + ast::StateSpace::Param => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamFunc => Err(error_todo()), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), diff --git a/ptx/src/test/ll/param_is_addressable.ll b/ptx/src/test/ll/param_is_addressable.ll new file mode 100644 index 0000000..6f75d5e --- /dev/null +++ b/ptx/src/test/ll/param_is_addressable.ll @@ -0,0 +1,34 @@ +define amdgpu_kernel void @param_is_addressable(ptr addrspace(4) byref(i64) %"33", ptr addrspace(4) byref(i64) %"34") #0 { + %"35" = alloca i64, align 8, addrspace(5) + %"36" = alloca i64, align 8, addrspace(5) + %"37" = alloca i64, align 8, addrspace(5) + br label %1 + +1: ; preds = %0 + br label %"32" + +"32": ; preds = %1 + %"38" = load i64, ptr addrspace(4) %"33", align 8 + store i64 %"38", ptr addrspace(5) %"35", align 8 + %"39" = load i64, ptr addrspace(4) %"34", align 8 + store i64 %"39", ptr addrspace(5) %"36", align 8 + %"49" = ptrtoint ptr addrspace(4) %"33" to i64 + %2 = inttoptr i64 %"49" to ptr addrspace(4) + %"40" = addrspacecast ptr addrspace(4) %2 to ptr + store ptr %"40", ptr addrspace(5) %"37", align 8 + %"43" = load i64, ptr addrspace(5) %"37", align 8 + %"50" = inttoptr i64 %"43" to ptr + %"42" = load i64, ptr %"50", align 8 + store i64 %"42", ptr addrspace(5) %"37", align 8 + %"45" = load i64, ptr addrspace(5) %"37", align 8 + %"46" = load i64, ptr addrspace(5) %"35", align 8 + %"51" = sub i64 %"45", %"46" + store i64 %"51", ptr addrspace(5) %"37", align 8 + %"47" = load i64, ptr addrspace(5) %"36", align 8 + %"48" = load i64, ptr addrspace(5) %"37", align 8 + %"53" = inttoptr i64 %"47" to ptr + store i64 %"48", ptr %"53", align 8 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee" "denormal-fp-math-f32"="ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } \ No newline at end of file