Some fixes to BLASLt (#482)
Some checks failed
ZLUDA / Build (Linux) (push) Has been cancelled
ZLUDA / Build (Windows) (push) Has been cancelled
ZLUDA / Build AMD GPU unit tests (push) Has been cancelled
ZLUDA / Run AMD GPU unit tests (push) Has been cancelled

This commit is contained in:
Andrzej Janik 2025-08-26 23:28:36 +02:00 committed by GitHub
commit 3632f2bf03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 2078 additions and 2248 deletions

1
Cargo.lock generated
View file

@ -3766,6 +3766,7 @@ version = "0.0.0"
dependencies = [
"cuda_macros",
"cuda_types",
"zluda_common",
]
[[package]]

File diff suppressed because it is too large Load diff

View file

@ -50,8 +50,6 @@ cuda_macros::cublas_function_declarations!(
cublasDestroy_v2,
cublasGemmEx,
cublasGetMathMode,
cublasLtCreate,
cublasLtDestroy,
cublasSetMathMode,
cublasSetStream_v2,
cublasSetWorkspace_v2,

View file

@ -10,6 +10,7 @@ name = "cublaslt"
[dependencies]
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
zluda_common = { path = "../zluda_common" }
[package.metadata.zluda]
linux_symlinks = [

View file

@ -1,4 +1,22 @@
use cuda_types::cublas::*;
use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t};
use zluda_common::{from_cuda_object, ZludaObject};
pub struct Handle {
_handle: usize,
}
impl ZludaObject for Handle {
const COOKIE: usize = 0x49dec801578301ee;
type Error = cublasError_t;
type CudaHandle = cublasLtHandle_t;
fn drop_checked(&mut self) -> cublasStatus_t {
Ok(())
}
}
from_cuda_object!(Handle);
#[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> cublasStatus_t {
@ -10,15 +28,11 @@ pub(crate) fn unimplemented() -> cublasStatus_t {
cublasStatus_t::ERROR_NOT_SUPPORTED
}
pub(crate) fn get_status_name(
_status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_name(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
todo!()
}
pub(crate) fn get_status_string(
_status: cuda_types::cublas::cublasStatus_t,
) -> *const ::core::ffi::c_char {
pub(crate) fn get_status_string(_status: cublasStatus_t) -> *const ::core::ffi::c_char {
todo!()
}
@ -30,7 +44,15 @@ pub(crate) fn get_cudart_version() -> usize {
todo!()
}
#[allow(non_snake_case)]
pub(crate) fn disable_cpu_instructions_set_mask(_mask: ::core::ffi::c_uint) -> ::core::ffi::c_uint {
todo!()
}
pub(crate) fn create(handle: &mut cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
*handle = Handle { _handle: 0 }.wrap();
Ok(())
}
pub(crate) fn destroy(handle: cuda_types::cublaslt::cublasLtHandle_t) -> cublasStatus_t {
zluda_common::drop_checked::<Handle>(handle)
}

View file

@ -14,6 +14,20 @@ macro_rules! unimplemented {
}
macro_rules! implemented {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
#[allow(improper_ctypes)]
#[allow(improper_ctypes_definitions)]
pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
cuda_macros::cublaslt_normalize_fn!( crate::r#impl::$fn_name ) ($(zluda_common::FromCuda::<_, cuda_types::cublas::cublasError_t>::from_cuda(&$arg_id)?),*)?;
Ok(())
}
)*
};
}
macro_rules! implemented_unmapped {
($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:ty;)*) => {
$(
#[cfg_attr(not(test), no_mangle)]
@ -28,12 +42,13 @@ macro_rules! implemented {
cuda_macros::cublaslt_function_declarations!(
unimplemented,
implemented
implemented <= [cublasLtCreate, cublasLtDestroy,],
implemented_unmapped
<= [
cublasLtDisableCpuInstructionsSetMask,
cublasLtGetCudartVersion,
cublasLtGetStatusName,
cublasLtGetStatusString,
cublasLtDisableCpuInstructionsSetMask,
cublasLtGetVersion,
cublasLtGetCudartVersion
]
);

View file

@ -1,4 +1,4 @@
use cuda_types::{cublas::*, cuda::*, nvml::*};
use cuda_types::{cublas::*, cublaslt::cublasLtHandle_t, cuda::*, nvml::*};
use hip_runtime_sys::*;
use rocblas_sys::*;
use std::{
@ -156,7 +156,8 @@ from_cuda_nop!(
cublasMath_t,
nvmlDevice_t,
nvmlFieldValue_t,
nvmlGpuFabricInfo_t
nvmlGpuFabricInfo_t,
cublasLtHandle_t
);
from_cuda_transmute!(
CUuuid => hipUUID,