Use transmute instead of wrapping

This commit is contained in:
Violet 2025-09-17 05:41:09 +00:00
commit 01d52fdf80
2 changed files with 51 additions and 129 deletions

View file

@ -29,84 +29,6 @@ impl ZludaObject for Handle {
from_cuda_object!(Handle);
pub struct MatmulDesc {
desc: hipblasLtMatmulDesc_t,
}
impl MatmulDesc {
fn new() -> Self {
Self {
desc: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatmulDesc {
const COOKIE: usize = 0x4406a5f4b814f52b;
type Error = cublasError_t;
type CudaHandle = cublasLtMatmulDesc_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatmulDescDestroy(self.desc) }?;
Ok(())
}
}
from_cuda_object!(MatmulDesc);
pub struct MatmulPreference {
pref: hipblasLtMatmulPreference_t,
}
impl MatmulPreference {
fn new() -> Self {
Self {
pref: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatmulPreference {
const COOKIE: usize = 0x6a6d1c41958baa9b;
type Error = cublasError_t;
type CudaHandle = cublasLtMatmulPreference_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceDestroy(self.pref) }?;
Ok(())
}
}
from_cuda_object!(MatmulPreference);
pub struct MatrixLayout {
layout: hipblasLtMatrixLayout_t,
}
impl MatrixLayout {
fn new() -> Self {
Self {
layout: unsafe { mem::zeroed() },
}
}
}
impl ZludaObject for MatrixLayout {
const COOKIE: usize = 0xcf566e9656cec9b8;
type Error = cublasError_t;
type CudaHandle = cublasLtMatrixLayout_t;
fn drop_checked(&mut self) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutDestroy(self.layout) }?;
Ok(())
}
}
from_cuda_object!(MatrixLayout);
#[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> cublasStatus_t {
unimplemented!()
@ -159,17 +81,17 @@ fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
pub(crate) fn matmul(
light_handle: &Handle,
compute_desc: &MatmulDesc,
compute_desc: hipblasLtMatmulDesc_t,
alpha: *const ::core::ffi::c_void,
a: *const ::core::ffi::c_void,
a_desc: &MatrixLayout,
a_desc: hipblasLtMatrixLayout_t,
b: *const ::core::ffi::c_void,
b_desc: &MatrixLayout,
b_desc: hipblasLtMatrixLayout_t,
beta: *const ::core::ffi::c_void,
c: *const ::core::ffi::c_void,
c_desc: &MatrixLayout,
c_desc: hipblasLtMatrixLayout_t,
d: *mut ::core::ffi::c_void,
d_desc: &MatrixLayout,
d_desc: hipblasLtMatrixLayout_t,
algo: hipblasLtMatmulAlgo_t,
workspace: *mut ::core::ffi::c_void,
workspace_size_in_bytes: usize,
@ -178,17 +100,17 @@ pub(crate) fn matmul(
unsafe {
hipblasLtMatmul(
light_handle.handle,
compute_desc.desc,
compute_desc,
alpha,
a,
a_desc.layout,
a_desc,
b,
b_desc.layout,
b_desc,
beta,
c,
c_desc.layout,
c_desc,
d,
d_desc.layout,
d_desc,
&algo,
workspace,
workspace_size_in_bytes,
@ -200,12 +122,12 @@ pub(crate) fn matmul(
pub(crate) fn matmul_algo_get_heuristic(
light_handle: &Handle,
operation_desc: &MatmulDesc,
a_desc: &MatrixLayout,
b_desc: &MatrixLayout,
c_desc: &MatrixLayout,
d_desc: &MatrixLayout,
preference: &MatmulPreference,
operation_desc: hipblasLtMatmulDesc_t,
a_desc: hipblasLtMatrixLayout_t,
b_desc: hipblasLtMatrixLayout_t,
c_desc: hipblasLtMatrixLayout_t,
d_desc: hipblasLtMatrixLayout_t,
preference: hipblasLtMatmulPreference_t,
requested_algo_count: ::core::ffi::c_int,
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
return_algo_count: &mut ::core::ffi::c_int,
@ -214,12 +136,12 @@ pub(crate) fn matmul_algo_get_heuristic(
unsafe {
hipblasLtMatmulAlgoGetHeuristic(
light_handle.handle,
operation_desc.desc,
a_desc.layout,
b_desc.layout,
c_desc.layout,
d_desc.layout,
preference.pref,
operation_desc,
a_desc,
b_desc,
c_desc,
d_desc,
preference,
requested_algo_count,
hip_algos.as_mut_ptr(),
return_algo_count,
@ -244,22 +166,21 @@ pub(crate) fn matmul_algo_get_heuristic(
}
pub(crate) fn matmul_desc_create(
matmul_desc: &mut cublasLtMatmulDesc_t,
matmul_desc: &mut hipblasLtMatmulDesc_t,
compute_type: hipblasComputeType_t,
scale_type: hipDataType,
) -> cublasStatus_t {
let mut zluda_blaslt_desc = MatmulDesc::new();
unsafe { hipblasLtMatmulDescCreate(&mut zluda_blaslt_desc.desc, compute_type, scale_type) }?;
*matmul_desc = MatmulDesc::wrap(zluda_blaslt_desc);
unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
Ok(())
}
pub(crate) fn matmul_desc_destroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatmulDesc>(matmul_desc)
pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
Ok(())
}
pub(crate) fn matmul_desc_set_attribute(
matmul_desc: &MatmulDesc,
matmul_desc: hipblasLtMatmulDesc_t,
attr: cublasLtMatmulDescAttributes_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
@ -300,16 +221,16 @@ pub(crate) fn matmul_desc_set_attribute(
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
}
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => unsafe {
hipblasLtMatmulDescSetAttribute(matmul_desc.desc, hip_attr, buf, size_in_bytes)
}?,
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
}
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
}
Ok(())
}
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
matmul_desc: &MatmulDesc,
matmul_desc: hipblasLtMatmulDesc_t,
buf: *const std::ffi::c_void,
hip_attr: hipblasLtMatmulDescAttributes_t,
) -> Result<(), cublasError_t>
@ -322,7 +243,7 @@ where
let hip_buf: *const HipType = &hip_operation;
unsafe {
hipblasLtMatmulDescSetAttribute(
matmul_desc.desc,
matmul_desc,
hip_attr,
hip_buf.cast(),
mem::size_of::<HipType>(),
@ -331,50 +252,48 @@ where
Ok(())
}
pub(crate) fn matmul_preference_create(pref: &mut cublasLtMatmulPreference_t) -> cublasStatus_t {
let mut zluda_matmul_pref = MatmulPreference::new();
unsafe { hipblasLtMatmulPreferenceCreate(&mut zluda_matmul_pref.pref) }?;
*pref = MatmulPreference::wrap(zluda_matmul_pref);
pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
Ok(())
}
pub(crate) fn matmul_preference_destroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatmulPreference>(pref)
pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
Ok(())
}
pub(crate) fn matmul_preference_set_attribute(
pref: &MatmulPreference,
pref: hipblasLtMatmulPreference_t,
attr: hipblasLtMatmulPreferenceAttributes_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
) -> cublasStatus_t {
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref.pref, attr, buf, size_in_bytes) }?;
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
Ok(())
}
pub(crate) fn matrix_layout_create(
mat_layout: &mut cublasLtMatrixLayout_t,
mat_layout: &mut hipblasLtMatrixLayout_t,
type_: hipDataType,
rows: u64,
cols: u64,
ld: i64,
) -> cublasStatus_t {
let mut zluda_matrix_layout = MatrixLayout::new();
unsafe { hipblasLtMatrixLayoutCreate(&mut zluda_matrix_layout.layout, type_, rows, cols, ld) }?;
*mat_layout = MatrixLayout::wrap(zluda_matrix_layout);
unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
Ok(())
}
pub(crate) fn matrix_layout_destroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t {
zluda_common::drop_checked::<MatrixLayout>(mat_layout)
pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
Ok(())
}
pub(crate) fn matrix_layout_set_attribute(
mat_layout: &MatrixLayout,
mat_layout: hipblasLtMatrixLayout_t,
attr: hipblasLtMatrixLayoutAttribute_t,
buf: *const ::core::ffi::c_void,
size_in_bytes: usize,
) -> cublasStatus_t {
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout.layout, attr, buf, size_in_bytes) }?;
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
Ok(())
}

View file

@ -190,7 +190,10 @@ from_cuda_transmute!(
CUstreamCaptureMode => hipStreamCaptureMode,
CUgraphNode => hipGraphNode_t,
CUgraphExec => hipGraphExec_t,
CUkernel => hipFunction_t
CUkernel => hipFunction_t,
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
);
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {