diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 911a1a7fd18b9..1092a30bd3469 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -22,18 +22,50 @@ void MIGraphXAllocator::CheckDevice() const { #endif } +void MIGraphXAllocator::EnablePoolMode() { + std::lock_guard lock(pool_mu_); + pool_enabled_ = true; +} + void* MIGraphXAllocator::Alloc(size_t size) { CheckDevice(); + if (size == 0) return nullptr; + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + auto it = free_list_.find(size); + if (it != free_list_.end() && !it->second.empty()) { + void* p = it->second.back(); + it->second.pop_back(); + return p; + } + } + void* p = nullptr; - if (size > 0) { - HIP_CALL_THROW(hipMalloc((void**)&p, size)); + HIP_CALL_THROW(hipMalloc((void**)&p, size)); + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + alloc_sizes_[p] = size; } + return p; } void MIGraphXAllocator::Free(void* p) { CheckDevice(); - (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown + if (!p) return; + + if (pool_enabled_) { + std::lock_guard lock(pool_mu_); + auto it = alloc_sizes_.find(p); + if (it != alloc_sizes_.end()) { + free_list_[it->second].push_back(p); + return; + } + } + + (void)hipFree(p); } void* MIGraphXExternalAllocator::Alloc(size_t size) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 10e06ab2f35ad..ee681e07c546a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -4,7 +4,9 @@ #pragma once #include +#include #include +#include #include "core/framework/allocator.h" namespace onnxruntime { @@ -21,8 +23,20 @@ class MIGraphXAllocator : public IAllocator { virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; + void EnablePoolMode(); + bool IsPoolModeEnabled() const { return pool_enabled_; } + private: void CheckDevice() const; + + // When pool mode is enabled (for hipGraph), freed allocations are cached by + // size so that subsequent Alloc calls for the same size return the same + // device pointer. This provides pointer stability required by hipGraph + // replay without the cost of intermediary buffer copies. + bool pool_enabled_ = false; + mutable std::mutex pool_mu_; + std::unordered_map> free_list_; + std::unordered_map alloc_sizes_; }; class MIGraphXExternalAllocator : public MIGraphXAllocator { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2b73055c87df7..a81614337a11d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -142,6 +143,12 @@ static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) static std::vector parse_compile_batches(const std::string& spec); +// Serializes remaining synchronous hipMalloc calls (e.g. temp output buffers) +// across all MIGraphX EP instances in the process. The primary pinned I/O +// allocation paths use hipMallocAsync/hipFreeAsync which are per-stream safe, +// but a few fallback paths still use synchronous hipMalloc. +static std::mutex g_hip_alloc_mutex; + MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, device_id_{info.device_id}, @@ -161,7 +168,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv external_free_{info.external_free}, external_empty_cache_{info.external_empty_cache}, max_dynamic_batch_{info.max_dynamic_batch}, - compile_batches_{info.compile_batches} { + compile_batches_{info.compile_batches}, + hip_graph_enable_{info.hip_graph_enable} { InitProviderOrtApi(); // Set GPU device to be used and read device properties for feature usage. @@ -206,6 +214,35 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); GET_ENV_STRING(migraphx_env_vars::kCompileBatches, compile_batches_); + GET_ENV_BOOL(migraphx_env_vars::kHipGraphEnable, hip_graph_enable_); + + // hipGraph requires single-stream MIGraphX execution (MIGRAPHX_NSTREAMS=1). + if (hip_graph_enable_) { + const auto nstreams_env = GetEnvironmentVar("MIGRAPHX_NSTREAMS"); + int nstreams = nstreams_env.empty() ? 1 : std::stoi(nstreams_env); + if (nstreams > 1) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_NSTREAMS=" << nstreams + << " is incompatible with hipGraph capture. Disabling hipGraph."; + hip_graph_enable_ = false; + } + + const auto trace_env = GetEnvironmentVar("MIGRAPHX_TRACE_EVAL"); + if (!trace_env.empty() && std::stoi(trace_env) != 0) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_TRACE_EVAL is enabled, which calls hipStreamSynchronize " + << "per instruction. Disabling hipGraph."; + hip_graph_enable_ = false; + } + + const auto null_stream_env = GetEnvironmentVar("MIGRAPHX_ENABLE_NULL_STREAM"); + if (!null_stream_env.empty() && std::stoi(null_stream_env) != 0) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] MIGRAPHX_ENABLE_NULL_STREAM is enabled (default stream = illegal " + << "during capture). Disabling hipGraph."; + hip_graph_enable_ = false; + } + } // If compile_batches is set, auto-derive max_dynamic_batch from the spec's max value if (!compile_batches_.empty()) { @@ -285,13 +322,18 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_ << "\n " << migraphx_provider_option::kModelMaxDynamicBatch << ": " << max_dynamic_batch_ - << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_); + << "\n " << migraphx_provider_option::kCompileBatches << ": " << (compile_batches_.empty() ? "(not set)" : compile_batches_) + << "\n " << migraphx_provider_option::kHipGraphEnable << ": " << hip_graph_enable_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA); + [this](OrtDevice::DeviceId device_id) { + auto alloc = std::make_unique(device_id, onnxruntime::CUDA); + if (hip_graph_enable_) { + alloc->EnablePoolMode(); + } + return alloc; }, device_id_); AllocatorCreationInfo pinned_allocator_info( @@ -1418,214 +1460,324 @@ static void pad_input_tensor(const void* src_data, void* dst_data, } } -// Allocate padded input buffers and pad the data for dynamic batching -// Returns true if padding was applied, false otherwise -// OPTIMIZATION: Reuses existing buffers if padded batch size matches -static bool allocate_and_pad_inputs( +// Helper: Extract output index from MIGraphX output parameter name +// MIGraphX names outputs as "#output_0", "#output_1", etc. +static int compute_output_index(const std::string_view sv) { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { + return -1; + } + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); +} + + +// Allocate pinned I/O buffers at the given max batch size. Called once per node +// at session creation (or lazily on first inference for deferred compilation). +// All batch sizes share these buffers — smaller batches use the leading prefix. +static void allocate_pinned_io( MIGraphXFuncState* mgx_state, - Ort::KernelContext& ctx, - std::size_t original_batch_size, - std::size_t padded_batch_size, - hipStream_t stream) { - - if (padded_batch_size <= original_batch_size || mgx_state->cached_inputs.empty()) { - return false; // No padding needed + const migraphx::program_parameter_shapes& param_shapes, + const migraphx::shapes& output_shapes, + std::size_t max_batch_size, + hipStream_t stream) +{ + auto& pio = mgx_state->pinned_io; + if (pio.allocated) { + return; } - - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION: Check if we can reuse existing padded buffers - // ═══════════════════════════════════════════════════════════════════════════ - bool can_reuse_buffers = ( - mgx_state->last_padded_batch_size == padded_batch_size && - !mgx_state->padded_input_buffers.empty() && - mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size() - ); - - if (can_reuse_buffers) { - - // Just copy new data into existing buffers - skip allocation - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& cached_inp = mgx_state->cached_inputs[i]; - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (tensor_shape.empty()) continue; - - auto& padded_buf = mgx_state->padded_input_buffers[i]; - - // Calculate elements per batch - std::size_t elements_per_batch = 1; - for (std::size_t j = 1; j < tensor_shape.size(); ++j) { - elements_per_batch *= tensor_shape[j]; - } - - // Calculate element size from tensor type - std::size_t element_size_bytes; - switch (tensor_info.GetElementType()) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - element_size_bytes = sizeof(float); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - element_size_bytes = sizeof(uint16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - element_size_bytes = sizeof(int64_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - element_size_bytes = sizeof(int32_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - element_size_bytes = sizeof(int16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - element_size_bytes = sizeof(int8_t); - break; - default: - element_size_bytes = sizeof(float); // Fallback to float - break; + + const auto& map_input_name_index = mgx_state->input_name_indexes; + + pio.inputs.clear(); + pio.input_name_to_idx.clear(); + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) == map_input_name_index.end()) continue; + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(base_shape.type(), lens); + std::size_t bytes = max_shape.bytes(); + + pio.input_name_to_idx[name] = pio.inputs.size(); + void* ptr = nullptr; + HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); + pio.inputs.push_back({ptr, bytes, max_shape}); + } + + pio.outputs.clear(); + pio.output_name_to_idx.clear(); + std::size_t output_alloc_idx = 0; + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) != map_input_name_index.end()) continue; + const auto oi = compute_output_index(name); + if (oi == -1) continue; + if (static_cast(oi) >= output_shapes.size()) continue; + if (output_alloc_idx >= output_shapes.size()) break; + + const auto& out_shape = output_shapes[oi]; + auto lens = out_shape.lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(out_shape.type(), lens); + std::size_t bytes = max_shape.bytes(); + + pio.output_name_to_idx[name] = pio.outputs.size(); + void* ptr = nullptr; + HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); + pio.outputs.push_back({ptr, bytes, max_shape}); + ++output_alloc_idx; + } + + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + pio.max_batch_size = max_batch_size; + pio.allocated = true; +} + +static void free_pinned_io(MIGraphXFuncState* mgx_state, hipStream_t stream) { + auto& pio = mgx_state->pinned_io; + for (auto& buf : pio.inputs) { + if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } + } + for (auto& buf : pio.outputs) { + if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + pio.inputs.clear(); + pio.outputs.clear(); + pio.allocated = false; +} + +// Copy ORT input tensors into pinned buffers and pad if needed. +static void copy_inputs_to_pinned( + MIGraphXFuncState* mgx_state, + const migraphx::program_parameter_shapes& param_shapes, + Ort::KernelContext& ctx, + std::size_t actual_batch, + std::size_t compiled_batch, + hipStream_t stream) +{ + auto& pio = mgx_state->pinned_io; + const auto& map_input_name_index = mgx_state->input_name_indexes; + + for (const auto& name : param_shapes.names()) { + auto it = map_input_name_index.find(name); + if (it == map_input_name_index.end()) continue; + + auto pin_it = pio.input_name_to_idx.find(name); + if (pin_it == pio.input_name_to_idx.end()) continue; + auto& pin = pio.inputs[pin_it->second]; + + const auto& input_tensor = ctx.GetInput(it->second); + const void* src = input_tensor.GetTensorRawData(); + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + + std::size_t elements_per_batch = std::accumulate( + lens.begin() + 1, lens.end(), std::size_t{1}, std::multiplies<>{}); + + std::size_t total_elems = 1; + for (auto l : lens) total_elems *= l; + std::size_t byte_per_elem = (total_elems > 0) ? base_shape.bytes() / total_elems : 0; + std::size_t bytes_per_batch = elements_per_batch * byte_per_elem; + + std::size_t copy_bytes = actual_batch * bytes_per_batch; + if (copy_bytes > pin.size_bytes) copy_bytes = pin.size_bytes; + + if (actual_batch == compiled_batch) { + if (copy_bytes > 0) { + HIP_CALL_THROW(hipMemcpyAsync(pin.data, src, copy_bytes, hipMemcpyDefault, stream)); } - - // Reuse existing buffer - just pad with new data - const void* original_data = input_tensor.GetTensorRawData(); - pad_input_tensor(original_data, padded_buf.data, original_batch_size, padded_batch_size, - element_size_bytes, elements_per_batch, stream); + } else { + pad_input_tensor(src, pin.data, actual_batch, compiled_batch, + byte_per_elem, elements_per_batch, stream); } - - // Update original batch tracking (padded batch is already correct) - mgx_state->last_original_batch_size = original_batch_size; - - return true; } - - // ═══════════════════════════════════════════════════════════════════════════ - // Normal path: Allocate new buffers (batch size changed or first run) - // ═══════════════════════════════════════════════════════════════════════════ - - // Free old buffers if they exist - for (auto& buf : mgx_state->padded_input_buffers) { - if (buf.data != nullptr) { - HIP_CALL_THROW(hipFree(buf.data)); - buf.data = nullptr; +} + +// Build program_parameters binding pinned buffers at the given compiled shape. +// Uses name-based lookup into pinned buffers so parameter ordering differences +// between compiled programs don't cause mismatched buffer access. +// Returns: {program_parameters, ORT_output_indices, pinned_buffer_indices} +struct PinnedBindResult { + migraphx::program_parameters params; + std::vector prog_output_indices; + std::vector pinned_output_indices; +}; + +static PinnedBindResult +bind_pinned_program_params( + MIGraphXFuncState* mgx_state, + const migraphx::program_parameter_shapes& param_shapes, + const migraphx::shapes& output_shapes) +{ + auto& pio = mgx_state->pinned_io; + const auto& map_input_name_index = mgx_state->input_name_indexes; + + PinnedBindResult result; + + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) != map_input_name_index.end()) { + auto pin_it = pio.input_name_to_idx.find(name); + if (pin_it == pio.input_name_to_idx.end()) continue; + result.params.add(name, migraphx::argument(param_shapes[name], pio.inputs[pin_it->second].data)); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + auto pin_it = pio.output_name_to_idx.find(name); + if (pin_it == pio.output_name_to_idx.end()) continue; + result.params.add(name, migraphx::argument(param_shapes[name], pio.outputs[pin_it->second].data)); + result.prog_output_indices.push_back(static_cast(oi)); + result.pinned_output_indices.push_back(pin_it->second); + } } } - mgx_state->padded_input_buffers.clear(); - - // Allocate and pad each input - mgx_state->padded_input_buffers.reserve(mgx_state->cached_inputs.size()); - - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (tensor_shape.empty()) { - continue; + + return result; +} + +// Copy results from pinned output buffers to ORT output tensors. +static void copy_pinned_outputs_to_ort( + MIGraphXFuncState* mgx_state, + const migraphx::shapes& output_shapes, + const std::vector& prog_output_indices, + const std::vector& pinned_output_indices, + Ort::KernelContext& ctx, + std::size_t actual_batch, + hipStream_t stream) +{ + auto& pio = mgx_state->pinned_io; + + for (std::size_t i = 0; i < prog_output_indices.size() && i < pinned_output_indices.size(); ++i) { + const auto oi = prog_output_indices[i]; + const auto pin_idx = pinned_output_indices[i]; + if (pin_idx >= pio.outputs.size()) continue; + const auto& pin = pio.outputs[pin_idx]; + const auto& out_shape = output_shapes[oi]; + auto lens = out_shape.lengths(); + + std::vector ort_shape(lens.begin(), lens.end()); + if (!ort_shape.empty()) { + ort_shape[0] = static_cast(actual_batch); } - - // Calculate padded shape - std::vector padded_lens(tensor_shape.begin(), tensor_shape.end()); - padded_lens[0] = padded_batch_size; // Replace batch dimension - - // Create padded MIGraphX shape - migraphx::shape padded_mgx_shape{cached_inp.mgx_shape.type(), padded_lens}; - std::size_t padded_bytes = padded_mgx_shape.bytes(); - - // Allocate GPU buffer for padded data - void* padded_data = nullptr; - HIP_CALL_THROW(hipMalloc(&padded_data, padded_bytes)); - - // Calculate elements per batch - std::size_t elements_per_batch = 1; - for (std::size_t i = 1; i < tensor_shape.size(); ++i) { - elements_per_batch *= tensor_shape[i]; + + auto output_tensor = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + void* dst = output_tensor.GetTensorMutableRawData(); + + std::size_t total_elems = 1; + for (auto l : lens) total_elems *= l; + std::size_t copy_bytes = 0; + if (total_elems > 0 && !lens.empty()) { + std::size_t byte_per_elem = out_shape.bytes() / total_elems; + std::size_t elems_per_batch = total_elems / std::max(1, lens[0]); + copy_bytes = actual_batch * elems_per_batch * byte_per_elem; } - - // Calculate element size from tensor type - std::size_t element_size_bytes; - switch (tensor_info.GetElementType()) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - element_size_bytes = sizeof(float); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - element_size_bytes = sizeof(uint16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - element_size_bytes = sizeof(int64_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - element_size_bytes = sizeof(int32_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - element_size_bytes = sizeof(int16_t); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - element_size_bytes = sizeof(int8_t); - break; - default: - element_size_bytes = sizeof(float); // Fallback to float - break; + + if (copy_bytes > 0) { + HIP_CALL_THROW(hipMemcpyAsync(dst, pin.data, copy_bytes, hipMemcpyDefault, stream)); } - - // Pad the data - const void* original_data = input_tensor.GetTensorRawData(); - pad_input_tensor(original_data, padded_data, original_batch_size, padded_batch_size, - element_size_bytes, elements_per_batch, stream); - - // Store padded buffer info - MIGraphXFuncState::PaddedBuffer buf; - buf.data = padded_data; - buf.size_bytes = padded_bytes; - buf.mgx_shape = padded_mgx_shape; - mgx_state->padded_input_buffers.push_back(buf); - } - - // Update batch tracking - mgx_state->last_original_batch_size = original_batch_size; - mgx_state->last_padded_batch_size = padded_batch_size; - - return true; } -// Helper: Extract output index from MIGraphX output parameter name -// MIGraphX names outputs as "#output_0", "#output_1", etc. -static int compute_output_index(const std::string_view sv) { - constexpr std::string_view out_name_prefix = "#output_"; - const auto pos = sv.find(out_name_prefix); - if (pos == std::string_view::npos) { - return -1; + +// Helper: Run the MIGraphX program and handle outputs +// This function executes the compiled MIGraphX program and copies outputs that +// were not pre-allocated (input parameters reused as outputs) to the ORT output tensors +// If original_batch_size is provided and < padded batch size, slices the output to remove padding +static void run_migraphx_program( + std::mutex* mgx_mu_ptr, + hipStream_t rocm_stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + std::optional prog_outputs; + { + std::lock_guard lock(*mgx_mu_ptr); + prog_outputs = prog.run_async(m, rocm_stream); } - const auto index_str = sv.substr(pos + out_name_prefix.length()); - return ToInteger(Trim(index_str, std::isdigit)); -} -// Free temporary output buffers -static void free_temp_output_buffers(MIGraphXFuncState* mgx_state) { - for (auto& buf : mgx_state->temp_output_buffers) { - if (buf.data != nullptr) { - (void)hipFree(buf.data); // Don't throw on cleanup - buf.data = nullptr; + bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size); + + auto output_num = prog_outputs->size(); + + // Fast path: no padding/slicing and all outputs were pre-allocated — nothing to do. + if (!needs_slicing && prog_output_indices.size() == output_num) + return; + + std::unordered_set prog_output_indices_set(prog_output_indices.begin(), prog_output_indices.end()); + + if (needs_slicing && !prog_output_indices_set.empty()) { + // Must sync before reallocating any pre-allocated output buffer for slicing. + HIP_CALL_THROW(hipStreamSynchronize(rocm_stream)); + + for (std::size_t i = 0; i < output_num; ++i) { + if (prog_output_indices_set.count(i) == 0) continue; + + auto gpu_res = (*prog_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { + ort_shape[0] = static_cast(original_batch_size); + + std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; + std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; + + const void* src_data = gpu_res.data(); + auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + if (output_data != src_data) { + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + src_data, + bytes_to_copy, + hipMemcpyDeviceToDevice, + rocm_stream)); + } + } + } + } + + // Copy outputs that were not pre-allocated into ORT output tensors. + // All copies are async on rocm_stream — no sync needed here. + for (std::size_t i = 0; i < output_num; ++i) { + if (prog_output_indices_set.count(i) > 0) continue; + + auto gpu_res = (*prog_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + if (needs_slicing && !ort_shape.empty()) { + ort_shape[0] = original_batch_size; + } + + auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + std::size_t bytes_to_copy = res_shape.bytes(); + if (needs_slicing && !res_lens.empty()) { + bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; } + + HIP_CALL_THROW(hipMemcpyWithStream(output_data, + gpu_res.data(), + bytes_to_copy, + hipMemcpyDeviceToDevice, + rocm_stream)); } - mgx_state->temp_output_buffers.clear(); - mgx_state->temp_output_padded_batch_size = 0; } + // Clear cached MIGraphX shapes (call when program changes) static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { mgx_state->cached_mgx_param_shapes.reset(); @@ -1634,72 +1786,375 @@ static void clear_cached_mgx_shapes(MIGraphXFuncState* mgx_state) { mgx_state->cached_program_hash.clear(); } -// Allocate or reuse temporary output buffers for slicing mode -// Returns vector of raw pointers for use with handle_program_input_outputs -static std::vector get_or_allocate_temp_output_buffers( +// ═══════════════════════════════════════════════════════════════════════════════ +// hipGraph CAPTURE / REPLAY helpers +// ═══════════════════════════════════════════════════════════════════════════════ + +static bool check_hip_graph_compatibility(const migraphx::program& prog, + const std::string& node_name) { + /* std::ostringstream prog_text; + prog.print(prog_text); + const std::string text = prog_text.str(); + + static const std::vector unsafe_ops = { + "hip::sync_stream", + "hip::allocate", + "hip::copy_from_gpu", + "hip::copy_to_gpu", + "gpu::record_event", + "gpu::wait_event", + "gpu::set_stream", + }; + + for (const auto& op : unsafe_ops) { + if (text.find(op) != std::string::npos) { + LOGS_DEFAULT(WARNING) + << "[HipGraph] Node '" << node_name + << "' contains '" << op + << "' which is incompatible with hipGraph capture. " + << "Falling back to eager execution for this node."; + return false; + } + } */ + return true; +} + +static void destroy_hip_graphs(MIGraphXFuncState* mgx_state) { + for (auto& [hash, entry] : mgx_state->hip_graph_cache) { + if (entry.exec) { + (void)hipGraphExecDestroy(entry.exec); + entry.exec = nullptr; + } + if (entry.graph) { + (void)hipGraphDestroy(entry.graph); + entry.graph = nullptr; + } + entry.captured = false; + } + mgx_state->hip_graph_cache.clear(); +} + +// Warmup run (ensures lazy GPU allocations are finalized) then capture the graph. +// Stores extra (non-pre-allocated) output metadata so replay can materialize them. +static constexpr int kHipGraphWarmInIterations = 8; + +static bool warmup_and_capture_hip_graph( MIGraphXFuncState* mgx_state, - const migraphx::program_parameter_shapes& param_shapes, - const migraphx::shapes& output_shapes, - const std::unordered_map& map_input_name_index, + hipStream_t stream, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash) +{ + // Zero all pinned buffers before warmup to avoid stale data from prior batch runs + auto& pio = mgx_state->pinned_io; + for (auto& pin : pio.inputs) { + HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); + } + for (auto& pin : pio.outputs) { + HIP_CALL_THROW(hipMemsetAsync(pin.data, 0, pin.size_bytes, stream)); + } + + // Run multiple eager warmup iterations before capture to let MIGraphX + // internal workspace buffers (scratch, reductions, etc.) stabilize. + std::optional warmup_outputs; + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + warmup_outputs = prog.run_async(m, stream); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + auto& entry = mgx_state->hip_graph_cache[shape_hash]; + + try { + HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + prog.run_async(m, stream); + } + hipError_t err = hipStreamEndCapture(stream, &entry.graph); + if (err != hipSuccess || entry.graph == nullptr) { + entry.graph = nullptr; + entry.captured = false; + mgx_state->hip_graph_enabled = false; + return false; + } + + HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); + entry.captured = true; + + // Replay the captured graph several more times post-capture to ensure + // workspace is fully settled before the first real inference. + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + std::unordered_set pre_alloc_set(prog_output_indices.begin(), + prog_output_indices.end()); + entry.extra_outputs.clear(); + if (warmup_outputs) { + auto output_num = warmup_outputs->size(); + for (std::size_t i = 0; i < output_num; ++i) { + if (pre_alloc_set.count(i) > 0) continue; + auto gpu_res = (*warmup_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + entry.extra_outputs.push_back({i, std::move(ort_shape), + gpu_res.data(), res_shape.bytes()}); + } + } + + return true; + } catch (...) { + hipGraph_t dummy = nullptr; + (void)hipStreamEndCapture(stream, &dummy); + if (dummy) (void)hipGraphDestroy(dummy); + entry.graph = nullptr; + entry.exec = nullptr; + entry.captured = false; + mgx_state->hip_graph_enabled = false; + return false; + } +} + +static void replay_hip_graph(MIGraphXFuncState* mgx_state, + hipStream_t stream, + const std::string& shape_hash) { + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); +} + +// Forward declaration (defined after run_program_or_hip_graph) +static void materialize_extra_outputs( + Ort::KernelContext& ctx, + hipStream_t stream, + const std::vector& extras, + std::size_t original_batch_size, + std::size_t padded_batch_size); + +// Direct-bind capture: bind ORT tensor pointers directly (no pinned buffers) +// and capture the hipGraph. Requires stable pointers from pool allocator. +static bool warmup_and_capture_hip_graph_direct( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + const std::unordered_map& input_ptrs, + const std::unordered_map& output_ptrs) +{ + std::optional warmup_outputs; + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + warmup_outputs = prog.run_async(m, stream); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + auto& entry = mgx_state->hip_graph_cache[shape_hash]; + + try { + HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + { + std::lock_guard lock(*mgx_state->mgx_mu_ptr); + prog.run_async(m, stream); + } + hipError_t err = hipStreamEndCapture(stream, &entry.graph); + if (err != hipSuccess || entry.graph == nullptr) { + entry.graph = nullptr; + entry.captured = false; + mgx_state->use_direct_hip_graph = false; + return false; + } + + HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); + entry.captured = true; + entry.captured_input_ptrs = input_ptrs; + entry.captured_output_ptrs = output_ptrs; + + for (int i = 0; i < kHipGraphWarmInIterations; ++i) { + HIP_CALL_THROW(hipGraphLaunch(entry.exec, stream)); + } + HIP_CALL_THROW(hipStreamSynchronize(stream)); + + std::unordered_set pre_alloc_set(prog_output_indices.begin(), + prog_output_indices.end()); + entry.extra_outputs.clear(); + if (warmup_outputs) { + auto output_num = warmup_outputs->size(); + for (std::size_t i = 0; i < output_num; ++i) { + if (pre_alloc_set.count(i) > 0) continue; + auto gpu_res = (*warmup_outputs)[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + entry.extra_outputs.push_back({i, std::move(ort_shape), + gpu_res.data(), res_shape.bytes()}); + } + } + + return true; + } catch (...) { + hipGraph_t dummy = nullptr; + (void)hipStreamEndCapture(stream, &dummy); + if (dummy) (void)hipGraphDestroy(dummy); + entry.graph = nullptr; + entry.exec = nullptr; + entry.captured = false; + mgx_state->use_direct_hip_graph = false; + return false; + } +} + +// Check whether ORT's current tensor pointers match the addresses stored +// during capture. Returns true if all pointers match. +static bool check_captured_ptrs_match( + const MIGraphXFuncState::CapturedHipGraph& entry, + const std::unordered_map& current_input_ptrs, + const std::unordered_map& current_output_ptrs) +{ + for (const auto& [name, ptr] : current_input_ptrs) { + auto it = entry.captured_input_ptrs.find(name); + if (it == entry.captured_input_ptrs.end() || it->second != ptr) return false; + } + for (const auto& [name, ptr] : current_output_ptrs) { + auto it = entry.captured_output_ptrs.find(name); + if (it == entry.captured_output_ptrs.end() || it->second != ptr) return false; + } + return true; +} + +// Direct-bind dispatch: replay or capture hipGraph using ORT tensor pointers +// directly. Falls back to the pinned-copy path on pointer mismatch. +static void run_program_or_hip_graph_direct( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + const std::unordered_map& input_ptrs, + const std::unordered_map& output_ptrs, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + auto it = mgx_state->hip_graph_cache.find(shape_hash); + if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + if (!check_captured_ptrs_match(it->second, input_ptrs, output_ptrs)) { + ++mgx_state->direct_recapture_count; + if (mgx_state->direct_recapture_count > MIGraphXFuncState::kMaxDirectRecaptures) { + LOGS_DEFAULT(WARNING) << "[HipGraph] Too many pointer-drift re-captures (" + << mgx_state->direct_recapture_count + << "), falling back to eager execution"; + mgx_state->use_direct_hip_graph = false; + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + return; + } + if (it->second.exec) { (void)hipGraphExecDestroy(it->second.exec); it->second.exec = nullptr; } + if (it->second.graph) { (void)hipGraphDestroy(it->second.graph); it->second.graph = nullptr; } + it->second.captured = false; + } else { + HIP_CALL_THROW(hipGraphLaunch(it->second.exec, stream)); + if (!it->second.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, it->second.extra_outputs, + original_batch_size, padded_batch_size); + } + return; + } + } + + if (!warmup_and_capture_hip_graph_direct(mgx_state, stream, prog, m, + prog_output_indices, shape_hash, + input_ptrs, output_ptrs)) { + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + } else { + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + if (!entry.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, entry.extra_outputs, + original_batch_size, padded_batch_size); + } + } +} + +// Materialize extra (non-pre-allocated) outputs recorded during hipGraph capture. +// These are MIGraphX outputs not exposed as named parameters — their GPU data +// pointers are stable across replays because hipGraph replays the same kernels. +static void materialize_extra_outputs( + Ort::KernelContext& ctx, + hipStream_t stream, + const std::vector& extras, + std::size_t original_batch_size, std::size_t padded_batch_size) { - // Check if we can reuse existing buffers - bool can_reuse = ( - mgx_state->temp_output_padded_batch_size == padded_batch_size && - !mgx_state->temp_output_buffers.empty() - ); - - if (can_reuse) { - // Return raw pointers from existing buffers - std::vector ptrs; - ptrs.reserve(mgx_state->temp_output_buffers.size()); - for (const auto& buf : mgx_state->temp_output_buffers) { - ptrs.push_back(buf.data); + bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && + original_batch_size < padded_batch_size); + + for (const auto& extra : extras) { + auto ort_shape = extra.ort_shape; + std::size_t bytes = extra.bytes; + if (needs_slicing && !ort_shape.empty()) { + std::size_t full_batch = static_cast(ort_shape[0]); + if (full_batch > 0) { + bytes = (extra.bytes / full_batch) * original_batch_size; + } + ort_shape[0] = static_cast(original_batch_size); } - return ptrs; + + auto output_tensor = ctx.GetOutput(extra.output_index, ort_shape.data(), ort_shape.size()); + void* output_data = output_tensor.GetTensorMutableRawData(); + + HIP_CALL_THROW(hipMemcpyWithStream(output_data, extra.gpu_data, + bytes, hipMemcpyDeviceToDevice, stream)); } - - // Free old buffers if they exist - free_temp_output_buffers(mgx_state); - - // Count outputs and allocate - std::vector ptrs; - for (const auto& name : param_shapes.names()) { - // Skip inputs - if (map_input_name_index.find(name) != map_input_name_index.end()) { - continue; +} + +// Dispatch point: replay a cached hipGraph, capture one on first use, or fall back to eager. +// This replaces run_migraphx_program in all pinned-I/O paths when hipGraph is enabled. +// IMPORTANT: when hipGraph is enabled this function must ONLY be called via the pinned-I/O +// code path so that buffer addresses captured in the graph remain stable across replays. +static void run_program_or_hip_graph( + MIGraphXFuncState* mgx_state, + hipStream_t stream, + Ort::KernelContext& ctx, + migraphx::program& prog, + migraphx::program_parameters& m, + const std::vector& prog_output_indices, + const std::string& shape_hash, + std::size_t original_batch_size = 0, + std::size_t padded_batch_size = 0) +{ + if (!mgx_state->hip_graph_enabled) { + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + return; + } + + auto it = mgx_state->hip_graph_cache.find(shape_hash); + if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + replay_hip_graph(mgx_state, stream, shape_hash); + + if (!it->second.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, it->second.extra_outputs, + original_batch_size, padded_batch_size); } - - // This is an output - const auto output_index = compute_output_index(name); - if (output_index != -1) { - const auto& mgx_shape = param_shapes[name]; - std::size_t size_bytes = mgx_shape.bytes(); - - void* buffer = nullptr; - auto hip_status = hipMalloc(&buffer, size_bytes); - if (hip_status != hipSuccess) { - // Clean up any allocated buffers on failure - for (auto& buf : mgx_state->temp_output_buffers) { - if (buf.data) (void)hipFree(buf.data); - } - mgx_state->temp_output_buffers.clear(); - ORT_THROW("hipMalloc failed for temporary output buffer"); + } else { + if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, + prog_output_indices, shape_hash)) { + run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, + prog_output_indices, original_batch_size, padded_batch_size); + } else { + auto& entry = mgx_state->hip_graph_cache.at(shape_hash); + if (!entry.extra_outputs.empty()) { + materialize_extra_outputs(ctx, stream, entry.extra_outputs, + original_batch_size, padded_batch_size); } - - MIGraphXFuncState::TempOutputBuffer temp_buf; - temp_buf.data = buffer; - temp_buf.size_bytes = size_bytes; - temp_buf.mgx_shape = mgx_shape; - mgx_state->temp_output_buffers.push_back(temp_buf); - ptrs.push_back(buffer); - } } - - mgx_state->temp_output_padded_batch_size = padded_batch_size; - - return ptrs; } // Order matters here especially if the program uses mixed quantization @@ -1900,18 +2355,7 @@ static migraphx::program load_or_compile_model( { migraphx::program prog; - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ==== ENTERING ===="; - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Cache file: " << (cache_file.empty() ? "(none)" : cache_file.string()); - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Batch size: " << (batch_size > 0 ? std::to_string(batch_size) : "(default)"); - if (!load_precompiled_model(prog, cache_file)) { - // Cache miss - need to compile - if (batch_size > 0) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✗ CACHE MISS for batch size " << batch_size << " - COMPILING..."; - } else { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✗ CACHE MISS - COMPILING..."; - } - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Compilation started (this may take a while)..."; prog = CompileProgramWithBatch( onnx_string, @@ -1931,119 +2375,12 @@ static migraphx::program load_or_compile_model( all_input_base_shapes, batch_size); - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Compilation finished"; - save_compiled_model(prog, cache_file); - if (!cache_file.empty()) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Saved compiled model to disk: " << cache_file.string(); - } - } else { - // Cache hit - loaded from disk - if (batch_size > 0) { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✓ CACHE HIT - LOADING FROM DISK for batch size " << batch_size; - } else { - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ✓ CACHE HIT - LOADING FROM DISK"; - } - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] Loaded precompiled model from: " << cache_file.string(); } - - LOGS_DEFAULT(VERBOSE) << "[load_or_compile_model] ==== EXITING ===="; return prog; } -// Helper: Run the MIGraphX program and handle outputs -// This function executes the compiled MIGraphX program and copies outputs that -// were not pre-allocated (input parameters reused as outputs) to the ORT output tensors -// If original_batch_size is provided and < padded batch size, slices the output to remove padding -static void run_migraphx_program( - std::mutex* mgx_mu_ptr, - hipStream_t rocm_stream, - Ort::KernelContext& ctx, - migraphx::program& prog, - migraphx::program_parameters& m, - const std::vector& prog_output_indices, - std::size_t original_batch_size = 0, - std::size_t padded_batch_size = 0) -{ - std::optional prog_outputs; - { - std::lock_guard lock(*mgx_mu_ptr); - prog_outputs = prog.run_async(m, rocm_stream); - } - - bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); - - auto output_num = prog_outputs->size(); - - // Fast path: no padding/slicing and all outputs were pre-allocated — nothing to do. - if (!needs_slicing && prog_output_indices.size() == output_num) - return; - - std::unordered_set prog_output_indices_set(prog_output_indices.begin(), prog_output_indices.end()); - - if (needs_slicing && !prog_output_indices_set.empty()) { - // Must sync before reallocating any pre-allocated output buffer for slicing. - HIP_CALL_THROW(hipStreamSynchronize(rocm_stream)); - - for (std::size_t i = 0; i < output_num; ++i) { - if (prog_output_indices_set.count(i) == 0) continue; - auto gpu_res = (*prog_outputs)[i]; - migraphx::shape res_shape = gpu_res.get_shape(); - auto res_lens = res_shape.lengths(); - - std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (!ort_shape.empty() && static_cast(ort_shape[0]) != original_batch_size) { - ort_shape[0] = static_cast(original_batch_size); - - std::size_t bytes_per_batch = res_shape.bytes() / padded_batch_size; - std::size_t bytes_to_copy = bytes_per_batch * original_batch_size; - - const void* src_data = gpu_res.data(); - auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); - void* output_data = output_tensor.GetTensorMutableRawData(); - - if (output_data != src_data) { - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - src_data, - bytes_to_copy, - hipMemcpyDeviceToDevice, - rocm_stream)); - } - } - } - } - - // Copy outputs that were not pre-allocated into ORT output tensors. - // All copies are async on rocm_stream — no sync needed here. - for (std::size_t i = 0; i < output_num; ++i) { - if (prog_output_indices_set.count(i) > 0) continue; - - auto gpu_res = (*prog_outputs)[i]; - migraphx::shape res_shape = gpu_res.get_shape(); - auto res_lens = res_shape.lengths(); - - std::vector ort_shape{res_lens.begin(), res_lens.end()}; - if (needs_slicing && !ort_shape.empty()) { - ort_shape[0] = original_batch_size; - } - - auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); - void* output_data = output_tensor.GetTensorMutableRawData(); - - std::size_t bytes_to_copy = res_shape.bytes(); - if (needs_slicing && !res_lens.empty()) { - bytes_to_copy = (res_shape.bytes() / padded_batch_size) * original_batch_size; - } - - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - gpu_res.data(), - bytes_to_copy, - hipMemcpyDeviceToDevice, - rocm_stream)); - } -} // Helper: Handle input shape mismatch by recompiling the model with new input shapes // This function is called when runtime input shapes differ from compiled shapes @@ -2186,9 +2523,12 @@ std::pair> handle_program temp_buffer = (*temp_output_buffers)[temp_buffer_count]; } else { // Allocate new buffer (first run or buffer list is empty) - auto hip_status = hipMalloc(&temp_buffer, output_size_bytes); - if (hip_status != hipSuccess) { - ORT_THROW("hipMalloc failed for temporary output buffer"); + { + std::lock_guard alloc_lock(g_hip_alloc_mutex); + auto hip_status = hipMalloc(&temp_buffer, output_size_bytes); + if (hip_status != hipSuccess) { + ORT_THROW("hipMalloc failed for temporary output buffer"); + } } temp_output_buffers->push_back(temp_buffer); } @@ -2310,84 +2650,58 @@ static bool execute_ultra_fast_path( if (!mgx_state->caches_valid || mgx_state->last_input_shapes_raw.empty()) { return false; } - - // Ultra-fast path not supported when outputs need dynamic slicing + if (mgx_state->cached_outputs.empty()) { return false; } - // Quick shape comparison bool shapes_match = true; std::size_t offset = 0; const auto& last_shapes = mgx_state->last_input_shapes_raw; - + std::size_t original_batch_size = 0; std::size_t padded_batch_size = 0; bool is_first = true; for (const auto& inp : mgx_state->cached_inputs) { const auto& shape = ctx.GetInput(inp.ort_index).GetTensorTypeAndShapeInfo().GetShape(); - + if (offset + shape.size() > last_shapes.size()) { shapes_match = false; break; } - - // For dynamic batch, we check if the current batch needs padding + if (mgx_state->has_dynamic_batch && !mgx_state->compiled_batch_sizes.empty()) { - // Get batch sizes from first input if (is_first) { original_batch_size = static_cast(shape[0]); padded_batch_size = static_cast(last_shapes[offset]); is_first = false; - - // Check if the batch size matches (original or padded) + if (shape[0] != last_shapes[offset]) { - // Batch size changed - check if we can use padding std::size_t required_padded = find_nearest_compiled_batch_size( original_batch_size, mgx_state->compiled_batch_sizes); - if (required_padded != padded_batch_size) { shapes_match = false; break; } } } - - // All current inputs should have the same batch size (original_batch_size) + if (static_cast(shape[0]) != original_batch_size) { - shapes_match = false; - break; + shapes_match = false; break; } - - // Cached shape should have padded batch size in dimension 0 if (last_shapes[offset] != static_cast(padded_batch_size)) { - shapes_match = false; - break; + shapes_match = false; break; } - - // Check non-batch dimensions (current vs cached) - bool rest_matches = true; for (std::size_t i = 1; i < shape.size(); ++i) { - if (last_shapes[offset + i] != shape[i]) { - rest_matches = false; - break; - } - } - if (!rest_matches) { - shapes_match = false; - break; + if (last_shapes[offset + i] != shape[i]) { shapes_match = false; break; } } } else { - // No dynamic batching - strict comparison for (std::size_t i = 0; i < shape.size(); ++i) { - if (last_shapes[offset + i] != shape[i]) { - shapes_match = false; - break; - } + if (last_shapes[offset + i] != shape[i]) { shapes_match = false; break; } } } - + if (!shapes_match) break; offset += shape.size(); } @@ -2396,40 +2710,67 @@ static bool execute_ultra_fast_path( return false; } - // Ultra-fast path doesn't support output slicing because cached_output_ort_shapes - // contains padded shapes, not sliced shapes. Fall back to fast path which handles - // slicing properly via temp output buffers. - if (padded_batch_size > 0 && original_batch_size > 0 && padded_batch_size > original_batch_size) { - return false; - } - - // Shapes unchanged (or compatible with padding) - rebind pointers and run directly - auto& m = mgx_state->cached_prog_params.value(); auto& prog = mgx_state->prog; - - // Allocate and pad inputs if needed for dynamic batching - bool using_padded_inputs = false; - if (padded_batch_size > original_batch_size) { - using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); - } - - // Rebind inputs - use padded buffers if available, otherwise use original inputs - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - m.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); - } - } else { + std::size_t actual_batch = original_batch_size > 0 ? original_batch_size + : (!mgx_state->cached_inputs.empty() + ? static_cast(ctx.GetInput(mgx_state->cached_inputs[0].ort_index) + .GetTensorTypeAndShapeInfo().GetShape()[0]) + : 0); + std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; + bool needs_padding = (actual_batch < compiled_batch); + + // Direct-bind hipGraph: no copies, bind ORT pointers and replay + if (mgx_state->use_direct_hip_graph && !needs_padding) { + auto& m = mgx_state->cached_prog_params.value(); + std::unordered_map input_ptrs, output_ptrs; for (const auto& inp : mgx_state->cached_inputs) { const auto& input_tensor = ctx.GetInput(inp.ort_index); - m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, - const_cast(input_tensor.GetTensorRawData()))); + void* ptr = const_cast(input_tensor.GetTensorRawData()); + m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, ptr)); + input_ptrs[inp.name] = ptr; } + for (std::size_t i = 0; i < mgx_state->cached_outputs.size(); ++i) { + const auto& out = mgx_state->cached_outputs[i]; + const auto& ort_shape = mgx_state->cached_output_ort_shapes[i]; + auto output_tensor = ctx.GetOutput(out.output_index, ort_shape.data(), ort_shape.size()); + void* ptr = output_tensor.GetTensorMutableRawData(); + m.add(out.name.c_str(), migraphx::argument(out.mgx_shape, ptr)); + output_ptrs[out.name] = ptr; + } + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + mgx_state->cached_prog_output_indices, + mgx_state->last_input_shape_hash, + input_ptrs, output_ptrs); + return true; } - // Rebind outputs - direct iteration, uses pre-allocated shape vectors + // Pinned-copy path: padding needed or legacy hipGraph path + bool needs_pinned = (needs_padding || mgx_state->hip_graph_enabled) + && mgx_state->pinned_io.allocated; + + if (needs_pinned && mgx_state->cached_mgx_param_shapes.has_value()) { + const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); + const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); + + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); + + auto& m = mgx_state->cached_prog_params.value(); + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, + mgx_state->cached_prog_output_indices, + mgx_state->last_input_shape_hash); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + mgx_state->cached_pinned_output_indices, + ctx, actual_batch, rocm_stream); + return true; + } + + auto& m = mgx_state->cached_prog_params.value(); + for (const auto& inp : mgx_state->cached_inputs) { + const auto& input_tensor = ctx.GetInput(inp.ort_index); + m.add(inp.name.c_str(), migraphx::argument(inp.mgx_shape, + const_cast(input_tensor.GetTensorRawData()))); + } for (std::size_t i = 0; i < mgx_state->cached_outputs.size(); ++i) { const auto& out = mgx_state->cached_outputs[i]; const auto& ort_shape = mgx_state->cached_output_ort_shapes[i]; @@ -2437,12 +2778,8 @@ static bool execute_ultra_fast_path( m.add(out.name.c_str(), migraphx::argument(out.mgx_shape, output_tensor.GetTensorMutableRawData())); } - - // Run directly - minimal overhead path run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, - mgx_state->cached_prog_output_indices, - original_batch_size, padded_batch_size); - + mgx_state->cached_prog_output_indices); return true; } @@ -2456,125 +2793,95 @@ static bool execute_fast_path( const std::string& current_hash, std::vector& all_input_shapes) { - - if (mgx_state->defer_compilation || !mgx_state->cached_programs_ref.has_value()) { + if (!mgx_state->cached_programs_ref.has_value()) { return false; } auto& cached_programs = mgx_state->cached_programs_ref.value().get(); + + if (mgx_state->defer_compilation && cached_programs.empty()) { + return false; + } + auto prog_it = cached_programs.find(current_hash); - - // If not found directly, check if we need to use a padded batch size + std::size_t original_batch_size = 0; std::size_t padded_batch_size = 0; bool needs_padding = false; - - if (prog_it == cached_programs.end() && mgx_state->has_dynamic_batch && + + if (prog_it == cached_programs.end() && mgx_state->has_dynamic_batch && !mgx_state->compiled_batch_sizes.empty()) { - // Try to find a padded batch size const auto& map_input_name_index = mgx_state->input_name_indexes; - + for (const auto& [name, index] : map_input_name_index) { auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_shape = input_tensor.GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { original_batch_size = static_cast(tensor_shape[0]); padded_batch_size = find_nearest_compiled_batch_size(original_batch_size, - mgx_state->compiled_batch_sizes); + mgx_state->compiled_batch_sizes); needs_padding = (padded_batch_size > original_batch_size); break; } } - + if (needs_padding && padded_batch_size > 0) { - // Build padded shapes in alphabetical order (map order) for hash calculation - // This matches the order used during compilation in compile_dynamic_batch_models std::vector padded_shapes_for_hash; padded_shapes_for_hash.reserve(all_input_shapes.size()); - for (const auto& [name, index] : map_input_name_index) { - auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - + const auto tensor_shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { padded_shapes_for_hash.push_back(static_cast(padded_batch_size)); padded_shapes_for_hash.insert(padded_shapes_for_hash.end(), tensor_shape.begin() + 1, tensor_shape.end()); } } - auto padded_hash = make_hash(padded_shapes_for_hash); prog_it = cached_programs.find(padded_hash); - - if (prog_it != cached_programs.end()) { - - // Now rebuild padded_shapes in cached_inputs order for saving to last_input_shapes_raw - // This ensures ultra-fast path shape comparison works correctly - if (!mgx_state->cached_inputs.empty()) { - std::vector padded_shapes_for_cache; - padded_shapes_for_cache.reserve(mgx_state->cached_inputs.size() * 2); - - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (!tensor_shape.empty()) { - padded_shapes_for_cache.push_back(static_cast(padded_batch_size)); - padded_shapes_for_cache.insert(padded_shapes_for_cache.end(), tensor_shape.begin() + 1, tensor_shape.end()); - } + + if (prog_it != cached_programs.end() && !mgx_state->cached_inputs.empty()) { + std::vector padded_shapes_for_cache; + padded_shapes_for_cache.reserve(mgx_state->cached_inputs.size() * 2); + for (const auto& cached_inp : mgx_state->cached_inputs) { + const auto tensor_shape = ctx.GetInput(cached_inp.ort_index).GetTensorTypeAndShapeInfo().GetShape(); + if (!tensor_shape.empty()) { + padded_shapes_for_cache.push_back(static_cast(padded_batch_size)); + padded_shapes_for_cache.insert(padded_shapes_for_cache.end(), tensor_shape.begin() + 1, tensor_shape.end()); } - all_input_shapes = std::move(padded_shapes_for_cache); - } else { - // Fallback: use map order (shouldn't happen if caches are populated) - all_input_shapes = std::move(padded_shapes_for_hash); } + all_input_shapes = std::move(padded_shapes_for_cache); } } } - + if (prog_it == cached_programs.end()) { return false; } - // Determine which hash was used to find the program - // This is needed to detect program changes and invalidate caches std::string effective_program_hash = current_hash; if (needs_padding && padded_batch_size > 0) { - // If we used padded hash, compute it for tracking std::vector padded_shapes_for_hash_tracking; for (const auto& [name, index] : mgx_state->input_name_indexes) { - auto input_tensor = ctx.GetInput(index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); + const auto tensor_shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); if (!tensor_shape.empty()) { padded_shapes_for_hash_tracking.push_back(static_cast(padded_batch_size)); - padded_shapes_for_hash_tracking.insert(padded_shapes_for_hash_tracking.end(), + padded_shapes_for_hash_tracking.insert(padded_shapes_for_hash_tracking.end(), tensor_shape.begin() + 1, tensor_shape.end()); } } effective_program_hash = make_hash(padded_shapes_for_hash_tracking); } - // Found cached program - use it and populate caches auto& prog = mgx_state->prog; prog = prog_it->second; const auto& map_input_name_index = mgx_state->input_name_indexes; - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 1: Cache MIGraphX API results (avoid redundant API calls) - // Check if program changed - if so, invalidate caches - // ═══════════════════════════════════════════════════════════════════════════ bool program_changed = (mgx_state->cached_program_hash != effective_program_hash); - if (program_changed) { clear_cached_mgx_shapes(mgx_state); - free_temp_output_buffers(mgx_state); mgx_state->cached_program_hash = effective_program_hash; } - + if (!mgx_state->cached_mgx_param_shapes.has_value()) { mgx_state->cached_mgx_param_shapes = prog.get_parameter_shapes(); mgx_state->cached_mgx_output_shapes = prog.get_output_shapes(); @@ -2582,68 +2889,96 @@ static bool execute_fast_path( const auto& param_shapes = mgx_state->cached_mgx_param_shapes.value(); const auto& output_shapes = mgx_state->cached_mgx_output_shapes.value(); - bool needs_slicing = (original_batch_size > 0 && padded_batch_size > 0 && - original_batch_size < padded_batch_size); - - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 2: Skip populate_ultra_fast_caches when already populated - // ═══════════════════════════════════════════════════════════════════════════ if (!mgx_state->ultra_fast_caches_populated) { populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, original_batch_size, padded_batch_size); mgx_state->ultra_fast_caches_populated = true; } - // Allocate and pad inputs if needed for dynamic batching - bool using_padded_inputs = false; - if (padded_batch_size > original_batch_size) { - using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); + std::size_t actual_batch = original_batch_size > 0 ? original_batch_size : 0; + if (actual_batch == 0) { + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } + } } + std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; + bool fast_needs_padding = (actual_batch < compiled_batch); - // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION 3: Reuse temp output buffers when slicing - // ═══════════════════════════════════════════════════════════════════════════ - std::vector temp_output_buffer_ptrs; - if (needs_slicing) { - temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( - mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); + // Direct-bind hipGraph path: bind ORT pointers and replay, no copies + if (mgx_state->use_direct_hip_graph && !fast_needs_padding) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } + } + } + + mgx_state->cached_prog_params = std::move(m); + mgx_state->cached_prog_output_indices = std::move(prog_output_indices); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, + mgx_state->cached_prog_params.value(), + mgx_state->cached_prog_output_indices, + effective_program_hash, + input_ptrs, output_ptrs); + return true; + } + + // Pinned-copy path: padding needed or legacy hipGraph + bool needs_pinned = (fast_needs_padding || mgx_state->hip_graph_enabled) + && mgx_state->pinned_io.allocated; + + if (needs_pinned) { + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, compiled_batch, rocm_stream); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + + mgx_state->cached_prog_params = std::move(bind_result.params); + mgx_state->cached_prog_output_indices = std::move(bind_result.prog_output_indices); + mgx_state->cached_pinned_output_indices = std::move(bind_result.pinned_output_indices); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( + mgx_state, ctx, padded_batch_size); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, + mgx_state->cached_prog_params.value(), + mgx_state->cached_prog_output_indices, + effective_program_hash); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, + mgx_state->cached_pinned_output_indices, + ctx, actual_batch, rocm_stream); + return true; } - // Bind inputs/outputs (use temp buffers for outputs when slicing) auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx, needs_slicing, - needs_slicing ? &temp_output_buffer_ptrs : nullptr); + param_shapes, output_shapes, map_input_name_index, ctx); mgx_state->cached_prog_params = std::move(m); mgx_state->cached_prog_output_indices = std::move(prog_output_indices); - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering - mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( - mgx_state, ctx, using_padded_inputs ? padded_batch_size : 0); - + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - // Rebind padded inputs to program parameters - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - auto& prog_params = mgx_state->cached_prog_params.value(); - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - prog_params.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); - } - } - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), - mgx_state->cached_prog_output_indices, - original_batch_size, padded_batch_size); - - // NOTE: Temp output buffers are kept for reuse - they will be freed when batch size changes - // NOTE: Padded input buffers are also kept for reuse - + mgx_state->cached_prog_output_indices); return true; } @@ -2726,45 +3061,39 @@ static InputShapeResult handle_input_shape( } // Helper: Compile models for all configured batch sizes and cache them +// rocm_stream is the per-Run compute stream resolved from ctx.GetGPUComputeStream() +// in compute_func. It MUST be threaded through to allocate_pinned_io so the +// stream-ordered memory pool used by hipMallocAsync has the same lineage as the +// stream that will later issue copies, captured-graph launches, and replays +// against those pinned buffers. Using a different stream here (e.g. the EP's +// own mgx_state->stream) is undefined behavior under the hipMemPool semantics +// and on ROCm typically surfaces as the captured graph reading stale or +// uninitialized pinned memory on first replay. static void compile_dynamic_batch_models( MIGraphXFuncState* mgx_state, const std::filesystem::path& model_cache_path, const std::filesystem::path& model_path, const std::string& mxr_filename_prefix, - const Ort::KernelContext& ctx) { - - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== ENTERING compile_dynamic_batch_models ===="; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] has_dynamic_batch = " << mgx_state->has_dynamic_batch; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] compiled_batch_sizes.size() = " - << mgx_state->compiled_batch_sizes.size(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] max_dynamic_batch = " << mgx_state->max_dynamic_batch; + const Ort::KernelContext& ctx, + hipStream_t rocm_stream) { if (!mgx_state->has_dynamic_batch || mgx_state->compiled_batch_sizes.empty()) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Skipping - dynamic batch disabled or no batch sizes"; return; } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Compiling models for " - << mgx_state->compiled_batch_sizes.size() << " batch sizes"; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Batch sizes: "; - for (const auto& bs : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] - " << bs; - } - // Get input names and base shapes (without batch dimension) const auto& map_input_name_index = mgx_state->input_name_indexes; std::vector input_names; std::vector> all_input_base_shapes; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Processing " << map_input_name_index.size() << " input parameters"; for (const auto& [name, index] : map_input_name_index) { input_names.push_back(name); auto input_tensor = ctx.GetInput(index); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Input '" << name << "' (index " << index - << ") runtime shape: [" << [&]() { + LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Input '" << name << "' (index " << index + << ") runtime shape: [" << [&]() { std::ostringstream ss; for (size_t i = 0; i < tensor_shape.size(); ++i) { if (i > 0) ss << ", "; @@ -2792,9 +3121,7 @@ static void compile_dynamic_batch_models( // Compile a model for each configured batch size for (const auto& batch_size : mgx_state->compiled_batch_sizes) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ---- Processing batch size: " << batch_size << " ----"; - // Build cache key for this batch size std::vector batch_shape_key; for (size_t i = 0; i < input_names.size(); ++i) { batch_shape_key.push_back(batch_size); @@ -2804,40 +3131,18 @@ static void compile_dynamic_batch_models( } auto cache_hash = make_hash(batch_shape_key); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Shape key for batch " << batch_size << ": [" << [&]() { - std::ostringstream ss; - for (size_t i = 0; i < batch_shape_key.size(); ++i) { - if (i > 0) ss << ", "; - ss << batch_shape_key[i]; - } - return ss.str(); - }() << "]"; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Cache hash: " << cache_hash; - - // Check if already cached if (mgx_state->cached_programs_ref.has_value()) { auto& cached_progs = mgx_state->cached_programs_ref.value().get(); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Checking in-memory cache (size: " - << cached_progs.size() << ")"; if (cached_progs.find(cache_hash) != cached_progs.end()) { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ✓ Batch size " << batch_size - << " already in memory cache, skipping"; continue; } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Cache miss - need to compile/load"; } - // Build cache file path std::filesystem::path batch_cache_file; if (!model_cache_path.empty()) { batch_cache_file = model_cache_path / (mxr_filename_prefix + cache_hash + ".mxr"); - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Disk cache file: " << batch_cache_file.string(); - } else { - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] No disk cache path configured"; } - // Compile or load the model for this batch size - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Calling load_or_compile_model for batch " << batch_size; migraphx::program batch_prog = load_or_compile_model( batch_cache_file, mgx_state->onnx_string, @@ -2857,27 +3162,53 @@ static void compile_dynamic_batch_models( all_input_base_shapes, batch_size); - // Store in cache if (mgx_state->cached_programs_ref.has_value()) { mgx_state->cached_programs_ref.value().get()[cache_hash] = batch_prog; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ✓ Stored program for batch size " << batch_size - << " in memory cache with hash " << cache_hash; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Memory cache now contains " - << mgx_state->cached_programs_ref.value().get().size() << " programs"; } } - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== All batch models compiled and cached ===="; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Setting max_dynamic_batch to 0 to disable future compilation"; - - // Disable dynamic batch compilation for subsequent runs (set max_dynamic_batch to 0) mgx_state->max_dynamic_batch = 0; - - // Also disable defer_compilation since we've now compiled mgx_state->defer_compilation = false; - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] Set defer_compilation = false"; - - LOGS_DEFAULT(VERBOSE) << "[DynamicBatch][COMPILE] ==== EXITING compile_dynamic_batch_models ===="; + + // Allocate pinned I/O now that all batch models are compiled. + // Must use the largest-batch program's shapes so the buffer count and + // parameter ordering match every subsequent bind_pinned_program_params call. + if (!mgx_state->pinned_io.allocated && mgx_state->cached_programs_ref.has_value()) { + auto& progs = mgx_state->cached_programs_ref.value().get(); + if (!progs.empty()) { + std::size_t max_batch = 0; + if (!mgx_state->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); + } + migraphx::program* largest_prog = nullptr; + std::size_t largest_batch_found = 0; + for (auto& [hash, prog] : progs) { + auto ps = prog.get_parameter_shapes(); + std::size_t prog_batch = 0; + for (const auto& name : ps.names()) { + if (mgx_state->input_name_indexes.find(name) != mgx_state->input_name_indexes.end()) { + auto lens = ps[name].lengths(); + if (!lens.empty() && lens[0] > 0) { + prog_batch = lens[0]; + break; + } + } + } + if (prog_batch > largest_batch_found) { + largest_batch_found = prog_batch; + largest_prog = &prog; + } + } + if (max_batch == 0) max_batch = largest_batch_found; + if (largest_prog && max_batch > 0) { + auto ps = largest_prog->get_parameter_shapes(); + auto os = largest_prog->get_output_shapes(); + allocate_pinned_io(mgx_state, ps, os, max_batch, rocm_stream); + } + } + } + } // Standard path: Shape checking, potential recompilation, and execution @@ -2901,11 +3232,18 @@ static void execute_standard_path( // If precompilation happened during Compile(), max_dynamic_batch will be > 0 but defer_compilation = false // In that case, the programs are already in cache and we can skip runtime compilation if (mgx_state->has_dynamic_batch && mgx_state->max_dynamic_batch > 0 && mgx_state->defer_compilation) { - // Runtime compilation path - used when precompilation was not possible (e.g., non-pure dynamic batch) - - // Compile all batch models at runtime - compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx); - + compile_dynamic_batch_models(mgx_state, model_cache_path, model_path, mxr_filename_prefix, ctx, rocm_stream); + + // Validate newly compiled programs for hipGraph compatibility + if (mgx_state->hip_graph_enabled && mgx_state->cached_programs_ref.has_value()) { + for (const auto& [hash, cached_prog] : mgx_state->cached_programs_ref.value().get()) { + if (!check_hip_graph_compatibility(cached_prog, "runtime_dynamic_batch")) { + mgx_state->hip_graph_enabled = false; + mgx_state->use_direct_hip_graph = false; + break; + } + } + } } else if (mgx_state->has_dynamic_batch) { } @@ -2958,91 +3296,107 @@ static void execute_standard_path( if (prog_it != cached_progs.end()) { prog = prog_it->second; - // Get shapes for the cached program auto param_shapes = prog.get_parameter_shapes(); auto output_shapes = prog.get_output_shapes(); - if (needs_padding) { - // ============ PADDING PATH: Batch size needs to be padded ============ - - // Populate caches (with slicing info so ultra-fast path is disabled) - populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, - original_batch_size, padded_batch_size); - - // Rebuild padded_shapes in cached_inputs order (MIGraphX parameter order) - // This ensures consistency with ultra-fast path shape comparison - padded_shapes.clear(); - padded_shapes.reserve(mgx_state->cached_inputs.size() * 2); - for (const auto& cached_inp : mgx_state->cached_inputs) { - auto input_tensor = ctx.GetInput(cached_inp.ort_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shape = tensor_info.GetShape(); - - if (!tensor_shape.empty()) { - padded_shapes.push_back(static_cast(padded_batch_size)); - padded_shapes.insert(padded_shapes.end(), tensor_shape.begin() + 1, tensor_shape.end()); + populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index, + original_batch_size, padded_batch_size); + + // Direct-bind hipGraph for exact-match batch (no padding) + if (mgx_state->use_direct_hip_graph && !needs_padding) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } } } - - // Allocate and pad inputs for dynamic batching - bool using_padded_inputs = allocate_and_pad_inputs(mgx_state, ctx, original_batch_size, - padded_batch_size, rocm_stream); - - // Get or reuse cached temp output buffers (avoids hipMalloc/hipFree per run) - auto temp_output_buffer_ptrs = get_or_allocate_temp_output_buffers( - mgx_state, param_shapes, output_shapes, map_input_name_index, padded_batch_size); - auto [m, prog_output_indices] = handle_program_input_outputs( - param_shapes, output_shapes, map_input_name_index, ctx, true, &temp_output_buffer_ptrs); - mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - mgx_state->last_input_shapes_raw = std::move(padded_shapes); + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; - - // Rebind padded inputs to program parameters - if (using_padded_inputs && mgx_state->padded_input_buffers.size() == mgx_state->cached_inputs.size()) { - for (size_t i = 0; i < mgx_state->cached_inputs.size(); ++i) { - const auto& inp = mgx_state->cached_inputs[i]; - const auto& padded_buf = mgx_state->padded_input_buffers[i]; - m.add(inp.name.c_str(), migraphx::argument(padded_buf.mgx_shape, padded_buf.data)); + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + prog_output_indices, padded_hash, + input_ptrs, output_ptrs); + return; + } + + bool use_pinned = needs_padding || mgx_state->hip_graph_enabled; + if (use_pinned) { + if (!mgx_state->pinned_io.allocated) { + std::size_t max_batch = padded_batch_size; + if (!mgx_state->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); + } + auto alloc_ps = param_shapes; + auto alloc_os = output_shapes; + if (max_batch > padded_batch_size && mgx_state->cached_programs_ref.has_value()) { + bool found = false; + for (auto& [h, p] : mgx_state->cached_programs_ref.value().get()) { + if (found) break; + auto candidate_ps = p.get_parameter_shapes(); + for (const auto& nm : candidate_ps.names()) { + if (mgx_state->input_name_indexes.count(nm)) { + auto lens = candidate_ps[nm].lengths(); + if (!lens.empty() && lens[0] == max_batch) { + alloc_ps = candidate_ps; + alloc_os = p.get_output_shapes(); + found = true; + } + break; + } + } + } } + allocate_pinned_io(mgx_state, alloc_ps, alloc_os, max_batch, rocm_stream); } - - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, - prog_output_indices, original_batch_size, padded_batch_size); - - // Temp output buffers are cached on mgx_state for reuse across runs. - // They are freed when the batch size changes or when the state is destroyed. - - return; + + std::size_t copy_actual = needs_padding ? original_batch_size : padded_batch_size; + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, copy_actual, padded_batch_size, rocm_stream); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + + mgx_state->cached_prog_params = bind_result.params; + mgx_state->cached_prog_output_indices = bind_result.prog_output_indices; + mgx_state->cached_pinned_output_indices = bind_result.pinned_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order( + mgx_state, ctx, padded_batch_size); + mgx_state->last_input_shape_hash = padded_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, + bind_result.prog_output_indices, padded_hash); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, + bind_result.pinned_output_indices, + ctx, copy_actual, rocm_stream); } else { - // ============ EXACT MATCH PATH: Batch size matches exactly, no padding needed ============ - - // Populate caches for ultra-fast path (no slicing needed) - populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); - - // Bind inputs and allocate outputs (no slicing) auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); - - // Complete cache population + mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); - mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - - run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices, - 0, 0); - - return; + + run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } + return; } } } @@ -3052,7 +3406,6 @@ static void execute_standard_path( mgx_state->defer_compilation, map_input_name_index, ctx, cmp_options, prog); if (!input_shape_match) { - // Invalidate caches before recompilation mgx_state->caches_valid = false; handle_input_shape_mismatch( @@ -3064,34 +3417,103 @@ static void execute_standard_path( param_shapes, input_shapes); - // Re-fetch param_shapes after recompilation param_shapes = prog.get_parameter_shapes(); + + if (mgx_state->hip_graph_enabled && !check_hip_graph_compatibility(prog, "standard_path_recompile")) { + mgx_state->hip_graph_enabled = false; + mgx_state->use_direct_hip_graph = false; + } } - // Fetch output shapes once auto output_shapes = prog.get_output_shapes(); - // Populate optimized caches for ultra-fast path populate_ultra_fast_caches(mgx_state, param_shapes, output_shapes, map_input_name_index); - // Bind inputs and allocate outputs + // Allocate pinned I/O: required for hipGraph (stable addresses), also useful for future pad/slice. + if (!mgx_state->pinned_io.allocated) { + std::size_t batch_for_alloc = 0; + if (!mgx_state->compiled_batch_sizes.empty()) { + batch_for_alloc = *std::max_element(mgx_state->compiled_batch_sizes.begin(), + mgx_state->compiled_batch_sizes.end()); + } + if (batch_for_alloc == 0) { + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { batch_for_alloc = static_cast(shape[0]); break; } + } + } + if (batch_for_alloc > 0) { + allocate_pinned_io(mgx_state, param_shapes, output_shapes, batch_for_alloc, rocm_stream); + } + } + + // Direct-bind hipGraph for standard path (no padding case) + if (mgx_state->use_direct_hip_graph) { + auto [m, prog_output_indices] = handle_program_input_outputs( + param_shapes, output_shapes, map_input_name_index, ctx); + + std::unordered_map input_ptrs, output_ptrs; + for (const auto& name : param_shapes.names()) { + auto inp_it = map_input_name_index.find(name); + if (inp_it != map_input_name_index.end()) { + input_ptrs[name] = const_cast(ctx.GetInput(inp_it->second).GetTensorRawData()); + } else { + const auto oi = compute_output_index(name); + if (oi != -1) { + const auto& lens = output_shapes[oi].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + auto ot = ctx.GetOutput(oi, ort_shape.data(), ort_shape.size()); + output_ptrs[name] = ot.GetTensorMutableRawData(); + } + } + } + + mgx_state->cached_prog_params = m; + mgx_state->cached_prog_output_indices = prog_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, + prog_output_indices, current_hash, + input_ptrs, output_ptrs); + return; + } + + if (mgx_state->hip_graph_enabled && mgx_state->pinned_io.allocated) { + std::size_t actual_batch = 0; + for (const auto& [name, index] : map_input_name_index) { + auto shape = ctx.GetInput(index).GetTensorTypeAndShapeInfo().GetShape(); + if (!shape.empty()) { actual_batch = static_cast(shape[0]); break; } + } + copy_inputs_to_pinned(mgx_state, param_shapes, ctx, actual_batch, actual_batch, rocm_stream); + auto bind_result = bind_pinned_program_params(mgx_state, param_shapes, output_shapes); + + mgx_state->cached_prog_params = bind_result.params; + mgx_state->cached_prog_output_indices = bind_result.prog_output_indices; + mgx_state->cached_pinned_output_indices = bind_result.pinned_output_indices; + mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); + mgx_state->last_input_shape_hash = current_hash; + mgx_state->caches_valid = true; + + run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, + bind_result.prog_output_indices, current_hash); + + copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, + bind_result.pinned_output_indices, + ctx, actual_batch, rocm_stream); + return; + } + auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx); - // Complete cache population mgx_state->cached_prog_params = m; mgx_state->cached_prog_output_indices = prog_output_indices; - - // IMPORTANT: Build last_input_shapes_raw in cached_inputs order (MIGraphX parameter order) - // This ensures ultra-fast path shape comparison uses consistent ordering mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); - mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; - // The program at this point was compiled (or already matched) for the exact - // runtime input shapes, so its outputs already have the original batch - // dimension. Pass 0,0 to avoid incorrect slicing arithmetic. run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } @@ -3262,53 +3684,19 @@ extract_base_shapes_from_graph( const NodeArg* nodearg = it->second; auto tensor_shape = nodearg->Shape(); if (tensor_shape != nullptr && tensor_shape->dim_size() > 1) { - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Processing input '" << name - << "' with " << tensor_shape->dim_size() << " dimensions"; - // Extract non-batch dimensions (skip dim 0) for (int j = 1; j < tensor_shape->dim_size(); ++j) { const auto& dim = tensor_shape->dim(j); if (dim.has_dim_value()) { base_shape.push_back(dim.dim_value()); - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] dim[" << j << "] = " << dim.dim_value(); } else { - // Symbolic non-batch dimension found - cannot precompile - // Do NOT default to any value - mark as failure - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' dim " << j << " is symbolic - cannot extract concrete base shapes"; all_concrete = false; } } - } else if (tensor_shape != nullptr && tensor_shape->dim_size() == 1) { - // Single dimension input (just batch) - no base shape needed - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Input '" << name - << "' has only 1 dimension (batch only) - empty base shape"; - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' has null or empty shape"; } - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Input '" << name - << "' not found in NodeArg map - may be subgraph-only input"; } base_shapes.push_back(base_shape); - - // Log the extracted base shape - std::ostringstream ss; - ss << "["; - for (std::size_t k = 0; k < base_shape.size(); ++k) { - if (k > 0) ss << ", "; - ss << base_shape[k]; - } - ss << "]"; - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Input '" << name << "' base_shape: " << ss.str(); } - if (all_concrete) { - LOGS_DEFAULT(VERBOSE) << "[extract_base_shapes_from_graph] Successfully extracted " << ordered_names.size() - << " input base shapes (all concrete)"; - } else { - LOGS_DEFAULT(WARNING) << "[extract_base_shapes_from_graph] Failed - found symbolic non-batch dimensions"; - } return {all_concrete, ordered_names, base_shapes}; } @@ -3656,9 +4044,6 @@ static inline void precompile_static_model( } } } - } else { - LOGS_DEFAULT(WARNING) << "[precompile_static_model] Input '" << name - << "' not found in NodeArg map - skipping"; } full_shapes.push_back(shape); } @@ -3722,6 +4107,54 @@ static inline void precompile_static_model( LOGS_DEFAULT(INFO) << "[precompile_static_model] ✓ Static model precompiled and cached"; } +// Scan disk cache for .mxr files matching the node prefix and pre-load them +// into the in-memory cache. Eliminates first-inference stalls for deferred +// compilation where .mxr files exist from a previous session. +static void preload_mxr_cache_from_disk( + const std::filesystem::path& model_cache_path, + const std::string& mxr_filename_prefix, + std::unordered_map& cached_programs) +{ + if (model_cache_path.empty() || !std::filesystem::exists(model_cache_path)) return; + + const std::string suffix = ".mxr"; + std::vector> to_load; + + for (const auto& entry : std::filesystem::directory_iterator(model_cache_path)) { + if (!entry.is_regular_file()) continue; + const auto fname = entry.path().filename().string(); + if (fname.size() <= mxr_filename_prefix.size() + suffix.size()) continue; + if (fname.substr(0, mxr_filename_prefix.size()) != mxr_filename_prefix) continue; + if (fname.substr(fname.size() - suffix.size()) != suffix) continue; + + auto hash = fname.substr(mxr_filename_prefix.size(), + fname.size() - mxr_filename_prefix.size() - suffix.size()); + if (cached_programs.find(hash) != cached_programs.end()) continue; + to_load.emplace_back(hash, entry.path()); + } + + if (to_load.empty()) return; + + LOGS_DEFAULT(INFO) << "[preload_mxr_cache] Found " << to_load.size() + << " .mxr file(s) to pre-load for prefix '" << mxr_filename_prefix << "'"; + + std::mutex mu; + std::vector> futs; + for (const auto& [hash, path] : to_load) { + futs.push_back(std::async(std::launch::async, [&, hash, path]() { + migraphx::program prog; + if (load_precompiled_model(prog, path)) { + std::lock_guard lk(mu); + cached_programs[hash] = std::move(prog); + } + })); + } + for (auto& f : futs) f.get(); + + LOGS_DEFAULT(INFO) << "[preload_mxr_cache] Pre-loaded " << cached_programs.size() + << " program(s) into in-memory cache"; +} + // Encapsulates precompilation decision logic from Compile() // Returns true if compilation should be deferred to runtime, false if precompilation succeeded static inline bool handle_precompilation_decision( @@ -3994,6 +4427,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& max_dynamic_batch_, compile_batches_); + // Pre-load any .mxr files from disk that aren't already in memory. + preload_mxr_cache_from_disk(model_cache_path_, mxr_filename_prefix, + cached_programs_[fused_node.Name()]); + // Create program object (may be empty if precompiled programs are in cache) migraphx::program prog; map_progs_[fused_node.Name()] = prog; @@ -4006,13 +4443,29 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& NodeComputeInfo compute_info; compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); - *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], - map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_defer_compilation_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map_, - model_cache_path_, dump_model_ops_, exhaustive_tune_, max_dynamic_batch_, - std::ref(cached_programs_[context->node_name])}; - + p->allocate_func = context->allocate_func; + p->release_func = context->release_func; + p->allocate_handle = context->allocator_handle; + p->prog = map_progs_[context->node_name]; + p->onnx_string = map_onnx_string_[context->node_name]; + p->options = options; + p->t = t_; + p->input_name_indexes = map_input_index_[context->node_name]; + p->mgx_mu_ptr = &mgx_mu_; + p->stream = stream_; + p->defer_compilation = map_defer_compilation_[context->node_name]; + p->fp16_enable = fp16_enable_; + p->bf16_enable = bf16_enable_; + p->fp8_enable = fp8_enable_; + p->int8_enable = int8_enable_; + p->int8_calibration_cache_available = int8_calibration_cache_available_; + p->dynamic_range_map = dynamic_range_map_; + p->model_cache_dir = model_cache_path_; + p->dump_model_ops = dump_model_ops_; + p->exhaustive_tune = exhaustive_tune_; + p->max_dynamic_batch = max_dynamic_batch_; + p->cached_programs_ref = std::ref(cached_programs_[context->node_name]); + // Initialize dynamic batch support if max_dynamic_batch > 0 if (max_dynamic_batch_ > 0) { p->has_dynamic_batch = true; @@ -4036,7 +4489,74 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] Static model mode for node '" << context->node_name << "'"; LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] defer_compilation=" << p->defer_compilation; } - + + // Allocate pinned I/O buffers from the cached programs. + // create_state_func runs ONCE at session init (long before any Run()), + // so there is no per-Run compute stream to query here — ComputeContext + // does not expose one. We use stream_ (the EP-owned init stream) and + // rely on the hipStreamSynchronize(stream) inside allocate_pinned_io to + // establish the hipMallocAsync pool memory so the per-Run compute stream + // (resolved from ctx.GetGPUComputeStream() in compute_func) can safely + // consume these pointers without further cross-stream ordering. + // Uses the program compiled for the largest batch size so that + // allocate_pinned_io sees parameter shapes whose batch dim matches + // max_batch. All smaller batches share the same buffers. + if (p->cached_programs_ref.has_value() && !p->cached_programs_ref.value().get().empty()) { + std::size_t max_batch = 0; + if (!p->compiled_batch_sizes.empty()) { + max_batch = *std::max_element(p->compiled_batch_sizes.begin(), + p->compiled_batch_sizes.end()); + } + migraphx::program* largest_prog = nullptr; + std::size_t largest_batch_found = 0; + for (auto& [hash, prog] : p->cached_programs_ref.value().get()) { + auto ps = prog.get_parameter_shapes(); + std::size_t prog_batch = 0; + for (const auto& name : ps.names()) { + if (p->input_name_indexes.find(name) != p->input_name_indexes.end()) { + auto lens = ps[name].lengths(); + if (!lens.empty() && lens[0] > 0) { + prog_batch = lens[0]; + break; + } + } + } + if (prog_batch > largest_batch_found) { + largest_batch_found = prog_batch; + largest_prog = &prog; + } + } + if (max_batch == 0) max_batch = largest_batch_found; + if (largest_prog && max_batch > 0) { + auto ps = largest_prog->get_parameter_shapes(); + auto os = largest_prog->get_output_shapes(); + allocate_pinned_io(p.get(), ps, os, max_batch, stream_); + } + + // If all batch sizes are pre-loaded, disable deferred compilation + if (p->defer_compilation && p->has_dynamic_batch && !p->compiled_batch_sizes.empty()) { + auto& progs = p->cached_programs_ref.value().get(); + if (progs.size() >= p->compiled_batch_sizes.size()) { + p->defer_compilation = false; + LOGS_DEFAULT(INFO) << "[Compile][CREATE_STATE] All " << p->compiled_batch_sizes.size() + << " batch model(s) pre-loaded — defer_compilation disabled"; + } + } + } + + // hipGraph: set per-node enable flag and validate cached programs + p->hip_graph_enabled = hip_graph_enable_; + p->use_direct_hip_graph = hip_graph_enable_; + if (p->hip_graph_enabled && p->cached_programs_ref.has_value()) { + for (const auto& [hash, cached_prog] : p->cached_programs_ref.value().get()) { + if (!check_hip_graph_compatibility(cached_prog, context->node_name)) { + p->hip_graph_enabled = false; + p->use_direct_hip_graph = false; + break; + } + } + } + *state = p.release(); return 0; }; @@ -4044,12 +4564,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& compute_info.release_state_func = [](FunctionState state) { if (state) { auto* s = static_cast(state); - for (auto& buf : s->padded_input_buffers) { - if (buf.data) (void)hipFree(buf.data); - } - for (auto& buf : s->temp_output_buffers) { - if (buf.data) (void)hipFree(buf.data); - } + destroy_hip_graphs(s); + free_pinned_io(s, s->stream); delete s; } }; @@ -4058,21 +4574,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); - const auto& map_input_name_index = mgx_state->input_name_indexes; + // Run on whichever stream ORT elected for this device for THIS Run(). + // - external_stream_=true -> ORT wrapper around the user-supplied stream + // - external_stream_=false -> stream ORT created via RegisterCreateStreamFn + // Either way, ORT's MemcpyFromHost/MemcpyToHost ran on this stream, so issuing + // kernels on it removes the cross-stream race that EP::stream_ would introduce. + hipStream_t run_stream = static_cast(ctx.GetGPUComputeStream()); + if (run_stream == nullptr) run_stream = stream_; // fallback for harnesses w/o stream registry - // stream_ is always valid: either the user's external stream or an - // EP-owned hipStreamNonBlocking created in the constructor. + const auto& map_input_name_index = mgx_state->input_name_indexes; - // ═══════════════════════════════════════════════════════════════════════ - // ULTRA-FAST PATH: Shapes unchanged from last run - // ═══════════════════════════════════════════════════════════════════════ - if (execute_ultra_fast_path(mgx_state, stream_, ctx)) { + if (execute_ultra_fast_path(mgx_state, run_stream, ctx)) { return Status::OK(); } - // ═══════════════════════════════════════════════════════════════════════ - // Build input shape hash - only computed when shapes change - // ═══════════════════════════════════════════════════════════════════════ std::vector all_input_shapes; all_input_shapes.reserve(map_input_name_index.size() * 4); for (const auto& [name, index] : map_input_name_index) { @@ -4081,17 +4596,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } const auto current_hash = make_hash(all_input_shapes); - // ═══════════════════════════════════════════════════════════════════════ - // FAST PATH: Check cached programs for this shape hash - // ═══════════════════════════════════════════════════════════════════════ - if (execute_fast_path(mgx_state, stream_, ctx, current_hash, all_input_shapes)) { + if (execute_fast_path(mgx_state, run_stream, ctx, current_hash, all_input_shapes)) { return Status::OK(); } - // ═══════════════════════════════════════════════════════════════════════ - // STANDARD PATH: Shape checking and potential recompilation - // ═══════════════════════════════════════════════════════════════════════ - execute_standard_path(mgx_state, stream_, ctx, current_hash, std::move(all_input_shapes), + execute_standard_path(mgx_state, run_stream, ctx, current_hash, std::move(all_input_shapes), model_cache_path_, model_path_, mxr_filename_prefix); return Status::OK(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 24aeee586263d..20b9538aad1ab 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -37,6 +37,7 @@ constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; constexpr auto kModelMaxDynamicBatch = "ORT_MIGRAPHX_MAX_DYNAMIC_BATCH"sv; constexpr auto kCompileBatches = "ORT_MIGRAPHX_COMPILE_BATCHES"sv; +constexpr auto kHipGraphEnable = "ORT_MIGRAPHX_HIP_GRAPH_ENABLE"sv; } // namespace migraphx_env_vars // Tracks which dimensions are symbolic for a given input @@ -56,6 +57,7 @@ struct MIGraphXFuncState { migraphx::target t{}; std::unordered_map input_name_indexes; std::mutex* mgx_mu_ptr = nullptr; + hipStream_t stream = nullptr; bool defer_compilation = false; bool fp16_enable = false; bool bf16_enable = false; @@ -74,17 +76,24 @@ struct MIGraphXFuncState { bool has_dynamic_batch = false; std::vector compiled_batch_sizes; - // Padded input buffers for dynamic batching (allocated on GPU) - struct PaddedBuffer { - void* data = nullptr; // GPU buffer pointer - std::size_t size_bytes = 0; // Buffer size in bytes - migraphx::shape mgx_shape; // Padded MIGraphX shape + // Pinned I/O buffers: allocated once at max compiled batch, reused across all inferences. + // Eliminates per-inference hipMalloc/hipFree for padding and temp outputs. + struct PinnedIOBuffer { + void* data = nullptr; + std::size_t size_bytes = 0; + migraphx::shape max_shape; // Shape at max_batch_size }; - std::vector padded_input_buffers; // One per input when padding is active - // Track last batch sizes to avoid re-allocation when batch size is unchanged - std::size_t last_original_batch_size = 0; // Original batch size from last run - std::size_t last_padded_batch_size = 0; // Padded batch size from last run + struct PinnedIOSet { + std::vector inputs; + std::vector outputs; + std::unordered_map input_name_to_idx; + std::unordered_map output_name_to_idx; + std::size_t max_batch_size = 0; + bool allocated = false; + }; + + PinnedIOSet pinned_io; // ═══════════════════════════════════════════════════════════════════════════ // PERFORMANCE CACHES - Avoid redundant MIGraphX API calls per inference @@ -116,6 +125,7 @@ struct MIGraphXFuncState { // Cached output indices for pre-allocated outputs (used by run_migraphx_program) std::vector cached_prog_output_indices; + std::vector cached_pinned_output_indices; // Last input shapes for quick comparison (avoids hash computation in ultra-fast path) std::vector last_input_shapes_raw; @@ -141,21 +151,39 @@ struct MIGraphXFuncState { // Track which program hash the cached shapes belong to (invalidate when program changes) std::string cached_program_hash; - + // ═══════════════════════════════════════════════════════════════════════════ - // OPTIMIZATION: Reusable temporary output buffers (for slicing mode) + // hipGraph CAPTURE / REPLAY // ═══════════════════════════════════════════════════════════════════════════ - - // Temporary output buffers for slicing (allocated at padded size) - struct TempOutputBuffer { - void* data = nullptr; // GPU buffer pointer - std::size_t size_bytes = 0; // Buffer size in bytes - migraphx::shape mgx_shape; // Padded MIGraphX shape + + struct ExtraOutputInfo { + std::size_t output_index; + std::vector ort_shape; + void* gpu_data; + std::size_t bytes; }; - std::vector temp_output_buffers; - - // Track padded batch size for temp output buffers - std::size_t temp_output_padded_batch_size = 0; + + struct CapturedHipGraph { + hipGraph_t graph = nullptr; + hipGraphExec_t exec = nullptr; + bool captured = false; + std::vector extra_outputs; + + // Addresses captured in the graph for direct-bind mode. + // Used to detect pointer drift and trigger re-capture. + std::unordered_map captured_input_ptrs; + std::unordered_map captured_output_ptrs; + }; + + bool hip_graph_enabled = false; + // When true, capture/replay binds ORT tensor pointers directly (no pinned copies). + // Requires the pool allocator to provide stable addresses. + bool use_direct_hip_graph = false; + // If pointer drift causes too many re-captures, disable direct mode permanently. + static constexpr int kMaxDirectRecaptures = 3; + int direct_recapture_count = 0; + // shape_hash -> captured graph (one per compiled program variant) + std::unordered_map hip_graph_cache; }; // Logical device representation. @@ -213,6 +241,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch_)}, {std::string{migraphx_provider_option::kCompileBatches}, compile_batches_}, + {std::string{migraphx_provider_option::kHipGraphEnable}, MakeStringWithClassicLocale(hip_graph_enable_)}, {std::string{migraphx_provider_option::kHasUserComputeStream}, MakeStringWithClassicLocale(external_stream_)}, {std::string{migraphx_provider_option::kUserComputeStream}, MakeStringWithClassicLocale(reinterpret_cast(stream_))}}; } @@ -257,6 +286,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool first_start_ = true; size_t max_dynamic_batch_{0}; std::string compile_batches_{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + bool hip_graph_enable_{false}; }; }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 81169dc59e327..33876c173c41b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -74,6 +74,7 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .AddAssignmentToReference(migraphx_provider_option::kModelMaxDynamicBatch, max_dynamic_batch) .AddAssignmentToReference(migraphx_provider_option::kCompileBatches, compile_batches) + .AddAssignmentToReference(migraphx_provider_option::kHipGraphEnable, hip_graph_enable) .AddAssignmentToReference(migraphx_provider_option::kHasUserComputeStream, has_user_compute_stream) .AddValueParser( migraphx_provider_option::kUserComputeStream, @@ -118,6 +119,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, {std::string{migraphx_provider_option::kModelMaxDynamicBatch}, MakeStringWithClassicLocale(max_dynamic_batch)}, {std::string{migraphx_provider_option::kCompileBatches}, compile_batches}, + {std::string{migraphx_provider_option::kHipGraphEnable}, MakeStringWithClassicLocale(hip_graph_enable)}, {std::string{migraphx_provider_option::kHasUserComputeStream}, MakeStringWithClassicLocale(has_user_compute_stream)}, {std::string{migraphx_provider_option::kUserComputeStream}, MakeStringWithClassicLocale(reinterpret_cast(user_compute_stream))}, }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index f814485b901dc..1a7754bd70656 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -36,6 +36,7 @@ constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; constexpr auto kModelMaxDynamicBatch = "migraphx_max_dynamic_batch"sv; constexpr auto kCompileBatches = "migraphx_compile_batches"sv; +constexpr auto kHipGraphEnable = "migraphx_hip_graph_enable"sv; constexpr auto kHasUserComputeStream = "has_user_compute_stream"sv; constexpr auto kUserComputeStream = "user_compute_stream"sv; } // namespace migraphx_provider_option @@ -61,6 +62,7 @@ struct MIGraphXExecutionProviderInfo { OrtArenaCfg* default_memory_arena_cfg{nullptr}; size_t max_dynamic_batch{static_cast(0)}; std::string compile_batches{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" + bool hip_graph_enable{false}; void* external_alloc{nullptr}; void* external_free{nullptr}; @@ -94,7 +96,8 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ (static_cast(info.exhaustive_tune) << 21) ^ - (static_cast(info.bf16_enable) << 22); + (static_cast(info.bf16_enable) << 22) ^ + (static_cast(info.hip_graph_enable) << 23); onnxruntime::HashCombine(data, value); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index b021e37cb5112..2829bb45fb93b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -67,8 +67,11 @@ std::unique_ptr MIGraphXStream::CreateNotification(si } void MIGraphXStream::Flush() { - if (auto* handle = GetHandle()) - HIP_CALL_THROW(hipStreamSynchronize(static_cast(handle))); + // Only sync streams we own. External streams are caller-managed; implicit sync + // here would break async pipelines (e.g. Triton) by serializing every Run(). + if(own_stream_) + if (auto* handle = GetHandle()) + HIP_CALL_THROW(hipStreamSynchronize(static_cast(handle))); } void MIGraphXStream::EnqueDeferredCPUBuffer(void* cpu_buffer) {