Use transmute instead of wrapping

This commit is contained in:
Violet 2025-09-17 05:41:09 +00:00
commit 01d52fdf80
2 changed files with 51 additions and 129 deletions

View file

@ -29,84 +29,6 @@ impl ZludaObject for Handle {
from_cuda_object!(Handle); from_cuda_object!(Handle);
pub struct MatmulDesc {
desc: hipblasLtMatmulDesc_t,
}
impl MatmulDesc {
fn new() -> Self {
Self {
desc: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatmulDesc {
const COOKIE: usize = 0x4406a5f4b814f52b;
type Error = cublasError_t;
type CudaHandle = cublasLtMatmulDesc_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatmulDescDestroy(self.desc) }?;
Ok(())
}
}
from_cuda_object!(MatmulDesc);
pub struct MatmulPreference {
pref: hipblasLtMatmulPreference_t,
}
impl MatmulPreference {
fn new() -> Self {
Self {
pref: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatmulPreference {
const COOKIE: usize = 0x6a6d1c41958baa9b;
type Error = cublasError_t;
type CudaHandle = cublasLtMatmulPreference_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceDestroy(self.pref) }?;
Ok(())
}
}
from_cuda_object!(MatmulPreference);
pub struct MatrixLayout {
layout: hipblasLtMatrixLayout_t,
}
impl MatrixLayout {
fn new() -> Self {
Self {
layout: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatrixLayout {
const COOKIE: usize = 0xcf566e9656cec9b8;
type Error = cublasError_t;
type CudaHandle = cublasLtMatrixLayout_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutDestroy(self.layout) }?;
Ok(())
}
}
from_cuda_object!(MatrixLayout);
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> cublasStatus_t { pub(crate) fn unimplemented() -> cublasStatus_t {
unimplemented!() unimplemented!()
@ -159,17 +81,17 @@ fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
pub(crate) fn matmul( pub(crate) fn matmul(
light_handle: &Handle, light_handle: &Handle,
compute_desc: &MatmulDesc, compute_desc: hipblasLtMatmulDesc_t,
alpha: *const ::core::ffi::c_void, alpha: *const ::core::ffi::c_void,
a: *const ::core::ffi::c_void, a: *const ::core::ffi::c_void,
a_desc: &MatrixLayout, a_desc: hipblasLtMatrixLayout_t,
b: *const ::core::ffi::c_void, b: *const ::core::ffi::c_void,
b_desc: &MatrixLayout, b_desc: hipblasLtMatrixLayout_t,
beta: *const ::core::ffi::c_void, beta: *const ::core::ffi::c_void,
c: *const ::core::ffi::c_void, c: *const ::core::ffi::c_void,
c_desc: &MatrixLayout, c_desc: hipblasLtMatrixLayout_t,
d: *mut ::core::ffi::c_void, d: *mut ::core::ffi::c_void,
d_desc: &MatrixLayout, d_desc: hipblasLtMatrixLayout_t,
algo: hipblasLtMatmulAlgo_t, algo: hipblasLtMatmulAlgo_t,
workspace: *mut ::core::ffi::c_void, workspace: *mut ::core::ffi::c_void,
workspace_size_in_bytes: usize, workspace_size_in_bytes: usize,
@ -178,17 +100,17 @@ pub(crate) fn matmul(
unsafe { unsafe {
hipblasLtMatmul( hipblasLtMatmul(
light_handle.handle, light_handle.handle,
compute_desc.desc, compute_desc,
alpha, alpha,
a, a,
a_desc.layout, a_desc,
b, b,
b_desc.layout, b_desc,
beta, beta,
c, c,
c_desc.layout, c_desc,
d, d,
d_desc.layout, d_desc,
&algo, &algo,
workspace, workspace,
workspace_size_in_bytes, workspace_size_in_bytes,
@ -200,12 +122,12 @@ pub(crate) fn matmul(
pub(crate) fn matmul_algo_get_heuristic( pub(crate) fn matmul_algo_get_heuristic(
light_handle: &Handle, light_handle: &Handle,
operation_desc: &MatmulDesc, operation_desc: hipblasLtMatmulDesc_t,
a_desc: &MatrixLayout, a_desc: hipblasLtMatrixLayout_t,
b_desc: &MatrixLayout, b_desc: hipblasLtMatrixLayout_t,
c_desc: &MatrixLayout, c_desc: hipblasLtMatrixLayout_t,
d_desc: &MatrixLayout, d_desc: hipblasLtMatrixLayout_t,
preference: &MatmulPreference, preference: hipblasLtMatmulPreference_t,
requested_algo_count: ::core::ffi::c_int, requested_algo_count: ::core::ffi::c_int,
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t, heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
return_algo_count: &mut ::core::ffi::c_int, return_algo_count: &mut ::core::ffi::c_int,
@ -214,12 +136,12 @@ pub(crate) fn matmul_algo_get_heuristic(
unsafe { unsafe {
hipblasLtMatmulAlgoGetHeuristic( hipblasLtMatmulAlgoGetHeuristic(
light_handle.handle, light_handle.handle,
operation_desc.desc, operation_desc,
a_desc.layout, a_desc,
b_desc.layout, b_desc,
c_desc.layout, c_desc,
d_desc.layout, d_desc,
preference.pref, preference,
requested_algo_count, requested_algo_count,
hip_algos.as_mut_ptr(), hip_algos.as_mut_ptr(),
return_algo_count, return_algo_count,
@ -244,22 +166,21 @@ pub(crate) fn matmul_algo_get_heuristic(
} }
pub(crate) fn matmul_desc_create( pub(crate) fn matmul_desc_create(
matmul_desc: &mut cublasLtMatmulDesc_t, matmul_desc: &mut hipblasLtMatmulDesc_t,
compute_type: hipblasComputeType_t, compute_type: hipblasComputeType_t,
scale_type: hipDataType, scale_type: hipDataType,
) -> cublasStatus_t { ) -> cublasStatus_t {
let mut zluda_blaslt_desc = MatmulDesc::new(); unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
unsafe { hipblasLtMatmulDescCreate(&mut zluda_blaslt_desc.desc, compute_type, scale_type) }?;
*matmul_desc = MatmulDesc::wrap(zluda_blaslt_desc);
Ok(()) Ok(())
} }
pub(crate) fn matmul_desc_destroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t { pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatmulDesc>(matmul_desc) unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
Ok(())
} }
pub(crate) fn matmul_desc_set_attribute( pub(crate) fn matmul_desc_set_attribute(
matmul_desc: &MatmulDesc, matmul_desc: hipblasLtMatmulDesc_t,
attr: cublasLtMatmulDescAttributes_t, attr: cublasLtMatmulDescAttributes_t,
buf: *const ::core::ffi::c_void, buf: *const ::core::ffi::c_void,
size_in_bytes: usize, size_in_bytes: usize,
@ -300,16 +221,16 @@ pub(crate) fn matmul_desc_set_attribute(
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => { hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)? convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
} }
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => unsafe { hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
hipblasLtMatmulDescSetAttribute(matmul_desc.desc, hip_attr, buf, size_in_bytes) unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
}?, }
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED, _ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
} }
Ok(()) Ok(())
} }
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>( fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
matmul_desc: &MatmulDesc, matmul_desc: hipblasLtMatmulDesc_t,
buf: *const std::ffi::c_void, buf: *const std::ffi::c_void,
hip_attr: hipblasLtMatmulDescAttributes_t, hip_attr: hipblasLtMatmulDescAttributes_t,
) -> Result<(), cublasError_t> ) -> Result<(), cublasError_t>
@ -322,7 +243,7 @@ where
let hip_buf: *const HipType = &hip_operation; let hip_buf: *const HipType = &hip_operation;
unsafe { unsafe {
hipblasLtMatmulDescSetAttribute( hipblasLtMatmulDescSetAttribute(
matmul_desc.desc, matmul_desc,
hip_attr, hip_attr,
hip_buf.cast(), hip_buf.cast(),
mem::size_of::<HipType>(), mem::size_of::<HipType>(),
@ -331,50 +252,48 @@ where
Ok(()) Ok(())
} }
pub(crate) fn matmul_preference_create(pref: &mut cublasLtMatmulPreference_t) -> cublasStatus_t { pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
let mut zluda_matmul_pref = MatmulPreference::new(); unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
unsafe { hipblasLtMatmulPreferenceCreate(&mut zluda_matmul_pref.pref) }?;
*pref = MatmulPreference::wrap(zluda_matmul_pref);
Ok(()) Ok(())
} }
pub(crate) fn matmul_preference_destroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t { pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatmulPreference>(pref) unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
Ok(())
} }
pub(crate) fn matmul_preference_set_attribute( pub(crate) fn matmul_preference_set_attribute(
pref: &MatmulPreference, pref: hipblasLtMatmulPreference_t,
attr: hipblasLtMatmulPreferenceAttributes_t, attr: hipblasLtMatmulPreferenceAttributes_t,
buf: *const ::core::ffi::c_void, buf: *const ::core::ffi::c_void,
size_in_bytes: usize, size_in_bytes: usize,
) -> cublasStatus_t { ) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref.pref, attr, buf, size_in_bytes) }?; unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
Ok(()) Ok(())
} }
pub(crate) fn matrix_layout_create( pub(crate) fn matrix_layout_create(
mat_layout: &mut cublasLtMatrixLayout_t, mat_layout: &mut hipblasLtMatrixLayout_t,
type_: hipDataType, type_: hipDataType,
rows: u64, rows: u64,
cols: u64, cols: u64,
ld: i64, ld: i64,
) -> cublasStatus_t { ) -> cublasStatus_t {
let mut zluda_matrix_layout = MatrixLayout::new(); unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
unsafe { hipblasLtMatrixLayoutCreate(&mut zluda_matrix_layout.layout, type_, rows, cols, ld) }?;
*mat_layout = MatrixLayout::wrap(zluda_matrix_layout);
Ok(()) Ok(())
} }
pub(crate) fn matrix_layout_destroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t { pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatrixLayout>(mat_layout) unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
Ok(())
} }
pub(crate) fn matrix_layout_set_attribute( pub(crate) fn matrix_layout_set_attribute(
mat_layout: &MatrixLayout, mat_layout: hipblasLtMatrixLayout_t,
attr: hipblasLtMatrixLayoutAttribute_t, attr: hipblasLtMatrixLayoutAttribute_t,
buf: *const ::core::ffi::c_void, buf: *const ::core::ffi::c_void,
size_in_bytes: usize, size_in_bytes: usize,
) -> cublasStatus_t { ) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout.layout, attr, buf, size_in_bytes) }?; unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
Ok(()) Ok(())
} }

View file

@ -190,7 +190,10 @@ from_cuda_transmute!(
CUstreamCaptureMode => hipStreamCaptureMode, CUstreamCaptureMode => hipStreamCaptureMode,
CUgraphNode => hipGraphNode_t, CUgraphNode => hipGraphNode_t,
CUgraphExec => hipGraphExec_t, CUgraphExec => hipGraphExec_t,
CUkernel => hipFunction_t CUkernel => hipFunction_t,
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
); );
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t { impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {