diff --git a/zluda/src/impl/graph.rs b/zluda/src/impl/graph.rs index 9cfc786..59d091a 100644 --- a/zluda/src/impl/graph.rs +++ b/zluda/src/impl/graph.rs @@ -1,3 +1,4 @@ +use cuda_types::cuda::CUgraphExecUpdateResult; use hip_runtime_sys::*; pub(crate) unsafe fn destroy(graph: hipGraph_t) -> hipError_t { @@ -8,6 +9,48 @@ pub(crate) unsafe fn exec_destroy(graph_exec: hipGraphExec_t) -> hipError_t { hipGraphExecDestroy(graph_exec) } +pub(crate) fn exec_update_v2( + h_graph_exec: hipGraphExec_t, + h_graph: hipGraph_t, + result_info: &mut cuda_types::cuda::CUgraphExecUpdateResultInfo, +) -> hipError_t { + let mut h_error_node: hipGraphNode_t = unsafe { std::mem::zeroed() }; + let mut update_result: hipGraphExecUpdateResult = unsafe { std::mem::zeroed() }; + unsafe { hipGraphExecUpdate(h_graph_exec, h_graph, &mut h_error_node, &mut update_result) }?; + + result_info.errorNode = unsafe { std::mem::transmute(h_error_node) }; + result_info.errorFromNode = unsafe { std::mem::transmute(h_error_node) }; + result_info.result = match update_result { + hipGraphExecUpdateResult::hipGraphExecUpdateSuccess => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS + } + hipGraphExecUpdateResult::hipGraphExecUpdateError => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorTopologyChanged => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorNodeTypeChanged => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorFunctionChanged => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorParametersChanged => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorNotSupported => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED + } + hipGraphExecUpdateResult::hipGraphExecUpdateErrorUnsupportedFunctionChange => { + CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE + } + _ => return hipError_t::ErrorNotSupported, + }; + + Ok(()) +} + pub(crate) unsafe fn get_nodes( graph: hipGraph_t, nodes: *mut hipGraphNode_t, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 959691a..01bdb9f 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -111,6 +111,7 @@ cuda_macros::cuda_function_declarations!( cuGetProcAddress_v2, cuGraphDestroy, cuGraphExecDestroy, + cuGraphExecUpdate_v2, cuGraphGetNodes, cuGraphInstantiateWithFlags, cuGraphLaunch, diff --git a/zluda_common/src/lib.rs b/zluda_common/src/lib.rs index 4f8aef7..b55d539 100644 --- a/zluda_common/src/lib.rs +++ b/zluda_common/src/lib.rs @@ -173,7 +173,8 @@ from_cuda_nop!( cublasLtMatmulDescAttributes_t, CUmemAllocationGranularity_flags, CUmemAllocationProp, - CUresult + CUresult, + CUgraphExecUpdateResultInfo ); from_cuda_transmute!( CUuuid => hipUUID,