Add atomic loads and stores (#526)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

And add various smaller fixes across the compiler and runtime
This commit is contained in:
Andrzej Janik 2025-09-25 18:19:10 -07:00 committed by GitHub
commit 5d03261457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 238 additions and 104 deletions

View file

@ -24,15 +24,7 @@ jobs:
name: Build (Linux) name: Build (Linux)
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: jlumbroso/free-disk-space@main - uses: jlumbroso/free-disk-space@v1.3.1
with:
# Removing Android stuff should be enough
android: true
dotnet: false
haskell: false
large-packages: false
docker-images: false
swap-storage: false
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
submodules: true submodules: true
@ -79,15 +71,7 @@ jobs:
outputs: outputs:
test_package: ${{ steps.upload_artifacts.outputs.artifact-id }} test_package: ${{ steps.upload_artifacts.outputs.artifact-id }}
steps: steps:
- uses: jlumbroso/free-disk-space@main - uses: jlumbroso/free-disk-space@v1.3.1
with:
# Removing Android stuff should be enough
android: true
dotnet: false
haskell: false
large-packages: false
docker-images: false
swap-storage: false
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
submodules: true submodules: true

View file

@ -18,15 +18,7 @@ jobs:
permissions: permissions:
contents: write contents: write
steps: steps:
- uses: jlumbroso/free-disk-space@main - uses: jlumbroso/free-disk-space@v1.3.1
with:
# Removing Android stuff should be enough
android: true
dotnet: false
haskell: false
large-packages: false
docker-images: false
swap-storage: false
- uses: actions/checkout@v4 - uses: actions/checkout@v4
# fetch-depth and fetch-tags are required to properly tag pre-release builds # fetch-depth and fetch-tags are required to properly tag pre-release builds
with: with:
@ -117,15 +109,7 @@ jobs:
outputs: outputs:
test_package: ${{ steps.upload_artifacts.outputs.artifact-id }} test_package: ${{ steps.upload_artifacts.outputs.artifact-id }}
steps: steps:
- uses: jlumbroso/free-disk-space@main - uses: jlumbroso/free-disk-space@v1.3.1
with:
# Removing Android stuff should be enough
android: true
dotnet: false
haskell: false
large-packages: false
docker-images: false
swap-storage: false
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
submodules: true submodules: true

8
Cargo.lock generated
View file

@ -420,7 +420,7 @@ version = "0.0.0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"syn 2.0.89", "syn 2.0.89",
] ]
@ -3706,7 +3706,7 @@ dependencies = [
"paste", "paste",
"ptx", "ptx",
"ptx_parser", "ptx_parser",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"serde", "serde",
"serde_json", "serde_json",
"tempfile", "tempfile",
@ -3726,7 +3726,7 @@ dependencies = [
"prettyplease", "prettyplease",
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"syn 2.0.89", "syn 2.0.89",
] ]
@ -3854,7 +3854,7 @@ dependencies = [
"ptx", "ptx",
"ptx_parser", "ptx_parser",
"regex", "regex",
"rustc-hash 1.1.0", "rustc-hash 2.0.0",
"unwrap_or", "unwrap_or",
"wchar", "wchar",
"winapi", "winapi",

View file

@ -219,6 +219,12 @@ pub fn compile_bitcode(
compile_to_exec.set_isa_name(gcn_arch)?; compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(Language::LlvmIr)?; compile_to_exec.set_language(Language::LlvmIr)?;
let common_options = [ let common_options = [
// Uncomment for LLVM debug
//c"-mllvm",
//c"-debug",
// Uncomment to save passes
// c"-mllvm",
// c"-print-before-all",
c"-mllvm", c"-mllvm",
c"-ignore-tti-inline-compatible", c"-ignore-tti-inline-compatible",
// c"-mllvm", // c"-mllvm",

View file

@ -8,7 +8,7 @@ edition = "2021"
quote = "1.0" quote = "1.0"
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] } syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
proc-macro2 = "1.0" proc-macro2 = "1.0"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
[lib] [lib]
proc-macro = true proc-macro = true

View file

@ -196,4 +196,24 @@ void LLVMZludaBuildFence(LLVMBuilderRef B, LLVMAtomicOrdering Ordering,
Name); Name);
} }
void LLVMZludaSetAtomic(
LLVMValueRef AtomicInst,
LLVMAtomicOrdering Ordering,
char * SSID)
{
auto inst = unwrap(AtomicInst);
if (LoadInst *LI = dyn_cast<LoadInst>(inst))
{
LI->setAtomic(mapFromLLVMOrdering(Ordering), LI->getContext().getOrInsertSyncScopeID(SSID));
}
else if (StoreInst *SI = dyn_cast<StoreInst>(inst))
{
SI->setAtomic(mapFromLLVMOrdering(Ordering), SI->getContext().getOrInsertSyncScopeID(SSID));
}
else
{
llvm_unreachable("Invalid instruction type for LLVMZludaSetAtomic");
}
}
LLVM_C_EXTERN_C_END LLVM_C_EXTERN_C_END

View file

@ -78,4 +78,10 @@ extern "C" {
scope: *const i8, scope: *const i8,
Name: *const i8, Name: *const i8,
) -> LLVMValueRef; ) -> LLVMValueRef;
pub fn LLVMZludaSetAtomic(
AtomicInst: LLVMValueRef,
Ordering: LLVMAtomicOrdering,
SSID: *const i8,
);
} }

Binary file not shown.

View file

@ -17,6 +17,7 @@ typedef _Float16 half16 __attribute__((ext_vector_type(16)));
typedef float float8 __attribute__((ext_vector_type(8))); typedef float float8 __attribute__((ext_vector_type(8)));
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME #define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
#define FUNC_CALL(NAME) __zluda_ptx_impl_##NAME
#define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME #define ATTR(NAME) __ZLUDA_PTX_IMPL_ATTRIBUTE_##NAME
#define DECLARE_ATTR(TYPE, NAME) \ #define DECLARE_ATTR(TYPE, NAME) \
extern "C" __attribute__((constant)) CONSTANT_SPACE TYPE ATTR(NAME) \ extern "C" __attribute__((constant)) CONSTANT_SPACE TYPE ATTR(NAME) \
@ -58,6 +59,18 @@ extern "C"
return __lane_id(); return __lane_id();
} }
uint32_t FUNC(sreg_lanemask_lt)()
{
uint32_t lane_idx = FUNC_CALL(sreg_laneid)();
return (1U << lane_idx) - 1U;
}
uint32_t FUNC(sreg_lanemask_ge)()
{
uint32_t lane_idx = FUNC_CALL(sreg_laneid)();
return (~0U) << lane_idx;
}
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __device__; uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __device__;
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32) uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
{ {

View file

@ -539,17 +539,25 @@ impl<'a> MethodEmitContext<'a> {
data: ast::LdDetails, data: ast::LdDetails,
arguments: ast::LdArgs<SpirvWord>, arguments: ast::LdArgs<SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
let builder = self.builder; let builder = self.builder;
let type_ = get_type(self.context, &data.typ)?; let underlying_type = get_type(self.context, &data.typ)?;
let ptr = self.resolver.value(arguments.src)?; let needs_cast = not_supported_by_atomics(data.qualifier, underlying_type);
self.resolver.with_result(arguments.dst, |dst| { let op_type = if needs_cast {
let load = unsafe { LLVMBuildLoad2(builder, type_, ptr, dst) }; unsafe { LLVMIntTypeInContext(self.context, data.typ.layout().size() as u32 * 8) }
} else {
underlying_type
};
let src = self.resolver.value(arguments.src)?;
let load = unsafe { LLVMBuildLoad2(builder, op_type, src, LLVM_UNNAMED.as_ptr()) };
apply_qualifier(load, data.qualifier)?;
unsafe { LLVMSetAlignment(load, data.typ.layout().align() as u32) }; unsafe { LLVMSetAlignment(load, data.typ.layout().align() as u32) };
load if needs_cast {
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildBitCast(builder, load, underlying_type, dst)
}); });
} else {
self.resolver.register(arguments.dst, load);
}
Ok(()) Ok(())
} }
@ -758,11 +766,21 @@ impl<'a> MethodEmitContext<'a> {
arguments: ast::StArgs<SpirvWord>, arguments: ast::StArgs<SpirvWord>,
) -> Result<(), TranslateError> { ) -> Result<(), TranslateError> {
let ptr = self.resolver.value(arguments.src1)?; let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?; let underlying_type = get_type(self.context, &data.typ)?;
if data.qualifier != ast::LdStQualifier::Weak { let needs_cast = not_supported_by_atomics(data.qualifier, underlying_type);
todo!() let mut value = self.resolver.value(arguments.src2)?;
if needs_cast {
value = unsafe {
LLVMBuildBitCast(
self.builder,
value,
LLVMIntTypeInContext(self.context, data.typ.layout().size() as u32 * 8),
LLVM_UNNAMED.as_ptr(),
)
};
} }
let store = unsafe { LLVMBuildStore(self.builder, value, ptr) }; let store = unsafe { LLVMBuildStore(self.builder, value, ptr) };
apply_qualifier(store, data.qualifier)?;
unsafe { unsafe {
LLVMSetAlignment(store, data.typ.layout().align() as u32); LLVMSetAlignment(store, data.typ.layout().align() as u32);
} }
@ -1653,7 +1671,6 @@ impl<'a> MethodEmitContext<'a> {
.ok_or_else(|| error_mismatched_type())?, .ok_or_else(|| error_mismatched_type())?,
); );
let src2 = self.resolver.value(src2)?; let src2 = self.resolver.value(src2)?;
self.resolver.with_result(arguments.dst, |dst| {
let vec = unsafe { let vec = unsafe {
LLVMBuildInsertElement( LLVMBuildInsertElement(
self.builder, self.builder,
@ -1663,7 +1680,7 @@ impl<'a> MethodEmitContext<'a> {
LLVM_UNNAMED.as_ptr(), LLVM_UNNAMED.as_ptr(),
) )
}; };
unsafe { self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildInsertElement( LLVMBuildInsertElement(
self.builder, self.builder,
vec, vec,
@ -1671,7 +1688,6 @@ impl<'a> MethodEmitContext<'a> {
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32), LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst, dst,
) )
}
}) })
} else { } else {
self.resolver.with_result(arguments.dst, |dst| unsafe { self.resolver.with_result(arguments.dst, |dst| unsafe {
@ -2197,7 +2213,7 @@ impl<'a> MethodEmitContext<'a> {
Some(&ast::ScalarType::F32.into()), Some(&ast::ScalarType::F32.into()),
vec![( vec![(
self.resolver.value(arguments.src)?, self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()), get_scalar_type(self.context, ast::ScalarType::F32),
)], )],
)?; )?;
Ok(()) Ok(())
@ -2236,7 +2252,7 @@ impl<'a> MethodEmitContext<'a> {
} }
fn emit_bar_warp(&mut self) -> Result<(), TranslateError> { fn emit_bar_warp(&mut self) -> Result<(), TranslateError> {
self.emit_intrinsic(c"llvm.amdgcn.barrier.warp", None, None, vec![])?; self.emit_intrinsic(c"llvm.amdgcn.wave.barrier", None, None, vec![])?;
Ok(()) Ok(())
} }
@ -2658,14 +2674,14 @@ impl<'a> MethodEmitContext<'a> {
let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) }; let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
unsafe { unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); LLVMSetAlignment(load, cp_size.as_u64() as u32);
} }
let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) }; let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMBuildStore(self.builder, extended, to) }; let store = unsafe { LLVMBuildStore(self.builder, extended, to) };
unsafe { unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8); LLVMSetAlignment(store, cp_size.as_u64() as u32);
} }
Ok(()) Ok(())
} }
@ -2923,6 +2939,61 @@ impl<'a> MethodEmitContext<'a> {
*/ */
} }
fn not_supported_by_atomics(qualifier: ast::LdStQualifier, underlying_type: *mut LLVMType) -> bool {
// This is not meant to be 100% accurate, just a best-effort guess for atomics
fn is_non_scalar_type(type_: LLVMTypeRef) -> bool {
let kind = unsafe { LLVMGetTypeKind(type_) };
matches!(
kind,
LLVMTypeKind::LLVMArrayTypeKind
| LLVMTypeKind::LLVMVectorTypeKind
| LLVMTypeKind::LLVMStructTypeKind
)
}
!matches!(qualifier, ast::LdStQualifier::Weak) && is_non_scalar_type(underlying_type)
}
fn apply_qualifier(
value: LLVMValueRef,
qualifier: ptx_parser::LdStQualifier,
) -> Result<(), TranslateError> {
match qualifier {
ptx_parser::LdStQualifier::Weak => {}
ptx_parser::LdStQualifier::Volatile => unsafe {
LLVMSetVolatile(value, 1);
// The semantics of volatile operations are equivalent to a relaxed memory operation
// with system-scope but with the following extra implementation-specific constraints...
LLVMZludaSetAtomic(
value,
LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
get_scope(ast::MemScope::Sys)?,
);
},
ptx_parser::LdStQualifier::Relaxed(mem_scope) => unsafe {
LLVMZludaSetAtomic(
value,
LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
get_scope(mem_scope)?,
);
},
ptx_parser::LdStQualifier::Acquire(mem_scope) => unsafe {
LLVMZludaSetAtomic(
value,
LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
get_scope(mem_scope)?,
);
},
ptx_parser::LdStQualifier::Release(mem_scope) => unsafe {
LLVMZludaSetAtomic(
value,
LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
get_scope(mem_scope)?,
);
},
}
Ok(())
}
fn get_pointer_type<'ctx>( fn get_pointer_type<'ctx>(
context: LLVMContextRef, context: LLVMContextRef,
to_space: ast::StateSpace, to_space: ast::StateSpace,
@ -2936,7 +3007,7 @@ fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
ast::MemScope::Cta => c"workgroup-one-as", ast::MemScope::Cta => c"workgroup-one-as",
ast::MemScope::Gpu => c"agent-one-as", ast::MemScope::Gpu => c"agent-one-as",
ast::MemScope::Sys => c"one-as", ast::MemScope::Sys => c"one-as",
ast::MemScope::Cluster => todo!(), ast::MemScope::Cluster => return Err(error_todo()),
} }
.as_ptr()) .as_ptr())
} }
@ -2945,8 +3016,9 @@ fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope { Ok(match scope {
ast::MemScope::Cta => c"workgroup", ast::MemScope::Cta => c"workgroup",
ast::MemScope::Gpu => c"agent", ast::MemScope::Gpu => c"agent",
// Don't change to "system", this is the same as __threadfence_system, AMDPGU LLVM expects "" here
ast::MemScope::Sys => c"", ast::MemScope::Sys => c"",
ast::MemScope::Cluster => todo!(), ast::MemScope::Cluster => return Err(error_todo()),
} }
.as_ptr()) .as_ptr())
} }

View file

@ -136,6 +136,7 @@ enum PtxSpecialRegister {
Nctaid, Nctaid,
Clock, Clock,
LanemaskLt, LanemaskLt,
LanemaskGe,
Laneid, Laneid,
} }
@ -148,6 +149,7 @@ impl PtxSpecialRegister {
Self::Nctaid => "%nctaid", Self::Nctaid => "%nctaid",
Self::Clock => "%clock", Self::Clock => "%clock",
Self::LanemaskLt => "%lanemask_lt", Self::LanemaskLt => "%lanemask_lt",
Self::LanemaskGe => "%lanemask_ge",
Self::Laneid => "%laneid", Self::Laneid => "%laneid",
} }
} }
@ -170,6 +172,7 @@ impl PtxSpecialRegister {
PtxSpecialRegister::Nctaid => ast::ScalarType::U32, PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
PtxSpecialRegister::Clock => ast::ScalarType::U32, PtxSpecialRegister::Clock => ast::ScalarType::U32,
PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
PtxSpecialRegister::LanemaskGe => ast::ScalarType::U32,
PtxSpecialRegister::Laneid => ast::ScalarType::U32, PtxSpecialRegister::Laneid => ast::ScalarType::U32,
} }
} }
@ -182,6 +185,7 @@ impl PtxSpecialRegister {
| PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
PtxSpecialRegister::Clock PtxSpecialRegister::Clock
| PtxSpecialRegister::LanemaskLt | PtxSpecialRegister::LanemaskLt
| PtxSpecialRegister::LanemaskGe
| PtxSpecialRegister::Laneid => None, | PtxSpecialRegister::Laneid => None,
} }
} }
@ -194,6 +198,7 @@ impl PtxSpecialRegister {
PtxSpecialRegister::Nctaid => "sreg_nctaid", PtxSpecialRegister::Nctaid => "sreg_nctaid",
PtxSpecialRegister::Clock => "sreg_clock", PtxSpecialRegister::Clock => "sreg_clock",
PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
PtxSpecialRegister::LanemaskGe => "sreg_lanemask_ge",
PtxSpecialRegister::Laneid => "sreg_laneid", PtxSpecialRegister::Laneid => "sreg_laneid",
} }
} }

View file

@ -0,0 +1,24 @@
.version 7.0
.target sm_80
.address_size 64
.visible .entry atomics_128(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.acquire.gpu.v2.u64 {temp1, temp2}, [in_addr];
add.u64 temp1, temp1, 1;
add.u64 temp2, temp2, 1;
st.release.gpu.v2.u64 [out_addr], {temp1, temp2};
ret;
}

View file

@ -352,6 +352,12 @@ test_ptx!(
[613065134u32] [613065134u32]
); );
test_ptx!(param_is_addressable, [0xDEAD], [0u64]); test_ptx!(param_is_addressable, [0xDEAD], [0u64]);
// TODO: re-enable when we have a patched LLVM
//test_ptx!(
// atomics_128,
// [0xce16728dead1ceb0u64, 0xe7728e3c390b7fb7],
// [0xce16728dead1ceb1u64, 0xe7728e3c390b7fb8]
//);
test_ptx!(assertfail); test_ptx!(assertfail);
// TODO: not yet supported // TODO: not yet supported

View file

@ -226,8 +226,9 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
take_error((opt(Token::Minus), num).map(|(neg, x)| { take_error((opt(Token::Minus), num).map(|(neg, x)| {
let (num, radix, is_unsigned) = x; let (num, radix, is_unsigned) = x;
if neg.is_some() { if neg.is_some() {
match i64::from_str_radix(num, radix) { let full_number = format!("-{num}");
Ok(x) => Ok(ast::ImmediateValue::S64(-x)), match i64::from_str_radix(&full_number, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(x)),
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
} }
} else if is_unsigned { } else if is_unsigned {

View file

@ -22,7 +22,7 @@ num_enum = "0.4"
lz4-sys = "1.9" lz4-sys = "1.9"
tempfile = "3" tempfile = "3"
paste = "1.0" paste = "1.0"
rustc-hash = "1.1" rustc-hash = "2.0.0"
zluda_common = { path = "../zluda_common" } zluda_common = { path = "../zluda_common" }
blake3 = "1.8.2" blake3 = "1.8.2"
serde = "1.0.219" serde = "1.0.219"

View file

@ -1,22 +1,33 @@
use cuda_types::cuda::CUfunction_attribute;
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::mem;
pub(crate) fn get_attribute( pub(crate) fn get_attribute(
pi: &mut i32, pi: &mut i32,
cu_attrib: hipFunction_attribute, cu_attrib: CUfunction_attribute,
func: hipFunction_t, func: hipFunction_t,
) -> hipError_t { ) -> hipError_t {
// TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION // TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION
// TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION // TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION
match cu_attrib { match cu_attrib {
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_PTX_VERSION CUfunction_attribute::CU_FUNC_ATTRIBUTE_PTX_VERSION
| hipFunction_attribute::HIP_FUNC_ATTRIBUTE_BINARY_VERSION => { | CUfunction_attribute::CU_FUNC_ATTRIBUTE_BINARY_VERSION => {
*pi = 120; *pi = 120;
return Ok(()); return Ok(());
} }
CUfunction_attribute::CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET
| CUfunction_attribute::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH
| CUfunction_attribute::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT
| CUfunction_attribute::CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH
| CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED
| CUfunction_attribute::CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE => {
*pi = 0;
return Ok(());
}
_ => {} _ => {}
} }
unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?; unsafe { hipFuncGetAttribute(pi, mem::transmute(cu_attrib), func) }?;
if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS { if cu_attrib == CUfunction_attribute::CU_FUNC_ATTRIBUTE_NUM_REGS {
*pi = (*pi).max(1); *pi = (*pi).max(1);
} }
Ok(()) Ok(())
@ -55,12 +66,12 @@ pub(crate) fn launch_kernel(
pub(crate) unsafe fn set_attribute( pub(crate) unsafe fn set_attribute(
func: hipFunction_t, func: hipFunction_t,
attribute: hipFunction_attribute, attribute: CUfunction_attribute,
value: i32, value: i32,
) -> hipError_t { ) -> hipError_t {
match attribute { match attribute {
hipFunction_attribute::HIP_FUNC_ATTRIBUTE_PTX_VERSION CUfunction_attribute::CU_FUNC_ATTRIBUTE_PTX_VERSION
| hipFunction_attribute::HIP_FUNC_ATTRIBUTE_BINARY_VERSION => { | CUfunction_attribute::CU_FUNC_ATTRIBUTE_BINARY_VERSION => {
return hipError_t::ErrorNotSupported; return hipError_t::ErrorNotSupported;
} }
_ => {} _ => {}

View file

@ -1,4 +1,4 @@
use cuda_types::cuda::CUresult; use cuda_types::cuda::{CUfunction_attribute, CUresult};
use hip_runtime_sys::*; use hip_runtime_sys::*;
use crate::r#impl::function; use crate::r#impl::function;
@ -9,7 +9,7 @@ pub(crate) unsafe fn get_function(func: &mut hipFunction_t, kernel: hipFunction_
} }
pub(crate) unsafe fn set_attribute( pub(crate) unsafe fn set_attribute(
attrib: hipFunction_attribute, attrib: CUfunction_attribute,
val: ::core::ffi::c_int, val: ::core::ffi::c_int,
kernel: hipFunction_t, kernel: hipFunction_t,
_dev: hipDevice_t, _dev: hipDevice_t,

View file

@ -1,16 +1,18 @@
use std::ptr; use crate::r#impl::{context, driver};
use cuda_types::cuda::{CUerror, CUresult, CUresultConsts}; use cuda_types::cuda::{CUerror, CUresult, CUresultConsts};
use hip_runtime_sys::*; use hip_runtime_sys::*;
use std::{mem, ptr};
use crate::r#impl::{context, driver}; pub(crate) unsafe fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult {
pub(crate) fn alloc_v2(dptr: &mut hipDeviceptr_t, bytesize: usize) -> CUresult {
let context = context::get_current_context()?; let context = context::get_current_context()?;
unsafe { hipMalloc(ptr::from_mut(dptr).cast(), bytesize) }?; hipMalloc(ptr::from_mut(dptr).cast(), bytesize)?;
add_allocation(dptr.0, bytesize, context)?; add_allocation(dptr.0, bytesize, context)?;
let mut status = mem::zeroed();
hipStreamIsCapturing(hipStream_t(ptr::null_mut()), &mut status)?;
// TODO: parametrize for non-Geekbench // TODO: parametrize for non-Geekbench
unsafe { hipMemsetD8(*dptr, 0, bytesize) }?; if status != hipStreamCaptureStatus::hipStreamCaptureStatusNone {
hipMemsetD8(*dptr, 0, bytesize)?;
}
Ok(()) Ok(())
} }

View file

@ -9,6 +9,6 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89" proc-macro2 = "1.0.89"
quote = "1.0" quote = "1.0"
prettyplease = "0.2.25" prettyplease = "0.2.25"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
libloading = "0.8" libloading = "0.8"
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }

View file

@ -173,12 +173,12 @@ from_cuda_nop!(
cublasLtMatmulDescAttributes_t, cublasLtMatmulDescAttributes_t,
CUmemAllocationGranularity_flags, CUmemAllocationGranularity_flags,
CUmemAllocationProp, CUmemAllocationProp,
CUresult CUresult,
CUfunction_attribute
); );
from_cuda_transmute!( from_cuda_transmute!(
CUuuid => hipUUID, CUuuid => hipUUID,
CUfunction => hipFunction_t, CUfunction => hipFunction_t,
CUfunction_attribute => hipFunction_attribute,
CUstream => hipStream_t, CUstream => hipStream_t,
CUpointer_attribute => hipPointer_attribute, CUpointer_attribute => hipPointer_attribute,
CUdeviceptr_v2 => hipDeviceptr_t, CUdeviceptr_v2 => hipDeviceptr_t,

View file

@ -24,7 +24,7 @@ paste = "1.0"
cuda_macros = { path = "../cuda_macros" } cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" } cuda_types = { path = "../cuda_types" }
parking_lot = "0.12.3" parking_lot = "0.12.3"
rustc-hash = "1.1.0" rustc-hash = "2.0.0"
cglue = "0.3.5" cglue = "0.3.5"
zstd-safe = { version = "7.2.4", features = ["std"] } zstd-safe = { version = "7.2.4", features = ["std"] }
unwrap_or = "1.0.1" unwrap_or = "1.0.1"