From 2e35d157ceb1b8eb1e7af654ced39027bd8c715e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Tue, 16 Sep 2025 18:56:30 +0000 Subject: [PATCH] Add nvmlDeviceGetHandleByPciBusId_v2 --- zluda_ml/src/impl_unix.rs | 113 ++++++++++++++++++++++++++++++++++++++ zluda_ml/src/impl_win.rs | 11 ++-- zluda_ml/src/lib.rs | 1 + 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/zluda_ml/src/impl_unix.rs b/zluda_ml/src/impl_unix.rs index 93d04e3..55437a6 100644 --- a/zluda_ml/src/impl_unix.rs +++ b/zluda_ml/src/impl_unix.rs @@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint) rsmi_num_monitor_devices(device_count) } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?; + let bdfid = pci.to_bdfid(); + let mut device_count = 0; + rsmi_num_monitor_devices(&mut device_count)?; + for dv_ind in 0..device_count { + let mut curr_bdfid = 0; + rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?; + if curr_bdfid == bdfid { + *device = Device { _index: dv_ind }.wrap(); + return nvmlReturn_t::SUCCESS; + } + } + nvmlReturn_t::ERROR_NOT_FOUND +} + +#[derive(Clone, Copy)] +struct PciBusId { + domain: u16, + bus: u8, + device: u8, + function: u8, +} +impl PciBusId { + fn to_bdfid(self) -> u64 { + ((self.domain as u64) << 32) + | ((self.bus as u64) << 8) + | ((self.device as u64) << 3) + | (self.function as u64) + } +} + +fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option { + let s = id.to_str().ok()?.trim(); + let mut domain: u16 = 0; + let mut rest = s; + if let Some(colon1) = s.find(':') { + if colon1 == 4 { + domain = hex_u16(&s[..4])?; + rest = &s[5..]; + } + } + let mut parts = rest.split(':'); + let bus_part = parts.next()?; + let tail = parts.next()?; + if parts.next().is_some() { + return None; + } + let mut dev_func = tail.split('.'); + let dev_part = dev_func.next()?; + let func_part = dev_func.next(); + let function = match func_part { + Some(f) => hex_u8(f)?, + None => 0, + }; + Some(PciBusId { + domain, + bus: hex_u8(bus_part)?, + device: hex_u8(dev_part)?, + function, + }) +} + +fn hex_u16(s: &str) -> Option { + if s.len() > 4 { + return None; + } + u16::from_str_radix(s, 16).ok() +} + +fn hex_u8(s: &str) -> Option { + if s.len() > 2 { + return None; + } + u8::from_str_radix(s, 16).ok() +} + pub(crate) unsafe fn device_get_field_values( _device: &Device, values_count: ::core::ffi::c_int, @@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2( *device = Device { _index: index }.wrap(); nvmlReturn_t::SUCCESS } + +#[cfg(test)] +mod tests { + #[test] + fn parse_pci_bus_id_full() { + let id = std::ffi::CString::new("0100:65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } + + #[test] + fn parse_pci_bus_id_no_func() { + let id = std::ffi::CString::new("0100:65:a0").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0x0100); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0); + } + + #[test] + fn parse_pci_bus_id_no_domain() { + let id = std::ffi::CString::new("65:a0.f").unwrap(); + let parsed = super::parse_pci_bus_id(&id).unwrap(); + assert_eq!(parsed.domain, 0); + assert_eq!(parsed.bus, 0x65); + assert_eq!(parsed.device, 0xa0); + assert_eq!(parsed.function, 0xf); + } +} diff --git a/zluda_ml/src/impl_win.rs b/zluda_ml/src/impl_win.rs index 35f0dfc..205e792 100644 --- a/zluda_ml/src/impl_win.rs +++ b/zluda_ml/src/impl_win.rs @@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint crate::impl_common::unimplemented() } +pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2( + pci_bus_id: &std::ffi::CStr, + device: &mut cuda_types::nvml::nvmlDevice_t, +) -> nvmlReturn_t { + crate::impl_common::unimplemented() +} + pub(crate) unsafe fn device_get_field_values( _device: cuda_types::nvml::nvmlDevice_t, _values_count: ::core::ffi::c_int, @@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values( crate::impl_common::unimplemented() } -unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> { - crate::impl_common::unimplemented() -} - pub(crate) unsafe fn device_get_gpu_fabric_info( _device: cuda_types::nvml::nvmlDevice_t, _gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t, diff --git a/zluda_ml/src/lib.rs b/zluda_ml/src/lib.rs index fe8271c..40a7e30 100644 --- a/zluda_ml/src/lib.rs +++ b/zluda_ml/src/lib.rs @@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!( nvmlDeviceGetFieldValues, nvmlDeviceGetGpuFabricInfo, nvmlDeviceGetHandleByIndex_v2, + nvmlDeviceGetHandleByPciBusId_v2, nvmlInit, nvmlInitWithFlags, nvmlInit_v2,