Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 84 additions & 9 deletions torchtitan/experiments/graph_trainer/make_fx_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,63 @@ def _patched(t_outputs, *args, **kwargs): # type: ignore[no-untyped-def]
torch.autograd._engine_run_backward = _orig_fn # type: ignore[assignment]


def _copy_fwd_meta_to_node(fwd_node: torch.fx.Node, node: torch.fx.Node) -> None:
"""Copy forward node's custom, nn_module_stack, and stack_trace to *node*."""
custom = fwd_node.meta.get("custom")
if custom:
node.meta.setdefault("custom", {}).update(custom)
nn_module_stack = fwd_node.meta.get("nn_module_stack")
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack.copy()
stack_trace = fwd_node.meta.get("stack_trace")
if stack_trace is not None:
node.meta["stack_trace"] = stack_trace


def _trace_to_flex_attention_fwd(node: torch.fx.Node) -> torch.fx.Node | None:
"""Trace flex_attention_backward's ``out`` arg back to its flex_attention fwd.

flex_attention_backward(q, k, v, out, logsumexp, grad_out, ...) — the
``out`` (arg index 3) originates from the forward flex_attention call via a
chain of getitem / detach / view nodes. Walk backwards through these
pass-through ops until we find a ``flex_attention`` call_function.
"""
import operator

_PASSTHROUGH_OPS = {
torch.ops.aten.detach.default,
torch.ops.aten.view.default,
}
cur = node.args[3] if len(node.args) > 3 else None
for _ in range(20): # bounded walk to avoid infinite loops
if cur is None or not isinstance(cur, torch.fx.Node):
return None
if cur.op == "call_function":
if getattr(cur.target, "__name__", "") == "flex_attention":
return cur
# operator.getitem or aten getitem: follow the first arg
if cur.target in (operator.getitem, operator.__getitem__):
cur = cur.args[0] if cur.args else None
continue
if cur.target in _PASSTHROUGH_OPS:
cur = cur.args[0] if cur.args else None
continue
return None
return None


def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
"""Copy forward node metadata (custom) to later nodes sharing the same seq_nr.

Walks the graph in a single pass. The first node seen for each seq_nr is
treated as the forward node.
Subsequent nodes with the same seq_nr (typically backward nodes) receive
the forward node's custom metadata.

A second pass handles flex_attention_backward nodes whose seq_nr differs
from the forward flex_attention (e.g. DeepSeek with TP). For these, the
forward node is found by tracing the ``out`` input back through
getitem/detach/view to the originating flex_attention call.
"""
seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {}

Expand All @@ -244,16 +294,41 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
seq_nr_to_fwd_node[seq_nr] = node
else:
fwd_node = seq_nr_to_fwd_node[seq_nr]
_copy_fwd_meta_to_node(fwd_node, node)

custom = fwd_node.meta.get("custom")
if custom:
node.meta.setdefault("custom", {}).update(custom)
nn_module_stack = fwd_node.meta.get("nn_module_stack")
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack.copy()
stack_trace = fwd_node.meta.get("stack_trace")
if stack_trace is not None:
node.meta["stack_trace"] = stack_trace
# Second pass: fix up flex_attention_backward (and its get_attr mask/score
# subgraphs) that didn't get annotations because their seq_nr differs from
# the forward flex_attention node.
for node in fx_g.graph.nodes:
if node.op == "call_function" and getattr(
node.target, "__name__", ""
) == "flex_attention_backward":
if "compile_with_inductor" in node.meta.get("custom", {}):
continue # already annotated
fwd_node = _trace_to_flex_attention_fwd(node)
if fwd_node is not None:
_copy_fwd_meta_to_node(fwd_node, node)
# Also propagate to get_attr nodes (mask_graph, score graph,
# joint_graph, fw_graph) that are arguments of this backward
# call and share its seq_nr. These may be direct args or
# nested inside tuple args (e.g. the block_mask tuple).
bw_seq = node.meta.get("seq_nr")

def _maybe_propagate(arg: object) -> None:
if isinstance(arg, torch.fx.Node):
if (
arg.op == "get_attr"
and arg.meta.get("seq_nr") == bw_seq
and "compile_with_inductor"
not in arg.meta.get("custom", {})
):
_copy_fwd_meta_to_node(fwd_node, arg)
elif isinstance(arg, (tuple, list)):
for item in arg:
_maybe_propagate(item)

for arg in node.args:
_maybe_propagate(arg)


def trace_module(
Expand Down
Loading