diff --git a/zluda_blaslt/src/impl.rs b/zluda_blaslt/src/impl.rs index c153538..adf18a2 100644 --- a/zluda_blaslt/src/impl.rs +++ b/zluda_blaslt/src/impl.rs @@ -29,84 +29,6 @@ impl ZludaObject for Handle { from_cuda_object!(Handle); -pub struct MatmulDesc { - desc: hipblasLtMatmulDesc_t, -} - -impl MatmulDesc { - fn new() -> Self { - Self { - desc: unsafe { mem::zeroed() }, - } - } -} - -impl ZludaObject for MatmulDesc { - const COOKIE: usize = 0x4406a5f4b814f52b; - - type Error = cublasError_t; - type CudaHandle = cublasLtMatmulDesc_t; - - fn drop_checked(&mut self) -> cublasStatus_t { - unsafe { hipblasLtMatmulDescDestroy(self.desc) }?; - Ok(()) - } -} - -from_cuda_object!(MatmulDesc); - -pub struct MatmulPreference { - pref: hipblasLtMatmulPreference_t, -} - -impl MatmulPreference { - fn new() -> Self { - Self { - pref: unsafe { mem::zeroed() }, - } - } -} - -impl ZludaObject for MatmulPreference { - const COOKIE: usize = 0x6a6d1c41958baa9b; - - type Error = cublasError_t; - type CudaHandle = cublasLtMatmulPreference_t; - - fn drop_checked(&mut self) -> cublasStatus_t { - unsafe { hipblasLtMatmulPreferenceDestroy(self.pref) }?; - Ok(()) - } -} - -from_cuda_object!(MatmulPreference); - -pub struct MatrixLayout { - layout: hipblasLtMatrixLayout_t, -} - -impl MatrixLayout { - fn new() -> Self { - Self { - layout: unsafe { mem::zeroed() }, - } - } -} - -impl ZludaObject for MatrixLayout { - const COOKIE: usize = 0xcf566e9656cec9b8; - - type Error = cublasError_t; - type CudaHandle = cublasLtMatrixLayout_t; - - fn drop_checked(&mut self) -> cublasStatus_t { - unsafe { hipblasLtMatrixLayoutDestroy(self.layout) }?; - Ok(()) - } -} - -from_cuda_object!(MatrixLayout); - #[cfg(debug_assertions)] pub(crate) fn unimplemented() -> cublasStatus_t { unimplemented!() @@ -159,17 +81,17 @@ fn cuda_algo_from_hip(hip: hipblasLtMatmulAlgo_t) -> cublasLtMatmulAlgo_t { pub(crate) fn matmul( light_handle: &Handle, - compute_desc: &MatmulDesc, + compute_desc: hipblasLtMatmulDesc_t, alpha: *const ::core::ffi::c_void, a: *const ::core::ffi::c_void, - a_desc: &MatrixLayout, + a_desc: hipblasLtMatrixLayout_t, b: *const ::core::ffi::c_void, - b_desc: &MatrixLayout, + b_desc: hipblasLtMatrixLayout_t, beta: *const ::core::ffi::c_void, c: *const ::core::ffi::c_void, - c_desc: &MatrixLayout, + c_desc: hipblasLtMatrixLayout_t, d: *mut ::core::ffi::c_void, - d_desc: &MatrixLayout, + d_desc: hipblasLtMatrixLayout_t, algo: hipblasLtMatmulAlgo_t, workspace: *mut ::core::ffi::c_void, workspace_size_in_bytes: usize, @@ -178,17 +100,17 @@ pub(crate) fn matmul( unsafe { hipblasLtMatmul( light_handle.handle, - compute_desc.desc, + compute_desc, alpha, a, - a_desc.layout, + a_desc, b, - b_desc.layout, + b_desc, beta, c, - c_desc.layout, + c_desc, d, - d_desc.layout, + d_desc, &algo, workspace, workspace_size_in_bytes, @@ -200,12 +122,12 @@ pub(crate) fn matmul( pub(crate) fn matmul_algo_get_heuristic( light_handle: &Handle, - operation_desc: &MatmulDesc, - a_desc: &MatrixLayout, - b_desc: &MatrixLayout, - c_desc: &MatrixLayout, - d_desc: &MatrixLayout, - preference: &MatmulPreference, + operation_desc: hipblasLtMatmulDesc_t, + a_desc: hipblasLtMatrixLayout_t, + b_desc: hipblasLtMatrixLayout_t, + c_desc: hipblasLtMatrixLayout_t, + d_desc: hipblasLtMatrixLayout_t, + preference: hipblasLtMatmulPreference_t, requested_algo_count: ::core::ffi::c_int, heuristic_results_array: &mut cublasLtMatmulHeuristicResult_t, return_algo_count: &mut ::core::ffi::c_int, @@ -214,12 +136,12 @@ pub(crate) fn matmul_algo_get_heuristic( unsafe { hipblasLtMatmulAlgoGetHeuristic( light_handle.handle, - operation_desc.desc, - a_desc.layout, - b_desc.layout, - c_desc.layout, - d_desc.layout, - preference.pref, + operation_desc, + a_desc, + b_desc, + c_desc, + d_desc, + preference, requested_algo_count, hip_algos.as_mut_ptr(), return_algo_count, @@ -244,22 +166,21 @@ pub(crate) fn matmul_algo_get_heuristic( } pub(crate) fn matmul_desc_create( - matmul_desc: &mut cublasLtMatmulDesc_t, + matmul_desc: &mut hipblasLtMatmulDesc_t, compute_type: hipblasComputeType_t, scale_type: hipDataType, ) -> cublasStatus_t { - let mut zluda_blaslt_desc = MatmulDesc::new(); - unsafe { hipblasLtMatmulDescCreate(&mut zluda_blaslt_desc.desc, compute_type, scale_type) }?; - *matmul_desc = MatmulDesc::wrap(zluda_blaslt_desc); + unsafe { hipblasLtMatmulDescCreate(matmul_desc, compute_type, scale_type) }?; Ok(()) } -pub(crate) fn matmul_desc_destroy(matmul_desc: cublasLtMatmulDesc_t) -> cublasStatus_t { - zluda_common::drop_checked::(matmul_desc) +pub(crate) fn matmul_desc_destroy(matmul_desc: hipblasLtMatmulDesc_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulDescDestroy(matmul_desc) }?; + Ok(()) } pub(crate) fn matmul_desc_set_attribute( - matmul_desc: &MatmulDesc, + matmul_desc: hipblasLtMatmulDesc_t, attr: cublasLtMatmulDescAttributes_t, buf: *const ::core::ffi::c_void, size_in_bytes: usize, @@ -300,16 +221,16 @@ pub(crate) fn matmul_desc_set_attribute( hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE => { convert_and_set_attribute::(matmul_desc, buf, hip_attr)? } - hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => unsafe { - hipblasLtMatmulDescSetAttribute(matmul_desc.desc, hip_attr, buf, size_in_bytes) - }?, + hipblasLtMatmulDescAttributes_t::HIPBLASLT_MATMUL_DESC_BIAS_POINTER => { + unsafe { hipblasLtMatmulDescSetAttribute(matmul_desc, hip_attr, buf, size_in_bytes) }? + } _ => return cublasStatus_t::ERROR_NOT_SUPPORTED, } Ok(()) } fn convert_and_set_attribute<'a, CudaType: 'a, HipType>( - matmul_desc: &MatmulDesc, + matmul_desc: hipblasLtMatmulDesc_t, buf: *const std::ffi::c_void, hip_attr: hipblasLtMatmulDescAttributes_t, ) -> Result<(), cublasError_t> @@ -322,7 +243,7 @@ where let hip_buf: *const HipType = &hip_operation; unsafe { hipblasLtMatmulDescSetAttribute( - matmul_desc.desc, + matmul_desc, hip_attr, hip_buf.cast(), mem::size_of::(), @@ -331,50 +252,48 @@ where Ok(()) } -pub(crate) fn matmul_preference_create(pref: &mut cublasLtMatmulPreference_t) -> cublasStatus_t { - let mut zluda_matmul_pref = MatmulPreference::new(); - unsafe { hipblasLtMatmulPreferenceCreate(&mut zluda_matmul_pref.pref) }?; - *pref = MatmulPreference::wrap(zluda_matmul_pref); +pub(crate) fn matmul_preference_create(pref: &mut hipblasLtMatmulPreference_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulPreferenceCreate(pref) }?; Ok(()) } -pub(crate) fn matmul_preference_destroy(pref: cublasLtMatmulPreference_t) -> cublasStatus_t { - zluda_common::drop_checked::(pref) +pub(crate) fn matmul_preference_destroy(pref: hipblasLtMatmulPreference_t) -> cublasStatus_t { + unsafe { hipblasLtMatmulPreferenceDestroy(pref) }?; + Ok(()) } pub(crate) fn matmul_preference_set_attribute( - pref: &MatmulPreference, + pref: hipblasLtMatmulPreference_t, attr: hipblasLtMatmulPreferenceAttributes_t, buf: *const ::core::ffi::c_void, size_in_bytes: usize, ) -> cublasStatus_t { - unsafe { hipblasLtMatmulPreferenceSetAttribute(pref.pref, attr, buf, size_in_bytes) }?; + unsafe { hipblasLtMatmulPreferenceSetAttribute(pref, attr, buf, size_in_bytes) }?; Ok(()) } pub(crate) fn matrix_layout_create( - mat_layout: &mut cublasLtMatrixLayout_t, + mat_layout: &mut hipblasLtMatrixLayout_t, type_: hipDataType, rows: u64, cols: u64, ld: i64, ) -> cublasStatus_t { - let mut zluda_matrix_layout = MatrixLayout::new(); - unsafe { hipblasLtMatrixLayoutCreate(&mut zluda_matrix_layout.layout, type_, rows, cols, ld) }?; - *mat_layout = MatrixLayout::wrap(zluda_matrix_layout); + unsafe { hipblasLtMatrixLayoutCreate(mat_layout, type_, rows, cols, ld) }?; Ok(()) } -pub(crate) fn matrix_layout_destroy(mat_layout: cublasLtMatrixLayout_t) -> cublasStatus_t { - zluda_common::drop_checked::(mat_layout) +pub(crate) fn matrix_layout_destroy(mat_layout: hipblasLtMatrixLayout_t) -> cublasStatus_t { + unsafe { hipblasLtMatrixLayoutDestroy(mat_layout) }?; + Ok(()) } pub(crate) fn matrix_layout_set_attribute( - mat_layout: &MatrixLayout, + mat_layout: hipblasLtMatrixLayout_t, attr: hipblasLtMatrixLayoutAttribute_t, buf: *const ::core::ffi::c_void, size_in_bytes: usize, ) -> cublasStatus_t { - unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout.layout, attr, buf, size_in_bytes) }?; + unsafe { hipblasLtMatrixLayoutSetAttribute(mat_layout, attr, buf, size_in_bytes) }?; Ok(()) } diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 437ed0d..4f8aef7 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -190,7 +190,10 @@ from_cuda_transmute!( CUstreamCaptureMode => hipStreamCaptureMode, CUgraphNode => hipGraphNode_t, CUgraphExec => hipGraphExec_t, - CUkernel => hipFunction_t + CUkernel => hipFunction_t, + cublasLtMatmulDesc_t => hipblasLtMatmulDesc_t, + cublasLtMatmulPreference_t => hipblasLtMatmulPreference_t, + cublasLtMatrixLayout_t => hipblasLtMatrixLayout_t ); impl<'a, E: CudaErrorType> FromCuda<'a, CUlimit, E> for hipLimit_t {