Add support for cuBLASLt functions used by llm.c (#512)

This commit is contained in:
Violet 2025-09-17 11:02:21 -07:00 committed by GitHub
commit 571dad0972
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 522 additions and 9 deletions

3
Cargo.lock generated
View file

@ -3748,6 +3748,8 @@ version = "0.0.0"
dependencies = [
"cuda_macros",
"cuda_types",
"hip_runtime-sys",
"hipblaslt-sys",
"zluda_common",
]
@ -3769,6 +3771,7 @@ dependencies = [
"cuda_types",
"dark_api",
"hip_runtime-sys",
"hipblaslt-sys",
"rocblas-sys",
]

View file

@ -10,6 +10,8 @@ name = "cublaslt"
[dependencies]
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
hipblaslt-sys = { path = "../ext/hipblaslt-sys" }
zluda_common = { path = "../zluda_common" }
[package.metadata.zluda]

View file

@ -1,8 +1,19 @@
use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t};
use zluda_common::{from_cuda_object, ZludaObject};
use cuda_types::{cublas::*, cublaslt::*};
use hip_runtime_sys::hipStream_t;
use hipblaslt_sys::*;
use std::mem;
use zluda_common::{from_cuda_object, FromCuda, ZludaObject};
pub struct Handle {
_handle: usize,
handle: hipblasLtHandle_t,
}
impl Handle {
fn new() -> Self {
Self {
handle: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for Handle {
@ -48,11 +59,241 @@ pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> :
todo!()
}
pub(crate) fn create(handle: &mut cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
*handle = Handle { _handle: 0 }.wrap();
pub(crate) fn create(handle: &mut cublasLtHandle_t) -> cublasStatus_t {
let mut zluda_blaslt_handle = Handle::new();
unsafe { hipblasLtCreate(&mut zluda_blaslt_handle.handle) }?;
*handle = Handle::wrap(zluda_blaslt_handle);
Ok(())
}
pub(crate) fn destroy(handle: cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
pub(crate) fn destroy(handle: cublasLtHandle_t) -> cublasStatus_t {
zluda_common::drop_checked::<Handle>(handle)
}
fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
let mut cuda = cublasLtMatmulAlgo_t { data: [0; 8] };
let (chunks, _) = hip.data.as_chunks::<8>();
cuda.data[0] = u64::from_ne_bytes(chunks[0]);
cuda.data[1] = u64::from_ne_bytes(chunks[1]);
cuda.data[2] = hip.max_workspace_bytes as u64;
cuda
}
pub(crate) fn matmul(
light_handle: &Handle,
compute_desc: hipblasLtMatmulDesc_t,
alpha: *const ::core::ffi::c_void,
a: *const ::core::ffi::c_void,
a_desc: hipblasLtMatrixLayout_t,
b: *const ::core::ffi::c_void,
b_desc: hipblasLtMatrixLayout_t,
beta: *const ::core::ffi::c_void,
c: *const ::core::ffi::c_void,
c_desc: hipblasLtMatrixLayout_t,
d: *mut ::core::ffi::c_void,
d_desc: hipblasLtMatrixLayout_t,
algo: hipblasLtMatmulAlgo_t,
workspace: *mut ::core::ffi::c_void,
workspace_size_in_bytes: usize,
stream: hipStream_t,
) -> cublasStatus_t {
unsafe {
hipblasLtMatmul(
light_handle.handle,
compute_desc,
alpha,
a,
a_desc,
b,
b_desc,
beta,
c,
c_desc,
d,
d_desc,
&algo,
workspace,
workspace_size_in_bytes,
stream,
)
}?;
Ok(())
}
pub(crate) fn matmul_algo_get_heuristic(
light_handle: &Handle,
operation_desc: hipblasLtMatmulDesc_t,
a_desc: hipblasLtMatrixLayout_t,
b_desc: hipblasLtMatrixLayout_t,
c_desc: hipblasLtMatrixLayout_t,
d_desc: hipblasLtMatrixLayout_t,
preference: hipblasLtMatmulPreference_t,
requested_algo_count: ::core::ffi::c_int,
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
return_algo_count: &mut ::core::ffi::c_int,
) -> cublasStatus_t {
let mut hip_algos = vec![unsafe { mem::zeroed() }; requested_algo_count as usize];
unsafe {
hipblasLtMatmulAlgoGetHeuristic(
light_handle.handle,
operation_desc,
a_desc,
b_desc,
c_desc,
d_desc,
preference,
requested_algo_count,
hip_algos.as_mut_ptr(),
return_algo_count,
)
}?;
if *return_algo_count as usize > hip_algos.len() {
return cublasStatus_t::ERROR_INTERNAL_ERROR;
}
for (idx, hip_algo) in hip_algos
.into_iter()
.take(*return_algo_count as usize)
.enumerate()
{
let heuristic_results_array: *mut cublasLtMatmulHeuristicResult_t = heuristic_results_array;
let result = unsafe { &mut *heuristic_results_array.add(idx) };
result.algo = cuda_algo_from_hip(hip_algo.algo);
result.workspaceSize = hip_algo.workspaceSize;
result.state = hip_algo.state.map_err(|e| cublasError_t::from(e));
result.wavesCount = hip_algo.wavesCount;
}
Ok(())
}
pub(crate) fn matmul_desc_create(
matmul_desc: &mut hipblasLtMatmulDesc_t,
compute_type: hipblasComputeType_t,
scale_type: hipDataType,
) -> cublasStatus_t {
unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
Ok(())
}
pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
Ok(())
}
pub(crate) fn matmul_desc_set_attribute(
matmul_desc: hipblasLtMatmulDesc_t,
attr: cublasLtMatmulDescAttributes_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
) -> cublasStatus_t {
if attr == cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_SCALE_TYPE {
if size_in_bytes != 4 {
return cublasStatus_t::ERROR_INVALID_VALUE;
}
let scale_type = cudaDataType_t(unsafe { *buf.cast() });
if scale_type != cudaDataType_t::CUDA_R_32F {
return cublasStatus_t::ERROR_NOT_SUPPORTED;
}
return Ok(());
}
let hip_attr = FromCuda::<_, cublasError_t>::from_cuda(&attr)?;
match hip_attr {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA => {
convert_and_set_attribute::<cublasOperation_t, hipblasOperation_t>(
matmul_desc,
buf,
hip_attr,
)?
}
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB => {
convert_and_set_attribute::<cublasOperation_t, hipblasOperation_t>(
matmul_desc,
buf,
hip_attr,
)?
}
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE => {
convert_and_set_attribute::<cublasLtEpilogue_t, hipblasLtEpilogue_t>(
matmul_desc,
buf,
hip_attr,
)?
}
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
}
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
}
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
}
Ok(())
}
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
matmul_desc: hipblasLtMatmulDesc_t,
buf: *const std::ffi::c_void,
hip_attr: hipblasLtMatmulDescAttributes_t,
) -> Result<(), cublasError_t>
where
HipType: FromCuda<'a, CudaType, cuda_types::cublas::cublasError_t>,
{
let cublas_operation: &CudaType =
unsafe { buf.cast::<CudaType>().as_ref() }.ok_or(cublasError_t::INVALID_VALUE)?;
let hip_operation: HipType = FromCuda::<_, cublasError_t>::from_cuda(cublas_operation)?;
let hip_buf: *const HipType = &hip_operation;
unsafe {
hipblasLtMatmulDescSetAttribute(
matmul_desc,
hip_attr,
hip_buf.cast(),
mem::size_of::<HipType>(),
)
}?;
Ok(())
}
pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
Ok(())
}
pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
Ok(())
}
pub(crate) fn matmul_preference_set_attribute(
pref: hipblasLtMatmulPreference_t,
attr: hipblasLtMatmulPreferenceAttributes_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
Ok(())
}
pub(crate) fn matrix_layout_create(
mat_layout: &mut hipblasLtMatrixLayout_t,
type_: hipDataType,
rows: u64,
cols: u64,
ld: i64,
) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
Ok(())
}
pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
Ok(())
}
pub(crate) fn matrix_layout_set_attribute(
mat_layout: hipblasLtMatrixLayout_t,
attr: hipblasLtMatrixLayoutAttribute_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
Ok(())
}

View file

@ -42,7 +42,22 @@ macro_rules! implemented_unmapped {
cuda_macros::cublaslt_function_declarations!(
unimplemented,
implemented <= [cublasLtCreate, cublasLtDestroy,],
implemented
<= [
cublasLtCreate,
cublasLtDestroy,
cublasLtMatmul,
cublasLtMatmulAlgoGetHeuristic,
cublasLtMatmulDescCreate,
cublasLtMatmulDescDestroy,
cublasLtMatmulDescSetAttribute,
cublasLtMatmulPreferenceCreate,
cublasLtMatmulPreferenceDestroy,
cublasLtMatmulPreferenceSetAttribute,
cublasLtMatrixLayoutCreate,
cublasLtMatrixLayoutDestroy,
cublasLtMatrixLayoutSetAttribute,
],
implemented_unmapped
<= [
cublasLtDisableCpuInstructionsSetMask,

View file

@ -8,4 +8,5 @@ edition = "2021"
cuda_types = { path = "../cuda_types" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
rocblas-sys = { path = "../ext/rocblas-sys" }
hipblaslt-sys = { path = "../ext/hipblaslt-sys" }
dark_api = { path = "../dark_api" }

View file

@ -1,12 +1,13 @@
use cuda_types::{
cublas::*,
cublaslt::cublasLtHandle_t,
cublaslt::*,
cuda::*,
dark_api::{FatbinHeader, FatbincWrapper},
nvml::*,
};
use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule};
use hip_runtime_sys::*;
use hipblaslt_sys::*;
use rocblas_sys::*;
use std::{
ffi::{c_void, CStr},
@ -166,6 +167,10 @@ from_cuda_nop!(
nvmlFieldValue_t,
nvmlGpuFabricInfo_t,
cublasLtHandle_t,
cublasLtMatmulDesc_t,
cublasLtMatmulPreference_t,
cublasLtMatrixLayout_t,
cublasLtMatmulDescAttributes_t,
CUmemAllocationGranularity_flags,
CUmemAllocationProp,
CUresult
@ -185,7 +190,10 @@ from_cuda_transmute!(
CUstreamCaptureMode => hipStreamCaptureMode,
CUgraphNode => hipGraphNode_t,
CUgraphExec => hipGraphExec_t,
CUkernel => hipFunction_t
CUkernel => hipFunction_t,
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
);
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
@ -311,6 +319,249 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cuda_types::cublas::cublasGemmAlgo_t, E>
}
}
// These have the same values, so it might be okay to use from_cuda_transmute
impl<'a, E: CudaErrorType> FromCuda<'a, cudaDataType, E> for hipDataType {
fn from_cuda(t: &'a cudaDataType) -> Result<Self, E> {
Ok(match *t {
cudaDataType::CUDA_R_16F => hipDataType::HIP_R_16F,
cudaDataType::CUDA_C_16F => hipDataType::HIP_C_16F,
cudaDataType::CUDA_R_16BF => hipDataType::HIP_R_16BF,
cudaDataType::CUDA_C_16BF => hipDataType::HIP_C_16BF,
cudaDataType::CUDA_R_32F => hipDataType::HIP_R_32F,
cudaDataType::CUDA_C_32F => hipDataType::HIP_C_32F,
cudaDataType::CUDA_R_64F => hipDataType::HIP_R_64F,
cudaDataType::CUDA_C_64F => hipDataType::HIP_C_64F,
cudaDataType::CUDA_R_8I => hipDataType::HIP_R_8I,
cudaDataType::CUDA_C_8I => hipDataType::HIP_C_8I,
cudaDataType::CUDA_R_8U => hipDataType::HIP_R_8U,
cudaDataType::CUDA_C_8U => hipDataType::HIP_C_8U,
cudaDataType::CUDA_R_32I => hipDataType::HIP_R_32I,
cudaDataType::CUDA_C_32I => hipDataType::HIP_C_32I,
cudaDataType::CUDA_R_8F_E4M3 => hipDataType::HIP_R_8F_E4M3,
cudaDataType::CUDA_R_8F_E5M2 => hipDataType::HIP_R_8F_E5M2,
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasComputeType_t, E> for hipblasComputeType_t {
fn from_cuda(t: &'a cublasComputeType_t) -> Result<Self, E> {
Ok(match *t {
cublasComputeType_t::CUBLAS_COMPUTE_16F => hipblasComputeType_t::HIPBLAS_COMPUTE_16F,
cublasComputeType_t::CUBLAS_COMPUTE_16F_PEDANTIC => {
hipblasComputeType_t::HIPBLAS_COMPUTE_16F_PEDANTIC
}
cublasComputeType_t::CUBLAS_COMPUTE_32F => hipblasComputeType_t::HIPBLAS_COMPUTE_32F,
cublasComputeType_t::CUBLAS_COMPUTE_32F_PEDANTIC => {
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_PEDANTIC
}
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F => {
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16F
}
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF => {
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16BF
}
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 => {
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_TF32
}
cublasComputeType_t::CUBLAS_COMPUTE_64F => hipblasComputeType_t::HIPBLAS_COMPUTE_64F,
cublasComputeType_t::CUBLAS_COMPUTE_64F_PEDANTIC => {
hipblasComputeType_t::HIPBLAS_COMPUTE_64F_PEDANTIC
}
cublasComputeType_t::CUBLAS_COMPUTE_32I => hipblasComputeType_t::HIPBLAS_COMPUTE_32I,
cublasComputeType_t::CUBLAS_COMPUTE_32I_PEDANTIC => {
hipblasComputeType_t::HIPBLAS_COMPUTE_32I_PEDANTIC
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulDescAttributes_t, E>
for hipblasLtMatmulDescAttributes_t
{
fn from_cuda(t: &'a cuda_types::cublaslt::cublasLtMatmulDescAttributes_t) -> Result<Self, E> {
Ok(match *t {
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_POINTER_MODE => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_POINTER_MODE
}
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_AMAX_D_POINTER => {
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatrixLayoutAttribute_t, E>
for hipblasLtMatrixLayoutAttribute_t
{
fn from_cuda(t: &'a cublasLtMatrixLayoutAttribute_t) -> Result<Self, E> {
Ok(match *t {
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_TYPE => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_TYPE
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ORDER => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ROWS => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ROWS
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_COLS => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_COLS
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_LD => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_LD
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
}
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET => {
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulPreferenceAttributes_t, E>
for hipblasLtMatmulPreferenceAttributes_t
{
fn from_cuda(t: &'a cublasLtMatmulPreferenceAttributes_t) -> Result<Self, E> {
Ok(match *t {
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_SEARCH_MODE => {
hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_SEARCH_MODE
}
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES => {
hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, *const cublasLtMatmulAlgo_t, E> for hipblasLtMatmulAlgo_t {
fn from_cuda(t: &'a *const cublasLtMatmulAlgo_t) -> Result<Self, E> {
// We assume the algo came from hip_algo_to_cuda so we can discard the last six bytes
let cuda_algo = match unsafe { t.as_ref() } {
Some(algo) => algo,
None => return Err(E::INVALID_VALUE),
};
let mut hip = hipblasLtMatmulAlgo_t {
data: [0; 16],
max_workspace_bytes: cuda_algo.data[2] as usize,
};
hip.data[..8].copy_from_slice(&cuda_algo.data[0].to_ne_bytes());
hip.data[8..].copy_from_slice(&cuda_algo.data[1].to_ne_bytes());
Ok(hip)
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for hipblasOperation_t {
fn from_cuda(t: &'a cublasOperation_t) -> Result<Self, E> {
Ok(match *t {
cublasOperation_t::CUBLAS_OP_N => hipblasOperation_t::HIPBLAS_OP_N,
cublasOperation_t::CUBLAS_OP_T => hipblasOperation_t::HIPBLAS_OP_T,
cublasOperation_t::CUBLAS_OP_C => hipblasOperation_t::HIPBLAS_OP_C,
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtEpilogue_t, E> for hipblasLtEpilogue_t {
fn from_cuda(t: &'a cublasLtEpilogue_t) -> Result<Self, E> {
Ok(match *t {
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DEFAULT
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BIAS
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU_BIAS
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_BIAS
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX_BIAS
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU_BGRAD
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADA
}
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB => {
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADB
}
_ => return Err(E::NOT_SUPPORTED),
})
}
}
impl<'a, E: CudaErrorType> FromCuda<'a, *mut cublasLtMatmulHeuristicResult_t, E>
for &'a mut cublasLtMatmulHeuristicResult_t
{
fn from_cuda(x: &'a *mut cublasLtMatmulHeuristicResult_t) -> Result<Self, E> {
match unsafe { x.as_mut() } {
Some(x) => Ok(x),
None => Err(E::INVALID_VALUE),
}
}
}
/// Represents an object that can be sent across the API boundary.
///
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a