mirror of
https://github.com/vosen/ZLUDA.git
synced 2025-09-28 20:29:11 +00:00
Add nvmlDeviceGetHandleByPciBusId_v2
This commit is contained in:
parent
c1adc216fe
commit
2e35d157ce
3 changed files with 121 additions and 4 deletions
|
@ -43,6 +43,86 @@ pub(crate) unsafe fn device_get_count_v2(device_count: &mut ::core::ffi::c_uint)
|
||||||
rsmi_num_monitor_devices(device_count)
|
rsmi_num_monitor_devices(device_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||||
|
pci_bus_id: &std::ffi::CStr,
|
||||||
|
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||||
|
) -> nvmlReturn_t {
|
||||||
|
let pci = parse_pci_bus_id(pci_bus_id).ok_or(nvmlError_t::INVALID_ARGUMENT)?;
|
||||||
|
let bdfid = pci.to_bdfid();
|
||||||
|
let mut device_count = 0;
|
||||||
|
rsmi_num_monitor_devices(&mut device_count)?;
|
||||||
|
for dv_ind in 0..device_count {
|
||||||
|
let mut curr_bdfid = 0;
|
||||||
|
rsmi_dev_pci_id_get(dv_ind, &mut curr_bdfid)?;
|
||||||
|
if curr_bdfid == bdfid {
|
||||||
|
*device = Device { _index: dv_ind }.wrap();
|
||||||
|
return nvmlReturn_t::SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nvmlReturn_t::ERROR_NOT_FOUND
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct PciBusId {
|
||||||
|
domain: u16,
|
||||||
|
bus: u8,
|
||||||
|
device: u8,
|
||||||
|
function: u8,
|
||||||
|
}
|
||||||
|
impl PciBusId {
|
||||||
|
fn to_bdfid(self) -> u64 {
|
||||||
|
((self.domain as u64) << 32)
|
||||||
|
| ((self.bus as u64) << 8)
|
||||||
|
| ((self.device as u64) << 3)
|
||||||
|
| (self.function as u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_pci_bus_id(id: &std::ffi::CStr) -> Option<PciBusId> {
|
||||||
|
let s = id.to_str().ok()?.trim();
|
||||||
|
let mut domain: u16 = 0;
|
||||||
|
let mut rest = s;
|
||||||
|
if let Some(colon1) = s.find(':') {
|
||||||
|
if colon1 == 4 {
|
||||||
|
domain = hex_u16(&s[..4])?;
|
||||||
|
rest = &s[5..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut parts = rest.split(':');
|
||||||
|
let bus_part = parts.next()?;
|
||||||
|
let tail = parts.next()?;
|
||||||
|
if parts.next().is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let mut dev_func = tail.split('.');
|
||||||
|
let dev_part = dev_func.next()?;
|
||||||
|
let func_part = dev_func.next();
|
||||||
|
let function = match func_part {
|
||||||
|
Some(f) => hex_u8(f)?,
|
||||||
|
None => 0,
|
||||||
|
};
|
||||||
|
Some(PciBusId {
|
||||||
|
domain,
|
||||||
|
bus: hex_u8(bus_part)?,
|
||||||
|
device: hex_u8(dev_part)?,
|
||||||
|
function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hex_u16(s: &str) -> Option<u16> {
|
||||||
|
if s.len() > 4 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
u16::from_str_radix(s, 16).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hex_u8(s: &str) -> Option<u8> {
|
||||||
|
if s.len() > 2 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
u8::from_str_radix(s, 16).ok()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_field_values(
|
pub(crate) unsafe fn device_get_field_values(
|
||||||
_device: &Device,
|
_device: &Device,
|
||||||
values_count: ::core::ffi::c_int,
|
values_count: ::core::ffi::c_int,
|
||||||
|
@ -75,3 +155,36 @@ pub(crate) fn device_get_handle_by_index_v2(
|
||||||
*device = Device { _index: index }.wrap();
|
*device = Device { _index: index }.wrap();
|
||||||
nvmlReturn_t::SUCCESS
|
nvmlReturn_t::SUCCESS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_full() {
|
||||||
|
let id = std::ffi::CString::new("0100:65:a0.f").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0x0100);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0xf);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_no_func() {
|
||||||
|
let id = std::ffi::CString::new("0100:65:a0").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0x0100);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_pci_bus_id_no_domain() {
|
||||||
|
let id = std::ffi::CString::new("65:a0.f").unwrap();
|
||||||
|
let parsed = super::parse_pci_bus_id(&id).unwrap();
|
||||||
|
assert_eq!(parsed.domain, 0);
|
||||||
|
assert_eq!(parsed.bus, 0x65);
|
||||||
|
assert_eq!(parsed.device, 0xa0);
|
||||||
|
assert_eq!(parsed.function, 0xf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,13 @@ pub(crate) unsafe fn device_get_count_v2(_device_count: &mut ::core::ffi::c_uint
|
||||||
crate::impl_common::unimplemented()
|
crate::impl_common::unimplemented()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) unsafe fn device_get_handle_by_pci_bus_id_v2(
|
||||||
|
pci_bus_id: &std::ffi::CStr,
|
||||||
|
device: &mut cuda_types::nvml::nvmlDevice_t,
|
||||||
|
) -> nvmlReturn_t {
|
||||||
|
crate::impl_common::unimplemented()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_field_values(
|
pub(crate) unsafe fn device_get_field_values(
|
||||||
_device: cuda_types::nvml::nvmlDevice_t,
|
_device: cuda_types::nvml::nvmlDevice_t,
|
||||||
_values_count: ::core::ffi::c_int,
|
_values_count: ::core::ffi::c_int,
|
||||||
|
@ -31,10 +38,6 @@ pub(crate) unsafe fn device_get_field_values(
|
||||||
crate::impl_common::unimplemented()
|
crate::impl_common::unimplemented()
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn get_field_value(_field: &mut nvmlFieldValue_st) -> Result<(), nvmlError_t> {
|
|
||||||
crate::impl_common::unimplemented()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
pub(crate) unsafe fn device_get_gpu_fabric_info(
|
||||||
_device: cuda_types::nvml::nvmlDevice_t,
|
_device: cuda_types::nvml::nvmlDevice_t,
|
||||||
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
_gpu_fabric_info: &mut cuda_types::nvml::nvmlGpuFabricInfo_t,
|
||||||
|
|
|
@ -48,6 +48,7 @@ cuda_macros::nvml_function_declarations!(
|
||||||
nvmlDeviceGetFieldValues,
|
nvmlDeviceGetFieldValues,
|
||||||
nvmlDeviceGetGpuFabricInfo,
|
nvmlDeviceGetGpuFabricInfo,
|
||||||
nvmlDeviceGetHandleByIndex_v2,
|
nvmlDeviceGetHandleByIndex_v2,
|
||||||
|
nvmlDeviceGetHandleByPciBusId_v2,
|
||||||
nvmlInit,
|
nvmlInit,
|
||||||
nvmlInitWithFlags,
|
nvmlInitWithFlags,
|
||||||
nvmlInit_v2,
|
nvmlInit_v2,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue