Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
8debc5e
feat(inference): add ServeConfig and _EventLoopManager
YangFei1990 May 7, 2026
4ca3c95
feat(inference): add coordinator runtime and _MegatronLLMBase
YangFei1990 May 7, 2026
d7e68f1
feat(inference): add MegatronAsyncLLM, slim base class
YangFei1990 May 7, 2026
c03ab48
feat(inference): add MegatronLLM
YangFei1990 May 7, 2026
5c38044
refactor(inference): drop model_name fields from ServeConfig
YangFei1990 May 7, 2026
6331cc0
feat(inference): add MegatronAsyncLLM.serve(), drop ServeConfig.role
YangFei1990 May 7, 2026
b6448e8
feat(inference): add offline_inference example, fix high-level API bugs
YangFei1990 May 7, 2026
44ec5a9
feat(inference): add launch_inference_server example, fix daemon-thre…
YangFei1990 May 8, 2026
0d8ae8b
fix(tests): repoint inference recipes and cuda_graphs.sh to examples/…
YangFei1990 May 8, 2026
589626d
test(inference): add unit tests for the high-level inference API
YangFei1990 May 8, 2026
2eb2e81
test(inference): add functional tests for offline_inference 4 modes w…
YangFei1990 May 8, 2026
739f3b0
test(inference): add HTTP smoke test for launch_inference_server with…
YangFei1990 May 8, 2026
8ff9a26
docs(inference): add README for the high-level inference API
YangFei1990 May 8, 2026
9e7fae3
docs(inference): rewrite examples README and remove stale llama_mistr…
YangFei1990 May 8, 2026
5f651b7
Merge branch 'main' into inference_apis
YangFei1990 May 8, 2026
34494c5
ci(inference): satisfy linting, copyright-check, and build-docs
YangFei1990 May 8, 2026
8caa84d
Merge branch 'inference_apis' of https://github.com/YangFei1990/Megat…
YangFei1990 May 8, 2026
963f663
Merge branch 'main' into inference_apis
YangFei1990 May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions docs/llama_mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatr
- [Download Huggingface checkpoints](#download-huggingface-checkpoints)
- [Convert checkpoint format](#convert-checkpoint-format)
- [Huggingface format](#huggingface-format)
- [(Optional) Validate checkpoints](#optional-validate-checkpoints)
- [Launch model](#launch-model)
- [Mistral-7b](#mistral-7b)
- [Download Huggingface checkpoints](#download-huggingface-checkpoints)
- [Convert checkpoint format](#convert-checkpoint-format)
- [(Optional) Validate checkpoints](#optional-validate-checkpoints)
- [Launch model](#launch-model)
- [Other Llama-like model support](#other-llama-like-model-support)
- [Known numerical differences](#known-numerical-differences)
Expand Down Expand Up @@ -210,14 +208,6 @@ python Megatron-Bridge/examples/conversion/convert_checkpoints.py import \

After this conversion, we are ready to load the checkpoints into a Megatron GPT model.

## (Optional) Validate checkpoints

A Megatron-LM text generation server for Llama3 can be launched using the script `examples/inference/llama_mistral/run_text_generation_llama3.sh <PATH_TO_CONVERTED_CORE_CHECKPOINT> <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT>`. For Llama3.1, please use `examples/inference/llama_mistral/run_text_generation_llama3.1.sh`.

Once running, query the server with `curl 'http://<TEXT_GENERATION_SERVER_IP>:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["<SOME_PROMPT>"], "tokens_to_generate":100, "top_k":1}'`.

A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT> --prompt <SOME_PROMPT>`.

## Launch model

If loading for either inference or finetuning, use the following arguments for Llama 3.0:
Expand Down Expand Up @@ -314,14 +304,6 @@ python Megatron-Bridge/examples/conversion/convert_checkpoints.py import \

After this conversion, we are ready to load the checkpoints into a Megatron GPT model.

## (Optional) Validate checkpoints

A Megatron-LM text generation server for Mistral-7B can be launched using the script `examples/inference/llama_mistral/run_text_generation_mistral.sh <PATH_TO_CONVERTED_MCORE_CHECKPOINT> <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT>`.

Once running, query the server with `curl 'http://<TEXT_GENERATION_SERVER_IP>:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["<SOME_PROMPT>"], "tokens_to_generate":100, "top_k":1}'`.

A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/inference/llama_mistral/huggingface_reference.py --model_path <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT> --prompt <SOME_PROMPT>`.

## Launch model

If loading for either inference or finetuning, use the following arguments:
Expand Down
360 changes: 92 additions & 268 deletions examples/inference/README.md

Large diffs are not rendered by default.

127 changes: 0 additions & 127 deletions examples/inference/gpt/gpt_dynamic_inference_12b.sh

This file was deleted.

115 changes: 0 additions & 115 deletions examples/inference/gpt/gpt_dynamic_inference_357m.sh

This file was deleted.

106 changes: 106 additions & 0 deletions examples/inference/launch_inference_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

"""OpenAI-compatible inference server using the Megatron high-level API.

Mirrors tools/run_dynamic_text_generation_server.py but drives the
``DynamicInferenceEngine`` through ``MegatronAsyncLLM.serve(...)`` instead
of building the coordinator/engine pipeline manually. Coordinator mode is
required (HTTP serving uses the coordinator path); ``use_coordinator=True``
is hardcoded in the script.
"""

import asyncio
import os
import sys
from argparse import ArgumentParser

import torch

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)

from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
from megatron.core.utils import configure_nvtx_profiling
from megatron.inference import MegatronAsyncLLM, ServeConfig
from megatron.inference.utils import (
add_inference_args,
get_inference_config_from_model_and_args,
get_model_for_inference,
)
from megatron.training import get_args, initialize_megatron
from megatron.training.arguments import parse_and_validate_args


def add_serve_args(parser: ArgumentParser) -> ArgumentParser:
parser = add_inference_args(parser)
group = parser.add_argument_group(title='High-level inference server')
group.add_argument("--coordinator-host", type=str, default=None)
group.add_argument("--coordinator-port", type=int, default=None)
group.add_argument("--host", type=str, default="0.0.0.0", help="HTTP bind host")
group.add_argument("--port", type=int, default=5000, help="HTTP bind port")
group.add_argument(
"--parsers", type=str, nargs="+", default=[], help="Response parser names"
)
group.add_argument(
"--verbose", action="store_true", default=False, help="Per-request HTTP logging"
)
group.add_argument(
"--frontend-replicas", type=int, default=4,
help="Number of HTTP frontend processes spawned on the primary rank.",
)
return parser


async def _serve(args, model, tokenizer, inference_config):
async with MegatronAsyncLLM(
model=model,
tokenizer=tokenizer,
inference_config=inference_config,
use_coordinator=True,
coordinator_host=args.coordinator_host,
coordinator_port=args.coordinator_port,
) as llm:
serve_config = ServeConfig(
host=args.host,
port=args.port,
parsers=args.parsers,
verbose=args.verbose,
frontend_replicas=args.frontend_replicas,
)
await llm.serve(serve_config, blocking=True)


def main():
parse_and_validate_args(
extra_args_provider=add_serve_args,
args_defaults={'no_load_rng': True, 'no_load_optim': True},
)
initialize_megatron()

args = get_args()

# Match the legacy tool's NVTX gating.
if args.profile and args.nvtx_ranges:
configure_nvtx_profiling(True)

# Required for lm-eval loglikelihood compatibility: keeps prompt logits
# materialized so echo=True / logprob requests work end-to-end. Matches
# tools/run_dynamic_text_generation_server.py.
args.return_log_probs = True

tokenizer = build_tokenizer(args)
model = get_model_for_inference()
inference_config = get_inference_config_from_model_and_args(model, args)

try:
asyncio.run(_serve(args, model, tokenizer, inference_config))
except KeyboardInterrupt:
print("Server process interrupted by user.")
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading