feat(gpt): add output postprocess hook#4686
Open
Glitchfix wants to merge 1 commit intoNVIDIA:mainfrom
Open
Conversation
1eda379 to
ab3bdb5
Compare
Add a keyword-only output_processor and context to GPTModel.forward so downstream RL callers can consume decoder hidden states and output-layer helpers without monkey-patching GPTModel postprocessing. Thread the same hook through schedule-plan postprocess nodes so 1F1B schedule users can rely on the same extension point, and add coverage for direct forward invocation plus schedule-plan threading. Addresses NVIDIA#4590. Signed-off-by: Shivanjan Chakravorty <shivanjanc@nvidia.com>
ab3bdb5 to
bbb5e60
Compare
santhnm2
approved these changes
May 9, 2026
fanshiqing
approved these changes
May 9, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Adds an objective-neutral GPT output postprocess hook so downstream RL callers can consume decoder hidden states and output-layer helpers without monkey-patching
GPTModel.forwardor the schedule-plan postprocess path.Issue tracking
Linked issue: Related to #4590
Summary
output_processorandoutput_processor_contextarguments toGPTModel.forward.build_schedule_planand the 1F1BPostProcessNode.This is intentionally a narrow first slice for #4590. It does not add RL-specific arguments to
GPTModel.forward, and it leaves MTP-specific custom postprocess behavior as a follow-up unless maintainers prefer a broader API.Contribution process
Pre-checks
Testing
pre-commit run --files megatron/core/models/common/model_chunk_schedule_plan.py megatron/core/models/gpt/fine_grained_callables.py megatron/core/models/gpt/gpt_model.py tests/unit_tests/models/test_gpt_model.pypython -m py_compile megatron/core/models/gpt/gpt_model.py megatron/core/models/common/model_chunk_schedule_plan.py megatron/core/models/gpt/fine_grained_callables.py tests/unit_tests/models/test_gpt_model.pyPATH=/tmp/megatron-autoformat-stable/bin:$PATH CHECK_ONLY=true SKIP_DOCS=true ./tools/autoformat.shtools/autoformat.sh(|| true); running it in this local environment reports existing project typing issues in the touched files' surrounding code.GPTModel.forward(..., output_processor=...)selected-token logprobs match logprobs computed from the default logits path;output_layer.weight;PostProcessNodeinvokes the threaded processor/context.output_processor;4.2114 -> 2.4986.Salesforce/wikitext,wikitext-2-raw-v1,train[:64]:output_processor;10.833719vs10.833719;10.8337 -> 10.3881.Full pytest collection was not run locally because this environment is missing
omegaconfandtransformer_engine.