From 5185138596338ff9aac05218c8cecf897fd4f8ca Mon Sep 17 00:00:00 2001 From: Violet Date: Tue, 16 Sep 2025 16:23:15 -0700 Subject: [PATCH] Create bindings for hipblasLt (#510) Generate bindings for hipblasLt and make some changes to the bindings for cublasLt. Notably, the `hip_type` `Option` is changed to a `Vec`, so that multiple `From` implementations (for `rocblas_error` and `hipblasLtError`) can be created for `cublasError_t`. --- .github/workflows/rocm_setup_build.sh | 2 +- Cargo.lock | 8 + cuda_macros/src/cublaslt.rs | 8 +- cuda_types/Cargo.toml | 1 + cuda_types/src/cublas.rs | 5 + cuda_types/src/cublaslt.rs | 272 ----- ext/hipblaslt-sys/.rustfmt.toml | 1 + ext/hipblaslt-sys/Cargo.toml | 10 + ext/hipblaslt-sys/build.rs | 9 + ext/hipblaslt-sys/src/lib.rs | 1439 +++++++++++++++++++++++++ ext/rocm_smi-sys/.rustfmt.toml | 1 + format/src/format_generated_blaslt.rs | 362 +------ zluda_bindgen/src/main.rs | 131 ++- 13 files changed, 1589 insertions(+), 660 deletions(-) create mode 100644 ext/hipblaslt-sys/.rustfmt.toml create mode 100644 ext/hipblaslt-sys/Cargo.toml create mode 100644 ext/hipblaslt-sys/build.rs create mode 100644 ext/hipblaslt-sys/src/lib.rs create mode 100644 ext/rocm_smi-sys/.rustfmt.toml diff --git a/.github/workflows/rocm_setup_build.sh b/.github/workflows/rocm_setup_build.sh index f66c34e..0c27ae5 100644 --- a/.github/workflows/rocm_setup_build.sh +++ b/.github/workflows/rocm_setup_build.sh @@ -11,7 +11,7 @@ echo deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.c echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' \ | tee /etc/apt/preferences.d/rocm-pin-600 DEBIAN_FRONTEND=noninteractive apt update -y -DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends rocm-smi-lib rocm-llvm-dev hip-runtime-amd hip-dev rocblas-dev +DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends rocm-smi-lib rocm-llvm-dev hip-runtime-amd hip-dev rocblas-dev hipblaslt-dev echo 'export PATH="$PATH:/opt/rocm/bin"' | tee /etc/profile.d/rocm.sh echo "/opt/rocm/lib" | tee /etc/ld.so.conf.d/rocm.conf ldconfig diff --git a/Cargo.lock b/Cargo.lock index 3f883dd..32cf398 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -431,6 +431,7 @@ dependencies = [ "bitflags 2.9.1", "cuda_macros", "hip_runtime-sys", + "hipblaslt-sys", "rocblas-sys", "rocm_smi-sys", ] @@ -1773,6 +1774,13 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" name = "hip_runtime-sys" version = "0.0.0" +[[package]] +name = "hipblaslt-sys" +version = "0.1.0" +dependencies = [ + "hip_runtime-sys", +] + [[package]] name = "home" version = "0.5.11" diff --git a/cuda_macros/src/cublaslt.rs b/cuda_macros/src/cublaslt.rs index be04c7f..09e27f4 100644 --- a/cuda_macros/src/cublaslt.rs +++ b/cuda_macros/src/cublaslt.rs @@ -171,7 +171,7 @@ extern "system" { fn cublasLtMatmulDescInit_internal( matmulDesc: cuda_types::cublaslt::cublasLtMatmulDesc_t, size: usize, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, ) -> cuda_types::cublas::cublasStatus_t; #[must_use] @@ -181,7 +181,7 @@ extern "system" { \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully*/ fn cublasLtMatmulDescCreate( matmulDesc: *mut cuda_types::cublaslt::cublasLtMatmulDesc_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, ) -> cuda_types::cublas::cublasStatus_t; #[must_use] @@ -396,7 +396,7 @@ extern "system" { available*/ fn cublasLtMatmulAlgoGetIds( lightHandle: cuda_types::cublaslt::cublasLtHandle_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, Atype: cuda_types::cublaslt::cudaDataType_t, Btype: cuda_types::cublaslt::cudaDataType_t, @@ -414,7 +414,7 @@ extern "system" { \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized*/ fn cublasLtMatmulAlgoInit( lightHandle: cuda_types::cublaslt::cublasLtHandle_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, Atype: cuda_types::cublaslt::cudaDataType_t, Btype: cuda_types::cublaslt::cudaDataType_t, diff --git a/cuda_types/Cargo.toml b/cuda_types/Cargo.toml index 66320b9..47d004e 100644 --- a/cuda_types/Cargo.toml +++ b/cuda_types/Cargo.toml @@ -9,6 +9,7 @@ cuda_macros = { path = "../cuda_macros" } hip_runtime-sys = { path = "../ext/hip_runtime-sys" } bitflags = "2.9.1" rocblas-sys = { path = "../ext/rocblas-sys" } +hipblaslt-sys = { path = "../ext/hipblaslt-sys" } [target.'cfg(unix)'.dependencies] rocm_smi-sys = { path = "../ext/rocm_smi-sys" } diff --git a/cuda_types/src/cublas.rs b/cuda_types/src/cublas.rs index cb9190d..9aa3d99 100644 --- a/cuda_types/src/cublas.rs +++ b/cuda_types/src/cublas.rs @@ -363,3 +363,8 @@ impl From for cublasError_t { Self(error.0) } } +impl From for cublasError_t { + fn from(error: hipblaslt_sys::hipblasLtError) -> Self { + Self(error.0) + } +} diff --git a/cuda_types/src/cublaslt.rs b/cuda_types/src/cublaslt.rs index c89b69e..4af4c7f 100644 --- a/cuda_types/src/cublaslt.rs +++ b/cuda_types/src/cublaslt.rs @@ -33,277 +33,6 @@ pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3: u32 = 4194304; pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2: u32 = 8388608; pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK: u32 = 16711680; pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN: u64 = 4294967296; -impl cublasFillMode_t { - pub const CUBLAS_FILL_MODE_LOWER: cublasFillMode_t = cublasFillMode_t(0); -} -impl cublasFillMode_t { - pub const CUBLAS_FILL_MODE_UPPER: cublasFillMode_t = cublasFillMode_t(1); -} -impl cublasFillMode_t { - pub const CUBLAS_FILL_MODE_FULL: cublasFillMode_t = cublasFillMode_t(2); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasFillMode_t(pub ::core::ffi::c_uint); -impl cublasDiagType_t { - pub const CUBLAS_DIAG_NON_UNIT: cublasDiagType_t = cublasDiagType_t(0); -} -impl cublasDiagType_t { - pub const CUBLAS_DIAG_UNIT: cublasDiagType_t = cublasDiagType_t(1); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasDiagType_t(pub ::core::ffi::c_uint); -impl cublasSideMode_t { - pub const CUBLAS_SIDE_LEFT: cublasSideMode_t = cublasSideMode_t(0); -} -impl cublasSideMode_t { - pub const CUBLAS_SIDE_RIGHT: cublasSideMode_t = cublasSideMode_t(1); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasSideMode_t(pub ::core::ffi::c_uint); -impl cublasOperation_t { - pub const CUBLAS_OP_N: cublasOperation_t = cublasOperation_t(0); -} -impl cublasOperation_t { - pub const CUBLAS_OP_T: cublasOperation_t = cublasOperation_t(1); -} -impl cublasOperation_t { - pub const CUBLAS_OP_C: cublasOperation_t = cublasOperation_t(2); -} -impl cublasOperation_t { - pub const CUBLAS_OP_HERMITAN: cublasOperation_t = cublasOperation_t(2); -} -impl cublasOperation_t { - pub const CUBLAS_OP_CONJG: cublasOperation_t = cublasOperation_t(3); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasOperation_t(pub ::core::ffi::c_uint); -impl cublasPointerMode_t { - pub const CUBLAS_POINTER_MODE_HOST: cublasPointerMode_t = cublasPointerMode_t(0); -} -impl cublasPointerMode_t { - pub const CUBLAS_POINTER_MODE_DEVICE: cublasPointerMode_t = cublasPointerMode_t(1); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasPointerMode_t(pub ::core::ffi::c_uint); -impl cublasAtomicsMode_t { - pub const CUBLAS_ATOMICS_NOT_ALLOWED: cublasAtomicsMode_t = cublasAtomicsMode_t(0); -} -impl cublasAtomicsMode_t { - pub const CUBLAS_ATOMICS_ALLOWED: cublasAtomicsMode_t = cublasAtomicsMode_t(1); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasAtomicsMode_t(pub ::core::ffi::c_uint); -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_DFALT: cublasGemmAlgo_t = cublasGemmAlgo_t(-1); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_DEFAULT: cublasGemmAlgo_t = cublasGemmAlgo_t(-1); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO0: cublasGemmAlgo_t = cublasGemmAlgo_t(0); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO1: cublasGemmAlgo_t = cublasGemmAlgo_t(1); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO2: cublasGemmAlgo_t = cublasGemmAlgo_t(2); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO3: cublasGemmAlgo_t = cublasGemmAlgo_t(3); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO4: cublasGemmAlgo_t = cublasGemmAlgo_t(4); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO5: cublasGemmAlgo_t = cublasGemmAlgo_t(5); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO6: cublasGemmAlgo_t = cublasGemmAlgo_t(6); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO7: cublasGemmAlgo_t = cublasGemmAlgo_t(7); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO8: cublasGemmAlgo_t = cublasGemmAlgo_t(8); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO9: cublasGemmAlgo_t = cublasGemmAlgo_t(9); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO10: cublasGemmAlgo_t = cublasGemmAlgo_t(10); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO11: cublasGemmAlgo_t = cublasGemmAlgo_t(11); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO12: cublasGemmAlgo_t = cublasGemmAlgo_t(12); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO13: cublasGemmAlgo_t = cublasGemmAlgo_t(13); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO14: cublasGemmAlgo_t = cublasGemmAlgo_t(14); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO15: cublasGemmAlgo_t = cublasGemmAlgo_t(15); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO16: cublasGemmAlgo_t = cublasGemmAlgo_t(16); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO17: cublasGemmAlgo_t = cublasGemmAlgo_t(17); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO18: cublasGemmAlgo_t = cublasGemmAlgo_t(18); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO19: cublasGemmAlgo_t = cublasGemmAlgo_t(19); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO20: cublasGemmAlgo_t = cublasGemmAlgo_t(20); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO21: cublasGemmAlgo_t = cublasGemmAlgo_t(21); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO22: cublasGemmAlgo_t = cublasGemmAlgo_t(22); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO23: cublasGemmAlgo_t = cublasGemmAlgo_t(23); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_DEFAULT_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(99); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_DFALT_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(99); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO0_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(100); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO1_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(101); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO2_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(102); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO3_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(103); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO4_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(104); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO5_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(105); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO6_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(106); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO7_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(107); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO8_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(108); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO9_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(109); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO10_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(110); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO11_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(111); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO12_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(112); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO13_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(113); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO14_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(114); -} -impl cublasGemmAlgo_t { - pub const CUBLAS_GEMM_ALGO15_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(115); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasGemmAlgo_t(pub ::core::ffi::c_int); -impl cublasMath_t { - pub const CUBLAS_DEFAULT_MATH: cublasMath_t = cublasMath_t(0); -} -impl cublasMath_t { - pub const CUBLAS_TENSOR_OP_MATH: cublasMath_t = cublasMath_t(1); -} -impl cublasMath_t { - pub const CUBLAS_PEDANTIC_MATH: cublasMath_t = cublasMath_t(2); -} -impl cublasMath_t { - pub const CUBLAS_TF32_TENSOR_OP_MATH: cublasMath_t = cublasMath_t(3); -} -impl cublasMath_t { - pub const CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION: cublasMath_t = cublasMath_t( - 16, - ); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasMath_t(pub ::core::ffi::c_uint); -pub use super::cuda::cudaDataType as cublasDataType_t; -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_16F: cublasComputeType_t = cublasComputeType_t(64); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_16F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(65); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32F: cublasComputeType_t = cublasComputeType_t(68); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(69); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32F_FAST_16F: cublasComputeType_t = cublasComputeType_t(74); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32F_FAST_16BF: cublasComputeType_t = cublasComputeType_t( - 75, - ); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32F_FAST_TF32: cublasComputeType_t = cublasComputeType_t( - 77, - ); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_64F: cublasComputeType_t = cublasComputeType_t(70); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_64F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(71); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32I: cublasComputeType_t = cublasComputeType_t(72); -} -impl cublasComputeType_t { - pub const CUBLAS_COMPUTE_32I_PEDANTIC: cublasComputeType_t = cublasComputeType_t(73); -} -#[repr(transparent)] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub struct cublasComputeType_t(pub ::core::ffi::c_uint); -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct cublasContext { - _unused: [u8; 0], -} -pub type cublasHandle_t = *mut cublasContext; -pub type cublasLogCallback = ::core::option::Option< - unsafe extern "C" fn(msg: *const ::core::ffi::c_char), ->; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct cublasLtContext { @@ -5071,7 +4800,6 @@ pub struct cublasLtMatmulPreferenceAttributes_t(pub ::core::ffi::c_uint); Holds returned configured algo descriptor and its runtime properties.*/ #[repr(C)] -#[derive(Debug, Copy, Clone, PartialEq)] pub struct cublasLtMatmulHeuristicResult_t { /** Matmul algorithm descriptor. diff --git a/ext/hipblaslt-sys/.rustfmt.toml b/ext/hipblaslt-sys/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/ext/hipblaslt-sys/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/ext/hipblaslt-sys/Cargo.toml b/ext/hipblaslt-sys/Cargo.toml new file mode 100644 index 0000000..a4111eb --- /dev/null +++ b/ext/hipblaslt-sys/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hipblaslt-sys" +version = "0.1.0" +authors = ["Violet "] +edition = "2021" + +[lib] + +[dependencies] +hip_runtime-sys = { version = "0.0.0", path = "../hip_runtime-sys" } diff --git a/ext/hipblaslt-sys/build.rs b/ext/hipblaslt-sys/build.rs new file mode 100644 index 0000000..b9fe5ec --- /dev/null +++ b/ext/hipblaslt-sys/build.rs @@ -0,0 +1,9 @@ +use std::env::VarError; + +fn main() -> Result<(), VarError> { + if !cfg!(windows) { + println!("cargo:rustc-link-lib=dylib=hipblaslt"); + println!("cargo:rustc-link-search=native=/opt/rocm/lib/"); + } + Ok(()) +} diff --git a/ext/hipblaslt-sys/src/lib.rs b/ext/hipblaslt-sys/src/lib.rs new file mode 100644 index 0000000..ae4fef2 --- /dev/null +++ b/ext/hipblaslt-sys/src/lib.rs @@ -0,0 +1,1439 @@ +// Generated automatically by zluda_bindgen +// DO NOT EDIT MANUALLY +#![allow(warnings)] +impl hipblasOperation_t { + ///< Operate with the matrix. + pub const HIPBLAS_OP_N: hipblasOperation_t = hipblasOperation_t(111); +} +impl hipblasOperation_t { + ///< Operate with the transpose of the matrix. + pub const HIPBLAS_OP_T: hipblasOperation_t = hipblasOperation_t(112); +} +impl hipblasOperation_t { + ///< Operate with the conjugate transpose of the matrix. + pub const HIPBLAS_OP_C: hipblasOperation_t = hipblasOperation_t(113); +} +#[repr(transparent)] +/// \brief Used to specify whether the matrix is to be transposed or not. +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct hipblasOperation_t(pub ::core::ffi::c_uint); +impl hipblasComputeType_t { + ///< compute will be at least 16-bit precision + pub const HIPBLAS_COMPUTE_16F: hipblasComputeType_t = hipblasComputeType_t(0); +} +impl hipblasComputeType_t { + ///< compute will be exactly 16-bit precision + pub const HIPBLAS_COMPUTE_16F_PEDANTIC: hipblasComputeType_t = hipblasComputeType_t( + 1, + ); +} +impl hipblasComputeType_t { + ///< compute will be at least 32-bit precision + pub const HIPBLAS_COMPUTE_32F: hipblasComputeType_t = hipblasComputeType_t(2); +} +impl hipblasComputeType_t { + ///< compute will be exactly 32-bit precision + pub const HIPBLAS_COMPUTE_32F_PEDANTIC: hipblasComputeType_t = hipblasComputeType_t( + 3, + ); +} +impl hipblasComputeType_t { + ///< 32-bit input can use 16-bit compute + pub const HIPBLAS_COMPUTE_32F_FAST_16F: hipblasComputeType_t = hipblasComputeType_t( + 4, + ); +} +impl hipblasComputeType_t { + ///< 32-bit input can is bf16 compute + pub const HIPBLAS_COMPUTE_32F_FAST_16BF: hipblasComputeType_t = hipblasComputeType_t( + 5, + ); +} +impl hipblasComputeType_t { + pub const HIPBLAS_COMPUTE_32F_FAST_TF32: hipblasComputeType_t = hipblasComputeType_t( + 6, + ); +} +impl hipblasComputeType_t { + ///< compute will be at least 64-bit precision + pub const HIPBLAS_COMPUTE_64F: hipblasComputeType_t = hipblasComputeType_t(7); +} +impl hipblasComputeType_t { + ///< compute will be exactly 64-bit precision + pub const HIPBLAS_COMPUTE_64F_PEDANTIC: hipblasComputeType_t = hipblasComputeType_t( + 8, + ); +} +impl hipblasComputeType_t { + ///< compute will be at least 32-bit integer precision + pub const HIPBLAS_COMPUTE_32I: hipblasComputeType_t = hipblasComputeType_t(9); +} +impl hipblasComputeType_t { + ///< compute will be exactly 32-bit integer precision + pub const HIPBLAS_COMPUTE_32I_PEDANTIC: hipblasComputeType_t = hipblasComputeType_t( + 10, + ); +} +#[repr(transparent)] +/** \brief The compute type to be used. Currently only used with GemmEx with the HIPBLAS_V2 interface. + Note that support for compute types is largely dependent on backend.*/ +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct hipblasComputeType_t(pub ::core::ffi::c_uint); +/// \brief Struct to represent a 16 bit brain floating point number. +#[repr(C)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct hip_bfloat16 { + pub data: u16, +} +impl hipDataType { + pub const HIP_R_32F: hipDataType = hipDataType(0); +} +impl hipDataType { + pub const HIP_R_64F: hipDataType = hipDataType(1); +} +impl hipDataType { + pub const HIP_R_16F: hipDataType = hipDataType(2); +} +impl hipDataType { + pub const HIP_R_8I: hipDataType = hipDataType(3); +} +impl hipDataType { + pub const HIP_C_32F: hipDataType = hipDataType(4); +} +impl hipDataType { + pub const HIP_C_64F: hipDataType = hipDataType(5); +} +impl hipDataType { + pub const HIP_C_16F: hipDataType = hipDataType(6); +} +impl hipDataType { + pub const HIP_C_8I: hipDataType = hipDataType(7); +} +impl hipDataType { + pub const HIP_R_8U: hipDataType = hipDataType(8); +} +impl hipDataType { + pub const HIP_C_8U: hipDataType = hipDataType(9); +} +impl hipDataType { + pub const HIP_R_32I: hipDataType = hipDataType(10); +} +impl hipDataType { + pub const HIP_C_32I: hipDataType = hipDataType(11); +} +impl hipDataType { + pub const HIP_R_32U: hipDataType = hipDataType(12); +} +impl hipDataType { + pub const HIP_C_32U: hipDataType = hipDataType(13); +} +impl hipDataType { + pub const HIP_R_16BF: hipDataType = hipDataType(14); +} +impl hipDataType { + pub const HIP_C_16BF: hipDataType = hipDataType(15); +} +impl hipDataType { + pub const HIP_R_4I: hipDataType = hipDataType(16); +} +impl hipDataType { + pub const HIP_C_4I: hipDataType = hipDataType(17); +} +impl hipDataType { + pub const HIP_R_4U: hipDataType = hipDataType(18); +} +impl hipDataType { + pub const HIP_C_4U: hipDataType = hipDataType(19); +} +impl hipDataType { + pub const HIP_R_16I: hipDataType = hipDataType(20); +} +impl hipDataType { + pub const HIP_C_16I: hipDataType = hipDataType(21); +} +impl hipDataType { + pub const HIP_R_16U: hipDataType = hipDataType(22); +} +impl hipDataType { + pub const HIP_C_16U: hipDataType = hipDataType(23); +} +impl hipDataType { + pub const HIP_R_64I: hipDataType = hipDataType(24); +} +impl hipDataType { + pub const HIP_C_64I: hipDataType = hipDataType(25); +} +impl hipDataType { + pub const HIP_R_64U: hipDataType = hipDataType(26); +} +impl hipDataType { + pub const HIP_C_64U: hipDataType = hipDataType(27); +} +impl hipDataType { + pub const HIP_R_8F_E4M3: hipDataType = hipDataType(28); +} +impl hipDataType { + pub const HIP_R_8F_E5M2: hipDataType = hipDataType(29); +} +impl hipDataType { + pub const HIP_R_8F_E4M3_FNUZ: hipDataType = hipDataType(1000); +} +impl hipDataType { + pub const HIP_R_8F_E5M2_FNUZ: hipDataType = hipDataType(1001); +} +#[repr(transparent)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct hipDataType(pub ::core::ffi::c_uint); +/// \brief Single precision floating point type +pub type hipblasLtFloat = f32; +/// \brief Structure definition for hipblasLtHalf +#[repr(C)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct _hipblasLtHalf { + pub data: u16, +} +/// \brief Structure definition for hipblasLtHalf +pub type hipblasLtHalf = _hipblasLtHalf; +/// \brief Struct to represent a 16 bit brain floating point number. +pub type hipblasLtBfloat16 = hip_bfloat16; +pub type hipblasLtInt8 = i8; +pub type hipblasLtInt32 = i32; +impl hipblasLtEpilogue_t { + /// hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + pub fn hipblasLtGetGitRevision( + handle: hipblasLtHandle_t, + rev: *mut ::core::ffi::c_char, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + pub fn hipblasLtGetArchName( + archName: *mut *mut ::core::ffi::c_char, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Create a hipblaslt handle + + \details + This function initializes the hipBLASLt library and creates a handle to an + opaque structure holding the hipBLASLt library context. It allocates light + hardware resources on the host and device, and must be called prior to making + any other hipBLASLt library calls. The hipBLASLt library context is tied to + the current ROCm device. To use the library on multiple devices, one + hipBLASLt handle should be created for each device. + + @param[out] + handle Pointer to the allocated hipBLASLt handle for the created hipBLASLt + context. + + \retval HIPBLAS_STATUS_SUCCESS The allocation completed successfully. + \retval HIPBLAS_STATUS_INVALID_VALUE \p handle == NULL.*/ + pub fn hipblasLtCreate(handle: *mut hipblasLtHandle_t) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Destory a hipblaslt handle + + \details + This function releases hardware resources used by the hipBLASLt library. + This function is usually the last call with a particular handle to the + hipBLASLt library. Because hipblasLtCreate() allocates some internal + resources and the release of those resources by calling hipblasLtDestroy() + will implicitly call hipDeviceSynchronize(), it is recommended to minimize + the number of hipblasLtCreate()/hipblasLtDestroy() occurrences. + + @param[in] + handle Pointer to the hipBLASLt handle to be destroyed. + + \retval HIPBLAS_STATUS_SUCCESS The hipBLASLt context was successfully + destroyed. \retval HIPBLAS_STATUS_NOT_INITIALIZED The hipBLASLt library was + not initialized. \retval HIPBLAS_STATUS_INVALID_VALUE \p handle == NULL.*/ + pub fn hipblasLtDestroy(handle: hipblasLtHandle_t) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Create a matrix layout descriptor + + \details + This function creates a matrix layout descriptor by allocating the memory + needed to hold its opaque structure. + + @param[out] + matLayout Pointer to the structure holding the matrix layout descriptor + created by this function. see \ref hipblasLtMatrixLayout_t . + @param[in] + type Enumerant that specifies the data precision for the matrix layout + descriptor this function creates. See hipDataType. + @param[in] + rows Number of rows of the matrix. + @param[in] + cols Number of columns of the matrix. + @param[in] + ld The leading dimension of the matrix. In column major layout, this is the + number of elements to jump to reach the next column. Thus ld >= m (number of + rows). + + \retval HIPBLAS_STATUS_SUCCESS If the descriptor was created successfully. + \retval HIPBLAS_STATUS_ALLOC_FAILED If the memory could not be allocated.*/ + pub fn hipblasLtMatrixLayoutCreate( + matLayout: *mut hipblasLtMatrixLayout_t, + type_: hipDataType, + rows: u64, + cols: u64, + ld: i64, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Destory a matrix layout descriptor + + \details + This function destroys a previously created matrix layout descriptor object. + + @param[in] + matLayout Pointer to the structure holding the matrix layout descriptor that + should be destroyed by this function. see \ref hipblasLtMatrixLayout_t . + + \retval HIPBLAS_STATUS_SUCCESS If the operation was successful.*/ + pub fn hipblasLtMatrixLayoutDestroy( + matLayout: hipblasLtMatrixLayout_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Set attribute to a matrix descriptor + + \details + This function sets the value of the specified attribute belonging to a + previously created matrix descriptor. + + @param[in] + matLayout Pointer to the previously created structure holding the matrix + mdescriptor queried by this function. See \ref hipblasLtMatrixLayout_t. + @param[in] + attr The attribute that will be set by this function. See \ref + hipblasLtMatrixLayoutAttribute_t. + @param[in] + buf The value to which the specified attribute should be set. + @param[in] + sizeInBytes Size of buf buffer (in bytes) for verification. + + \retval HIPBLAS_STATUS_SUCCESS If the attribute was set successfully.. + \retval HIPBLAS_STATUS_INVALID_VALUE If \p buf is NULL or \p sizeInBytes + doesn't match the size of the internal storage for the selected attribute.*/ + pub fn hipblasLtMatrixLayoutSetAttribute( + matLayout: hipblasLtMatrixLayout_t, + attr: hipblasLtMatrixLayoutAttribute_t, + buf: *const ::core::ffi::c_void, + sizeInBytes: usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Query attribute from a matrix descriptor + + \details + This function returns the value of the queried attribute belonging to a + previously created matrix descriptor. + + @param[in] + matLayout Pointer to the previously created structure holding the matrix + descriptor queried by this function. See \ref hipblasLtMatrixLayout_t. + @param[in] + attr The attribute that will be retrieved by this function. See + \ref hipblasLtMatrixLayoutAttribute_t. + @param[out] + buf Memory address containing the attribute value retrieved by this + function. + @param[in] + sizeInBytes Size of \p buf buffer (in bytes) for verification. + @param[out] + sizeWritten Valid only when the return value is HIPBLAS_STATUS_SUCCESS. If + sizeInBytes is non-zero: then sizeWritten is the number of bytes actually + written; if sizeInBytes is 0: then sizeWritten is the number of bytes needed + to write full contents. + + \retval HIPBLAS_STATUS_SUCCESS If attribute's value was successfully + written to user memory. \retval HIPBLAS_STATUS_INVALID_VALUE If \p + sizeInBytes is 0 and \p sizeWritten is NULL, or if \p sizeInBytes is non-zero + and \p buf is NULL, or \p sizeInBytes doesn't match size of internal storage + for the selected attribute.*/ + pub fn hipblasLtMatrixLayoutGetAttribute( + matLayout: hipblasLtMatrixLayout_t, + attr: hipblasLtMatrixLayoutAttribute_t, + buf: *mut ::core::ffi::c_void, + sizeInBytes: usize, + sizeWritten: *mut usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Create a matrix multiply descriptor + + \details + This function creates a matrix multiply descriptor by allocating the memory + needed to hold its opaque structure. + + @param[out] + matmulDesc Pointer to the structure holding the matrix multiply descriptor + created by this function. See \ref hipblasLtMatmulDesc_t . + @param[in] + computeType Enumerant that specifies the data precision for the matrix + multiply descriptor this function creates. See hipblasComputeType_t . + @param[in] + scaleType Enumerant that specifies the data precision for the matrix + transform descriptor this function creates. See hipDataType. + + \retval HIPBLAS_STATUS_SUCCESS If the descriptor was created successfully. + \retval HIPBLAS_STATUS_ALLOC_FAILED If the memory could not be allocated.*/ + pub fn hipblasLtMatmulDescCreate( + matmulDesc: *mut hipblasLtMatmulDesc_t, + computeType: hipblasComputeType_t, + scaleType: hipDataType, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Destory a matrix multiply descriptor + + \details + This function destroys a previously created matrix multiply descriptor + object. + + @param[in] + matmulDesc Pointer to the structure holding the matrix multiply descriptor + that should be destroyed by this function. See \ref hipblasLtMatmulDesc_t . + + \retval HIPBLAS_STATUS_SUCCESS If operation was successful.*/ + pub fn hipblasLtMatmulDescDestroy( + matmulDesc: hipblasLtMatmulDesc_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Set attribute to a matrix multiply descriptor + + \details + This function sets the value of the specified attribute belonging to a + previously created matrix multiply descriptor. + + @param[in] + matmulDesc Pointer to the previously created structure holding the matrix + multiply descriptor queried by this function. See \ref hipblasLtMatmulDesc_t. + @param[in] + attr The attribute that will be set by this function. See \ref + hipblasLtMatmulDescAttributes_t. + @param[in] + buf The value to which the specified attribute should be set. + @param[in] + sizeInBytes Size of buf buffer (in bytes) for verification. + + \retval HIPBLAS_STATUS_SUCCESS If the attribute was set successfully.. + \retval HIPBLAS_STATUS_INVALID_VALUE If \p buf is NULL or \p sizeInBytes + doesn't match the size of the internal storage for the selected attribute.*/ + pub fn hipblasLtMatmulDescSetAttribute( + matmulDesc: hipblasLtMatmulDesc_t, + attr: hipblasLtMatmulDescAttributes_t, + buf: *const ::core::ffi::c_void, + sizeInBytes: usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Query attribute from a matrix multiply descriptor + + \details + This function returns the value of the queried attribute belonging to a + previously created matrix multiply descriptor. + + @param[in] + matmulDesc Pointer to the previously created structure holding the matrix + multiply descriptor queried by this function. See \ref hipblasLtMatmulDesc_t. + @param[in] + attr The attribute that will be retrieved by this function. See + \ref hipblasLtMatmulDescAttributes_t. + @param[out] + buf Memory address containing the attribute value retrieved by this + function. + @param[in] + sizeInBytes Size of \p buf buffer (in bytes) for verification. + @param[out] + sizeWritten Valid only when the return value is HIPBLAS_STATUS_SUCCESS. If + sizeInBytes is non-zero: then sizeWritten is the number of bytes actually + written; if sizeInBytes is 0: then sizeWritten is the number of bytes needed + to write full contents. + + \retval HIPBLAS_STATUS_SUCCESS If attribute's value was successfully + written to user memory. \retval HIPBLAS_STATUS_INVALID_VALUE If \p + sizeInBytes is 0 and \p sizeWritten is NULL, or if \p sizeInBytes is non-zero + and \p buf is NULL, or \p sizeInBytes doesn't match size of internal storage + for the selected attribute.*/ + pub fn hipblasLtMatmulDescGetAttribute( + matmulDesc: hipblasLtMatmulDesc_t, + attr: hipblasLtMatmulDescAttributes_t, + buf: *mut ::core::ffi::c_void, + sizeInBytes: usize, + sizeWritten: *mut usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Create a preference descriptor + + \details + This function creates a matrix multiply heuristic search preferences + descriptor by allocating the memory needed to hold its opaque structure. + + @param[out] + pref Pointer to the structure holding the matrix multiply preferences + descriptor created by this function. see \ref hipblasLtMatmulPreference_t . + + \retval HIPBLAS_STATUS_SUCCESS If the descriptor was created + successfully. \retval HIPBLAS_STATUS_ALLOC_FAILED If memory could not be + allocated.*/ + pub fn hipblasLtMatmulPreferenceCreate( + pref: *mut hipblasLtMatmulPreference_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Destory a preferences descriptor + + \details + This function destroys a previously created matrix multiply preferences + descriptor object. + + @param[in] + pref Pointer to the structure holding the matrix multiply preferences + descriptor that should be destroyed by this function. See \ref + hipblasLtMatmulPreference_t . + + \retval HIPBLAS_STATUS_SUCCESS If operation was successful.*/ + pub fn hipblasLtMatmulPreferenceDestroy( + pref: hipblasLtMatmulPreference_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Set attribute to a preference descriptor + + \details + This function sets the value of the specified attribute belonging to a + previously created matrix multiply preferences descriptor. + + @param[in] + pref Pointer to the previously created structure holding the matrix + multiply preferences descriptor queried by this function. See \ref + hipblasLtMatmulPreference_t + @param[in] + attr The attribute that will be set by this function. See \ref + hipblasLtMatmulPreferenceAttributes_t. + @param[in] + buf The value to which the specified attribute should be set. + @param[in] + sizeInBytes Size of \p buf buffer (in bytes) for verification. + + \retval HIPBLAS_STATUS_SUCCESS If the attribute was set successfully.. + \retval HIPBLAS_STATUS_INVALID_VALUE If \p buf is NULL or \p sizeInBytes + doesn't match the size of the internal storage for the selected attribute.*/ + pub fn hipblasLtMatmulPreferenceSetAttribute( + pref: hipblasLtMatmulPreference_t, + attr: hipblasLtMatmulPreferenceAttributes_t, + buf: *const ::core::ffi::c_void, + sizeInBytes: usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Query attribute from a preference descriptor + + \details + This function returns the value of the queried attribute belonging to a + previously created matrix multiply heuristic search preferences descriptor. + + @param[in] + pref Pointer to the previously created structure holding the matrix + multiply heuristic search preferences descriptor queried by this function. + See \ref hipblasLtMatmulPreference_t. + @param[in] + attr The attribute that will be retrieved by this function. See + \ref hipblasLtMatmulPreferenceAttributes_t. + @param[out] + buf Memory address containing the attribute value retrieved by this + function. + @param[in] + sizeInBytes Size of \p buf buffer (in bytes) for verification. + @param[out] + sizeWritten Valid only when the return value is HIPBLAS_STATUS_SUCCESS. If + sizeInBytes is non-zero: then sizeWritten is the number of bytes actually + written; if sizeInBytes is 0: then sizeWritten is the number of bytes needed + to write full contents. + + \retval HIPBLAS_STATUS_SUCCESS If attribute's value was successfully + written to user memory. \retval HIPBLAS_STATUS_INVALID_VALUE If \p + sizeInBytes is 0 and \p sizeWritten is NULL, or if \p sizeInBytes is non-zero + and \p buf is NULL, or \p sizeInBytes doesn't match size of internal storage + for the selected attribute.*/ + pub fn hipblasLtMatmulPreferenceGetAttribute( + pref: hipblasLtMatmulPreference_t, + attr: hipblasLtMatmulPreferenceAttributes_t, + buf: *mut ::core::ffi::c_void, + sizeInBytes: usize, + sizeWritten: *mut usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Retrieve the possible algorithms + + \details + This function retrieves the possible algorithms for the matrix multiply + operation hipblasLtMatmul() function with the given input matrices A, B and + C, and the output matrix D. The output is placed in heuristicResultsArray[] + in the order of increasing estimated compute time. Note that the wall duration + increases if the requestedAlgoCount increases. + + @param[in] + handle Pointer to the allocated hipBLASLt handle for the + hipBLASLt context. See \ref hipblasLtHandle_t . + @param[in] + matmulDesc Handle to a previously created matrix multiplication + descriptor of type \ref hipblasLtMatmulDesc_t . + @param[in] + Adesc,Bdesc,Cdesc,Ddesc Handles to the previously created matrix layout + descriptors of the type \ref hipblasLtMatrixLayout_t . + @param[in] + pref Pointer to the structure holding the heuristic + search preferences descriptor. See \ref hipblasLtMatmulPreference_t . + @param[in] + requestedAlgoCount Size of the \p heuristicResultsArray (in elements). + This is the requested maximum number of algorithms to return. + @param[out] + heuristicResultsArray[] Array containing the algorithm heuristics and + associated runtime characteristics, returned by this function, in the order + of increasing estimated compute time. + @param[out] + returnAlgoCount Number of algorithms returned by this function. This + is the number of \p heuristicResultsArray elements written. + + \retval HIPBLAS_STATUS_SUCCESS If query was successful. Inspect + heuristicResultsArray[0 to (returnAlgoCount -1)].state for the status of the + results. \retval HIPBLAS_STATUS_NOT_SUPPORTED If no heuristic function + available for current configuration. \retval HIPBLAS_STATUS_INVALID_VALUE If + \p requestedAlgoCount is less or equal to zero.*/ + pub fn hipblasLtMatmulAlgoGetHeuristic( + handle: hipblasLtHandle_t, + matmulDesc: hipblasLtMatmulDesc_t, + Adesc: hipblasLtMatrixLayout_t, + Bdesc: hipblasLtMatrixLayout_t, + Cdesc: hipblasLtMatrixLayout_t, + Ddesc: hipblasLtMatrixLayout_t, + pref: hipblasLtMatmulPreference_t, + requestedAlgoCount: ::core::ffi::c_int, + heuristicResultsArray: *mut hipblasLtMatmulHeuristicResult_t, + returnAlgoCount: *mut ::core::ffi::c_int, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Retrieve the possible algorithms + + \details + This function computes the matrix multiplication of matrices A and B to + produce the output matrix D, according to the following operation: \p D = \p + alpha*( \p A *\p B) + \p beta*( \p C ), where \p A, \p B, and \p C are input + matrices, and \p alpha and \p beta are input scalars. Note: This function + supports both in-place matrix multiplication (C == D and Cdesc == Ddesc) and + out-of-place matrix multiplication (C != D, both matrices must have the same + data type, number of rows, number of columns, batch size, and memory order). + In the out-of-place case, the leading dimension of C can be different from + the leading dimension of D. Specifically the leading dimension of C can be 0 + to achieve row or column broadcast. If Cdesc is omitted, this function + assumes it to be equal to Ddesc. + + @param[in] + handle Pointer to the allocated hipBLASLt handle for the + hipBLASLt context. See \ref hipblasLtHandle_t . + @param[in] + matmulDesc Handle to a previously created matrix multiplication + descriptor of type \ref hipblasLtMatmulDesc_t . + @param[in] + alpha,beta Pointers to the scalars used in the multiplication. + @param[in] + Adesc,Bdesc,Cdesc,Ddesc Handles to the previously created matrix layout + descriptors of the type \ref hipblasLtMatrixLayout_t . + @param[in] + A,B,C Pointers to the GPU memory associated with the + corresponding descriptors \p Adesc, \p Bdesc and \p Cdesc . + @param[out] + D Pointer to the GPU memory associated with the + descriptor \p Ddesc . + @param[in] + algo Handle for matrix multiplication algorithm to be + used. See \ref hipblasLtMatmulAlgo_t . When NULL, an implicit heuristics query + with default search preferences will be performed to determine actual + algorithm to use. + @param[in] + workspace Pointer to the workspace buffer allocated in the GPU + memory. Pointer must be 16B aligned (that is, lowest 4 bits of address must + be 0). + @param[in] + workspaceSizeInBytes Size of the workspace. + @param[in] + stream The HIP stream where all the GPU work will be + submitted. + + \retval HIPBLAS_STATUS_SUCCESS If the operation completed + successfully. \retval HIPBLAS_STATUS_EXECUTION_FAILED If HIP reported an + execution error from the device. \retval HIPBLAS_STATUS_ARCH_MISMATCH If + the configured operation cannot be run using the selected device. \retval + HIPBLAS_STATUS_NOT_SUPPORTED If the current implementation on the + selected device doesn't support the configured operation. \retval + HIPBLAS_STATUS_INVALID_VALUE If the parameters are unexpectedly NULL, in + conflict or in an impossible configuration. For example, when + workspaceSizeInBytes is less than workspace required by the configured algo. + \retval HIBLAS_STATUS_NOT_INITIALIZED If hipBLASLt handle has not been + initialized.*/ + pub fn hipblasLtMatmul( + handle: hipblasLtHandle_t, + matmulDesc: hipblasLtMatmulDesc_t, + alpha: *const ::core::ffi::c_void, + A: *const ::core::ffi::c_void, + Adesc: hipblasLtMatrixLayout_t, + B: *const ::core::ffi::c_void, + Bdesc: hipblasLtMatrixLayout_t, + beta: *const ::core::ffi::c_void, + C: *const ::core::ffi::c_void, + Cdesc: hipblasLtMatrixLayout_t, + D: *mut ::core::ffi::c_void, + Ddesc: hipblasLtMatrixLayout_t, + algo: *const hipblasLtMatmulAlgo_t, + workspace: *mut ::core::ffi::c_void, + workspaceSizeInBytes: usize, + stream: hip_runtime_sys::hipStream_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** Create new matrix transform operation descriptor. + + \retval HIPBLAS_STATUS_ALLOC_FAILED if memory could not be allocated + \retval HIPBLAS_STATUS_SUCCESS if desciptor was created successfully*/ + pub fn hipblasLtMatrixTransformDescCreate( + transformDesc: *mut hipblasLtMatrixTransformDesc_t, + scaleType: hipDataType, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** Destroy matrix transform operation descriptor. + + \retval HIPBLAS_STATUS_SUCCESS if operation was successful*/ + pub fn hipblasLtMatrixTransformDescDestroy( + transformDesc: hipblasLtMatrixTransformDesc_t, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** Set matrix transform operation descriptor attribute. + + \param[in] transformDesc The descriptor + \param[in] attr The attribute + \param[in] buf memory address containing the new value + \param[in] sizeInBytes size of buf buffer for verification (in bytes) + + \retval HIPBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for + selected attribute + \retval HIPBLAS_STATUS_SUCCESS if attribute was set successfully*/ + pub fn hipblasLtMatrixTransformDescSetAttribute( + transformDesc: hipblasLtMatrixTransformDesc_t, + attr: hipblasLtMatrixTransformDescAttributes_t, + buf: *const ::core::ffi::c_void, + sizeInBytes: usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Matrix transform operation getter + \details Get matrix transform operation descriptor attribute. + + @param[in] transformDesc The descriptor + @param[in] attr The attribute + @param[out] buf memory address containing the new value + @param[in] sizeInBytes size of buf buffer for verification (in bytes) + @param[out] sizeWritten only valid when return value is HIPBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number + of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents + + \retval HIPBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero + and buf is NULL or sizeInBytes doesn't match size of internal storage for + selected attribute + \retval HIPBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory*/ + pub fn hipblasLtMatrixTransformDescGetAttribute( + transformDesc: hipblasLtMatrixTransformDesc_t, + attr: hipblasLtMatrixTransformDescAttributes_t, + buf: *mut ::core::ffi::c_void, + sizeInBytes: usize, + sizeWritten: *mut usize, + ) -> hipblasStatus_t; +} +#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))] +extern "C" { + #[must_use] + /** \ingroup library_module + \brief Matrix layout conversion helper + \details + Matrix layout conversion helper (C = alpha * op(A) + beta * op(B)), + can be used to change memory order of data or to scale and shift the values. + @param[in] lightHandle Pointer to the allocated hipBLASLt handle for the + hipBLASLt context. See \ref hipblasLtHandle_t . + @param[in] transformDesc Pointer to allocated matrix transform descriptor. + @param[in] alpha Pointer to scalar alpha, either pointer to host or device address. + @param[in] A Pointer to matrix A, must be pointer to device address. + @param[in] Adesc Pointer to layout for input matrix A. + @param[in] beta Pointer to scalar beta, either pointer to host or device address. + @param[in] B Pointer to layout for matrix B, must be pointer to device address + @param[in] Bdesc Pointer to layout for inputmatrix B. + @param[in] C Pointer to matrix C, must be pointer to device address + @param[out] Cdesc Pointer to layout for output matrix C. + @param[in] stream The HIP stream where all the GPU work will be submitted. + + \retval HIPBLAS_STATUS_NOT_INITIALIZED if hipBLASLt handle has not been initialized + \retval HIPBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g. + when A is not NULL, but Adesc is NULL + \retval HIPBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured + operation + \retval HIPBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device + \retval HIPBLAS_STATUS_EXECUTION_FAILED if HIP reported execution error from the device + \retval HIPBLAS_STATUS_SUCCESS if the operation completed successfully*/ + pub fn hipblasLtMatrixTransform( + lightHandle: hipblasLtHandle_t, + transformDesc: hipblasLtMatrixTransformDesc_t, + alpha: *const ::core::ffi::c_void, + A: *const ::core::ffi::c_void, + Adesc: hipblasLtMatrixLayout_t, + beta: *const ::core::ffi::c_void, + B: *const ::core::ffi::c_void, + Bdesc: hipblasLtMatrixLayout_t, + C: *mut ::core::ffi::c_void, + Cdesc: hipblasLtMatrixLayout_t, + stream: hip_runtime_sys::hipStream_t, + ) -> hipblasStatus_t; +} +impl hipblasLtError { + pub const r#NOT_INITIALIZED: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(1) + }); + pub const r#ALLOC_FAILED: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(2) + }); + pub const r#INVALID_VALUE: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(3) + }); + pub const r#MAPPING_ERROR: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(4) + }); + pub const r#EXECUTION_FAILED: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(5) + }); + pub const r#INTERNAL_ERROR: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(6) + }); + pub const r#NOT_SUPPORTED: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(7) + }); + pub const r#ARCH_MISMATCH: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(8) + }); + pub const r#HANDLE_IS_NULLPTR: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(9) + }); + pub const r#INVALID_ENUM: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(10) + }); + pub const r#UNKNOWN: hipblasLtError = hipblasLtError(unsafe { + ::core::num::NonZeroU32::new_unchecked(11) + }); +} +#[repr(transparent)] +#[derive(Debug, Hash, Copy, Clone, PartialEq, Eq)] +pub struct hipblasLtError(pub ::core::num::NonZeroU32); +pub trait hipblasStatus_tConsts { + const SUCCESS: hipblasStatus_t = hipblasStatus_t::Ok(()); + const ERROR_NOT_INITIALIZED: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#NOT_INITIALIZED, + ); + const ERROR_ALLOC_FAILED: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#ALLOC_FAILED, + ); + const ERROR_INVALID_VALUE: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#INVALID_VALUE, + ); + const ERROR_MAPPING_ERROR: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#MAPPING_ERROR, + ); + const ERROR_EXECUTION_FAILED: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#EXECUTION_FAILED, + ); + const ERROR_INTERNAL_ERROR: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#INTERNAL_ERROR, + ); + const ERROR_NOT_SUPPORTED: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#NOT_SUPPORTED, + ); + const ERROR_ARCH_MISMATCH: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#ARCH_MISMATCH, + ); + const ERROR_HANDLE_IS_NULLPTR: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#HANDLE_IS_NULLPTR, + ); + const ERROR_INVALID_ENUM: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#INVALID_ENUM, + ); + const ERROR_UNKNOWN: hipblasStatus_t = hipblasStatus_t::Err( + hipblasLtError::r#UNKNOWN, + ); +} +impl hipblasStatus_tConsts for hipblasStatus_t {} +#[must_use] +pub type hipblasStatus_t = ::core::result::Result<(), hipblasLtError>; +const _: fn() = || { + let _ = std::mem::transmute::; +}; +unsafe impl Send for hipblasLtHandle_t {} +unsafe impl Sync for hipblasLtHandle_t {} +unsafe impl Send for hipblasLtMatmulDesc_t {} +unsafe impl Sync for hipblasLtMatmulDesc_t {} +unsafe impl Send for hipblasLtMatmulPreference_t {} +unsafe impl Sync for hipblasLtMatmulPreference_t {} +unsafe impl Send for hipblasLtMatrixLayout_t {} +unsafe impl Sync for hipblasLtMatrixLayout_t {} diff --git a/ext/rocm_smi-sys/.rustfmt.toml b/ext/rocm_smi-sys/.rustfmt.toml new file mode 100644 index 0000000..c7ad93b --- /dev/null +++ b/ext/rocm_smi-sys/.rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/format/src/format_generated_blaslt.rs b/format/src/format_generated_blaslt.rs index 9bc143e..4167583 100644 --- a/format/src/format_generated_blaslt.rs +++ b/format/src/format_generated_blaslt.rs @@ -1,360 +1,6 @@ // Generated automatically by zluda_bindgen // DO NOT EDIT MANUALLY #![allow(warnings)] -impl crate::CudaDisplay for cuda_types::cublaslt::cublasFillMode_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER => { - writer.write_all(stringify!(CUBLAS_FILL_MODE_LOWER).as_bytes()) - } - &cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER => { - writer.write_all(stringify!(CUBLAS_FILL_MODE_UPPER).as_bytes()) - } - &cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_FULL => { - writer.write_all(stringify!(CUBLAS_FILL_MODE_FULL).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasDiagType_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT => { - writer.write_all(stringify!(CUBLAS_DIAG_NON_UNIT).as_bytes()) - } - &cuda_types::cublaslt::cublasDiagType_t::CUBLAS_DIAG_UNIT => { - writer.write_all(stringify!(CUBLAS_DIAG_UNIT).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasSideMode_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasSideMode_t::CUBLAS_SIDE_LEFT => { - writer.write_all(stringify!(CUBLAS_SIDE_LEFT).as_bytes()) - } - &cuda_types::cublaslt::cublasSideMode_t::CUBLAS_SIDE_RIGHT => { - writer.write_all(stringify!(CUBLAS_SIDE_RIGHT).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasOperation_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_N => { - writer.write_all(stringify!(CUBLAS_OP_N).as_bytes()) - } - &cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_T => { - writer.write_all(stringify!(CUBLAS_OP_T).as_bytes()) - } - &cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_C => { - writer.write_all(stringify!(CUBLAS_OP_C).as_bytes()) - } - &cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_HERMITAN => { - writer.write_all(stringify!(CUBLAS_OP_HERMITAN).as_bytes()) - } - &cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_CONJG => { - writer.write_all(stringify!(CUBLAS_OP_CONJG).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasPointerMode_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST => { - writer.write_all(stringify!(CUBLAS_POINTER_MODE_HOST).as_bytes()) - } - &cuda_types::cublaslt::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE => { - writer.write_all(stringify!(CUBLAS_POINTER_MODE_DEVICE).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasAtomicsMode_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => { - writer.write_all(stringify!(CUBLAS_ATOMICS_NOT_ALLOWED).as_bytes()) - } - &cuda_types::cublaslt::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => { - writer.write_all(stringify!(CUBLAS_ATOMICS_ALLOWED).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasGemmAlgo_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DFALT => { - writer.write_all(stringify!(CUBLAS_GEMM_DFALT).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT => { - writer.write_all(stringify!(CUBLAS_GEMM_DEFAULT).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO0 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO0).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO1 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO1).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO2 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO2).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO3 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO3).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO4 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO4).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO5 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO5).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO6 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO6).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO7 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO7).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO8 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO8).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO9 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO9).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO10 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO10).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO11 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO11).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO12 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO12).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO13 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO13).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO14 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO14).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO15 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO15).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO16 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO16).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO17 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO17).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO18 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO18).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO19 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO19).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO20 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO20).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO21 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO21).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO22 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO22).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO23 => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO23).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_DEFAULT_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DFALT_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_DFALT_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO0_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO0_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO1_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO1_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO2_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO2_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO3_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO3_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO4_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO4_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO5_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO5_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO6_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO6_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO7_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO7_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO8_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO8_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO9_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO9_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO10_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO10_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO11_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO11_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO12_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO12_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO13_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO13_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO14_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO14_TENSOR_OP).as_bytes()) - } - &cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO15_TENSOR_OP => { - writer.write_all(stringify!(CUBLAS_GEMM_ALGO15_TENSOR_OP).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasMath_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasMath_t::CUBLAS_DEFAULT_MATH => { - writer.write_all(stringify!(CUBLAS_DEFAULT_MATH).as_bytes()) - } - &cuda_types::cublaslt::cublasMath_t::CUBLAS_TENSOR_OP_MATH => { - writer.write_all(stringify!(CUBLAS_TENSOR_OP_MATH).as_bytes()) - } - &cuda_types::cublaslt::cublasMath_t::CUBLAS_PEDANTIC_MATH => { - writer.write_all(stringify!(CUBLAS_PEDANTIC_MATH).as_bytes()) - } - &cuda_types::cublaslt::cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => { - writer.write_all(stringify!(CUBLAS_TF32_TENSOR_OP_MATH).as_bytes()) - } - &cuda_types::cublaslt::cublasMath_t::CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION => { - writer - .write_all( - stringify!(CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION) - .as_bytes(), - ) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasComputeType_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - match self { - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_16F => { - writer.write_all(stringify!(CUBLAS_COMPUTE_16F).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_16F_PEDANTIC => { - writer.write_all(stringify!(CUBLAS_COMPUTE_16F_PEDANTIC).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32F).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_PEDANTIC => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32F_PEDANTIC).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_16F).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_16BF).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_TF32).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_64F => { - writer.write_all(stringify!(CUBLAS_COMPUTE_64F).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_64F_PEDANTIC => { - writer.write_all(stringify!(CUBLAS_COMPUTE_64F_PEDANTIC).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32I => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32I).as_bytes()) - } - &cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32I_PEDANTIC => { - writer.write_all(stringify!(CUBLAS_COMPUTE_32I_PEDANTIC).as_bytes()) - } - _ => write!(writer, "{}", self.0), - } - } -} -impl crate::CudaDisplay for cuda_types::cublaslt::cublasHandle_t { - fn write( - &self, - _fn_name: &'static str, - _index: usize, - writer: &mut (impl std::io::Write + ?Sized), - ) -> std::io::Result<()> { - if self.is_null() { - writer.write_all(b"NULL") - } else { - write!(writer, "{:p}", *self) - } - } -} impl crate::CudaDisplay for cuda_types::cublaslt::cublasLtHandle_t { fn write( &self, @@ -3558,7 +3204,7 @@ pub fn write_cublasLtMatmulDescInit_internal( writer: &mut (impl std::io::Write + ?Sized), matmulDesc: cuda_types::cublaslt::cublasLtMatmulDesc_t, size: usize, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, ) -> std::io::Result<()> { let mut arg_idx = 0usize; @@ -3602,7 +3248,7 @@ pub fn write_cublasLtMatmulDescInit_internal( pub fn write_cublasLtMatmulDescCreate( writer: &mut (impl std::io::Write + ?Sized), matmulDesc: *mut cuda_types::cublaslt::cublasLtMatmulDesc_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, ) -> std::io::Result<()> { let mut arg_idx = 0usize; @@ -4405,7 +4051,7 @@ pub fn write_cublasLtMatmulAlgoGetHeuristic( pub fn write_cublasLtMatmulAlgoGetIds( writer: &mut (impl std::io::Write + ?Sized), lightHandle: cuda_types::cublaslt::cublasLtHandle_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, Atype: cuda_types::cublaslt::cudaDataType_t, Btype: cuda_types::cublaslt::cudaDataType_t, @@ -4485,7 +4131,7 @@ pub fn write_cublasLtMatmulAlgoGetIds( pub fn write_cublasLtMatmulAlgoInit( writer: &mut (impl std::io::Write + ?Sized), lightHandle: cuda_types::cublaslt::cublasLtHandle_t, - computeType: cuda_types::cublaslt::cublasComputeType_t, + computeType: cuda_types::cublas::cublasComputeType_t, scaleType: cuda_types::cublaslt::cudaDataType_t, Atype: cuda_types::cublaslt::cudaDataType_t, Btype: cuda_types::cublaslt::cudaDataType_t, diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs index 7a5bbcd..8203865 100644 --- a/zluda_bindgen/src/main.rs +++ b/zluda_bindgen/src/main.rs @@ -29,6 +29,10 @@ fn main() { &["..", "ext", "hip_runtime-sys", "src", "lib.rs"], ); generate_rocblas(&crate_root, &["..", "ext", "rocblas-sys", "src", "lib.rs"]); + generate_hiplaslt( + &crate_root, + &["..", "ext", "hipblaslt-sys", "src", "lib.rs"], + ); generate_rocm_smi(&crate_root, &["..", "ext", "rocm_smi-sys", "src", "lib.rs"]); let cuda_functions = generate_cuda(&crate_root); generate_process_address_table(&crate_root, cuda_functions); @@ -172,7 +176,7 @@ fn generate_cufft(crate_root: &PathBuf) { new_error_type: "cufftError_t", error_prefix: ("CUFFT_", "ERROR_"), success: ("CUFFT_SUCCESS", "SUCCESS"), - hip_type: None, + hip_types: vec![], }; generate_types_library( Some(&result_options), @@ -185,6 +189,7 @@ fn generate_cufft(crate_root: &PathBuf) { generate_display_perflib( Some(&result_options), &crate_root, + None, &["..", "format", "src", "format_generated_fft.rs"], &["cuda_types", "cufft"], &module, @@ -239,7 +244,7 @@ fn generate_cusparse(crate_root: &PathBuf) { new_error_type: "cusparseError_t", error_prefix: ("CUSPARSE_STATUS_", "ERROR_"), success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"), - hip_type: None, + hip_types: vec![], }; generate_types_library( Some(&result_options), @@ -252,6 +257,7 @@ fn generate_cusparse(crate_root: &PathBuf) { generate_display_perflib( Some(&result_options), &crate_root, + None, &["..", "format", "src", "format_generated_sparse.rs"], &["cuda_types", "cusparse"], &module, @@ -277,7 +283,7 @@ fn generate_cudnn(crate_root: &PathBuf) { new_error_type: "cudnnError_", error_prefix: ("CUDNN_STATUS_", "ERROR_"), success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"), - hip_type: None, + hip_types: vec![], }; let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap(); let cudnn9_types = generate_types_library_impl(Some(&result_options), &cudnn9_module); @@ -322,6 +328,7 @@ fn generate_cudnn(crate_root: &PathBuf) { generate_display_perflib( Some(&result_options), &crate_root, + None, &["..", "format", "src", "format_generated_dnn9.rs"], &["cuda_types", "cudnn9"], &cudnn9_module, @@ -680,7 +687,10 @@ fn generate_cublas(crate_root: &PathBuf) { new_error_type: "cublasError_t", error_prefix: ("CUBLAS_STATUS_", "ERROR_"), success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"), - hip_type: Some(syn::parse_str("rocblas_sys::rocblas_error").unwrap()), + hip_types: vec![ + syn::parse_str("rocblas_sys::rocblas_error").unwrap(), + syn::parse_str("hipblaslt_sys::hipblasLtError").unwrap(), + ], }; generate_types_library( Some(&result_options), @@ -693,6 +703,7 @@ fn generate_cublas(crate_root: &PathBuf) { generate_display_perflib( Some(&result_options), &crate_root, + None, &["..", "format", "src", "format_generated_blas.rs"], &["cuda_types", "cublas"], &module, @@ -717,7 +728,7 @@ fn remove_type(module: &mut syn::File, type_name: &str) { fn generate_cublaslt(crate_root: &PathBuf) { let cublaslt_header = new_builder() .header("/usr/local/cuda/include/cublasLt.h") - .allowlist_type("^cublas.*") + .allowlist_type("^cublasLt.*") .allowlist_function("^cublasLt.*") .allowlist_var("^CUBLASLT_.*") .must_use_type("cublasStatus_t") @@ -749,8 +760,7 @@ fn generate_cublaslt(crate_root: &PathBuf) { cublaslt_internal_header, ) .unwrap(); - let mut module_blas: syn::File = syn::parse_str(&cublaslt_header).unwrap(); - remove_type(&mut module_blas, "cublasStatus_t"); + let module_blas: syn::File = syn::parse_str(&cublaslt_header).unwrap(); generate_functions( &crate_root, "cublaslt", @@ -768,6 +778,7 @@ fn generate_cublaslt(crate_root: &PathBuf) { generate_display_perflib( None, &crate_root, + Some(LibraryOverride::CuBlasLt), &["..", "format", "src", "format_generated_blaslt.rs"], &["cuda_types", "cublaslt"], &module_blas, @@ -775,6 +786,7 @@ fn generate_cublaslt(crate_root: &PathBuf) { generate_display_perflib( None, &crate_root, + Some(LibraryOverride::CuBlasLt), &["..", "format", "src", "format_generated_blaslt_internal.rs"], &["cuda_types", "cublaslt"], &module_blaslt_internal, @@ -816,7 +828,7 @@ fn generate_cuda(crate_root: &PathBuf) -> Vec { new_error_type: "CUerror", error_prefix: ("CUDA_ERROR_", "ERROR_"), success: ("CUDA_SUCCESS", "SUCCESS"), - hip_type: Some(syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()), + hip_types: vec![syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()], }; generate_types_cuda( &result_options, @@ -862,7 +874,7 @@ fn generate_ml(crate_root: &PathBuf) { new_error_type: "nvmlError_t", error_prefix: ("NVML_ERROR_", "ERROR_"), success: ("NVML_SUCCESS", "SUCCESS"), - hip_type: None, + hip_types: vec![], }; let suffix = "#[cfg(unix)] @@ -893,6 +905,7 @@ impl From for nvmlError_t { generate_display_perflib( Some(&result_options), &crate_root, + None, &["..", "format", "src", "format_generated_nvml.rs"], &["cuda_types", "nvml"], &module, @@ -1002,7 +1015,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) { new_error_type: "hipErrorCode_t", error_prefix: ("hipError", "Error"), success: ("hipSuccess", "Success"), - hip_type: None, + hip_types: vec![], }); module.items = converter.convert(module.items).collect::>(); converter.flush(&mut module.items); @@ -1044,7 +1057,7 @@ fn generate_rocblas(output: &PathBuf, path: &[&str]) { new_error_type: "rocblas_error", error_prefix: ("rocblas_status_", "error_"), success: ("rocblas_status_success", "success"), - hip_type: None, + hip_types: vec![], }; let mut converter = ConvertIntoRustResult::new(result_options); module.items = converter @@ -1069,6 +1082,59 @@ fn generate_rocblas(output: &PathBuf, path: &[&str]) { write_rust_to_file(output, text) } +fn generate_hiplaslt(output: &PathBuf, path: &[&str]) { + let rocblas_header = new_builder() + .header("/opt/rocm/include/hipblaslt/hipblaslt.h") + .allowlist_type("^hipblasLt.*") + .allowlist_type("hipblasOperation_t") + .allowlist_function("^hipblasLt.*") + .allowlist_var("^hipblasLt.*") + .must_use_type("hipblasStatus_t") + .constified_enum("hipblasStatus_t") + .new_type_alias("^hipblasLtHandle_t$") + .new_type_alias("^hipblasLtMatmulDesc_t$") + .new_type_alias("^hipblasLtMatmulPreference_t$") + .new_type_alias("^hipblasLtMatrixLayout_t$") + .clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__", "-x", "c++"]) + .generate() + .unwrap() + .to_string(); + let mut module: syn::File = syn::parse_str(&rocblas_header).unwrap(); + remove_type(&mut module, "hipStream_t"); + remove_type(&mut module, "ihipStream_t"); + let result_options = ConvertIntoRustResultOptions { + type_: "hipblasStatus_t", + underlying_type: "hipblasStatus_t", + new_error_type: "hipblasLtError", + error_prefix: ("HIPBLAS_STATUS_", "ERROR_"), + success: ("HIPBLAS_STATUS_SUCCESS", "SUCCESS"), + hip_types: vec![], + }; + let mut converter = ConvertIntoRustResult::new(result_options); + module.items = converter + .convert(module.items) + .map(|item| match item { + Item::ForeignMod(mut extern_) => { + extern_.attrs.push( + parse_quote!(#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))]), + ); + Item::ForeignMod(extern_) + } + item => item, + }) + .collect(); + converter.flush(&mut module.items); + add_send_sync(&mut module.items, &["hipblasLtHandle_t"]); + add_send_sync(&mut module.items, &["hipblasLtMatmulDesc_t"]); + add_send_sync(&mut module.items, &["hipblasLtMatmulPreference_t"]); + add_send_sync(&mut module.items, &["hipblasLtMatrixLayout_t"]); + let mut output = output.clone(); + output.extend(path); + let text = + &prettyplease::unparse(&module).replace("hipStream_t", "hip_runtime_sys::hipStream_t"); + write_rust_to_file(output, text) +} + fn generate_rocm_smi(output: &PathBuf, path: &[&str]) { let rocm_smi_header = new_builder() .header("/opt/rocm/include/rocm_smi/rocm_smi.h") @@ -1088,7 +1154,7 @@ fn generate_rocm_smi(output: &PathBuf, path: &[&str]) { new_error_type: "rsmi_error", error_prefix: ("RSMI_STATUS_", "ERROR_"), success: ("RSMI_STATUS_SUCCESS", "SUCCESS"), - hip_type: None, + hip_types: vec![], }; let mut converter = ConvertIntoRustResult::new(result_options); module.items = converter.convert(module.items).collect(); @@ -1114,7 +1180,7 @@ fn add_send_sync(items: &mut Vec, arg: &[&str]) { fn generate_functions( output: &PathBuf, - submodule: &str, + submodule_str: &str, path: &[&str], module: &syn::File, ) -> syn::File { @@ -1141,7 +1207,7 @@ fn generate_functions( #(#fns_)* } }; - let submodule = Ident::new(submodule, Span::call_site()); + let submodule = Ident::new(submodule_str, Span::call_site()); syn::visit_mut::visit_file_mut( &mut PrependCudaPath { module: vec![Ident::new("cuda_types", Span::call_site()), submodule], @@ -1152,7 +1218,15 @@ fn generate_functions( syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module); let mut output = output.clone(); output.extend(path); - write_rust_to_file(output, &prettyplease::unparse(&module)); + let text = prettyplease::unparse(&module); + let text = match submodule_str { + "cublaslt" => text.replace( + "cuda_types::cublaslt::cublasComputeType_t", + "cuda_types::cublas::cublasComputeType_t", + ), + _ => text, + }; + write_rust_to_file(output, &text); module /* module @@ -1239,7 +1313,7 @@ struct ConvertIntoRustResultOptions { error_prefix: (&'static str, &'static str), success: (&'static str, &'static str), // TODO: this should no longer be an Option once all hip perf libraries are present - hip_type: Option, + hip_types: Vec, } struct ConvertIntoRustResult { @@ -1326,7 +1400,7 @@ impl ConvertIntoRustResult { }; }; items.extend(extra_items); - if let Some(hip_error_path) = self.options.hip_type { + for hip_error_path in self.options.hip_types { items.push( parse_quote! {impl From<#hip_error_path> for #new_error_type { fn from(error: #hip_error_path) -> Self { @@ -1495,6 +1569,7 @@ fn generate_display_cuda( fn generate_display_perflib( result_options: Option<&ConvertIntoRustResultOptions>, output: &PathBuf, + override_: Option, path: &[&str], types_crate: &[&'static str], module: &syn::File, @@ -1539,14 +1614,20 @@ fn generate_display_perflib( } let mut output = output.clone(); output.extend(path); - write_rust_to_file( - output, - &prettyplease::unparse(&syn::File { - shebang: None, - attrs: Vec::new(), - items, - }), - ); + let text = prettyplease::unparse(&syn::File { + shebang: None, + attrs: Vec::new(), + items, + }); + let text = match override_ { + None => text, + Some(LibraryOverride::CuBlasLt) => text.replace( + "cuda_types::cublaslt::cublasComputeType_t", + "cuda_types::cublas::cublasComputeType_t", + ), + Some(LibraryOverride::CuFft) => text, + }; + write_rust_to_file(output, &text); } struct DeriveDisplayState<'a> {