llm_inference.log
Describe the bug
Impact: Blocker — This bug makes it impossible to use CUDA graph capture for the generation (decode) profile and then switch back to the prefill profile within the same IExecutionContext. It blocks the primary inference pipeline for hybrid linear-attention LLM models on DRIVE Orin.
IExecutionContext::enqueueV3() fails with CUDA error 700 (cudaErrorIllegalAddress) inside TensorRT's internal Myelin runtime (shapeChangeHelper → async copy) during the first prefill execution. The crash occurs when switching the optimization profile from profile 1 (generation) back to profile 0 (prefill) via setOptimizationProfileAsync(), in a multi-stream engine scenario.
Error message:
[18:31:00.337] [ERROR] [TensorRT] IExecutionContext::enqueueV3: Error Code 1: Myelin
([copy.cpp:exec:180] CUDA error 700 enqueueing async copy.
In shapeChangeHelper at runtime/myelin/runner.cpp:460)
Followed by a cascade failure on cleanup:
terminate called after throwing an instance of 'std::runtime_error'
what(): CUDA runtime error in cudaFreeHost(data): an illegal memory access was encountered
Root cause analysis:
The error originates inside TensorRT's internal Myelin runtime (not in any user plugin), specifically in shapeChangeHelper at runtime/myelin/runner.cpp:460, performing an async memory copy.
We believe the root cause is a TensorRT internal state corruption when switching optimization profiles after CUDA graph capture in a multi-stream engine. Specifically:
- During CUDA graph capture on profile 1, TensorRT's Myelin runtime initializes internal state (buffer addresses, shape metadata) for its auxiliary worker streams.
- When
setOptimizationProfileAsync(0, stream) is called to switch back to profile 0, Myelin's internal state on the auxiliary worker streams is not fully reinitialized — even if a cudaStreamSynchronize(stream) is issued on the main stream.
- When
enqueueV3() is subsequently called on profile 0, Myelin's shapeChangeHelper operates on stale internal state (e.g., buffer addresses/shapes from profile 1), causing the async copy to access invalid memory.
This is NOT a simple race condition. Adding cudaStreamSynchronize(stream) between setOptimizationProfileAsync() and enqueueV3() does not resolve the crash. The main stream is fully quiesced before enqueueV3(), yet the crash still occurs. This rules out async profile switch timing as the cause and points to Myelin's auxiliary stream state not being properly updated during profile switches.
Key evidence:
- The TensorRT log confirms the profile switch:
"Switching optimization profile from: 1 to 0" (at 18:30:59.915)
- The crash happens 422ms later during
enqueueV3 (at 18:31:00.337)
cudaStreamSynchronize(stream) was added between setOptimizationProfileAsync and enqueueV3 — the crash still occurs, 100% reproducibly
- The engine uses multi-stream (
Number of total worker streams is 2)
- The error is in Myelin's internal
copy.cpp:exec:180, not in any user-provided plugin
- The crash only occurs after CUDA graph capture has been performed on profile 1; if CUDA graph capture is skipped entirely, the profile switch works correctly
Workarounds attempted (all failed):
- Warm up profile 0 before profile 1 — Still crashes. The Myelin internal state established during the warm-up is overwritten during CUDA graph capture on profile 1.
- Remove all
cudaMemcpyAsync in custom LinearAttentionPlugin — Still crashes. The error is in Myelin's internal copy, not any plugin-side copy.
- Add
cudaStreamSynchronize(stream) after setOptimizationProfileAsync — Still crashes. The stream sync completes successfully, confirming the profile switch itself finishes without error, yet the subsequent enqueueV3() still hits CUDA 700.
- Add
cudaDeviceSynchronize() after setOptimizationProfileAsync — Still crashes. This synchronizes ALL streams on the device. Confirms the issue is not a synchronization/timing problem, but a state management bug in Myelin's profile switching logic.
Questions for NVIDIA:
- Is switching optimization profiles supported after CUDA graph capture on a multi-stream engine? If not, this should be documented.
- Does CUDA graph capture on one optimization profile permanently lock Myelin's internal auxiliary stream state?
- What is the recommended pattern for using multiple optimization profiles with CUDA graph capture in a multi-stream engine? (e.g., separate
IExecutionContext per profile?)
Steps/Code to reproduce bug
Build configuration:
cmake .. -DCMAKE_BUILD_TYPE=Release \
-DTRT_PACKAGE_DIR=/path/to/TensorRT \
-DCMAKE_TOOLCHAIN_FILE=cmake/aarch64_linux_toolchain.cmake \
-DEMBEDDED_TARGET=drive-orin
Runtime command used:
# Step 1: Build the engine with 2 optimization profiles (profile 0: prefill, profile 1: generation)
./build/examples/llm/llm_build --onnxDir ./onnx_models/qwen3.5-0.8b \
--engineDir ./engines/qwen3.5-0.8b \
--maxBatchSize 1 --maxInputLen 1024 --maxKVCacheCapacity 4096
# Step 2: Run inference — crashes on first prefill after CUDA graph capture
./build/examples/llm/llm_inference --engineDir ./engines/qwen3.5-0.8b \
--inputFile input.json --outputFile output.json
Execution flow that triggers the crash (simplified):
① Engine deserialization (profile defaults to 0)
② captureDecodingCUDAGraph()
└─ captureVanillaDecodingCudaGraph():
├─ setOptimizationProfileAsync(1, stream) // switch to generation profile
├─ ... set tensor addresses & shapes ...
├─ enqueueV3(stream) // warm-up run
├─ cudaStreamSynchronize(stream)
├─ cudaStreamBeginCapture(...)
├─ enqueueV3(stream) // capture
└─ cudaStreamEndCapture(...)
// ⚠ Function returns with profile still at 1
③ handleRequest() → executePrefillStep():
├─ setOptimizationProfileAsync(0, stream) // switch back to prefill profile
├─ cudaStreamSynchronize(stream) // ← explicit sync added; does NOT prevent crash
├─ setTensorAddress(...) // bind tensors for profile 0
├─ setInputShape(...)
└─ enqueueV3(stream) // 💥 CUDA 700 in Myelin shapeChangeHelper
Relevant source code is in cpp/runtime/llmEngineRunner.cpp:
executePrefillStep() — line ~1009: where the profile switch + enqueueV3 crash occurs
captureVanillaDecodingCudaGraph() — line ~1530: CUDA graph capture that leaves context on profile 1
executeVanillaDecodingStep() — line ~1225: generation path (bypasses profile switching via CUDA graph, so does not crash)
All setOptimizationProfileAsync call sites and their synchronization status:
| Location |
Profile |
cudaStreamSynchronize after? |
executePrefillStep line 1014 |
0 (prefill) |
Yes (added, does NOT prevent crash) |
executeVanillaDecodingStep line 1229 |
1 (generation) |
No |
captureVanillaDecodingCudaGraph line 1535 |
1 (generation) |
No (sync is 100+ lines later, after enqueueV3 warm-up) |
captureEagleBaseTreeDecodingCudaGraph line 1687 |
1 (generation) |
No |
Model / Engine Details:
- Architecture: Qwen3.5-0.8B hybrid model — 24 layers total:
[L, L, L, F] × 6 (18 linear attention + 6 full attention)
- Full Attention config: 8 q_heads, 2 kv_heads, head_dim=256, GQA 4:1
- Linear Attention config: GatedDeltaNet, 16 heads, key/value_dim=128, conv_kernel=4
- Precision: INT4 AWQ
- Optimization profiles: Profile 0 (prefill, dynamic sequence length), Profile 1 (generation, single token)
- Engine multi-stream:
Number of total worker streams is 2, Number of aux streams is 1
- Engine size: 1224 MiB
- Engine bindings include:
- Standard:
inputs_embeds, context_lengths, last_token_ids, rope_cos_sin, logits
- KV Cache:
past_key_values_{3,7,11,15,19,23} (6 full attention layers)
- Linear Attention State:
conv_state_{0,1,2,4,5,6,...}, recurrent_state_{0,1,2,4,5,6,...} (18 layers)
- Plugins:
Int4GroupwiseGemmPlugin (IPluginV3), LinearAttentionPlugin (IPluginV3), AttentionPlugin (IPluginV2DynamicExt)
Relevant log (trimmed):
[18:30:57.511] [INFO] [TensorRT] [MS] Running engine with multi stream info
[18:30:57.511] [INFO] [TensorRT] [MS] Number of aux streams is 1
[18:30:57.511] [INFO] [TensorRT] [MS] Number of total worker streams is 2
[18:30:57.511] [INFO] [TensorRT] [MS] The main stream provided by execute/enqueue calls is the first worker stream
...
[18:30:59.523] [INFO] [TensorRT] Switching optimization profile from: 0 to 1. Please ensure there are no enqueued operations pending in this context prior to switching profiles
[18:30:59.915] [DEBUG] captureVanillaDecodingCudaGraph(): CUDA graph captured successfully for input shape [1, 1, 1024]
[18:30:59.915] [INFO] LLMInferenceRuntime(): Successfully captured the decoding CUDA graph for all execution batch sizes and LoRA weights.
[18:30:59.915] [INFO] [TensorRT] Switching optimization profile from: 1 to 0. Please ensure there are no enqueued operations pending in this context prior to switching profiles
[18:31:00.337] [ERROR] [TensorRT] IExecutionContext::enqueueV3: Error Code 1: Myelin ([copy.cpp:exec:180] CUDA error 700 enqueueing async copy. In shapeChangeHelper at runtime/myelin/runner.cpp:460)
[18:31:00.337] [ERROR] executePrefill(): Failed on TensorRT prefill stage enqueueV3() call.
Expected behavior
setOptimizationProfileAsync(profile, stream) followed by proper tensor binding and enqueueV3(stream) on the same stream should work correctly after CUDA graph capture on a different profile. TensorRT's internal state management should ensure that Myelin's shapeChangeHelper and all auxiliary worker streams properly reinitialize their state when the optimization profile changes, regardless of whether CUDA graph capture was previously performed.
Actual behavior: After CUDA graph capture on profile 1, switching back to profile 0 via setOptimizationProfileAsync(0, stream) — even with explicit cudaStreamSynchronize(stream) or cudaDeviceSynchronize() — leaves Myelin's internal state in an inconsistent state. The subsequent enqueueV3(stream) triggers CUDA error 700 (illegal memory access) in Myelin's shapeChangeHelper async copy. The error is 100% reproducible on the first prefill enqueueV3 after CUDA graph capture. If CUDA graph capture is entirely skipped, profile switching between 0 and 1 with enqueueV3 works correctly, confirming the issue is specific to the interaction between CUDA graph capture and subsequent profile switches in multi-stream Myelin engines.
System information (Edge Device)
- Platform: NVIDIA DRIVE Orin
- Software release: NVIDIA DRIVE OS (bundled with DriveOS)
- CPU architecture: aarch64
- GPU compute capability: SM87 (Ampere)
- Total device memory: 32 GB (unified memory)
- Build type: Release
- Library versions:
- TensorRT Edge-LLM version or commit hash: v0.5.0
- CUDA: 11.4
- TensorRT: 10.13.1 (bundled with DriveOS)
- C++ compiler: GCC 11.4 (cross-compilation toolchain)
- CMake options used:
- CMAKE_TOOLCHAIN_FILE: cmake/aarch64_linux_toolchain.cmake
- EMBEDDED_TARGET: drive-orin
- TRT_PACKAGE_DIR: /path/to/TensorRT
- Any other details that may help:
- Engine uses multi-stream execution:
Number of total worker streams is 2, Number of aux streams is 1
- Engine size: 1224 MiB
- Model: Qwen3.5-0.8B hybrid (18 linear attention + 6 full attention layers, INT4 AWQ)
- Plugins used:
Int4GroupwiseGemmPlugin (IPluginV3), LinearAttentionPlugin (IPluginV3), AttentionPlugin (IPluginV2DynamicExt)
llm_inference.log
Describe the bug
Impact: Blocker — This bug makes it impossible to use CUDA graph capture for the generation (decode) profile and then switch back to the prefill profile within the same
IExecutionContext. It blocks the primary inference pipeline for hybrid linear-attention LLM models on DRIVE Orin.IExecutionContext::enqueueV3()fails with CUDA error 700 (cudaErrorIllegalAddress) inside TensorRT's internal Myelin runtime (shapeChangeHelper→ async copy) during the first prefill execution. The crash occurs when switching the optimization profile from profile 1 (generation) back to profile 0 (prefill) viasetOptimizationProfileAsync(), in a multi-stream engine scenario.Error message:
Followed by a cascade failure on cleanup:
Root cause analysis:
The error originates inside TensorRT's internal Myelin runtime (not in any user plugin), specifically in
shapeChangeHelperatruntime/myelin/runner.cpp:460, performing an async memory copy.We believe the root cause is a TensorRT internal state corruption when switching optimization profiles after CUDA graph capture in a multi-stream engine. Specifically:
setOptimizationProfileAsync(0, stream)is called to switch back to profile 0, Myelin's internal state on the auxiliary worker streams is not fully reinitialized — even if acudaStreamSynchronize(stream)is issued on the main stream.enqueueV3()is subsequently called on profile 0, Myelin'sshapeChangeHelperoperates on stale internal state (e.g., buffer addresses/shapes from profile 1), causing the async copy to access invalid memory.This is NOT a simple race condition. Adding
cudaStreamSynchronize(stream)betweensetOptimizationProfileAsync()andenqueueV3()does not resolve the crash. The main stream is fully quiesced beforeenqueueV3(), yet the crash still occurs. This rules out async profile switch timing as the cause and points to Myelin's auxiliary stream state not being properly updated during profile switches.Key evidence:
"Switching optimization profile from: 1 to 0"(at18:30:59.915)enqueueV3(at18:31:00.337)cudaStreamSynchronize(stream)was added betweensetOptimizationProfileAsyncandenqueueV3— the crash still occurs, 100% reproduciblyNumber of total worker streams is 2)copy.cpp:exec:180, not in any user-provided pluginWorkarounds attempted (all failed):
cudaMemcpyAsyncin custom LinearAttentionPlugin — Still crashes. The error is in Myelin's internal copy, not any plugin-side copy.cudaStreamSynchronize(stream)aftersetOptimizationProfileAsync— Still crashes. The stream sync completes successfully, confirming the profile switch itself finishes without error, yet the subsequentenqueueV3()still hits CUDA 700.cudaDeviceSynchronize()aftersetOptimizationProfileAsync— Still crashes. This synchronizes ALL streams on the device. Confirms the issue is not a synchronization/timing problem, but a state management bug in Myelin's profile switching logic.Questions for NVIDIA:
IExecutionContextper profile?)Steps/Code to reproduce bug
Build configuration:
cmake .. -DCMAKE_BUILD_TYPE=Release \ -DTRT_PACKAGE_DIR=/path/to/TensorRT \ -DCMAKE_TOOLCHAIN_FILE=cmake/aarch64_linux_toolchain.cmake \ -DEMBEDDED_TARGET=drive-orinRuntime command used:
Execution flow that triggers the crash (simplified):
Relevant source code is in
cpp/runtime/llmEngineRunner.cpp:executePrefillStep()— line ~1009: where the profile switch +enqueueV3crash occurscaptureVanillaDecodingCudaGraph()— line ~1530: CUDA graph capture that leaves context on profile 1executeVanillaDecodingStep()— line ~1225: generation path (bypasses profile switching via CUDA graph, so does not crash)All
setOptimizationProfileAsynccall sites and their synchronization status:cudaStreamSynchronizeafter?executePrefillStepline 1014executeVanillaDecodingStepline 1229captureVanillaDecodingCudaGraphline 1535captureEagleBaseTreeDecodingCudaGraphline 1687Model / Engine Details:
[L, L, L, F] × 6(18 linear attention + 6 full attention)Number of total worker streams is 2,Number of aux streams is 1inputs_embeds,context_lengths,last_token_ids,rope_cos_sin,logitspast_key_values_{3,7,11,15,19,23}(6 full attention layers)conv_state_{0,1,2,4,5,6,...},recurrent_state_{0,1,2,4,5,6,...}(18 layers)Int4GroupwiseGemmPlugin(IPluginV3),LinearAttentionPlugin(IPluginV3),AttentionPlugin(IPluginV2DynamicExt)Relevant log (trimmed):
Expected behavior
setOptimizationProfileAsync(profile, stream)followed by proper tensor binding andenqueueV3(stream)on the same stream should work correctly after CUDA graph capture on a different profile. TensorRT's internal state management should ensure that Myelin'sshapeChangeHelperand all auxiliary worker streams properly reinitialize their state when the optimization profile changes, regardless of whether CUDA graph capture was previously performed.Actual behavior: After CUDA graph capture on profile 1, switching back to profile 0 via
setOptimizationProfileAsync(0, stream)— even with explicitcudaStreamSynchronize(stream)orcudaDeviceSynchronize()— leaves Myelin's internal state in an inconsistent state. The subsequentenqueueV3(stream)triggers CUDA error 700 (illegal memory access) in Myelin'sshapeChangeHelperasync copy. The error is 100% reproducible on the first prefillenqueueV3after CUDA graph capture. If CUDA graph capture is entirely skipped, profile switching between 0 and 1 withenqueueV3works correctly, confirming the issue is specific to the interaction between CUDA graph capture and subsequent profile switches in multi-stream Myelin engines.System information (Edge Device)
Number of total worker streams is 2,Number of aux streams is 1Int4GroupwiseGemmPlugin(IPluginV3),LinearAttentionPlugin(IPluginV3),AttentionPlugin(IPluginV2DynamicExt)