Conversation
7 tasks
Contributor
|
should be addressed by #2924 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Human Note
drisspg/pt_job_queue#9
Agent Report
Fix: DeepSeek 16B model with make_fx + SAC + regional_inductor on TP=2 FSDP=4
Summary
The
flex_attention_backwardHOP node and its associatedmask_graphget_attr nodes fail to receive thecompile_with_inductorannotation when using DeepSeek attention with TP=2 FSDP=4. This causesregional_inductorto skip compiling the backward attention pass, leading to failures.Root Cause
_copy_fwd_metadata_to_bw_nodes()inmake_fx_tracer.pypropagates forward node metadata (includingcompile_with_inductor) to backward nodes by matchingseq_nrvalues. The assumption is that forward and backward operations for the same HOP share the sameseq_nr.In DeepSeek with tensor parallelism, the forward
flex_attentiongetsseq_nr=78while the backwardflex_attention_backwardgetsseq_nr=76. This mismatch is caused by DTensor redistribution operations (detach nodes with different seq_nr values) inserted between the forward output and the backward input. Because the seq_nr values differ, the backward node never receives thecompile_with_inductorannotation.Additionally, the
mask_graph0get_attr node (which is nested inside a tuple argument offlex_attention_backward) also fails to get annotated because the original code only checked direct arguments, not nested tuple arguments.Fix
Added a second pass to
_copy_fwd_metadata_to_bw_nodes()that specifically handlesflex_attention_backwardnodes missing the annotation:_trace_to_flex_attention_fwd(): Traces the backward node'soutargument (arg index 3) backwards throughdetach/getitem/viewnodes to find the originatingflex_attentionforward call.Tuple-aware propagation: After annotating the backward node, recursively walks its arguments (including nested tuples like the block_mask) to find and annotate associated
get_attrnodes (mask_graph, fw_graph, joint_graph).Extracted
_copy_fwd_meta_to_node(): Refactored the metadata copy logic into a reusable helper to avoid duplication between the seq_nr-based first pass and the new data-flow-based second pass.Test Results
Before fix:
After fix:
Repro Script
Output after fix:
Files Changed
torchtitan/experiments/graph_trainer/make_fx_tracer.py: Added_copy_fwd_meta_to_node(),_trace_to_flex_attention_fwd(), and a second pass in_copy_fwd_metadata_to_bw_nodes()for data-flow-based annotation propagation.Fixes #2818
Repro Script
Agent Worklog
Worklog: Issue #2818 - DeepSeek 16B with make_fx + SAC + regional_inductor
Reproduced the issue
repro_seqnr.pyusing DeepSeek Attention with TP=2 FSDP=4 (4 GPUs)flex_attention(fwd) getsseq_nr=78withcompile_with_inductor=Trueflex_attention_backwardgetsseq_nr=76withcompile_with_inductor=Falsemask_graph0get_attr node also missing the annotationRoot cause identified
_copy_fwd_metadata_to_bw_nodes()relies solely on seq_nr matchingflex_attention(seq=78)->getitem(seq=78)->detach(seq=78)->detach(seq=76)->flex_attention_backward(seq=76)mask_graph0is nested inside a tuple arg, not a direct argFix implemented
_trace_to_flex_attention_fwd()to trace backward'soutarg back through detach/getitem/view to the forward flex_attention_copy_fwd_metadata_to_bw_nodes()for backward nodes missing annotationTest result
compile_with_inductor=TrueThis PR was generated by ptq with human review.