Update llama.cpp support (#102)

Add sign extension support to prmt, allow set.<op>.f16x2.f16x2, add more BLAS mappings
This commit is contained in:
Andrzej Janik 2024-02-16 00:01:21 +01:00 committed by GitHub
parent 9f7be97ef6
commit 4a81dbffb5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 368 additions and 95 deletions

View file

@ -127,6 +127,16 @@ If an application fails to start under ZLUDA or crashes please check [Known Issu
### Applications
#### llama.cpp
If you are building llama.cpp with cmake and don't want it to crash on ZLUDA then you should use `CUDA_DOCKER_ARCH=compute_61` like this:
```
make CUDA_DOCKER_ARCH=compute_61
```
Alternatively, building with cmake should work with no changes.
Performance is currently much lower than the native HIP backend, see the discussion in #102.
#### Arnold
* ZLUDA implements minimum of OptiX framework to support Arnold. ZLUDA's OptiX is buggy, unoptimized and incomplete. It's been tested with Arnold 7.1.4.1 command line rendering on Linux.

View file

@ -1176,15 +1176,26 @@ fn emit_inst_set(
) -> Result<(), TranslateError> {
let builder = ctx.builder.get();
let temp_result = emit_inst_setp_float(ctx, details.cmp_op, None, arg.src1, arg.src2)?;
if details.src_type != ast::ScalarType::F16x2 || details.dst_type == ast::ScalarType::F16x2 {
if details.src_type != ast::ScalarType::F16x2 {
return Err(TranslateError::todo());
}
if details.dst_type.is_integer() && details.dst_type.size_of() == mem::size_of::<u32>() as u8 {
let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?;
let b16vec2_result =
unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) };
let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
ctx.names.register_result(arg.dst, |dst_name| unsafe {
LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name)
});
} else if matches!(details.dst_type, ast::ScalarType::F16x2) {
let f16x2_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::F16x2))?;
ctx.names.register_result(arg.dst, |dst_name| unsafe {
LLVMBuildUIToFP(builder, temp_result, f16x2_type, dst_name)
});
} else {
return Err(TranslateError::todo());
}
let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?;
let b16vec2_result = unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) };
let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
ctx.names.register_result(arg.dst, |dst_name| unsafe {
LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name)
});
Ok(())
}
@ -1654,14 +1665,17 @@ fn emit_inst_prmt(
) -> Result<(), TranslateError> {
let builder = ctx.builder.get();
let components = [
((control >> 0) & 0b1111) as u32,
((control >> 4) & 0b1111) as u32,
((control >> 8) & 0b1111) as u32,
((control >> 12) & 0b1111) as u32,
((control >> 0) & 0b0111) as u32,
((control >> 4) & 0b0111) as u32,
((control >> 8) & 0b0111) as u32,
((control >> 12) & 0b0111) as u32,
];
let sext_components = [
((control >> 0) & 0b1000) != 0,
((control >> 4) & 0b1000) != 0,
((control >> 8) & 0b1000) != 0,
((control >> 12) & 0b1000) != 0,
];
if components.iter().any(|&c| c > 7) {
return Err(TranslateError::todo());
}
let llvm_i32 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
let llvm_vec4_i8 = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::U8, 4))?;
let src1 = ctx.names.value(arg.src1)?;
@ -1674,9 +1688,24 @@ fn emit_inst_prmt(
unsafe { LLVMConstInt(llvm_i32, components[2] as _, 0) },
unsafe { LLVMConstInt(llvm_i32, components[3] as _, 0) },
];
let mask = unsafe { LLVMConstVector(components_llvm.as_mut_ptr(), 4) };
let shuffle_result =
let mask =
unsafe { LLVMConstVector(components_llvm.as_mut_ptr(), components_llvm.len() as u32) };
let mut shuffle_result =
unsafe { LLVMBuildShuffleVector(builder, src1_vector, src2_vector, mask, LLVM_UNNAMED) };
// In sext case I'd prefer to just emit V_PERM_B32 directly and be done with it,
// but V_PERM_B32 can sext only odd-indexed bytes.
let llvm_i8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U8))?;
let const_7 = unsafe { LLVMConstInt(llvm_i8, 7, 0) };
for (idx, requires_sext) in sext_components.iter().copied().enumerate() {
if !requires_sext {
continue;
}
let idx = unsafe { LLVMConstInt(llvm_i32, idx as u64, 0) };
let scalar = unsafe { LLVMBuildExtractElement(builder, shuffle_result, idx, LLVM_UNNAMED) };
let shift = unsafe { LLVMBuildAShr(builder, scalar, const_7, LLVM_UNNAMED) };
shuffle_result =
unsafe { LLVMBuildInsertElement(builder, shuffle_result, shift, idx, LLVM_UNNAMED) };
}
ctx.names.register_result(arg.dst, |dst_name| unsafe {
LLVMBuildBitCast(builder, shuffle_result, llvm_i32, dst_name)
});

View file

@ -1097,6 +1097,15 @@ InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-set
InstSet: ast::Instruction<ast::ParsedArgParams<'input>> = {
"set" <cmp_op:SetpCompareOp> <ftz:".ftz"?> ".f16x2" ".f16x2" <arg:Arg3> => {
let data = ast::SetData {
dst_type: ast::ScalarType::F16x2,
src_type: ast::ScalarType::F16x2,
flush_to_zero: ftz.is_some(),
cmp_op: cmp_op,
};
ast::Instruction::Set(data, arg)
},
"set" <cmp_op:SetpCompareOp> <ftz:".ftz"?> ".u32" ".f16x2" <arg:Arg3> => {
let data = ast::SetData {
dst_type: ast::ScalarType::U32,

View file

@ -271,7 +271,7 @@ test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(cvt_f32_f16, [0xa1u16], [0x37210000u32]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32, 0x6FFFD600]);
test_ptx!(
prmt_non_immediate,
[0x70c507d6u32, 0x6fbd4b5cu32],
@ -336,7 +336,7 @@ test_ptx!(
[f16::from_f32(2.0), f16::from_f32(3.0)],
[f16::from_f32(2.0), f16::from_f32(5.0)]
);
test_ptx!(st_f16x2, [0xc1690e6eu32, 0x13739444u32], [0xffffu32]);
test_ptx!(set_f16x2, [0xc1690e6eu32, 0x13739444u32, 0x424834CC, 0x4248B4CC], [0xffffu32, 0x3C000000]);
test_ptx!(
dp4a,
[0xde3032f5u32, 0x2474fe15, 0xf51d8d6c],

View file

@ -1,40 +1,60 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"
define protected amdgpu_kernel void @prmt(ptr addrspace(4) byref(i64) %"23", ptr addrspace(4) byref(i64) %"24") #0 {
"31":
%"8" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"8", align 1
%"9" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"9", align 1
define protected amdgpu_kernel void @prmt(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 {
"44":
%"10" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"10", align 1
%"11" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"11", align 1
%"4" = alloca i64, align 8, addrspace(5)
%"5" = alloca i64, align 8, addrspace(5)
%"6" = alloca i32, align 4, addrspace(5)
%"7" = alloca i32, align 4, addrspace(5)
%"10" = load i64, ptr addrspace(4) %"23", align 8
store i64 %"10", ptr addrspace(5) %"4", align 8
%"11" = load i64, ptr addrspace(4) %"24", align 8
store i64 %"11", ptr addrspace(5) %"5", align 8
%"13" = load i64, ptr addrspace(5) %"4", align 8
%"25" = inttoptr i64 %"13" to ptr
%"12" = load i32, ptr %"25", align 4
store i32 %"12", ptr addrspace(5) %"6", align 4
%"8" = alloca i32, align 4, addrspace(5)
%"9" = alloca i32, align 4, addrspace(5)
%"12" = load i64, ptr addrspace(4) %"32", align 8
store i64 %"12", ptr addrspace(5) %"4", align 8
%"13" = load i64, ptr addrspace(4) %"33", align 8
store i64 %"13", ptr addrspace(5) %"5", align 8
%"15" = load i64, ptr addrspace(5) %"4", align 8
%"26" = inttoptr i64 %"15" to ptr
%"33" = getelementptr inbounds i8, ptr %"26", i64 4
%"14" = load i32, ptr %"33", align 4
store i32 %"14", ptr addrspace(5) %"7", align 4
%"17" = load i32, ptr addrspace(5) %"6", align 4
%"18" = load i32, ptr addrspace(5) %"7", align 4
%0 = bitcast i32 %"17" to <4 x i8>
%1 = bitcast i32 %"18" to <4 x i8>
%2 = shufflevector <4 x i8> %0, <4 x i8> %1, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
%"27" = bitcast <4 x i8> %2 to i32
store i32 %"27", ptr addrspace(5) %"7", align 4
%"19" = load i64, ptr addrspace(5) %"5", align 8
%"34" = inttoptr i64 %"15" to ptr
%"14" = load i32, ptr %"34", align 4
store i32 %"14", ptr addrspace(5) %"6", align 4
%"17" = load i64, ptr addrspace(5) %"4", align 8
%"35" = inttoptr i64 %"17" to ptr
%"46" = getelementptr inbounds i8, ptr %"35", i64 4
%"16" = load i32, ptr %"46", align 4
store i32 %"16", ptr addrspace(5) %"7", align 4
%"19" = load i32, ptr addrspace(5) %"6", align 4
%"20" = load i32, ptr addrspace(5) %"7", align 4
%"30" = inttoptr i64 %"19" to ptr
store i32 %"20", ptr %"30", align 4
%0 = bitcast i32 %"19" to <4 x i8>
%1 = bitcast i32 %"20" to <4 x i8>
%2 = shufflevector <4 x i8> %0, <4 x i8> %1, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
%"36" = bitcast <4 x i8> %2 to i32
store i32 %"36", ptr addrspace(5) %"8", align 4
%"22" = load i32, ptr addrspace(5) %"6", align 4
%"23" = load i32, ptr addrspace(5) %"7", align 4
%3 = bitcast i32 %"22" to <4 x i8>
%4 = bitcast i32 %"23" to <4 x i8>
%5 = shufflevector <4 x i8> %3, <4 x i8> %4, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
%6 = extractelement <4 x i8> %5, i32 0
%7 = ashr i8 %6, 7
%8 = insertelement <4 x i8> %5, i8 %7, i32 0
%9 = extractelement <4 x i8> %8, i32 2
%10 = ashr i8 %9, 7
%11 = insertelement <4 x i8> %8, i8 %10, i32 2
%"39" = bitcast <4 x i8> %11 to i32
store i32 %"39", ptr addrspace(5) %"9", align 4
%"24" = load i64, ptr addrspace(5) %"5", align 8
%"25" = load i32, ptr addrspace(5) %"8", align 4
%"42" = inttoptr i64 %"24" to ptr
store i32 %"25", ptr %"42", align 4
%"26" = load i64, ptr addrspace(5) %"5", align 8
%"27" = load i32, ptr addrspace(5) %"9", align 4
%"43" = inttoptr i64 %"26" to ptr
%"48" = getelementptr inbounds i8, ptr %"43", i64 4
store i32 %"27", ptr %"48", align 4
ret void
}

View file

@ -11,13 +11,17 @@
.reg .u64 out_addr;
.reg .u32 temp1;
.reg .u32 temp2;
.reg .u32 temp3;
.reg .u32 temp4;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp1, [in_addr];
ld.u32 temp2, [in_addr+4];
prmt.b32 temp2, temp1, temp2, 30212;
st.u32 [out_addr], temp2;
prmt.b32 temp3, temp1, temp2, 30212;
prmt.b32 temp4, temp1, temp2, 32268;
st.u32 [out_addr], temp3;
st.u32 [out_addr+4], temp4;
ret;
}

View file

@ -0,0 +1,68 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"
define protected amdgpu_kernel void @set_f16x2(ptr addrspace(4) byref(i64) %"41", ptr addrspace(4) byref(i64) %"42") #0 {
"59":
%"11" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"11", align 1
%"12" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"12", align 1
%"4" = alloca i64, align 8, addrspace(5)
%"5" = alloca i64, align 8, addrspace(5)
%"6" = alloca i32, align 4, addrspace(5)
%"7" = alloca i32, align 4, addrspace(5)
%"8" = alloca i32, align 4, addrspace(5)
%"9" = alloca i32, align 4, addrspace(5)
%"10" = alloca <2 x half>, align 4, addrspace(5)
%"13" = load i64, ptr addrspace(4) %"41", align 8
store i64 %"13", ptr addrspace(5) %"4", align 8
%"14" = load i64, ptr addrspace(4) %"42", align 8
store i64 %"14", ptr addrspace(5) %"5", align 8
%"16" = load i64, ptr addrspace(5) %"4", align 8
%"44" = inttoptr i64 %"16" to ptr
%"43" = load i32, ptr %"44", align 4
store i32 %"43", ptr addrspace(5) %"6", align 4
%"18" = load i64, ptr addrspace(5) %"4", align 8
%"45" = inttoptr i64 %"18" to ptr
%"61" = getelementptr inbounds i8, ptr %"45", i64 4
%"46" = load i32, ptr %"61", align 4
store i32 %"46", ptr addrspace(5) %"7", align 4
%"20" = load i64, ptr addrspace(5) %"4", align 8
%"47" = inttoptr i64 %"20" to ptr
%"63" = getelementptr inbounds i8, ptr %"47", i64 8
%"48" = load i32, ptr %"63", align 4
store i32 %"48", ptr addrspace(5) %"8", align 4
%"22" = load i64, ptr addrspace(5) %"4", align 8
%"49" = inttoptr i64 %"22" to ptr
%"65" = getelementptr inbounds i8, ptr %"49", i64 12
%"50" = load i32, ptr %"65", align 4
store i32 %"50", ptr addrspace(5) %"9", align 4
%"24" = load i32, ptr addrspace(5) %"6", align 4
%"25" = load i32, ptr addrspace(5) %"7", align 4
%"52" = bitcast i32 %"24" to <2 x half>
%"53" = bitcast i32 %"25" to <2 x half>
%0 = fcmp ugt <2 x half> %"52", %"53"
%1 = sext <2 x i1> %0 to <2 x i16>
%"51" = bitcast <2 x i16> %1 to i32
store i32 %"51", ptr addrspace(5) %"6", align 4
%"27" = load i32, ptr addrspace(5) %"8", align 4
%"28" = load i32, ptr addrspace(5) %"9", align 4
%"55" = bitcast i32 %"27" to <2 x half>
%"56" = bitcast i32 %"28" to <2 x half>
%2 = fcmp oeq <2 x half> %"55", %"56"
%"54" = uitofp <2 x i1> %2 to <2 x half>
%"26" = bitcast <2 x half> %"54" to i32
store i32 %"26", ptr addrspace(5) %"8", align 4
%"29" = load i64, ptr addrspace(5) %"5", align 8
%"30" = load i32, ptr addrspace(5) %"6", align 4
%"57" = inttoptr i64 %"29" to ptr
store i32 %"30", ptr %"57", align 4
%"31" = load i64, ptr addrspace(5) %"5", align 8
%"32" = load i32, ptr addrspace(5) %"8", align 4
%"58" = inttoptr i64 %"31" to ptr
%"67" = getelementptr inbounds i8, ptr %"58", i64 4
store i32 %"32", ptr %"67", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -2,7 +2,7 @@
.target sm_53
.address_size 64
.visible .entry st_f16x2(
.visible .entry set_f16x2(
.param .u64 input,
.param .u64 output
)
@ -11,6 +11,8 @@
.reg .u64 out_addr;
.reg .b32 temp0;
.reg .b32 temp1;
.reg .b32 temp2;
.reg .b32 temp3;
.reg .f16x2 sela;
ld.param.u64 in_addr, [input];
@ -18,7 +20,11 @@
ld.u32 temp0, [in_addr];
ld.u32 temp1, [in_addr+4];
ld.u32 temp2, [in_addr+8];
ld.u32 temp3, [in_addr+12];
set.gtu.u32.f16x2 temp0, temp0, temp1;
set.eq.f16x2.f16x2 temp2, temp2, temp3;
st.b32 [out_addr], temp0;
st.b32 [out_addr+4], temp2;
ret;
}

View file

@ -1,43 +0,0 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"
define protected amdgpu_kernel void @st_f16x2(ptr addrspace(4) byref(i64) %"24", ptr addrspace(4) byref(i64) %"25") #0 {
"34":
%"9" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"9", align 1
%"10" = alloca i1, align 1, addrspace(5)
store i1 false, ptr addrspace(5) %"10", align 1
%"4" = alloca i64, align 8, addrspace(5)
%"5" = alloca i64, align 8, addrspace(5)
%"6" = alloca i32, align 4, addrspace(5)
%"7" = alloca i32, align 4, addrspace(5)
%"8" = alloca <2 x half>, align 4, addrspace(5)
%"11" = load i64, ptr addrspace(4) %"24", align 8
store i64 %"11", ptr addrspace(5) %"4", align 8
%"12" = load i64, ptr addrspace(4) %"25", align 8
store i64 %"12", ptr addrspace(5) %"5", align 8
%"14" = load i64, ptr addrspace(5) %"4", align 8
%"27" = inttoptr i64 %"14" to ptr
%"26" = load i32, ptr %"27", align 4
store i32 %"26", ptr addrspace(5) %"6", align 4
%"16" = load i64, ptr addrspace(5) %"4", align 8
%"28" = inttoptr i64 %"16" to ptr
%"36" = getelementptr inbounds i8, ptr %"28", i64 4
%"29" = load i32, ptr %"36", align 4
store i32 %"29", ptr addrspace(5) %"7", align 4
%"18" = load i32, ptr addrspace(5) %"6", align 4
%"19" = load i32, ptr addrspace(5) %"7", align 4
%"31" = bitcast i32 %"18" to <2 x half>
%"32" = bitcast i32 %"19" to <2 x half>
%0 = fcmp ugt <2 x half> %"31", %"32"
%1 = sext <2 x i1> %0 to <2 x i16>
%"30" = bitcast <2 x i16> %1 to i32
store i32 %"30", ptr addrspace(5) %"6", align 4
%"20" = load i64, ptr addrspace(5) %"5", align 8
%"21" = load i32, ptr addrspace(5) %"6", align 4
%"33" = inttoptr i64 %"20" to ptr
store i32 %"21", ptr %"33", align 4
ret void
}
attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }

View file

@ -162,7 +162,9 @@ pub(crate) unsafe fn get_attribute(
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS => {
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS
// Possibly true, used by llama.cpp
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED => {
*pi = 0;
return Ok(());
}

View file

@ -3926,7 +3926,28 @@ pub unsafe extern "system" fn cublasGemmBatchedEx(
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
crate::unsupported()
crate::gemm_batched_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
Aarray,
Atype,
lda,
Barray,
Btype,
ldb,
beta,
Carray,
Ctype,
ldc,
batchCount,
computeType,
algo,
)
}
#[no_mangle]
@ -3955,7 +3976,31 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx(
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
crate::unsupported()
crate::gemm_strided_batched_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
Atype,
lda,
strideA,
B,
Btype,
ldb,
strideB,
beta,
C,
Ctype,
ldc,
strideC,
batchCount,
computeType,
algo,
)
}
#[no_mangle]

View file

@ -916,3 +916,126 @@ unsafe fn dtrsm(
ldb,
))
}
unsafe fn gemm_batched_ex(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const *const std::ffi::c_void,
atype: cudaDataType_t,
lda: i32,
b: *const *const std::ffi::c_void,
btype: cudaDataType_t,
ldb: i32,
beta: *const std::ffi::c_void,
c: *const *mut std::ffi::c_void,
ctype: cudaDataType_t,
ldc: i32,
batch_count: i32,
compute_type: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
let transa = op_from_cuda(transa);
let transb = op_from_cuda(transb);
let atype = type_from_cuda(atype);
let btype = type_from_cuda(btype);
let ctype = type_from_cuda(ctype);
let compute_type = to_compute_type(compute_type);
let algo = to_algo(algo);
to_cuda(rocblas_gemm_batched_ex(
handle.cast(),
transa,
transb,
m,
n,
k,
alpha,
a.cast(),
atype,
lda,
b.cast(),
btype,
ldb,
beta,
c.cast(),
ctype,
ldc,
c.cast_mut().cast(),
ctype,
ldc,
batch_count,
compute_type,
algo,
0,
rocblas_gemm_flags::rocblas_gemm_flags_none.0,
))
}
unsafe fn gemm_strided_batched_ex(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: ::std::os::raw::c_int,
n: ::std::os::raw::c_int,
k: ::std::os::raw::c_int,
alpha: *const ::std::os::raw::c_void,
a: *const ::std::os::raw::c_void,
atype: cudaDataType,
lda: ::std::os::raw::c_int,
stride_a: ::std::os::raw::c_longlong,
b: *const ::std::os::raw::c_void,
btype: cudaDataType,
ldb: ::std::os::raw::c_int,
stride_b: ::std::os::raw::c_longlong,
beta: *const ::std::os::raw::c_void,
c: *mut ::std::os::raw::c_void,
ctype: cudaDataType,
ldc: ::std::os::raw::c_int,
stride_c: ::std::os::raw::c_longlong,
batch_count: ::std::os::raw::c_int,
compute_type: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
let transa = op_from_cuda(transa);
let transb = op_from_cuda(transb);
let atype = type_from_cuda(atype);
let btype = type_from_cuda(btype);
let ctype = type_from_cuda(ctype);
let compute_type = to_compute_type(compute_type);
let algo = to_algo(algo);
to_cuda(rocblas_gemm_strided_batched_ex(
handle.cast(),
transa,
transb,
m,
n,
k,
alpha,
a,
atype,
lda,
stride_a,
b,
btype,
ldb,
stride_b,
beta,
c,
ctype,
ldc,
stride_c,
c,
ctype,
ldc,
stride_c,
batch_count,
compute_type,
algo,
0,
rocblas_gemm_flags::rocblas_gemm_flags_none.0,
))
}