Create bindings for hipblasLt (#510)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

Generate bindings for hipblasLt and make some changes to the bindings for cublasLt. Notably, the `hip_type` `Option` is changed to a `Vec`, so that multiple `From` implementations (for `rocblas_error` and `hipblasLtError`) can be created for `cublasError_t`.
This commit is contained in:
Violet 2025-09-16 16:23:15 -07:00 committed by GitHub
commit 5185138596
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1589 additions and 660 deletions

View file

@ -11,7 +11,7 @@ echo deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.c
echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' \
| tee /etc/apt/preferences.d/rocm-pin-600
DEBIAN_FRONTEND=noninteractive apt update -y
DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends rocm-smi-lib rocm-llvm-dev hip-runtime-amd hip-dev rocblas-dev
DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends rocm-smi-lib rocm-llvm-dev hip-runtime-amd hip-dev rocblas-dev hipblaslt-dev
echo 'export PATH="$PATH:/opt/rocm/bin"' | tee /etc/profile.d/rocm.sh
echo "/opt/rocm/lib" | tee /etc/ld.so.conf.d/rocm.conf
ldconfig

8
Cargo.lock generated
View file

@ -431,6 +431,7 @@ dependencies = [
"bitflags 2.9.1",
"cuda_macros",
"hip_runtime-sys",
"hipblaslt-sys",
"rocblas-sys",
"rocm_smi-sys",
]
@ -1773,6 +1774,13 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hip_runtime-sys"
version = "0.0.0"
[[package]]
name = "hipblaslt-sys"
version = "0.1.0"
dependencies = [
"hip_runtime-sys",
]
[[package]]
name = "home"
version = "0.5.11"

View file

@ -171,7 +171,7 @@ extern "system" {
fn cublasLtMatmulDescInit_internal(
matmulDesc: cuda_types::cublaslt::cublasLtMatmulDesc_t,
size: usize,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
) -> cuda_types::cublas::cublasStatus_t;
#[must_use]
@ -181,7 +181,7 @@ extern "system" {
\retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully*/
fn cublasLtMatmulDescCreate(
matmulDesc: *mut cuda_types::cublaslt::cublasLtMatmulDesc_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
) -> cuda_types::cublas::cublasStatus_t;
#[must_use]
@ -396,7 +396,7 @@ extern "system" {
available*/
fn cublasLtMatmulAlgoGetIds(
lightHandle: cuda_types::cublaslt::cublasLtHandle_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
Atype: cuda_types::cublaslt::cudaDataType_t,
Btype: cuda_types::cublaslt::cudaDataType_t,
@ -414,7 +414,7 @@ extern "system" {
\retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized*/
fn cublasLtMatmulAlgoInit(
lightHandle: cuda_types::cublaslt::cublasLtHandle_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
Atype: cuda_types::cublaslt::cudaDataType_t,
Btype: cuda_types::cublaslt::cudaDataType_t,

View file

@ -9,6 +9,7 @@ cuda_macros = { path = "../cuda_macros" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
bitflags = "2.9.1"
rocblas-sys = { path = "../ext/rocblas-sys" }
hipblaslt-sys = { path = "../ext/hipblaslt-sys" }
[target.'cfg(unix)'.dependencies]
rocm_smi-sys = { path = "../ext/rocm_smi-sys" }

View file

@ -363,3 +363,8 @@ impl From<rocblas_sys::rocblas_error> for cublasError_t {
Self(error.0)
}
}
impl From<hipblaslt_sys::hipblasLtError> for cublasError_t {
fn from(error: hipblaslt_sys::hipblasLtError) -> Self {
Self(error.0)
}
}

View file

@ -33,277 +33,6 @@ pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3: u32 = 4194304;
pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2: u32 = 8388608;
pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK: u32 = 16711680;
pub const CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN: u64 = 4294967296;
impl cublasFillMode_t {
pub const CUBLAS_FILL_MODE_LOWER: cublasFillMode_t = cublasFillMode_t(0);
}
impl cublasFillMode_t {
pub const CUBLAS_FILL_MODE_UPPER: cublasFillMode_t = cublasFillMode_t(1);
}
impl cublasFillMode_t {
pub const CUBLAS_FILL_MODE_FULL: cublasFillMode_t = cublasFillMode_t(2);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasFillMode_t(pub ::core::ffi::c_uint);
impl cublasDiagType_t {
pub const CUBLAS_DIAG_NON_UNIT: cublasDiagType_t = cublasDiagType_t(0);
}
impl cublasDiagType_t {
pub const CUBLAS_DIAG_UNIT: cublasDiagType_t = cublasDiagType_t(1);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasDiagType_t(pub ::core::ffi::c_uint);
impl cublasSideMode_t {
pub const CUBLAS_SIDE_LEFT: cublasSideMode_t = cublasSideMode_t(0);
}
impl cublasSideMode_t {
pub const CUBLAS_SIDE_RIGHT: cublasSideMode_t = cublasSideMode_t(1);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasSideMode_t(pub ::core::ffi::c_uint);
impl cublasOperation_t {
pub const CUBLAS_OP_N: cublasOperation_t = cublasOperation_t(0);
}
impl cublasOperation_t {
pub const CUBLAS_OP_T: cublasOperation_t = cublasOperation_t(1);
}
impl cublasOperation_t {
pub const CUBLAS_OP_C: cublasOperation_t = cublasOperation_t(2);
}
impl cublasOperation_t {
pub const CUBLAS_OP_HERMITAN: cublasOperation_t = cublasOperation_t(2);
}
impl cublasOperation_t {
pub const CUBLAS_OP_CONJG: cublasOperation_t = cublasOperation_t(3);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasOperation_t(pub ::core::ffi::c_uint);
impl cublasPointerMode_t {
pub const CUBLAS_POINTER_MODE_HOST: cublasPointerMode_t = cublasPointerMode_t(0);
}
impl cublasPointerMode_t {
pub const CUBLAS_POINTER_MODE_DEVICE: cublasPointerMode_t = cublasPointerMode_t(1);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasPointerMode_t(pub ::core::ffi::c_uint);
impl cublasAtomicsMode_t {
pub const CUBLAS_ATOMICS_NOT_ALLOWED: cublasAtomicsMode_t = cublasAtomicsMode_t(0);
}
impl cublasAtomicsMode_t {
pub const CUBLAS_ATOMICS_ALLOWED: cublasAtomicsMode_t = cublasAtomicsMode_t(1);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasAtomicsMode_t(pub ::core::ffi::c_uint);
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DFALT: cublasGemmAlgo_t = cublasGemmAlgo_t(-1);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DEFAULT: cublasGemmAlgo_t = cublasGemmAlgo_t(-1);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO0: cublasGemmAlgo_t = cublasGemmAlgo_t(0);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO1: cublasGemmAlgo_t = cublasGemmAlgo_t(1);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO2: cublasGemmAlgo_t = cublasGemmAlgo_t(2);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO3: cublasGemmAlgo_t = cublasGemmAlgo_t(3);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO4: cublasGemmAlgo_t = cublasGemmAlgo_t(4);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO5: cublasGemmAlgo_t = cublasGemmAlgo_t(5);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO6: cublasGemmAlgo_t = cublasGemmAlgo_t(6);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO7: cublasGemmAlgo_t = cublasGemmAlgo_t(7);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO8: cublasGemmAlgo_t = cublasGemmAlgo_t(8);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO9: cublasGemmAlgo_t = cublasGemmAlgo_t(9);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO10: cublasGemmAlgo_t = cublasGemmAlgo_t(10);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO11: cublasGemmAlgo_t = cublasGemmAlgo_t(11);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO12: cublasGemmAlgo_t = cublasGemmAlgo_t(12);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO13: cublasGemmAlgo_t = cublasGemmAlgo_t(13);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO14: cublasGemmAlgo_t = cublasGemmAlgo_t(14);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO15: cublasGemmAlgo_t = cublasGemmAlgo_t(15);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO16: cublasGemmAlgo_t = cublasGemmAlgo_t(16);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO17: cublasGemmAlgo_t = cublasGemmAlgo_t(17);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO18: cublasGemmAlgo_t = cublasGemmAlgo_t(18);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO19: cublasGemmAlgo_t = cublasGemmAlgo_t(19);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO20: cublasGemmAlgo_t = cublasGemmAlgo_t(20);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO21: cublasGemmAlgo_t = cublasGemmAlgo_t(21);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO22: cublasGemmAlgo_t = cublasGemmAlgo_t(22);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO23: cublasGemmAlgo_t = cublasGemmAlgo_t(23);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DEFAULT_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(99);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_DFALT_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(99);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO0_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(100);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO1_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(101);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO2_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(102);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO3_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(103);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO4_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(104);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO5_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(105);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO6_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(106);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO7_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(107);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO8_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(108);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO9_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(109);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO10_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(110);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO11_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(111);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO12_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(112);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO13_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(113);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO14_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(114);
}
impl cublasGemmAlgo_t {
pub const CUBLAS_GEMM_ALGO15_TENSOR_OP: cublasGemmAlgo_t = cublasGemmAlgo_t(115);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasGemmAlgo_t(pub ::core::ffi::c_int);
impl cublasMath_t {
pub const CUBLAS_DEFAULT_MATH: cublasMath_t = cublasMath_t(0);
}
impl cublasMath_t {
pub const CUBLAS_TENSOR_OP_MATH: cublasMath_t = cublasMath_t(1);
}
impl cublasMath_t {
pub const CUBLAS_PEDANTIC_MATH: cublasMath_t = cublasMath_t(2);
}
impl cublasMath_t {
pub const CUBLAS_TF32_TENSOR_OP_MATH: cublasMath_t = cublasMath_t(3);
}
impl cublasMath_t {
pub const CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION: cublasMath_t = cublasMath_t(
16,
);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasMath_t(pub ::core::ffi::c_uint);
pub use super::cuda::cudaDataType as cublasDataType_t;
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_16F: cublasComputeType_t = cublasComputeType_t(64);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_16F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(65);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32F: cublasComputeType_t = cublasComputeType_t(68);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(69);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32F_FAST_16F: cublasComputeType_t = cublasComputeType_t(74);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32F_FAST_16BF: cublasComputeType_t = cublasComputeType_t(
75,
);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32F_FAST_TF32: cublasComputeType_t = cublasComputeType_t(
77,
);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_64F: cublasComputeType_t = cublasComputeType_t(70);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_64F_PEDANTIC: cublasComputeType_t = cublasComputeType_t(71);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32I: cublasComputeType_t = cublasComputeType_t(72);
}
impl cublasComputeType_t {
pub const CUBLAS_COMPUTE_32I_PEDANTIC: cublasComputeType_t = cublasComputeType_t(73);
}
#[repr(transparent)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub struct cublasComputeType_t(pub ::core::ffi::c_uint);
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cublasContext {
_unused: [u8; 0],
}
pub type cublasHandle_t = *mut cublasContext;
pub type cublasLogCallback = ::core::option::Option<
unsafe extern "C" fn(msg: *const ::core::ffi::c_char),
>;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cublasLtContext {
@ -5071,7 +4800,6 @@ pub struct cublasLtMatmulPreferenceAttributes_t(pub ::core::ffi::c_uint);
Holds returned configured algo descriptor and its runtime properties.*/
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct cublasLtMatmulHeuristicResult_t {
/** Matmul algorithm descriptor.

1
ext/hipblaslt-sys/.rustfmt.toml vendored Normal file
View file

@ -0,0 +1 @@
disable_all_formatting = true

10
ext/hipblaslt-sys/Cargo.toml vendored Normal file
View file

@ -0,0 +1,10 @@
[package]
name = "hipblaslt-sys"
version = "0.1.0"
authors = ["Violet <c01368481@gmail.com>"]
edition = "2021"
[lib]
[dependencies]
hip_runtime-sys = { version = "0.0.0", path = "../hip_runtime-sys" }

9
ext/hipblaslt-sys/build.rs vendored Normal file
View file

@ -0,0 +1,9 @@
use std::env::VarError;
fn main() -> Result<(), VarError> {
if !cfg!(windows) {
println!("cargo:rustc-link-lib=dylib=hipblaslt");
println!("cargo:rustc-link-search=native=/opt/rocm/lib/");
}
Ok(())
}

1439
ext/hipblaslt-sys/src/lib.rs vendored Normal file

File diff suppressed because it is too large Load diff

1
ext/rocm_smi-sys/.rustfmt.toml vendored Normal file
View file

@ -0,0 +1 @@
disable_all_formatting = true

View file

@ -1,360 +1,6 @@
// Generated automatically by zluda_bindgen
// DO NOT EDIT MANUALLY
#![allow(warnings)]
impl crate::CudaDisplay for cuda_types::cublaslt::cublasFillMode_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER => {
writer.write_all(stringify!(CUBLAS_FILL_MODE_LOWER).as_bytes())
}
&cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_UPPER => {
writer.write_all(stringify!(CUBLAS_FILL_MODE_UPPER).as_bytes())
}
&cuda_types::cublaslt::cublasFillMode_t::CUBLAS_FILL_MODE_FULL => {
writer.write_all(stringify!(CUBLAS_FILL_MODE_FULL).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasDiagType_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasDiagType_t::CUBLAS_DIAG_NON_UNIT => {
writer.write_all(stringify!(CUBLAS_DIAG_NON_UNIT).as_bytes())
}
&cuda_types::cublaslt::cublasDiagType_t::CUBLAS_DIAG_UNIT => {
writer.write_all(stringify!(CUBLAS_DIAG_UNIT).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasSideMode_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasSideMode_t::CUBLAS_SIDE_LEFT => {
writer.write_all(stringify!(CUBLAS_SIDE_LEFT).as_bytes())
}
&cuda_types::cublaslt::cublasSideMode_t::CUBLAS_SIDE_RIGHT => {
writer.write_all(stringify!(CUBLAS_SIDE_RIGHT).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasOperation_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_N => {
writer.write_all(stringify!(CUBLAS_OP_N).as_bytes())
}
&cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_T => {
writer.write_all(stringify!(CUBLAS_OP_T).as_bytes())
}
&cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_C => {
writer.write_all(stringify!(CUBLAS_OP_C).as_bytes())
}
&cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_HERMITAN => {
writer.write_all(stringify!(CUBLAS_OP_HERMITAN).as_bytes())
}
&cuda_types::cublaslt::cublasOperation_t::CUBLAS_OP_CONJG => {
writer.write_all(stringify!(CUBLAS_OP_CONJG).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasPointerMode_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST => {
writer.write_all(stringify!(CUBLAS_POINTER_MODE_HOST).as_bytes())
}
&cuda_types::cublaslt::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE => {
writer.write_all(stringify!(CUBLAS_POINTER_MODE_DEVICE).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasAtomicsMode_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasAtomicsMode_t::CUBLAS_ATOMICS_NOT_ALLOWED => {
writer.write_all(stringify!(CUBLAS_ATOMICS_NOT_ALLOWED).as_bytes())
}
&cuda_types::cublaslt::cublasAtomicsMode_t::CUBLAS_ATOMICS_ALLOWED => {
writer.write_all(stringify!(CUBLAS_ATOMICS_ALLOWED).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasGemmAlgo_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DFALT => {
writer.write_all(stringify!(CUBLAS_GEMM_DFALT).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT => {
writer.write_all(stringify!(CUBLAS_GEMM_DEFAULT).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO0 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO0).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO1 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO1).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO2 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO2).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO3 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO3).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO4 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO4).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO5 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO5).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO6 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO6).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO7 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO7).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO8 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO8).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO9 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO9).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO10 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO10).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO11 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO11).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO12 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO12).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO13 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO13).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO14 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO14).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO15 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO15).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO16 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO16).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO17 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO17).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO18 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO18).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO19 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO19).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO20 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO20).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO21 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO21).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO22 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO22).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO23 => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO23).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_DEFAULT_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_DFALT_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_DFALT_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO0_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO0_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO1_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO1_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO2_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO2_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO3_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO3_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO4_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO4_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO5_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO5_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO6_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO6_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO7_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO7_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO8_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO8_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO9_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO9_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO10_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO10_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO11_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO11_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO12_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO12_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO13_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO13_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO14_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO14_TENSOR_OP).as_bytes())
}
&cuda_types::cublaslt::cublasGemmAlgo_t::CUBLAS_GEMM_ALGO15_TENSOR_OP => {
writer.write_all(stringify!(CUBLAS_GEMM_ALGO15_TENSOR_OP).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasMath_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasMath_t::CUBLAS_DEFAULT_MATH => {
writer.write_all(stringify!(CUBLAS_DEFAULT_MATH).as_bytes())
}
&cuda_types::cublaslt::cublasMath_t::CUBLAS_TENSOR_OP_MATH => {
writer.write_all(stringify!(CUBLAS_TENSOR_OP_MATH).as_bytes())
}
&cuda_types::cublaslt::cublasMath_t::CUBLAS_PEDANTIC_MATH => {
writer.write_all(stringify!(CUBLAS_PEDANTIC_MATH).as_bytes())
}
&cuda_types::cublaslt::cublasMath_t::CUBLAS_TF32_TENSOR_OP_MATH => {
writer.write_all(stringify!(CUBLAS_TF32_TENSOR_OP_MATH).as_bytes())
}
&cuda_types::cublaslt::cublasMath_t::CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION => {
writer
.write_all(
stringify!(CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION)
.as_bytes(),
)
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasComputeType_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
match self {
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_16F => {
writer.write_all(stringify!(CUBLAS_COMPUTE_16F).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_16F_PEDANTIC => {
writer.write_all(stringify!(CUBLAS_COMPUTE_16F_PEDANTIC).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32F).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_PEDANTIC => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32F_PEDANTIC).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_16F).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_16BF).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32F_FAST_TF32).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_64F => {
writer.write_all(stringify!(CUBLAS_COMPUTE_64F).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_64F_PEDANTIC => {
writer.write_all(stringify!(CUBLAS_COMPUTE_64F_PEDANTIC).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32I => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32I).as_bytes())
}
&cuda_types::cublaslt::cublasComputeType_t::CUBLAS_COMPUTE_32I_PEDANTIC => {
writer.write_all(stringify!(CUBLAS_COMPUTE_32I_PEDANTIC).as_bytes())
}
_ => write!(writer, "{}", self.0),
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasHandle_t {
fn write(
&self,
_fn_name: &'static str,
_index: usize,
writer: &mut (impl std::io::Write + ?Sized),
) -> std::io::Result<()> {
if self.is_null() {
writer.write_all(b"NULL")
} else {
write!(writer, "{:p}", *self)
}
}
}
impl crate::CudaDisplay for cuda_types::cublaslt::cublasLtHandle_t {
fn write(
&self,
@ -3558,7 +3204,7 @@ pub fn write_cublasLtMatmulDescInit_internal(
writer: &mut (impl std::io::Write + ?Sized),
matmulDesc: cuda_types::cublaslt::cublasLtMatmulDesc_t,
size: usize,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
) -> std::io::Result<()> {
let mut arg_idx = 0usize;
@ -3602,7 +3248,7 @@ pub fn write_cublasLtMatmulDescInit_internal(
pub fn write_cublasLtMatmulDescCreate(
writer: &mut (impl std::io::Write + ?Sized),
matmulDesc: *mut cuda_types::cublaslt::cublasLtMatmulDesc_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
) -> std::io::Result<()> {
let mut arg_idx = 0usize;
@ -4405,7 +4051,7 @@ pub fn write_cublasLtMatmulAlgoGetHeuristic(
pub fn write_cublasLtMatmulAlgoGetIds(
writer: &mut (impl std::io::Write + ?Sized),
lightHandle: cuda_types::cublaslt::cublasLtHandle_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
Atype: cuda_types::cublaslt::cudaDataType_t,
Btype: cuda_types::cublaslt::cudaDataType_t,
@ -4485,7 +4131,7 @@ pub fn write_cublasLtMatmulAlgoGetIds(
pub fn write_cublasLtMatmulAlgoInit(
writer: &mut (impl std::io::Write + ?Sized),
lightHandle: cuda_types::cublaslt::cublasLtHandle_t,
computeType: cuda_types::cublaslt::cublasComputeType_t,
computeType: cuda_types::cublas::cublasComputeType_t,
scaleType: cuda_types::cublaslt::cudaDataType_t,
Atype: cuda_types::cublaslt::cudaDataType_t,
Btype: cuda_types::cublaslt::cudaDataType_t,

View file

@ -29,6 +29,10 @@ fn main() {
&["..", "ext", "hip_runtime-sys", "src", "lib.rs"],
);
generate_rocblas(&crate_root, &["..", "ext", "rocblas-sys", "src", "lib.rs"]);
generate_hiplaslt(
&crate_root,
&["..", "ext", "hipblaslt-sys", "src", "lib.rs"],
);
generate_rocm_smi(&crate_root, &["..", "ext", "rocm_smi-sys", "src", "lib.rs"]);
let cuda_functions = generate_cuda(&crate_root);
generate_process_address_table(&crate_root, cuda_functions);
@ -172,7 +176,7 @@ fn generate_cufft(crate_root: &PathBuf) {
new_error_type: "cufftError_t",
error_prefix: ("CUFFT_", "ERROR_"),
success: ("CUFFT_SUCCESS", "SUCCESS"),
hip_type: None,
hip_types: vec![],
};
generate_types_library(
Some(&result_options),
@ -185,6 +189,7 @@ fn generate_cufft(crate_root: &PathBuf) {
generate_display_perflib(
Some(&result_options),
&crate_root,
None,
&["..", "format", "src", "format_generated_fft.rs"],
&["cuda_types", "cufft"],
&module,
@ -239,7 +244,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
new_error_type: "cusparseError_t",
error_prefix: ("CUSPARSE_STATUS_", "ERROR_"),
success: ("CUSPARSE_STATUS_SUCCESS", "SUCCESS"),
hip_type: None,
hip_types: vec![],
};
generate_types_library(
Some(&result_options),
@ -252,6 +257,7 @@ fn generate_cusparse(crate_root: &PathBuf) {
generate_display_perflib(
Some(&result_options),
&crate_root,
None,
&["..", "format", "src", "format_generated_sparse.rs"],
&["cuda_types", "cusparse"],
&module,
@ -277,7 +283,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
new_error_type: "cudnnError_",
error_prefix: ("CUDNN_STATUS_", "ERROR_"),
success: ("CUDNN_STATUS_SUCCESS", "SUCCESS"),
hip_type: None,
hip_types: vec![],
};
let cudnn9_module: syn::File = syn::parse_str(&cudnn9).unwrap();
let cudnn9_types = generate_types_library_impl(Some(&result_options), &cudnn9_module);
@ -322,6 +328,7 @@ fn generate_cudnn(crate_root: &PathBuf) {
generate_display_perflib(
Some(&result_options),
&crate_root,
None,
&["..", "format", "src", "format_generated_dnn9.rs"],
&["cuda_types", "cudnn9"],
&cudnn9_module,
@ -680,7 +687,10 @@ fn generate_cublas(crate_root: &PathBuf) {
new_error_type: "cublasError_t",
error_prefix: ("CUBLAS_STATUS_", "ERROR_"),
success: ("CUBLAS_STATUS_SUCCESS", "SUCCESS"),
hip_type: Some(syn::parse_str("rocblas_sys::rocblas_error").unwrap()),
hip_types: vec![
syn::parse_str("rocblas_sys::rocblas_error").unwrap(),
syn::parse_str("hipblaslt_sys::hipblasLtError").unwrap(),
],
};
generate_types_library(
Some(&result_options),
@ -693,6 +703,7 @@ fn generate_cublas(crate_root: &PathBuf) {
generate_display_perflib(
Some(&result_options),
&crate_root,
None,
&["..", "format", "src", "format_generated_blas.rs"],
&["cuda_types", "cublas"],
&module,
@ -717,7 +728,7 @@ fn remove_type(module: &mut syn::File, type_name: &str) {
fn generate_cublaslt(crate_root: &PathBuf) {
let cublaslt_header = new_builder()
.header("/usr/local/cuda/include/cublasLt.h")
.allowlist_type("^cublas.*")
.allowlist_type("^cublasLt.*")
.allowlist_function("^cublasLt.*")
.allowlist_var("^CUBLASLT_.*")
.must_use_type("cublasStatus_t")
@ -749,8 +760,7 @@ fn generate_cublaslt(crate_root: &PathBuf) {
cublaslt_internal_header,
)
.unwrap();
let mut module_blas: syn::File = syn::parse_str(&cublaslt_header).unwrap();
remove_type(&mut module_blas, "cublasStatus_t");
let module_blas: syn::File = syn::parse_str(&cublaslt_header).unwrap();
generate_functions(
&crate_root,
"cublaslt",
@ -768,6 +778,7 @@ fn generate_cublaslt(crate_root: &PathBuf) {
generate_display_perflib(
None,
&crate_root,
Some(LibraryOverride::CuBlasLt),
&["..", "format", "src", "format_generated_blaslt.rs"],
&["cuda_types", "cublaslt"],
&module_blas,
@ -775,6 +786,7 @@ fn generate_cublaslt(crate_root: &PathBuf) {
generate_display_perflib(
None,
&crate_root,
Some(LibraryOverride::CuBlasLt),
&["..", "format", "src", "format_generated_blaslt_internal.rs"],
&["cuda_types", "cublaslt"],
&module_blaslt_internal,
@ -816,7 +828,7 @@ fn generate_cuda(crate_root: &PathBuf) -> Vec<Ident> {
new_error_type: "CUerror",
error_prefix: ("CUDA_ERROR_", "ERROR_"),
success: ("CUDA_SUCCESS", "SUCCESS"),
hip_type: Some(syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()),
hip_types: vec![syn::parse_str("hip_runtime_sys::hipErrorCode_t").unwrap()],
};
generate_types_cuda(
&result_options,
@ -862,7 +874,7 @@ fn generate_ml(crate_root: &PathBuf) {
new_error_type: "nvmlError_t",
error_prefix: ("NVML_ERROR_", "ERROR_"),
success: ("NVML_SUCCESS", "SUCCESS"),
hip_type: None,
hip_types: vec![],
};
let suffix =
"#[cfg(unix)]
@ -893,6 +905,7 @@ impl From<rocm_smi_sys::rsmi_error> for nvmlError_t {
generate_display_perflib(
Some(&result_options),
&crate_root,
None,
&["..", "format", "src", "format_generated_nvml.rs"],
&["cuda_types", "nvml"],
&module,
@ -1002,7 +1015,7 @@ fn generate_hip_runtime(output: &PathBuf, path: &[&str]) {
new_error_type: "hipErrorCode_t",
error_prefix: ("hipError", "Error"),
success: ("hipSuccess", "Success"),
hip_type: None,
hip_types: vec![],
});
module.items = converter.convert(module.items).collect::<Vec<Item>>();
converter.flush(&mut module.items);
@ -1044,7 +1057,7 @@ fn generate_rocblas(output: &PathBuf, path: &[&str]) {
new_error_type: "rocblas_error",
error_prefix: ("rocblas_status_", "error_"),
success: ("rocblas_status_success", "success"),
hip_type: None,
hip_types: vec![],
};
let mut converter = ConvertIntoRustResult::new(result_options);
module.items = converter
@ -1069,6 +1082,59 @@ fn generate_rocblas(output: &PathBuf, path: &[&str]) {
write_rust_to_file(output, text)
}
fn generate_hiplaslt(output: &PathBuf, path: &[&str]) {
let rocblas_header = new_builder()
.header("/opt/rocm/include/hipblaslt/hipblaslt.h")
.allowlist_type("^hipblasLt.*")
.allowlist_type("hipblasOperation_t")
.allowlist_function("^hipblasLt.*")
.allowlist_var("^hipblasLt.*")
.must_use_type("hipblasStatus_t")
.constified_enum("hipblasStatus_t")
.new_type_alias("^hipblasLtHandle_t$")
.new_type_alias("^hipblasLtMatmulDesc_t$")
.new_type_alias("^hipblasLtMatmulPreference_t$")
.new_type_alias("^hipblasLtMatrixLayout_t$")
.clang_args(["-I/opt/rocm/include", "-D__HIP_PLATFORM_AMD__", "-x", "c++"])
.generate()
.unwrap()
.to_string();
let mut module: syn::File = syn::parse_str(&rocblas_header).unwrap();
remove_type(&mut module, "hipStream_t");
remove_type(&mut module, "ihipStream_t");
let result_options = ConvertIntoRustResultOptions {
type_: "hipblasStatus_t",
underlying_type: "hipblasStatus_t",
new_error_type: "hipblasLtError",
error_prefix: ("HIPBLAS_STATUS_", "ERROR_"),
success: ("HIPBLAS_STATUS_SUCCESS", "SUCCESS"),
hip_types: vec![],
};
let mut converter = ConvertIntoRustResult::new(result_options);
module.items = converter
.convert(module.items)
.map(|item| match item {
Item::ForeignMod(mut extern_) => {
extern_.attrs.push(
parse_quote!(#[cfg_attr(windows, link(name = "hipblaslt", kind = "raw-dylib"))]),
);
Item::ForeignMod(extern_)
}
item => item,
})
.collect();
converter.flush(&mut module.items);
add_send_sync(&mut module.items, &["hipblasLtHandle_t"]);
add_send_sync(&mut module.items, &["hipblasLtMatmulDesc_t"]);
add_send_sync(&mut module.items, &["hipblasLtMatmulPreference_t"]);
add_send_sync(&mut module.items, &["hipblasLtMatrixLayout_t"]);
let mut output = output.clone();
output.extend(path);
let text =
&prettyplease::unparse(&module).replace("hipStream_t", "hip_runtime_sys::hipStream_t");
write_rust_to_file(output, text)
}
fn generate_rocm_smi(output: &PathBuf, path: &[&str]) {
let rocm_smi_header = new_builder()
.header("/opt/rocm/include/rocm_smi/rocm_smi.h")
@ -1088,7 +1154,7 @@ fn generate_rocm_smi(output: &PathBuf, path: &[&str]) {
new_error_type: "rsmi_error",
error_prefix: ("RSMI_STATUS_", "ERROR_"),
success: ("RSMI_STATUS_SUCCESS", "SUCCESS"),
hip_type: None,
hip_types: vec![],
};
let mut converter = ConvertIntoRustResult::new(result_options);
module.items = converter.convert(module.items).collect();
@ -1114,7 +1180,7 @@ fn add_send_sync(items: &mut Vec<Item>, arg: &[&str]) {
fn generate_functions(
output: &PathBuf,
submodule: &str,
submodule_str: &str,
path: &[&str],
module: &syn::File,
) -> syn::File {
@ -1141,7 +1207,7 @@ fn generate_functions(
#(#fns_)*
}
};
let submodule = Ident::new(submodule, Span::call_site());
let submodule = Ident::new(submodule_str, Span::call_site());
syn::visit_mut::visit_file_mut(
&mut PrependCudaPath {
module: vec![Ident::new("cuda_types", Span::call_site()), submodule],
@ -1152,7 +1218,15 @@ fn generate_functions(
syn::visit_mut::visit_file_mut(&mut ExplicitReturnType, &mut module);
let mut output = output.clone();
output.extend(path);
write_rust_to_file(output, &prettyplease::unparse(&module));
let text = prettyplease::unparse(&module);
let text = match submodule_str {
"cublaslt" => text.replace(
"cuda_types::cublaslt::cublasComputeType_t",
"cuda_types::cublas::cublasComputeType_t",
),
_ => text,
};
write_rust_to_file(output, &text);
module
/*
module
@ -1239,7 +1313,7 @@ struct ConvertIntoRustResultOptions {
error_prefix: (&'static str, &'static str),
success: (&'static str, &'static str),
// TODO: this should no longer be an Option once all hip perf libraries are present
hip_type: Option<Path>,
hip_types: Vec<Path>,
}
struct ConvertIntoRustResult {
@ -1326,7 +1400,7 @@ impl ConvertIntoRustResult {
};
};
items.extend(extra_items);
if let Some(hip_error_path) = self.options.hip_type {
for hip_error_path in self.options.hip_types {
items.push(
parse_quote! {impl From<#hip_error_path> for #new_error_type {
fn from(error: #hip_error_path) -> Self {
@ -1495,6 +1569,7 @@ fn generate_display_cuda(
fn generate_display_perflib(
result_options: Option<&ConvertIntoRustResultOptions>,
output: &PathBuf,
override_: Option<LibraryOverride>,
path: &[&str],
types_crate: &[&'static str],
module: &syn::File,
@ -1539,14 +1614,20 @@ fn generate_display_perflib(
}
let mut output = output.clone();
output.extend(path);
write_rust_to_file(
output,
&prettyplease::unparse(&syn::File {
let text = prettyplease::unparse(&syn::File {
shebang: None,
attrs: Vec::new(),
items,
}),
);
});
let text = match override_ {
None => text,
Some(LibraryOverride::CuBlasLt) => text.replace(
"cuda_types::cublaslt::cublasComputeType_t",
"cuda_types::cublas::cublasComputeType_t",
),
Some(LibraryOverride::CuFft) => text,
};
write_rust_to_file(output, &text);
}
struct DeriveDisplayState<'a> {