Skip to content

Test for https://github.com/drisspg/pt_job_queue/pull/9#2904

Draft
xmfan wants to merge 1 commit intomainfrom
ptq/2818
Draft

Test for https://github.com/drisspg/pt_job_queue/pull/9#2904
xmfan wants to merge 1 commit intomainfrom
ptq/2818

Conversation

@xmfan
Copy link
Copy Markdown
Member

@xmfan xmfan commented Apr 9, 2026

Human Note

drisspg/pt_job_queue#9

Agent Report

Fix: DeepSeek 16B model with make_fx + SAC + regional_inductor on TP=2 FSDP=4

Summary

The flex_attention_backward HOP node and its associated mask_graph get_attr nodes fail to receive the compile_with_inductor annotation when using DeepSeek attention with TP=2 FSDP=4. This causes regional_inductor to skip compiling the backward attention pass, leading to failures.

Root Cause

_copy_fwd_metadata_to_bw_nodes() in make_fx_tracer.py propagates forward node metadata (including compile_with_inductor) to backward nodes by matching seq_nr values. The assumption is that forward and backward operations for the same HOP share the same seq_nr.

In DeepSeek with tensor parallelism, the forward flex_attention gets seq_nr=78 while the backward flex_attention_backward gets seq_nr=76. This mismatch is caused by DTensor redistribution operations (detach nodes with different seq_nr values) inserted between the forward output and the backward input. Because the seq_nr values differ, the backward node never receives the compile_with_inductor annotation.

Additionally, the mask_graph0 get_attr node (which is nested inside a tuple argument of flex_attention_backward) also fails to get annotated because the original code only checked direct arguments, not nested tuple arguments.

Fix

Added a second pass to _copy_fwd_metadata_to_bw_nodes() that specifically handles flex_attention_backward nodes missing the annotation:

  1. _trace_to_flex_attention_fwd(): Traces the backward node's out argument (arg index 3) backwards through detach/getitem/view nodes to find the originating flex_attention forward call.

  2. Tuple-aware propagation: After annotating the backward node, recursively walks its arguments (including nested tuples like the block_mask) to find and annotate associated get_attr nodes (mask_graph, fw_graph, joint_graph).

  3. Extracted _copy_fwd_meta_to_node(): Refactored the metadata copy logic into a reusable helper to avoid duplication between the seq_nr-based first pass and the new data-flow-based second pass.

Test Results

Before fix:

flex_attention                 seq_nr=78 compile_with_inductor=True
mask_graph0                    seq_nr=76 compile_with_inductor=False
flex_attention_backward        seq_nr=76 compile_with_inductor=False
FAIL

After fix:

flex_attention                 seq_nr=78 compile_with_inductor=True
mask_graph0                    seq_nr=76 compile_with_inductor=True
flex_attention_backward        seq_nr=76 compile_with_inductor=True
PASS
Repro Script
#!/usr/bin/env python3
"""Minimal repro: flex_attention_backward doesn't get compile_with_inductor
when its seq_nr differs from the forward flex_attention node.

Run:
    python -m torch.distributed.run --standalone --nproc_per_node=4 repro_seqnr.py
"""

import os
import sys

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    PrepareModuleInput,
    RowwiseParallel,
)
from torch.fx.traceback import annotate_fn
from torch.nn.attention.flex_attention import and_masks

from torchtitan.distributed import ParallelDims
from torchtitan.distributed.tensor_parallel import NoParallel
from torchtitan.experiments.graph_trainer.common_utils import (
    _AC_REGION_ID,
    annotate_flex_attention_for_regional_inductor,
    register_blockmask_pytree_node,
)
from torchtitan.experiments.graph_trainer.make_fx_tracer import trace_module
from torchtitan.experiments.graph_trainer.simple_fsdp import data_parallel
from torchtitan.models.common.attention import (
    FlexAttention,
    create_attention_mask,
    get_causal_mask_mod,
    get_document_mask_mod,
)
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.rmsnorm import RMSNorm
from torchtitan.models.common.rope import RoPE
from torchtitan.models.deepseek_v3.model import Attention

_PARAM_INIT = {"init_fn": "full_", "init_fill_value": 0.02}
register_blockmask_pytree_node()


class TinyAttentionModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        dim = 64
        n_heads = 4
        rope_dim = 16
        kv_lora_rank = 32
        qk_nope_head_dim = 16
        v_head_dim = 16
        qk_head_dim = qk_nope_head_dim + rope_dim
        self.attn = Attention(
            Attention.Config(
                dim=dim, n_heads=n_heads, q_lora_rank=0,
                kv_lora_rank=kv_lora_rank, qk_nope_head_dim=qk_nope_head_dim,
                qk_rope_head_dim=rope_dim, v_head_dim=v_head_dim,
                q_norm=RMSNorm.Config(normalized_shape=1, param_init=_PARAM_INIT),
                kv_norm=RMSNorm.Config(normalized_shape=kv_lora_rank, param_init=_PARAM_INIT),
                wq=Linear.Config(in_features=dim, out_features=n_heads * qk_head_dim, param_init=_PARAM_INIT),
                wkv_a=Linear.Config(in_features=dim, out_features=kv_lora_rank + rope_dim, param_init=_PARAM_INIT),
                wkv_b=Linear.Config(in_features=kv_lora_rank, out_features=n_heads * (qk_nope_head_dim + v_head_dim), param_init=_PARAM_INIT),
                wo=Linear.Config(in_features=n_heads * v_head_dim, out_features=dim, param_init=_PARAM_INIT),
                inner_attention=FlexAttention.Config(), mask_type="block_causal",
                rope_factor=1.0, rope_max_seq_len=64, rope_original_seq_len=64,
            ),
        )
        self.rope = RoPE(RoPE.Config(dim=rope_dim, max_seq_len=64, backend="complex", scaling="none"))

    def forward(self, x, block_mask):
        out = self.attn(x, self.rope.cache, block_mask)
        loss = out.float().sum()
        torch.autograd.grad(loss, [p for p in self.parameters() if p.requires_grad])
        return loss


def main():
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group("nccl")
    device = torch.device("cuda", local_rank)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    try:
        model = TinyAttentionModel().to(device=device, dtype=torch.bfloat16)
        with torch.no_grad():
            for p in model.parameters():
                if p.requires_grad:
                    torch.nn.init.normal_(p, std=0.02)

        parallel_dims = ParallelDims(dp_replicate=1, dp_shard=2, cp=1, tp=2, pp=1, ep=1, etp=1, world_size=4)
        parallel_dims.build_mesh()
        tp_mesh = parallel_dims.get_mesh("tp")
        parallelize_module(model, tp_mesh, {
            "attn": PrepareModuleInput(input_layouts=(Replicate(), Replicate(), None), desired_input_layouts=(Replicate(), Replicate(), None)),
            "attn.wq": ColwiseParallel(use_local_output=False),
            "attn.wkv_a": NoParallel(),
            "attn.wkv_b": ColwiseParallel(use_local_output=False),
            "attn.kv_norm": NoParallel(),
            "attn.inner_attention": PrepareModuleInput(input_layouts=(Shard(1), Shard(1), Shard(1)), desired_input_layouts=(Shard(1), Shard(1), Shard(1)), use_local_output=True),
            "attn.wo": RowwiseParallel(output_layouts=Replicate(), use_local_output=True),
        })
        model = data_parallel(model, parallel_dims.get_mesh("fsdp"), mode="fully_shard")

        seq_len = 64
        x = torch.randn(1, seq_len, 64, device=device, dtype=torch.bfloat16)
        tokens = torch.randint(2, 128, (1, seq_len), device=device)
        tokens[:, 15::16] = 1
        block_mask = create_attention_mask(
            and_masks(get_causal_mask_mod(), get_document_mask_mod(tokens, eos_id=1)),
            B=1, H=None, Q_LEN=seq_len, KV_LEN=seq_len,
        )

        model.forward = annotate_fn({_AC_REGION_ID: 1})(model.forward)
        with annotate_flex_attention_for_regional_inductor():
            traced = trace_module(model, (x, block_mask))

        dist.barrier()
        if rank == 0:
            ok = True
            for node in traced.gm.graph.nodes:
                if node.op == "call_function":
                    name = getattr(node.target, "__name__", "")
                    if name in ("flex_attention", "flex_attention_backward"):
                        ci = "compile_with_inductor" in node.meta.get("custom", {})
                        seq = node.meta.get("seq_nr", "?")
                        print(f"  {name:30s} seq_nr={seq} compile_with_inductor={ci}")
                        if not ci:
                            ok = False
                elif node.op == "get_attr" and isinstance(node.target, str):
                    if node.target.startswith("mask_graph"):
                        ci = "compile_with_inductor" in node.meta.get("custom", {})
                        seq = node.meta.get("seq_nr", "?")
                        print(f"  {node.target:30s} seq_nr={seq} compile_with_inductor={ci}")
                        if not ci:
                            ok = False
            print("\nPASS" if ok else "\nFAIL")
            if not ok:
                sys.exit(1)
    finally:
        dist.barrier()
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Output after fix:

  sdpa_score0                    seq_nr=78 compile_with_inductor=True
  sdpa_mask0                     seq_nr=78 compile_with_inductor=True
  flex_attention                 seq_nr=78 compile_with_inductor=True
  mask_graph0                    seq_nr=76 compile_with_inductor=True
  flex_attention_backward        seq_nr=76 compile_with_inductor=True

PASS: All flex_attention nodes have compile_with_inductor

Files Changed

  • torchtitan/experiments/graph_trainer/make_fx_tracer.py: Added _copy_fwd_meta_to_node(), _trace_to_flex_attention_fwd(), and a second pass in _copy_fwd_metadata_to_bw_nodes() for data-flow-based annotation propagation.

Fixes #2818

Repro Script
#!/usr/bin/env python3
"""Smaller distributed repro for block-causal regional_inductor failure.

Run:
    torchrun --standalone --nproc_per_node=4 \
        scripts/repro_regional_inductor_tiny_attention.py
"""

import os

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    PrepareModuleInput,
    RowwiseParallel,
)
from torch.fx.graph import CodeGen
from torch.fx.traceback import annotate_fn
from torch.nn.attention.flex_attention import and_masks

from torchtitan.distributed import ParallelDims
from torchtitan.distributed.tensor_parallel import NoParallel
from torchtitan.experiments.graph_trainer.common_utils import (
    _AC_REGION_ID,
    register_blockmask_pytree_node,
)
from torchtitan.experiments.graph_trainer.make_fx_tracer import minimal_fx_tracer
from torchtitan.experiments.graph_trainer.passes import apply_ac_on_fwd_bwd_graph
from torchtitan.experiments.graph_trainer.simple_fsdp import data_parallel
from torchtitan.models.common.attention import (
    annotate_flex_attention_for_regional_inductor,
    create_attention_mask,
    get_causal_mask_mod,
    get_document_mask_mod,
)
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.rmsnorm import RMSNorm
from torchtitan.models.common.rope import RoPE
from torchtitan.models.deepseek_v3.model import Attention


register_blockmask_pytree_node()


class TinyAttentionModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        dim = 64
        rope_dim = 16
        self.attn = Attention(
            Attention.Config(
                n_heads=4,
                q_lora_rank=0,
                kv_lora_rank=32,
                qk_nope_head_dim=16,
                qk_rope_head_dim=rope_dim,
                v_head_dim=16,
                q_norm=RMSNorm.Config(),
                kv_norm=RMSNorm.Config(),
                attn_backend="flex",
                attn_mask_type="block_causal",
                wq=Linear.Config(),
                rope_factor=1.0,
                rope_max_seq_len=64,
                rope_original_seq_len=64,
            ),
            dim=dim,
        )
        self.rope = RoPE(
            RoPE.Config(
                dim=rope_dim,
                max_seq_len=64,
                backend="complex",
                scaling="none",
            )
        )

    def init_weights(self) -> None:
        self.rope.init_weights(buffer_device=torch.device("cuda"))
        self.attn.init_weights(init_std=0.02)

    def forward(self, x: torch.Tensor, block_mask) -> torch.Tensor:
        return self.attn(x, self.rope.cache, block_mask)


def _init_dist() -> tuple[int, int]:
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group("nccl")
    return rank, local_rank


def _parallelize(model: TinyAttentionModel) -> TinyAttentionModel:
    parallel_dims = ParallelDims(
        dp_replicate=1,
        dp_shard=2,
        cp=1,
        tp=2,
        pp=1,
        ep=1,
        etp=1,
        world_size=4,
    )
    parallel_dims.build_mesh()

    tp_mesh = parallel_dims.get_mesh("tp")
    parallelize_module(
        model,
        tp_mesh,
        {
            "attn": PrepareModuleInput(
                input_layouts=(Replicate(), Replicate(), None),
                desired_input_layouts=(Replicate(), Replicate(), None),
            ),
            "attn.wq": ColwiseParallel(use_local_output=False),
            "attn.wkv_a": NoParallel(),
            "attn.wkv_b": ColwiseParallel(use_local_output=False),
            "attn.kv_norm": NoParallel(),
            "attn.inner_attention": PrepareModuleInput(
                input_layouts=(Shard(1), Shard(1), Shard(1)),
                desired_input_layouts=(Shard(1), Shard(1), Shard(1)),
                use_local_output=True,
            ),
            "attn.wo": RowwiseParallel(output_layouts=Replicate(), use_local_output=True),
        },
    )
    return data_parallel(model, parallel_dims.get_mesh("fsdp"), mode="fully_shard")


def _make_inputs(device: torch.device) -> tuple[torch.Tensor, object]:
    seq_len = 64
    x = torch.randn(1, seq_len, 64, device=device, dtype=torch.bfloat16)
    tokens = torch.randint(2, 128, (1, seq_len), device=device)
    tokens[:, 15::16] = 1
    block_mask = create_attention_mask(
        and_masks(get_causal_mask_mod(), get_document_mask_mod(tokens, eos_id=1)),
        B=1,
        H=None,
        Q_LEN=seq_len,
        KV_LEN=seq_len,
    )
    return x, block_mask


def _trace(model: TinyAttentionModel, x, block_mask):
    def step(mod, x_in, mask):
        out = mod(x_in, mask)
        loss = out.float().sum()
        grads = torch.autograd.grad(loss, [p for p in mod.parameters() if p.requires_grad])
        return [loss] + list(grads)

    model.forward = annotate_fn({_AC_REGION_ID: 1})(model.forward)
    with annotate_flex_attention_for_regional_inductor():
        traced = minimal_fx_tracer(step, (model, x, block_mask))
    traced.gm = apply_ac_on_fwd_bwd_graph(traced.gm)
    return traced


def _apply_regional_inductor(traced) -> None:
    from torch.fx.passes.regional_inductor import regional_inductor

    fake_mode = None
    for node in traced.gm.graph.nodes:
        if node.op == "placeholder" and "val" in node.meta:
            val = node.meta["val"]
            if isinstance(val, torch.Tensor) and hasattr(val, "fake_mode"):
                fake_mode = val.fake_mode
                break
    context = torch._guards.TracingContext(fake_mode)
    with torch._guards.tracing(context):
        traced.gm = regional_inductor(traced.gm)
    traced.gm.graph.set_codegen(CodeGen())
    traced.gm.recompile()


def _print_diag(gm: torch.fx.GraphModule) -> None:
    for node in gm.graph.nodes:
        if node.op == "get_attr" and isinstance(node.target, str) and node.target.startswith("mask_graph"):
            ci = "compile_with_inductor" in node.meta.get("custom", {})
            print(f"{node.target:24s} seq={node.meta.get('seq_nr')} ci={ci}")
        if node.op == "call_function" and getattr(node.target, "__name__", "") == "flex_attention_backward":
            ci = "compile_with_inductor" in node.meta.get("custom", {})
            print(f"{node.target.__name__:24s} seq={node.meta.get('seq_nr')} ci={ci}")
    for name, mod in gm.named_children():
        if name.startswith("mask_graph") and isinstance(mod, torch.fx.GraphModule):
            n_ph = sum(1 for node in mod.graph.nodes if node.op == "placeholder")
            print(f"{name}: placeholders={n_ph} codegen={type(mod.graph._codegen).__name__}")


def main() -> None:
    rank, local_rank = _init_dist()
    device = torch.device("cuda", local_rank)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    try:
        model = TinyAttentionModel().to(device=device, dtype=torch.bfloat16)
        with torch.no_grad():
            model.init_weights()
        model = _parallelize(model)
        x, block_mask = _make_inputs(device)
        traced = _trace(model, x, block_mask)
        _apply_regional_inductor(traced)

        dist.barrier()
        if rank == 0:
            _print_diag(traced.gm)

        traced(model, x, block_mask)
    finally:
        dist.barrier()
        dist.destroy_process_group()


if __name__ == "__main__":
    main()
Agent Worklog

Worklog: Issue #2818 - DeepSeek 16B with make_fx + SAC + regional_inductor

Reproduced the issue

  • Wrote repro_seqnr.py using DeepSeek Attention with TP=2 FSDP=4 (4 GPUs)
  • Confirmed: flex_attention (fwd) gets seq_nr=78 with compile_with_inductor=True
  • Confirmed: flex_attention_backward gets seq_nr=76 with compile_with_inductor=False
  • The mask_graph0 get_attr node also missing the annotation

Root cause identified

  • _copy_fwd_metadata_to_bw_nodes() relies solely on seq_nr matching
  • DTensor redistribution ops (detach) between fwd output and bwd input cause seq_nr divergence
  • The chain: flex_attention(seq=78) -> getitem(seq=78) -> detach(seq=78) -> detach(seq=76) -> flex_attention_backward(seq=76)
  • Additionally, mask_graph0 is nested inside a tuple arg, not a direct arg

Fix implemented

  • Added _trace_to_flex_attention_fwd() to trace backward's out arg back through detach/getitem/view to the forward flex_attention
  • Added second pass in _copy_fwd_metadata_to_bw_nodes() for backward nodes missing annotation
  • Handles tuple-nested get_attr nodes (mask_graph inside block_mask tuple)

Test result

  • All flex_attention-related nodes now correctly annotated with compile_with_inductor=True
  • PASS on 4xH100 with TP=2 FSDP=4

This PR was generated by ptq with human review.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 9, 2026
@SherlockNoMad
Copy link
Copy Markdown
Contributor

should be addressed by #2924

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Deepseek 16B model doesn't work with make_fx + SAC + regional_inductor on TP=2 FSDP=4

2 participants