Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepspeed/compile/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

# DeepSpeed Team

from typing import List, Optional, Literal
from deepspeed.runtime.config_utils import DeepSpeedConfigModel

PassName = Literal["z1", "z3", "autosp"]


class CompileConfig(DeepSpeedConfigModel):
""" Configure compile settings """
Expand Down Expand Up @@ -53,3 +56,6 @@ class CompileConfig(DeepSpeedConfigModel):

keep_all_input_tensors: bool = False
""" Keep real values for all input tensors in InputStorage instead of using dummy values """

passes: Optional[List[PassName]] = None
""" Composes different optimizations. """
11 changes: 11 additions & 0 deletions deepspeed/compile/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

#########################################
# AUTOSP
#########################################
AUTOSP_INPUT_ID_KEY = "input_id"
AUTOSP_LABEL_ID_KEY = "label_id"
AUTOSP_POSITION_ID_KEY = "position_id"
9 changes: 9 additions & 0 deletions deepspeed/compile/custom_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .all_to_all import all_to_all
from . import sp_dp_registry

__all__ = ["all_to_all", "sp_dp_registry", "sp_compat"]
92 changes: 92 additions & 0 deletions deepspeed/compile/custom_ops/all_to_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed.comm as dist
from torch.utils._sympy.functions import FloorDiv
from .sp_dp_registry import get_group, is_setup, sp_size


@torch.library.custom_op("autosp::all_to_all", mutates_args=())
def all_to_all(
input: torch.Tensor,
scatter_idx: int,
gather_idx: int,
name: str,
) -> torch.Tensor:
"""
All-to-all collective for SDPA tensors [B, N, S, H].

For QKV (scatter_idx=1, gather_idx=2):
[B, N, S/P, H] -> [B, N/P, S, H]
For O (scatter_idx=2, gather_idx=1):
[B, N/P, S, H] -> [B, N, S/P, H]
"""
assert is_setup(), 'Incorrect initialization of SP/DP mesh.'
B, dim1, dim2, H = input.shape
gid = dist.get_rank() // sp_size()
group = get_group(gid)

if scatter_idx == 1:
N, local_S = dim1, dim2
input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H)
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()

output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

output = output.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(B, N // sp_size(), sp_size() * local_S, H)
else:
local_N, S = dim1, dim2
input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H)
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()

output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

output = output.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(B, sp_size() * local_N, S // sp_size(), H)

return output


@torch.library.register_fake("autosp::all_to_all")
def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str):

def maybe_restore_sharded_dim(dim: torch.SymInt, factor: int):
# Torch 2.9 may keep `P * (s // P)` distinct from the original `s` during
# fake shape propagation. When the local dim is exactly `FloorDiv(s, P)`,
# restore the original symbol so downstream ops see a consistent sequence dim.
node = getattr(dim, "node", None)
if node is None:
return dim * factor

expr = node.expr
if isinstance(expr, FloorDiv) and expr.args[1] == factor:
hint = node.hint * factor if node.has_hint() else None
return node.shape_env.create_symintnode(expr.args[0], hint=hint)

return dim * factor

B, dim1, dim2, H = input.shape
if scatter_idx == 1:
return input.new_empty(B, dim1 // sp_size(), maybe_restore_sharded_dim(dim2, sp_size()), H)
else:
return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H)


def _all_to_all_backward_setup(ctx, inputs, output):
_, scatter_idx, gather_idx, name = inputs
ctx.scatter_idx = gather_idx
ctx.gather_idx = scatter_idx
ctx.name = name + "_grad"


def _all_to_all_backward(ctx, grad):
return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None)


torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup)
24 changes: 24 additions & 0 deletions deepspeed/compile/custom_ops/sp_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from packaging.version import Version


def _check_autosp_compatibility():
# Strip the local version segment (e.g. +cu128) so CUDA builds don't sort
# above the max bound when using packaging's local-version ordering rules.
torch_version = Version(torch.__version__.split("+")[0])
if torch_version < Version("2.9"):
raise RuntimeError("AutoSP requires PyTorch >= 2.9, found "
f"{torch.__version__}.")

try:
import transformers
if Version(transformers.__version__) > Version("4.50.3"):
raise RuntimeError("AutoSP requires transformers <= 4.50.3, found "
f"{transformers.__version__}.")
except ImportError:
pass # transformers not installed; skip the check
67 changes: 67 additions & 0 deletions deepspeed/compile/custom_ops/sp_dp_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed.comm as dist

GROUP_REGISTRY = {} # int -> dist.ProcessGroup


def register_groups(groups):
"""groups: List[List[int]], e.g. [[0,1],[2,3]]"""
for gid, ranks in enumerate(groups):
if gid not in GROUP_REGISTRY:
GROUP_REGISTRY[gid] = dist.new_group(ranks)


def get_group(gid: int):
return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group()


def get_registry():
return GROUP_REGISTRY


def is_setup():
return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False


def extract_mesh_size(param_dict):
sp_size = param_dict.get('sequence_parallel_size', 1)
assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE'
dp_size = dist.get_world_size() // sp_size

return sp_size, dp_size


def sp_size():
assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.'

return GROUP_REGISTRY['SP_SIZE']


def dp_size():
assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly'

return GROUP_REGISTRY['DP_SIZE']


def populate_registry(SP_SIZE, DP_SIZE):
""" Populate rank to SP/DP mesh index. """

if GROUP_REGISTRY.get('is_reg', False):
return

group_listing = []
offset = 0
for _ in range(DP_SIZE):
group_listing.append([i + offset for i in range(SP_SIZE)])
offset += SP_SIZE

register_groups(group_listing)

## Extraneous metadata required for proper instatiation. ##
GROUP_REGISTRY['SP_SIZE'] = SP_SIZE
GROUP_REGISTRY['DP_SIZE'] = DP_SIZE
GROUP_REGISTRY['is_reg'] = True
33 changes: 31 additions & 2 deletions deepspeed/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# DeepSpeed Team

from typing import Callable, Any, List, Dict
from typing import Callable, Any, List, Dict, Optional
from collections import defaultdict

import torch
from torch.fx import Node, Graph
from torch.fx import Node, Graph, GraphModule

from .util import get_last_uses

Expand Down Expand Up @@ -138,3 +138,32 @@ def free_tensors(tensors: List[torch.Tensor]):

# Python version for debugging
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)


def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]:
for node in gm.graph.nodes:
if node.name == name:
return node
return None


def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]:
return node.meta.get("val") or node.meta.get("example_value")


def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]:
input_id_node = None
for node in gm.graph.nodes:
# https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048
tensor_dict = node.meta.get('tensor_dict')
if tensor_dict and tensor_dict.get('tag') == tag:
input_id_node = node
break
return input_id_node


def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None):
exclude = exclude or []
to_replace = [u for u in node.users if u not in exclude]
for user in to_replace:
user.replace_input_with(node, replacement)
23 changes: 23 additions & 0 deletions deepspeed/compile/init_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from torch.fx import GraphModule
from .passes.sp_compile import apply_autosp
from .passes.long_context_checkpointing import register_long_context_checkpointing
from .custom_ops.sp_dp_registry import extract_mesh_size
from .custom_ops.sp_compat import _check_autosp_compatibility


def init_autosp(config):
_check_autosp_compatibility()
sp_size, dp_size = extract_mesh_size(config._param_dict)
register_long_context_checkpointing()

def backend_fn(gm: GraphModule, real_inputs):
apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size)
return torch._inductor.compile(gm, real_inputs)

return backend_fn
Loading
Loading