mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-26 19:29:05 +00:00
Add nvmlDeviceGetHandleByPciBusId_v2
This commit is contained in:
parent
c1adc216fe
commit
2e35d157ce
3 changed files with 121 additions and 4 deletions
|
@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
|
|||
rsmi_num_monitor_devices(device_count)
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||
pci_bus_id: &std::ffi::CStr,
|
||||
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||
) -> nvmlReturn_t {
|
||||
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
|
||||
let bdfid = pci.to_bdfid();
|
||||
let mut device_count = 0;
|
||||
rsmi_num_monitor_devices(&mut device_count)?;
|
||||
for dv_ind in 0..device_count {
|
||||
let mut curr_bdfid = 0;
|
||||
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
|
||||
if curr_bdfid == bdfid {
|
||||
*device = Device { _index: dv_ind }.wrap();
|
||||
return nvmlReturn_t::SUCCESS;
|
||||
}
|
||||
}
|
||||
nvmlReturn_t::ERROR_NOT_FOUND
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct PciBusId {
|
||||
domain: u16,
|
||||
bus: u8,
|
||||
device: u8,
|
||||
function: u8,
|
||||
}
|
||||
impl PciBusId {
|
||||
fn to_bdfid(self) -> u64 {
|
||||
((self.domain as u64) << 32)
|
||||
| ((self.bus as u64) << 8)
|
||||
| ((self.device as u64) << 3)
|
||||
| (self.function as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
|
||||
let s = id.to_str().ok()?.trim();
|
||||
let mut domain: u16 = 0;
|
||||
let mut rest = s;
|
||||
if let Some(colon1) = s.find(':') {
|
||||
if colon1 == 4 {
|
||||
domain = hex_u16(&s[..4])?;
|
||||
rest = &s[5..];
|
||||
}
|
||||
}
|
||||
let mut parts = rest.split(':');
|
||||
let bus_part = parts.next()?;
|
||||
let tail = parts.next()?;
|
||||
if parts.next().is_some() {
|
||||
return None;
|
||||
}
|
||||
let mut dev_func = tail.split('.');
|
||||
let dev_part = dev_func.next()?;
|
||||
let func_part = dev_func.next();
|
||||
let function = match func_part {
|
||||
Some(f) => hex_u8(f)?,
|
||||
None => 0,
|
||||
};
|
||||
Some(PciBusId {
|
||||
domain,
|
||||
bus: hex_u8(bus_part)?,
|
||||
device: hex_u8(dev_part)?,
|
||||
function,
|
||||
})
|
||||
}
|
||||
|
||||
fn hex_u16(s: &str) -> Option<u16> {
|
||||
if s.len() > 4 {
|
||||
return None;
|
||||
}
|
||||
u16::from_str_radix(s, 16).ok()
|
||||
}
|
||||
|
||||
fn hex_u8(s: &str) -> Option<u8> {
|
||||
if s.len() > 2 {
|
||||
return None;
|
||||
}
|
||||
u8::from_str_radix(s, 16).ok()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_field_values(
|
||||
_device: &Device,
|
||||
values_count: ::core::ffi::c_int,
|
||||
|
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
|
|||
*device = Device { _index: index }.wrap();
|
||||
nvmlReturn_t::SUCCESS
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn parse_pci_bus_id_full() {
|
||||
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0x0100);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pci_bus_id_no_func() {
|
||||
let id = std::ffi::CString::new("0100:65:a0").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0x0100);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pci_bus_id_no_domain() {
|
||||
let id = std::ffi::CString::new("65:a0.f").unwrap();
|
||||
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||
assert_eq!(parsed.domain, 0);
|
||||
assert_eq!(parsed.bus, 0x65);
|
||||
assert_eq!(parsed.device, 0xa0);
|
||||
assert_eq!(parsed.function, 0xf);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
|
|||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||
pci_bus_id: &std::ffi::CStr,
|
||||
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||
) -> nvmlReturn_t {
|
||||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_field_values(
|
||||
_device: cuda_types::nvml::nvmlDevice_t,
|
||||
_values_count: ::core::ffi::c_int,
|
||||
|
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
|
|||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
|
||||
crate::impl_common::unimplemented()
|
||||
}
|
||||
|
||||
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
||||
_device: cuda_types::nvml::nvmlDevice_t,
|
||||
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
||||
|
|
|
@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
|
|||
nvmlDeviceGetFieldValues,
|
||||
nvmlDeviceGetGpuFabricInfo,
|
||||
nvmlDeviceGetHandleByIndex_v2,
|
||||
nvmlDeviceGetHandleByPciBusId_v2,
|
||||
nvmlInit,
|
||||
nvmlInitWithFlags,
|
||||
nvmlInit_v2,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue