From 90770bbe6654b5fc69246416cfb2fd2436475351 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 25 Sep 2025 18:46:34 +0000 Subject: [PATCH] Add test for 128bit atomics --- comgr/src/lib.rs | 6 ++++++ ptx/src/pass/llvm/emit.rs | 21 +++++++++++++++++---- ptx/src/test/spirv_run/atomics_128.ptx | 24 ++++++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + 4 files changed, 48 insertions(+), 4 deletions(-) create mode 100644 ptx/src/test/spirv_run/atomics_128.ptx diff --git a/comgr/src/lib.rs b/comgr/src/lib.rs index 8546203..9e36ab6 100644 --- a/comgr/src/lib.rs +++ b/comgr/src/lib.rs @@ -219,6 +219,12 @@ pub fn compile_bitcode( compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_language(Language::LlvmIr)?; let common_options = [ + // Uncomment for LLVM debug + //c"-mllvm", + //c"-debug", + // Uncomment to save passes + // c"-mllvm", + // c"-print-before-all", c"-mllvm", c"-ignore-tti-inline-compatible", // c"-mllvm", diff --git a/ptx/src/pass/llvm/emit.rs b/ptx/src/pass/llvm/emit.rs index eebbbc2..3d56f3a 100644 --- a/ptx/src/pass/llvm/emit.rs +++ b/ptx/src/pass/llvm/emit.rs @@ -540,9 +540,8 @@ impl<'a> MethodEmitContext<'a> { arguments: ast::LdArgs, ) -> Result<(), TranslateError> { let builder = self.builder; - let needs_cast = !matches!(data.typ, ast::Type::Scalar(_)) - && !matches!(data.qualifier, ast::LdStQualifier::Weak); let underlying_type = get_type(self.context, &data.typ)?; + let needs_cast = not_supported_by_atomics(data.qualifier, underlying_type); let op_type = if needs_cast { unsafe { LLVMIntTypeInContext(self.context, data.typ.layout().size() as u32 * 8) } } else { @@ -767,8 +766,8 @@ impl<'a> MethodEmitContext<'a> { arguments: ast::StArgs, ) -> Result<(), TranslateError> { let ptr = self.resolver.value(arguments.src1)?; - let needs_cast = !matches!(data.typ, ast::Type::Scalar(_)) - && !matches!(data.qualifier, ast::LdStQualifier::Weak); + let underlying_type = get_type(self.context, &data.typ)?; + let needs_cast = not_supported_by_atomics(data.qualifier, underlying_type); let mut value = self.resolver.value(arguments.src2)?; if needs_cast { value = unsafe { @@ -2940,6 +2939,20 @@ impl<'a> MethodEmitContext<'a> { */ } +fn not_supported_by_atomics(qualifier: ast::LdStQualifier, underlying_type: *mut LLVMType) -> bool { + // This is not meant to be 100% accurate, just a best-effort guess for atomics + fn is_non_scalar_type(type_: LLVMTypeRef) -> bool { + let kind = unsafe { LLVMGetTypeKind(type_) }; + matches!( + kind, + LLVMTypeKind::LLVMArrayTypeKind + | LLVMTypeKind::LLVMVectorTypeKind + | LLVMTypeKind::LLVMStructTypeKind + ) + } + !matches!(qualifier, ast::LdStQualifier::Weak) && is_non_scalar_type(underlying_type) +} + fn apply_qualifier( value: LLVMValueRef, qualifier: ptx_parser::LdStQualifier, diff --git a/ptx/src/test/spirv_run/atomics_128.ptx b/ptx/src/test/spirv_run/atomics_128.ptx new file mode 100644 index 0000000..147d350 --- /dev/null +++ b/ptx/src/test/spirv_run/atomics_128.ptx @@ -0,0 +1,24 @@ +.version 7.0 +.target sm_80 +.address_size 64 + +.visible .entry atomics_128( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.acquire.gpu.v2.u64 {temp1, temp2}, [in_addr]; + add.u64 temp1, temp1, 1; + add.u64 temp2, temp2, 1; + st.release.gpu.v2.u64 [out_addr], {temp1, temp2}; + + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c24ca1a..acad0b6 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -352,6 +352,7 @@ test_ptx!( [613065134u32] ); test_ptx!(param_is_addressable, [0xDEAD], [0u64]); +test_ptx!(atomics_128, [0xce16728dead1ceb0u64, 0xe7728e3c390b7fb7], [0xce16728dead1ceb1u64, 0xe7728e3c390b7fb8]); test_ptx!(assertfail); // TODO: not yet supported