mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-21 16:59:04 +00:00
Add support for cuBLASLt functions used by llm.c (#512)
This commit is contained in:
parent
5185138596
commit
571dad0972
6 changed files with 522 additions and 9 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -3748,6 +3748,8 @@ version = "0.0.0"
|
|||
dependencies = [
|
||||
"cuda_macros",
|
||||
"cuda_types",
|
||||
"hip_runtime-sys",
|
||||
"hipblaslt-sys",
|
||||
"zluda_common",
|
||||
]
|
||||
|
||||
|
@ -3769,6 +3771,7 @@ dependencies = [
|
|||
"cuda_types",
|
||||
"dark_api",
|
||||
"hip_runtime-sys",
|
||||
"hipblaslt-sys",
|
||||
"rocblas-sys",
|
||||
]
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ name = "cublaslt"
|
|||
[dependencies]
|
||||
cuda_macros = { path = "../cuda_macros" }
|
||||
cuda_types = { path = "../cuda_types" }
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
hipblaslt-sys = { path = "../ext/hipblaslt-sys" }
|
||||
zluda_common = { path = "../zluda_common" }
|
||||
|
||||
[package.metadata.zluda]
|
||||
|
|
|
@ -1,8 +1,19 @@
|
|||
use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t};
|
||||
use zluda_common::{from_cuda_object, ZludaObject};
|
||||
use cuda_types::{cublas::*, cublaslt::*};
|
||||
use hip_runtime_sys::hipStream_t;
|
||||
use hipblaslt_sys::*;
|
||||
use std::mem;
|
||||
use zluda_common::{from_cuda_object, FromCuda, ZludaObject};
|
||||
|
||||
pub struct Handle {
|
||||
_handle: usize,
|
||||
handle: hipblasLtHandle_t,
|
||||
}
|
||||
|
||||
impl Handle {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
handle: unsafe { mem::zeroed() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZludaObject for Handle {
|
||||
|
@ -48,11 +59,241 @@ pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> :
|
|||
todo!()
|
||||
}
|
||||
|
||||
pub(crate) fn create(handle: &mut cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
|
||||
*handle = Handle { _handle: 0 }.wrap();
|
||||
pub(crate) fn create(handle: &mut cublasLtHandle_t) -> cublasStatus_t {
|
||||
let mut zluda_blaslt_handle = Handle::new();
|
||||
unsafe { hipblasLtCreate(&mut zluda_blaslt_handle.handle) }?;
|
||||
*handle = Handle::wrap(zluda_blaslt_handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn destroy(handle: cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
|
||||
pub(crate) fn destroy(handle: cublasLtHandle_t) -> cublasStatus_t {
|
||||
zluda_common::drop_checked::<Handle>(handle)
|
||||
}
|
||||
|
||||
fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
|
||||
let mut cuda = cublasLtMatmulAlgo_t { data: [0; 8] };
|
||||
let (chunks, _) = hip.data.as_chunks::<8>();
|
||||
cuda.data[0] = u64::from_ne_bytes(chunks[0]);
|
||||
cuda.data[1] = u64::from_ne_bytes(chunks[1]);
|
||||
cuda.data[2] = hip.max_workspace_bytes as u64;
|
||||
cuda
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
light_handle: &Handle,
|
||||
compute_desc: hipblasLtMatmulDesc_t,
|
||||
alpha: *const ::core::ffi::c_void,
|
||||
a: *const ::core::ffi::c_void,
|
||||
a_desc: hipblasLtMatrixLayout_t,
|
||||
b: *const ::core::ffi::c_void,
|
||||
b_desc: hipblasLtMatrixLayout_t,
|
||||
beta: *const ::core::ffi::c_void,
|
||||
c: *const ::core::ffi::c_void,
|
||||
c_desc: hipblasLtMatrixLayout_t,
|
||||
d: *mut ::core::ffi::c_void,
|
||||
d_desc: hipblasLtMatrixLayout_t,
|
||||
algo: hipblasLtMatmulAlgo_t,
|
||||
workspace: *mut ::core::ffi::c_void,
|
||||
workspace_size_in_bytes: usize,
|
||||
stream: hipStream_t,
|
||||
) -> cublasStatus_t {
|
||||
unsafe {
|
||||
hipblasLtMatmul(
|
||||
light_handle.handle,
|
||||
compute_desc,
|
||||
alpha,
|
||||
a,
|
||||
a_desc,
|
||||
b,
|
||||
b_desc,
|
||||
beta,
|
||||
c,
|
||||
c_desc,
|
||||
d,
|
||||
d_desc,
|
||||
&algo,
|
||||
workspace,
|
||||
workspace_size_in_bytes,
|
||||
stream,
|
||||
)
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_algo_get_heuristic(
|
||||
light_handle: &Handle,
|
||||
operation_desc: hipblasLtMatmulDesc_t,
|
||||
a_desc: hipblasLtMatrixLayout_t,
|
||||
b_desc: hipblasLtMatrixLayout_t,
|
||||
c_desc: hipblasLtMatrixLayout_t,
|
||||
d_desc: hipblasLtMatrixLayout_t,
|
||||
preference: hipblasLtMatmulPreference_t,
|
||||
requested_algo_count: ::core::ffi::c_int,
|
||||
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
|
||||
return_algo_count: &mut ::core::ffi::c_int,
|
||||
) -> cublasStatus_t {
|
||||
let mut hip_algos = vec![unsafe { mem::zeroed() }; requested_algo_count as usize];
|
||||
unsafe {
|
||||
hipblasLtMatmulAlgoGetHeuristic(
|
||||
light_handle.handle,
|
||||
operation_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
d_desc,
|
||||
preference,
|
||||
requested_algo_count,
|
||||
hip_algos.as_mut_ptr(),
|
||||
return_algo_count,
|
||||
)
|
||||
}?;
|
||||
if *return_algo_count as usize > hip_algos.len() {
|
||||
return cublasStatus_t::ERROR_INTERNAL_ERROR;
|
||||
}
|
||||
for (idx, hip_algo) in hip_algos
|
||||
.into_iter()
|
||||
.take(*return_algo_count as usize)
|
||||
.enumerate()
|
||||
{
|
||||
let heuristic_results_array: *mut cublasLtMatmulHeuristicResult_t = heuristic_results_array;
|
||||
let result = unsafe { &mut *heuristic_results_array.add(idx) };
|
||||
result.algo = cuda_algo_from_hip(hip_algo.algo);
|
||||
result.workspaceSize = hip_algo.workspaceSize;
|
||||
result.state = hip_algo.state.map_err(|e| cublasError_t::from(e));
|
||||
result.wavesCount = hip_algo.wavesCount;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_desc_create(
|
||||
matmul_desc: &mut hipblasLtMatmulDesc_t,
|
||||
compute_type: hipblasComputeType_t,
|
||||
scale_type: hipDataType,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_desc_set_attribute(
|
||||
matmul_desc: hipblasLtMatmulDesc_t,
|
||||
attr: cublasLtMatmulDescAttributes_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
) -> cublasStatus_t {
|
||||
if attr == cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_SCALE_TYPE {
|
||||
if size_in_bytes != 4 {
|
||||
return cublasStatus_t::ERROR_INVALID_VALUE;
|
||||
}
|
||||
let scale_type = cudaDataType_t(unsafe { *buf.cast() });
|
||||
if scale_type != cudaDataType_t::CUDA_R_32F {
|
||||
return cublasStatus_t::ERROR_NOT_SUPPORTED;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
let hip_attr = FromCuda::<_, cublasError_t>::from_cuda(&attr)?;
|
||||
match hip_attr {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA => {
|
||||
convert_and_set_attribute::<cublasOperation_t, hipblasOperation_t>(
|
||||
matmul_desc,
|
||||
buf,
|
||||
hip_attr,
|
||||
)?
|
||||
}
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB => {
|
||||
convert_and_set_attribute::<cublasOperation_t, hipblasOperation_t>(
|
||||
matmul_desc,
|
||||
buf,
|
||||
hip_attr,
|
||||
)?
|
||||
}
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE => {
|
||||
convert_and_set_attribute::<cublasLtEpilogue_t, hipblasLtEpilogue_t>(
|
||||
matmul_desc,
|
||||
buf,
|
||||
hip_attr,
|
||||
)?
|
||||
}
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
|
||||
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
|
||||
}
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
|
||||
unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
|
||||
}
|
||||
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
|
||||
matmul_desc: hipblasLtMatmulDesc_t,
|
||||
buf: *const std::ffi::c_void,
|
||||
hip_attr: hipblasLtMatmulDescAttributes_t,
|
||||
) -> Result<(), cublasError_t>
|
||||
where
|
||||
HipType: FromCuda<'a, CudaType, cuda_types::cublas::cublasError_t>,
|
||||
{
|
||||
let cublas_operation: &CudaType =
|
||||
unsafe { buf.cast::<CudaType>().as_ref() }.ok_or(cublasError_t::INVALID_VALUE)?;
|
||||
let hip_operation: HipType = FromCuda::<_, cublasError_t>::from_cuda(cublas_operation)?;
|
||||
let hip_buf: *const HipType = &hip_operation;
|
||||
unsafe {
|
||||
hipblasLtMatmulDescSetAttribute(
|
||||
matmul_desc,
|
||||
hip_attr,
|
||||
hip_buf.cast(),
|
||||
mem::size_of::<HipType>(),
|
||||
)
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_set_attribute(
|
||||
pref: hipblasLtMatmulPreference_t,
|
||||
attr: hipblasLtMatmulPreferenceAttributes_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_create(
|
||||
mat_layout: &mut hipblasLtMatrixLayout_t,
|
||||
type_: hipDataType,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
ld: i64,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_set_attribute(
|
||||
mat_layout: hipblasLtMatrixLayout_t,
|
||||
attr: hipblasLtMatrixLayoutAttribute_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -42,7 +42,22 @@ macro_rules! implemented_unmapped {
|
|||
|
||||
cuda_macros::cublaslt_function_declarations!(
|
||||
unimplemented,
|
||||
implemented <= [cublasLtCreate, cublasLtDestroy,],
|
||||
implemented
|
||||
<= [
|
||||
cublasLtCreate,
|
||||
cublasLtDestroy,
|
||||
cublasLtMatmul,
|
||||
cublasLtMatmulAlgoGetHeuristic,
|
||||
cublasLtMatmulDescCreate,
|
||||
cublasLtMatmulDescDestroy,
|
||||
cublasLtMatmulDescSetAttribute,
|
||||
cublasLtMatmulPreferenceCreate,
|
||||
cublasLtMatmulPreferenceDestroy,
|
||||
cublasLtMatmulPreferenceSetAttribute,
|
||||
cublasLtMatrixLayoutCreate,
|
||||
cublasLtMatrixLayoutDestroy,
|
||||
cublasLtMatrixLayoutSetAttribute,
|
||||
],
|
||||
implemented_unmapped
|
||||
<= [
|
||||
cublasLtDisableCpuInstructionsSetMask,
|
||||
|
|
|
@ -8,4 +8,5 @@ edition = "2021"
|
|||
cuda_types = { path = "../cuda_types" }
|
||||
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
|
||||
rocblas-sys = { path = "../ext/rocblas-sys" }
|
||||
hipblaslt-sys = { path = "../ext/hipblaslt-sys" }
|
||||
dark_api = { path = "../dark_api" }
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use cuda_types::{
|
||||
cublas::*,
|
||||
cublaslt::cublasLtHandle_t,
|
||||
cublaslt::*,
|
||||
cuda::*,
|
||||
dark_api::{FatbinHeader, FatbincWrapper},
|
||||
nvml::*,
|
||||
};
|
||||
use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule};
|
||||
use hip_runtime_sys::*;
|
||||
use hipblaslt_sys::*;
|
||||
use rocblas_sys::*;
|
||||
use std::{
|
||||
ffi::{c_void, CStr},
|
||||
|
@ -166,6 +167,10 @@ from_cuda_nop!(
|
|||
nvmlFieldValue_t,
|
||||
nvmlGpuFabricInfo_t,
|
||||
cublasLtHandle_t,
|
||||
cublasLtMatmulDesc_t,
|
||||
cublasLtMatmulPreference_t,
|
||||
cublasLtMatrixLayout_t,
|
||||
cublasLtMatmulDescAttributes_t,
|
||||
CUmemAllocationGranularity_flags,
|
||||
CUmemAllocationProp,
|
||||
CUresult
|
||||
|
@ -185,7 +190,10 @@ from_cuda_transmute!(
|
|||
CUstreamCaptureMode => hipStreamCaptureMode,
|
||||
CUgraphNode => hipGraphNode_t,
|
||||
CUgraphExec => hipGraphExec_t,
|
||||
CUkernel => hipFunction_t
|
||||
CUkernel => hipFunction_t,
|
||||
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
|
||||
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
|
||||
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
|
||||
);
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
|
||||
|
@ -311,6 +319,249 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cuda_types::cublas::cublasGemmAlgo_t, E>
|
|||
}
|
||||
}
|
||||
|
||||
// These have the same values, so it might be okay to use from_cuda_transmute
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cudaDataType, E> for hipDataType {
|
||||
fn from_cuda(t: &'a cudaDataType) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cudaDataType::CUDA_R_16F => hipDataType::HIP_R_16F,
|
||||
cudaDataType::CUDA_C_16F => hipDataType::HIP_C_16F,
|
||||
cudaDataType::CUDA_R_16BF => hipDataType::HIP_R_16BF,
|
||||
cudaDataType::CUDA_C_16BF => hipDataType::HIP_C_16BF,
|
||||
cudaDataType::CUDA_R_32F => hipDataType::HIP_R_32F,
|
||||
cudaDataType::CUDA_C_32F => hipDataType::HIP_C_32F,
|
||||
cudaDataType::CUDA_R_64F => hipDataType::HIP_R_64F,
|
||||
cudaDataType::CUDA_C_64F => hipDataType::HIP_C_64F,
|
||||
cudaDataType::CUDA_R_8I => hipDataType::HIP_R_8I,
|
||||
cudaDataType::CUDA_C_8I => hipDataType::HIP_C_8I,
|
||||
cudaDataType::CUDA_R_8U => hipDataType::HIP_R_8U,
|
||||
cudaDataType::CUDA_C_8U => hipDataType::HIP_C_8U,
|
||||
cudaDataType::CUDA_R_32I => hipDataType::HIP_R_32I,
|
||||
cudaDataType::CUDA_C_32I => hipDataType::HIP_C_32I,
|
||||
cudaDataType::CUDA_R_8F_E4M3 => hipDataType::HIP_R_8F_E4M3,
|
||||
cudaDataType::CUDA_R_8F_E5M2 => hipDataType::HIP_R_8F_E5M2,
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasComputeType_t, E> for hipblasComputeType_t {
|
||||
fn from_cuda(t: &'a cublasComputeType_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_16F => hipblasComputeType_t::HIPBLAS_COMPUTE_16F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_16F_PEDANTIC => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_16F_PEDANTIC
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F => hipblasComputeType_t::HIPBLAS_COMPUTE_32F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_PEDANTIC => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_PEDANTIC
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16F
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16BF
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_TF32
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_64F => hipblasComputeType_t::HIPBLAS_COMPUTE_64F,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_64F_PEDANTIC => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_64F_PEDANTIC
|
||||
}
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32I => hipblasComputeType_t::HIPBLAS_COMPUTE_32I,
|
||||
cublasComputeType_t::CUBLAS_COMPUTE_32I_PEDANTIC => {
|
||||
hipblasComputeType_t::HIPBLAS_COMPUTE_32I_PEDANTIC
|
||||
}
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulDescAttributes_t, E>
|
||||
for hipblasLtMatmulDescAttributes_t
|
||||
{
|
||||
fn from_cuda(t: &'a cuda_types::cublaslt::cublasLtMatmulDescAttributes_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_POINTER_MODE => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_POINTER_MODE
|
||||
}
|
||||
cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_AMAX_D_POINTER => {
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER
|
||||
}
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatrixLayoutAttribute_t, E>
|
||||
for hipblasLtMatrixLayoutAttribute_t
|
||||
{
|
||||
fn from_cuda(t: &'a cublasLtMatrixLayoutAttribute_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_TYPE => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_TYPE
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ORDER => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ROWS => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ROWS
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_COLS => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_COLS
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_LD => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_LD
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT
|
||||
}
|
||||
cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET => {
|
||||
hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
|
||||
}
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulPreferenceAttributes_t, E>
|
||||
for hipblasLtMatmulPreferenceAttributes_t
|
||||
{
|
||||
fn from_cuda(t: &'a cublasLtMatmulPreferenceAttributes_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_SEARCH_MODE => {
|
||||
hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_SEARCH_MODE
|
||||
}
|
||||
cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES => {
|
||||
hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
|
||||
}
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, *const cublasLtMatmulAlgo_t, E> for hipblasLtMatmulAlgo_t {
|
||||
fn from_cuda(t: &'a *const cublasLtMatmulAlgo_t) -> Result<Self, E> {
|
||||
// We assume the algo came from hip_algo_to_cuda so we can discard the last six bytes
|
||||
let cuda_algo = match unsafe { t.as_ref() } {
|
||||
Some(algo) => algo,
|
||||
None => return Err(E::INVALID_VALUE),
|
||||
};
|
||||
let mut hip = hipblasLtMatmulAlgo_t {
|
||||
data: [0; 16],
|
||||
max_workspace_bytes: cuda_algo.data[2] as usize,
|
||||
};
|
||||
hip.data[..8].copy_from_slice(&cuda_algo.data[0].to_ne_bytes());
|
||||
hip.data[8..].copy_from_slice(&cuda_algo.data[1].to_ne_bytes());
|
||||
Ok(hip)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for hipblasOperation_t {
|
||||
fn from_cuda(t: &'a cublasOperation_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasOperation_t::CUBLAS_OP_N => hipblasOperation_t::HIPBLAS_OP_N,
|
||||
cublasOperation_t::CUBLAS_OP_T => hipblasOperation_t::HIPBLAS_OP_T,
|
||||
cublasOperation_t::CUBLAS_OP_C => hipblasOperation_t::HIPBLAS_OP_C,
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtEpilogue_t, E> for hipblasLtEpilogue_t {
|
||||
fn from_cuda(t: &'a cublasLtEpilogue_t) -> Result<Self, E> {
|
||||
Ok(match *t {
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DEFAULT
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BIAS
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU_BIAS
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_BIAS
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX_BIAS
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU_BGRAD
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADA
|
||||
}
|
||||
cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB => {
|
||||
hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADB
|
||||
}
|
||||
_ => return Err(E::NOT_SUPPORTED),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, *mut cublasLtMatmulHeuristicResult_t, E>
|
||||
for &'a mut cublasLtMatmulHeuristicResult_t
|
||||
{
|
||||
fn from_cuda(x: &'a *mut cublasLtMatmulHeuristicResult_t) -> Result<Self, E> {
|
||||
match unsafe { x.as_mut() } {
|
||||
Some(x) => Ok(x),
|
||||
None => Err(E::INVALID_VALUE),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an object that can be sent across the API boundary.
|
||||
///
|
||||
/// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue