diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..cf1f5e4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + language_version: python + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.4 + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix, --no-cache] + +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + rev: v1.5.1 + hooks: + - id: remove-crlf + - id: remove-tabs + name: Tabs remver (Python) + files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ + args: [--whitespaces-count, '4'] diff --git a/test/graph_fusibility_demo.py b/test/graph_fusibility_demo.py index 401f9b1..2e4e66b 100644 --- a/test/graph_fusibility_demo.py +++ b/test/graph_fusibility_demo.py @@ -1,12 +1,11 @@ import torch import torch.fx as fx -from typing import Callable, List, Dict, Any, Union, Optional, Tuple +from typing import Callable, List, Any, Optional # --- Imports (No local definitions or decorators, just use) --- from tst.torch_ap.torch_ap_trace import torch_ap_trace from tst.torch_ap.spider import down_spider as DS, up_spider as US from tst.torch_ap.load_store_op import load, store -from torch.fx.passes.infra.pass_manager import PassManager, PassResult from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import ( DemoMatmulEpilogueReplacerPass, ) diff --git a/test/matmul_epilogue_ap_pass.py b/test/matmul_epilogue_ap_pass.py index dc3614f..519343e 100644 --- a/test/matmul_epilogue_ap_pass.py +++ b/test/matmul_epilogue_ap_pass.py @@ -2,7 +2,6 @@ import torch.fx as fx from tst.torch_ap.ap_pass import ApPass -from tst.torch_ap.match_replace_util import MatchContext from tst.torch_ap.torch_ap_trace import torch_ap_trace @@ -51,7 +50,6 @@ def forward(self, x): return x - self.bias class SimpleTracer(fx.Tracer): - def __init__(self, leaf_module_classes): super().__init__() self.leaf_module_classes = leaf_module_classes diff --git a/test/test_fusibility_predictor.py b/test/test_fusibility_predictor.py index 5346e84..4fcd21d 100644 --- a/test/test_fusibility_predictor.py +++ b/test/test_fusibility_predictor.py @@ -1,12 +1,19 @@ import torch + def fusibility_of(epilogue_func) -> bool: """Check if an epilogue function is fusible.""" from tst.torch_ap.torch_ap_trace import torch_ap_trace - from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import DemoMatmulEpilogueReplacerPass + from tst.torch_ap.concrete_pass.demo_matmul_epilogue_replacer_pass import ( + DemoMatmulEpilogueReplacerPass, + ) from tst.torch_ap.trivial_ops_folder_pass import TrivialOpsFolderPass - from tst.torch_ap.concrete_pass.matmul_epilogue_util import get_matmul_epilogue_arg_name_to_is_mm_out - from tst.torch_ap.concrete_pass.matmul_epilogue_extractor_pass import MatmulEpilogueExtractorPass + from tst.torch_ap.concrete_pass.matmul_epilogue_util import ( + get_matmul_epilogue_arg_name_to_is_mm_out, + ) + from tst.torch_ap.concrete_pass.matmul_epilogue_extractor_pass import ( + MatmulEpilogueExtractorPass, + ) from tst.torch_ap.concrete_pass.matmul_epilogue_fusibility_predictor import ( MatmuEpilogueFusibilityPredicator, ) diff --git a/tst/fn.py b/tst/fn.py index ff908f0..9e652e0 100644 --- a/tst/fn.py +++ b/tst/fn.py @@ -1,4 +1,3 @@ -from functools import wraps import tst diff --git a/tst/torch_ap/access_topo_pass_validator/op_convertion_rule_validator.py b/tst/torch_ap/access_topo_pass_validator/op_convertion_rule_validator.py index 8aefac0..67fc0cb 100644 --- a/tst/torch_ap/access_topo_pass_validator/op_convertion_rule_validator.py +++ b/tst/torch_ap/access_topo_pass_validator/op_convertion_rule_validator.py @@ -1,6 +1,6 @@ import torch import torch.fx as fx -from typing import Any, Union, List +from typing import Any # --- Core Functor Implementation --- diff --git a/tst/torch_ap/concrete_pass/demo_matmul_epilogue_replacer_pass.py b/tst/torch_ap/concrete_pass/demo_matmul_epilogue_replacer_pass.py index d710ac1..cd36142 100644 --- a/tst/torch_ap/concrete_pass/demo_matmul_epilogue_replacer_pass.py +++ b/tst/torch_ap/concrete_pass/demo_matmul_epilogue_replacer_pass.py @@ -1,5 +1,4 @@ import torch -import torch.fx as fx import inspect import string from typing import Any diff --git a/tst/torch_ap/concrete_pass/down_spider_inserter_pass.py b/tst/torch_ap/concrete_pass/down_spider_inserter_pass.py index 32fb022..dbac78f 100644 --- a/tst/torch_ap/concrete_pass/down_spider_inserter_pass.py +++ b/tst/torch_ap/concrete_pass/down_spider_inserter_pass.py @@ -1,4 +1,3 @@ -import torch import torch.fx as fx from typing import List, Tuple, Callable from tst.torch_ap.spider import down_spider @@ -63,7 +62,7 @@ def get_topo_hash(g): # --- Inline Logic --- for gm, input_idx in gms_with_pos: - inserter = DownSpiderInserter(input_idx) + inserter = DownSpiderInserterPass(input_idx) inserter(gm) diff --git a/tst/torch_ap/concrete_pass/matmul_epilogue_extractor_pass.py b/tst/torch_ap/concrete_pass/matmul_epilogue_extractor_pass.py index d62074d..076448b 100644 --- a/tst/torch_ap/concrete_pass/matmul_epilogue_extractor_pass.py +++ b/tst/torch_ap/concrete_pass/matmul_epilogue_extractor_pass.py @@ -2,11 +2,9 @@ import torch.fx as fx from tst.torch_ap.ap_pass import ApPass, PassResult -from tst.torch_ap.match_replace_util import MatchContext class SimpleTracer(fx.Tracer): - def __init__(self, leaf_module_classes): super().__init__() self.leaf_module_classes = leaf_module_classes @@ -46,7 +44,6 @@ def __call__(self, target: fx.GraphModule) -> PassResult: if __name__ == "__main__": - # Target: Matmul -> MatmulEpilogue (call_module) class TargetModel(torch.nn.Module): def __init__(self): diff --git a/tst/torch_ap/concrete_pass/matmul_epilogue_fusibility_predictor.py b/tst/torch_ap/concrete_pass/matmul_epilogue_fusibility_predictor.py index 15b221c..56f8a21 100644 --- a/tst/torch_ap/concrete_pass/matmul_epilogue_fusibility_predictor.py +++ b/tst/torch_ap/concrete_pass/matmul_epilogue_fusibility_predictor.py @@ -39,8 +39,12 @@ class MatmuEpilogueFusibilityPredicator: def __init__( self, - config_pattern_rewriters: Callable[[List[AccessTopoRule]], List[AccessTopoRule]], - config_pattern_removers: Optional[Callable[[List[ConfirmPattern]], List[ConfirmPattern]]] = None, + config_pattern_rewriters: Callable[ + [List[AccessTopoRule]], List[AccessTopoRule] + ], + config_pattern_removers: Optional[ + Callable[[List[ConfirmPattern]], List[ConfirmPattern]] + ] = None, ): self._config_pattern_rewriters = config_pattern_rewriters self._config_pattern_remover_overrider = config_pattern_removers @@ -51,17 +55,34 @@ def config_pattern_rewriters(self) -> List[AccessTopoRule]: def _default_rewriters(self) -> List[AccessTopoRule]: return [ - AccessTopoRule("y=x**2", "y=relu(x)", lambda x: x**2, lambda x: torch.relu(x)), - AccessTopoRule("y=tanh(x)", "y=relu(x)", lambda x: torch.tanh(x), lambda x: torch.relu(x)), - AccessTopoRule("z=DS(x)+y", "z=DS(US(x,y))", lambda x, y: DS(x) + y, lambda x, y: DS(US(x, y))), - AccessTopoRule("z=y+DS(x)", "z=DS(US(x,y))", lambda x, y: y + DS(x), lambda x, y: DS(US(x, y))), - AccessTopoRule("z=relu(DS(x))", "z=DS(x)", lambda x: torch.relu(DS(x)), lambda x: DS(x)), + AccessTopoRule( + "y=x**2", "y=relu(x)", lambda x: x**2, lambda x: torch.relu(x) + ), + AccessTopoRule( + "y=tanh(x)", + "y=relu(x)", + lambda x: torch.tanh(x), + lambda x: torch.relu(x), + ), + AccessTopoRule( + "z=DS(x)+y", + "z=DS(US(x,y))", + lambda x, y: DS(x) + y, + lambda x, y: DS(US(x, y)), + ), + AccessTopoRule( + "z=y+DS(x)", + "z=DS(US(x,y))", + lambda x, y: y + DS(x), + lambda x, y: DS(US(x, y)), + ), + AccessTopoRule( + "z=relu(DS(x))", "z=DS(x)", lambda x: torch.relu(DS(x)), lambda x: DS(x) + ), ] def _default_removers( - self, - mm_epi: fx.GraphModule, - mm_out_idx: int + self, mm_epi: fx.GraphModule, mm_out_idx: int ) -> List[ConfirmPattern]: """Generate default ConfirmPatterns from epilogue structure.""" arg_list = self._get_arg_list(mm_epi) @@ -73,23 +94,27 @@ def _default_removers( removers: List[ConfirmPattern] = [] for out_idx in output_indices: - removers.append(ConfirmPattern( - f"store(DS(x, {out_idx}))", - lambda x, idx=out_idx: store(DS(x), idx) - )) + removers.append( + ConfirmPattern( + f"store(DS(x, {out_idx}))", lambda x, idx=out_idx: store(DS(x), idx) + ) + ) for arg_name in other_arg_names: - removers.append(ConfirmPattern( - f"US(x, load(y, {arg_name}))", - lambda x, y, n=arg_name: US(x, load(y, n)) - )) + removers.append( + ConfirmPattern( + f"US(x, load(y, {arg_name}))", + lambda x, y, n=arg_name: US(x, load(y, n)), + ) + ) if mm_out_arg_names: mm_arg_name = mm_out_arg_names[0] - removers.append(ConfirmPattern( - f"load(x, {mm_arg_name})", - lambda x, n=mm_arg_name: load(x, n) - )) + removers.append( + ConfirmPattern( + f"load(x, {mm_arg_name})", lambda x, n=mm_arg_name: load(x, n) + ) + ) return removers @@ -98,7 +123,10 @@ def _get_arg_list(self, gm: fx.GraphModule) -> List[tuple[str, bool]]: placeholder_names = [n.name for n in gm.graph.nodes if n.op == "placeholder"] if len(placeholder_names) == 0: return [] - return [(name, i == len(placeholder_names) - 1) for i, name in enumerate(placeholder_names)] + return [ + (name, i == len(placeholder_names) - 1) + for i, name in enumerate(placeholder_names) + ] def _get_output_indices(self, gm: fx.GraphModule) -> List[Optional[int]]: out_node = next(n for n in gm.graph.nodes if n.op == "output") @@ -107,11 +135,7 @@ def _get_output_indices(self, gm: fx.GraphModule) -> List[Optional[int]]: return list(range(len(res))) return [None] - def __call__( - self, - mm_epi: fx.GraphModule, - mm_out_as_epi_in_index: int - ) -> bool: + def __call__(self, mm_epi: fx.GraphModule, mm_out_as_epi_in_index: int) -> bool: """ Predict if the matmul epilogue is fusible. @@ -137,7 +161,9 @@ def __call__( for remover in removers: pattern_gm = torch_ap_trace(remover.pattern_func) - replacement_gm = torch_ap_trace(remover.pattern_func) # Identity replacement + replacement_gm = torch_ap_trace( + remover.pattern_func + ) # Identity replacement matches = subgraph_rewriter.replace_pattern( working_gm, pattern_gm, replacement_gm diff --git a/tst/torch_ap/concrete_pass/matmul_epilogue_util.py b/tst/torch_ap/concrete_pass/matmul_epilogue_util.py index af6bc0e..f7f7f0d 100644 --- a/tst/torch_ap/concrete_pass/matmul_epilogue_util.py +++ b/tst/torch_ap/concrete_pass/matmul_epilogue_util.py @@ -1,8 +1,7 @@ import torch import torch.fx as fx -from tst.torch_ap.ap_pass import ApPass, PassResult -from tst.torch_ap.match_replace_util import MatchContext +from tst.torch_ap.ap_pass import ApPass def get_matmul_epilogue_arg_name_to_is_mm_out(graph_module) -> list[(str, bool)]: @@ -10,7 +9,6 @@ def get_matmul_epilogue_arg_name_to_is_mm_out(graph_module) -> list[(str, bool)] class SimpleTracer(fx.Tracer): - def __init__(self, leaf_module_classes): super().__init__() self.leaf_module_classes = leaf_module_classes @@ -64,7 +62,6 @@ def _get_epilogue_placeholder_names(self, match_ctx) -> list[str]: if __name__ == "__main__": - # Target: Matmul -> MatmulEpilogue (call_module) class TargetModel(torch.nn.Module): def __init__(self): @@ -74,7 +71,6 @@ def forward(self, a, b): return torch.tanh(torch.matmul(a, b) - 2.0) t_gm = fx.GraphModule(TargetModel(), fx.Tracer().trace(TargetModel())) - from torch.fx.passes.infra.pass_manager import PassManager from tst.torch_ap.trivial_ops_folder_pass import TrivialOpsFolderPass pass_mgr = TrivialOpsFolderPass() diff --git a/tst/torch_ap/load_store_op.py b/tst/torch_ap/load_store_op.py index 921d87f..e69de29 100644 --- a/tst/torch_ap/load_store_op.py +++ b/tst/torch_ap/load_store_op.py @@ -1 +0,0 @@ -from tst.torch_ap.ops import load, store diff --git a/tst/torch_ap/ops.py b/tst/torch_ap/ops.py index 39a029c..5470968 100644 --- a/tst/torch_ap/ops.py +++ b/tst/torch_ap/ops.py @@ -1,5 +1,4 @@ import torch -import torch.fx as fx def down_spider(x): diff --git a/tst/torch_ap/outputs_as_mut_inputs_transformer.py b/tst/torch_ap/outputs_as_mut_inputs_transformer.py index 1284643..7031af8 100644 --- a/tst/torch_ap/outputs_as_mut_inputs_transformer.py +++ b/tst/torch_ap/outputs_as_mut_inputs_transformer.py @@ -25,7 +25,10 @@ def __call__( # 2. ($symbolic_output_shapes * $input_symbol_map <- $target <- $input_dtypes <- $symbolic_input_shapes <- $placeholder_nodes) # We need to capture the symbolic environment to trace symbol sources - symbolic_output_shapes, input_symbol_map = self._infer_output_shapes_with_sources( + ( + symbolic_output_shapes, + input_symbol_map, + ) = self._infer_output_shapes_with_sources( target, example_inputs, placeholder_nodes ) @@ -53,7 +56,9 @@ def __call__( self._convert_outputs_to_mut_ops(sole_sub) # Finalize the main graph output - gm_with_sub.graph.output(mut_input_nodes[0] if len(mut_input_nodes) == 1 else tuple(mut_input_nodes)) + gm_with_sub.graph.output( + mut_input_nodes[0] if len(mut_input_nodes) == 1 else tuple(mut_input_nodes) + ) gm_with_sub.recompile() return gm_with_sub @@ -75,10 +80,10 @@ def _fold_to_sole_submodule(self, target: fx.GraphModule) -> fx.GraphModule: return new_gm, placeholder_nodes def _extract_input_symbols( - self, - fake_input: torch.Tensor, - ph_node: fx.Node, - symbol_map: Dict[Any, Tuple[fx.Node, int]] + self, + fake_input: torch.Tensor, + ph_node: fx.Node, + symbol_map: Dict[Any, Tuple[fx.Node, int]], ): """Helper to map SymInts from a single input to their source node and index.""" for i, dim in enumerate(fake_input.shape): @@ -90,7 +95,7 @@ def _infer_output_shapes_with_sources( self, target: fx.GraphModule, example_inputs: List[torch.Tensor], - placeholders: List[fx.Node] + placeholders: List[fx.Node], ) -> Tuple[List[List[torch.SymInt]], Dict[Any, Tuple[fx.Node, int]]]: """ Infer output shapes and build a map from SymInt to (placeholder_node, dim_index). @@ -101,22 +106,25 @@ def _infer_output_shapes_with_sources( input_symbol_map = {} with FakeTensorMode(shape_env=ShapeEnv()) as mode: fake_inputs = [mode.from_tensor(t) for t in example_inputs] - + # Build the map using a helper to reduce nesting for fake_input, ph_node in zip(fake_inputs, placeholders): self._extract_input_symbols(fake_input, ph_node, input_symbol_map) - + with torch.no_grad(): outputs = target(*fake_inputs) - - symbolic_shapes = [list(o.shape) for o in (outputs if isinstance(outputs, (tuple, list)) else [outputs])] - + + symbolic_shapes = [ + list(o.shape) + for o in (outputs if isinstance(outputs, (tuple, list)) else [outputs]) + ] + return symbolic_shapes, input_symbol_map def _bind_single_shape( - self, - shape: List[torch.SymInt], - input_symbol_map: Dict[Any, Tuple[fx.Node, int]] + self, + shape: List[torch.SymInt], + input_symbol_map: Dict[Any, Tuple[fx.Node, int]], ) -> List[Union[int, Tuple[fx.Node, int]]]: """Helper to bind a single shape's SymInts to their sources.""" recipe = [] @@ -131,21 +139,24 @@ def _bind_single_shape( return recipe def _bind_symbols_to_nodes( - self, - symbolic_shapes: List[List[torch.SymInt]], - input_symbol_map: Dict[Any, Tuple[fx.Node, int]] + self, + symbolic_shapes: List[List[torch.SymInt]], + input_symbol_map: Dict[Any, Tuple[fx.Node, int]], ) -> List[List[Union[int, Tuple[fx.Node, int]]]]: """ 3. ($runnable_output_shapes <- $symbolic_output_shapes <- $input_symbol_map) Uses helpers to reduce nesting while finding the EXACT source for each SymInt. """ - return [self._bind_single_shape(shape, input_symbol_map) for shape in symbolic_shapes] + return [ + self._bind_single_shape(shape, input_symbol_map) + for shape in symbolic_shapes + ] def _materialize_shape( - self, - gm: fx.GraphModule, + self, + gm: fx.GraphModule, recipe: List[Union[int, Tuple[fx.Node, int]]], - cache: Dict[fx.Node, fx.Node] + cache: Dict[fx.Node, fx.Node], ) -> Tuple[Any, ...]: """Helper to convert a shape recipe into a tuple of FX nodes/ints.""" actual_shape_nodes = [] @@ -155,20 +166,24 @@ def _materialize_shape( elif isinstance(item, tuple): source_node, dim_idx = item if source_node not in cache: - cache[source_node] = gm.graph.call_function(getattr, args=(source_node, "shape")) + cache[source_node] = gm.graph.call_function( + getattr, args=(source_node, "shape") + ) shape_attr = cache[source_node] - dim_val = gm.graph.call_function(operator.getitem, args=(shape_attr, dim_idx)) + dim_val = gm.graph.call_function( + operator.getitem, args=(shape_attr, dim_idx) + ) actual_shape_nodes.append(dim_val) else: actual_shape_nodes.append(item) return tuple(actual_shape_nodes) def _insert_empty_nodes( - self, - gm: fx.GraphModule, - runnable_shapes: List[List[Union[int, Tuple[fx.Node, int]]]], + self, + gm: fx.GraphModule, + runnable_shapes: List[List[Union[int, Tuple[fx.Node, int]]]], dtypes: List[torch.dtype], - placeholders: List[fx.Node] + placeholders: List[fx.Node], ) -> List[fx.Node]: """ 4. ($inserted_mut_input_nodes <- $gm_with_sub <- $runnable_output_shapes) @@ -176,21 +191,21 @@ def _insert_empty_nodes( """ insert_point = next(n for n in gm.graph.nodes if n.op != "placeholder") inserted_empty_nodes = [] - shape_node_cache = {} # Cache for getattr(node, 'shape') - + shape_node_cache = {} # Cache for getattr(node, 'shape') + with gm.graph.inserting_before(insert_point): for recipe in runnable_shapes: # Materialize the shape recipe into actual FX nodes actual_shape = self._materialize_shape(gm, recipe, shape_node_cache) - + empty_node = gm.graph.call_function( - torch.empty, - args=(actual_shape,), - kwargs={"dtype": dtypes[0]} + torch.empty, args=(actual_shape,), kwargs={"dtype": dtypes[0]} ) inserted_empty_nodes.append(empty_node) - sub_node = next(n for n in gm.graph.nodes if n.op == "call_module" and n.target == "sub") + sub_node = next( + n for n in gm.graph.nodes if n.op == "call_module" and n.target == "sub" + ) sub_node.args = (*placeholders, *inserted_empty_nodes) gm.recompile() return inserted_empty_nodes @@ -250,9 +265,11 @@ def run_dynamic_test(name, model, example_inputs, dynamic_test_inputs): correct = torch.allclose(out, ref) - print(f"✅ Input shapes {shapes} " - f"→ Output shape {tuple(out.shape)} " - f"| Correct: {correct}") + print( + f"✅ Input shapes {shapes} " + f"→ Output shape {tuple(out.shape)} " + f"| Correct: {correct}" + ) except Exception as e: print(f"❌ Input shapes {shapes} FAILED") @@ -260,15 +277,17 @@ def run_dynamic_test(name, model, example_inputs, dynamic_test_inputs): def test_main(): - class AddModel(torch.nn.Module): - def forward(self, x, y): return x + y + def forward(self, x, y): + return x + y class MulModel(torch.nn.Module): - def forward(self, x, y): return x * y + def forward(self, x, y): + return x * y class MatMulModel(torch.nn.Module): - def forward(self, x, y): return torch.matmul(x, y) + def forward(self, x, y): + return torch.matmul(x, y) # ------------------------------------------------- # Scenario 1: AddModel @@ -283,9 +302,9 @@ def forward(self, x, y): return torch.matmul(x, y) dynamic_test_inputs=[ (torch.randn(128, 64), torch.randn(128, 64)), # original (torch.randn(256, 64), torch.randn(256, 64)), # change batch - (torch.randn(32, 64), torch.randn(32, 64)), # smaller batch - (torch.randn(128, 128), torch.randn(128, 128)) # change feature dim - ] + (torch.randn(32, 64), torch.randn(32, 64)), # smaller batch + (torch.randn(128, 128), torch.randn(128, 128)), # change feature dim + ], ) # ------------------------------------------------- @@ -302,7 +321,7 @@ def forward(self, x, y): return torch.matmul(x, y) (torch.randn(256, 256), torch.randn(256, 256)), (torch.randn(128, 256), torch.randn(128, 256)), (torch.randn(256, 128), torch.randn(256, 128)), - ] + ], ) # ------------------------------------------------- @@ -317,9 +336,9 @@ def forward(self, x, y): return torch.matmul(x, y) ], dynamic_test_inputs=[ (torch.randn(128, 64), torch.randn(64, 32)), - (torch.randn(256, 64), torch.randn(64, 32)), # change batch - (torch.randn(128, 128), torch.randn(128, 16)), # change inner dims - ] + (torch.randn(256, 64), torch.randn(64, 32)), # change batch + (torch.randn(128, 128), torch.randn(128, 16)), # change inner dims + ], ) diff --git a/tst/torch_ap/store_op_inserter_pass.py b/tst/torch_ap/store_op_inserter_pass.py index e074b78..fe7a614 100644 --- a/tst/torch_ap/store_op_inserter_pass.py +++ b/tst/torch_ap/store_op_inserter_pass.py @@ -1,6 +1,6 @@ import torch import torch.fx as fx -from typing import List, Union +from typing import Union from tst.torch_ap.load_store_op import store from torch.fx.passes.infra.pass_manager import PassResult from tst.torch_ap.torch_ap_trace import torch_ap_trace diff --git a/tst/torch_ap/struct_matcher.py b/tst/torch_ap/struct_matcher.py index 3813000..8e2a2e7 100644 --- a/tst/torch_ap/struct_matcher.py +++ b/tst/torch_ap/struct_matcher.py @@ -1,6 +1,5 @@ -import torch import torch.fx as fx -from typing import Dict, List, Set, Union, Any, Tuple, Optional, Generator +from typing import Dict, List, Set, Union, Tuple # Viba type-aliases diff --git a/tst/torch_ap/struct_pattern_replacer.py b/tst/torch_ap/struct_pattern_replacer.py index 4a102f5..c02b537 100644 --- a/tst/torch_ap/struct_pattern_replacer.py +++ b/tst/torch_ap/struct_pattern_replacer.py @@ -1,7 +1,6 @@ import torch import torch.fx as fx -from typing import Dict, List, Set, Union, Tuple, Optional, Callable -from dataclasses import dataclass +from typing import Dict, List, Set, Tuple, Callable # ImportFrom sources (Logical Reference) from tst.torch_ap.struct_matcher import StructMatcher @@ -114,7 +113,7 @@ def _replace_subgraph( def test_main(): # Setup shared patterns p_add = fx.symbolic_trace(lambda x: torch.add(x, 1.0)) - repl_mul = lambda ctx: fx.symbolic_trace(lambda x: torch.mul(x, 10.0)) + repl_mul = lambda ctx: fx.symbolic_trace(lambda x: torch.mul(x, 10.0)) # noqa replacer = StructPatternReplacer(p_add, repl_mul) # 1. Multi-match (Standard) @@ -202,7 +201,7 @@ def forward(self, x): # 10. Multi-Output pattern (Tuple) p10 = fx.symbolic_trace(lambda x: (x + 1.0, x + 2.0)) t10 = fx.symbolic_trace(lambda x: (x + 1.0, x + 2.0, x + 3.0)) - repl10 = lambda ctx: fx.symbolic_trace(lambda x: (x * 10.0, x * 20.0)) + repl10 = lambda ctx: fx.symbolic_trace(lambda x: (x * 10.0, x * 20.0)) # noqa assert StructPatternReplacer(p10, repl10)(t10).modified print("Test 10: Multi-Output Pattern Passed") diff --git a/tst/torch_ap/trivial_ops_folder_pass.py b/tst/torch_ap/trivial_ops_folder_pass.py index ac4b3e8..ee7430b 100644 --- a/tst/torch_ap/trivial_ops_folder_pass.py +++ b/tst/torch_ap/trivial_ops_folder_pass.py @@ -46,7 +46,6 @@ def __call__(self, gm: fx.GraphModule) -> PassResult: def main(): import torch - import operator # Scenario: Multi-input model class M(torch.nn.Module):