From 9c1485648f566fb9ac1e7cbd1092f7aba32df60e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 18 Apr 2026 23:18:01 -0500 Subject: [PATCH 01/16] Add pinned GPU memory allocation on startup. Saves alloc/dealloc overhead by statically allocating buffers for model IO/params on session start --- .../migraphx/migraphx_execution_provider.cc | 958 ++++++++---------- .../migraphx/migraphx_execution_provider.h | 36 +- 2 files changed, 458 insertions(+), 536 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2b73055c87df7..0a545cf7e9873 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1418,190 +1418,6 @@ static void pad_input_tensor(const void* src_data, void* dst_data, } } -// Allocate padded input buffers and pad the data for dynamic batching -// Returns true if padding was applied, false otherwise -// OPTIMIZATION: Reuses existing buffers if padded batch size matches -static bool allocate_and_pad_inputs( - MIGraphXFuncState* mgx_state, - Ort::KernelContext& ctx, - std::size_t original_batch_size, - std::size_t padded_batch_size, - hipStream_t stream) { - - if (padded_batch_size <= original_batch_size || mgx_state->cached_inputs.empty()) { - return false; // No padding needed - } - - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION: Check if we can reuse existing padded buffers - // ═══════════════════════════════════════════════════════════════════════════ - bool can_reuse_buffers = ( - mgx_state->last_padded_batch_size == padded_batch_size && - !mgx_state->padded_input_buffers.empty() && - mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size() - ); - - if (can_reuse_buffers) { - - // Just copy new data into existing buffers - skip allocation - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& cached_inp = mgx_state->cached_inputs[i]; - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (tensor_shape.empty()) continue; - - auto& padded_buf = mgx_state->padded_input_buffers[i]; - - // Calculate elements per batch - std::size_t elements_per_batch = 1; - for (std::size_t j = 1; j < tensor_shape.size(); ++j) { - elements_per_batch *= tensor_shape[j]; - } - - // Calculate element size from tensor type - std::size_t element_size_bytes; - switch (tensor_info.GetElementType()) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - element_size_bytes = sizeof(float); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - element_size_bytes = sizeof(uint16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - element_size_bytes = sizeof(int64_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - element_size_bytes = sizeof(int32_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - element_size_bytes = sizeof(int16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - element_size_bytes = sizeof(int8_t); - break; - default: - element_size_bytes = sizeof(float); // Fallback to float - break; - } - - // Reuse existing buffer - just pad with new data - const void* original_data = input_tensor.GetTensorRawData(); - pad_input_tensor(original_data, padded_buf.data, original_batch_size, padded_batch_size, - element_size_bytes, elements_per_batch, stream); - } - - // Update original batch tracking (padded batch is already correct) - mgx_state->last_original_batch_size = original_batch_size; - - return true; - } - - // ═══════════════════════════════════════════════════════════════════════════ - // Normal path: Allocate new buffers (batch size changed or first run) - // ═══════════════════════════════════════════════════════════════════════════ - - // Free old buffers if they exist - for (auto& buf : mgx_state->padded_input_buffers) { - if (buf.data != nullptr) { - HIP_CALL_THROW(hipFree(buf.data)); - buf.data = nullptr; - } - } - mgx_state->padded_input_buffers.clear(); - - // Allocate and pad each input - mgx_state->padded_input_buffers.reserve(mgx_state->cached_inputs.size()); - - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (tensor_shape.empty()) { - continue; - } - - // Calculate padded shape - std::vector padded_lens(tensor_shape.begin(), tensor_shape.end()); - padded_lens[0] = padded_batch_size; // Replace batch dimension - - // Create padded MIGraphX shape - migraphx::shape padded_mgx_shape{cached_inp.mgx_shape.type(), padded_lens}; - std::size_t padded_bytes = padded_mgx_shape.bytes(); - - // Allocate GPU buffer for padded data - void* padded_data = nullptr; - HIP_CALL_THROW(hipMalloc(&padded_data, padded_bytes)); - - // Calculate elements per batch - std::size_t elements_per_batch = 1; - for (std::size_t i = 1; i < tensor_shape.size(); ++i) { - elements_per_batch *= tensor_shape[i]; - } - - // Calculate element size from tensor type - std::size_t element_size_bytes; - switch (tensor_info.GetElementType()) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - element_size_bytes = sizeof(float); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - element_size_bytes = sizeof(uint16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - element_size_bytes = sizeof(int64_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - element_size_bytes = sizeof(int32_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - element_size_bytes = sizeof(int16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - element_size_bytes = sizeof(int8_t); - break; - default: - element_size_bytes = sizeof(float); // Fallback to float - break; - } - - // Pad the data - const void* original_data = input_tensor.GetTensorRawData(); - pad_input_tensor(original_data, padded_data, original_batch_size, padded_batch_size, - element_size_bytes, elements_per_batch, stream); - - // Store padded buffer info - MIGraphXFuncState::PaddedBuffer buf; - buf.data = padded_data; - buf.size_bytes = padded_bytes; - buf.mgx_shape = padded_mgx_shape; - mgx_state->padded_input_buffers.push_back(buf); - - } - - // Update batch tracking - mgx_state->last_original_batch_size = original_batch_size; - mgx_state->last_padded_batch_size = padded_batch_size; - - return true; -} - // Helper: Extract output index from MIGraphX output parameter name // MIGraphX names outputs as "#output_0", "#output_1", etc. static int compute_output_index(const std::string_view sv) { @@ -1614,92 +1430,200 @@ static int compute_output_index(const std::string_view sv) { return ToInteger(Trim(index_str, std::isdigit)); } -// Free temporary output buffers -static void free_temp_output_buffers(MIGraphXFuncState* mgx_state) { - for (auto& buf : mgx_state->temp_output_buffers) { - if (buf.data != nullptr) { - (void)hipFree(buf.data); // Don't throw on cleanup - buf.data = nullptr; - } - } - mgx_state->temp_output_buffers.clear(); - mgx_state->temp_output_padded_batch_size = 0; + +// Allocate pinned I/O buffers at the given max batch size. Called once per node +// at session creation (or lazily on first inference for deferred compilation). +// All batch sizes share these buffers — smaller batches use the leading prefix. +static void allocate_pinned_io( + MIGraphXFuncState* mgx_state, + const migraphx::program_parameter_shapes& param_shapes, + const migraphx::shapes& output_shapes, + std::size_t max_batch_size) +{ + auto& pio = mgx_state->pinned_io; + if (pio.allocated) return; + + const auto& map_input_name_index = mgx_state->input_name_indexes; + + pio.inputs.clear(); + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) == map_input_name_index.end()) continue; + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(base_shape.type(), lens); + std::size_t bytes = max_shape.bytes(); + + void* ptr = nullptr; + HIP_CALL_THROW(hipMalloc(&ptr, bytes)); + HIP_CALL_THROW(hipMemset(ptr, 0, bytes)); + pio.inputs.push_back({ptr, bytes, max_shape}); + } + + pio.outputs.clear(); + for (std::size_t i = 0; i < output_shapes.size(); ++i) { + auto lens = output_shapes[i].lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(output_shapes[i].type(), lens); + std::size_t bytes = max_shape.bytes(); + + void* ptr = nullptr; + HIP_CALL_THROW(hipMalloc(&ptr, bytes)); + HIP_CALL_THROW(hipMemset(ptr, 0, bytes)); + pio.outputs.push_back({ptr, bytes, max_shape}); + } + + pio.max_batch_size = max_batch_size; + pio.allocated = true; + + std::size_t total_bytes = 0; + for (const auto& b : pio.inputs) total_bytes += b.size_bytes; + for (const auto& b : pio.outputs) total_bytes += b.size_bytes; + LOGS_DEFAULT(INFO) << "[PinnedIO] Allocated: max_batch=" << max_batch_size + << " inputs=" << pio.inputs.size() + << " outputs=" << pio.outputs.size() + << " total=" << (total_bytes / (1024.0 * 1024.0)) << " MB"; } -// Clear cached MIGraphX shapes (call when program changes) -static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { - mgx_state->cached_mgx_param_shapes.reset(); - mgx_state->cached_mgx_output_shapes.reset(); - mgx_state->ultra_fast_caches_populated = false; - mgx_state->cached_program_hash.clear(); +static void free_pinned_io(MIGraphXFuncState* mgx_state) { + auto& pio = mgx_state->pinned_io; + for (auto& buf : pio.inputs) { + if (buf.data) { (void)hipFree(buf.data); buf.data = nullptr; } + } + pio.inputs.clear(); + for (auto& buf : pio.outputs) { + if (buf.data) { (void)hipFree(buf.data); buf.data = nullptr; } + } + pio.outputs.clear(); + pio.allocated = false; } -// Allocate or reuse temporary output buffers for slicing mode -// Returns vector of raw pointers for use with handle_program_input_outputs -static std::vector get_or_allocate_temp_output_buffers( +// Copy ORT input tensors into pinned buffers and pad if needed. +// Returns the number of bytes-per-batch for each input (used for element-size calc). +static void copy_inputs_to_pinned( MIGraphXFuncState* mgx_state, const migraphx::program_parameter_shapes& param_shapes, - const migraphx::shapes& output_shapes, - const std::unordered_map& map_input_name_index, - std::size_t padded_batch_size) + Ort::KernelContext& ctx, + std::size_t actual_batch, + std::size_t compiled_batch, + hipStream_t stream) { - // Check if we can reuse existing buffers - bool can_reuse = ( - mgx_state->temp_output_padded_batch_size == padded_batch_size && - !mgx_state->temp_output_buffers.empty() - ); - - if (can_reuse) { - // Return raw pointers from existing buffers - std::vector ptrs; - ptrs.reserve(mgx_state->temp_output_buffers.size()); - for (const auto& buf : mgx_state->temp_output_buffers) { - ptrs.push_back(buf.data); + auto& pio = mgx_state->pinned_io; + const auto& map_input_name_index = mgx_state->input_name_indexes; + std::size_t idx = 0; + + for (const auto& name : param_shapes.names()) { + auto it = map_input_name_index.find(name); + if (it == map_input_name_index.end()) continue; + + auto& pin = pio.inputs[idx]; + const auto& input_tensor = ctx.GetInput(it->second); + const void* src = input_tensor.GetTensorRawData(); + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + + std::size_t elements_per_batch = 1; + for (std::size_t d = 1; d < lens.size(); ++d) elements_per_batch *= lens[d]; + + std::size_t total_elems = 1; + for (auto l : lens) total_elems *= l; + std::size_t byte_per_elem = (total_elems > 0) ? base_shape.bytes() / total_elems : 0; + std::size_t bytes_per_batch = elements_per_batch * byte_per_elem; + + if (actual_batch == compiled_batch) { + std::size_t copy_bytes = actual_batch * bytes_per_batch; + if (copy_bytes > 0) { + HIP_CALL_THROW(hipMemcpyAsync(pin.data, src, copy_bytes, hipMemcpyDefault, stream)); + } + } else { + pad_input_tensor(src, pin.data, actual_batch, compiled_batch, + byte_per_elem, elements_per_batch, stream); } - return ptrs; + ++idx; } - - // Free old buffers if they exist - free_temp_output_buffers(mgx_state); - - // Count outputs and allocate - std::vector ptrs; +} + +// Build program_parameters binding pinned buffers at the given compiled shape. +static std::pair> +bind_pinned_program_params( + MIGraphXFuncState* mgx_state, + const migraphx::program_parameter_shapes& param_shapes, + const migraphx::shapes& output_shapes) +{ + auto& pio = mgx_state->pinned_io; + const auto& map_input_name_index = mgx_state->input_name_indexes; + + migraphx::program_parameters m; + std::vector prog_output_indices; + + std::size_t input_idx = 0; + std::size_t output_idx = 0; + for (const auto& name : param_shapes.names()) { - // Skip inputs if (map_input_name_index.find(name) != map_input_name_index.end()) { - continue; - } - - // This is an output - const auto output_index = compute_output_index(name); - if (output_index != -1) { - const auto& mgx_shape = param_shapes[name]; - std::size_t size_bytes = mgx_shape.bytes(); - - void* buffer = nullptr; - auto hip_status = hipMalloc(&buffer, size_bytes); - if (hip_status != hipSuccess) { - // Clean up any allocated buffers on failure - for (auto& buf : mgx_state->temp_output_buffers) { - if (buf.data) (void)hipFree(buf.data); - } - mgx_state->temp_output_buffers.clear(); - ORT_THROW("hipMalloc failed for temporary output buffer"); + m.add(name, migraphx::argument(param_shapes[name], pio.inputs[input_idx].data)); + ++input_idx; + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + m.add(name, migraphx::argument(param_shapes[name], pio.outputs[output_idx].data)); + prog_output_indices.push_back(static_cast(oi)); + ++output_idx; } - - MIGraphXFuncState::TempOutputBuffer temp_buf; - temp_buf.data = buffer; - temp_buf.size_bytes = size_bytes; - temp_buf.mgx_shape = mgx_shape; - mgx_state->temp_output_buffers.push_back(temp_buf); - ptrs.push_back(buffer); - } } - - mgx_state->temp_output_padded_batch_size = padded_batch_size; - - return ptrs; + + return {std::move(m), std::move(prog_output_indices)}; +} + +// Copy results from pinned output buffers to ORT output tensors. +static void copy_pinned_outputs_to_ort( + MIGraphXFuncState* mgx_state, + const migraphx::shapes& output_shapes, + const std::vector& prog_output_indices, + Ort::KernelContext& ctx, + std::size_t actual_batch, + hipStream_t stream) +{ + auto& pio = mgx_state->pinned_io; + + for (std::size_t i = 0; i < prog_output_indices.size() && i < pio.outputs.size(); ++i) { + const auto oi = prog_output_indices[i]; + const auto& pin = pio.outputs[i]; + const auto& out_shape = output_shapes[oi]; + auto lens = out_shape.lengths(); + + std::vector ort_shape(lens.begin(), lens.end()); + if (!ort_shape.empty()) { + ort_shape[0] = static_cast(actual_batch); + } + + auto output_tensor = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + void* dst = output_tensor.GetTensorMutableRawData(); + + std::size_t total_elems = 1; + for (auto l : lens) total_elems *= l; + std::size_t copy_bytes = 0; + if (total_elems > 0 && !lens.empty()) { + std::size_t byte_per_elem = out_shape.bytes() / total_elems; + std::size_t elems_per_batch = total_elems / std::max(1, lens[0]); + copy_bytes = actual_batch * elems_per_batch * byte_per_elem; + } + + if (copy_bytes > 0) { + HIP_CALL_THROW(hipMemcpyAsync(dst, pin.data, copy_bytes, hipMemcpyDefault, stream)); + } + } +} + + + +// Clear cached MIGraphX shapes (call when program changes) +static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { + mgx_state->cached_mgx_param_shapes.reset(); + mgx_state->cached_mgx_output_shapes.reset(); + mgx_state->ultra_fast_caches_populated = false; + mgx_state->cached_program_hash.clear(); } // Order matters here especially if the program uses mixed quantization @@ -2310,84 +2234,58 @@ static bool execute_ultra_fast_path( if (!mgx_state->caches_valid || mgx_state->last_input_shapes_raw.empty()) { return false; } - - // Ultra-fast path not supported when outputs need dynamic slicing + if (mgx_state->cached_outputs.empty()) { return false; } - // Quick shape comparison bool shapes_match = true; std::size_t offset = 0; const auto& last_shapes = mgx_state->last_input_shapes_raw; - + std::size_t original_batch_size = 0; std::size_t padded_batch_size = 0; bool is_first = true; for (const auto& inp : mgx_state->cached_inputs) { const auto& shape = ctx.GetInput(inp.ort_index).GetTensorTypeAndShapeInfo().GetShape(); - + if (offset + shape.size() > last_shapes.size()) { shapes_match = false; break; } - - // For dynamic batch, we check if the current batch needs padding + if (mgx_state->has_dynamic_batch && !mgx_state->compiled_batch_sizes.empty()) { - // Get batch sizes from first input if (is_first) { original_batch_size = static_cast(shape[0]); padded_batch_size = static_cast(last_shapes[offset]); is_first = false; - - // Check if the batch size matches (original or padded) + if (shape[0] != last_shapes[offset]) { - // Batch size changed - check if we can use padding std::size_t required_padded = find_nearest_compiled_batch_size( original_batch_size, mgx_state->compiled_batch_sizes); - if (required_padded != padded_batch_size) { shapes_match = false; break; } } } - - // All current inputs should have the same batch size (original_batch_size) + if (static_cast(shape[0]) != original_batch_size) { - shapes_match = false; - break; + shapes_match = false; break; } - - // Cached shape should have padded batch size in dimension 0 if (last_shapes[offset] != static_cast(padded_batch_size)) { - shapes_match = false; - break; + shapes_match = false; break; } - - // Check non-batch dimensions (current vs cached) - bool rest_matches = true; for (std::size_t i = 1; i < shape.size(); ++i) { - if (last_shapes[offset + i] != shape[i]) { - rest_matches = false; - break; - } - } - if (!rest_matches) { - shapes_match = false; - break; + if (last_shapes[offset + i] != shape[i]) { shapes_match = false; break; } } } else { - // No dynamic batching - strict comparison for (std::size_t i = 0; i < shape.size(); ++i) { - if (last_shapes[offset + i] != shape[i]) { - shapes_match = false; - break; - } + if (last_shapes[offset + i] != shape[i]) { shapes_match = false; break; } } } - + if (!shapes_match) break; offset += shape.size(); } @@ -2396,40 +2294,38 @@ static bool execute_ultra_fast_path( return false; } - // Ultra-fast path doesn't support output slicing because cached_output_ort_shapes - // contains padded shapes, not sliced shapes. Fall back to fast path which handles - // slicing properly via temp output buffers. - if (padded_batch_size > 0 && original_batch_size > 0 && padded_batch_size > original_batch_size) { - return false; + auto& prog = mgx_state->prog; + std::size_t actual_batch = original_batch_size > 0 ? original_batch_size + : (!mgx_state->cached_inputs.empty() + ? static_cast(ctx.GetInput(mgx_state->cached_inputs[0].ort_index) + .GetTensorTypeAndShapeInfo().GetShape()[0]) + : 0); + std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; + bool needs_pinned = (actual_batch < compiled_batch) && mgx_state->pinned_io.allocated; + + if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { + // Pinned I/O path: pad inputs -> run -> slice outputs + const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); + const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); + + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); + + auto& m = mgx_state->cached_prog_params.value(); + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, + mgx_state->cached_prog_output_indices); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + ctx, actual_batch, rocm_stream); + return true; } - // Shapes unchanged (or compatible with padding) - rebind pointers and run directly + // Direct ORT pointer path: bind ORT tensors directly — zero memcpy overhead auto& m = mgx_state->cached_prog_params.value(); - auto& prog = mgx_state->prog; - - // Allocate and pad inputs if needed for dynamic batching - bool using_padded_inputs = false; - if (padded_batch_size > original_batch_size) { - using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); - } - - // Rebind inputs - use padded buffers if available, otherwise use original inputs - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - m.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); - } - } else { - for (const auto& inp : mgx_state->cached_inputs) { - const auto& input_tensor = ctx.GetInput(inp.ort_index); - m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, - const_cast(input_tensor.GetTensorRawData()))); - } + for (const auto& inp : mgx_state->cached_inputs) { + const auto& input_tensor = ctx.GetInput(inp.ort_index); + m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, + const_cast(input_tensor.GetTensorRawData()))); } - - // Rebind outputs - direct iteration, uses pre-allocated shape vectors for (std::size_t i = 0; i < mgx_state->cached_outputs.size(); ++i) { const auto& out = mgx_state->cached_outputs[i]; const auto& ort_shape = mgx_state->cached_output_ort_shapes[i]; @@ -2437,12 +2333,8 @@ static bool execute_ultra_fast_path( m.add(out.name.c_str(), migraphx::argument(out.mgx_shape, output_tensor.GetTensorMutableRawData())); } - - // Run directly - minimal overhead path run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, - mgx_state->cached_prog_output_indices, - original_batch_size, padded_batch_size); - + mgx_state->cached_prog_output_indices); return true; } @@ -2456,125 +2348,95 @@ static bool execute_fast_path( const std::string& current_hash, std::vector& all_input_shapes) { - - if (mgx_state->defer_compilation || !mgx_state->cached_programs_ref.has_value()) { + if (!mgx_state->cached_programs_ref.has_value()) { return false; } auto& cached_programs = mgx_state->cached_programs_ref.value().get(); + + if (mgx_state->defer_compilation && cached_programs.empty()) { + return false; + } + auto prog_it = cached_programs.find(current_hash); - - // If not found directly, check if we need to use a padded batch size + std::size_t original_batch_size = 0; std::size_t padded_batch_size = 0; bool needs_padding = false; - - if (prog_it == cached_programs.end() && mgx_state->has_dynamic_batch && + + if (prog_it == cached_programs.end() && mgx_state->has_dynamic_batch && !mgx_state->compiled_batch_sizes.empty()) { - // Try to find a padded batch size const auto& map_input_name_index = mgx_state->input_name_indexes; - + for (const auto& [name, index] : map_input_name_index) { auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_shape = input_tensor.GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { original_batch_size = static_cast(tensor_shape[0]); padded_batch_size = find_nearest_compiled_batch_size(original_batch_size, - mgx_state->compiled_batch_sizes); + mgx_state->compiled_batch_sizes); needs_padding = (padded_batch_size > original_batch_size); break; } } - + if (needs_padding && padded_batch_size > 0) { - // Build padded shapes in alphabetical order (map order) for hash calculation - // This matches the order used during compilation in compile_dynamic_batch_models std::vector padded_shapes_for_hash; padded_shapes_for_hash.reserve(all_input_shapes.size()); - for (const auto& [name, index] : map_input_name_index) { - auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - + const auto tensor_shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { padded_shapes_for_hash.push_back(static_cast(padded_batch_size)); padded_shapes_for_hash.insert(padded_shapes_for_hash.end(), tensor_shape.begin() + 1, tensor_shape.end()); } } - auto padded_hash = make_hash(padded_shapes_for_hash); prog_it = cached_programs.find(padded_hash); - - if (prog_it != cached_programs.end()) { - - // Now rebuild padded_shapes in cached_inputs order for saving to last_input_shapes_raw - // This ensures ultra-fast path shape comparison works correctly - if (!mgx_state->cached_inputs.empty()) { - std::vector padded_shapes_for_cache; - padded_shapes_for_cache.reserve(mgx_state->cached_inputs.size() * 2); - - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (!tensor_shape.empty()) { - padded_shapes_for_cache.push_back(static_cast(padded_batch_size)); - padded_shapes_for_cache.insert(padded_shapes_for_cache.end(), tensor_shape.begin() + 1, tensor_shape.end()); - } + + if (prog_it != cached_programs.end() && !mgx_state->cached_inputs.empty()) { + std::vector padded_shapes_for_cache; + padded_shapes_for_cache.reserve(mgx_state->cached_inputs.size() * 2); + for (const auto& cached_inp : mgx_state->cached_inputs) { + const auto tensor_shape = ctx.GetInput(cached_inp.ort_index).GetTensorTypeAndShapeInfo().GetShape(); + if (!tensor_shape.empty()) { + padded_shapes_for_cache.push_back(static_cast(padded_batch_size)); + padded_shapes_for_cache.insert(padded_shapes_for_cache.end(), tensor_shape.begin() + 1, tensor_shape.end()); } - all_input_shapes = std::move(padded_shapes_for_cache); - } else { - // Fallback: use map order (shouldn't happen if caches are populated) - all_input_shapes = std::move(padded_shapes_for_hash); } + all_input_shapes = std::move(padded_shapes_for_cache); } } } - + if (prog_it == cached_programs.end()) { return false; } - // Determine which hash was used to find the program - // This is needed to detect program changes and invalidate caches std::string effective_program_hash = current_hash; if (needs_padding && padded_batch_size > 0) { - // If we used padded hash, compute it for tracking std::vector padded_shapes_for_hash_tracking; for (const auto& [name, index] : mgx_state->input_name_indexes) { - auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { padded_shapes_for_hash_tracking.push_back(static_cast(padded_batch_size)); - padded_shapes_for_hash_tracking.insert(padded_shapes_for_hash_tracking.end(), + padded_shapes_for_hash_tracking.insert(padded_shapes_for_hash_tracking.end(), tensor_shape.begin() + 1, tensor_shape.end()); } } effective_program_hash = make_hash(padded_shapes_for_hash_tracking); } - // Found cached program - use it and populate caches auto& prog = mgx_state->prog; prog = prog_it->second; const auto& map_input_name_index = mgx_state->input_name_indexes; - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 1: Cache MIGraphX API results (avoid redundant API calls) - // Check if program changed - if so, invalidate caches - // ═══════════════════════════════════════════════════════════════════════════ bool program_changed = (mgx_state->cached_program_hash != effective_program_hash); - if (program_changed) { clear_cached_mgx_shapes(mgx_state); - free_temp_output_buffers(mgx_state); mgx_state->cached_program_hash = effective_program_hash; } - + if (!mgx_state->cached_mgx_param_shapes.has_value()) { mgx_state->cached_mgx_param_shapes = prog.get_parameter_shapes(); mgx_state->cached_mgx_output_shapes = prog.get_output_shapes(); @@ -2582,68 +2444,56 @@ static bool execute_fast_path( const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); - bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); - - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 2: Skip populate_ultra_fast_caches when already populated - // ═══════════════════════════════════════════════════════════════════════════ if (!mgx_state->ultra_fast_caches_populated) { populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, original_batch_size, padded_batch_size); mgx_state->ultra_fast_caches_populated = true; } - // Allocate and pad inputs if needed for dynamic batching - bool using_padded_inputs = false; - if (padded_batch_size > original_batch_size) { - using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); + std::size_t actual_batch = original_batch_size > 0 ? original_batch_size : 0; + if (actual_batch == 0) { + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } + } } + std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; + bool needs_pinned = (actual_batch < compiled_batch) && mgx_state->pinned_io.allocated; - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 3: Reuse temp output buffers when slicing - // ═══════════════════════════════════════════════════════════════════════════ - std::vector temp_output_buffer_ptrs; - if (needs_slicing) { - temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( - mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); + if (needs_pinned) { + // Pinned I/O path: pad inputs -> run on compiled shape -> slice outputs + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); + auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + + mgx_state->cached_prog_params = std::move(m); + mgx_state->cached_prog_output_indices = std::move(out_indices); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( + mgx_state, ctx, padded_batch_size); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, + mgx_state->cached_prog_params.value(), + mgx_state->cached_prog_output_indices); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + ctx, actual_batch, rocm_stream); + return true; } - // Bind inputs/outputs (use temp buffers for outputs when slicing) + // Direct ORT pointer path: bind ORT tensors directly — zero memcpy overhead auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx, needs_slicing, - needs_slicing ? &temp_output_buffer_ptrs : nullptr); + param_shapes, output_shapes, map_input_name_index, ctx); mgx_state->cached_prog_params = std::move(m); mgx_state->cached_prog_output_indices = std::move(prog_output_indices); - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering - mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( - mgx_state, ctx, using_padded_inputs ? padded_batch_size : 0); - + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - // Rebind padded inputs to program parameters - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - auto& prog_params = mgx_state->cached_prog_params.value(); - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - prog_params.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); - } - } - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), - mgx_state->cached_prog_output_indices, - original_batch_size, padded_batch_size); - - // NOTE: Temp output buffers are kept for reuse - they will be freed when batch size changes - // NOTE: Padded input buffers are also kept for reuse - + mgx_state->cached_prog_output_indices); return true; } @@ -2873,10 +2723,27 @@ static void compile_dynamic_batch_models( // Disable dynamic batch compilation for subsequent runs (set max_dynamic_batch to 0) mgx_state->max_dynamic_batch = 0; - // Also disable defer_compilation since we've now compiled mgx_state->defer_compilation = false; LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Set defer_compilation = false"; - + + // Allocate pinned I/O now that all batch models are compiled + if (!mgx_state->pinned_io.allocated && mgx_state->cached_programs_ref.has_value()) { + auto& progs = mgx_state->cached_programs_ref.value().get(); + if (!progs.empty()) { + std::size_t max_batch = 0; + if (!mgx_state->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); + } + auto& last_prog = progs.begin()->second; + auto ps = last_prog.get_parameter_shapes(); + auto os = last_prog.get_output_shapes(); + if (max_batch > 0) { + allocate_pinned_io(mgx_state, ps, os, max_batch); + } + } + } + LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; } @@ -2962,87 +2829,48 @@ static void execute_standard_path( auto param_shapes = prog.get_parameter_shapes(); auto output_shapes = prog.get_output_shapes(); + populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, + original_batch_size, padded_batch_size); + if (needs_padding) { - // ============ PADDING PATH: Batch size needs to be padded ============ - - // Populate caches (with slicing info so ultra-fast path is disabled) - populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, - original_batch_size, padded_batch_size); - - // Rebuild padded_shapes in cached_inputs order (MIGraphX parameter order) - // This ensures consistency with ultra-fast path shape comparison - padded_shapes.clear(); - padded_shapes.reserve(mgx_state->cached_inputs.size() * 2); - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (!tensor_shape.empty()) { - padded_shapes.push_back(static_cast(padded_batch_size)); - padded_shapes.insert(padded_shapes.end(), tensor_shape.begin() + 1, tensor_shape.end()); + // Padding required — use pinned I/O for pad + slice + if (!mgx_state->pinned_io.allocated) { + std::size_t max_batch = padded_batch_size; + if (!mgx_state->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); } + allocate_pinned_io(mgx_state, param_shapes, output_shapes, max_batch); } - - // Allocate and pad inputs for dynamic batching - bool using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); - - // Get or reuse cached temp output buffers (avoids hipMalloc/hipFree per run) - auto temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( - mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); - auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx, true, &temp_output_buffer_ptrs); - + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, original_batch_size, padded_batch_size, rocm_stream); + auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + mgx_state->cached_prog_params = m; - mgx_state->cached_prog_output_indices = prog_output_indices; - mgx_state->last_input_shapes_raw = std::move(padded_shapes); + mgx_state->cached_prog_output_indices = out_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( + mgx_state, ctx, padded_batch_size); mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; - - // Rebind padded inputs to program parameters - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - m.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); - } - } - - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, - prog_output_indices, original_batch_size, padded_batch_size); - - // Temp output buffers are cached on mgx_state for reuse across runs. - // They are freed when the batch size changes or when the state is destroyed. - - return; + + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, out_indices); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, out_indices, + ctx, original_batch_size, rocm_stream); } else { - // ============ EXACT MATCH PATH: Batch size matches exactly, no padding needed ============ - - // Populate caches for ultra-fast path (no slicing needed) - populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); - - // Bind inputs and allocate outputs (no slicing) + // Exact batch match — direct ORT pointer binding auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); - - // Complete cache population + mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); - mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices, - 0, 0); - - return; + + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } + return; } } } @@ -3068,30 +2896,39 @@ static void execute_standard_path( param_shapes = prog.get_parameter_shapes(); } - // Fetch output shapes once auto output_shapes = prog.get_output_shapes(); - // Populate optimized caches for ultra-fast path populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); - // Bind inputs and allocate outputs + // Lazily allocate pinned I/O for future pad/slice use (e.g. first inference after JIT compile). + // The buffers are not used on this path since the program was compiled for the exact shape. + if (!mgx_state->pinned_io.allocated) { + std::size_t batch_for_alloc = 0; + if (!mgx_state->compiled_batch_sizes.empty()) { + batch_for_alloc = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); + } + if (batch_for_alloc == 0) { + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { batch_for_alloc = static_cast(shape[0]); break; } + } + } + if (batch_for_alloc > 0) { + allocate_pinned_io(mgx_state, param_shapes, output_shapes, batch_for_alloc); + } + } + + // Direct ORT pointer path: program compiled for exact runtime shape — no padding needed auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); - // Complete cache population mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); - mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - // The program at this point was compiled (or already matched) for the exact - // runtime input shapes, so its outputs already have the original batch - // dimension. Pass 0,0 to avoid incorrect slicing arithmetic. run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } @@ -3722,6 +3559,54 @@ static inline void precompile_static_model( LOGS_DEFAULT(INFO) << "[precompile_static_model] ✓ Static model precompiled and cached"; } +// Scan disk cache for .mxr files matching the node prefix and pre-load them +// into the in-memory cache. Eliminates first-inference stalls for deferred +// compilation where .mxr files exist from a previous session. +static void preload_mxr_cache_from_disk( + const std::filesystem::path& model_cache_path, + const std::string& mxr_filename_prefix, + std::unordered_map& cached_programs) +{ + if (model_cache_path.empty() || !std::filesystem::exists(model_cache_path)) return; + + const std::string suffix = ".mxr"; + std::vector> to_load; + + for (const auto& entry : std::filesystem::directory_iterator(model_cache_path)) { + if (!entry.is_regular_file()) continue; + const auto fname = entry.path().filename().string(); + if (fname.size() <= mxr_filename_prefix.size() + suffix.size()) continue; + if (fname.substr(0, mxr_filename_prefix.size()) != mxr_filename_prefix) continue; + if (fname.substr(fname.size() - suffix.size()) != suffix) continue; + + auto hash = fname.substr(mxr_filename_prefix.size(), + fname.size() - mxr_filename_prefix.size() - suffix.size()); + if (cached_programs.find(hash) != cached_programs.end()) continue; + to_load.emplace_back(hash, entry.path()); + } + + if (to_load.empty()) return; + + LOGS_DEFAULT(INFO) << "[preload_mxr_cache] Found " << to_load.size() + << " .mxr file(s) to pre-load for prefix '" << mxr_filename_prefix << "'"; + + std::mutex mu; + std::vector> futs; + for (const auto& [hash, path] : to_load) { + futs.push_back(std::async(std::launch::async, [&, hash, path]() { + migraphx::program prog; + if (load_precompiled_model(prog, path)) { + std::lock_guard lk(mu); + cached_programs[hash] = std::move(prog); + } + })); + } + for (auto& f : futs) f.get(); + + LOGS_DEFAULT(INFO) << "[preload_mxr_cache] Pre-loaded " << cached_programs.size() + << " program(s) into in-memory cache"; +} + // Encapsulates precompilation decision logic from Compile() // Returns true if compilation should be deferred to runtime, false if precompilation succeeded static inline bool handle_precompilation_decision( @@ -3994,6 +3879,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& max_dynamic_batch_, compile_batches_); + // Pre-load any .mxr files from disk that aren't already in memory. + preload_mxr_cache_from_disk(model_cache_path_, mxr_filename_prefix, + cached_programs_[fused_node.Name()]); + // Create program object (may be empty if precompiled programs are in cache) migraphx::program prog; map_progs_[fused_node.Name()] = prog; @@ -4036,7 +3925,53 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] Static model mode for node '" << context->node_name << "'"; LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] defer_compilation=" << p->defer_compilation; } - + + // Allocate pinned I/O buffers from the cached programs. + // For dynamic batch: use max compiled batch size. + // For static: use the program's own batch size. + if (p->cached_programs_ref.has_value() && !p->cached_programs_ref.value().get().empty()) { + std::size_t max_batch = 0; + if (!p->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(p->compiled_batch_sizes.begin(), + p->compiled_batch_sizes.end()); + } + // Find the program with the largest batch to get representative shapes + migraphx::program* largest_prog = nullptr; + for (auto& [hash, prog] : p->cached_programs_ref.value().get()) { + if (!largest_prog) { + largest_prog = &prog; + if (max_batch == 0) { + auto ps = prog.get_parameter_shapes(); + for (const auto& name : ps.names()) { + if (p->input_name_indexes.find(name) != p->input_name_indexes.end()) { + auto lens = ps[name].lengths(); + if (!lens.empty() && lens[0] > 0) { + max_batch = lens[0]; + break; + } + } + } + } + } + largest_prog = &prog; + } + if (largest_prog && max_batch > 0) { + auto ps = largest_prog->get_parameter_shapes(); + auto os = largest_prog->get_output_shapes(); + allocate_pinned_io(p.get(), ps, os, max_batch); + } + + // If all batch sizes are pre-loaded, disable deferred compilation + if (p->defer_compilation && p->has_dynamic_batch && !p->compiled_batch_sizes.empty()) { + auto& progs = p->cached_programs_ref.value().get(); + if (progs.size() >= p->compiled_batch_sizes.size()) { + p->defer_compilation = false; + LOGS_DEFAULT(INFO) << "[Compile][CREATE_STATE] All " << p->compiled_batch_sizes.size() + << " batch model(s) pre-loaded — defer_compilation disabled"; + } + } + } + *state = p.release(); return 0; }; @@ -4044,12 +3979,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.release_state_func = [](FunctionState state) { if (state) { auto* s = static_cast(state); - for (auto& buf : s->padded_input_buffers) { - if (buf.data) (void)hipFree(buf.data); - } - for (auto& buf : s->temp_output_buffers) { - if (buf.data) (void)hipFree(buf.data); - } + free_pinned_io(s); delete s; } }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 24aeee586263d..f09ccaea07288 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -74,17 +74,22 @@ struct MIGraphXFuncState { bool has_dynamic_batch = false; std::vector compiled_batch_sizes; - // Padded input buffers for dynamic batching (allocated on GPU) - struct PaddedBuffer { - void* data = nullptr; // GPU buffer pointer - std::size_t size_bytes = 0; // Buffer size in bytes - migraphx::shape mgx_shape; // Padded MIGraphX shape + // Pinned I/O buffers: allocated once at max compiled batch, reused across all inferences. + // Eliminates per-inference hipMalloc/hipFree for padding and temp outputs. + struct PinnedIOBuffer { + void* data = nullptr; + std::size_t size_bytes = 0; + migraphx::shape max_shape; // Shape at max_batch_size }; - std::vector padded_input_buffers; // One per input when padding is active - // Track last batch sizes to avoid re-allocation when batch size is unchanged - std::size_t last_original_batch_size = 0; // Original batch size from last run - std::size_t last_padded_batch_size = 0; // Padded batch size from last run + struct PinnedIOSet { + std::vector inputs; + std::vector outputs; + std::size_t max_batch_size = 0; + bool allocated = false; + }; + + PinnedIOSet pinned_io; // ═══════════════════════════════════════════════════════════════════════════ // PERFORMANCE CACHES - Avoid redundant MIGraphX API calls per inference @@ -142,20 +147,7 @@ struct MIGraphXFuncState { // Track which program hash the cached shapes belong to (invalidate when program changes) std::string cached_program_hash; - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION: Reusable temporary output buffers (for slicing mode) - // ═══════════════════════════════════════════════════════════════════════════ - - // Temporary output buffers for slicing (allocated at padded size) - struct TempOutputBuffer { - void* data = nullptr; // GPU buffer pointer - std::size_t size_bytes = 0; // Buffer size in bytes - migraphx::shape mgx_shape; // Padded MIGraphX shape - }; - std::vector temp_output_buffers; - // Track padded batch size for temp output buffers - std::size_t temp_output_padded_batch_size = 0; }; // Logical device representation. From 0c0e3036ea75d4d366b8e99a3a725e24c491ba5f Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 18 Apr 2026 23:20:33 -0500 Subject: [PATCH 02/16] Cleanup part of code for pinned alloc --- .../providers/migraphx/migraphx_execution_provider.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 0a545cf7e9873..9a96a41693e1d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1461,10 +1462,10 @@ static void allocate_pinned_io( } pio.outputs.clear(); - for (std::size_t i = 0; i < output_shapes.size(); ++i) { - auto lens = output_shapes[i].lengths(); + for (const auto& out_shape : output_shapes) { + auto lens = out_shape.lengths(); if (!lens.empty()) lens[0] = max_batch_size; - auto max_shape = migraphx::shape(output_shapes[i].type(), lens); + auto max_shape = migraphx::shape(out_shape.type(), lens); std::size_t bytes = max_shape.bytes(); void* ptr = nullptr; @@ -1522,8 +1523,8 @@ static void copy_inputs_to_pinned( const auto& base_shape = param_shapes[name]; auto lens = base_shape.lengths(); - std::size_t elements_per_batch = 1; - for (std::size_t d = 1; d < lens.size(); ++d) elements_per_batch *= lens[d]; + std::size_t elements_per_batch = std::accumulate( + lens.begin() + 1, lens.end(), std::size_t{1}, std::multiplies<>{}); std::size_t total_elems = 1; for (auto l : lens) total_elems *= l; From 2e7e102ba04faa8df983b3a725e0783c39c07d56 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 21 Apr 2026 15:19:07 -0500 Subject: [PATCH 03/16] Ensure no race condition and contention with pinned_io allocation on session startup or defered compile --- .../migraphx/migraphx_execution_provider.cc | 50 ++++++++++++------- .../migraphx/migraphx_execution_provider.h | 1 + 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9a96a41693e1d..8e9bbf0a48d56 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -143,6 +143,12 @@ static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) static std::vector parse_compile_batches(const std::string& spec); +// Serializes remaining synchronous hipMalloc calls (e.g. temp output buffers) +// across all MIGraphX EP instances in the process. The primary pinned I/O +// allocation paths use hipMallocAsync/hipFreeAsync which are per-stream safe, +// but a few fallback paths still use synchronous hipMalloc. +static std::mutex g_hip_alloc_mutex; + MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, device_id_{info.device_id}, @@ -1439,7 +1445,8 @@ static void allocate_pinned_io( MIGraphXFuncState* mgx_state, const migraphx::program_parameter_shapes& param_shapes, const migraphx::shapes& output_shapes, - std::size_t max_batch_size) + std::size_t max_batch_size, + hipStream_t stream) { auto& pio = mgx_state->pinned_io; if (pio.allocated) return; @@ -1456,8 +1463,8 @@ static void allocate_pinned_io( std::size_t bytes = max_shape.bytes(); void* ptr = nullptr; - HIP_CALL_THROW(hipMalloc(&ptr, bytes)); - HIP_CALL_THROW(hipMemset(ptr, 0, bytes)); + HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); pio.inputs.push_back({ptr, bytes, max_shape}); } @@ -1469,11 +1476,13 @@ static void allocate_pinned_io( std::size_t bytes = max_shape.bytes(); void* ptr = nullptr; - HIP_CALL_THROW(hipMalloc(&ptr, bytes)); - HIP_CALL_THROW(hipMemset(ptr, 0, bytes)); + HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); pio.outputs.push_back({ptr, bytes, max_shape}); } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + pio.max_batch_size = max_batch_size; pio.allocated = true; @@ -1486,15 +1495,16 @@ static void allocate_pinned_io( << " total=" << (total_bytes / (1024.0 * 1024.0)) << " MB"; } -static void free_pinned_io(MIGraphXFuncState* mgx_state) { +static void free_pinned_io(MIGraphXFuncState* mgx_state, hipStream_t stream) { auto& pio = mgx_state->pinned_io; for (auto& buf : pio.inputs) { - if (buf.data) { (void)hipFree(buf.data); buf.data = nullptr; } + if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } } - pio.inputs.clear(); for (auto& buf : pio.outputs) { - if (buf.data) { (void)hipFree(buf.data); buf.data = nullptr; } + if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + pio.inputs.clear(); pio.outputs.clear(); pio.allocated = false; } @@ -2111,9 +2121,12 @@ std::pair> handle_program temp_buffer = (*temp_output_buffers)[temp_buffer_count]; } else { // Allocate new buffer (first run or buffer list is empty) - auto hip_status = hipMalloc(&temp_buffer, output_size_bytes); - if (hip_status != hipSuccess) { - ORT_THROW("hipMalloc failed for temporary output buffer"); + { + std::lock_guard alloc_lock(g_hip_alloc_mutex); + auto hip_status = hipMalloc(&temp_buffer, output_size_bytes); + if (hip_status != hipSuccess) { + ORT_THROW("hipMalloc failed for temporary output buffer"); + } } temp_output_buffers->push_back(temp_buffer); } @@ -2740,7 +2753,7 @@ static void compile_dynamic_batch_models( auto ps = last_prog.get_parameter_shapes(); auto os = last_prog.get_output_shapes(); if (max_batch > 0) { - allocate_pinned_io(mgx_state, ps, os, max_batch); + allocate_pinned_io(mgx_state, ps, os, max_batch, mgx_state->stream); } } } @@ -2841,7 +2854,7 @@ static void execute_standard_path( max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), mgx_state->compiled_batch_sizes.end()); } - allocate_pinned_io(mgx_state, param_shapes, output_shapes, max_batch); + allocate_pinned_io(mgx_state, param_shapes, output_shapes, max_batch, rocm_stream); } copy_inputs_to_pinned(mgx_state, param_shapes, ctx, original_batch_size, padded_batch_size, rocm_stream); @@ -2916,7 +2929,7 @@ static void execute_standard_path( } } if (batch_for_alloc > 0) { - allocate_pinned_io(mgx_state, param_shapes, output_shapes, batch_for_alloc); + allocate_pinned_io(mgx_state, param_shapes, output_shapes, batch_for_alloc, rocm_stream); } } @@ -3902,7 +3915,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& int8_calibration_cache_available_, dynamic_range_map_, model_cache_path_, dump_model_ops_, exhaustive_tune_, max_dynamic_batch_, std::ref(cached_programs_[context->node_name])}; - + p->stream = stream_; + // Initialize dynamic batch support if max_dynamic_batch > 0 if (max_dynamic_batch_ > 0) { p->has_dynamic_batch = true; @@ -3959,7 +3973,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (largest_prog && max_batch > 0) { auto ps = largest_prog->get_parameter_shapes(); auto os = largest_prog->get_output_shapes(); - allocate_pinned_io(p.get(), ps, os, max_batch); + allocate_pinned_io(p.get(), ps, os, max_batch, stream_); } // If all batch sizes are pre-loaded, disable deferred compilation @@ -3980,7 +3994,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.release_state_func = [](FunctionState state) { if (state) { auto* s = static_cast(state); - free_pinned_io(s); + free_pinned_io(s, s->stream); delete s; } }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f09ccaea07288..ce13a80312b64 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -56,6 +56,7 @@ struct MIGraphXFuncState { migraphx::target t{}; std::unordered_map input_name_indexes; std::mutex* mgx_mu_ptr = nullptr; + hipStream_t stream = nullptr; bool defer_compilation = false; bool fp16_enable = false; bool bf16_enable = false; From 2fc5e0fb1a2665753467419bb906c8278c1c83d5 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 11 Apr 2026 23:28:46 -0500 Subject: [PATCH 04/16] Add hip_graph_enable flag to MIGraphX EP provider options Plumbs a new boolean option through the full provider options pipeline: - Provider option key: migraphx_hip_graph_enable - Environment variable: ORT_MIGRAPHX_HIP_GRAPH_ENABLE - MIGraphXExecutionProviderInfo struct field - EP constructor initialization and env override - GetProviderOptions / ToProviderOptions round-trip - Hash function for provider info No behavioral changes; flag is wired but not yet consumed. --- .../core/providers/migraphx/migraphx_execution_provider.cc | 7 +++++-- .../core/providers/migraphx/migraphx_execution_provider.h | 3 +++ .../providers/migraphx/migraphx_execution_provider_info.cc | 2 ++ .../providers/migraphx/migraphx_execution_provider_info.h | 5 ++++- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 8e9bbf0a48d56..349aac104a635 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -168,7 +168,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv external_free_{info.external_free}, external_empty_cache_{info.external_empty_cache}, max_dynamic_batch_{info.max_dynamic_batch}, - compile_batches_{info.compile_batches} { + compile_batches_{info.compile_batches}, + hip_graph_enable_{info.hip_graph_enable} { InitProviderOrtApi(); // Set GPU device to be used and read device properties for feature usage. @@ -213,6 +214,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); GET_ENV_STRING(migraphx_env_vars::kCompileBatches, compile_batches_); + GET_ENV_BOOL(migraphx_env_vars::kHipGraphEnable, hip_graph_enable_); // If compile_batches is set, auto-derive max_dynamic_batch from the spec's max value if (!compile_batches_.empty()) { @@ -292,7 +294,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_ << "\n " << migraphx_provider_option::kModelMaxDynamicBatch << ": " << max_dynamic_batch_ - << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_); + << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_) + << "\n " << migraphx_provider_option::kHipGraphEnable << ": " << hip_graph_enable_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index ce13a80312b64..9cbb80fe15e5c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -37,6 +37,7 @@ constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; constexpr auto kModelMaxDynamicBatch = "ORT_MIGRAPHX_MAX_DYNAMIC_BATCH"sv; constexpr auto kCompileBatches = "ORT_MIGRAPHX_COMPILE_BATCHES"sv; +constexpr auto kHipGraphEnable = "ORT_MIGRAPHX_HIP_GRAPH_ENABLE"sv; } // namespace migraphx_env_vars // Tracks which dimensions are symbolic for a given input @@ -206,6 +207,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch_)}, {std::string{migraphx_provider_option::kCompileBatches}, compile_batches_}, + {std::string{migraphx_provider_option::kHipGraphEnable}, MakeStringWithClassicLocale(hip_graph_enable_)}, {std::string{migraphx_provider_option::kHasUserComputeStream}, MakeStringWithClassicLocale(external_stream_)}, {std::string{migraphx_provider_option::kUserComputeStream}, MakeStringWithClassicLocale(reinterpret_cast(stream_))}}; } @@ -250,6 +252,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool first_start_ = true; size_t max_dynamic_batch_{0}; std::string compile_batches_{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + bool hip_graph_enable_{false}; }; }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 81169dc59e327..33876c173c41b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -74,6 +74,7 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .AddAssignmentToReference(migraphx_provider_option::kModelMaxDynamicBatch, max_dynamic_batch) .AddAssignmentToReference(migraphx_provider_option::kCompileBatches, compile_batches) + .AddAssignmentToReference(migraphx_provider_option::kHipGraphEnable, hip_graph_enable) .AddAssignmentToReference(migraphx_provider_option::kHasUserComputeStream, has_user_compute_stream) .AddValueParser( migraphx_provider_option::kUserComputeStream, @@ -118,6 +119,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch)}, {std::string{migraphx_provider_option::kCompileBatches}, compile_batches}, + {std::string{migraphx_provider_option::kHipGraphEnable}, MakeStringWithClassicLocale(hip_graph_enable)}, {std::string{migraphx_provider_option::kHasUserComputeStream}, MakeStringWithClassicLocale(has_user_compute_stream)}, {std::string{migraphx_provider_option::kUserComputeStream}, MakeStringWithClassicLocale(reinterpret_cast(user_compute_stream))}, }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index f814485b901dc..1a7754bd70656 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -36,6 +36,7 @@ constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; constexpr auto kModelMaxDynamicBatch = "migraphx_max_dynamic_batch"sv; constexpr auto kCompileBatches = "migraphx_compile_batches"sv; +constexpr auto kHipGraphEnable = "migraphx_hip_graph_enable"sv; constexpr auto kHasUserComputeStream = "has_user_compute_stream"sv; constexpr auto kUserComputeStream = "user_compute_stream"sv; } // namespace migraphx_provider_option @@ -61,6 +62,7 @@ struct MIGraphXExecutionProviderInfo { OrtArenaCfg* default_memory_arena_cfg{nullptr}; size_t max_dynamic_batch{static_cast(0)}; std::string compile_batches{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + bool hip_graph_enable{false}; void* external_alloc{nullptr}; void* external_free{nullptr}; @@ -94,7 +96,8 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ (static_cast(info.exhaustive_tune) << 21) ^ - (static_cast(info.bf16_enable) << 22); + (static_cast(info.bf16_enable) << 22) ^ + (static_cast(info.hip_graph_enable) << 23); onnxruntime::HashCombine(data, value); From 5abfe34df2fd3679f0b1a90e448f9565f2a2ca0e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 23 Apr 2026 09:48:36 -0500 Subject: [PATCH 05/16] Filter hipGraph capture based on MIGraphX ENV variables Add requirements to disable graph capture if MIGraphX env variables are set on session creation. --- .../migraphx/migraphx_execution_provider.cc | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 349aac104a635..9754edae1fcd7 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -216,6 +216,34 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_STRING(migraphx_env_vars::kCompileBatches, compile_batches_); GET_ENV_BOOL(migraphx_env_vars::kHipGraphEnable, hip_graph_enable_); + // hipGraph requires single-stream MIGraphX execution (MIGRAPHX_NSTREAMS=1). + if (hip_graph_enable_) { + const auto nstreams_env = GetEnvironmentVar("MIGRAPHX_NSTREAMS"); + int nstreams = nstreams_env.empty() ? 1 : std::stoi(nstreams_env); + if (nstreams > 1) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_NSTREAMS=" << nstreams + << " is incompatible with hipGraph capture. Disabling hipGraph."; + hip_graph_enable_ = false; + } + + const auto trace_env = GetEnvironmentVar("MIGRAPHX_TRACE_EVAL"); + if (!trace_env.empty() && std::stoi(trace_env) != 0) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_TRACE_EVAL is enabled, which calls hipStreamSynchronize " + << "per instruction. Disabling hipGraph."; + hip_graph_enable_ = false; + } + + const auto null_stream_env = GetEnvironmentVar("MIGRAPHX_ENABLE_NULL_STREAM"); + if (!null_stream_env.empty() && std::stoi(null_stream_env) != 0) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_ENABLE_NULL_STREAM is enabled (default stream = illegal " + << "during capture). Disabling hipGraph."; + hip_graph_enable_ = false; + } + } + // If compile_batches is set, auto-derive max_dynamic_batch from the spec's max value if (!compile_batches_.empty()) { auto explicit_sizes = parse_compile_batches(compile_batches_); From c7ae4276501a1b926b15ba77a2dbe30054f2cd4a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 23 Apr 2026 12:30:25 -0500 Subject: [PATCH 06/16] Add hipGraph Capture/replay to MIGraphXFuncState Ensure we have primatives we need to perform hipGraph capture/replay --- .../migraphx/migraphx_execution_provider.h | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 9cbb80fe15e5c..f939c8c0495be 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -148,8 +148,20 @@ struct MIGraphXFuncState { // Track which program hash the cached shapes belong to (invalidate when program changes) std::string cached_program_hash; - - + + // ═══════════════════════════════════════════════════════════════════════════ + // hipGraph CAPTURE / REPLAY + // ═══════════════════════════════════════════════════════════════════════════ + + struct CapturedHipGraph { + hipGraph_t graph = nullptr; + hipGraphExec_t exec = nullptr; + bool captured = false; + }; + + bool hip_graph_enabled = false; + // shape_hash -> captured graph (one per compiled program variant) + std::unordered_map hip_graph_cache; }; // Logical device representation. From bbe62780a2465913ad7856d6081c36cd3838aff1 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 24 Apr 2026 14:58:13 -0500 Subject: [PATCH 07/16] Add INFO-level debug logging to MIGraphX EP compile and compute pipeline Track execution path taken (ultra-fast/fast/standard), hipGraph warmup/ capture/replay phases, pinned I/O buffer allocation sizes, dynamic batch compilation progress, and program cache hit/miss across all code paths. Made-with: Cursor --- .../migraphx/migraphx_execution_provider.cc | 658 ++++++++++++------ 1 file changed, 457 insertions(+), 201 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9754edae1fcd7..f0e27bbcdd5b3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1480,11 +1480,18 @@ static void allocate_pinned_io( hipStream_t stream) { auto& pio = mgx_state->pinned_io; - if (pio.allocated) return; + if (pio.allocated) { + LOGS_DEFAULT(INFO) << "[PinnedIO] Already allocated (max_batch=" << pio.max_batch_size + << "), skipping re-allocation for max_batch=" << max_batch_size; + return; + } + + LOGS_DEFAULT(INFO) << "[PinnedIO] Allocating buffers: max_batch=" << max_batch_size; const auto& map_input_name_index = mgx_state->input_name_indexes; pio.inputs.clear(); + std::size_t input_total_bytes = 0; for (const auto& name : param_shapes.names()) { if (map_input_name_index.find(name) == map_input_name_index.end()) continue; const auto& base_shape = param_shapes[name]; @@ -1492,6 +1499,7 @@ static void allocate_pinned_io( if (!lens.empty()) lens[0] = max_batch_size; auto max_shape = migraphx::shape(base_shape.type(), lens); std::size_t bytes = max_shape.bytes(); + input_total_bytes += bytes; void* ptr = nullptr; HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); @@ -1500,11 +1508,13 @@ static void allocate_pinned_io( } pio.outputs.clear(); + std::size_t output_total_bytes = 0; for (const auto& out_shape : output_shapes) { auto lens = out_shape.lengths(); if (!lens.empty()) lens[0] = max_batch_size; auto max_shape = migraphx::shape(out_shape.type(), lens); std::size_t bytes = max_shape.bytes(); + output_total_bytes += bytes; void* ptr = nullptr; HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); @@ -1517,12 +1527,12 @@ static void allocate_pinned_io( pio.max_batch_size = max_batch_size; pio.allocated = true; - std::size_t total_bytes = 0; - for (const auto& b : pio.inputs) total_bytes += b.size_bytes; - for (const auto& b : pio.outputs) total_bytes += b.size_bytes; + std::size_t total_bytes = input_total_bytes + output_total_bytes; LOGS_DEFAULT(INFO) << "[PinnedIO] Allocated: max_batch=" << max_batch_size << " inputs=" << pio.inputs.size() + << " (" << (input_total_bytes / (1024.0 * 1024.0)) << " MB)" << " outputs=" << pio.outputs.size() + << " (" << (output_total_bytes / (1024.0 * 1024.0)) << " MB)" << " total=" << (total_bytes / (1024.0 * 1024.0)) << " MB"; } @@ -1659,6 +1669,104 @@ static void copy_pinned_outputs_to_ort( } +// Helper: Run the MIGraphX program and handle outputs +// This function executes the compiled MIGraphX program and copies outputs that +// were not pre-allocated (input parameters reused as outputs) to the ORT output tensors +// If original_batch_size is provided and < padded batch size, slices the output to remove padding +static void run_migraphx_program( + std::mutex* mgx_mu_ptr, + hipStream_t rocm_stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + LOGS_DEFAULT(INFO) << "[RunMIGraphX] run_async START" + << " original_batch=" << original_batch_size + << " padded_batch=" << padded_batch_size; + std::optional prog_outputs; + { + std::lock_guard lock(*mgx_mu_ptr); + prog_outputs = prog.run_async(m, rocm_stream); + } + LOGS_DEFAULT(INFO) << "[RunMIGraphX] run_async DONE"; + + bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size); + + auto output_num = prog_outputs->size(); + + // Fast path: no padding/slicing and all outputs were pre-allocated — nothing to do. + if (!needs_slicing && prog_output_indices.size() == output_num) + return; + + std::unordered_set prog_output_indices_set(prog_output_indices.begin(), prog_output_indices.end()); + + if (needs_slicing && !prog_output_indices_set.empty()) { + // Must sync before reallocating any pre-allocated output buffer for slicing. + HIP_CALL_THROW(hipStreamSynchronize(rocm_stream)); + + for (std::size_t i = 0; i < output_num; ++i) { + if (prog_output_indices_set.count(i) == 0) continue; + + auto gpu_res = (*prog_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { + ort_shape[0] = static_cast(original_batch_size); + + std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; + std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; + + const void* src_data = gpu_res.data(); + auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + if (output_data != src_data) { + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + src_data, + bytes_to_copy, + hipMemcpyDeviceToDevice, + rocm_stream)); + } + } + } + } + + // Copy outputs that were not pre-allocated into ORT output tensors. + // All copies are async on rocm_stream — no sync needed here. + for (std::size_t i = 0; i < output_num; ++i) { + if (prog_output_indices_set.count(i) > 0) continue; + + auto gpu_res = (*prog_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + if (needs_slicing && !ort_shape.empty()) { + ort_shape[0] = original_batch_size; + } + + auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + std::size_t bytes_to_copy = res_shape.bytes(); + if (needs_slicing && !res_lens.empty()) { + bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; + } + + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + gpu_res.data(), + bytes_to_copy, + hipMemcpyDeviceToDevice, + rocm_stream)); + } +} + // Clear cached MIGraphX shapes (call when program changes) static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { @@ -1668,6 +1776,157 @@ static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { mgx_state->cached_program_hash.clear(); } +// ═══════════════════════════════════════════════════════════════════════════════ +// hipGraph CAPTURE / REPLAY helpers +// ═══════════════════════════════════════════════════════════════════════════════ + +static bool check_hip_graph_compatibility(const migraphx::program& prog, + const std::string& node_name) { + /* std::ostringstream prog_text; + prog.print(prog_text); + const std::string text = prog_text.str(); + + static const std::vector unsafe_ops = { + "hip::sync_stream", + "hip::allocate", + "hip::copy_from_gpu", + "hip::copy_to_gpu", + "gpu::record_event", + "gpu::wait_event", + "gpu::set_stream", + }; + + for (const auto& op : unsafe_ops) { + if (text.find(op) != std::string::npos) { + LOGS_DEFAULT(WARNING) + << "[HipGraph] Node '" << node_name + << "' contains '" << op + << "' which is incompatible with hipGraph capture. " + << "Falling back to eager execution for this node."; + return false; + } + } */ + return true; +} + +static void destroy_hip_graphs(MIGraphXFuncState* mgx_state) { + for (auto& [hash, entry] : mgx_state->hip_graph_cache) { + if (entry.exec) { + (void)hipGraphExecDestroy(entry.exec); + entry.exec = nullptr; + } + if (entry.graph) { + (void)hipGraphDestroy(entry.graph); + entry.graph = nullptr; + } + entry.captured = false; + } + mgx_state->hip_graph_cache.clear(); +} + +// Warmup run (ensures lazy GPU allocations are finalized) then capture the graph. +// The warmup output is discarded — pinned output buffers are overwritten on replay. +static bool warmup_and_capture_hip_graph( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::string& shape_hash) +{ + LOGS_DEFAULT(INFO) << "[HipGraph] WARMUP: running eager execution for hash=" << shape_hash.substr(0, 12) << "..."; + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + prog.run_async(m, stream); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + LOGS_DEFAULT(INFO) << "[HipGraph] WARMUP complete, starting CAPTURE"; + + auto& entry = mgx_state->hip_graph_cache[shape_hash]; + + try { + HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + prog.run_async(m, stream); + } + hipError_t err = hipStreamEndCapture(stream, &entry.graph); + if (err != hipSuccess || entry.graph == nullptr) { + LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE FAILED (err=" << err + << "). Falling back to eager execution."; + entry.graph = nullptr; + entry.captured = false; + mgx_state->hip_graph_enabled = false; + return false; + } + + HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); + entry.captured = true; + LOGS_DEFAULT(INFO) << "[HipGraph] CAPTURE SUCCESS: instantiated graph for hash=" + << shape_hash.substr(0, 12) << "..." + << " total_graphs=" << mgx_state->hip_graph_cache.size(); + return true; + } catch (...) { + LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE EXCEPTION. Falling back to eager execution."; + hipGraph_t dummy = nullptr; + (void)hipStreamEndCapture(stream, &dummy); + if (dummy) (void)hipGraphDestroy(dummy); + entry.graph = nullptr; + entry.exec = nullptr; + entry.captured = false; + mgx_state->hip_graph_enabled = false; + return false; + } +} + +static void replay_hip_graph(MIGraphXFuncState* mgx_state, + hipStream_t stream, + const std::string& shape_hash) { + LOGS_DEFAULT(INFO) << "[HipGraph] REPLAY launch for hash=" << shape_hash.substr(0, 12) << "..."; + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); +} + +// Dispatch point: replay a cached hipGraph, capture one on first use, or fall back to eager. +// This replaces run_migraphx_program in all pinned-I/O paths when hipGraph is enabled. +// IMPORTANT: when hipGraph is enabled this function must ONLY be called via the pinned-I/O +// code path so that buffer addresses captured in the graph remain stable across replays. +static void run_program_or_hip_graph( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + if (!mgx_state->hip_graph_enabled) { + LOGS_DEFAULT(INFO) << "[Dispatch] EAGER run_async (hipGraph disabled)" + << " hash=" << shape_hash.substr(0, 12) << "..."; + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + return; + } + + auto it = mgx_state->hip_graph_cache.find(shape_hash); + if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + LOGS_DEFAULT(INFO) << "[Dispatch] REPLAY hipGraph" + << " hash=" << shape_hash.substr(0, 12) << "..." + << " cache_size=" << mgx_state->hip_graph_cache.size(); + replay_hip_graph(mgx_state, stream, shape_hash); + } else { + LOGS_DEFAULT(INFO) << "[Dispatch] WARMUP+CAPTURE hipGraph" + << " hash=" << shape_hash.substr(0, 12) << "..." + << " cache_size=" << mgx_state->hip_graph_cache.size(); + if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, shape_hash)) { + LOGS_DEFAULT(WARNING) << "[Dispatch] Capture FAILED — falling back to EAGER"; + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + } + } +} + // Order matters here especially if the program uses mixed quantization // Calibrate on full precision for int8/fp8 and then quantize down to fp16 void calibrate_and_quantize(migraphx::program& prog, @@ -1866,18 +2125,12 @@ static migraphx::program load_or_compile_model( { migraphx::program prog; - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ==== ENTERING ===="; - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Cache file: " << (cache_file.empty() ? "(none)" : cache_file.string()); - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Batch size: " << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); + LOGS_DEFAULT(INFO) << "[LoadOrCompile] batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)") + << " cache_file=" << (cache_file.empty() ? "(none)" : cache_file.string()); if (!load_precompiled_model(prog, cache_file)) { - // Cache miss - need to compile - if (batch_size > 0) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✗ CACHE MISS for batch size " << batch_size << " - COMPILING..."; - } else { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✗ CACHE MISS - COMPILING..."; - } - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Compilation started (this may take a while)..."; + LOGS_DEFAULT(INFO) << "[LoadOrCompile] DISK CACHE MISS — COMPILING batch_size=" + << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); prog = CompileProgramWithBatch( onnx_string, @@ -1897,119 +2150,22 @@ static migraphx::program load_or_compile_model( all_input_base_shapes, batch_size); - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Compilation finished"; + LOGS_DEFAULT(INFO) << "[LoadOrCompile] Compilation DONE batch_size=" + << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); save_compiled_model(prog, cache_file); if (!cache_file.empty()) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Saved compiled model to disk: " << cache_file.string(); + LOGS_DEFAULT(INFO) << "[LoadOrCompile] Saved to disk: " << cache_file.string(); } } else { - // Cache hit - loaded from disk - if (batch_size > 0) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✓ CACHE HIT - LOADING FROM DISK for batch size " << batch_size; - } else { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✓ CACHE HIT - LOADING FROM DISK"; - } - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Loaded precompiled model from: " << cache_file.string(); + LOGS_DEFAULT(INFO) << "[LoadOrCompile] DISK CACHE HIT — loaded batch_size=" + << (batch_size > 0 ? std::to_string(batch_size) : "(default)") + << " from " << cache_file.string(); } - - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ==== EXITING ===="; return prog; } -// Helper: Run the MIGraphX program and handle outputs -// This function executes the compiled MIGraphX program and copies outputs that -// were not pre-allocated (input parameters reused as outputs) to the ORT output tensors -// If original_batch_size is provided and < padded batch size, slices the output to remove padding -static void run_migraphx_program( - std::mutex* mgx_mu_ptr, - hipStream_t rocm_stream, - Ort::KernelContext& ctx, - migraphx::program& prog, - migraphx::program_parameters& m, - const std::vector& prog_output_indices, - std::size_t original_batch_size = 0, - std::size_t padded_batch_size = 0) -{ - std::optional prog_outputs; - { - std::lock_guard lock(*mgx_mu_ptr); - prog_outputs = prog.run_async(m, rocm_stream); - } - - bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); - auto output_num = prog_outputs->size(); - - // Fast path: no padding/slicing and all outputs were pre-allocated — nothing to do. - if (!needs_slicing && prog_output_indices.size() == output_num) - return; - - std::unordered_set prog_output_indices_set(prog_output_indices.begin(), prog_output_indices.end()); - - if (needs_slicing && !prog_output_indices_set.empty()) { - // Must sync before reallocating any pre-allocated output buffer for slicing. - HIP_CALL_THROW(hipStreamSynchronize(rocm_stream)); - - for (std::size_t i = 0; i < output_num; ++i) { - if (prog_output_indices_set.count(i) == 0) continue; - - auto gpu_res = (*prog_outputs)[i]; - migraphx::shape res_shape = gpu_res.get_shape(); - auto res_lens = res_shape.lengths(); - - std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { - ort_shape[0] = static_cast(original_batch_size); - - std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; - std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; - - const void* src_data = gpu_res.data(); - auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); - void* output_data = output_tensor.GetTensorMutableRawData(); - - if (output_data != src_data) { - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - src_data, - bytes_to_copy, - hipMemcpyDeviceToDevice, - rocm_stream)); - } - } - } - } - - // Copy outputs that were not pre-allocated into ORT output tensors. - // All copies are async on rocm_stream — no sync needed here. - for (std::size_t i = 0; i < output_num; ++i) { - if (prog_output_indices_set.count(i) > 0) continue; - - auto gpu_res = (*prog_outputs)[i]; - migraphx::shape res_shape = gpu_res.get_shape(); - auto res_lens = res_shape.lengths(); - - std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (needs_slicing && !ort_shape.empty()) { - ort_shape[0] = original_batch_size; - } - - auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); - void* output_data = output_tensor.GetTensorMutableRawData(); - - std::size_t bytes_to_copy = res_shape.bytes(); - if (needs_slicing && !res_lens.empty()) { - bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; - } - - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - gpu_res.data(), - bytes_to_copy, - hipMemcpyDeviceToDevice, - rocm_stream)); - } -} // Helper: Handle input shape mismatch by recompiling the model with new input shapes // This function is called when runtime input shapes differ from compiled shapes @@ -2346,25 +2502,33 @@ static bool execute_ultra_fast_path( .GetTensorTypeAndShapeInfo().GetShape()[0]) : 0); std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; - bool needs_pinned = (actual_batch < compiled_batch) && mgx_state->pinned_io.allocated; + // hipGraph requires stable buffer addresses → always route through pinned I/O + bool needs_pinned = ((actual_batch < compiled_batch) || mgx_state->hip_graph_enabled) + && mgx_state->pinned_io.allocated; if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { - // Pinned I/O path: pad inputs -> run -> slice outputs + LOGS_DEFAULT(INFO) << "[UltraFastPath] batch=" << actual_batch + << " compiled_batch=" << compiled_batch + << " mode=PINNED_IO" + << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); auto& m = mgx_state->cached_prog_params.value(); - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, - mgx_state->cached_prog_output_indices); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, + mgx_state->cached_prog_output_indices, + mgx_state->last_input_shape_hash); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, ctx, actual_batch, rocm_stream); return true; } - // Direct ORT pointer path: bind ORT tensors directly — zero memcpy overhead + LOGS_DEFAULT(INFO) << "[UltraFastPath] batch=" << actual_batch + << " compiled_batch=" << compiled_batch + << " mode=DIRECT_ORT_POINTERS"; auto& m = mgx_state->cached_prog_params.value(); for (const auto& inp : mgx_state->cached_inputs) { const auto& input_tensor = ctx.GetInput(inp.ort_index); @@ -2503,10 +2667,17 @@ static bool execute_fast_path( } } std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; - bool needs_pinned = (actual_batch < compiled_batch) && mgx_state->pinned_io.allocated; + // hipGraph requires stable buffer addresses → always route through pinned I/O + bool needs_pinned = ((actual_batch < compiled_batch) || mgx_state->hip_graph_enabled) + && mgx_state->pinned_io.allocated; if (needs_pinned) { - // Pinned I/O path: pad inputs -> run on compiled shape -> slice outputs + LOGS_DEFAULT(INFO) << "[FastPath] batch=" << actual_batch + << " compiled_batch=" << compiled_batch + << " padded=" << (needs_padding ? "YES" : "NO") + << " mode=PINNED_IO" + << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") + << " program_changed=" << (program_changed ? "YES" : "NO"); copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); @@ -2517,16 +2688,20 @@ static bool execute_fast_path( mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, - mgx_state->cached_prog_params.value(), - mgx_state->cached_prog_output_indices); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, + mgx_state->cached_prog_params.value(), + mgx_state->cached_prog_output_indices, + effective_program_hash); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, ctx, actual_batch, rocm_stream); return true; } - // Direct ORT pointer path: bind ORT tensors directly — zero memcpy overhead + LOGS_DEFAULT(INFO) << "[FastPath] batch=" << actual_batch + << " compiled_batch=" << compiled_batch + << " mode=DIRECT_ORT_POINTERS" + << " program_changed=" << (program_changed ? "YES" : "NO"); auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -2628,22 +2803,23 @@ static void compile_dynamic_batch_models( const std::string& mxr_filename_prefix, const Ort::KernelContext& ctx) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] has_dynamic_batch = " << mgx_state->has_dynamic_batch; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] compiled_batch_sizes.size() = " - << mgx_state->compiled_batch_sizes.size(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] max_dynamic_batch = " << mgx_state->max_dynamic_batch; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] has_dynamic_batch=" << mgx_state->has_dynamic_batch + << " batch_sizes_count=" << mgx_state->compiled_batch_sizes.size() + << " max_dynamic_batch=" << mgx_state->max_dynamic_batch; if (!mgx_state->has_dynamic_batch || mgx_state->compiled_batch_sizes.empty()) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; return; } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Compiling models for " - << mgx_state->compiled_batch_sizes.size() << " batch sizes"; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Batch sizes: "; - for (const auto& bs : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] - " << bs; + { + std::ostringstream bs_list; + for (std::size_t i = 0; i < mgx_state->compiled_batch_sizes.size(); ++i) { + if (i > 0) bs_list << ", "; + bs_list << mgx_state->compiled_batch_sizes[i]; + } + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Batch sizes to compile: [" << bs_list.str() << "]"; } // Get input names and base shapes (without batch dimension) @@ -2651,15 +2827,15 @@ static void compile_dynamic_batch_models( std::vector input_names; std::vector> all_input_base_shapes; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; for (const auto& [name, index] : map_input_name_index) { input_names.push_back(name); auto input_tensor = ctx.GetInput(index); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Input '" << name << "' (index " << index - << ") runtime shape: [" << [&]() { + LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Input '" << name << "' (index " << index + << ") runtime shape: [" << [&]() { std::ostringstream ss; for (size_t i = 0; i < tensor_shape.size(); ++i) { if (i > 0) ss << ", "; @@ -2687,9 +2863,8 @@ static void compile_dynamic_batch_models( // Compile a model for each configured batch size for (const auto& batch_size : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ---- Processing batch size: " << batch_size << " ----"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Processing batch_size=" << batch_size; - // Build cache key for this batch size std::vector batch_shape_key; for (size_t i = 0; i < input_names.size(); ++i) { batch_shape_key.push_back(batch_size); @@ -2699,40 +2874,24 @@ static void compile_dynamic_batch_models( } auto cache_hash = make_hash(batch_shape_key); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Shape key for batch " << batch_size << ": [" << [&]() { - std::ostringstream ss; - for (size_t i = 0; i < batch_shape_key.size(); ++i) { - if (i > 0) ss << ", "; - ss << batch_shape_key[i]; - } - return ss.str(); - }() << "]"; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Cache hash: " << cache_hash; - - // Check if already cached if (mgx_state->cached_programs_ref.has_value()) { auto& cached_progs = mgx_state->cached_programs_ref.value().get(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Checking in-memory cache (size: " - << cached_progs.size() << ")"; if (cached_progs.find(cache_hash) != cached_progs.end()) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ✓ Batch size " << batch_size - << " already in memory cache, skipping"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + << " CACHE HIT (in-memory), skipping"; continue; } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Cache miss - need to compile/load"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + << " CACHE MISS (in-memory cache_size=" << cached_progs.size() << ")"; } - // Build cache file path std::filesystem::path batch_cache_file; if (!model_cache_path.empty()) { batch_cache_file = model_cache_path / (mxr_filename_prefix + cache_hash + ".mxr"); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Disk cache file: " << batch_cache_file.string(); - } else { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] No disk cache path configured"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Disk cache: " << batch_cache_file.string(); } - // Compile or load the model for this batch size - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Calling load_or_compile_model for batch " << batch_size; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] load_or_compile batch_size=" << batch_size << " hash=" << cache_hash.substr(0, 12) << "..."; migraphx::program batch_prog = load_or_compile_model( batch_cache_file, mgx_state->onnx_string, @@ -2752,24 +2911,19 @@ static void compile_dynamic_batch_models( all_input_base_shapes, batch_size); - // Store in cache if (mgx_state->cached_programs_ref.has_value()) { mgx_state->cached_programs_ref.value().get()[cache_hash] = batch_prog; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ✓ Stored program for batch size " << batch_size - << " in memory cache with hash " << cache_hash; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Memory cache now contains " - << mgx_state->cached_programs_ref.value().get().size() << " programs"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + << " STORED in cache (total_programs=" + << mgx_state->cached_programs_ref.value().get().size() << ")"; } } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== All batch models compiled and cached ===="; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Setting max_dynamic_batch to 0 to disable future compilation"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== All batch models compiled/loaded ===="; - // Disable dynamic batch compilation for subsequent runs (set max_dynamic_batch to 0) mgx_state->max_dynamic_batch = 0; - mgx_state->defer_compilation = false; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Set defer_compilation = false"; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] defer_compilation=false, max_dynamic_batch=0"; // Allocate pinned I/O now that all batch models are compiled if (!mgx_state->pinned_io.allocated && mgx_state->cached_programs_ref.has_value()) { @@ -2789,7 +2943,7 @@ static void compile_dynamic_batch_models( } } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; + LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; } // Standard path: Shape checking, potential recompilation, and execution @@ -2812,12 +2966,25 @@ static void execute_standard_path( // NOTE: max_dynamic_batch > 0 means compilation was deferred to runtime (not precompiled) // If precompilation happened during Compile(), max_dynamic_batch will be > 0 but defer_compilation = false // In that case, the programs are already in cache and we can skip runtime compilation + LOGS_DEFAULT(INFO) << "[StandardPath] Entered: defer_compilation=" << mgx_state->defer_compilation + << " has_dynamic_batch=" << mgx_state->has_dynamic_batch + << " max_dynamic_batch=" << mgx_state->max_dynamic_batch + << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") + << " pinned_allocated=" << (mgx_state->pinned_io.allocated ? "YES" : "NO"); + if (mgx_state->has_dynamic_batch && mgx_state->max_dynamic_batch > 0 && mgx_state->defer_compilation) { - // Runtime compilation path - used when precompilation was not possible (e.g., non-pure dynamic batch) - - // Compile all batch models at runtime + LOGS_DEFAULT(INFO) << "[StandardPath] Triggering DEFERRED COMPILATION for all batch sizes"; compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx); - + + // Validate newly compiled programs for hipGraph compatibility + if (mgx_state->hip_graph_enabled && mgx_state->cached_programs_ref.has_value()) { + for (const auto& [hash, cached_prog] : mgx_state->cached_programs_ref.value().get()) { + if (!check_hip_graph_compatibility(cached_prog, "runtime_dynamic_batch")) { + mgx_state->hip_graph_enabled = false; + break; + } + } + } } else if (mgx_state->has_dynamic_batch) { } @@ -2870,15 +3037,19 @@ static void execute_standard_path( if (prog_it != cached_progs.end()) { prog = prog_it->second; - // Get shapes for the cached program auto param_shapes = prog.get_parameter_shapes(); auto output_shapes = prog.get_output_shapes(); populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, original_batch_size, padded_batch_size); - if (needs_padding) { - // Padding required — use pinned I/O for pad + slice + bool use_pinned = needs_padding || mgx_state->hip_graph_enabled; + if (use_pinned) { + LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << original_batch_size + << " padded_batch=" << padded_batch_size + << " padded=" << (needs_padding ? "YES" : "NO") + << " mode=PINNED_IO (dynamic batch cache hit)" + << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); if (!mgx_state->pinned_io.allocated) { std::size_t max_batch = padded_batch_size; if (!mgx_state->compiled_batch_sizes.empty()) { @@ -2888,7 +3059,8 @@ static void execute_standard_path( allocate_pinned_io(mgx_state, param_shapes, output_shapes, max_batch, rocm_stream); } - copy_inputs_to_pinned(mgx_state, param_shapes, ctx, original_batch_size, padded_batch_size, rocm_stream); + std::size_t copy_actual = needs_padding ? original_batch_size : padded_batch_size; + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, copy_actual, padded_batch_size, rocm_stream); auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); mgx_state->cached_prog_params = m; @@ -2898,12 +3070,15 @@ static void execute_standard_path( mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, out_indices); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, + out_indices, padded_hash); copy_pinned_outputs_to_ort(mgx_state, output_shapes, out_indices, - ctx, original_batch_size, rocm_stream); + ctx, copy_actual, rocm_stream); } else { - // Exact batch match — direct ORT pointer binding + LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << original_batch_size + << " padded_batch=" << padded_batch_size + << " mode=DIRECT_ORT_POINTERS (dynamic batch cache hit, exact match)"; auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -2925,7 +3100,7 @@ static void execute_standard_path( mgx_state->defer_compilation, map_input_name_index, ctx, cmp_options, prog); if (!input_shape_match) { - // Invalidate caches before recompilation + LOGS_DEFAULT(INFO) << "[StandardPath] Input shape MISMATCH — triggering recompilation"; mgx_state->caches_valid = false; handle_input_shape_mismatch( @@ -2937,16 +3112,19 @@ static void execute_standard_path( param_shapes, input_shapes); - // Re-fetch param_shapes after recompilation param_shapes = prog.get_parameter_shapes(); + LOGS_DEFAULT(INFO) << "[StandardPath] Recompilation complete"; + + if (mgx_state->hip_graph_enabled && !check_hip_graph_compatibility(prog, "standard_path_recompile")) { + mgx_state->hip_graph_enabled = false; + } } auto output_shapes = prog.get_output_shapes(); populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); - // Lazily allocate pinned I/O for future pad/slice use (e.g. first inference after JIT compile). - // The buffers are not used on this path since the program was compiled for the exact shape. + // Allocate pinned I/O: required for hipGraph (stable addresses), also useful for future pad/slice. if (!mgx_state->pinned_io.allocated) { std::size_t batch_for_alloc = 0; if (!mgx_state->compiled_batch_sizes.empty()) { @@ -2964,7 +3142,43 @@ static void execute_standard_path( } } - // Direct ORT pointer path: program compiled for exact runtime shape — no padding needed + if (mgx_state->hip_graph_enabled && mgx_state->pinned_io.allocated) { + std::size_t actual_batch = 0; + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } + } + LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << actual_batch + << " mode=PINNED_IO (hipGraph static path)" + << " hipGraph=ON"; + + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, actual_batch, rocm_stream); + auto [m, prog_output_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + + mgx_state->cached_prog_params = m; + mgx_state->cached_prog_output_indices = prog_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, + prog_output_indices, current_hash); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, prog_output_indices, + ctx, actual_batch, rocm_stream); + return; + } + + { + std::size_t actual_batch = 0; + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } + } + LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << actual_batch + << " mode=DIRECT_ORT_POINTERS (static fallback)" + << " hipGraph=OFF"; + } auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -3940,13 +4154,28 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& NodeComputeInfo compute_info; compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); - *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], - map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_defer_compilation_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map_, - model_cache_path_, dump_model_ops_, exhaustive_tune_, max_dynamic_batch_, - std::ref(cached_programs_[context->node_name])}; + p->allocate_func = context->allocate_func; + p->release_func = context->release_func; + p->allocate_handle = context->allocator_handle; + p->prog = map_progs_[context->node_name]; + p->onnx_string = map_onnx_string_[context->node_name]; + p->options = options; + p->t = t_; + p->input_name_indexes = map_input_index_[context->node_name]; + p->mgx_mu_ptr = &mgx_mu_; p->stream = stream_; + p->defer_compilation = map_defer_compilation_[context->node_name]; + p->fp16_enable = fp16_enable_; + p->bf16_enable = bf16_enable_; + p->fp8_enable = fp8_enable_; + p->int8_enable = int8_enable_; + p->int8_calibration_cache_available = int8_calibration_cache_available_; + p->dynamic_range_map = dynamic_range_map_; + p->model_cache_dir = model_cache_path_; + p->dump_model_ops = dump_model_ops_; + p->exhaustive_tune = exhaustive_tune_; + p->max_dynamic_batch = max_dynamic_batch_; + p->cached_programs_ref = std::ref(cached_programs_[context->node_name]); // Initialize dynamic batch support if max_dynamic_batch > 0 if (max_dynamic_batch_ > 0) { @@ -4018,6 +4247,22 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } + // hipGraph: set per-node enable flag and validate cached programs + p->hip_graph_enabled = hip_graph_enable_; + if (p->hip_graph_enabled && p->cached_programs_ref.has_value()) { + for (const auto& [hash, cached_prog] : p->cached_programs_ref.value().get()) { + if (!check_hip_graph_compatibility(cached_prog, context->node_name)) { + p->hip_graph_enabled = false; + break; + } + } + if (p->hip_graph_enabled) { + LOGS_DEFAULT(INFO) << "[HipGraph] Enabled for node '" << context->node_name << "'"; + } else { + LOGS_DEFAULT(INFO) << "[HipGraph] Disabled for node '" << context->node_name << "'"; + } + } + *state = p.release(); return 0; }; @@ -4025,6 +4270,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.release_state_func = [](FunctionState state) { if (state) { auto* s = static_cast(state); + destroy_hip_graphs(s); free_pinned_io(s, s->stream); delete s; } @@ -4036,8 +4282,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& const auto& map_input_name_index = mgx_state->input_name_indexes; - // stream_ is always valid: either the user's external stream or an - // EP-owned hipStreamNonBlocking created in the constructor. + // Determine batch size from first input for logging + std::size_t log_batch = 0; + for (const auto& [name, index] : map_input_name_index) { + const auto& shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { log_batch = static_cast(shape[0]); break; } + } // ═══════════════════════════════════════════════════════════════════════ // ULTRA-FAST PATH: Shapes unchanged from last run @@ -4049,6 +4299,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // ═══════════════════════════════════════════════════════════════════════ // Build input shape hash - only computed when shapes change // ═══════════════════════════════════════════════════════════════════════ + LOGS_DEFAULT(INFO) << "[Compute] UltraFast miss — building hash for batch=" << log_batch + << " inputs=" << map_input_name_index.size() + << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); std::vector all_input_shapes; all_input_shapes.reserve(map_input_name_index.size() * 4); for (const auto& [name, index] : map_input_name_index) { @@ -4064,6 +4317,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } + LOGS_DEFAULT(INFO) << "[Compute] FastPath miss — entering StandardPath for batch=" << log_batch + << " hash=" << current_hash.substr(0, 12) << "..."; + // ═══════════════════════════════════════════════════════════════════════ // STANDARD PATH: Shape checking and potential recompilation // ═══════════════════════════════════════════════════════════════════════ From 0d1324c912831d845ee35d27f4a8db345d2758cf Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 24 Apr 2026 16:45:32 -0500 Subject: [PATCH 08/16] fixes for hipgraph capture/replay WIP --- .../migraphx/migraphx_execution_provider.cc | 366 +++++++++++++----- .../migraphx/migraphx_execution_provider.h | 11 + 2 files changed, 270 insertions(+), 107 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f0e27bbcdd5b3..33b64cf1a5748 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1481,16 +1481,17 @@ static void allocate_pinned_io( { auto& pio = mgx_state->pinned_io; if (pio.allocated) { - LOGS_DEFAULT(INFO) << "[PinnedIO] Already allocated (max_batch=" << pio.max_batch_size + LOGS_DEFAULT(WARNING) << "[PinnedIO] Already allocated (max_batch=" << pio.max_batch_size << "), skipping re-allocation for max_batch=" << max_batch_size; return; } - LOGS_DEFAULT(INFO) << "[PinnedIO] Allocating buffers: max_batch=" << max_batch_size; + LOGS_DEFAULT(WARNING) << "[PinnedIO] Allocating buffers: max_batch=" << max_batch_size; const auto& map_input_name_index = mgx_state->input_name_indexes; pio.inputs.clear(); + pio.input_name_to_idx.clear(); std::size_t input_total_bytes = 0; for (const auto& name : param_shapes.names()) { if (map_input_name_index.find(name) == map_input_name_index.end()) continue; @@ -1501,6 +1502,7 @@ static void allocate_pinned_io( std::size_t bytes = max_shape.bytes(); input_total_bytes += bytes; + pio.input_name_to_idx[name] = pio.inputs.size(); void* ptr = nullptr; HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); @@ -1508,18 +1510,29 @@ static void allocate_pinned_io( } pio.outputs.clear(); + pio.output_name_to_idx.clear(); std::size_t output_total_bytes = 0; - for (const auto& out_shape : output_shapes) { + std::size_t output_alloc_idx = 0; + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) != map_input_name_index.end()) continue; + const auto oi = compute_output_index(name); + if (oi == -1) continue; + if (static_cast(oi) >= output_shapes.size()) continue; + if (output_alloc_idx >= output_shapes.size()) break; + + const auto& out_shape = output_shapes[oi]; auto lens = out_shape.lengths(); if (!lens.empty()) lens[0] = max_batch_size; auto max_shape = migraphx::shape(out_shape.type(), lens); std::size_t bytes = max_shape.bytes(); output_total_bytes += bytes; + pio.output_name_to_idx[name] = pio.outputs.size(); void* ptr = nullptr; HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); pio.outputs.push_back({ptr, bytes, max_shape}); + ++output_alloc_idx; } HIP_CALL_THROW(hipStreamSynchronize(stream)); @@ -1528,7 +1541,7 @@ static void allocate_pinned_io( pio.allocated = true; std::size_t total_bytes = input_total_bytes + output_total_bytes; - LOGS_DEFAULT(INFO) << "[PinnedIO] Allocated: max_batch=" << max_batch_size + LOGS_DEFAULT(WARNING) << "[PinnedIO] Allocated: max_batch=" << max_batch_size << " inputs=" << pio.inputs.size() << " (" << (input_total_bytes / (1024.0 * 1024.0)) << " MB)" << " outputs=" << pio.outputs.size() @@ -1551,7 +1564,6 @@ static void free_pinned_io(MIGraphXFuncState* mgx_state, hipStream_t stream) { } // Copy ORT input tensors into pinned buffers and pad if needed. -// Returns the number of bytes-per-batch for each input (used for element-size calc). static void copy_inputs_to_pinned( MIGraphXFuncState* mgx_state, const migraphx::program_parameter_shapes& param_shapes, @@ -1562,13 +1574,15 @@ static void copy_inputs_to_pinned( { auto& pio = mgx_state->pinned_io; const auto& map_input_name_index = mgx_state->input_name_indexes; - std::size_t idx = 0; for (const auto& name : param_shapes.names()) { auto it = map_input_name_index.find(name); if (it == map_input_name_index.end()) continue; - auto& pin = pio.inputs[idx]; + auto pin_it = pio.input_name_to_idx.find(name); + if (pin_it == pio.input_name_to_idx.end()) continue; + auto& pin = pio.inputs[pin_it->second]; + const auto& input_tensor = ctx.GetInput(it->second); const void* src = input_tensor.GetTensorRawData(); const auto& base_shape = param_shapes[name]; @@ -1582,8 +1596,10 @@ static void copy_inputs_to_pinned( std::size_t byte_per_elem = (total_elems > 0) ? base_shape.bytes() / total_elems : 0; std::size_t bytes_per_batch = elements_per_batch * byte_per_elem; + std::size_t copy_bytes = actual_batch * bytes_per_batch; + if (copy_bytes > pin.size_bytes) copy_bytes = pin.size_bytes; + if (actual_batch == compiled_batch) { - std::size_t copy_bytes = actual_batch * bytes_per_batch; if (copy_bytes > 0) { HIP_CALL_THROW(hipMemcpyAsync(pin.data, src, copy_bytes, hipMemcpyDefault, stream)); } @@ -1591,12 +1607,20 @@ static void copy_inputs_to_pinned( pad_input_tensor(src, pin.data, actual_batch, compiled_batch, byte_per_elem, elements_per_batch, stream); } - ++idx; } } // Build program_parameters binding pinned buffers at the given compiled shape. -static std::pair> +// Uses name-based lookup into pinned buffers so parameter ordering differences +// between compiled programs don't cause mismatched buffer access. +// Returns: {program_parameters, ORT_output_indices, pinned_buffer_indices} +struct PinnedBindResult { + migraphx::program_parameters params; + std::vector prog_output_indices; + std::vector pinned_output_indices; +}; + +static PinnedBindResult bind_pinned_program_params( MIGraphXFuncState* mgx_state, const migraphx::program_parameter_shapes& param_shapes, @@ -1605,27 +1629,26 @@ bind_pinned_program_params( auto& pio = mgx_state->pinned_io; const auto& map_input_name_index = mgx_state->input_name_indexes; - migraphx::program_parameters m; - std::vector prog_output_indices; - - std::size_t input_idx = 0; - std::size_t output_idx = 0; + PinnedBindResult result; for (const auto& name : param_shapes.names()) { if (map_input_name_index.find(name) != map_input_name_index.end()) { - m.add(name, migraphx::argument(param_shapes[name], pio.inputs[input_idx].data)); - ++input_idx; + auto pin_it = pio.input_name_to_idx.find(name); + if (pin_it == pio.input_name_to_idx.end()) continue; + result.params.add(name, migraphx::argument(param_shapes[name], pio.inputs[pin_it->second].data)); } else { const auto oi = compute_output_index(name); if (oi != -1) { - m.add(name, migraphx::argument(param_shapes[name], pio.outputs[output_idx].data)); - prog_output_indices.push_back(static_cast(oi)); - ++output_idx; + auto pin_it = pio.output_name_to_idx.find(name); + if (pin_it == pio.output_name_to_idx.end()) continue; + result.params.add(name, migraphx::argument(param_shapes[name], pio.outputs[pin_it->second].data)); + result.prog_output_indices.push_back(static_cast(oi)); + result.pinned_output_indices.push_back(pin_it->second); } } } - return {std::move(m), std::move(prog_output_indices)}; + return result; } // Copy results from pinned output buffers to ORT output tensors. @@ -1633,15 +1656,18 @@ static void copy_pinned_outputs_to_ort( MIGraphXFuncState* mgx_state, const migraphx::shapes& output_shapes, const std::vector& prog_output_indices, + const std::vector& pinned_output_indices, Ort::KernelContext& ctx, std::size_t actual_batch, hipStream_t stream) { auto& pio = mgx_state->pinned_io; - for (std::size_t i = 0; i < prog_output_indices.size() && i < pio.outputs.size(); ++i) { + for (std::size_t i = 0; i < prog_output_indices.size() && i < pinned_output_indices.size(); ++i) { const auto oi = prog_output_indices[i]; - const auto& pin = pio.outputs[i]; + const auto pin_idx = pinned_output_indices[i]; + if (pin_idx >= pio.outputs.size()) continue; + const auto& pin = pio.outputs[pin_idx]; const auto& out_shape = output_shapes[oi]; auto lens = out_shape.lengths(); @@ -1683,7 +1709,7 @@ static void run_migraphx_program( std::size_t original_batch_size = 0, std::size_t padded_batch_size = 0) { - LOGS_DEFAULT(INFO) << "[RunMIGraphX] run_async START" + LOGS_DEFAULT(WARNING) << "[RunMIGraphX] run_async START" << " original_batch=" << original_batch_size << " padded_batch=" << padded_batch_size; std::optional prog_outputs; @@ -1691,7 +1717,7 @@ static void run_migraphx_program( std::lock_guard lock(*mgx_mu_ptr); prog_outputs = prog.run_async(m, rocm_stream); } - LOGS_DEFAULT(INFO) << "[RunMIGraphX] run_async DONE"; + LOGS_DEFAULT(WARNING) << "[RunMIGraphX] run_async DONE"; bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && original_batch_size < padded_batch_size); @@ -1825,21 +1851,32 @@ static void destroy_hip_graphs(MIGraphXFuncState* mgx_state) { } // Warmup run (ensures lazy GPU allocations are finalized) then capture the graph. -// The warmup output is discarded — pinned output buffers are overwritten on replay. +// Stores extra (non-pre-allocated) output metadata so replay can materialize them. static bool warmup_and_capture_hip_graph( MIGraphXFuncState* mgx_state, hipStream_t stream, migraphx::program& prog, migraphx::program_parameters& m, + const std::vector& prog_output_indices, const std::string& shape_hash) { - LOGS_DEFAULT(INFO) << "[HipGraph] WARMUP: running eager execution for hash=" << shape_hash.substr(0, 12) << "..."; + // Zero all pinned buffers before warmup to avoid stale data from prior batch runs + auto& pio = mgx_state->pinned_io; + for (auto& pin : pio.inputs) { + HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); + } + for (auto& pin : pio.outputs) { + HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); + } + + LOGS_DEFAULT(WARNING) << "[HipGraph] WARMUP: running eager execution for hash=" << shape_hash.substr(0, 12) << "..."; + std::optional warmup_outputs; { std::lock_guard lock(*mgx_state->mgx_mu_ptr); - prog.run_async(m, stream); + warmup_outputs = prog.run_async(m, stream); } HIP_CALL_THROW(hipStreamSynchronize(stream)); - LOGS_DEFAULT(INFO) << "[HipGraph] WARMUP complete, starting CAPTURE"; + LOGS_DEFAULT(WARNING) << "[HipGraph] WARMUP complete, starting CAPTURE"; auto& entry = mgx_state->hip_graph_cache[shape_hash]; @@ -1861,7 +1898,29 @@ static bool warmup_and_capture_hip_graph( HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); entry.captured = true; - LOGS_DEFAULT(INFO) << "[HipGraph] CAPTURE SUCCESS: instantiated graph for hash=" + + std::unordered_set pre_alloc_set(prog_output_indices.begin(), + prog_output_indices.end()); + entry.extra_outputs.clear(); + if (warmup_outputs) { + auto output_num = warmup_outputs->size(); + for (std::size_t i = 0; i < output_num; ++i) { + if (pre_alloc_set.count(i) > 0) continue; + auto gpu_res = (*warmup_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + entry.extra_outputs.push_back({i, std::move(ort_shape), + gpu_res.data(), res_shape.bytes()}); + } + if (!entry.extra_outputs.empty()) { + LOGS_DEFAULT(WARNING) << "[HipGraph] Recorded " << entry.extra_outputs.size() + << " extra (non-pre-allocated) outputs for hash=" + << shape_hash.substr(0, 12) << "..."; + } + } + + LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE SUCCESS: instantiated graph for hash=" << shape_hash.substr(0, 12) << "..." << " total_graphs=" << mgx_state->hip_graph_cache.size(); return true; @@ -1881,11 +1940,43 @@ static bool warmup_and_capture_hip_graph( static void replay_hip_graph(MIGraphXFuncState* mgx_state, hipStream_t stream, const std::string& shape_hash) { - LOGS_DEFAULT(INFO) << "[HipGraph] REPLAY launch for hash=" << shape_hash.substr(0, 12) << "..."; + LOGS_DEFAULT(WARNING) << "[HipGraph] REPLAY launch for hash=" << shape_hash.substr(0, 12) << "..."; auto& entry = mgx_state->hip_graph_cache.at(shape_hash); HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); } +// Materialize extra (non-pre-allocated) outputs recorded during hipGraph capture. +// These are MIGraphX outputs not exposed as named parameters — their GPU data +// pointers are stable across replays because hipGraph replays the same kernels. +static void materialize_extra_outputs( + Ort::KernelContext& ctx, + hipStream_t stream, + const std::vector& extras, + std::size_t original_batch_size, + std::size_t padded_batch_size) +{ + bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size); + + for (const auto& extra : extras) { + auto ort_shape = extra.ort_shape; + std::size_t bytes = extra.bytes; + if (needs_slicing && !ort_shape.empty()) { + std::size_t full_batch = static_cast(ort_shape[0]); + if (full_batch > 0) { + bytes = (extra.bytes / full_batch) * original_batch_size; + } + ort_shape[0] = static_cast(original_batch_size); + } + + auto output_tensor = ctx.GetOutput(extra.output_index, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + HIP_CALL_THROW(hipMemcpyWithStream(output_data, extra.gpu_data, + bytes, hipMemcpyDeviceToDevice, stream)); + } +} + // Dispatch point: replay a cached hipGraph, capture one on first use, or fall back to eager. // This replaces run_migraphx_program in all pinned-I/O paths when hipGraph is enabled. // IMPORTANT: when hipGraph is enabled this function must ONLY be called via the pinned-I/O @@ -1902,7 +1993,7 @@ static void run_program_or_hip_graph( std::size_t padded_batch_size = 0) { if (!mgx_state->hip_graph_enabled) { - LOGS_DEFAULT(INFO) << "[Dispatch] EAGER run_async (hipGraph disabled)" + LOGS_DEFAULT(WARNING) << "[Dispatch] EAGER run_async (hipGraph disabled)" << " hash=" << shape_hash.substr(0, 12) << "..."; run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, prog_output_indices, original_batch_size, padded_batch_size); @@ -1911,18 +2002,30 @@ static void run_program_or_hip_graph( auto it = mgx_state->hip_graph_cache.find(shape_hash); if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { - LOGS_DEFAULT(INFO) << "[Dispatch] REPLAY hipGraph" + LOGS_DEFAULT(WARNING) << "[Dispatch] REPLAY hipGraph" << " hash=" << shape_hash.substr(0, 12) << "..." << " cache_size=" << mgx_state->hip_graph_cache.size(); replay_hip_graph(mgx_state, stream, shape_hash); + + if (!it->second.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, it->second.extra_outputs, + original_batch_size, padded_batch_size); + } } else { - LOGS_DEFAULT(INFO) << "[Dispatch] WARMUP+CAPTURE hipGraph" + LOGS_DEFAULT(WARNING) << "[Dispatch] WARMUP+CAPTURE hipGraph" << " hash=" << shape_hash.substr(0, 12) << "..." << " cache_size=" << mgx_state->hip_graph_cache.size(); - if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, shape_hash)) { + if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, + prog_output_indices, shape_hash)) { LOGS_DEFAULT(WARNING) << "[Dispatch] Capture FAILED — falling back to EAGER"; run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, prog_output_indices, original_batch_size, padded_batch_size); + } else { + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + if (!entry.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, entry.extra_outputs, + original_batch_size, padded_batch_size); + } } } } @@ -2125,11 +2228,11 @@ static migraphx::program load_or_compile_model( { migraphx::program prog; - LOGS_DEFAULT(INFO) << "[LoadOrCompile] batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)") + LOGS_DEFAULT(WARNING) << "[LoadOrCompile] batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)") << " cache_file=" << (cache_file.empty() ? "(none)" : cache_file.string()); if (!load_precompiled_model(prog, cache_file)) { - LOGS_DEFAULT(INFO) << "[LoadOrCompile] DISK CACHE MISS — COMPILING batch_size=" + LOGS_DEFAULT(WARNING) << "[LoadOrCompile] DISK CACHE MISS — COMPILING batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); prog = CompileProgramWithBatch( @@ -2150,15 +2253,15 @@ static migraphx::program load_or_compile_model( all_input_base_shapes, batch_size); - LOGS_DEFAULT(INFO) << "[LoadOrCompile] Compilation DONE batch_size=" + LOGS_DEFAULT(WARNING) << "[LoadOrCompile] Compilation DONE batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); save_compiled_model(prog, cache_file); if (!cache_file.empty()) { - LOGS_DEFAULT(INFO) << "[LoadOrCompile] Saved to disk: " << cache_file.string(); + LOGS_DEFAULT(WARNING) << "[LoadOrCompile] Saved to disk: " << cache_file.string(); } } else { - LOGS_DEFAULT(INFO) << "[LoadOrCompile] DISK CACHE HIT — loaded batch_size=" + LOGS_DEFAULT(WARNING) << "[LoadOrCompile] DISK CACHE HIT — loaded batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)") << " from " << cache_file.string(); } @@ -2507,7 +2610,7 @@ static bool execute_ultra_fast_path( && mgx_state->pinned_io.allocated; if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { - LOGS_DEFAULT(INFO) << "[UltraFastPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[UltraFastPath] batch=" << actual_batch << " compiled_batch=" << compiled_batch << " mode=PINNED_IO" << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); @@ -2522,11 +2625,12 @@ static bool execute_ultra_fast_path( mgx_state->last_input_shape_hash); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + mgx_state->cached_pinned_output_indices, ctx, actual_batch, rocm_stream); return true; } - LOGS_DEFAULT(INFO) << "[UltraFastPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[UltraFastPath] batch=" << actual_batch << " compiled_batch=" << compiled_batch << " mode=DIRECT_ORT_POINTERS"; auto& m = mgx_state->cached_prog_params.value(); @@ -2672,17 +2776,18 @@ static bool execute_fast_path( && mgx_state->pinned_io.allocated; if (needs_pinned) { - LOGS_DEFAULT(INFO) << "[FastPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[FastPath] batch=" << actual_batch << " compiled_batch=" << compiled_batch << " padded=" << (needs_padding ? "YES" : "NO") << " mode=PINNED_IO" << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") << " program_changed=" << (program_changed ? "YES" : "NO"); copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); - auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); - mgx_state->cached_prog_params = std::move(m); - mgx_state->cached_prog_output_indices = std::move(out_indices); + mgx_state->cached_prog_params = std::move(bind_result.params); + mgx_state->cached_prog_output_indices = std::move(bind_result.prog_output_indices); + mgx_state->cached_pinned_output_indices = std::move(bind_result.pinned_output_indices); mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( mgx_state, ctx, padded_batch_size); mgx_state->last_input_shape_hash = current_hash; @@ -2694,11 +2799,12 @@ static bool execute_fast_path( effective_program_hash); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + mgx_state->cached_pinned_output_indices, ctx, actual_batch, rocm_stream); return true; } - LOGS_DEFAULT(INFO) << "[FastPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[FastPath] batch=" << actual_batch << " compiled_batch=" << compiled_batch << " mode=DIRECT_ORT_POINTERS" << " program_changed=" << (program_changed ? "YES" : "NO"); @@ -2803,13 +2909,13 @@ static void compile_dynamic_batch_models( const std::string& mxr_filename_prefix, const Ort::KernelContext& ctx) { - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] has_dynamic_batch=" << mgx_state->has_dynamic_batch + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] has_dynamic_batch=" << mgx_state->has_dynamic_batch << " batch_sizes_count=" << mgx_state->compiled_batch_sizes.size() << " max_dynamic_batch=" << mgx_state->max_dynamic_batch; if (!mgx_state->has_dynamic_batch || mgx_state->compiled_batch_sizes.empty()) { - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; return; } @@ -2819,7 +2925,7 @@ static void compile_dynamic_batch_models( if (i > 0) bs_list << ", "; bs_list << mgx_state->compiled_batch_sizes[i]; } - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Batch sizes to compile: [" << bs_list.str() << "]"; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Batch sizes to compile: [" << bs_list.str() << "]"; } // Get input names and base shapes (without batch dimension) @@ -2827,7 +2933,7 @@ static void compile_dynamic_batch_models( std::vector input_names; std::vector> all_input_base_shapes; - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; for (const auto& [name, index] : map_input_name_index) { input_names.push_back(name); auto input_tensor = ctx.GetInput(index); @@ -2863,7 +2969,7 @@ static void compile_dynamic_batch_models( // Compile a model for each configured batch size for (const auto& batch_size : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Processing batch_size=" << batch_size; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Processing batch_size=" << batch_size; std::vector batch_shape_key; for (size_t i = 0; i < input_names.size(); ++i) { @@ -2877,21 +2983,21 @@ static void compile_dynamic_batch_models( if (mgx_state->cached_programs_ref.has_value()) { auto& cached_progs = mgx_state->cached_programs_ref.value().get(); if (cached_progs.find(cache_hash) != cached_progs.end()) { - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size << " CACHE HIT (in-memory), skipping"; continue; } - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size << " CACHE MISS (in-memory cache_size=" << cached_progs.size() << ")"; } std::filesystem::path batch_cache_file; if (!model_cache_path.empty()) { batch_cache_file = model_cache_path / (mxr_filename_prefix + cache_hash + ".mxr"); - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] Disk cache: " << batch_cache_file.string(); + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Disk cache: " << batch_cache_file.string(); } - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] load_or_compile batch_size=" << batch_size << " hash=" << cache_hash.substr(0, 12) << "..."; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] load_or_compile batch_size=" << batch_size << " hash=" << cache_hash.substr(0, 12) << "..."; migraphx::program batch_prog = load_or_compile_model( batch_cache_file, mgx_state->onnx_string, @@ -2913,19 +3019,21 @@ static void compile_dynamic_batch_models( if (mgx_state->cached_programs_ref.has_value()) { mgx_state->cached_programs_ref.value().get()[cache_hash] = batch_prog; - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] batch_size=" << batch_size + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size << " STORED in cache (total_programs=" << mgx_state->cached_programs_ref.value().get().size() << ")"; } } - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== All batch models compiled/loaded ===="; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== All batch models compiled/loaded ===="; mgx_state->max_dynamic_batch = 0; mgx_state->defer_compilation = false; - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] defer_compilation=false, max_dynamic_batch=0"; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] defer_compilation=false, max_dynamic_batch=0"; - // Allocate pinned I/O now that all batch models are compiled + // Allocate pinned I/O now that all batch models are compiled. + // Must use the largest-batch program's shapes so the buffer count and + // parameter ordering match every subsequent bind_pinned_program_params call. if (!mgx_state->pinned_io.allocated && mgx_state->cached_programs_ref.has_value()) { auto& progs = mgx_state->cached_programs_ref.value().get(); if (!progs.empty()) { @@ -2934,16 +3042,35 @@ static void compile_dynamic_batch_models( max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), mgx_state->compiled_batch_sizes.end()); } - auto& last_prog = progs.begin()->second; - auto ps = last_prog.get_parameter_shapes(); - auto os = last_prog.get_output_shapes(); - if (max_batch > 0) { + migraphx::program* largest_prog = nullptr; + std::size_t largest_batch_found = 0; + for (auto& [hash, prog] : progs) { + auto ps = prog.get_parameter_shapes(); + std::size_t prog_batch = 0; + for (const auto& name : ps.names()) { + if (mgx_state->input_name_indexes.find(name) != mgx_state->input_name_indexes.end()) { + auto lens = ps[name].lengths(); + if (!lens.empty() && lens[0] > 0) { + prog_batch = lens[0]; + break; + } + } + } + if (prog_batch > largest_batch_found) { + largest_batch_found = prog_batch; + largest_prog = &prog; + } + } + if (max_batch == 0) max_batch = largest_batch_found; + if (largest_prog && max_batch > 0) { + auto ps = largest_prog->get_parameter_shapes(); + auto os = largest_prog->get_output_shapes(); allocate_pinned_io(mgx_state, ps, os, max_batch, mgx_state->stream); } } } - LOGS_DEFAULT(INFO) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; + LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; } // Standard path: Shape checking, potential recompilation, and execution @@ -2966,14 +3093,14 @@ static void execute_standard_path( // NOTE: max_dynamic_batch > 0 means compilation was deferred to runtime (not precompiled) // If precompilation happened during Compile(), max_dynamic_batch will be > 0 but defer_compilation = false // In that case, the programs are already in cache and we can skip runtime compilation - LOGS_DEFAULT(INFO) << "[StandardPath] Entered: defer_compilation=" << mgx_state->defer_compilation + LOGS_DEFAULT(WARNING) << "[StandardPath] Entered: defer_compilation=" << mgx_state->defer_compilation << " has_dynamic_batch=" << mgx_state->has_dynamic_batch << " max_dynamic_batch=" << mgx_state->max_dynamic_batch << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") << " pinned_allocated=" << (mgx_state->pinned_io.allocated ? "YES" : "NO"); if (mgx_state->has_dynamic_batch && mgx_state->max_dynamic_batch > 0 && mgx_state->defer_compilation) { - LOGS_DEFAULT(INFO) << "[StandardPath] Triggering DEFERRED COMPILATION for all batch sizes"; + LOGS_DEFAULT(WARNING) << "[StandardPath] Triggering DEFERRED COMPILATION for all batch sizes"; compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx); // Validate newly compiled programs for hipGraph compatibility @@ -3045,7 +3172,7 @@ static void execute_standard_path( bool use_pinned = needs_padding || mgx_state->hip_graph_enabled; if (use_pinned) { - LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << original_batch_size + LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << original_batch_size << " padded_batch=" << padded_batch_size << " padded=" << (needs_padding ? "YES" : "NO") << " mode=PINNED_IO (dynamic batch cache hit)" @@ -3056,27 +3183,49 @@ static void execute_standard_path( max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), mgx_state->compiled_batch_sizes.end()); } - allocate_pinned_io(mgx_state, param_shapes, output_shapes, max_batch, rocm_stream); + auto alloc_ps = param_shapes; + auto alloc_os = output_shapes; + if (max_batch > padded_batch_size && mgx_state->cached_programs_ref.has_value()) { + bool found = false; + for (auto& [h, p] : mgx_state->cached_programs_ref.value().get()) { + if (found) break; + auto candidate_ps = p.get_parameter_shapes(); + for (const auto& nm : candidate_ps.names()) { + if (mgx_state->input_name_indexes.count(nm)) { + auto lens = candidate_ps[nm].lengths(); + if (!lens.empty() && lens[0] == max_batch) { + alloc_ps = candidate_ps; + alloc_os = p.get_output_shapes(); + found = true; + } + break; + } + } + } + } + allocate_pinned_io(mgx_state, alloc_ps, alloc_os, max_batch, rocm_stream); } std::size_t copy_actual = needs_padding ? original_batch_size : padded_batch_size; copy_inputs_to_pinned(mgx_state, param_shapes, ctx, copy_actual, padded_batch_size, rocm_stream); - auto [m, out_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); - mgx_state->cached_prog_params = m; - mgx_state->cached_prog_output_indices = out_indices; + mgx_state->cached_prog_params = bind_result.params; + mgx_state->cached_prog_output_indices = bind_result.prog_output_indices; + mgx_state->cached_pinned_output_indices = bind_result.pinned_output_indices; mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( mgx_state, ctx, padded_batch_size); mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; - run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, - out_indices, padded_hash); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, + bind_result.prog_output_indices, padded_hash); - copy_pinned_outputs_to_ort(mgx_state, output_shapes, out_indices, + copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, + bind_result.pinned_output_indices, ctx, copy_actual, rocm_stream); } else { - LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << original_batch_size + LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << original_batch_size << " padded_batch=" << padded_batch_size << " mode=DIRECT_ORT_POINTERS (dynamic batch cache hit, exact match)"; auto [m, prog_output_indices] = handle_program_input_outputs( @@ -3100,7 +3249,7 @@ static void execute_standard_path( mgx_state->defer_compilation, map_input_name_index, ctx, cmp_options, prog); if (!input_shape_match) { - LOGS_DEFAULT(INFO) << "[StandardPath] Input shape MISMATCH — triggering recompilation"; + LOGS_DEFAULT(WARNING) << "[StandardPath] Input shape MISMATCH — triggering recompilation"; mgx_state->caches_valid = false; handle_input_shape_mismatch( @@ -3113,7 +3262,7 @@ static void execute_standard_path( input_shapes); param_shapes = prog.get_parameter_shapes(); - LOGS_DEFAULT(INFO) << "[StandardPath] Recompilation complete"; + LOGS_DEFAULT(WARNING) << "[StandardPath] Recompilation complete"; if (mgx_state->hip_graph_enabled && !check_hip_graph_compatibility(prog, "standard_path_recompile")) { mgx_state->hip_graph_enabled = false; @@ -3148,23 +3297,25 @@ static void execute_standard_path( auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } } - LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << actual_batch << " mode=PINNED_IO (hipGraph static path)" << " hipGraph=ON"; copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, actual_batch, rocm_stream); - auto [m, prog_output_indices] = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); - mgx_state->cached_prog_params = m; - mgx_state->cached_prog_output_indices = prog_output_indices; + mgx_state->cached_prog_params = bind_result.params; + mgx_state->cached_prog_output_indices = bind_result.prog_output_indices; + mgx_state->cached_pinned_output_indices = bind_result.pinned_output_indices; mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, - prog_output_indices, current_hash); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, + bind_result.prog_output_indices, current_hash); - copy_pinned_outputs_to_ort(mgx_state, output_shapes, prog_output_indices, + copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, + bind_result.pinned_output_indices, ctx, actual_batch, rocm_stream); return; } @@ -3175,7 +3326,7 @@ static void execute_standard_path( auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } } - LOGS_DEFAULT(INFO) << "[StandardPath] batch=" << actual_batch + LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << actual_batch << " mode=DIRECT_ORT_POINTERS (static fallback)" << " hipGraph=OFF"; } @@ -4202,34 +4353,35 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } // Allocate pinned I/O buffers from the cached programs. - // For dynamic batch: use max compiled batch size. - // For static: use the program's own batch size. + // Uses the program compiled for the largest batch size so that + // allocate_pinned_io sees parameter shapes whose batch dim matches + // max_batch. All smaller batches share the same buffers. if (p->cached_programs_ref.has_value() && !p->cached_programs_ref.value().get().empty()) { std::size_t max_batch = 0; if (!p->compiled_batch_sizes.empty()) { max_batch = *std::max_element(p->compiled_batch_sizes.begin(), p->compiled_batch_sizes.end()); } - // Find the program with the largest batch to get representative shapes migraphx::program* largest_prog = nullptr; + std::size_t largest_batch_found = 0; for (auto& [hash, prog] : p->cached_programs_ref.value().get()) { - if (!largest_prog) { - largest_prog = &prog; - if (max_batch == 0) { - auto ps = prog.get_parameter_shapes(); - for (const auto& name : ps.names()) { - if (p->input_name_indexes.find(name) != p->input_name_indexes.end()) { - auto lens = ps[name].lengths(); - if (!lens.empty() && lens[0] > 0) { - max_batch = lens[0]; - break; - } - } + auto ps = prog.get_parameter_shapes(); + std::size_t prog_batch = 0; + for (const auto& name : ps.names()) { + if (p->input_name_indexes.find(name) != p->input_name_indexes.end()) { + auto lens = ps[name].lengths(); + if (!lens.empty() && lens[0] > 0) { + prog_batch = lens[0]; + break; } } } - largest_prog = &prog; + if (prog_batch > largest_batch_found) { + largest_batch_found = prog_batch; + largest_prog = &prog; + } } + if (max_batch == 0) max_batch = largest_batch_found; if (largest_prog && max_batch > 0) { auto ps = largest_prog->get_parameter_shapes(); auto os = largest_prog->get_output_shapes(); @@ -4257,9 +4409,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } if (p->hip_graph_enabled) { - LOGS_DEFAULT(INFO) << "[HipGraph] Enabled for node '" << context->node_name << "'"; + LOGS_DEFAULT(WARNING) << "[HipGraph] Enabled for node '" << context->node_name << "'"; } else { - LOGS_DEFAULT(INFO) << "[HipGraph] Disabled for node '" << context->node_name << "'"; + LOGS_DEFAULT(WARNING) << "[HipGraph] Disabled for node '" << context->node_name << "'"; } } @@ -4299,7 +4451,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // ═══════════════════════════════════════════════════════════════════════ // Build input shape hash - only computed when shapes change // ═══════════════════════════════════════════════════════════════════════ - LOGS_DEFAULT(INFO) << "[Compute] UltraFast miss — building hash for batch=" << log_batch + LOGS_DEFAULT(WARNING) << "[Compute] UltraFast miss — building hash for batch=" << log_batch << " inputs=" << map_input_name_index.size() << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); std::vector all_input_shapes; @@ -4317,7 +4469,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } - LOGS_DEFAULT(INFO) << "[Compute] FastPath miss — entering StandardPath for batch=" << log_batch + LOGS_DEFAULT(WARNING) << "[Compute] FastPath miss — entering StandardPath for batch=" << log_batch << " hash=" << current_hash.substr(0, 12) << "..."; // ═══════════════════════════════════════════════════════════════════════ diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f939c8c0495be..891db01eda263 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -87,6 +87,8 @@ struct MIGraphXFuncState { struct PinnedIOSet { std::vector inputs; std::vector outputs; + std::unordered_map input_name_to_idx; + std::unordered_map output_name_to_idx; std::size_t max_batch_size = 0; bool allocated = false; }; @@ -123,6 +125,7 @@ struct MIGraphXFuncState { // Cached output indices for pre-allocated outputs (used by run_migraphx_program) std::vector cached_prog_output_indices; + std::vector cached_pinned_output_indices; // Last input shapes for quick comparison (avoids hash computation in ultra-fast path) std::vector last_input_shapes_raw; @@ -153,10 +156,18 @@ struct MIGraphXFuncState { // hipGraph CAPTURE / REPLAY // ═══════════════════════════════════════════════════════════════════════════ + struct ExtraOutputInfo { + std::size_t output_index; + std::vector ort_shape; + void* gpu_data; + std::size_t bytes; + }; + struct CapturedHipGraph { hipGraph_t graph = nullptr; hipGraphExec_t exec = nullptr; bool captured = false; + std::vector extra_outputs; }; bool hip_graph_enabled = false; From 9138f1b532b6346d225cd061f7de872f69a3f2cf Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sun, 26 Apr 2026 20:02:36 -0500 Subject: [PATCH 09/16] remove debug logging --- .../migraphx/migraphx_execution_provider.cc | 203 ------------------ 1 file changed, 203 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 33b64cf1a5748..eca580c49efc0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1481,18 +1481,13 @@ static void allocate_pinned_io( { auto& pio = mgx_state->pinned_io; if (pio.allocated) { - LOGS_DEFAULT(WARNING) << "[PinnedIO] Already allocated (max_batch=" << pio.max_batch_size - << "), skipping re-allocation for max_batch=" << max_batch_size; return; } - LOGS_DEFAULT(WARNING) << "[PinnedIO] Allocating buffers: max_batch=" << max_batch_size; - const auto& map_input_name_index = mgx_state->input_name_indexes; pio.inputs.clear(); pio.input_name_to_idx.clear(); - std::size_t input_total_bytes = 0; for (const auto& name : param_shapes.names()) { if (map_input_name_index.find(name) == map_input_name_index.end()) continue; const auto& base_shape = param_shapes[name]; @@ -1500,7 +1495,6 @@ static void allocate_pinned_io( if (!lens.empty()) lens[0] = max_batch_size; auto max_shape = migraphx::shape(base_shape.type(), lens); std::size_t bytes = max_shape.bytes(); - input_total_bytes += bytes; pio.input_name_to_idx[name] = pio.inputs.size(); void* ptr = nullptr; @@ -1511,7 +1505,6 @@ static void allocate_pinned_io( pio.outputs.clear(); pio.output_name_to_idx.clear(); - std::size_t output_total_bytes = 0; std::size_t output_alloc_idx = 0; for (const auto& name : param_shapes.names()) { if (map_input_name_index.find(name) != map_input_name_index.end()) continue; @@ -1525,7 +1518,6 @@ static void allocate_pinned_io( if (!lens.empty()) lens[0] = max_batch_size; auto max_shape = migraphx::shape(out_shape.type(), lens); std::size_t bytes = max_shape.bytes(); - output_total_bytes += bytes; pio.output_name_to_idx[name] = pio.outputs.size(); void* ptr = nullptr; @@ -1539,14 +1531,6 @@ static void allocate_pinned_io( pio.max_batch_size = max_batch_size; pio.allocated = true; - - std::size_t total_bytes = input_total_bytes + output_total_bytes; - LOGS_DEFAULT(WARNING) << "[PinnedIO] Allocated: max_batch=" << max_batch_size - << " inputs=" << pio.inputs.size() - << " (" << (input_total_bytes / (1024.0 * 1024.0)) << " MB)" - << " outputs=" << pio.outputs.size() - << " (" << (output_total_bytes / (1024.0 * 1024.0)) << " MB)" - << " total=" << (total_bytes / (1024.0 * 1024.0)) << " MB"; } static void free_pinned_io(MIGraphXFuncState* mgx_state, hipStream_t stream) { @@ -1709,15 +1693,11 @@ static void run_migraphx_program( std::size_t original_batch_size = 0, std::size_t padded_batch_size = 0) { - LOGS_DEFAULT(WARNING) << "[RunMIGraphX] run_async START" - << " original_batch=" << original_batch_size - << " padded_batch=" << padded_batch_size; std::optional prog_outputs; { std::lock_guard lock(*mgx_mu_ptr); prog_outputs = prog.run_async(m, rocm_stream); } - LOGS_DEFAULT(WARNING) << "[RunMIGraphX] run_async DONE"; bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && original_batch_size < padded_batch_size); @@ -1869,14 +1849,12 @@ static bool warmup_and_capture_hip_graph( HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); } - LOGS_DEFAULT(WARNING) << "[HipGraph] WARMUP: running eager execution for hash=" << shape_hash.substr(0, 12) << "..."; std::optional warmup_outputs; { std::lock_guard lock(*mgx_state->mgx_mu_ptr); warmup_outputs = prog.run_async(m, stream); } HIP_CALL_THROW(hipStreamSynchronize(stream)); - LOGS_DEFAULT(WARNING) << "[HipGraph] WARMUP complete, starting CAPTURE"; auto& entry = mgx_state->hip_graph_cache[shape_hash]; @@ -1888,8 +1866,6 @@ static bool warmup_and_capture_hip_graph( } hipError_t err = hipStreamEndCapture(stream, &entry.graph); if (err != hipSuccess || entry.graph == nullptr) { - LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE FAILED (err=" << err - << "). Falling back to eager execution."; entry.graph = nullptr; entry.captured = false; mgx_state->hip_graph_enabled = false; @@ -1913,19 +1889,10 @@ static bool warmup_and_capture_hip_graph( entry.extra_outputs.push_back({i, std::move(ort_shape), gpu_res.data(), res_shape.bytes()}); } - if (!entry.extra_outputs.empty()) { - LOGS_DEFAULT(WARNING) << "[HipGraph] Recorded " << entry.extra_outputs.size() - << " extra (non-pre-allocated) outputs for hash=" - << shape_hash.substr(0, 12) << "..."; - } } - LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE SUCCESS: instantiated graph for hash=" - << shape_hash.substr(0, 12) << "..." - << " total_graphs=" << mgx_state->hip_graph_cache.size(); return true; } catch (...) { - LOGS_DEFAULT(WARNING) << "[HipGraph] CAPTURE EXCEPTION. Falling back to eager execution."; hipGraph_t dummy = nullptr; (void)hipStreamEndCapture(stream, &dummy); if (dummy) (void)hipGraphDestroy(dummy); @@ -1940,7 +1907,6 @@ static bool warmup_and_capture_hip_graph( static void replay_hip_graph(MIGraphXFuncState* mgx_state, hipStream_t stream, const std::string& shape_hash) { - LOGS_DEFAULT(WARNING) << "[HipGraph] REPLAY launch for hash=" << shape_hash.substr(0, 12) << "..."; auto& entry = mgx_state->hip_graph_cache.at(shape_hash); HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); } @@ -1993,8 +1959,6 @@ static void run_program_or_hip_graph( std::size_t padded_batch_size = 0) { if (!mgx_state->hip_graph_enabled) { - LOGS_DEFAULT(WARNING) << "[Dispatch] EAGER run_async (hipGraph disabled)" - << " hash=" << shape_hash.substr(0, 12) << "..."; run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, prog_output_indices, original_batch_size, padded_batch_size); return; @@ -2002,9 +1966,6 @@ static void run_program_or_hip_graph( auto it = mgx_state->hip_graph_cache.find(shape_hash); if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { - LOGS_DEFAULT(WARNING) << "[Dispatch] REPLAY hipGraph" - << " hash=" << shape_hash.substr(0, 12) << "..." - << " cache_size=" << mgx_state->hip_graph_cache.size(); replay_hip_graph(mgx_state, stream, shape_hash); if (!it->second.extra_outputs.empty()) { @@ -2012,12 +1973,8 @@ static void run_program_or_hip_graph( original_batch_size, padded_batch_size); } } else { - LOGS_DEFAULT(WARNING) << "[Dispatch] WARMUP+CAPTURE hipGraph" - << " hash=" << shape_hash.substr(0, 12) << "..." - << " cache_size=" << mgx_state->hip_graph_cache.size(); if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, prog_output_indices, shape_hash)) { - LOGS_DEFAULT(WARNING) << "[Dispatch] Capture FAILED — falling back to EAGER"; run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, prog_output_indices, original_batch_size, padded_batch_size); } else { @@ -2228,12 +2185,7 @@ static migraphx::program load_or_compile_model( { migraphx::program prog; - LOGS_DEFAULT(WARNING) << "[LoadOrCompile] batch_size=" << (batch_size > 0 ? std::to_string(batch_size) : "(default)") - << " cache_file=" << (cache_file.empty() ? "(none)" : cache_file.string()); - if (!load_precompiled_model(prog, cache_file)) { - LOGS_DEFAULT(WARNING) << "[LoadOrCompile] DISK CACHE MISS — COMPILING batch_size=" - << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); prog = CompileProgramWithBatch( onnx_string, @@ -2253,17 +2205,7 @@ static migraphx::program load_or_compile_model( all_input_base_shapes, batch_size); - LOGS_DEFAULT(WARNING) << "[LoadOrCompile] Compilation DONE batch_size=" - << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); - save_compiled_model(prog, cache_file); - if (!cache_file.empty()) { - LOGS_DEFAULT(WARNING) << "[LoadOrCompile] Saved to disk: " << cache_file.string(); - } - } else { - LOGS_DEFAULT(WARNING) << "[LoadOrCompile] DISK CACHE HIT — loaded batch_size=" - << (batch_size > 0 ? std::to_string(batch_size) : "(default)") - << " from " << cache_file.string(); } return prog; } @@ -2610,10 +2552,6 @@ static bool execute_ultra_fast_path( && mgx_state->pinned_io.allocated; if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { - LOGS_DEFAULT(WARNING) << "[UltraFastPath] batch=" << actual_batch - << " compiled_batch=" << compiled_batch - << " mode=PINNED_IO" - << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); @@ -2630,9 +2568,6 @@ static bool execute_ultra_fast_path( return true; } - LOGS_DEFAULT(WARNING) << "[UltraFastPath] batch=" << actual_batch - << " compiled_batch=" << compiled_batch - << " mode=DIRECT_ORT_POINTERS"; auto& m = mgx_state->cached_prog_params.value(); for (const auto& inp : mgx_state->cached_inputs) { const auto& input_tensor = ctx.GetInput(inp.ort_index); @@ -2776,12 +2711,6 @@ static bool execute_fast_path( && mgx_state->pinned_io.allocated; if (needs_pinned) { - LOGS_DEFAULT(WARNING) << "[FastPath] batch=" << actual_batch - << " compiled_batch=" << compiled_batch - << " padded=" << (needs_padding ? "YES" : "NO") - << " mode=PINNED_IO" - << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") - << " program_changed=" << (program_changed ? "YES" : "NO"); copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); @@ -2804,10 +2733,6 @@ static bool execute_fast_path( return true; } - LOGS_DEFAULT(WARNING) << "[FastPath] batch=" << actual_batch - << " compiled_batch=" << compiled_batch - << " mode=DIRECT_ORT_POINTERS" - << " program_changed=" << (program_changed ? "YES" : "NO"); auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -2909,31 +2834,15 @@ static void compile_dynamic_batch_models( const std::string& mxr_filename_prefix, const Ort::KernelContext& ctx) { - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] has_dynamic_batch=" << mgx_state->has_dynamic_batch - << " batch_sizes_count=" << mgx_state->compiled_batch_sizes.size() - << " max_dynamic_batch=" << mgx_state->max_dynamic_batch; - if (!mgx_state->has_dynamic_batch || mgx_state->compiled_batch_sizes.empty()) { - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; return; } - { - std::ostringstream bs_list; - for (std::size_t i = 0; i < mgx_state->compiled_batch_sizes.size(); ++i) { - if (i > 0) bs_list << ", "; - bs_list << mgx_state->compiled_batch_sizes[i]; - } - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Batch sizes to compile: [" << bs_list.str() << "]"; - } - // Get input names and base shapes (without batch dimension) const auto& map_input_name_index = mgx_state->input_name_indexes; std::vector input_names; std::vector> all_input_base_shapes; - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; for (const auto& [name, index] : map_input_name_index) { input_names.push_back(name); auto input_tensor = ctx.GetInput(index); @@ -2969,7 +2878,6 @@ static void compile_dynamic_batch_models( // Compile a model for each configured batch size for (const auto& batch_size : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Processing batch_size=" << batch_size; std::vector batch_shape_key; for (size_t i = 0; i < input_names.size(); ++i) { @@ -2983,21 +2891,15 @@ static void compile_dynamic_batch_models( if (mgx_state->cached_programs_ref.has_value()) { auto& cached_progs = mgx_state->cached_programs_ref.value().get(); if (cached_progs.find(cache_hash) != cached_progs.end()) { - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size - << " CACHE HIT (in-memory), skipping"; continue; } - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size - << " CACHE MISS (in-memory cache_size=" << cached_progs.size() << ")"; } std::filesystem::path batch_cache_file; if (!model_cache_path.empty()) { batch_cache_file = model_cache_path / (mxr_filename_prefix + cache_hash + ".mxr"); - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] Disk cache: " << batch_cache_file.string(); } - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] load_or_compile batch_size=" << batch_size << " hash=" << cache_hash.substr(0, 12) << "..."; migraphx::program batch_prog = load_or_compile_model( batch_cache_file, mgx_state->onnx_string, @@ -3019,17 +2921,11 @@ static void compile_dynamic_batch_models( if (mgx_state->cached_programs_ref.has_value()) { mgx_state->cached_programs_ref.value().get()[cache_hash] = batch_prog; - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] batch_size=" << batch_size - << " STORED in cache (total_programs=" - << mgx_state->cached_programs_ref.value().get().size() << ")"; } } - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== All batch models compiled/loaded ===="; - mgx_state->max_dynamic_batch = 0; mgx_state->defer_compilation = false; - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] defer_compilation=false, max_dynamic_batch=0"; // Allocate pinned I/O now that all batch models are compiled. // Must use the largest-batch program's shapes so the buffer count and @@ -3070,7 +2966,6 @@ static void compile_dynamic_batch_models( } } - LOGS_DEFAULT(WARNING) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; } // Standard path: Shape checking, potential recompilation, and execution @@ -3093,14 +2988,7 @@ static void execute_standard_path( // NOTE: max_dynamic_batch > 0 means compilation was deferred to runtime (not precompiled) // If precompilation happened during Compile(), max_dynamic_batch will be > 0 but defer_compilation = false // In that case, the programs are already in cache and we can skip runtime compilation - LOGS_DEFAULT(WARNING) << "[StandardPath] Entered: defer_compilation=" << mgx_state->defer_compilation - << " has_dynamic_batch=" << mgx_state->has_dynamic_batch - << " max_dynamic_batch=" << mgx_state->max_dynamic_batch - << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF") - << " pinned_allocated=" << (mgx_state->pinned_io.allocated ? "YES" : "NO"); - if (mgx_state->has_dynamic_batch && mgx_state->max_dynamic_batch > 0 && mgx_state->defer_compilation) { - LOGS_DEFAULT(WARNING) << "[StandardPath] Triggering DEFERRED COMPILATION for all batch sizes"; compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx); // Validate newly compiled programs for hipGraph compatibility @@ -3172,11 +3060,6 @@ static void execute_standard_path( bool use_pinned = needs_padding || mgx_state->hip_graph_enabled; if (use_pinned) { - LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << original_batch_size - << " padded_batch=" << padded_batch_size - << " padded=" << (needs_padding ? "YES" : "NO") - << " mode=PINNED_IO (dynamic batch cache hit)" - << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); if (!mgx_state->pinned_io.allocated) { std::size_t max_batch = padded_batch_size; if (!mgx_state->compiled_batch_sizes.empty()) { @@ -3225,9 +3108,6 @@ static void execute_standard_path( bind_result.pinned_output_indices, ctx, copy_actual, rocm_stream); } else { - LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << original_batch_size - << " padded_batch=" << padded_batch_size - << " mode=DIRECT_ORT_POINTERS (dynamic batch cache hit, exact match)"; auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -3249,7 +3129,6 @@ static void execute_standard_path( mgx_state->defer_compilation, map_input_name_index, ctx, cmp_options, prog); if (!input_shape_match) { - LOGS_DEFAULT(WARNING) << "[StandardPath] Input shape MISMATCH — triggering recompilation"; mgx_state->caches_valid = false; handle_input_shape_mismatch( @@ -3262,7 +3141,6 @@ static void execute_standard_path( input_shapes); param_shapes = prog.get_parameter_shapes(); - LOGS_DEFAULT(WARNING) << "[StandardPath] Recompilation complete"; if (mgx_state->hip_graph_enabled && !check_hip_graph_compatibility(prog, "standard_path_recompile")) { mgx_state->hip_graph_enabled = false; @@ -3297,10 +3175,6 @@ static void execute_standard_path( auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } } - LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << actual_batch - << " mode=PINNED_IO (hipGraph static path)" - << " hipGraph=ON"; - copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, actual_batch, rocm_stream); auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); @@ -3320,16 +3194,6 @@ static void execute_standard_path( return; } - { - std::size_t actual_batch = 0; - for (const auto& [name, index] : map_input_name_index) { - auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); - if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } - } - LOGS_DEFAULT(WARNING) << "[StandardPath] batch=" << actual_batch - << " mode=DIRECT_ORT_POINTERS (static fallback)" - << " hipGraph=OFF"; - } auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); @@ -3509,53 +3373,19 @@ extract_base_shapes_from_graph( const NodeArg* nodearg = it->second; auto tensor_shape = nodearg->Shape(); if (tensor_shape != nullptr && tensor_shape->dim_size() > 1) { - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Processing input '" << name - << "' with " << tensor_shape->dim_size() << " dimensions"; - // Extract non-batch dimensions (skip dim 0) for (int j = 1; j < tensor_shape->dim_size(); ++j) { const auto& dim = tensor_shape->dim(j); if (dim.has_dim_value()) { base_shape.push_back(dim.dim_value()); - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] dim[" << j << "] = " << dim.dim_value(); } else { - // Symbolic non-batch dimension found - cannot precompile - // Do NOT default to any value - mark as failure - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' dim " << j << " is symbolic - cannot extract concrete base shapes"; all_concrete = false; } } - } else if (tensor_shape != nullptr && tensor_shape->dim_size() == 1) { - // Single dimension input (just batch) - no base shape needed - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Input '" << name - << "' has only 1 dimension (batch only) - empty base shape"; - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' has null or empty shape"; } - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' not found in NodeArg map - may be subgraph-only input"; } base_shapes.push_back(base_shape); - - // Log the extracted base shape - std::ostringstream ss; - ss << "["; - for (std::size_t k = 0; k < base_shape.size(); ++k) { - if (k > 0) ss << ", "; - ss << base_shape[k]; - } - ss << "]"; - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Input '" << name << "' base_shape: " << ss.str(); } - if (all_concrete) { - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Successfully extracted " << ordered_names.size() - << " input base shapes (all concrete)"; - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Failed - found symbolic non-batch dimensions"; - } return {all_concrete, ordered_names, base_shapes}; } @@ -3903,9 +3733,6 @@ static inline void precompile_static_model( } } } - } else { - LOGS_DEFAULT(WARNING) << "[precompile_static_model] Input '" << name - << "' not found in NodeArg map - skipping"; } full_shapes.push_back(shape); } @@ -4408,11 +4235,6 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& break; } } - if (p->hip_graph_enabled) { - LOGS_DEFAULT(WARNING) << "[HipGraph] Enabled for node '" << context->node_name << "'"; - } else { - LOGS_DEFAULT(WARNING) << "[HipGraph] Disabled for node '" << context->node_name << "'"; - } } *state = p.release(); @@ -4434,26 +4256,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& const auto& map_input_name_index = mgx_state->input_name_indexes; - // Determine batch size from first input for logging - std::size_t log_batch = 0; - for (const auto& [name, index] : map_input_name_index) { - const auto& shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); - if (!shape.empty()) { log_batch = static_cast(shape[0]); break; } - } - - // ═══════════════════════════════════════════════════════════════════════ - // ULTRA-FAST PATH: Shapes unchanged from last run - // ═══════════════════════════════════════════════════════════════════════ if (execute_ultra_fast_path(mgx_state, stream_, ctx)) { return Status::OK(); } - // ═══════════════════════════════════════════════════════════════════════ - // Build input shape hash - only computed when shapes change - // ═══════════════════════════════════════════════════════════════════════ - LOGS_DEFAULT(WARNING) << "[Compute] UltraFast miss — building hash for batch=" << log_batch - << " inputs=" << map_input_name_index.size() - << " hipGraph=" << (mgx_state->hip_graph_enabled ? "ON" : "OFF"); std::vector all_input_shapes; all_input_shapes.reserve(map_input_name_index.size() * 4); for (const auto& [name, index] : map_input_name_index) { @@ -4462,19 +4268,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } const auto current_hash = make_hash(all_input_shapes); - // ═══════════════════════════════════════════════════════════════════════ - // FAST PATH: Check cached programs for this shape hash - // ═══════════════════════════════════════════════════════════════════════ if (execute_fast_path(mgx_state, stream_, ctx, current_hash, all_input_shapes)) { return Status::OK(); } - LOGS_DEFAULT(WARNING) << "[Compute] FastPath miss — entering StandardPath for batch=" << log_batch - << " hash=" << current_hash.substr(0, 12) << "..."; - - // ═══════════════════════════════════════════════════════════════════════ - // STANDARD PATH: Shape checking and potential recompilation - // ═══════════════════════════════════════════════════════════════════════ execute_standard_path(mgx_state, stream_, ctx, current_hash, std::move(all_input_shapes), model_cache_path_, model_path_, mxr_filename_prefix); From dcf04b06e91749c133614886b0236370dd5ee6fa Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 11:29:06 -0500 Subject: [PATCH 10/16] Add size-bucketed pool allocator for hipGraph pointer stability When hipGraph is enabled, the MIGraphXAllocator now caches freed device pointers by size and returns them on subsequent same-size allocations. This ensures ORT tensor buffers get stable GPU addresses across inference calls, which is a prerequisite for capturing hipGraph directly on ORT buffers (eliminating the intermediary pinned-copy overhead). Pool mode is gated behind hip_graph_enable -- zero behavioral change when hipGraph is disabled. Made-with: Cursor --- .../providers/migraphx/migraphx_allocator.cc | 38 +++++++++++++++++-- .../providers/migraphx/migraphx_allocator.h | 14 +++++++ .../migraphx/migraphx_execution_provider.cc | 8 +++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 911a1a7fd18b9..1092a30bd3469 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -22,18 +22,50 @@ void MIGraphXAllocator::CheckDevice() const { #endif } +void MIGraphXAllocator::EnablePoolMode() { + std::lock_guard lock(pool_mu_); + pool_enabled_ = true; +} + void* MIGraphXAllocator::Alloc(size_t size) { CheckDevice(); + if (size == 0) return nullptr; + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + auto it = free_list_.find(size); + if (it != free_list_.end() && !it->second.empty()) { + void* p = it->second.back(); + it->second.pop_back(); + return p; + } + } + void* p = nullptr; - if (size > 0) { - HIP_CALL_THROW(hipMalloc((void**)&p, size)); + HIP_CALL_THROW(hipMalloc((void**)&p, size)); + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + alloc_sizes_[p] = size; } + return p; } void MIGraphXAllocator::Free(void* p) { CheckDevice(); - (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown + if (!p) return; + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + auto it = alloc_sizes_.find(p); + if (it != alloc_sizes_.end()) { + free_list_[it->second].push_back(p); + return; + } + } + + (void)hipFree(p); } void* MIGraphXExternalAllocator::Alloc(size_t size) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 10e06ab2f35ad..ee681e07c546a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -4,7 +4,9 @@ #pragma once #include +#include #include +#include #include "core/framework/allocator.h" namespace onnxruntime { @@ -21,8 +23,20 @@ class MIGraphXAllocator : public IAllocator { virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; + void EnablePoolMode(); + bool IsPoolModeEnabled() const { return pool_enabled_; } + private: void CheckDevice() const; + + // When pool mode is enabled (for hipGraph), freed allocations are cached by + // size so that subsequent Alloc calls for the same size return the same + // device pointer. This provides pointer stability required by hipGraph + // replay without the cost of intermediary buffer copies. + bool pool_enabled_ = false; + mutable std::mutex pool_mu_; + std::unordered_map> free_list_; + std::unordered_map alloc_sizes_; }; class MIGraphXExternalAllocator : public MIGraphXAllocator { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index eca580c49efc0..ff18cef0113d3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -328,8 +328,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA); + [this](OrtDevice::DeviceId device_id) { + auto alloc = std::make_unique(device_id, onnxruntime::CUDA); + if (hip_graph_enable_) { + alloc->EnablePoolMode(); + } + return alloc; }, device_id_); AllocatorCreationInfo pinned_allocator_info( From 049b18725a7af8be95c03e9c37060afd139d1621 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 11:30:31 -0500 Subject: [PATCH 11/16] Add direct-bind hipGraph capture/replay path Add warmup_and_capture_hip_graph_direct() and run_program_or_hip_graph_direct() which capture and replay hipGraphs using ORT's tensor pointers directly instead of intermediary pinned buffers. Captured addresses are stored in CapturedHipGraph so pointer drift can be detected and trigger re-capture. Also adds use_direct_hip_graph flag to MIGraphXFuncState, set alongside hip_graph_enabled during node state creation. Made-with: Cursor --- .../migraphx/migraphx_execution_provider.cc | 135 ++++++++++++++++++ .../migraphx/migraphx_execution_provider.h | 8 ++ 2 files changed, 143 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ff18cef0113d3..d9e233e8ddbac 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1915,6 +1915,139 @@ static void replay_hip_graph(MIGraphXFuncState* mgx_state, HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); } +// Direct-bind capture: bind ORT tensor pointers directly (no pinned buffers) +// and capture the hipGraph. Requires stable pointers from pool allocator. +static bool warmup_and_capture_hip_graph_direct( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + const std::unordered_map& input_ptrs, + const std::unordered_map& output_ptrs) +{ + std::optional warmup_outputs; + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + warmup_outputs = prog.run_async(m, stream); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + auto& entry = mgx_state->hip_graph_cache[shape_hash]; + + try { + HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + prog.run_async(m, stream); + } + hipError_t err = hipStreamEndCapture(stream, &entry.graph); + if (err != hipSuccess || entry.graph == nullptr) { + entry.graph = nullptr; + entry.captured = false; + mgx_state->use_direct_hip_graph = false; + return false; + } + + HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); + entry.captured = true; + entry.captured_input_ptrs = input_ptrs; + entry.captured_output_ptrs = output_ptrs; + + std::unordered_set pre_alloc_set(prog_output_indices.begin(), + prog_output_indices.end()); + entry.extra_outputs.clear(); + if (warmup_outputs) { + auto output_num = warmup_outputs->size(); + for (std::size_t i = 0; i < output_num; ++i) { + if (pre_alloc_set.count(i) > 0) continue; + auto gpu_res = (*warmup_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + entry.extra_outputs.push_back({i, std::move(ort_shape), + gpu_res.data(), res_shape.bytes()}); + } + } + + return true; + } catch (...) { + hipGraph_t dummy = nullptr; + (void)hipStreamEndCapture(stream, &dummy); + if (dummy) (void)hipGraphDestroy(dummy); + entry.graph = nullptr; + entry.exec = nullptr; + entry.captured = false; + mgx_state->use_direct_hip_graph = false; + return false; + } +} + +// Check whether ORT's current tensor pointers match the addresses stored +// during capture. Returns true if all pointers match. +static bool check_captured_ptrs_match( + const MIGraphXFuncState::CapturedHipGraph& entry, + const std::unordered_map& current_input_ptrs, + const std::unordered_map& current_output_ptrs) +{ + for (const auto& [name, ptr] : current_input_ptrs) { + auto it = entry.captured_input_ptrs.find(name); + if (it == entry.captured_input_ptrs.end() || it->second != ptr) return false; + } + for (const auto& [name, ptr] : current_output_ptrs) { + auto it = entry.captured_output_ptrs.find(name); + if (it == entry.captured_output_ptrs.end() || it->second != ptr) return false; + } + return true; +} + +// Direct-bind dispatch: replay or capture hipGraph using ORT tensor pointers +// directly. Falls back to the pinned-copy path on pointer mismatch. +static void run_program_or_hip_graph_direct( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + const std::unordered_map& input_ptrs, + const std::unordered_map& output_ptrs, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + auto it = mgx_state->hip_graph_cache.find(shape_hash); + if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + if (!check_captured_ptrs_match(it->second, input_ptrs, output_ptrs)) { + // Pointer drift -- destroy old graph and re-capture + if (it->second.exec) { (void)hipGraphExecDestroy(it->second.exec); it->second.exec = nullptr; } + if (it->second.graph) { (void)hipGraphDestroy(it->second.graph); it->second.graph = nullptr; } + it->second.captured = false; + } else { + HIP_CALL_THROW(hipGraphLaunch(it->second.exec, stream)); + if (!it->second.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, it->second.extra_outputs, + original_batch_size, padded_batch_size); + } + return; + } + } + + if (!warmup_and_capture_hip_graph_direct(mgx_state, stream, prog, m, + prog_output_indices, shape_hash, + input_ptrs, output_ptrs)) { + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + } else { + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + if (!entry.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, entry.extra_outputs, + original_batch_size, padded_batch_size); + } + } +} + // Materialize extra (non-pre-allocated) outputs recorded during hipGraph capture. // These are MIGraphX outputs not exposed as named parameters — their GPU data // pointers are stable across replays because hipGraph replays the same kernels. @@ -4232,10 +4365,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // hipGraph: set per-node enable flag and validate cached programs p->hip_graph_enabled = hip_graph_enable_; + p->use_direct_hip_graph = hip_graph_enable_; if (p->hip_graph_enabled && p->cached_programs_ref.has_value()) { for (const auto& [hash, cached_prog] : p->cached_programs_ref.value().get()) { if (!check_hip_graph_compatibility(cached_prog, context->node_name)) { p->hip_graph_enabled = false; + p->use_direct_hip_graph = false; break; } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 891db01eda263..efdc3ff631832 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -168,9 +168,17 @@ struct MIGraphXFuncState { hipGraphExec_t exec = nullptr; bool captured = false; std::vector extra_outputs; + + // Addresses captured in the graph for direct-bind mode. + // Used to detect pointer drift and trigger re-capture. + std::unordered_map captured_input_ptrs; + std::unordered_map captured_output_ptrs; }; bool hip_graph_enabled = false; + // When true, capture/replay binds ORT tensor pointers directly (no pinned copies). + // Requires the pool allocator to provide stable addresses. + bool use_direct_hip_graph = false; // shape_hash -> captured graph (one per compiled program variant) std::unordered_map hip_graph_cache; }; From 22876279fa2622f1c853beb28554a402a38c6bf7 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 11:32:27 -0500 Subject: [PATCH 12/16] Route ultra-fast/fast/standard paths through direct-bind hipGraph When use_direct_hip_graph is true and no batch padding is needed, all three execution paths (ultra-fast, fast, standard) now bind ORT tensor pointers directly into MIGraphX program_parameters and dispatch through run_program_or_hip_graph_direct(). This eliminates the copy_inputs_to_pinned and copy_pinned_outputs_to_ort memcpy rounds entirely. The pinned-copy path is preserved as fallback for padding cases and when use_direct_hip_graph is false. Made-with: Cursor --- .../migraphx/migraphx_execution_provider.cc | 105 +++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index d9e233e8ddbac..9aab279bf0342 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -2684,8 +2684,35 @@ static bool execute_ultra_fast_path( .GetTensorTypeAndShapeInfo().GetShape()[0]) : 0); std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; - // hipGraph requires stable buffer addresses → always route through pinned I/O - bool needs_pinned = ((actual_batch < compiled_batch) || mgx_state->hip_graph_enabled) + bool needs_padding = (actual_batch < compiled_batch); + + // Direct-bind hipGraph: no copies, bind ORT pointers and replay + if (mgx_state->use_direct_hip_graph && !needs_padding) { + auto& m = mgx_state->cached_prog_params.value(); + std::unordered_map input_ptrs, output_ptrs; + for (const auto& inp : mgx_state->cached_inputs) { + const auto& input_tensor = ctx.GetInput(inp.ort_index); + void* ptr = const_cast(input_tensor.GetTensorRawData()); + m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, ptr)); + input_ptrs[inp.name] = ptr; + } + for (std::size_t i = 0; i < mgx_state->cached_outputs.size(); ++i) { + const auto& out = mgx_state->cached_outputs[i]; + const auto& ort_shape = mgx_state->cached_output_ort_shapes[i]; + auto output_tensor = ctx.GetOutput(out.output_index, ort_shape.data(), ort_shape.size()); + void* ptr = output_tensor.GetTensorMutableRawData(); + m.add(out.name.c_str(), migraphx::argument(out.mgx_shape, ptr)); + output_ptrs[out.name] = ptr; + } + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + mgx_state->cached_prog_output_indices, + mgx_state->last_input_shape_hash, + input_ptrs, output_ptrs); + return true; + } + + // Pinned-copy path: padding needed or legacy hipGraph path + bool needs_pinned = (needs_padding || mgx_state->hip_graph_enabled) && mgx_state->pinned_io.allocated; if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { @@ -2843,8 +2870,45 @@ static bool execute_fast_path( } } std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; - // hipGraph requires stable buffer addresses → always route through pinned I/O - bool needs_pinned = ((actual_batch < compiled_batch) || mgx_state->hip_graph_enabled) + bool fast_needs_padding = (actual_batch < compiled_batch); + + // Direct-bind hipGraph path: bind ORT pointers and replay, no copies + if (mgx_state->use_direct_hip_graph && !fast_needs_padding) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } + } + } + + mgx_state->cached_prog_params = std::move(m); + mgx_state->cached_prog_output_indices = std::move(prog_output_indices); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, + mgx_state->cached_prog_params.value(), + mgx_state->cached_prog_output_indices, + effective_program_hash, + input_ptrs, output_ptrs); + return true; + } + + // Pinned-copy path: padding needed or legacy hipGraph + bool needs_pinned = (fast_needs_padding || mgx_state->hip_graph_enabled) && mgx_state->pinned_io.allocated; if (needs_pinned) { @@ -3306,6 +3370,39 @@ static void execute_standard_path( } } + // Direct-bind hipGraph for standard path (no padding case) + if (mgx_state->use_direct_hip_graph) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } + } + } + + mgx_state->cached_prog_params = m; + mgx_state->cached_prog_output_indices = prog_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + prog_output_indices, current_hash, + input_ptrs, output_ptrs); + return; + } + if (mgx_state->hip_graph_enabled && mgx_state->pinned_io.allocated) { std::size_t actual_batch = 0; for (const auto& [name, index] : map_input_name_index) { From 2c9f4e723339d5e041c635fa5e9938d992883e1a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 11:33:13 -0500 Subject: [PATCH 13/16] Add pointer-drift detection with re-capture limit and fallback Track direct-bind re-capture count in MIGraphXFuncState. If pointer drift triggers more than kMaxDirectRecaptures (3) graph re-captures, permanently disable use_direct_hip_graph for that node and fall back to eager run_migraphx_program execution. This prevents infinite re-capture loops if the pool allocator cannot maintain pointer stability. Made-with: Cursor --- .../providers/migraphx/migraphx_execution_provider.cc | 11 ++++++++++- .../providers/migraphx/migraphx_execution_provider.h | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9aab279bf0342..ecb8eb2627884 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -2020,7 +2020,16 @@ static void run_program_or_hip_graph_direct( auto it = mgx_state->hip_graph_cache.find(shape_hash); if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { if (!check_captured_ptrs_match(it->second, input_ptrs, output_ptrs)) { - // Pointer drift -- destroy old graph and re-capture + ++mgx_state->direct_recapture_count; + if (mgx_state->direct_recapture_count > MIGraphXFuncState::kMaxDirectRecaptures) { + LOGS_DEFAULT(WARNING) << "[HipGraph] Too many pointer-drift re-captures (" + << mgx_state->direct_recapture_count + << "), falling back to eager execution"; + mgx_state->use_direct_hip_graph = false; + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + return; + } if (it->second.exec) { (void)hipGraphExecDestroy(it->second.exec); it->second.exec = nullptr; } if (it->second.graph) { (void)hipGraphDestroy(it->second.graph); it->second.graph = nullptr; } it->second.captured = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index efdc3ff631832..20b9538aad1ab 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -179,6 +179,9 @@ struct MIGraphXFuncState { // When true, capture/replay binds ORT tensor pointers directly (no pinned copies). // Requires the pool allocator to provide stable addresses. bool use_direct_hip_graph = false; + // If pointer drift causes too many re-captures, disable direct mode permanently. + static constexpr int kMaxDirectRecaptures = 3; + int direct_recapture_count = 0; // shape_hash -> captured graph (one per compiled program variant) std::unordered_map hip_graph_cache; }; From 7d274761cfbf169c19b856da41f378867c07ebc5 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 11:34:32 -0500 Subject: [PATCH 14/16] Ensure padding path falls back to pinned-copy hipGraph correctly When batch padding is needed (actual_batch < compiled_batch), all execution paths continue to use the existing pinned-copy hipGraph path since the padded buffer sizes differ from ORT's tensor sizes. Also adds direct-bind path to the dynamic-batch compilation code for exact-match (no-padding) cases, and ensures use_direct_hip_graph is disabled alongside hip_graph_enabled in all compatibility-check failure paths. Made-with: Cursor --- .../migraphx/migraphx_execution_provider.cc | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ecb8eb2627884..6c39b5067987e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1915,6 +1915,14 @@ static void replay_hip_graph(MIGraphXFuncState* mgx_state, HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); } +// Forward declaration (defined after run_program_or_hip_graph) +static void materialize_extra_outputs( + Ort::KernelContext& ctx, + hipStream_t stream, + const std::vector& extras, + std::size_t original_batch_size, + std::size_t padded_batch_size); + // Direct-bind capture: bind ORT tensor pointers directly (no pinned buffers) // and capture the hipGraph. Requires stable pointers from pool allocator. static bool warmup_and_capture_hip_graph_direct( @@ -3206,6 +3214,7 @@ static void execute_standard_path( for (const auto& [hash, cached_prog] : mgx_state->cached_programs_ref.value().get()) { if (!check_hip_graph_compatibility(cached_prog, "runtime_dynamic_batch")) { mgx_state->hip_graph_enabled = false; + mgx_state->use_direct_hip_graph = false; break; } } @@ -3268,6 +3277,39 @@ static void execute_standard_path( populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, original_batch_size, padded_batch_size); + // Direct-bind hipGraph for exact-match batch (no padding) + if (mgx_state->use_direct_hip_graph && !needs_padding) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } + } + } + + mgx_state->cached_prog_params = m; + mgx_state->cached_prog_output_indices = prog_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = padded_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + prog_output_indices, padded_hash, + input_ptrs, output_ptrs); + return; + } + bool use_pinned = needs_padding || mgx_state->hip_graph_enabled; if (use_pinned) { if (!mgx_state->pinned_io.allocated) { @@ -3354,6 +3396,7 @@ static void execute_standard_path( if (mgx_state->hip_graph_enabled && !check_hip_graph_compatibility(prog, "standard_path_recompile")) { mgx_state->hip_graph_enabled = false; + mgx_state->use_direct_hip_graph = false; } } From 0176ead6f8e33d6600f28c1bf1d5e157262dd929 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 27 Apr 2026 22:42:59 -0500 Subject: [PATCH 15/16] update warmup iterations --- .../migraphx/migraphx_execution_provider.cc | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 6c39b5067987e..e69351804d01e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1836,6 +1836,8 @@ static void destroy_hip_graphs(MIGraphXFuncState* mgx_state) { // Warmup run (ensures lazy GPU allocations are finalized) then capture the graph. // Stores extra (non-pre-allocated) output metadata so replay can materialize them. +static constexpr int kHipGraphWarmInIterations = 8; + static bool warmup_and_capture_hip_graph( MIGraphXFuncState* mgx_state, hipStream_t stream, @@ -1853,8 +1855,10 @@ static bool warmup_and_capture_hip_graph( HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); } + // Run multiple eager warmup iterations before capture to let MIGraphX + // internal workspace buffers (scratch, reductions, etc.) stabilize. std::optional warmup_outputs; - { + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { std::lock_guard lock(*mgx_state->mgx_mu_ptr); warmup_outputs = prog.run_async(m, stream); } @@ -1879,6 +1883,13 @@ static bool warmup_and_capture_hip_graph( HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); entry.captured = true; + // Replay the captured graph several more times post-capture to ensure + // workspace is fully settled before the first real inference. + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + std::unordered_set pre_alloc_set(prog_output_indices.begin(), prog_output_indices.end()); entry.extra_outputs.clear(); @@ -1936,7 +1947,7 @@ static bool warmup_and_capture_hip_graph_direct( const std::unordered_map& output_ptrs) { std::optional warmup_outputs; - { + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { std::lock_guard lock(*mgx_state->mgx_mu_ptr); warmup_outputs = prog.run_async(m, stream); } @@ -1963,6 +1974,11 @@ static bool warmup_and_capture_hip_graph_direct( entry.captured_input_ptrs = input_ptrs; entry.captured_output_ptrs = output_ptrs; + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + std::unordered_set pre_alloc_set(prog_output_indices.begin(), prog_output_indices.end()); entry.extra_outputs.clear(); From b9a4102847b05e74f446f06eeb0cfc05dc18077b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 21 May 2026 13:41:49 -0500 Subject: [PATCH 16/16] Update stream management for hipGraph capture/replay modes --- .../migraphx/migraphx_execution_provider.cc | 36 +++++++++++++++---- .../migraphx/migraphx_stream_handle.cc | 7 ++-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index e69351804d01e..a81614337a11d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -3061,12 +3061,21 @@ static InputShapeResult handle_input_shape( } // Helper: Compile models for all configured batch sizes and cache them +// rocm_stream is the per-Run compute stream resolved from ctx.GetGPUComputeStream() +// in compute_func. It MUST be threaded through to allocate_pinned_io so the +// stream-ordered memory pool used by hipMallocAsync has the same lineage as the +// stream that will later issue copies, captured-graph launches, and replays +// against those pinned buffers. Using a different stream here (e.g. the EP's +// own mgx_state->stream) is undefined behavior under the hipMemPool semantics +// and on ROCm typically surfaces as the captured graph reading stale or +// uninitialized pinned memory on first replay. static void compile_dynamic_batch_models( MIGraphXFuncState* mgx_state, const std::filesystem::path& model_cache_path, const std::filesystem::path& model_path, const std::string& mxr_filename_prefix, - const Ort::KernelContext& ctx) { + const Ort::KernelContext& ctx, + hipStream_t rocm_stream) { if (!mgx_state->has_dynamic_batch || mgx_state->compiled_batch_sizes.empty()) { return; @@ -3195,7 +3204,7 @@ static void compile_dynamic_batch_models( if (largest_prog && max_batch > 0) { auto ps = largest_prog->get_parameter_shapes(); auto os = largest_prog->get_output_shapes(); - allocate_pinned_io(mgx_state, ps, os, max_batch, mgx_state->stream); + allocate_pinned_io(mgx_state, ps, os, max_batch, rocm_stream); } } } @@ -3223,7 +3232,7 @@ static void execute_standard_path( // If precompilation happened during Compile(), max_dynamic_batch will be > 0 but defer_compilation = false // In that case, the programs are already in cache and we can skip runtime compilation if (mgx_state->has_dynamic_batch && mgx_state->max_dynamic_batch > 0 && mgx_state->defer_compilation) { - compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx); + compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx, rocm_stream); // Validate newly compiled programs for hipGraph compatibility if (mgx_state->hip_graph_enabled && mgx_state->cached_programs_ref.has_value()) { @@ -4482,6 +4491,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } // Allocate pinned I/O buffers from the cached programs. + // create_state_func runs ONCE at session init (long before any Run()), + // so there is no per-Run compute stream to query here — ComputeContext + // does not expose one. We use stream_ (the EP-owned init stream) and + // rely on the hipStreamSynchronize(stream) inside allocate_pinned_io to + // establish the hipMallocAsync pool memory so the per-Run compute stream + // (resolved from ctx.GetGPUComputeStream() in compute_func) can safely + // consume these pointers without further cross-stream ordering. // Uses the program compiled for the largest batch size so that // allocate_pinned_io sees parameter shapes whose batch dim matches // max_batch. All smaller batches share the same buffers. @@ -4558,9 +4574,17 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); + // Run on whichever stream ORT elected for this device for THIS Run(). + // - external_stream_=true -> ORT wrapper around the user-supplied stream + // - external_stream_=false -> stream ORT created via RegisterCreateStreamFn + // Either way, ORT's MemcpyFromHost/MemcpyToHost ran on this stream, so issuing + // kernels on it removes the cross-stream race that EP::stream_ would introduce. + hipStream_t run_stream = static_cast(ctx.GetGPUComputeStream()); + if (run_stream == nullptr) run_stream = stream_; // fallback for harnesses w/o stream registry + const auto& map_input_name_index = mgx_state->input_name_indexes; - if (execute_ultra_fast_path(mgx_state, stream_, ctx)) { + if (execute_ultra_fast_path(mgx_state, run_stream, ctx)) { return Status::OK(); } @@ -4572,11 +4596,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } const auto current_hash = make_hash(all_input_shapes); - if (execute_fast_path(mgx_state, stream_, ctx, current_hash, all_input_shapes)) { + if (execute_fast_path(mgx_state, run_stream, ctx, current_hash, all_input_shapes)) { return Status::OK(); } - execute_standard_path(mgx_state, stream_, ctx, current_hash, std::move(all_input_shapes), + execute_standard_path(mgx_state, run_stream, ctx, current_hash, std::move(all_input_shapes), model_cache_path_, model_path_, mxr_filename_prefix); return Status::OK(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index b021e37cb5112..2829bb45fb93b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -67,8 +67,11 @@ std::unique_ptr MIGraphXStream::CreateNotification(si } void MIGraphXStream::Flush() { - if (auto* handle = GetHandle()) - HIP_CALL_THROW(hipStreamSynchronize(static_cast(handle))); + // Only sync streams we own. External streams are caller-managed; implicit sync + // here would break async pipelines (e.g. Triton) by serializing every Run(). + if(own_stream_) + if (auto* handle = GetHandle()) + HIP_CALL_THROW(hipStreamSynchronize(static_cast(handle))); } void MIGraphXStream::EnqueDeferredCPUBuffer(void* cpu_buffer) {