diff --git a/Cargo.lock b/Cargo.lock index 32cf398..baddd3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3748,6 +3748,8 @@ version = "0.0.0" dependencies = [ "cuda_macros", "cuda_types", + "hip_runtime-sys", + "hipblaslt-sys", "zluda_common", ] @@ -3769,6 +3771,7 @@ dependencies = [ "cuda_types", "dark_api", "hip_runtime-sys", + "hipblaslt-sys", "rocblas-sys", ] diff --git a/zluda_blaslt/Cargo.toml b/zluda_blaslt/Cargo.toml index b7d3ebe..6253648 100644 --- a/zluda_blaslt/Cargo.toml +++ b/zluda_blaslt/Cargo.toml @@ -10,6 +10,8 @@ name = "cublaslt" [dependencies] cuda_macros = { path = "../cuda_macros" } cuda_types = { path = "../cuda_types" } +hip_runtime-sys = { path = "../ext/hip_runtime-sys" } +hipblaslt-sys = { path = "../ext/hipblaslt-sys" } zluda_common = { path = "../zluda_common" } [package.metadata.zluda] diff --git a/zluda_blaslt/src/impl.rs b/zluda_blaslt/src/impl.rs index 9f1c658..adf18a2 100644 --- a/zluda_blaslt/src/impl.rs +++ b/zluda_blaslt/src/impl.rs @@ -1,8 +1,19 @@ -use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t}; -use zluda_common::{from_cuda_object, ZludaObject}; +use cuda_types::{cublas::*, cublaslt::*}; +use hip_runtime_sys::hipStream_t; +use hipblaslt_sys::*; +use std::mem; +use zluda_common::{from_cuda_object, FromCuda, ZludaObject}; pub struct Handle { - _handle: usize, + handle: hipblasLtHandle_t, +} + +impl Handle { + fn new() -> Self { + Self { + handle: unsafe { mem::zeroed() }, + } + } } impl ZludaObject for Handle { @@ -48,11 +59,241 @@ pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> : todo!() } -pub(crate) fn create(handle: &mut cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t { - *handle = Handle { _handle: 0 }.wrap(); +pub(crate) fn create(handle: &mut cublasLtHandle_t) -> cublasStatus_t { + let mut zluda_blaslt_handle = Handle::new(); + unsafe { hipblasLtCreate(&mut zluda_blaslt_handle.handle) }?; + *handle = Handle::wrap(zluda_blaslt_handle); Ok(()) } -pub(crate) fn destroy(handle: cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t { +pub(crate) fn destroy(handle: cublasLtHandle_t) -> cublasStatus_t { zluda_common::drop_checked::(handle) } + +fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t { + let mut cuda = cublasLtMatmulAlgo_t { data: [0; 8] }; + let (chunks, _) = hip.data.as_chunks::<8>(); + cuda.data[0] = u64::from_ne_bytes(chunks[0]); + cuda.data[1] = u64::from_ne_bytes(chunks[1]); + cuda.data[2] = hip.max_workspace_bytes as u64; + cuda +} + +pub(crate) fn matmul( + light_handle: &Handle, + compute_desc: hipblasLtMatmulDesc_t, + alpha: *const ::core::ffi::c_void, + a: *const ::core::ffi::c_void, + a_desc: hipblasLtMatrixLayout_t, + b: *const ::core::ffi::c_void, + b_desc: hipblasLtMatrixLayout_t, + beta: *const ::core::ffi::c_void, + c: *const ::core::ffi::c_void, + c_desc: hipblasLtMatrixLayout_t, + d: *mut ::core::ffi::c_void, + d_desc: hipblasLtMatrixLayout_t, + algo: hipblasLtMatmulAlgo_t, + workspace: *mut ::core::ffi::c_void, + workspace_size_in_bytes: usize, + stream: hipStream_t, +) -> cublasStatus_t { + unsafe { + hipblasLtMatmul( + light_handle.handle, + compute_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + d, + d_desc, + &algo, + workspace, + workspace_size_in_bytes, + stream, + ) + }?; + Ok(()) +} + +pub(crate) fn matmul_algo_get_heuristic( + light_handle: &Handle, + operation_desc: hipblasLtMatmulDesc_t, + a_desc: hipblasLtMatrixLayout_t, + b_desc: hipblasLtMatrixLayout_t, + c_desc: hipblasLtMatrixLayout_t, + d_desc: hipblasLtMatrixLayout_t, + preference: hipblasLtMatmulPreference_t, + requested_algo_count: ::core::ffi::c_int, + heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t, + return_algo_count: &mut ::core::ffi::c_int, +) -> cublasStatus_t { + let mut hip_algos = vec![unsafe { mem::zeroed() }; requested_algo_count as usize]; + unsafe { + hipblasLtMatmulAlgoGetHeuristic( + light_handle.handle, + operation_desc, + a_desc, + b_desc, + c_desc, + d_desc, + preference, + requested_algo_count, + hip_algos.as_mut_ptr(), + return_algo_count, + ) + }?; + if *return_algo_count as usize > hip_algos.len() { + return cublasStatus_t::ERROR_INTERNAL_ERROR; + } + for (idx, hip_algo) in hip_algos + .into_iter() + .take(*return_algo_count as usize) + .enumerate() + { + let heuristic_results_array: *mut cublasLtMatmulHeuristicResult_t = heuristic_results_array; + let result = unsafe { &mut *heuristic_results_array.add(idx) }; + result.algo = cuda_algo_from_hip(hip_algo.algo); + result.workspaceSize = hip_algo.workspaceSize; + result.state = hip_algo.state.map_err(|e| cublasError_t::from(e)); + result.wavesCount = hip_algo.wavesCount; + } + Ok(()) +} + +pub(crate) fn matmul_desc_create( + matmul_desc: &mut hipblasLtMatmulDesc_t, + compute_type: hipblasComputeType_t, + scale_type: hipDataType, +) -> cublasStatus_t { + unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?; + Ok(()) +} + +pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?; + Ok(()) +} + +pub(crate) fn matmul_desc_set_attribute( + matmul_desc: hipblasLtMatmulDesc_t, + attr: cublasLtMatmulDescAttributes_t, + buf: *const ::core::ffi::c_void, + size_in_bytes: usize, +) -> cublasStatus_t { + if attr == cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_SCALE_TYPE { + if size_in_bytes != 4 { + return cublasStatus_t::ERROR_INVALID_VALUE; + } + let scale_type = cudaDataType_t(unsafe { *buf.cast() }); + if scale_type != cudaDataType_t::CUDA_R_32F { + return cublasStatus_t::ERROR_NOT_SUPPORTED; + } + return Ok(()); + } + let hip_attr = FromCuda::<_, cublasError_t>::from_cuda(&attr)?; + match hip_attr { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA => { + convert_and_set_attribute::( + matmul_desc, + buf, + hip_attr, + )? + } + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB => { + convert_and_set_attribute::( + matmul_desc, + buf, + hip_attr, + )? + } + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE => { + convert_and_set_attribute::( + matmul_desc, + buf, + hip_attr, + )? + } + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => { + convert_and_set_attribute::(matmul_desc, buf, hip_attr)? + } + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => { + unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }? + } + _ => return cublasStatus_t::ERROR_NOT_SUPPORTED, + } + Ok(()) +} + +fn convert_and_set_attribute<'a, CudaType: 'a, HipType>( + matmul_desc: hipblasLtMatmulDesc_t, + buf: *const std::ffi::c_void, + hip_attr: hipblasLtMatmulDescAttributes_t, +) -> Result<(), cublasError_t> +where + HipType: FromCuda<'a, CudaType, cuda_types::cublas::cublasError_t>, +{ + let cublas_operation: &CudaType = + unsafe { buf.cast::().as_ref() }.ok_or(cublasError_t::INVALID_VALUE)?; + let hip_operation: HipType = FromCuda::<_, cublasError_t>::from_cuda(cublas_operation)?; + let hip_buf: *const HipType = &hip_operation; + unsafe { + hipblasLtMatmulDescSetAttribute( + matmul_desc, + hip_attr, + hip_buf.cast(), + mem::size_of::(), + ) + }?; + Ok(()) +} + +pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulPreferenceCreate(pref) }?; + Ok(()) +} + +pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?; + Ok(()) +} + +pub(crate) fn matmul_preference_set_attribute( + pref: hipblasLtMatmulPreference_t, + attr: hipblasLtMatmulPreferenceAttributes_t, + buf: *const ::core::ffi::c_void, + size_in_bytes: usize, +) -> cublasStatus_t { + unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?; + Ok(()) +} + +pub(crate) fn matrix_layout_create( + mat_layout: &mut hipblasLtMatrixLayout_t, + type_: hipDataType, + rows: u64, + cols: u64, + ld: i64, +) -> cublasStatus_t { + unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?; + Ok(()) +} + +pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t { + unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?; + Ok(()) +} + +pub(crate) fn matrix_layout_set_attribute( + mat_layout: hipblasLtMatrixLayout_t, + attr: hipblasLtMatrixLayoutAttribute_t, + buf: *const ::core::ffi::c_void, + size_in_bytes: usize, +) -> cublasStatus_t { + unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?; + Ok(()) +} diff --git a/zluda_blaslt/src/lib.rs b/zluda_blaslt/src/lib.rs index 3fd30d5..cfef7e4 100644 --- a/zluda_blaslt/src/lib.rs +++ b/zluda_blaslt/src/lib.rs @@ -42,7 +42,22 @@ macro_rules! implemented_unmapped { cuda_macros::cublaslt_function_declarations!( unimplemented, - implemented <= [cublasLtCreate, cublasLtDestroy,], + implemented + <= [ + cublasLtCreate, + cublasLtDestroy, + cublasLtMatmul, + cublasLtMatmulAlgoGetHeuristic, + cublasLtMatmulDescCreate, + cublasLtMatmulDescDestroy, + cublasLtMatmulDescSetAttribute, + cublasLtMatmulPreferenceCreate, + cublasLtMatmulPreferenceDestroy, + cublasLtMatmulPreferenceSetAttribute, + cublasLtMatrixLayoutCreate, + cublasLtMatrixLayoutDestroy, + cublasLtMatrixLayoutSetAttribute, + ], implemented_unmapped <= [ cublasLtDisableCpuInstructionsSetMask, diff --git a/zluda_common/Cargo.toml b/zluda_common/Cargo.toml index ca70ab8..3d5655a 100644 --- a/zluda_common/Cargo.toml +++ b/zluda_common/Cargo.toml @@ -8,4 +8,5 @@ edition = "2021" cuda_types = { path = "../cuda_types" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" } rocblas-sys = { path = "../ext/rocblas-sys" } +hipblaslt-sys = { path = "../ext/hipblaslt-sys" } dark_api = { path = "../dark_api" } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 95ec415..4f8aef7 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -1,12 +1,13 @@ use cuda_types::{ cublas::*, - cublaslt::cublasLtHandle_t, + cublaslt::*, cuda::*, dark_api::{FatbinHeader, FatbincWrapper}, nvml::*, }; use dark_api::fatbin::{Fatbin, FatbinError, FatbinFile, FatbinSubmodule}; use hip_runtime_sys::*; +use hipblaslt_sys::*; use rocblas_sys::*; use std::{ ffi::{c_void, CStr}, @@ -166,6 +167,10 @@ from_cuda_nop!( nvmlFieldValue_t, nvmlGpuFabricInfo_t, cublasLtHandle_t, + cublasLtMatmulDesc_t, + cublasLtMatmulPreference_t, + cublasLtMatrixLayout_t, + cublasLtMatmulDescAttributes_t, CUmemAllocationGranularity_flags, CUmemAllocationProp, CUresult @@ -185,7 +190,10 @@ from_cuda_transmute!( CUstreamCaptureMode => hipStreamCaptureMode, CUgraphNode => hipGraphNode_t, CUgraphExec => hipGraphExec_t, - CUkernel => hipFunction_t + CUkernel => hipFunction_t, + cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t, + cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t, + cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t ); impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t { @@ -311,6 +319,249 @@ impl<'a, E: CudaErrorType> FromCuda<'a, cuda_types::cublas::cublasGemmAlgo_t, E> } } +// These have the same values, so it might be okay to use from_cuda_transmute +impl<'a, E: CudaErrorType> FromCuda<'a, cudaDataType, E> for hipDataType { + fn from_cuda(t: &'a cudaDataType) -> Result { + Ok(match *t { + cudaDataType::CUDA_R_16F => hipDataType::HIP_R_16F, + cudaDataType::CUDA_C_16F => hipDataType::HIP_C_16F, + cudaDataType::CUDA_R_16BF => hipDataType::HIP_R_16BF, + cudaDataType::CUDA_C_16BF => hipDataType::HIP_C_16BF, + cudaDataType::CUDA_R_32F => hipDataType::HIP_R_32F, + cudaDataType::CUDA_C_32F => hipDataType::HIP_C_32F, + cudaDataType::CUDA_R_64F => hipDataType::HIP_R_64F, + cudaDataType::CUDA_C_64F => hipDataType::HIP_C_64F, + cudaDataType::CUDA_R_8I => hipDataType::HIP_R_8I, + cudaDataType::CUDA_C_8I => hipDataType::HIP_C_8I, + cudaDataType::CUDA_R_8U => hipDataType::HIP_R_8U, + cudaDataType::CUDA_C_8U => hipDataType::HIP_C_8U, + cudaDataType::CUDA_R_32I => hipDataType::HIP_R_32I, + cudaDataType::CUDA_C_32I => hipDataType::HIP_C_32I, + cudaDataType::CUDA_R_8F_E4M3 => hipDataType::HIP_R_8F_E4M3, + cudaDataType::CUDA_R_8F_E5M2 => hipDataType::HIP_R_8F_E5M2, + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasComputeType_t, E> for hipblasComputeType_t { + fn from_cuda(t: &'a cublasComputeType_t) -> Result { + Ok(match *t { + cublasComputeType_t::CUBLAS_COMPUTE_16F => hipblasComputeType_t::HIPBLAS_COMPUTE_16F, + cublasComputeType_t::CUBLAS_COMPUTE_16F_PEDANTIC => { + hipblasComputeType_t::HIPBLAS_COMPUTE_16F_PEDANTIC + } + cublasComputeType_t::CUBLAS_COMPUTE_32F => hipblasComputeType_t::HIPBLAS_COMPUTE_32F, + cublasComputeType_t::CUBLAS_COMPUTE_32F_PEDANTIC => { + hipblasComputeType_t::HIPBLAS_COMPUTE_32F_PEDANTIC + } + cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F => { + hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16F + } + cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF => { + hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_16BF + } + cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 => { + hipblasComputeType_t::HIPBLAS_COMPUTE_32F_FAST_TF32 + } + cublasComputeType_t::CUBLAS_COMPUTE_64F => hipblasComputeType_t::HIPBLAS_COMPUTE_64F, + cublasComputeType_t::CUBLAS_COMPUTE_64F_PEDANTIC => { + hipblasComputeType_t::HIPBLAS_COMPUTE_64F_PEDANTIC + } + cublasComputeType_t::CUBLAS_COMPUTE_32I => hipblasComputeType_t::HIPBLAS_COMPUTE_32I, + cublasComputeType_t::CUBLAS_COMPUTE_32I_PEDANTIC => { + hipblasComputeType_t::HIPBLAS_COMPUTE_32I_PEDANTIC + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulDescAttributes_t, E> + for hipblasLtMatmulDescAttributes_t +{ + fn from_cuda(t: &'a cuda_types::cublaslt::cublasLtMatmulDescAttributes_t) -> Result { + Ok(match *t { + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSA + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_TRANSB + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_POINTER_MODE => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_POINTER_MODE + } + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_AMAX_D_POINTER => { + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatrixLayoutAttribute_t, E> + for hipblasLtMatrixLayoutAttribute_t +{ + fn from_cuda(t: &'a cublasLtMatrixLayoutAttribute_t) -> Result { + Ok(match *t { + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_TYPE => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_TYPE + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ORDER => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ORDER + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_ROWS => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_ROWS + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_COLS => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_COLS + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_LD => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_LD + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT + } + cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET => { + hipblasLtMatrixLayoutAttribute_t::HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtMatmulPreferenceAttributes_t, E> + for hipblasLtMatmulPreferenceAttributes_t +{ + fn from_cuda(t: &'a cublasLtMatmulPreferenceAttributes_t) -> Result { + Ok(match *t { + cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_SEARCH_MODE => { + hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_SEARCH_MODE + } + cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES => { + hipblasLtMatmulPreferenceAttributes_t::HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, *const cublasLtMatmulAlgo_t, E> for hipblasLtMatmulAlgo_t { + fn from_cuda(t: &'a *const cublasLtMatmulAlgo_t) -> Result { + // We assume the algo came from hip_algo_to_cuda so we can discard the last six bytes + let cuda_algo = match unsafe { t.as_ref() } { + Some(algo) => algo, + None => return Err(E::INVALID_VALUE), + }; + let mut hip = hipblasLtMatmulAlgo_t { + data: [0; 16], + max_workspace_bytes: cuda_algo.data[2] as usize, + }; + hip.data[..8].copy_from_slice(&cuda_algo.data[0].to_ne_bytes()); + hip.data[8..].copy_from_slice(&cuda_algo.data[1].to_ne_bytes()); + Ok(hip) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasOperation_t, E> for hipblasOperation_t { + fn from_cuda(t: &'a cublasOperation_t) -> Result { + Ok(match *t { + cublasOperation_t::CUBLAS_OP_N => hipblasOperation_t::HIPBLAS_OP_N, + cublasOperation_t::CUBLAS_OP_T => hipblasOperation_t::HIPBLAS_OP_T, + cublasOperation_t::CUBLAS_OP_C => hipblasOperation_t::HIPBLAS_OP_C, + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, cublasLtEpilogue_t, E> for hipblasLtEpilogue_t { + fn from_cuda(t: &'a cublasLtEpilogue_t) -> Result { + Ok(match *t { + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DEFAULT + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BIAS + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_RELU_BIAS + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_BIAS + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_GELU_AUX_BIAS + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_DGELU_BGRAD + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADA + } + cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB => { + hipblasLtEpilogue_t::HIPBLASLT_EPILOGUE_BGRADB + } + _ => return Err(E::NOT_SUPPORTED), + }) + } +} + +impl<'a, E: CudaErrorType> FromCuda<'a, *mut cublasLtMatmulHeuristicResult_t, E> + for &'a mut cublasLtMatmulHeuristicResult_t +{ + fn from_cuda(x: &'a *mut cublasLtMatmulHeuristicResult_t) -> Result { + match unsafe { x.as_mut() } { + Some(x) => Ok(x), + None => Err(E::INVALID_VALUE), + } + } +} + /// Represents an object that can be sent across the API boundary. /// /// Some CUDA calls operate on an opaque handle. For example, `cuModuleLoadData` will load a