From 1e3a836af4b92f27b34d4780facee98a57d978d8 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Sat, 14 Feb 2026 20:10:50 -0600 Subject: [PATCH 1/5] Add AutoSP to DeepSpeed --------- Signed-off-by: Neel Dani Co-authored-by: Ahan Gupta --- deepspeed/compile/config.py | 6 + deepspeed/compile/constants.py | 11 + deepspeed/compile/custom_ops/__init__.py | 9 + deepspeed/compile/custom_ops/all_to_all.py | 75 ++++++ .../compile/custom_ops/sp_dp_registry.py | 67 +++++ deepspeed/compile/fx.py | 33 ++- deepspeed/compile/init_sp.py | 21 ++ .../passes/long_context_checkpointing.py | 109 ++++++++ deepspeed/compile/passes/sp_compile.py | 243 ++++++++++++++++++ deepspeed/compile/util.py | 87 ++++++- deepspeed/runtime/engine.py | 117 ++++++--- docs/_pages/config-json.md | 28 ++ docs/code-docs/source/training.rst | 88 +++++++ 13 files changed, 850 insertions(+), 44 deletions(-) create mode 100644 deepspeed/compile/constants.py create mode 100644 deepspeed/compile/custom_ops/__init__.py create mode 100644 deepspeed/compile/custom_ops/all_to_all.py create mode 100644 deepspeed/compile/custom_ops/sp_dp_registry.py create mode 100644 deepspeed/compile/init_sp.py create mode 100644 deepspeed/compile/passes/long_context_checkpointing.py create mode 100644 deepspeed/compile/passes/sp_compile.py diff --git a/deepspeed/compile/config.py b/deepspeed/compile/config.py index 739add99271c..2137b94722f2 100644 --- a/deepspeed/compile/config.py +++ b/deepspeed/compile/config.py @@ -3,8 +3,11 @@ # DeepSpeed Team +from typing import List, Optional, Literal from deepspeed.runtime.config_utils import DeepSpeedConfigModel +PassName = Literal["z1", "z3", "autosp"] + class CompileConfig(DeepSpeedConfigModel): """ Configure compile settings """ @@ -53,3 +56,6 @@ class CompileConfig(DeepSpeedConfigModel): keep_all_input_tensors: bool = False """ Keep real values for all input tensors in InputStorage instead of using dummy values """ + + passes: Optional[List[PassName]] = None + """ Composes different optimizations. """ diff --git a/deepspeed/compile/constants.py b/deepspeed/compile/constants.py new file mode 100644 index 000000000000..e365b692a7d8 --- /dev/null +++ b/deepspeed/compile/constants.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######################################### +# AUTOSP +######################################### +AUTOSP_INPUT_ID_KEY = "input_id" +AUTOSP_LABEL_ID_KEY = "label_id" +AUTOSP_POSITION_ID_KEY = "position_id" diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py new file mode 100644 index 000000000000..0342f257fb5f --- /dev/null +++ b/deepspeed/compile/custom_ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .all_to_all import all_to_all +from . import sp_dp_registry + +__all__ = ["all_to_all", "sp_dp_registry"] diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py new file mode 100644 index 000000000000..dea50695c5df --- /dev/null +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +from .sp_dp_registry import get_group, is_setup, sp_size + + +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) +def all_to_all( + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + name: str, +) -> torch.Tensor: + """ + All-to-all collective for SDPA tensors [B, N, S, H]. + + For QKV (scatter_idx=1, gather_idx=2): + [B, N, S/P, H] -> [B, N/P, S, H] + For O (scatter_idx=2, gather_idx=1): + [B, N/P, S, H] -> [B, N, S/P, H] + """ + assert is_setup(), 'Incorrect initialization of SP/DP mesh.' + B, dim1, dim2, H = input.shape + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + if scatter_idx == 1: + N, local_S = dim1, dim2 + input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) + else: + local_N, S = dim1, dim2 + input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) + + return output + + +@torch.library.register_fake("autosp::all_to_all") +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): + B, dim1, dim2, H = input.shape + if scatter_idx == 1: + return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) + else: + return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) + + +def _all_to_all_backward_setup(ctx, inputs, output): + _, scatter_idx, gather_idx, name = inputs + ctx.scatter_idx = gather_idx + ctx.gather_idx = scatter_idx + ctx.name = name + "_grad" + + +def _all_to_all_backward(ctx, grad): + return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None) + + +torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup) diff --git a/deepspeed/compile/custom_ops/sp_dp_registry.py b/deepspeed/compile/custom_ops/sp_dp_registry.py new file mode 100644 index 000000000000..a93707032959 --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_dp_registry.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.comm as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group() + + +def get_registry(): + return GROUP_REGISTRY + + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + + +def extract_mesh_size(param_dict): + sp_size = param_dict.get('sequence_parallel_size', 1) + assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE' + dp_size = dist.get_world_size() // sp_size + + return sp_size, dp_size + + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + + +def populate_registry(SP_SIZE, DP_SIZE): + """ Populate rank to SP/DP mesh index. """ + + if GROUP_REGISTRY.get('is_reg', False): + return + + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 7b3408b56afe..d794851b4532 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -3,11 +3,11 @@ # DeepSpeed Team -from typing import Callable, Any, List, Dict +from typing import Callable, Any, List, Dict, Optional from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from .util import get_last_uses @@ -138,3 +138,32 @@ def free_tensors(tensors: List[torch.Tensor]): # Python version for debugging # graph.create_node('call_function', free_tensors, args, {}, name=node_name) + + +def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]: + for node in gm.graph.nodes: + if node.name == name: + return node + return None + + +def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]: + return node.meta.get("val") or node.meta.get("example_value") + + +def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]: + input_id_node = None + for node in gm.graph.nodes: + # https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048 + tensor_dict = node.meta.get('tensor_dict') + if tensor_dict and tensor_dict.get('tag') == tag: + input_id_node = node + break + return input_id_node + + +def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None): + exclude = exclude or [] + to_replace = [u for u in node.users if u not in exclude] + for user in to_replace: + user.replace_input_with(node, replacement) diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py new file mode 100644 index 000000000000..fdf2c1c499ae --- /dev/null +++ b/deepspeed/compile/init_sp.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch.fx import GraphModule +from .passes.sp_compile import apply_autosp +from .passes.long_context_checkpointing import register_long_context_checkpointing +from .custom_ops.sp_dp_registry import extract_mesh_size + + +def init_autosp(config): + sp_size, dp_size = extract_mesh_size(config._param_dict) + register_long_context_checkpointing() + + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size) + return torch._inductor.compile(gm, real_inputs) + + return backend_fn diff --git a/deepspeed/compile/passes/long_context_checkpointing.py b/deepspeed/compile/passes/long_context_checkpointing.py new file mode 100644 index 000000000000..3762da330df9 --- /dev/null +++ b/deepspeed/compile/passes/long_context_checkpointing.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import inspect +import textwrap +import torch._functorch.partitioners as _partitioners + +# The custom should_ban_recomputation to splice into solve_min_cut. +# All names it references (aten, operator, config, op_types, min_cut_options, +# is_materialized_backwards, get_aten_target, _size_of, fx, torch, +# CheckpointPolicy) are either module-level in torch._functorch.partitioners +# or local variables already in scope when this function executes inside +# solve_min_cut. +_CUSTOM_SHOULD_BAN = """\ +def should_ban_recomputation(node): + \"\"\"Sequence-aware recomputation banning logic\"\"\" + if node.op != "call_function": + return False + if node.target == operator.getitem: + return False + if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: + return True + if config.recompute_views and op_types.is_view(node): + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + must_save_set = [ + aten.convolution, + aten.convolution_backward, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten._flash_attention_forward, + aten._efficient_attention_forward, + aten.upsample_bilinear2d, + aten.native_dropout, + aten.rand_like, + aten.randn_like, + ] + + if get_aten_target(node) in must_save_set: + return True + + def heuristic(node): + if "val" in node.meta: + if isinstance(node.meta["val"], torch.Tensor) and node.meta["val"].dim() >= 2: + return node.meta["val"].shape[1] >= 4096 + return False + + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): + return False + + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(node): + if heuristic(node): + return False + return True + + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return False + + if min_cut_options.ban_if_reduction: + input_tensors_size = sum( + _size_of(i) for i in node.args if isinstance(i, fx.Node) + ) + output_size = _size_of(node) + return output_size * 4 < input_tensors_size + return False +""" + + +def register_long_context_checkpointing(): + """Splice the custom should_ban_recomputation into solve_min_cut. + + Uses inspect.getsource to extract solve_min_cut's source, replaces the + original should_ban_recomputation with _CUSTOM_SHOULD_BAN, then execs the + result directly in _partitioners.__dict__. + + The exec'd function's __globals__ is the real partitioners module dict, so + every other nested function (is_fusible, is_materialized_backwards, + can_fuse_into_*, etc.) and every local/closure variable (op_types, + min_cut_options, node_info, config, …) is exactly as in the original — + nothing else changes. + + Backward compatible: if solve_min_cut gains new heuristics in a future + PyTorch version the exec automatically picks them up; only + _CUSTOM_SHOULD_BAN needs to stay in sync with any changes to the + original should_ban_recomputation signature/contract. + """ + src = inspect.getsource(_partitioners.solve_min_cut) + lines = src.split('\n') + + # Locate the original should_ban_recomputation and the function after it. + start = next( + i for i, l in enumerate(lines) + if l.startswith(' def should_ban_recomputation(') + ) + end = next( + i for i, l in enumerate(lines) + if i > start and l.startswith(' def ') + ) + + # Indent the replacement to the nesting level inside solve_min_cut (4 spaces). + replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ') + + new_src = '\n'.join(lines[:start]) + '\n' + replacement + '\n'.join(lines[end:]) + exec(new_src, _partitioners.__dict__) # redefines _partitioners.solve_min_cut diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py new file mode 100644 index 000000000000..5b33aa119a56 --- /dev/null +++ b/deepspeed/compile/passes/sp_compile.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from typing import Optional, List, Callable + +import torch +import deepspeed.comm as dist +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from deepspeed.compile import constants + +from ..custom_ops import all_to_all, sp_dp_registry # noqa: F401 +from ..fx import find_node_by_name, get_node_shape_meta +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes + + +def prepare_autosp_inputs(input_id: torch.Tensor, + label_id: torch.Tensor, + position_id: torch.Tensor = None, + attention_mask: torch.Tensor = None, + seq_dim: int = 1): + """ + Prepare inputs for AutoSP by marking dynamic dimensions and tagging tensors. + + Args: + input_id: Token IDs tensor (required) + label_id: Label IDs tensor (required) + position_id: Position IDs tensor (optional) + attention_mask: Attention mask tensor (optional) + seq_dim: Sequence dimension index to mark as dynamic (default: 1) + """ + + if input_id is None: + raise ValueError("input_id is required") + if label_id is None: + raise ValueError("label_id is required") + + if seq_dim < 0 or seq_dim >= input_id.ndim: + raise ValueError(f"seq_dim {seq_dim} must be a valid index for input_id with shape {input_id.shape}") + + if position_id is not None: + if seq_dim >= position_id.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for position_id with shape {position_id.shape}") + + if attention_mask is not None: + if seq_dim >= attention_mask.ndim: + raise ValueError( + f"seq_dim {seq_dim} is out of bounds for attention_mask with shape {attention_mask.shape}") + + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + if position_id is not None: + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + if attention_mask is not None: + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + + input_id.tag = constants.AUTOSP_INPUT_ID_KEY + label_id.tag = constants.AUTOSP_LABEL_ID_KEY + if position_id is not None: + position_id.tag = constants.AUTOSP_POSITION_ID_KEY + + return input_id, label_id, position_id, attention_mask + + +def pass_shard_seq_dim(gm: GraphModule, example_inputs): + """ + Finds all direct and indirect consumers of the input sequence, label and position ids. + Shard the sequence dimension used by all such consumers. + """ + sp_size = sp_dp_registry.sp_size() + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + assert isinstance( + seq_symint, + torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" + + sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) + if sym_seq_dim_node is None: + print(f"WARNING: Could not find the symbolic node for the sequence dimension") + return + + with gm.graph.inserting_after(sym_seq_dim_node): + sharded_node = gm.graph.call_function(operator.floordiv, args=(sym_seq_dim_node, sp_size)) + + sharded_input_nodes = set() + label_ids_node = get_label_id_node(gm) + position_ids_node = get_position_id_node(gm) + + if input_ids_node is not None: + sharded_input_nodes.add(input_ids_node) + if label_ids_node is not None: + sharded_input_nodes.add(label_ids_node) + if position_ids_node is not None: + sharded_input_nodes.add(position_ids_node) + + # find all consumers of the sharded inputs + consumer_nodes = set() + worklist = list(sharded_input_nodes) + visited = set() + + while worklist: + node = worklist.pop(0) + if node in visited: + continue + visited.add(node) + consumer_nodes.add(node) + + for user in node.users: + if user not in visited: + worklist.append(user) + + to_replace = [] + for node in consumer_nodes: + if sym_seq_dim_node in node.all_input_nodes: + to_replace.append(node) + + for user in to_replace: + user.replace_input_with(sym_seq_dim_node, sharded_node) + + +def pass_shard_input_ids(gm: GraphModule, example_inputs): + input_ids_node = get_input_id_node(gm) + shard_tensor_node(gm, input_ids_node) + + +def pass_shard_label_ids(gm: GraphModule, example_inputs): + label_ids_node = get_label_id_node(gm) + shard_tensor_node(gm, label_ids_node) + + +def pass_shard_position_ids(gm: GraphModule, example_inputs): + position_ids_node = get_position_id_node(gm) + if position_ids_node is None: + print("[WARNING] position id node not found. Skipping sharding of position ids.") + return + shard_tensor_node(gm, position_ids_node) + + +def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): + + def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: + with gm.graph.inserting_after(node): + a2a_node = gm.graph.call_function( + torch.ops.autosp.all_to_all.default, + args=(node, scatter_idx, gather_idx, name), + ) + a2a_node.name = f"a2a_{name}" + node.replace_all_uses_with(a2a_node) + a2a_node.update_arg(0, node) + return a2a_node + + attention_nodes = get_sdpa_nodes(gm) + if len(attention_nodes) == 0: + raise RuntimeError("AutoSP currently supports torch.nn.functional.scaled_dot_product_attention as the " + "attention backend. No SDPA attention operations were found in the compiled graph. " + "Please ensure your model uses torch.nn.functional.scaled_dot_product_attention " + "for AutoSP to work as expected.") + + for idx, attn_node in enumerate(attention_nodes): + q, k, v = attn_node.args[:3] + suffix = f"_{idx}" if len(attention_nodes) > 1 else "" + + # QKV: [B, N, S/P, H] -> [B, N/P, S, H] + insert_a2a(q, scatter_idx=1, gather_idx=2, name=f"q{suffix}") + insert_a2a(k, scatter_idx=1, gather_idx=2, name=f"k{suffix}") + insert_a2a(v, scatter_idx=1, gather_idx=2, name=f"v{suffix}") + + # O: [B, N/P, S, H] -> [B, N, S/P, H] + insert_a2a(attn_node, scatter_idx=2, gather_idx=1, name=f"o{suffix}") + + +def pass_canonicalize(gm: GraphModule, real_inputs): + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + fake_inputs = [] + for t in real_inputs: + if isinstance(t, torch.Tensor): + fake_inputs.append(fake_mode.from_tensor(t)) + else: + fake_inputs.append(t) + FakeTensorProp(gm).propagate(*fake_inputs) + + +def apply_autosp(gm: GraphModule, + real_inputs, + debug: bool = False, + passes: Optional[List[Callable]] = None, + sp_size: int = 2, + dp_size: int = 1): + """ + Apply AutoSP (Ulysses) transformation passes to the graph and setup either DP/SP (2D) or SP (1D) mesh. + + Args: + gm: GraphModule to transform + real_inputs: Example inputs for shape propagation + debug: If True, print graph before/after each pass + passes: Optional custom list of passes (default: DEFAULT_PASSES) + """ + assert sp_size * dp_size <= dist.get_world_size(), 'Insufficient device count for mesh size' + + sp_dp_registry.populate_registry(sp_size, dp_size) + + AUTOSP_PASSES = [ + pass_shard_seq_dim, + pass_shard_input_ids, + pass_shard_label_ids, + pass_shard_position_ids, + pass_insert_attention_all_to_all, + pass_propagate_shapes, + pass_canonicalize, + ] + + passes = passes or AUTOSP_PASSES + rank = dist.get_rank() + + for p in passes: + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" BEFORE: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + + p(gm, real_inputs) + + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" AFTER: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index e8abcc2c8b3c..97b76f46b866 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -9,8 +9,9 @@ from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from torch.fx.node import map_aggregate, Argument, map_arg +import torch.nn.functional as F try: from torch._subclasses.fake_tensor import unset_fake_temporarily @@ -22,6 +23,9 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder +from deepspeed.compile import constants + +from .custom_ops import sp_dp_registry def is_deepcompile_supported() -> bool: @@ -521,3 +525,84 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor padded.append(out) return padded + + +def create_shard_offsets(gm: GraphModule, s0_node: Node) -> Tuple[Node, Node]: + sp_size: int = sp_dp_registry.sp_size() + sp_rank: int = dist.get_rank() % sp_dp_registry.sp_size() + with gm.graph.inserting_after(s0_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(s0_node, sp_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(sp_rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + + +def get_sdpa_nodes(gm: GraphModule) -> List[Node]: + return list(gm.graph.find_nodes( + op="call_function", + target=F.scaled_dot_product_attention, + )) + + +def get_input_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_INPUT_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the input sequence.") + return node + + +def get_label_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_LABEL_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the label.") + return node + + +def get_position_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.AUTOSP_POSITION_ID_KEY) + return node + + +def create_symbolic_slice_indices( + gm: GraphModule, + sym_seq_dim_node: Node, +) -> Tuple[Node, Node]: + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node) + + with gm.graph.inserting_after(end_node): + slice_all = gm.graph.call_function(slice, args=(None, None, None)) + with gm.graph.inserting_after(slice_all): + slice_range = gm.graph.call_function(slice, args=(start_node, end_node, None)) + + return slice_all, slice_range + + +def shard_tensor_node(gm: GraphModule, tensor_node: Node): + from .fx import find_node_by_name, get_node_shape_meta, replace_node_users + val = get_node_shape_meta(tensor_node) + assert val is not None, f"Node {tensor_node.name} has no shape metadata" + + seq_len = val.shape[1] + + assert isinstance( + seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`" + + symb_seq_int_node = find_node_by_name(gm, str(seq_len)) + assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" + + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node) + indices = (slice_all, slice_range) + + with gm.graph.inserting_after(tensor_node): + sliced_node = gm.graph.call_function( + operator.getitem, + args=(tensor_node, indices), + ) + + replace_node_users(tensor_node, sliced_node, exclude=[sliced_node]) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a4e4608c847..87e61bb367a0 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -127,6 +127,7 @@ from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states from deepspeed.compile.init_z1 import init_z1 from deepspeed.compile.init_z3 import init_z3 +from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 @@ -1004,6 +1005,14 @@ def zero_sub_group_size(self): def zero_optimization_stage(self): return self._config.zero_optimization_stage + def compile_zero_optimization_stage(self): + """Determines if zero-pass is set in deepcompile's passes attributes.""" + return "z1" in self._config.compile_config.passes or "z3" in self._config.compile_config.passes + + def compile_autosp(self): + """Determines if AutoSP is set in deepcompile's passes attributes.""" + return "autosp" in (getattr(self._config.compile_config, "passes", None) or []) + def mics_shard_size(self): return self._config.mics_shard_size @@ -2373,7 +2382,7 @@ def print_forward_breakdown(self, fwd_time): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Skip gradient reduction when DeepCompile is enabled # DeepCompile handles its own gradient reduction through compiled graph operations - if self.is_deepcompile_active(): + if self.is_deepcompile_active() and not self.compile_autosp(): return # Pass (PP) gas boundary flag to optimizer (required for zero) @@ -4336,6 +4345,62 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() + def get_autosp_backend(self, compile_kwargs): + if self.compile_autosp() and self.zero_optimization_stage() not in [ + ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states + ]: + logger.info( + f"Currently AutoSP does not compose with ZeRO stage 2 and 3. Falling back to the torch compiler.") + return None + + compile_config = self._config.compile_config + compile_kwargs['fullgraph'] = True + return init_autosp(self._config) + + def get_deepcompile_backend(self, backend, compile_kwargs, schedule): + if self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ + and self.zero_optimization_stage() != ZeroStageEnum.weights \ + and self.zero_optimization_stage() != ZeroStageEnum.gradients: + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) + return None + + compile_config = self._config.compile_config + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + return init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + return init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + return init_z3(self, backend, compile_config, compile_kwargs, schedule) + return None + + def get_deepspeed_compile_backend(self, backend, compile_kwargs, schedule): + resolved_backend = None + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] + + assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." + + if self.compile_autosp(): + resolved_backend = self.get_autosp_backend(compile_kwargs) + else: + resolved_backend = self.get_deepcompile_backend(backend, compile_kwargs, schedule) + + return resolved_backend, schedule + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, @@ -4358,53 +4423,23 @@ def compile(self, logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") - enable_deepcompile = self.is_deepcompile_enabled() - if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ - and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients: - logger.info( - f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." - ) - enable_deepcompile = False - - if enable_deepcompile: - - if schedule is not None: - - def passes_name_to_fn(passes): - for p in passes: - assert callable(p) or p in opt_passes, f"Unknown pass {p}" - return [p if callable(p) else opt_passes[p] for p in passes] - - schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] - - assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." - - compile_config = self._config.compile_config - if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] - and "offload_param" in self.config["zero_optimization"]) - and self._config.zero_config.offload_param.device == "cpu" - and self._config.zero_config.offload_optimizer.device == "cpu"): - compile_config.offload_parameters = True - if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) - elif self.zero_optimization_stage() == ZeroStageEnum.gradients: - backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True) - elif self.zero_optimization_stage() == ZeroStageEnum.weights: - if required_torch_version(min_version=2.9): - raise RuntimeError( - "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " - "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") - backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + resolved_backend = None + if self.is_deepcompile_enabled(): + resolved_backend, schedule = self.get_deepspeed_compile_backend(backend, compile_kwargs, schedule) + + is_deepspeed_compile_backend = resolved_backend is not None + + # default to torch.compiler backend if deepspeed config validation fails + backend = resolved_backend or backend # Hook state must align with whether DeepCompile is active. - self._set_deepcompile_active(enable_deepcompile) + self._set_deepcompile_active(is_deepspeed_compile_backend) # create new dict to avoid modifying original dict try: self.module.compile(**{**compile_kwargs, 'backend': backend}) except Exception: - if enable_deepcompile: + if is_deepspeed_compile_backend: # Restore default hooks if compilation fails before completing. self._set_deepcompile_active(False) raise diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index d5344d3b2320..1e164d273bd0 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1896,6 +1896,34 @@ Different pruning sets, this is used for different pruning parameters. In this e | ------------------------------------------------------------- | ------- | | Use pipeline stages to parallelize the writing of checkpoints.| `false` | +### AutoSP options + +DeepSpeed provides compiler-based optimization passes through the `compile` configuration. This includes enabling Ulysses-styled sequence paralllelism and a custom heuristic selective activation checkpointing pass. To enable Automatic Sequence Parallelism (AutoSP), configure the `compile` section: + +```json +{ + "zero_optimization": {"stage": 0}, + "compile": { + "deepcompile": true, + "passes": ["autosp"], + "pass_args": {"sp_size": 2} + } +} +``` + +**passes**: [array of strings] + +| Description | Default | +| ------------------------------------------------------------------------ | ------- | +| List of compiler passes to apply. Currently supported: `["autosp"]`. | `[]` | + + +**sp_size**: [integer] + +| Description | Default | +| ----------------------------------------------------------------------------------- | ------- | +| Sequence parallel degree (number of devices to shard the sequence dimension across). | `1` | + ### Data Type options ```json diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 92e3bcf80f1f..265ea7fd2bec 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -505,3 +505,91 @@ unless you provide a custom ``partition_config``. These presets are also useful when you want to extend the default patterns: set ``use_default_specs`` to ``true`` in ``partition_config`` to merge your custom specs on top of the selected preset. + + +Automatic Sequence Parallel Training +------------------------------------ +DeepSpeed supports **Automatic Sequence Parallel (AutoSP) training** for enabling +compiler-based sequence parallelism to unlock long-context LLM training. AutoSP +leverages defines custom passes to automatically shard inputs along the +sequence dimension and enable Ulysses-styled sequence parallelism. + +AutoSP training is enabled by setting ``compile`` and ``passes`` in the DeepSpeed +config and calling ``prepare_autosp_inputs()`` to prepare inputs before each forward pass. + +.. code-block:: python + + import deepspeed + from deepspeed.compile.passes.sp_compile import prepare_autosp_inputs + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": {"stage": 0}, + "compile": { + "deepcompile": True, + "passes": ["autosp"], + "pass_args": {"sp_size": 2} + } + } + + engine, optimizer, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizer, + config=ds_config, + ) + + # Compile the model before training + engine.compile(backend='inductor') + + for batch in dataloader: + input_ids = prepare_autosp_inputs( + input_id=batch["input_ids"], + label_id=batch["labels"], + position_id=batch.get("position_ids"), + seq_dim=1 + ) + loss = engine(input_ids) + engine.backward(loss) + engine.step() + +.. note:: + AutoSP requires ZeRO stage 0 (no ZeRO optimization). Using AutoSP with ZeRO stages 1, 2, or 3 is not currently supported. + AutoSP also requires ``torch.nn.functional.scaled_dot_product_attention()`` as the attention backend. + +Input Preparation +~~~~~~~~~~~~~~~~~ + +Before each forward pass, inputs must be prepared using ``prepare_autosp_inputs()`` to +mark the sequence dimension as dynamic and annotate tensors for identification during +automatic sharding: + +.. code-block:: python + + from deepspeed.compile.passes.sp_compile import prepare_autosp_inputs + + input_ids = prepare_autosp_inputs( + input_id=input_ids, + label_id=labels, + position_id=position_ids, # optional + attention_mask=attention_mask, # optional + seq_dim=1 + ) + +This serves as a hint to the compiler to know which inputs should be sharded across which dimension. + +Memory Optimization +~~~~~~~~~~~~~~~~~~~ + +AutoSP includes selective activation checkpointing that recomputes matmul operations +during backpropagation while preserving attention activations. This is effective for +long-context training because attention operations scale quadratically with sequence +length and dominate computation latency, while matmul operations scale linearly and are relatively cheaper +to recompute. This provides significant memory savings with minimal computational +overhead + +Limitations +~~~~~~~~~~~ + +AutoSP currently supports only ``torch.nn.functional.scaled_dot_product_attention``. Other attention patterns require additional pattern matching logic. + +AutoSP requires a fully connected computation graph without breaks. Graph breaks destroy the use-def chains across graphs and the compiler cannot propoaget sequence dimension sharding information. From 6f73ea282cf172cfec4d2a819aa034f7a1921e75 Mon Sep 17 00:00:00 2001 From: Ahan Gupta Date: Fri, 13 Mar 2026 20:44:30 +0000 Subject: [PATCH 2/5] Add AutoSP unit and end-to-end tests Signed-off-by: Ahan Gupta Co-authored-by: Neel Dani --- deepspeed/compile/custom_ops/__init__.py | 2 +- deepspeed/compile/custom_ops/sp_compat.py | 28 +++ deepspeed/compile/init_sp.py | 2 + .../passes/long_context_checkpointing.py | 10 +- tests/unit/v1/compile/test_compile_autosp.py | 237 ++++++++++++++++++ tests/unit/v1/compile/util.py | 207 +++++++++++++++ 6 files changed, 477 insertions(+), 9 deletions(-) create mode 100644 deepspeed/compile/custom_ops/sp_compat.py create mode 100644 tests/unit/v1/compile/test_compile_autosp.py diff --git a/deepspeed/compile/custom_ops/__init__.py b/deepspeed/compile/custom_ops/__init__.py index 0342f257fb5f..e5fc593a2e7e 100644 --- a/deepspeed/compile/custom_ops/__init__.py +++ b/deepspeed/compile/custom_ops/__init__.py @@ -6,4 +6,4 @@ from .all_to_all import all_to_all from . import sp_dp_registry -__all__ = ["all_to_all", "sp_dp_registry"] +__all__ = ["all_to_all", "sp_dp_registry", "sp_compat"] diff --git a/deepspeed/compile/custom_ops/sp_compat.py b/deepspeed/compile/custom_ops/sp_compat.py new file mode 100644 index 000000000000..a13bcf42fbf4 --- /dev/null +++ b/deepspeed/compile/custom_ops/sp_compat.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from packaging.version import Version + + +def _check_autosp_compatibility(): + # Strip the local version segment (e.g. +cu128) so CUDA builds don't sort + # above the max bound when using packaging's local-version ordering rules. + torch_version = Version(torch.__version__.split("+")[0]) + if torch_version < Version("2.6") or torch_version >= Version("2.8"): + raise RuntimeError( + "AutoSP requires PyTorch >= 2.6 and <= 2.7, found " + f"{torch.__version__}." + ) + + try: + import transformers + if Version(transformers.__version__) > Version("4.50.3"): + raise RuntimeError( + "AutoSP requires transformers <= 4.50.3, found " + f"{transformers.__version__}." + ) + except ImportError: + pass # transformers not installed; skip the check diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py index fdf2c1c499ae..23ffd28cae4c 100644 --- a/deepspeed/compile/init_sp.py +++ b/deepspeed/compile/init_sp.py @@ -8,9 +8,11 @@ from .passes.sp_compile import apply_autosp from .passes.long_context_checkpointing import register_long_context_checkpointing from .custom_ops.sp_dp_registry import extract_mesh_size +from .custom_ops.sp_compat import _check_autosp_compatibility def init_autosp(config): + _check_autosp_compatibility() sp_size, dp_size = extract_mesh_size(config._param_dict) register_long_context_checkpointing() diff --git a/deepspeed/compile/passes/long_context_checkpointing.py b/deepspeed/compile/passes/long_context_checkpointing.py index 3762da330df9..0f72d94fdf9e 100644 --- a/deepspeed/compile/passes/long_context_checkpointing.py +++ b/deepspeed/compile/passes/long_context_checkpointing.py @@ -93,14 +93,8 @@ def register_long_context_checkpointing(): lines = src.split('\n') # Locate the original should_ban_recomputation and the function after it. - start = next( - i for i, l in enumerate(lines) - if l.startswith(' def should_ban_recomputation(') - ) - end = next( - i for i, l in enumerate(lines) - if i > start and l.startswith(' def ') - ) + start = next(i for i, l in enumerate(lines) if l.startswith(' def should_ban_recomputation(')) + end = next(i for i, l in enumerate(lines) if i > start and l.startswith(' def ')) # Indent the replacement to the nesting level inside solve_min_cut (4 spaces). replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ') diff --git a/tests/unit/v1/compile/test_compile_autosp.py b/tests/unit/v1/compile/test_compile_autosp.py new file mode 100644 index 000000000000..8277eff8928b --- /dev/null +++ b/tests/unit/v1/compile/test_compile_autosp.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F + +from deepspeed.utils.torch import required_torch_version +from deepspeed.accelerator import get_accelerator +from deepspeed.compile import constants + +from unit.v1.compile.util import compare_sp_loss, create_gm_nodes, find_sym_seq_node +from unit.common import DistributedTest +from unit.util import bf16_required_version_check, skip_on_arch + +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.6), + reason="AutoSP tests require PyTorch >= 2.6") + +# Fixed sp_size injected into mocks. +_SP_SIZE = 2 + + +class TestAutoSPCompile(DistributedTest): + world_size = 4 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [0, 1]) + @pytest.mark.parametrize('sp_size', [2, 4]) + def test(self, zero_stage, dtype, sp_size): + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" + ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + dp_size = self.world_size // sp_size + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": dp_size, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True, + "passes": ["autosp"] + }, + "sequence_parallel_size": sp_size, + "gradient_clipping": 1.0, + } + + if dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + compare_sp_loss(self, config_dict, sp_size) + + +# Plain pytest classes — no distributed runtime needed because these functions +# perform pure IR-level graph rewrites; sp_size and get_rank are mocked. + + +class TestSDPANodesCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_sdpa_nodes + + gm, _ = create_gm_nodes(seq_len=seq_len) + sdpa_nodes = get_sdpa_nodes(gm) + + assert len(sdpa_nodes) >= 1, f"Expected at least 1 SDPA node, got {len(sdpa_nodes)}" + for node in sdpa_nodes: + assert node.target == F.scaled_dot_product_attention + + +class TestInputIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_input_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_input_id_node(gm) + + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_INPUT_ID_KEY + + +class TestLabelIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_label_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_label_id_node(gm) + + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_LABEL_ID_KEY + + +class TestPositionIdCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + from deepspeed.compile.util import get_position_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + node = get_position_id_node(gm) + + assert node is not None, "position_id node not found in graph" + assert node.op == "placeholder" + tensor_dict = node.meta.get("tensor_dict", {}) + assert tensor_dict.get("tag") == constants.AUTOSP_POSITION_ID_KEY + + +class TestShardOffsetsCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import create_shard_offsets + + gm, _ = create_gm_nodes(seq_len=seq_len) + sym_seq_node = find_sym_seq_node(gm) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + start_node, end_node = create_shard_offsets(gm, sym_seq_node) + + # create_shard_offsets emits: chunk = seq // sp_size; start = rank * chunk; end = start + chunk. + # Verify the three-node chain has the right operators and wiring. + chunk_size_node = start_node.args[1] # start = rank * chunk → chunk is arg[1] + + assert chunk_size_node.target == operator.floordiv + assert chunk_size_node.args[0] is sym_seq_node + assert chunk_size_node.args[1] == _SP_SIZE + + assert start_node.target == operator.mul + assert start_node.args[0] == 0 # rank 0 baked in at transform time + assert start_node.args[1] is chunk_size_node + + assert end_node.target == operator.add + assert end_node.args[0] is start_node + assert end_node.args[1] is chunk_size_node + + +class TestSymSliceCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import create_symbolic_slice_indices + + gm, _ = create_gm_nodes(seq_len=seq_len) + sym_seq_node = find_sym_seq_node(gm) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + slice_all, slice_range = create_symbolic_slice_indices(gm, sym_seq_node) + + # slice_all = slice(None, None, None) — selects the batch dimension unchanged + assert slice_all.target == slice + assert slice_all.args == (None, None, None) + + # slice_range selects [start, end) along the sequence dim, where start and + # end come from create_shard_offsets (mul and add nodes respectively). + assert slice_range.target == slice + start_arg, end_arg, step_arg = slice_range.args + assert step_arg is None + + # start = rank * chunk → verify the full shard-offset wiring + chunk_size_node = start_arg.args[1] + assert start_arg.target == operator.mul + assert start_arg.args[0] == 0 # rank 0 baked in at transform time + assert chunk_size_node.target == operator.floordiv + assert chunk_size_node.args[0] is sym_seq_node + assert chunk_size_node.args[1] == _SP_SIZE + + # end = start + chunk + assert end_arg.target == operator.add + assert end_arg.args[0] is start_arg + assert end_arg.args[1] is chunk_size_node + + +class TestShardTensorCompile: + + @pytest.mark.parametrize('seq_len', [64, 128, 256]) + def test(self, seq_len): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.util import shard_tensor_node, get_input_id_node + + gm, _ = create_gm_nodes(seq_len=seq_len) + input_ids_node = get_input_id_node(gm) + original_users = set(input_ids_node.users.keys()) + assert len(original_users) > 0, "input_ids_node must have users before sharding" + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + shard_tensor_node(gm, input_ids_node) + + getitem_nodes = [n for n in gm.graph.nodes if n.target == operator.getitem and n.args[0] is input_ids_node] + assert len(getitem_nodes) == 1, f"Expected 1 slice node after sharding, got {len(getitem_nodes)}" + sliced_node = getitem_nodes[0] + + # After sharding, the raw node must only feed the slice; all downstream + # consumers are rewired to sliced_node by replace_node_users. + assert set(input_ids_node.users.keys()) == {sliced_node} + + for user in original_users: + assert input_ids_node not in user.all_input_nodes, \ + f"User '{user.name}' still references the unsharded input_ids_node" + assert sliced_node in user.all_input_nodes, \ + f"User '{user.name}' does not reference the sliced node" diff --git a/tests/unit/v1/compile/util.py b/tests/unit/v1/compile/util.py index 40bc8cfe4ba1..2e4e45bd2b31 100644 --- a/tests/unit/v1/compile/util.py +++ b/tests/unit/v1/compile/util.py @@ -11,6 +11,7 @@ import torch import deepspeed +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero import GatheredParameters @@ -87,3 +88,209 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None): baseline_engine.destroy() target_engine.destroy() + + +def compare_sp_loss(self, config, sp_size, iterations=3): + """ + Compare AutoSP compiled model loss against a compiled Ulysses SP model (ground truth). + + Both engines are trained in lockstep. After all training steps the final-step + losses are compared. + """ + import torch.nn.functional as F + from transformers import AutoModelForCausalLM, AutoConfig + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from deepspeed.compile import constants as autosp_constants + from deepspeed.compile.custom_ops.sp_dp_registry import populate_registry, get_group + from deepspeed.sequence.layer import DistributedAttention + + RTOL, ATOL = 0.1, 0.01 + model_name = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + seq_length = 64 + + torch.manual_seed(42) + get_accelerator().manual_seed_all(42) + device = torch.device(get_accelerator().current_device_name()) + + model_config = AutoConfig.from_pretrained(model_name) + model_config._attn_implementation = "sdpa" + base_model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + vocab_size = model_config.vocab_size + + # Set up SP/DP process groups (shared by both Ulysses and AutoSP). + dp_size = dist.get_world_size() // sp_size + populate_registry(sp_size, dp_size) + # The DP-rank index selects which SP group the current rank belongs to. + sp_group = get_group(dist.get_rank() // sp_size) + sp_rank = dist.get_rank() % sp_size + chunk = seq_length // sp_size + + # Build a DistributedAttention wrapper that mirrors distributed_attention.py. + # Registered under a unique key so the model's "sdpa" slot stays untouched — + # AutoSP's graph pass can therefore find F.scaled_dot_product_attention nodes. + def _sdpa_inner(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): + # DistributedAttention delivers tensors in [b, s, n, h]; SDPA wants [b, n, s, h]. + out = F.scaled_dot_product_attention(q.permute(0, 2, 1, 3), + k.permute(0, 2, 1, 3), + v.permute(0, 2, 1, 3), + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale) + return out.permute(0, 2, 1, 3) + + _dist_attn = DistributedAttention(_sdpa_inner, sp_group, scatter_idx=2, gather_idx=1) + + def _ulysses_attn_forward(module, + query_states, + key_states, + value_states, + attention_mask, + scaling=None, + dropout=0.0, + is_causal=False, + **kwargs): + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + out = _dist_attn(q, k, v, batch_dim_idx=0, dropout_p=dropout, is_causal=is_causal, scale=scaling) + return out, None + + ALL_ATTENTION_FUNCTIONS["ulyssess"] = _ulysses_attn_forward + + # Ulysses baseline: regular torch.compile, no deepcompile or autosp pass. + ulysses_config = deepcopy(config) + ulysses_config.pop("compile", None) + ulysses_model = deepcopy(base_model) + ulysses_model.config._attn_implementation = "ulyssess" + ulysses_engine, _, _, _ = deepspeed.initialize(config=ulysses_config, + model=ulysses_model, + model_parameters=ulysses_model.parameters()) + ulysses_engine.compile() + + # AutoSP model: sdpa so the autosp pass can find F.scaled_dot_product_attention. + # dynamic=True ensures all shape dimensions are treated symbolically so the autosp + # pass can correctly shard the sequence dimension for all dtypes including fp16/bf16. + autosp_model = deepcopy(base_model) + autosp_engine, _, _, _ = deepspeed.initialize(config=config, + model=autosp_model, + model_parameters=autosp_model.parameters()) + autosp_engine.compile(compile_kwargs={"dynamic": True}) + + # Train both engines in lockstep; compare the losses at the final step. + ul_loss = autosp_loss = None + for i in range(iterations): + torch.manual_seed(42 + i) + full_ids = torch.randint(0, vocab_size, (1, seq_length), device=device) + + # Ulysses: each rank processes its own shard. + shard_ids = full_ids[:, sp_rank * chunk:(sp_rank + 1) * chunk] + shard_pos = torch.arange(sp_rank * chunk, (sp_rank + 1) * chunk, device=device).unsqueeze(0) + shard_mask = torch.ones(1, chunk, device=device, dtype=torch.long) + ul_out = ulysses_engine(input_ids=shard_ids, + labels=shard_ids, + position_ids=shard_pos, + attention_mask=shard_mask) + # Average per-shard losses across SP ranks to get the full-sequence loss. + ul_loss = ul_out.loss.clone() + dist.all_reduce(ul_loss, group=sp_group) + ul_loss = ul_loss / sp_size + + # AutoSP: full sequence. dynamic=True makes all shapes symbolic, so mark_dynamic + # is not needed; only the tag attributes that the autosp pass uses are set here. + autosp_ids = full_ids.clone() + autosp_lbl = autosp_ids.clone() + autosp_pos = torch.arange(seq_length, device=device).unsqueeze(0) + autosp_msk = torch.ones(1, seq_length, device=device, dtype=torch.long) + autosp_ids.tag = autosp_constants.AUTOSP_INPUT_ID_KEY + autosp_lbl.tag = autosp_constants.AUTOSP_LABEL_ID_KEY + autosp_pos.tag = autosp_constants.AUTOSP_POSITION_ID_KEY + autosp_out = autosp_engine(input_ids=autosp_ids, + labels=autosp_lbl, + position_ids=autosp_pos, + attention_mask=autosp_msk) + autosp_loss = autosp_out.loss + + ulysses_engine.backward(ul_out.loss) + ulysses_engine.step() + autosp_engine.backward(autosp_loss) + autosp_engine.step() + + allclose_on_all_ranks(autosp_loss, ul_loss, "AutoSP and Ulysses losses are not close.", rtol=RTOL, atol=ATOL) + + ulysses_engine.destroy() + del ALL_ATTENTION_FUNCTIONS["ulyssess"] + autosp_engine.destroy() + + +def create_gm_nodes(batch_size: int = 1, seq_len: int = 16): + """ + Load a tiny LlamaForCausalLM, tag inputs with AutoSP keys, mark the sequence + dimension dynamic, and capture the torch-fx GraphModule via a custom + torch.compile backend. + + The returned gm is identical to what the autosp pass receives during training: + placeholder nodes carry tensor_dict tags and meta['val'] shapes are symbolic + (SymInt) in the sequence dimension. + + Returns: + gm – GraphModule with fully populated node metadata + inputs – (input_ids, labels, position_ids) used for tracing + """ + from transformers import AutoModelForCausalLM, AutoConfig + from deepspeed.compile import constants + + # Each call needs a clean dynamo state; without this, the recompile_limit + # (default 8) is exhausted across tests and the backend is never invoked. + torch._dynamo.reset() + + model_name = 'hf-internal-testing/tiny-random-LlamaForCausalLM' + model_config = AutoConfig.from_pretrained(model_name) + model_config._attn_implementation = "sdpa" + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config) + model.eval() + + vocab_size = model_config.vocab_size + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + position_ids = torch.arange(seq_len).unsqueeze(0) + + # dynamo propagates Python tensor attributes into node.meta['tensor_dict']; + # find_node_by_tag relies on this to identify the AutoSP input nodes. + input_ids.tag = constants.AUTOSP_INPUT_ID_KEY + labels.tag = constants.AUTOSP_LABEL_ID_KEY + position_ids.tag = constants.AUTOSP_POSITION_ID_KEY + + # Marking the sequence dim dynamic causes dynamo to emit a SymInt placeholder + # node and store symbolic shapes in node.meta['val'], which shard_tensor_node + # needs to locate the sequence-length symbol in the graph. + torch._dynamo.decorators.mark_dynamic(input_ids, 1) + torch._dynamo.decorators.mark_dynamic(labels, 1) + torch._dynamo.decorators.mark_dynamic(position_ids, 1) + + captured_gm = [None] + + def _capture_backend(gm, example_inputs): + if captured_gm[0] is None: + captured_gm[0] = gm + return gm + + compiled = torch.compile(model, backend=_capture_backend, dynamic=True) + with torch.no_grad(): + compiled(input_ids=input_ids, labels=labels, position_ids=position_ids) + + assert captured_gm[0] is not None, "Capture backend was never invoked — graph capture failed" + return captured_gm[0], (input_ids, labels, position_ids) + + +def find_sym_seq_node(gm): + """ + Return the SymInt placeholder node for the sequence-length dimension of + input_ids, or None if it cannot be found. + """ + from deepspeed.compile.util import get_input_id_node + from deepspeed.compile.fx import get_node_shape_meta, find_node_by_name + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + return find_node_by_name(gm, str(seq_symint)) From 6ba4117d0f839c2f36791890605dad7b98f1906f Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:34:15 +0900 Subject: [PATCH 3/5] AutoSP: fix torch 2.9 fake propagation issues (#2) * Fix AutoSP shape propagation fake mode reuse * Fix AutoSP torch 2.9 fake propagation * Fix AutoSP shard slice ordering * Add comments for AutoSP torch 2.9 fixes * Change AutoSP PyTorch requirement to 2.9+ --------- Signed-off-by: Masahiro Tanaka Signed-off-by: Ahan Gupta --- deepspeed/compile/custom_ops/all_to_all.py | 19 +++++++- deepspeed/compile/custom_ops/sp_compat.py | 4 +- deepspeed/compile/passes/sp_compile.py | 47 +++++++++++++++++--- deepspeed/compile/util.py | 10 ++++- tests/unit/v1/compile/test_compile_autosp.py | 47 ++++++++++++++++++++ 5 files changed, 117 insertions(+), 10 deletions(-) diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py index dea50695c5df..3307bbc527ff 100644 --- a/deepspeed/compile/custom_ops/all_to_all.py +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -5,6 +5,7 @@ import torch import deepspeed.comm as dist +from torch.utils._sympy.functions import FloorDiv from .sp_dp_registry import get_group, is_setup, sp_size @@ -54,9 +55,25 @@ def all_to_all( @torch.library.register_fake("autosp::all_to_all") def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): + + def maybe_restore_sharded_dim(dim: torch.SymInt, factor: int): + # Torch 2.9 may keep `P * (s // P)` distinct from the original `s` during + # fake shape propagation. When the local dim is exactly `FloorDiv(s, P)`, + # restore the original symbol so downstream ops see a consistent sequence dim. + node = getattr(dim, "node", None) + if node is None: + return dim * factor + + expr = node.expr + if isinstance(expr, FloorDiv) and expr.args[1] == factor: + hint = node.hint * factor if node.has_hint() else None + return node.shape_env.create_symintnode(expr.args[0], hint=hint) + + return dim * factor + B, dim1, dim2, H = input.shape if scatter_idx == 1: - return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) + return input.new_empty(B, dim1 // sp_size(), maybe_restore_sharded_dim(dim2, sp_size()), H) else: return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) diff --git a/deepspeed/compile/custom_ops/sp_compat.py b/deepspeed/compile/custom_ops/sp_compat.py index a13bcf42fbf4..c8b25313e9be 100644 --- a/deepspeed/compile/custom_ops/sp_compat.py +++ b/deepspeed/compile/custom_ops/sp_compat.py @@ -11,9 +11,9 @@ def _check_autosp_compatibility(): # Strip the local version segment (e.g. +cu128) so CUDA builds don't sort # above the max bound when using packaging's local-version ordering rules. torch_version = Version(torch.__version__.split("+")[0]) - if torch_version < Version("2.6") or torch_version >= Version("2.8"): + if torch_version < Version("2.9"): raise RuntimeError( - "AutoSP requires PyTorch >= 2.6 and <= 2.7, found " + "AutoSP requires PyTorch >= 2.9, found " f"{torch.__version__}." ) diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py index 5b33aa119a56..ab2b3fb9fa33 100644 --- a/deepspeed/compile/passes/sp_compile.py +++ b/deepspeed/compile/passes/sp_compile.py @@ -8,7 +8,7 @@ import torch import deepspeed.comm as dist -from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.fake_tensor import FakeTensorMode, maybe_get_fake_mode from torch.fx import GraphModule, Node from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.experimental.symbolic_shapes import ShapeEnv @@ -80,7 +80,7 @@ def pass_shard_seq_dim(gm: GraphModule, example_inputs): seq_symint = val.shape[1] assert isinstance( seq_symint, - torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" + torch.SymInt), f"expected sequence dimension to be of type {torch.SymInt!r} but found {type(seq_symint)!r}" sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) if sym_seq_dim_node is None: @@ -184,15 +184,52 @@ def pass_canonicalize(gm: GraphModule, real_inputs): def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): - shape_env = ShapeEnv() - fake_mode = FakeTensorMode(shape_env=shape_env) + fake_mode = None + for node in gm.graph.nodes: + # Reuse the graph's existing fake mode when metadata is already present. + # Its ShapeEnv owns the symbolic dims captured during tracing, so using a + # fresh mode here can desynchronize fake inputs from graph metadata. + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_mode = maybe_get_fake_mode(fake_val) + elif fake_mode is None: + fake_val = node.meta.get("example_value", node.meta.get("val")) + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_mode = maybe_get_fake_mode(fake_val) + if fake_mode is not None: + break + + if fake_mode is None: + # Some graphs do not carry fake tensor metadata yet; create a fallback + # mode so FakeTensorProp can still run shape-only execution. + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + fake_inputs = [] for t in real_inputs: if isinstance(t, torch.Tensor): fake_inputs.append(fake_mode.from_tensor(t)) else: fake_inputs.append(t) - FakeTensorProp(gm).propagate(*fake_inputs) + + # Torch 2.9 can fail fake propagation through SDPA's masked fake-CUDA path, + # even though this pass only needs output metadata. Temporarily clear + # attn_mask so shape propagation can proceed, then restore it immediately; + # SDPA output shapes are still determined by Q/K/V shapes, not mask values. + saved_sdpa_masks = [] + for attn_node in get_sdpa_nodes(gm): + attn_mask = attn_node.kwargs.get("attn_mask") + if attn_mask is not None: + saved_sdpa_masks.append((attn_node, attn_mask)) + attn_node.update_kwarg("attn_mask", None) + + try: + # fake_inputs are already created under fake_mode above, so run + # propagation without reconverting them into a different fake mode. + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs) + finally: + for attn_node, attn_mask in saved_sdpa_masks: + attn_node.update_kwarg("attn_mask", attn_mask) def apply_autosp(gm: GraphModule, diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 97b76f46b866..5982f7477f52 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -591,7 +591,8 @@ def shard_tensor_node(gm: GraphModule, tensor_node: Node): seq_len = val.shape[1] assert isinstance( - seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`" + seq_len, + torch.SymInt), (f"Expected sequence dimension to be {torch.SymInt!r} but instead found {type(seq_len)!r}") symb_seq_int_node = find_node_by_name(gm, str(seq_len)) assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" @@ -599,7 +600,12 @@ def shard_tensor_node(gm: GraphModule, tensor_node: Node): slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node) indices = (slice_all, slice_range) - with gm.graph.inserting_after(tensor_node): + positions = {node: i for i, node in enumerate(gm.graph.nodes)} + # Insert after the later dependency so the new getitem does not appear + # before the symbolic slice nodes in graph order. Torch 2.9 bf16 can place + # the SymInt placeholder after the tensor placeholder. + anchor_node = slice_range if positions[slice_range] > positions[tensor_node] else tensor_node + with gm.graph.inserting_after(anchor_node): sliced_node = gm.graph.call_function( operator.getitem, args=(tensor_node, indices), diff --git a/tests/unit/v1/compile/test_compile_autosp.py b/tests/unit/v1/compile/test_compile_autosp.py index 8277eff8928b..12a5dbdb68b4 100644 --- a/tests/unit/v1/compile/test_compile_autosp.py +++ b/tests/unit/v1/compile/test_compile_autosp.py @@ -9,6 +9,7 @@ import pytest import torch import torch.nn.functional as F +from torch.fx import Graph, GraphModule from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator @@ -235,3 +236,49 @@ def test(self, seq_len): f"User '{user.name}' still references the unsharded input_ids_node" assert sliced_node in user.all_input_nodes, \ f"User '{user.name}' does not reference the sliced node" + + def test_preserves_topological_order_when_sym_placeholder_follows_input(self): + import deepspeed.comm as _dist + from deepspeed.compile.custom_ops import sp_dp_registry as _registry + from deepspeed.compile.fx import find_node_by_name, get_node_shape_meta + from deepspeed.compile.util import shard_tensor_node, get_input_id_node + + # Regression test for the torch 2.9 bf16 trace where the SymInt + # placeholder can appear after input_ids. shard_tensor_node must still + # produce a lint-clean graph instead of inserting getitem before its + # symbolic slice dependencies. + gm, _ = create_gm_nodes(seq_len=64) + input_ids_node = get_input_id_node(gm) + seq_symint = get_node_shape_meta(input_ids_node).shape[1] + sym_seq_node = find_node_by_name(gm, str(seq_symint)) + assert sym_seq_node is not None, "Symbolic sequence-length node not found in graph" + + nodes = list(gm.graph.nodes) + input_idx = nodes.index(input_ids_node) + sym_idx = nodes.index(sym_seq_node) + assert sym_idx < input_idx, "Expected source graph to place the symbolic placeholder before input_ids" + + # Reorder placeholders to mirror the torch 2.9 bf16 trace where the symbolic + # sequence placeholder can appear after input_ids. + reordered_nodes = nodes[:] + reordered_nodes.pop(input_idx) + reordered_nodes.insert(sym_idx, input_ids_node) + reordered_nodes.pop(sym_idx + 1) + reordered_nodes.insert(input_idx, sym_seq_node) + + reordered_graph = Graph() + env = {} + for node in reordered_nodes: + new_node = reordered_graph.node_copy(node, lambda n: env[n]) + new_node.meta = node.meta.copy() + env[node] = new_node + reordered_graph.lint() + + reordered_gm = GraphModule(gm, reordered_graph) + reordered_input_ids = get_input_id_node(reordered_gm) + + with patch.object(_registry, 'sp_size', return_value=_SP_SIZE), \ + patch.object(_dist, 'get_rank', return_value=0): + shard_tensor_node(reordered_gm, reordered_input_ids) + + reordered_gm.graph.lint() From cd27fa12fed657d7a9ac4ae46cca444ec6319885 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 23 Mar 2026 00:24:23 -0500 Subject: [PATCH 4/5] update docs Signed-off-by: Neel Dani --- deepspeed/compile/custom_ops/sp_compat.py | 12 ++++-------- docs/_pages/config-json.md | 8 -------- docs/code-docs/source/training.rst | 1 - 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/deepspeed/compile/custom_ops/sp_compat.py b/deepspeed/compile/custom_ops/sp_compat.py index c8b25313e9be..136d01edce0b 100644 --- a/deepspeed/compile/custom_ops/sp_compat.py +++ b/deepspeed/compile/custom_ops/sp_compat.py @@ -12,17 +12,13 @@ def _check_autosp_compatibility(): # above the max bound when using packaging's local-version ordering rules. torch_version = Version(torch.__version__.split("+")[0]) if torch_version < Version("2.9"): - raise RuntimeError( - "AutoSP requires PyTorch >= 2.9, found " - f"{torch.__version__}." - ) + raise RuntimeError("AutoSP requires PyTorch >= 2.9, found " + f"{torch.__version__}.") try: import transformers if Version(transformers.__version__) > Version("4.50.3"): - raise RuntimeError( - "AutoSP requires transformers <= 4.50.3, found " - f"{transformers.__version__}." - ) + raise RuntimeError("AutoSP requires transformers <= 4.50.3, found " + f"{transformers.__version__}.") except ImportError: pass # transformers not installed; skip the check diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 1e164d273bd0..32c454925b41 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1906,7 +1906,6 @@ DeepSpeed provides compiler-based optimization passes through the `compile` conf "compile": { "deepcompile": true, "passes": ["autosp"], - "pass_args": {"sp_size": 2} } } ``` @@ -1917,13 +1916,6 @@ DeepSpeed provides compiler-based optimization passes through the `compile` conf | ------------------------------------------------------------------------ | ------- | | List of compiler passes to apply. Currently supported: `["autosp"]`. | `[]` | - -**sp_size**: [integer] - -| Description | Default | -| ----------------------------------------------------------------------------------- | ------- | -| Sequence parallel degree (number of devices to shard the sequence dimension across). | `1` | - ### Data Type options ```json diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 265ea7fd2bec..5bb1e87d2bd2 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -528,7 +528,6 @@ config and calling ``prepare_autosp_inputs()`` to prepare inputs before each for "compile": { "deepcompile": True, "passes": ["autosp"], - "pass_args": {"sp_size": 2} } } From 4ef94195dfae258945cd1ce053a48aef69f510d4 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Thu, 26 Mar 2026 22:52:03 -0500 Subject: [PATCH 5/5] Update torch required version Signed-off-by: Neel Dani --- tests/unit/v1/compile/test_compile_autosp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/v1/compile/test_compile_autosp.py b/tests/unit/v1/compile/test_compile_autosp.py index 12a5dbdb68b4..fcdb7fdefc88 100644 --- a/tests/unit/v1/compile/test_compile_autosp.py +++ b/tests/unit/v1/compile/test_compile_autosp.py @@ -19,8 +19,8 @@ from unit.common import DistributedTest from unit.util import bf16_required_version_check, skip_on_arch -pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.6), - reason="AutoSP tests require PyTorch >= 2.6") +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.9), + reason="AutoSP tests require PyTorch >= 2.9") # Fixed sp_size injected into mocks. _SP_SIZE = 2