diff --git a/ci/install_test_pkgs.sh b/ci/install_test_pkgs.sh index 8fbbf0c2..14c59b12 100644 --- a/ci/install_test_pkgs.sh +++ b/ci/install_test_pkgs.sh @@ -5,3 +5,4 @@ python3 -m pip install black==22.10.0 python3 -m pip install transformers==4.25.1 --no-deps python3 -m pip install pylint==2.14.0 astroid==2.11.6 mock==4.0.3 +python3 -m pip install z3-solver tabulate diff --git a/setup.py b/setup.py index e5a401ea..b99ff40c 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,8 @@ def setup(): long_description_content_type="text/markdown", setup_requires=[], install_requires=[ + "z3-solver", + "tabulate", "packaging", "psutil", ], diff --git a/slapo/sharding/__init__.py b/slapo/sharding/__init__.py index 051919f5..4c310cf5 100644 --- a/slapo/sharding/__init__.py +++ b/slapo/sharding/__init__.py @@ -10,3 +10,4 @@ scatter_forward_output, reduce_forward_output, ) +from .solver import Solver diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py new file mode 100644 index 00000000..c7d04d01 --- /dev/null +++ b/slapo/sharding/solver.py @@ -0,0 +1,706 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Auto-parallelism solver that finds the optimal sharding scheme for a given model. +It models the problem as a program synthesis problem and uses Z3 to solve it. +""" + +import operator +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F +from torch.fx.passes.shape_prop import ShapeProp, TensorMetadata +import z3 +from tabulate import tabulate + +from ..logger import get_logger + +logger = get_logger(__name__) + + +class ShardSpec: + def __init__(self, spec): + """ + R: replicated + S: sharded + """ + self.map = {"RR": 0, "RS": 1, "SR": 2} + if isinstance(spec, str): + self.spec = spec + else: + self.spec = list(self.map.keys())[list(self.map.values()).index(spec)] + + @property + def id(self): + return self.map[self.spec] + + def __str__(self): + return self.spec + + +class FxOp: + def __init__(self, node): + self.node = node + self.name = node.name + self.args = [] + self.users = [] + self.out_shape = node.meta["tensor_meta"].shape + self.out_size = int(torch.prod(torch.tensor(self.out_shape))) + self.z3_inputs = [] + + def add_arg(self, arg): + self.args.append(arg) + + def add_user(self, user): + self.users.append(user) + + def generate_input_z3(self): + raise NotImplementedError + + def generate_output(self, mod): + output = self.generate_output_z3() + if isinstance(output, int): + return output + return mod.evaluate(output).as_long() + + def generate_output_z3(self): + raise NotImplementedError + + def calculate_comm_cost(self, mod): + cost = self.calculate_comm_cost_z3() + if isinstance(cost, int): + return cost + return mod.evaluate(cost).as_long() + + def calculate_comm_cost_z3(self): + raise NotImplementedError + + +class PlaceholderOp(FxOp): + def generate_input_z3(self): + # input should not be sharded + return [], [] + + def generate_output_z3(self): + return ShardSpec("RR").id + + def calculate_comm_cost_z3(self): + return 0 + + +class ElementwiseOp(FxOp): + def generate_input_z3(self): + return [], [] + + def generate_output_z3(self): + return self.args[0].generate_output_z3() + + def calculate_comm_cost_z3(self): + return 0 + + +class BinaryOp(FxOp): + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + self.z3_inputs.append(z3.BitVec(f"{self.name}_1", 2)) + compute_constraints = [self.z3_inputs[0] == self.z3_inputs[1]] + format_constraints = [ + z3.ULE(self.z3_inputs[0], 3), + z3.ULE(self.z3_inputs[1], 3), + ] + constraints = compute_constraints + format_constraints + return self.z3_inputs, constraints + + def generate_output_z3(self): + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class LayerNormOp(FxOp): + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + # Reduction across the last dimension, so `RS` is prohibited. + format_constraints += [self.z3_inputs[0] != 1] + return self.z3_inputs, format_constraints + + def generate_output_z3(self): + # The same spec as the input + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + # No communication cost + return 0 + + +class SoftmaxOp(FxOp): + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + # Reduction across the last dimension, so `RS` is prohibited. + format_constraints += [self.z3_inputs[0] != 1] + return self.z3_inputs, format_constraints + + def generate_output_z3(self): + # The same spec as the input + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + # No communication cost + return 0 + + +class ViewOp(FxOp): + """ + # TODO: 1. Verify the behavior of general view function + # 2. Support merging two dimensions + Only certain view functions can be sharded without communication. + Currently only reshaping the *last* dimension is supported. + + Consider the view function in Transformer: + (bs,seq,d) -> (bs,seq,h,d//h) + + We have the following communication matrix: + (The src spec only considers the last dimension) + src\dst RR RS SR + R 0 0 0 + S 1/p 1/p 0 + + S->RR requires an all-gather to retrieve all the data. + S->RS also requires an all-gather before the view function to ensure the data + is correct. To illustrate this case, consider h=2, p=2, d=8, d//h=4, and + the original data is [0 1 2 3 4 5 6 7], and the expected result is + [[0 1 2 3], [4 5 6 7]]. If we want to shard it into RS spec on two devices, + the data should be as follows: + Device 1 | Device 2 + 0 1 | 2 3 + 4 5 | 6 7 + But if we directly view from the source sharded spec shown below, + Device 1 | Device 2 + 0 1 2 3 | 4 5 6 7 + and reshape it to (h,d//h//p), we get + Device 1 | Device 2 + 0 1 | 4 5 + 2 3 | 6 7 + Thus, the data is incorrect. + To avoid this, we need to all-gather the data first, and then reshape it. + """ + + def __init__(self, node, z3_graph, p): + super().__init__(node) + self.z3_graph = z3_graph + self.num_devices = p + self.prev_op = self.z3_graph[self.node.args[0].name] + + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + # TODO: Need to consider when we support higher dimensions + # compute_constraints = [ + # z3.Implies( + # self.prev_op.generate_output_z3() == ShardSpec("SR").id, + # self.z3_inputs[0] == ShardSpec("RR").id, + # ), + # ] + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + return self.z3_inputs, format_constraints + + def generate_output_z3(self): + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + result = z3.If( + z3.And( + self.prev_op.generate_output_z3() == ShardSpec("RS").id, + self.z3_inputs[0] == ShardSpec("RR").id, + ), + self.out_size * 1 / self.num_devices, + z3.If( + z3.And( + self.prev_op.generate_output_z3() == ShardSpec("RS").id, + self.z3_inputs[0] == ShardSpec("RS").id, + ), + self.out_size * 1 / self.num_devices, + 0, + ), + ) + return result + + +class PermuteOp(FxOp): + def __init__(self, node, z3_graph): + super().__init__(node) + self.z3_graph = z3_graph + permute_idx = list(node.args[1:]) + self.output_map = {} + for in_spec in ("RR", "RS", "SR"): + spec = "R" * (len(permute_idx) - 2) + in_spec + out_spec = spec[-2:] + self.output_map[in_spec] = out_spec + self.prev_op = self.z3_graph[self.node.args[0].name] + + def generate_input_z3(self): + return [], [] + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If( + self.prev_op.generate_output_z3() == ShardSpec(inp).id, + ShardSpec(out).id, + result, + ) + return result + + def calculate_comm_cost_z3(self): + # permutation does not involve communication + return 0 + + +class TransposeOp(FxOp): + def __init__(self, node, z3_graph): + # FIXME: Suppose always transpose the last two dims + super().__init__(node) + self.z3_graph = z3_graph + self.output_map = {"RR": "RR", "RS": "SR", "SR": "RS"} + self.prev_op = self.z3_graph[self.node.args[0].name] + + def generate_input_z3(self): + return [], [] + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If( + self.prev_op.generate_output_z3() == ShardSpec(inp).id, + ShardSpec(out).id, + result, + ) + return result + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class MatmulOp(FxOp): + def __init__(self, node): + super().__init__(node) + self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"} + self.comm_cost_map = { # map from input spec to comm cost + "RR": 0, + "RS": self.out_size, # all_reduce + "SR": 0, + } + + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) # input + self.z3_inputs.append(z3.BitVec(f"{self.name}_1", 2)) # weight + + compute_constraints = [ + z3.Or( + [ + z3.And( + self.z3_inputs[0] == ShardSpec("RR").id, + self.z3_inputs[1] == ShardSpec("RS").id, + ), + z3.And( + self.z3_inputs[0] == ShardSpec("RS").id, + self.z3_inputs[1] == ShardSpec("SR").id, + ), + z3.And( + self.z3_inputs[0] == ShardSpec("SR").id, + self.z3_inputs[1] == ShardSpec("RR").id, + ), + ] + ) + ] + format_constraints = [ + z3.ULE(self.z3_inputs[0], 3), + z3.ULE(self.z3_inputs[1], 3), + ] + constraints = compute_constraints + format_constraints + # force to shard + # constraints += [self.z3_inputs[0] != ShardSpec("RR").id, self.z3_inputs[1] != ShardSpec("RR").id] + return self.z3_inputs, constraints + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If( + self.z3_inputs[0] == ShardSpec(inp).id, ShardSpec(out).id, result + ) + return result + + def calculate_comm_cost_z3(self): + result = 1e12 # invalid + for inp, cost in self.comm_cost_map.items(): + result = z3.If(self.z3_inputs[0] == ShardSpec(inp).id, cost, result) + return result + + +fx_op_map = { + nn.Linear: MatmulOp, + nn.LayerNorm: LayerNormOp, + F.softmax: SoftmaxOp, + nn.Dropout: ElementwiseOp, + torch.matmul: MatmulOp, + F.relu: ElementwiseOp, + F.gelu: ElementwiseOp, + torch.tensor: PlaceholderOp, + # FIXME: three operands, need to ensure specs are the same + torch.where: ElementwiseOp, + torch.pow: ElementwiseOp, + torch.tanh: ElementwiseOp, + operator.truediv: ElementwiseOp, + operator.getitem: ElementwiseOp, + operator.add: BinaryOp, + operator.mul: BinaryOp, +} + + +class Solver: + def __init__(self, gm, p) -> None: + assert isinstance(gm, fx.GraphModule), "gm must be a GraphModule" + self.gm = gm + self.gm.graph.eliminate_dead_code() + logger.debug(self.gm.graph, ranks=0) + self.named_modules = dict(self.gm.named_modules()) + self.z3_graph = {} # {node_name: FxOp} + self.goal = [] + self.cost = None + self.num_devices = p + self.reshard_cost_map = { + "RR": {"RR": 0, "RS": 0, "SR": 0}, + "RS": {"RR": 1 / p, "RS": 0, "SR": 1 / p - 1 / (p * p)}, + "SR": {"RR": 1 / p, "RS": 1 / p - 1 / (p * p), "SR": 0}, + } + + def inference_shape(self, inputs): + sp = ShapeProp(self.gm) + # Tackle the case of meta device + device = next(self.gm.named_parameters())[1].device + inputs = [inp.to("meta") for inp in inputs] + self.gm = self.gm.to(device) + sp.propagate(*inputs) + + def dump_fx_node(self): + res = [] + for node in self.gm.graph.nodes: + if "tensor_meta" in node.meta: + if isinstance(node.meta["tensor_meta"], list): + lst = node.meta["tensor_meta"] + else: + lst = [node.meta["tensor_meta"]] + for data in lst: + if node.op == "call_module": + target = type(self.named_modules[node.target]) + else: + target = node.target + if not isinstance(data, TensorMetadata): + continue + res.append( + [node.name, node.op, target, list(data.shape), data.dtype] + ) + if node.op == "call_module": + for name, param in self.named_modules[ + node.target + ].named_parameters(): + res.append( + ["|-" + name, "", "", list(param.shape), param.dtype] + ) + logger.info( + "\n %s \n", + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]), + ranks=0, + ) + + def calculate_reshard_cost(self, mod, prev, curr, shape): + return mod.evaluate(self.calculate_reshard_cost_z3(prev, curr, shape)) + + def calculate_reshard_cost_z3(self, prev, curr, shape): + result = 1e12 # invalid + for in_spec, target_map in self.reshard_cost_map.items(): + tmp = 1e12 # invalid + for out_spec, val in target_map.items(): + if in_spec == "RR" and out_spec in {"RS", "SR"}: + cost = 1 # add penalty for splitting cost + else: + cost = int(val * shape) + tmp = z3.If(curr == ShardSpec(out_spec).id, cost, tmp) + result = z3.If(prev == ShardSpec(in_spec).id, tmp, result) + return result + + def construct_z3_graph(self): + for node in self.gm.graph.nodes: + if ( + "tensor_meta" not in node.meta + ): # not an activation tensor, no need to care + continue + if node.op == "placeholder": # input + new_op = PlaceholderOp(node) + elif node.op == "call_module": + mod = self.named_modules[node.target] + if isinstance(mod, nn.Linear): + new_op = MatmulOp(node) + elif type(mod) in fx_op_map: + new_op = fx_op_map[type(mod)](node) + else: + raise RuntimeError(f"Unsupported module: {node.target}") + elif node.op == "call_function": + if node.target in fx_op_map: + new_cls = fx_op_map[node.target] + new_op = new_cls(node) + else: + raise RuntimeError(f"Unsupported function: {node.target}") + elif node.op == "call_method": + # pylint: disable=redefined-variable-type + if node.target == "view": + new_op = ViewOp(node, self.z3_graph, self.num_devices) + elif node.target == "permute": + new_op = PermuteOp(node, self.z3_graph) + elif node.target == "transpose": + new_op = TransposeOp(node, self.z3_graph) + elif node.target in ["contiguous", "to"]: + new_op = ElementwiseOp(node) + else: + raise RuntimeError(f"Unsupported method: {node.target}") + elif node.op == "get_attr": # extra buffers + new_op = PlaceholderOp(node) + else: # output + continue + # construct edges + if not (node.op == "call_method" and node.target == "view"): + for arg in node.args: + if not isinstance(arg, fx.Node) or arg.name not in self.z3_graph: + continue + new_op.add_arg(self.z3_graph[arg.name]) + self.z3_graph[arg.name].add_user(new_op) + else: + arg = node.args[0] + new_op.add_arg(self.z3_graph[arg.name]) + self.z3_graph[arg.name].add_user(new_op) + self.z3_graph[node.name] = new_op + + def dump_z3_graph(self, mod=None, dot_file="z3_graph.dot"): + """ + Dump the z3 graph in dot format + """ + if mod is None: + results = None + else: + results = {d.name(): mod[d] for d in mod.decls()} + res = "digraph z3_graph {\n" + # add nodes + for op in self.z3_graph.values(): + attr = f'label="{op.name}\\n({op.__class__.__name__})"' + if isinstance(op, PlaceholderOp): + attr += ",shape=box" + elif isinstance(op, MatmulOp): + if results is None: + attr += ",style=filled,fillcolor=yellow" + else: + weight_spec = results[op.name + "_1"] + if weight_spec == ShardSpec("RR").id: + attr += ",style=filled,fillcolor=yellow" + elif weight_spec == ShardSpec("RS").id: + attr += ',shape=box,style=striped,fillcolor="#FF5733:#FFBD33"' + else: # weight_spec == ShardSpec("SR").id + attr += ',shape=box,style=wedged,fillcolor="#FF5733:#FFBD33"' + res += f" {op.name} [{attr}];\n" + # add edges + for op in self.z3_graph.values(): + for i, arg in enumerate(op.args): + if results is None: + label = "" + elif op.name + "_" + str(i) not in results: + label = "" + else: + label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' + res += f" {arg.name} -> {op.name}{label};\n" + res += "}" + with open(dot_file, "w", encoding="utf-8") as f: + f.write(res) + + def construct_z3_problem(self): + bitvecs = {} + input_constraints = [] + comm_costs = [] + reshard_costs = [] + for op in self.z3_graph.values(): + # no need to include output, since output can be obtained from inputs + inputs, constraints = op.generate_input_z3() + for inp in inputs: + bitvecs[str(inp)] = inp + # input constraints + input_constraints.extend(constraints) + # communication cost + comm_costs.append(op.calculate_comm_cost_z3()) + # reshard cost + for i, arg in enumerate(op.args): + name = f"{op.name}_{i}" + if name not in bitvecs: + continue + curr = bitvecs[name] + prev = arg.generate_output_z3() + reshard_costs.append( + self.calculate_reshard_cost_z3(prev, curr, arg.out_size) + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + reshard_costs.append( + self.calculate_reshard_cost_z3( + op.generate_output_z3(), next_inp, op.out_size + ) + ) + + self.cost = sum(comm_costs) + sum(reshard_costs) + self.goal += input_constraints + + def calculate_new_cost(self, mod): + results = {d.name(): mod[d] for d in mod.decls()} + max_cost = 0 + table = [] + for name, op in self.z3_graph.items(): + # communication cost + inputs = [] + if f"{name}_0" in results: + inputs.append(results[f"{name}_0"]) + if f"{name}_1" in results: + inputs.append(results[f"{name}_1"]) + output = op.generate_output(mod) + comm_cost = op.calculate_comm_cost(mod) + max_cost += comm_cost + if len(inputs) == 1: + table.append( + [op.name, ShardSpec(inputs[0]), ShardSpec(output), comm_cost] + ) + elif len(inputs) == 2: + table.append( + [ + op.name, + f"{ShardSpec(inputs[0])}x{ShardSpec(inputs[1])}", + ShardSpec(output), + comm_cost, + ] + ) + elif len(inputs) > 2: + raise RuntimeError("Not supported") + # resharding cost + for i, arg in enumerate(op.args): + arg_name = f"{op.name}_{i}" + if arg_name not in results: + continue + curr = results[arg_name] + prev = arg.generate_output(mod) + reshard_cost = self.calculate_reshard_cost( + mod, prev, curr, arg.out_size + ) + max_cost += reshard_cost + table.append( + [f"|-{arg.name}", ShardSpec(prev), ShardSpec(curr), reshard_cost] + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + reshard_cost = self.calculate_reshard_cost( + mod, output, next_inp, op.out_size + ) + max_cost += reshard_cost + table.append( + ["output", ShardSpec(output), ShardSpec(next_inp), reshard_cost] + ) + max_cost = z3.simplify(max_cost).as_long() + table.append(["Total", "", "", max_cost]) + logger.info( + "\n %s \n", + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]), + ranks=0, + ) + return max_cost + + def generate_schedule_sequence(self, mod): + print() + print("Best solution:") + results = {d.name(): mod[d] for d in mod.decls()} + for name, op in self.z3_graph.items(): + if not isinstance(op, MatmulOp): + continue + weight = results[f"{name}_1"] + if weight == ShardSpec("RS").id: + dim = 0 # transposed + elif weight == ShardSpec("SR").id: + dim = 1 + else: + continue + if op.node.op == "call_module": + print(f'sch["{op.node.target}"].shard("weight", axis={dim})') + if dim == 0: + print(f'sch["{op.node.target}"].shard("bias", axis={dim})') + if ( + results[f"{name}_0"] == ShardSpec("RS").id + and results[f"{name}_1"] == ShardSpec("SR").id + ): + print( + f'sch["{op.node.target}"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")' + ) + # reshard + for name, op in self.z3_graph.items(): + for i, arg in enumerate(op.args): + arg_name = f"{op.name}_{i}" + if arg_name not in results: + continue + curr = results[arg_name].as_long() + prev = arg.generate_output(mod) + if curr != prev: + print( + f'sch["{op.name}"].sync(mode="fwd_pre", sync_op_or_fn="{ShardSpec(prev)}->{ShardSpec(curr)}")' + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + output = op.generate_output(mod) + if output != next_inp: + print( + f'sch["{op.name}"].sync(mode="fwd_post", sync_op_or_fn="{ShardSpec(output)}->{ShardSpec(next_inp)}")' + ) + + def solve(self, inputs, max_iter=100): + # 1. Shape propagation + self.inference_shape(inputs) + self.dump_fx_node() + # 2. Construct a simplied z3 graph from the fx graph + self.construct_z3_graph() + self.dump_z3_graph() + # 3. Construct the z3 constraints + self.construct_z3_problem() + # 4. Construct the z3 solver + sol = z3.Solver() + sol.add(self.goal) + max_cost = int(1e12) + for it in range(max_iter): + logger.info("=================== Iter %d ===================", it, ranks=0) + sol.push() + # 5. Update cost constraint + sol.add(self.cost < max_cost) + # 6. Solve the problem + sat = sol.check() + if str(sat) == "unsat": + logger.info("Cannot find better solutions", ranks=0) + break + mod = sol.model() + total_cost = mod.evaluate(self.cost) + logger.info(mod, ranks=0) + # 7. Calculate new cost from the results + max_cost = self.calculate_new_cost(mod) + assert max_cost == total_cost.as_long() + sol.pop() + # 8. Generate sharding sequence + self.generate_schedule_sequence(mod) + self.dump_z3_graph(mod, "z3_graph_sharded.dot") + results = {d.name(): mod[d] for d in mod.decls()} + return results, max_cost diff --git a/tests/autoshard/test_bert.py b/tests/autoshard/test_bert.py new file mode 100644 index 00000000..7aafd81a --- /dev/null +++ b/tests/autoshard/test_bert.py @@ -0,0 +1,262 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import copy +import inspect +import operator +import argparse + +import torch +import torch.distributed as dist +from transformers import BertLMHeadModel, AutoConfig + +import slapo +from slapo.logger import get_logger + +logger = get_logger(__name__) + +# Config for verification +bs = 4 +seq_len = 512 + + +def perf_model(mod, input_tensor): + """Measure the performance of a mod with certain resharding schemes""" + # warmup + mod.eval() + for _ in range(10): + mod(input_tensor) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + iters = 40 + for _ in range(iters): + mod(input_tensor) + end_event.record() + torch.cuda.synchronize() + if dist.get_rank() == 0: + print(f"{start_event.elapsed_time(end_event) / iters:.3f} ms") + + +def trace_and_find_view(sch): + input_names = ["hidden_states"] + sig = inspect.signature(sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sch.trace( + recursive=False, flatten=True, tracer="pytorch", concrete_args=concrete_args + ) + ops = sch.find_node(lambda node: node.op == "call_method" and node.target == "view") + assert len(ops) == 4 # q,k,v,context_layer + return ops + + +def fix_attention_mask_shape_megatron(sch): + ops = trace_and_find_view(sch) + + def new_view(tensor, args): + if len(args) == 4: # q,k,v + out = tensor.view(args[0], args[1], args[2] // sch.world_size, args[3]) + else: # context_layer + out = tensor.view(args[0], args[1], args[2] // sch.world_size) + return out + + for op in ops: + sch.replace(new_view, op) + + +def scheme_megatron(model, input_ids, config): + sch = slapo.create_schedule(model) + + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + subsch["dense"].shard("weight", axis=1) # replace + subsch["dense"].sync("fwd_post", sync_op_or_fn="all_reduce") # replace + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["output.dense"].shard("weight", axis=1) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="all_reduce") + + return sch + + +def scheme_sequence_parallel(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import ( + reshard_SR_to_RR, + reshard_RS_to_RR, + ) + + def new_matmul(lhs, rhs): + return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) + + def new_matmul_1(lhs, rhs): + return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) + + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): + sch["bert.embeddings.LayerNorm"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") + for i in range(config.num_hidden_layers): + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + trace_and_find_view(subsch) + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == torch.matmul + ) + assert len(ops) == 2 + subsch.replace(new_matmul, ops[0]) + subsch.replace(new_matmul_1, ops[1]) + sch[f"bert.encoder.layer.{config.num_hidden_layers - 1}.output.LayerNorm"].sync( + mode="fwd_post", sync_op_or_fn="SR->RR" + ) + + return sch + + +def scheme_activation_stationary(model, input_ids, config): + sch = slapo.create_schedule(model) + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + # shape here: [4096, 256](RS). Need to matmul with [1024, 1024] (without shard) + subsch["dense"].sync("fwd_pre", sync_op_or_fn="RS->RR") + subsch["dense"].shard("weight", axis=0) + subsch["dense"].shard("bias", axis=0) + subsch["dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["intermediate.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + subsch["output.dense"].shard("weight", axis=0) + subsch["output.dense"].shard("bias", axis=0) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + + return sch + + +def scheme_activation_sharding(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import reshard_RR_to_SR + + def reshard_and_add(dropout, hidden_states): + """Replace the add operator with reshard_and_add""" + reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) + return dropout + reshard_hidden_states + + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + + subsch.trace(recursive=False, flatten=False, tracer="pytorch") + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == operator.add + ) + subsch.replace(reshard_and_add, ops[0]) + + # shape here: RS + subsch["dense"].sync( + "fwd_pre", sync_op_or_fn="RS->SR" + ) # LayerNorm will crash for SR x RR = SR + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["output.LayerNorm"].sync("fwd_post", sync_op_or_fn="SR->RR") + + return sch + + +def test_schemes(init_dist): + torch.cuda.set_device(dist.get_rank()) + device = torch.cuda.current_device() + + config = AutoConfig.from_pretrained("bert-large-uncased") + with slapo.init_empty_weights(): + model = BertLMHeadModel(config) + + schs = [] + input_ids = torch.ones(bs, seq_len, dtype=torch.long, device=device) + # 1. Slapo-Megatron + # RR x RS = RS, RS x SR = RR + schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) + # 2. Sequence-Parallel + # RR->RS x RR = RS, RS x RR = RS->RR + schs.append(scheme_sequence_parallel(copy.deepcopy(model), input_ids, config)) + # 3. Activation-Stationary + # RR x RS = RS + schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) + # 4. Activation Sharding. SR x RR = SR + schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) + return schs + + +if __name__ == "__main__": + # Create parser + parser = argparse.ArgumentParser(description="Resharding schemes on BERT") + # Add arguments + parser.add_argument("--bs", type=int, help="Batch size", default=8) + parser.add_argument("--seq", type=int, help="Sequence length", default=512) + # Parse the arguments + args = parser.parse_args() + + bs = args.bs + seq_len = args.seq + + dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"])) + + logger.info( + "Number of GPUs: %d, bs=%d, seq_len=%d; Model: BERT-large", + dist.get_world_size(), + bs, + seq_len, + ranks=0, + ) + + schs = test_schemes(None) + + input_ids = torch.ones( + bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" + ) + for i, sch in enumerate(schs): + mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) + mod.to(f"cuda:{dist.get_rank()}") + torch.cuda.empty_cache() + perf_model(mod, input_ids) + del mod diff --git a/tests/autoshard/test_gpt.py b/tests/autoshard/test_gpt.py new file mode 100644 index 00000000..c3f8dce2 --- /dev/null +++ b/tests/autoshard/test_gpt.py @@ -0,0 +1,297 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import copy +import inspect +import operator +import argparse + +import torch +from torch import nn +from torch import fx +import torch.distributed as dist +from transformers import GPTNeoModel, AutoConfig + +import slapo +from slapo.logger import get_logger + +logger = get_logger(__name__) + +# Config for verification +bs = 4 +seq_len = 1024 + + +def perf_model(mod, input_tensor): + """Measure the performance of a mod with certain resharding schemes""" + # warmup + mod.eval() + # mod.to(torch.float16) + for _ in range(10): + mod(input_tensor) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + iters = 40 + for _ in range(iters): + mod(input_tensor) + end_event.record() + torch.cuda.synchronize() + if dist.get_rank() == 0: + print(f"{start_event.elapsed_time(end_event) / iters:.3f} ms") + + +def trace_and_find_view(sch, config): + input_names = ["hidden_states"] + sig = inspect.signature(sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sch.trace( + recursive=False, + flatten=True, + tracer="huggingface", + concrete_args=concrete_args, + config=config, + ) + ops = sch.find_node( + lambda node: node.op == "call_method" + and node.target == "view" + and ( + (node.args[0].op == "call_module" and "proj" in node.args[0].target) + or ( + len(node.args) > 1 + and isinstance(node.args[1], fx.Node) + and node.args[1].op == "call_function" + and node.args[1].target == operator.add + ) + ) + ) + assert len(ops) == 4 # q,k,v,context_layer + return ops + + +def fix_attention_mask_shape_megatron(sch, config): + ops = trace_and_find_view(sch, config) + + def new_view(tensor, args): + if len(args) == 4: # q,k,v + out = tensor.view(args[0], args[1], args[2] // sch.world_size, args[3]) + else: # context_layer + out = tensor.view(args[0], args[1], args[2] // sch.world_size) + return out + + for op in ops: + sch.replace(new_view, op) + + +def scheme_megatron(model, input_ids, config): + sch = slapo.create_schedule(model) + + enable = True if input_ids.shape[0] == 1 else False + with slapo.Verify(sch, [input_ids], enable=enable): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"h.{i}.attn.attention"] + # no bias for GPTNeo + subsch["q_proj"].shard("weight", axis=0) + subsch["k_proj"].shard("weight", axis=0) + subsch["v_proj"].shard("weight", axis=0) + subsch["out_proj"].shard("weight", axis=1) + subsch["out_proj"].sync("fwd_post", sync_op_or_fn="all_reduce") + fix_attention_mask_shape_megatron(subsch, config) + # shard MLP + subsch = sch[f"h.{i}.mlp"] + subsch["c_fc"].shard("weight", axis=0) + subsch["c_fc"].shard("bias", axis=0) + subsch["c_proj"].shard("weight", axis=1) + subsch["c_proj"].sync("fwd_post", sync_op_or_fn="all_reduce") + + return sch + + +def scheme_sequence_parallel(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import ( + reshard_SR_to_RR, + reshard_RS_to_RR, + ) + + def new_matmul(lhs, rhs): + return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) + + def new_matmul_1(lhs, rhs): + return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) + + class NewMask(nn.Module): + def forward(self, query, key, bias): + query_length, key_length = ( + query.size(-2) * sch.world_size, + key.size(-2) * sch.world_size, + ) + size_per_chunk = query_length // sch.world_size + start_idx = key_length - query_length + size_per_chunk * sch.rank + end_idx = start_idx + size_per_chunk + causal_mask = bias[:, :, start_idx:end_idx, :key_length] + return causal_mask + + enable = True if input_ids.shape[0] == 1 else False + with slapo.Verify(sch, [input_ids], eval_mode=True, enable=enable): + sch["drop"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") + for i in range(config.num_hidden_layers): + subsch = sch[f"h.{i}.attn.attention"] + trace_and_find_view(subsch, config) + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == torch.matmul + ) + assert len(ops) == 2 + subsch.replace(new_matmul, ops[0]) + subsch.replace(new_matmul_1, ops[1]) + + # Need to shard the tril matrix (causal mask) + def pattern(query, key, bias): + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = bias[ + :, :, key_length - query_length : key_length, :key_length + ] + return causal_mask + + ops = subsch.find(pattern) + subsch.replace(NewMask(), target_ops=[ops[-1]]) + sch[f"ln_f"].sync(mode="fwd_post", sync_op_or_fn="SR->RR") + + return sch + + +def scheme_activation_stationary(model, input_ids, config): + sch = slapo.create_schedule(model) + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + # shape here: [4096, 256](RS). Need to matmul with [1024, 1024] (without shard) + subsch["dense"].sync("fwd_pre", sync_op_or_fn="RS->RR") + subsch["dense"].shard("weight", axis=0) + subsch["dense"].shard("bias", axis=0) + subsch["dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["intermediate.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + subsch["output.dense"].shard("weight", axis=0) + subsch["output.dense"].shard("bias", axis=0) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + + return sch + + +def scheme_activation_sharding(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import reshard_RR_to_SR + + def reshard_and_add(dropout, hidden_states): + """Replace the add operator with reshard_and_add""" + reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) + return dropout + reshard_hidden_states + + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + + subsch.trace(recursive=False, flatten=False, tracer="pytorch") + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == operator.add + ) + subsch.replace(reshard_and_add, ops[0]) + + # shape here: RS + subsch["dense"].sync( + "fwd_pre", sync_op_or_fn="RS->SR" + ) # LayerNorm will crash for SR x RR = SR + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["output.LayerNorm"].sync("fwd_post", sync_op_or_fn="SR->RR") + + return sch + + +def test_schemes(init_dist): + torch.cuda.set_device(dist.get_rank()) + device = torch.cuda.current_device() + + config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B") + config.use_cache = False + with slapo.init_empty_weights(): + model = GPTNeoModel(config) + + schs = [] + input_ids = torch.ones(bs, seq_len, dtype=torch.long, device=device) + # 1. Slapo-Megatron + # RR x RS = RS, RS x SR = RR + schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) + # 2. Sequence-Parallel + # RR->RS x RR = RS, RS x RR = RS->RR + schs.append(scheme_sequence_parallel(copy.deepcopy(model), input_ids, config)) + # 3. Activation-Stationary + # RR x RS = RS + # schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) + # # 4. Activation Sharding. SR x RR = SR + # schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) + return schs + + +if __name__ == "__main__": + # Create parser + parser = argparse.ArgumentParser(description="Resharding schemes on GPTNeo") + # Add arguments + parser.add_argument("--bs", type=int, help="Batch size", default=4) + parser.add_argument("--seq", type=int, help="Sequence length", default=1024) + # Parse the arguments + args = parser.parse_args() + + bs = args.bs + seq_len = args.seq + + dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"])) + + logger.info( + "Number of GPUs: %d, bs=%d, seq_len=%d; Model: GPTNeo", + dist.get_world_size(), + bs, + seq_len, + ranks=0, + ) + + input_ids = torch.ones( + bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" + ) + schs = test_schemes(None) + for i, sch in enumerate(schs): + mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) + mod.to(f"cuda:{dist.get_rank()}") + torch.cuda.empty_cache() + perf_model(mod, input_ids) + del mod diff --git a/tests/autoshard/test_solver.py b/tests/autoshard/test_solver.py new file mode 100644 index 00000000..3a4b9733 --- /dev/null +++ b/tests/autoshard/test_solver.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Test different resharding schemes on MLP. +Verified by different combinations of resharding schemes. +""" + +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F + +import slapo +from slapo.logger import get_logger +from slapo.sharding import Solver + +logger = get_logger(__name__) + +# Config for verification +p = 8 +bs = 8 +seq_len = 512 +hidden_size = 1024 + + +class MLP(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc1 = nn.Linear(dim, 4 * dim) + self.fc2 = nn.Linear(4 * dim, dim) + + def forward(self, x): + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + return x + + +def test_mlp(): + with slapo.init_empty_weights(): + mlp = MLP(hidden_size) + + sch = slapo.create_schedule(mlp) + sch.trace() + assert isinstance(sch.mod, fx.GraphModule) + + sol = Solver(sch.mod, p=p) + results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + # fc1: SRxRR->SR + # fc2: SRxRR->SR->RR + assert results["fc1_0"] == 2 + assert results["fc1_1"] == 0 + assert results["fc2_0"] == 2 + assert results["fc2_1"] == 0 + assert max_cost == (bs * seq_len * hidden_size / p + 1) + + +def test_bert_attn(): + from transformers import BertLMHeadModel, AutoConfig + import inspect + + config = AutoConfig.from_pretrained("bert-large-uncased") + with slapo.init_empty_weights(): + model = BertLMHeadModel(config) + logger.info(config, ranks=0) + + sch = slapo.create_schedule(model) + input_names = ["hidden_states"] + i = 0 + subsch = sch[f"bert.encoder.layer.{i}"] + sig = inspect.signature(subsch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + subsch.trace( + recursive=False, + flatten=True, + tracer="pytorch", + concrete_args=concrete_args, + ) + logger.info(subsch.mod.graph, ranks=0) + + seq_len = 512 + sol = Solver(subsch.mod, p=p) + _, max_cost = sol.solve([torch.randn(bs, seq_len, config.hidden_size)]) + assert max_cost == 3 * (bs * seq_len * config.hidden_size / p) + 4 + + +def test_gpt_attn(): + from transformers import GPTNeoModel, AutoConfig + import inspect + + config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B") + # config.use_cache = False + with slapo.init_empty_weights(): + model = GPTNeoModel(config) + logger.info(config, ranks=0) + + sch = slapo.create_schedule(model) + input_names = ["hidden_states"] + i = 0 + subsch = sch[f"h.{i}"] + sig = inspect.signature(subsch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + subsch.trace( + recursive=False, + flatten=True, + tracer="huggingface", + concrete_args=concrete_args, + config=config, + ) + logger.info(subsch.mod.graph, ranks=0) + + seq_len = 1024 + sol = Solver(subsch.mod, p=p) + _, max_cost = sol.solve([torch.randn(bs, seq_len, config.hidden_size)]) + assert max_cost == 3 * (bs * seq_len * config.hidden_size // p) + 3 + + +if __name__ == "__main__": + test_mlp() + test_bert_attn() + test_gpt_attn()