diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f18b15c..027e891 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -137,6 +137,7 @@ test_ptx!(stateful_ld_st_simple, [121u64], [121u64]); test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]); test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]); test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]); +test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); struct DisplayError { err: T, @@ -261,6 +262,7 @@ fn test_spvtxt_assert<'a>( let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; assert!(errors.len() == 0); let spirv_module = translate::to_spirv_module(ast)?; + eprintln!("{}", rspirv::binary::Disassemble::disassemble(&spirv_module.spirv)); let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut()); diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.ptx b/ptx/src/test/spirv_run/shared_ptr_take_address.ptx new file mode 100644 index 0000000..e892993 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.ptx @@ -0,0 +1,27 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .shared .align 4 .b8 shared_mem[]; + +.visible .entry shared_ptr_take_address( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 shared_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + mov.u64 shared_addr, shared_mem; + + ld.global.u64 temp1, [in_addr]; + st.shared.u64 [shared_addr], temp1; + ld.shared.u64 temp2, [shared_addr]; + st.global.u64 [out_addr], temp2; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt new file mode 100644 index 0000000..d77c2c8 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -0,0 +1,68 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %33 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 + OpDecorate %1 Alignment 4 + %void = OpTypeVoid + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar +%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar + %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup + %ulong = OpTypeInt 64 0 + %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar +%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %2 = OpFunction %void None %39 + %10 = OpFunctionParameter %ulong + %11 = OpFunctionParameter %ulong + %31 = OpFunctionParameter %_ptr_Workgroup_uchar + %40 = OpLabel + %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %31 + OpBranch %29 + %29 = OpLabel + OpStore %3 %10 + OpStore %4 %11 + %12 = OpLoad %ulong %3 + OpStore %5 %12 + %13 = OpLoad %ulong %4 + OpStore %6 %13 + %15 = OpLoad %_ptr_Workgroup_uchar %32 + %24 = OpConvertPtrToU %ulong %15 + %14 = OpCopyObject %ulong %24 + OpStore %7 %14 + %17 = OpLoad %ulong %5 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 + %16 = OpLoad %ulong %25 + OpStore %8 %16 + %18 = OpLoad %ulong %7 + %19 = OpLoad %ulong %8 + %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18 + OpStore %26 %19 + %21 = OpLoad %ulong %7 + %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21 + %20 = OpLoad %ulong %27 + OpStore %9 %20 + %22 = OpLoad %ulong %6 + %23 = OpLoad %ulong %9 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22 + OpStore %28 %23 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 328bf30..20c3edb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -843,24 +843,25 @@ fn replace_uses_of_shared_memory<'a>( statement => { let new_statement = statement.map_id(&mut |id, _| { if let Some(typ) = extern_shared_decls.get(&id) { - let replacement_id = new_id(); - if *typ != ast::SizedScalarType::B8 { - result.push(Statement::Conversion(ImplicitConversion { - src: shared_var_id, - dst: replacement_id, - from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::B8), - ast::LdStateSpace::Shared, - ), - to: ast::Type::Pointer( - ast::PointerType::Scalar((*typ).into()), - ast::LdStateSpace::Shared, - ), - kind: ConversionKind::PtrToPtr { spirv_ptr: true }, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, - })); + if *typ == ast::SizedScalarType::B8 { + return shared_var_id; } + let replacement_id = new_id(); + result.push(Statement::Conversion(ImplicitConversion { + src: shared_var_id, + dst: replacement_id, + from: ast::Type::Pointer( + ast::PointerType::Scalar(ast::ScalarType::B8), + ast::LdStateSpace::Shared, + ), + to: ast::Type::Pointer( + ast::PointerType::Scalar((*typ).into()), + ast::LdStateSpace::Shared, + ), + kind: ConversionKind::PtrToPtr { spirv_ptr: true }, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); replacement_id } else { id