mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-28 20:29:11 +00:00
Use transmute instead of wrapping
This commit is contained in:
parent
86fc95b7f7
commit
01d52fdf80
2 changed files with 51 additions and 129 deletions
|
@ -29,84 +29,6 @@ impl ZludaObject for Handle {
|
||||||
|
|
||||||
from_cuda_object!(Handle);
|
from_cuda_object!(Handle);
|
||||||
|
|
||||||
pub struct MatmulDesc {
|
|
||||||
desc: hipblasLtMatmulDesc_t,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MatmulDesc {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
desc: unsafe { mem::zeroed() },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ZludaObject for MatmulDesc {
|
|
||||||
const COOKIE: usize = 0x4406a5f4b814f52b;
|
|
||||||
|
|
||||||
type Error = cublasError_t;
|
|
||||||
type CudaHandle = cublasLtMatmulDesc_t;
|
|
||||||
|
|
||||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
|
||||||
unsafe { hipblasLtMatmulDescDestroy(self.desc) }?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
from_cuda_object!(MatmulDesc);
|
|
||||||
|
|
||||||
pub struct MatmulPreference {
|
|
||||||
pref: hipblasLtMatmulPreference_t,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MatmulPreference {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
pref: unsafe { mem::zeroed() },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ZludaObject for MatmulPreference {
|
|
||||||
const COOKIE: usize = 0x6a6d1c41958baa9b;
|
|
||||||
|
|
||||||
type Error = cublasError_t;
|
|
||||||
type CudaHandle = cublasLtMatmulPreference_t;
|
|
||||||
|
|
||||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
|
||||||
unsafe { hipblasLtMatmulPreferenceDestroy(self.pref) }?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
from_cuda_object!(MatmulPreference);
|
|
||||||
|
|
||||||
pub struct MatrixLayout {
|
|
||||||
layout: hipblasLtMatrixLayout_t,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MatrixLayout {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
layout: unsafe { mem::zeroed() },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ZludaObject for MatrixLayout {
|
|
||||||
const COOKIE: usize = 0xcf566e9656cec9b8;
|
|
||||||
|
|
||||||
type Error = cublasError_t;
|
|
||||||
type CudaHandle = cublasLtMatrixLayout_t;
|
|
||||||
|
|
||||||
fn drop_checked(&mut self) -> cublasStatus_t {
|
|
||||||
unsafe { hipblasLtMatrixLayoutDestroy(self.layout) }?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
from_cuda_object!(MatrixLayout);
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
pub(crate) fn unimplemented() -> cublasStatus_t {
|
pub(crate) fn unimplemented() -> cublasStatus_t {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
|
@ -159,17 +81,17 @@ fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t {
|
||||||
|
|
||||||
pub(crate) fn matmul(
|
pub(crate) fn matmul(
|
||||||
light_handle: &Handle,
|
light_handle: &Handle,
|
||||||
compute_desc: &MatmulDesc,
|
compute_desc: hipblasLtMatmulDesc_t,
|
||||||
alpha: *const ::core::ffi::c_void,
|
alpha: *const ::core::ffi::c_void,
|
||||||
a: *const ::core::ffi::c_void,
|
a: *const ::core::ffi::c_void,
|
||||||
a_desc: &MatrixLayout,
|
a_desc: hipblasLtMatrixLayout_t,
|
||||||
b: *const ::core::ffi::c_void,
|
b: *const ::core::ffi::c_void,
|
||||||
b_desc: &MatrixLayout,
|
b_desc: hipblasLtMatrixLayout_t,
|
||||||
beta: *const ::core::ffi::c_void,
|
beta: *const ::core::ffi::c_void,
|
||||||
c: *const ::core::ffi::c_void,
|
c: *const ::core::ffi::c_void,
|
||||||
c_desc: &MatrixLayout,
|
c_desc: hipblasLtMatrixLayout_t,
|
||||||
d: *mut ::core::ffi::c_void,
|
d: *mut ::core::ffi::c_void,
|
||||||
d_desc: &MatrixLayout,
|
d_desc: hipblasLtMatrixLayout_t,
|
||||||
algo: hipblasLtMatmulAlgo_t,
|
algo: hipblasLtMatmulAlgo_t,
|
||||||
workspace: *mut ::core::ffi::c_void,
|
workspace: *mut ::core::ffi::c_void,
|
||||||
workspace_size_in_bytes: usize,
|
workspace_size_in_bytes: usize,
|
||||||
|
@ -178,17 +100,17 @@ pub(crate) fn matmul(
|
||||||
unsafe {
|
unsafe {
|
||||||
hipblasLtMatmul(
|
hipblasLtMatmul(
|
||||||
light_handle.handle,
|
light_handle.handle,
|
||||||
compute_desc.desc,
|
compute_desc,
|
||||||
alpha,
|
alpha,
|
||||||
a,
|
a,
|
||||||
a_desc.layout,
|
a_desc,
|
||||||
b,
|
b,
|
||||||
b_desc.layout,
|
b_desc,
|
||||||
beta,
|
beta,
|
||||||
c,
|
c,
|
||||||
c_desc.layout,
|
c_desc,
|
||||||
d,
|
d,
|
||||||
d_desc.layout,
|
d_desc,
|
||||||
&algo,
|
&algo,
|
||||||
workspace,
|
workspace,
|
||||||
workspace_size_in_bytes,
|
workspace_size_in_bytes,
|
||||||
|
@ -200,12 +122,12 @@ pub(crate) fn matmul(
|
||||||
|
|
||||||
pub(crate) fn matmul_algo_get_heuristic(
|
pub(crate) fn matmul_algo_get_heuristic(
|
||||||
light_handle: &Handle,
|
light_handle: &Handle,
|
||||||
operation_desc: &MatmulDesc,
|
operation_desc: hipblasLtMatmulDesc_t,
|
||||||
a_desc: &MatrixLayout,
|
a_desc: hipblasLtMatrixLayout_t,
|
||||||
b_desc: &MatrixLayout,
|
b_desc: hipblasLtMatrixLayout_t,
|
||||||
c_desc: &MatrixLayout,
|
c_desc: hipblasLtMatrixLayout_t,
|
||||||
d_desc: &MatrixLayout,
|
d_desc: hipblasLtMatrixLayout_t,
|
||||||
preference: &MatmulPreference,
|
preference: hipblasLtMatmulPreference_t,
|
||||||
requested_algo_count: ::core::ffi::c_int,
|
requested_algo_count: ::core::ffi::c_int,
|
||||||
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
|
heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t,
|
||||||
return_algo_count: &mut ::core::ffi::c_int,
|
return_algo_count: &mut ::core::ffi::c_int,
|
||||||
|
@ -214,12 +136,12 @@ pub(crate) fn matmul_algo_get_heuristic(
|
||||||
unsafe {
|
unsafe {
|
||||||
hipblasLtMatmulAlgoGetHeuristic(
|
hipblasLtMatmulAlgoGetHeuristic(
|
||||||
light_handle.handle,
|
light_handle.handle,
|
||||||
operation_desc.desc,
|
operation_desc,
|
||||||
a_desc.layout,
|
a_desc,
|
||||||
b_desc.layout,
|
b_desc,
|
||||||
c_desc.layout,
|
c_desc,
|
||||||
d_desc.layout,
|
d_desc,
|
||||||
preference.pref,
|
preference,
|
||||||
requested_algo_count,
|
requested_algo_count,
|
||||||
hip_algos.as_mut_ptr(),
|
hip_algos.as_mut_ptr(),
|
||||||
return_algo_count,
|
return_algo_count,
|
||||||
|
@ -244,22 +166,21 @@ pub(crate) fn matmul_algo_get_heuristic(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_desc_create(
|
pub(crate) fn matmul_desc_create(
|
||||||
matmul_desc: &mut cublasLtMatmulDesc_t,
|
matmul_desc: &mut hipblasLtMatmulDesc_t,
|
||||||
compute_type: hipblasComputeType_t,
|
compute_type: hipblasComputeType_t,
|
||||||
scale_type: hipDataType,
|
scale_type: hipDataType,
|
||||||
) -> cublasStatus_t {
|
) -> cublasStatus_t {
|
||||||
let mut zluda_blaslt_desc = MatmulDesc::new();
|
unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?;
|
||||||
unsafe { hipblasLtMatmulDescCreate(&mut zluda_blaslt_desc.desc, compute_type, scale_type) }?;
|
|
||||||
*matmul_desc = MatmulDesc::wrap(zluda_blaslt_desc);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_desc_destroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t {
|
pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t {
|
||||||
zluda_common::drop_checked::<MatmulDesc>(matmul_desc)
|
unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_desc_set_attribute(
|
pub(crate) fn matmul_desc_set_attribute(
|
||||||
matmul_desc: &MatmulDesc,
|
matmul_desc: hipblasLtMatmulDesc_t,
|
||||||
attr: cublasLtMatmulDescAttributes_t,
|
attr: cublasLtMatmulDescAttributes_t,
|
||||||
buf: *const ::core::ffi::c_void,
|
buf: *const ::core::ffi::c_void,
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
|
@ -300,16 +221,16 @@ pub(crate) fn matmul_desc_set_attribute(
|
||||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
|
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => {
|
||||||
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
|
convert_and_set_attribute::<cudaDataType, hipDataType>(matmul_desc, buf, hip_attr)?
|
||||||
}
|
}
|
||||||
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => unsafe {
|
hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => {
|
||||||
hipblasLtMatmulDescSetAttribute(matmul_desc.desc, hip_attr, buf, size_in_bytes)
|
unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }?
|
||||||
}?,
|
}
|
||||||
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
|
_ => return cublasStatus_t::ERROR_NOT_SUPPORTED,
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
|
fn convert_and_set_attribute<'a, CudaType: 'a, HipType>(
|
||||||
matmul_desc: &MatmulDesc,
|
matmul_desc: hipblasLtMatmulDesc_t,
|
||||||
buf: *const std::ffi::c_void,
|
buf: *const std::ffi::c_void,
|
||||||
hip_attr: hipblasLtMatmulDescAttributes_t,
|
hip_attr: hipblasLtMatmulDescAttributes_t,
|
||||||
) -> Result<(), cublasError_t>
|
) -> Result<(), cublasError_t>
|
||||||
|
@ -322,7 +243,7 @@ where
|
||||||
let hip_buf: *const HipType = &hip_operation;
|
let hip_buf: *const HipType = &hip_operation;
|
||||||
unsafe {
|
unsafe {
|
||||||
hipblasLtMatmulDescSetAttribute(
|
hipblasLtMatmulDescSetAttribute(
|
||||||
matmul_desc.desc,
|
matmul_desc,
|
||||||
hip_attr,
|
hip_attr,
|
||||||
hip_buf.cast(),
|
hip_buf.cast(),
|
||||||
mem::size_of::<HipType>(),
|
mem::size_of::<HipType>(),
|
||||||
|
@ -331,50 +252,48 @@ where
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_preference_create(pref: &mut cublasLtMatmulPreference_t) -> cublasStatus_t {
|
pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||||
let mut zluda_matmul_pref = MatmulPreference::new();
|
unsafe { hipblasLtMatmulPreferenceCreate(pref) }?;
|
||||||
unsafe { hipblasLtMatmulPreferenceCreate(&mut zluda_matmul_pref.pref) }?;
|
|
||||||
*pref = MatmulPreference::wrap(zluda_matmul_pref);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_preference_destroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t {
|
pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t {
|
||||||
zluda_common::drop_checked::<MatmulPreference>(pref)
|
unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_preference_set_attribute(
|
pub(crate) fn matmul_preference_set_attribute(
|
||||||
pref: &MatmulPreference,
|
pref: hipblasLtMatmulPreference_t,
|
||||||
attr: hipblasLtMatmulPreferenceAttributes_t,
|
attr: hipblasLtMatmulPreferenceAttributes_t,
|
||||||
buf: *const ::core::ffi::c_void,
|
buf: *const ::core::ffi::c_void,
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
) -> cublasStatus_t {
|
) -> cublasStatus_t {
|
||||||
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref.pref, attr, buf, size_in_bytes) }?;
|
unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matrix_layout_create(
|
pub(crate) fn matrix_layout_create(
|
||||||
mat_layout: &mut cublasLtMatrixLayout_t,
|
mat_layout: &mut hipblasLtMatrixLayout_t,
|
||||||
type_: hipDataType,
|
type_: hipDataType,
|
||||||
rows: u64,
|
rows: u64,
|
||||||
cols: u64,
|
cols: u64,
|
||||||
ld: i64,
|
ld: i64,
|
||||||
) -> cublasStatus_t {
|
) -> cublasStatus_t {
|
||||||
let mut zluda_matrix_layout = MatrixLayout::new();
|
unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?;
|
||||||
unsafe { hipblasLtMatrixLayoutCreate(&mut zluda_matrix_layout.layout, type_, rows, cols, ld) }?;
|
|
||||||
*mat_layout = MatrixLayout::wrap(zluda_matrix_layout);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matrix_layout_destroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t {
|
pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t {
|
||||||
zluda_common::drop_checked::<MatrixLayout>(mat_layout)
|
unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matrix_layout_set_attribute(
|
pub(crate) fn matrix_layout_set_attribute(
|
||||||
mat_layout: &MatrixLayout,
|
mat_layout: hipblasLtMatrixLayout_t,
|
||||||
attr: hipblasLtMatrixLayoutAttribute_t,
|
attr: hipblasLtMatrixLayoutAttribute_t,
|
||||||
buf: *const ::core::ffi::c_void,
|
buf: *const ::core::ffi::c_void,
|
||||||
size_in_bytes: usize,
|
size_in_bytes: usize,
|
||||||
) -> cublasStatus_t {
|
) -> cublasStatus_t {
|
||||||
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout.layout, attr, buf, size_in_bytes) }?;
|
unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,7 +190,10 @@ from_cuda_transmute!(
|
||||||
CUstreamCaptureMode => hipStreamCaptureMode,
|
CUstreamCaptureMode => hipStreamCaptureMode,
|
||||||
CUgraphNode => hipGraphNode_t,
|
CUgraphNode => hipGraphNode_t,
|
||||||
CUgraphExec => hipGraphExec_t,
|
CUgraphExec => hipGraphExec_t,
|
||||||
CUkernel => hipFunction_t
|
CUkernel => hipFunction_t,
|
||||||
|
cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t,
|
||||||
|
cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t,
|
||||||
|
cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t
|
||||||
);
|
);
|
||||||
|
|
||||||
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
|
impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue