Implement cuGraphExecUpdate_v2 (#528)
Some checks are pending
ZLUDA / Build (Linux) (push) Waiting to run
ZLUDA / Build (Windows) (push) Waiting to run
ZLUDA / Build AMD GPU unit tests (push) Waiting to run
ZLUDA / Run AMD GPU unit tests (push) Blocked by required conditions

This commit is contained in:
Violet 2025-09-30 13:13:37 -07:00 committed by GitHub
commit f68ea06704
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 55 additions and 1 deletions

View file

@ -1,4 +1,6 @@
use cuda_types::cuda::{CUerror, CUgraphExecUpdateResult, CUresult, CUresultConsts};
use hip_runtime_sys::*;
use zluda_common::FromCuda;
pub(crate) unsafe fn destroy(graph: hipGraph_t) -> hipError_t {
hipGraphDestroy(graph)
@ -8,6 +10,56 @@ pub(crate) unsafe fn exec_destroy(graph_exec: hipGraphExec_t) -> hipError_t {
hipGraphExecDestroy(graph_exec)
}
pub(crate) fn exec_update_v2(
h_graph_exec: hipGraphExec_t,
h_graph: hipGraph_t,
result_info: &mut cuda_types::cuda::CUgraphExecUpdateResultInfo,
) -> CUresult {
let mut h_error_node: hipGraphNode_t = unsafe { std::mem::zeroed() };
let mut update_result: hipGraphExecUpdateResult = hipGraphExecUpdateResult(0);
unsafe { hipGraphExecUpdate(h_graph_exec, h_graph, &mut h_error_node, &mut update_result) }?;
// We use FromCuda here instead of transmute in case our hipGraphNode_t representation changes
// in the future.
let error_node: *mut hipGraphNode_t =
FromCuda::<_, CUerror>::from_cuda(&std::ptr::from_mut(&mut result_info.errorNode))?;
let error_from_node: *mut hipGraphNode_t =
FromCuda::<_, CUerror>::from_cuda(&std::ptr::from_mut(&mut result_info.errorFromNode))?;
unsafe { *error_node = h_error_node };
unsafe { *error_from_node = h_error_node };
result_info.errorFromNode = result_info.errorNode;
result_info.result = match update_result {
hipGraphExecUpdateResult::hipGraphExecUpdateSuccess => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS
}
hipGraphExecUpdateResult::hipGraphExecUpdateError => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorTopologyChanged => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorNodeTypeChanged => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorFunctionChanged => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorParametersChanged => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorNotSupported => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED
}
hipGraphExecUpdateResult::hipGraphExecUpdateErrorUnsupportedFunctionChange => {
CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE
}
_ => return CUresult::ERROR_NOT_SUPPORTED,
};
Ok(())
}
pub(crate) unsafe fn get_nodes(
graph: hipGraph_t,
nodes: *mut hipGraphNode_t,

View file

@ -111,6 +111,7 @@ cuda_macros::cuda_function_declarations!(
cuGetProcAddress_v2,
cuGraphDestroy,
cuGraphExecDestroy,
cuGraphExecUpdate_v2,
cuGraphGetNodes,
cuGraphInstantiateWithFlags,
cuGraphLaunch,

View file

@ -174,7 +174,8 @@ from_cuda_nop!(
CUmemAllocationGranularity_flags,
CUmemAllocationProp,
CUresult,
CUfunction_attribute
CUfunction_attribute,
CUgraphExecUpdateResultInfo
);
from_cuda_transmute!(
CUuuid => hipUUID,