Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
language_version: python

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix, --no-cache]

- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
- id: remove-tabs
name: Tabs remver (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
3 changes: 1 addition & 2 deletions test/graph_fusibility_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
import torch.fx as fx
from typing import Callable, List, Dict, Any, Union, Optional, Tuple
from typing import Callable, List, Any, Optional

# --- Imports (No local definitions or decorators, just use) ---
from tst.torch_ap.torch_ap_trace import torch_ap_trace
from tst.torch_ap.spider import down_spider as DS, up_spider as US
from tst.torch_ap.load_store_op import load, store
from torch.fx.passes.infra.pass_manager import PassManager, PassResult
from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import (
DemoMatmulEpilogueReplacerPass,
)
Expand Down
2 changes: 0 additions & 2 deletions test/matmul_epilogue_ap_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch.fx as fx

from tst.torch_ap.ap_pass import ApPass
from tst.torch_ap.match_replace_util import MatchContext
from tst.torch_ap.torch_ap_trace import torch_ap_trace


Expand Down Expand Up @@ -51,7 +50,6 @@ def forward(self, x):
return x - self.bias

class SimpleTracer(fx.Tracer):

def __init__(self, leaf_module_classes):
super().__init__()
self.leaf_module_classes = leaf_module_classes
Expand Down
13 changes: 10 additions & 3 deletions test/test_fusibility_predictor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import torch


def fusibility_of(epilogue_func) -> bool:
"""Check if an epilogue function is fusible."""
from tst.torch_ap.torch_ap_trace import torch_ap_trace
from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import DemoMatmulEpilogueReplacerPass
from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import (
DemoMatmulEpilogueReplacerPass,
)
from tst.torch_ap.trivial_ops_folder_pass import TrivialOpsFolderPass
from tst.torch_ap.concrete_pass.matmul_epilogue_util import get_matmul_epilogue_arg_name_to_is_mm_out
from tst.torch_ap.concrete_pass.matmul_epilogue_extractor_pass import MatmulEpilogueExtractorPass
from tst.torch_ap.concrete_pass.matmul_epilogue_util import (
get_matmul_epilogue_arg_name_to_is_mm_out,
)
from tst.torch_ap.concrete_pass.matmul_epilogue_extractor_pass import (
MatmulEpilogueExtractorPass,
)
from tst.torch_ap.concrete_pass.matmul_epilogue_fusibility_predictor import (
MatmuEpilogueFusibilityPredicator,
)
Expand Down
1 change: 0 additions & 1 deletion tst/fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import wraps
import tst


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.fx as fx
from typing import Any, Union, List
from typing import Any

# --- Core Functor Implementation ---

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.fx as fx
import inspect
import string
from typing import Any
Expand Down
3 changes: 1 addition & 2 deletions tst/torch_ap/concrete_pass/down_spider_inserter_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.fx as fx
from typing import List, Tuple, Callable
from tst.torch_ap.spider import down_spider
Expand Down Expand Up @@ -63,7 +62,7 @@ def get_topo_hash(g):

# --- Inline Logic ---
for gm, input_idx in gms_with_pos:
inserter = DownSpiderInserter(input_idx)
inserter = DownSpiderInserterPass(input_idx)
inserter(gm)


Expand Down
3 changes: 0 additions & 3 deletions tst/torch_ap/concrete_pass/matmul_epilogue_extractor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import torch.fx as fx

from tst.torch_ap.ap_pass import ApPass, PassResult
from tst.torch_ap.match_replace_util import MatchContext


class SimpleTracer(fx.Tracer):

def __init__(self, leaf_module_classes):
super().__init__()
self.leaf_module_classes = leaf_module_classes
Expand Down Expand Up @@ -46,7 +44,6 @@ def __call__(self, target: fx.GraphModule) -> PassResult:


if __name__ == "__main__":

# Target: Matmul -> MatmulEpilogue (call_module)
class TargetModel(torch.nn.Module):
def __init__(self):
Expand Down
84 changes: 55 additions & 29 deletions tst/torch_ap/concrete_pass/matmul_epilogue_fusibility_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ class MatmuEpilogueFusibilityPredicator:

def __init__(
self,
config_pattern_rewriters: Callable[[List[AccessTopoRule]], List[AccessTopoRule]],
config_pattern_removers: Optional[Callable[[List[ConfirmPattern]], List[ConfirmPattern]]] = None,
config_pattern_rewriters: Callable[
[List[AccessTopoRule]], List[AccessTopoRule]
],
config_pattern_removers: Optional[
Callable[[List[ConfirmPattern]], List[ConfirmPattern]]
] = None,
):
self._config_pattern_rewriters = config_pattern_rewriters
self._config_pattern_remover_overrider = config_pattern_removers
Expand All @@ -51,17 +55,34 @@ def config_pattern_rewriters(self) -> List[AccessTopoRule]:

def _default_rewriters(self) -> List[AccessTopoRule]:
return [
AccessTopoRule("y=x**2", "y=relu(x)", lambda x: x**2, lambda x: torch.relu(x)),
AccessTopoRule("y=tanh(x)", "y=relu(x)", lambda x: torch.tanh(x), lambda x: torch.relu(x)),
AccessTopoRule("z=DS(x)+y", "z=DS(US(x,y))", lambda x, y: DS(x) + y, lambda x, y: DS(US(x, y))),
AccessTopoRule("z=y+DS(x)", "z=DS(US(x,y))", lambda x, y: y + DS(x), lambda x, y: DS(US(x, y))),
AccessTopoRule("z=relu(DS(x))", "z=DS(x)", lambda x: torch.relu(DS(x)), lambda x: DS(x)),
AccessTopoRule(
"y=x**2", "y=relu(x)", lambda x: x**2, lambda x: torch.relu(x)
),
AccessTopoRule(
"y=tanh(x)",
"y=relu(x)",
lambda x: torch.tanh(x),
lambda x: torch.relu(x),
),
AccessTopoRule(
"z=DS(x)+y",
"z=DS(US(x,y))",
lambda x, y: DS(x) + y,
lambda x, y: DS(US(x, y)),
),
AccessTopoRule(
"z=y+DS(x)",
"z=DS(US(x,y))",
lambda x, y: y + DS(x),
lambda x, y: DS(US(x, y)),
),
AccessTopoRule(
"z=relu(DS(x))", "z=DS(x)", lambda x: torch.relu(DS(x)), lambda x: DS(x)
),
]

def _default_removers(
self,
mm_epi: fx.GraphModule,
mm_out_idx: int
self, mm_epi: fx.GraphModule, mm_out_idx: int
) -> List[ConfirmPattern]:
"""Generate default ConfirmPatterns from epilogue structure."""
arg_list = self._get_arg_list(mm_epi)
Expand All @@ -73,23 +94,27 @@ def _default_removers(
removers: List[ConfirmPattern] = []

for out_idx in output_indices:
removers.append(ConfirmPattern(
f"store(DS(x, {out_idx}))",
lambda x, idx=out_idx: store(DS(x), idx)
))
removers.append(
ConfirmPattern(
f"store(DS(x, {out_idx}))", lambda x, idx=out_idx: store(DS(x), idx)
)
)

for arg_name in other_arg_names:
removers.append(ConfirmPattern(
f"US(x, load(y, {arg_name}))",
lambda x, y, n=arg_name: US(x, load(y, n))
))
removers.append(
ConfirmPattern(
f"US(x, load(y, {arg_name}))",
lambda x, y, n=arg_name: US(x, load(y, n)),
)
)

if mm_out_arg_names:
mm_arg_name = mm_out_arg_names[0]
removers.append(ConfirmPattern(
f"load(x, {mm_arg_name})",
lambda x, n=mm_arg_name: load(x, n)
))
removers.append(
ConfirmPattern(
f"load(x, {mm_arg_name})", lambda x, n=mm_arg_name: load(x, n)
)
)

return removers

Expand All @@ -98,7 +123,10 @@ def _get_arg_list(self, gm: fx.GraphModule) -> List[tuple[str, bool]]:
placeholder_names = [n.name for n in gm.graph.nodes if n.op == "placeholder"]
if len(placeholder_names) == 0:
return []
return [(name, i == len(placeholder_names) - 1) for i, name in enumerate(placeholder_names)]
return [
(name, i == len(placeholder_names) - 1)
for i, name in enumerate(placeholder_names)
]

def _get_output_indices(self, gm: fx.GraphModule) -> List[Optional[int]]:
out_node = next(n for n in gm.graph.nodes if n.op == "output")
Expand All @@ -107,11 +135,7 @@ def _get_output_indices(self, gm: fx.GraphModule) -> List[Optional[int]]:
return list(range(len(res)))
return [None]

def __call__(
self,
mm_epi: fx.GraphModule,
mm_out_as_epi_in_index: int
) -> bool:
def __call__(self, mm_epi: fx.GraphModule, mm_out_as_epi_in_index: int) -> bool:
"""
Predict if the matmul epilogue is fusible.

Expand All @@ -137,7 +161,9 @@ def __call__(

for remover in removers:
pattern_gm = torch_ap_trace(remover.pattern_func)
replacement_gm = torch_ap_trace(remover.pattern_func) # Identity replacement
replacement_gm = torch_ap_trace(
remover.pattern_func
) # Identity replacement

matches = subgraph_rewriter.replace_pattern(
working_gm, pattern_gm, replacement_gm
Expand Down
6 changes: 1 addition & 5 deletions tst/torch_ap/concrete_pass/matmul_epilogue_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import torch
import torch.fx as fx

from tst.torch_ap.ap_pass import ApPass, PassResult
from tst.torch_ap.match_replace_util import MatchContext
from tst.torch_ap.ap_pass import ApPass


def get_matmul_epilogue_arg_name_to_is_mm_out(graph_module) -> list[(str, bool)]:
return MatmulEpilogueArgNameToIsMmOutGetter()(graph_module)


class SimpleTracer(fx.Tracer):

def __init__(self, leaf_module_classes):
super().__init__()
self.leaf_module_classes = leaf_module_classes
Expand Down Expand Up @@ -64,7 +62,6 @@ def _get_epilogue_placeholder_names(self, match_ctx) -> list[str]:


if __name__ == "__main__":

# Target: Matmul -> MatmulEpilogue (call_module)
class TargetModel(torch.nn.Module):
def __init__(self):
Expand All @@ -74,7 +71,6 @@ def forward(self, a, b):
return torch.tanh(torch.matmul(a, b) - 2.0)

t_gm = fx.GraphModule(TargetModel(), fx.Tracer().trace(TargetModel()))
from torch.fx.passes.infra.pass_manager import PassManager
from tst.torch_ap.trivial_ops_folder_pass import TrivialOpsFolderPass

pass_mgr = TrivialOpsFolderPass()
Expand Down
1 change: 0 additions & 1 deletion tst/torch_ap/load_store_op.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from tst.torch_ap.ops import load, store
1 change: 0 additions & 1 deletion tst/torch_ap/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.fx as fx


def down_spider(x):
Expand Down
Loading