mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-27 19:59:08 +00:00
Use transmute instead of wrapping
This commit is contained in:
parent
86fc95b7f7
commit
01d52fdf80
2 changed files with 51 additions and 129 deletions
|
@ -29,84 +29,6 @@ impl ZludaObject for Handle {
|
|||
|
||||
from_cuda_object!(Handle);
|
||||
|
||||
pub struct MatmulDesc {
|
||||
desc: hipblasLtMatmulDesc_t,
|
||||
}
|
||||
|
||||
impl MatmulDesc {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
desc: unsafe { mem::zeroed() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZludaObject for MatmulDesc {
|
||||
const COOKIE: usize = 0x4406a5f4b814f52b;
|
||||
|
||||
type Error = cublasError_t;
|
||||
type CudaHandle = cublasLtMatmulDesc_t;
|
||||
|
||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulDescDestroy(self.desc) }?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
from_cuda_object!(MatmulDesc);
|
||||
|
||||
pub struct MatmulPreference {
|
||||
pref: hipblasLtMatmulPreference_t,
|
||||
}
|
||||
|
||||
impl MatmulPreference {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
pref: unsafe { mem::zeroed() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZludaObject for MatmulPreference {
|
||||
const COOKIE: usize = 0x6a6d1c41958baa9b;
|
||||
|
||||
type Error = cublasError_t;
|
||||
type CudaHandle = cublasLtMatmulPreference_t;
|
||||
|
||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceDestroy(self.pref) }?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
from_cuda_object!(MatmulPreference);
|
||||
|
||||
pub struct MatrixLayout {
|
||||
layout: hipblasLtMatrixLayout_t,
|
||||
}
|
||||
|
||||
impl MatrixLayout {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
layout: unsafe { mem::zeroed() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ZludaObject for MatrixLayout {
|
||||
const COOKIE: usize = 0xcf566e9656cec9b8;
|
||||
|
||||
type Error = cublasError_t;
|
||||
type CudaHandle = cublasLtMatrixLayout_t;
|
||||
|
||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutDestroy(self.layout) }?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
from_cuda_object!(MatrixLayout);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub(crate) fn unimplemented() -> cublasStatus_t {
|
||||
unimplemented!()
|
||||
|
@ -159,17 +81,17 @@ fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
|
|||
|
||||
pub(crate) fn matmul(
|
||||
light_handle: &Handle,
|
||||
compute_desc: &MatmulDesc,
|
||||
compute_desc: hipblasLtMatmulDesc_t,
|
||||
alpha: *const ::core::ffi::c_void,
|
||||
a: *const ::core::ffi::c_void,
|
||||
a_desc: &MatrixLayout,
|
||||
a_desc: hipblasLtMatrixLayout_t,
|
||||
b: *const ::core::ffi::c_void,
|
||||
b_desc: &MatrixLayout,
|
||||
b_desc: hipblasLtMatrixLayout_t,
|
||||
beta: *const ::core::ffi::c_void,
|
||||
c: *const ::core::ffi::c_void,
|
||||
c_desc: &MatrixLayout,
|
||||
c_desc: hipblasLtMatrixLayout_t,
|
||||
d: *mut ::core::ffi::c_void,
|
||||
d_desc: &MatrixLayout,
|
||||
d_desc: hipblasLtMatrixLayout_t,
|
||||
algo: hipblasLtMatmulAlgo_t,
|
||||
workspace: *mut ::core::ffi::c_void,
|
||||
workspace_size_in_bytes: usize,
|
||||
|
@ -178,17 +100,17 @@ pub(crate) fn matmul(
|
|||
unsafe {
|
||||
hipblasLtMatmul(
|
||||
light_handle.handle,
|
||||
compute_desc.desc,
|
||||
compute_desc,
|
||||
alpha,
|
||||
a,
|
||||
a_desc.layout,
|
||||
a_desc,
|
||||
b,
|
||||
b_desc.layout,
|
||||
b_desc,
|
||||
beta,
|
||||
c,
|
||||
c_desc.layout,
|
||||
c_desc,
|
||||
d,
|
||||
d_desc.layout,
|
||||
d_desc,
|
||||
&algo,
|
||||
workspace,
|
||||
workspace_size_in_bytes,
|
||||
|
@ -200,12 +122,12 @@ pub(crate) fn matmul(
|
|||
|
||||
pub(crate) fn matmul_algo_get_heuristic(
|
||||
light_handle: &Handle,
|
||||
operation_desc: &MatmulDesc,
|
||||
a_desc: &MatrixLayout,
|
||||
b_desc: &MatrixLayout,
|
||||
c_desc: &MatrixLayout,
|
||||
d_desc: &MatrixLayout,
|
||||
preference: &MatmulPreference,
|
||||
operation_desc: hipblasLtMatmulDesc_t,
|
||||
a_desc: hipblasLtMatrixLayout_t,
|
||||
b_desc: hipblasLtMatrixLayout_t,
|
||||
c_desc: hipblasLtMatrixLayout_t,
|
||||
d_desc: hipblasLtMatrixLayout_t,
|
||||
preference: hipblasLtMatmulPreference_t,
|
||||
requested_algo_count: ::core::ffi::c_int,
|
||||
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
|
||||
return_algo_count: &mut ::core::ffi::c_int,
|
||||
|
@ -214,12 +136,12 @@ pub(crate) fn matmul_algo_get_heuristic(
|
|||
unsafe {
|
||||
hipblasLtMatmulAlgoGetHeuristic(
|
||||
light_handle.handle,
|
||||
operation_desc.desc,
|
||||
a_desc.layout,
|
||||
b_desc.layout,
|
||||
c_desc.layout,
|
||||
d_desc.layout,
|
||||
preference.pref,
|
||||
operation_desc,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
d_desc,
|
||||
preference,
|
||||
requested_algo_count,
|
||||
hip_algos.as_mut_ptr(),
|
||||
return_algo_count,
|
||||
|
@ -244,22 +166,21 @@ pub(crate) fn matmul_algo_get_heuristic(
|
|||
}
|
||||
|
||||
pub(crate) fn matmul_desc_create(
|
||||
matmul_desc: &mut cublasLtMatmulDesc_t,
|
||||
matmul_desc: &mut hipblasLtMatmulDesc_t,
|
||||
compute_type: hipblasComputeType_t,
|
||||
scale_type: hipDataType,
|
||||
) -> cublasStatus_t {
|
||||
let mut zluda_blaslt_desc = MatmulDesc::new();
|
||||
unsafe { hipblasLtMatmulDescCreate(&mut zluda_blaslt_desc.desc, compute_type, scale_type) }?;
|
||||
*matmul_desc = MatmulDesc::wrap(zluda_blaslt_desc);
|
||||
unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_desc_destroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t {
|
||||
zluda_common::drop_checked::<MatmulDesc>(matmul_desc)
|
||||
pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_desc_set_attribute(
|
||||
matmul_desc: &MatmulDesc,
|
||||
matmul_desc: hipblasLtMatmulDesc_t,
|
||||
attr: cublasLtMatmulDescAttributes_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
|
@ -300,16 +221,16 @@ pub(crate) fn matmul_desc_set_attribute(
|
|||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
|
||||
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
|
||||
}
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => unsafe {
|
||||
hipblasLtMatmulDescSetAttribute(matmul_desc.desc, hip_attr, buf, size_in_bytes)
|
||||
}?,
|
||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
|
||||
unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
|
||||
}
|
||||
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
|
||||
matmul_desc: &MatmulDesc,
|
||||
matmul_desc: hipblasLtMatmulDesc_t,
|
||||
buf: *const std::ffi::c_void,
|
||||
hip_attr: hipblasLtMatmulDescAttributes_t,
|
||||
) -> Result<(), cublasError_t>
|
||||
|
@ -322,7 +243,7 @@ where
|
|||
let hip_buf: *const HipType = &hip_operation;
|
||||
unsafe {
|
||||
hipblasLtMatmulDescSetAttribute(
|
||||
matmul_desc.desc,
|
||||
matmul_desc,
|
||||
hip_attr,
|
||||
hip_buf.cast(),
|
||||
mem::size_of::<HipType>(),
|
||||
|
@ -331,50 +252,48 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_create(pref: &mut cublasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
let mut zluda_matmul_pref = MatmulPreference::new();
|
||||
unsafe { hipblasLtMatmulPreferenceCreate(&mut zluda_matmul_pref.pref) }?;
|
||||
*pref = MatmulPreference::wrap(zluda_matmul_pref);
|
||||
pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_destroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
zluda_common::drop_checked::<MatmulPreference>(pref)
|
||||
pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_preference_set_attribute(
|
||||
pref: &MatmulPreference,
|
||||
pref: hipblasLtMatmulPreference_t,
|
||||
attr: hipblasLtMatmulPreferenceAttributes_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref.pref, attr, buf, size_in_bytes) }?;
|
||||
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_create(
|
||||
mat_layout: &mut cublasLtMatrixLayout_t,
|
||||
mat_layout: &mut hipblasLtMatrixLayout_t,
|
||||
type_: hipDataType,
|
||||
rows: u64,
|
||||
cols: u64,
|
||||
ld: i64,
|
||||
) -> cublasStatus_t {
|
||||
let mut zluda_matrix_layout = MatrixLayout::new();
|
||||
unsafe { hipblasLtMatrixLayoutCreate(&mut zluda_matrix_layout.layout, type_, rows, cols, ld) }?;
|
||||
*mat_layout = MatrixLayout::wrap(zluda_matrix_layout);
|
||||
unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_destroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t {
|
||||
zluda_common::drop_checked::<MatrixLayout>(mat_layout)
|
||||
pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn matrix_layout_set_attribute(
|
||||
mat_layout: &MatrixLayout,
|
||||
mat_layout: hipblasLtMatrixLayout_t,
|
||||
attr: hipblasLtMatrixLayoutAttribute_t,
|
||||
buf: *const ::core::ffi::c_void,
|
||||
size_in_bytes: usize,
|
||||
) -> cublasStatus_t {
|
||||
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout.layout, attr, buf, size_in_bytes) }?;
|
||||
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -190,7 +190,10 @@ from_cuda_transmute!(
|
|||
CUstreamCaptureMode => hipStreamCaptureMode,
|
||||
CUgraphNode => hipGraphNode_t,
|
||||
CUgraphExec => hipGraphExec_t,
|
||||
CUkernel => hipFunction_t
|
||||
CUkernel => hipFunction_t,
|
||||
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
|
||||
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
|
||||
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
|
||||
);
|
||||
|
||||
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue