diff --git a/.gitignore b/.gitignore index 7d4676ce3..2fbd53fdf 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ mase-trainer/ test-trainer/ # DiffLogic: tutorial files -docs/tutorials/difflogic/data-mnist/ \ No newline at end of file +docs/tutorials/difflogic/data-mnist/ + +tiny-imagenet-200/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 72d438534..3c685ae7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,11 @@ description = "Machine-Learning Accelerator System Exploration Tools" readme = "README.md" requires-python = ">=3.11.9" +# PyTorch stack is pinned to match ExecuTorch 1.0.x on PyPI (torch>=2.9,<2.10). +# torchvision 0.24.1 requires torch==2.9.1. For export tooling: pip install -e ".[executorch]" dependencies = [ - "torch==2.6", - "torchvision", + "torch==2.9.1", + "torchvision==0.24.1", "onnx", "black", "toml", @@ -81,6 +83,11 @@ dependencies = [ "mase-triton>=0.0.6.post4; platform_machine == 'x86_64' and sys_platform == 'linux'", ] +[project.optional-dependencies] +executorch = [ + "executorch==1.0.1", +] + [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" diff --git a/src/chop/passes/graph/transforms/pruning/prune.py b/src/chop/passes/graph/transforms/pruning/prune.py index 373249706..9770d7ce8 100644 --- a/src/chop/passes/graph/transforms/pruning/prune.py +++ b/src/chop/passes/graph/transforms/pruning/prune.py @@ -1,4 +1,5 @@ import torch +from torch._subclasses.fake_tensor import DataDependentOutputException from chop.tools import get_logger @@ -43,7 +44,14 @@ def sparsify_input(module, args): f"{module.__class__.__name__} takes more than 1 argument at inference, the current sparsiy_input pre forward hook only allows one!" ) x = args[0] - mask = a_rank_fn(x, info, a_sparsity) + # torch.export / fake-tensor tracing: L1/random activation masks use ops like + # torch.quantile that are data-dependent and cannot appear in an ExportedProgram. + # Weight pruning (parametrization) stays in the graph; dynamic activation sparsity + # is skipped for export only — the deployed .pte has dense activations. + try: + mask = a_rank_fn(x, info, a_sparsity) + except DataDependentOutputException: + return x module.activation_mask = mask # it seems like the output of this can be a non-tuple thing?? return x * mask diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py index 665abc17e..63441f79d 100644 --- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py +++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py @@ -22,6 +22,37 @@ """ +def _under_fake_or_export_trace(tensor: torch.Tensor) -> bool: + """Avoid data-dependent ops (quantile, rand) when PT2 fake mode is active so they are + not embedded in the graph; export may succeed while to_edge decomps would raise.""" + try: + if torch._guards.detect_fake_mode() is not None: + return True + except Exception: + pass + try: + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(tensor, FakeTensor): + return True + except Exception: + pass + try: + from torch.fx.experimental.proxy_tensor import _ProxyTensor + + if isinstance(tensor, _ProxyTensor): + return True + except Exception: + pass + if type(tensor).__name__ == "FunctionalTensor": + return True + return False + + +def _dense_bool_mask_like(tensor: torch.Tensor) -> torch.Tensor: + return torch.ones(tensor.shape, dtype=torch.bool, device=tensor.device) + + def random(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor: """set sparsity percentage of values in the mask to False (i.e. 0) randomly @@ -34,6 +65,8 @@ def random(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor: :return: a random sparsity mask generated based on the sparsity value :rtype: torch.Tensor """ + if _under_fake_or_export_trace(tensor): + return _dense_bool_mask_like(tensor) mask = torch.ones(tensor.size(), dtype=torch.bool, device=tensor.device) mask[torch.rand(tensor.size()) < sparsity] = False return mask @@ -52,12 +85,16 @@ def l1(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor: :return: a sparsity mask :rtype: torch.Tensor """ + if _under_fake_or_export_trace(tensor): + return _dense_bool_mask_like(tensor) threshold = torch.quantile(tensor.abs().flatten(), sparsity) mask = (tensor.abs() > threshold).to(torch.bool).to(tensor.device) return mask def global_weight_l1(tensor: torch.Tensor, info: dict, sparsity: float): + if _under_fake_or_export_trace(tensor): + return _dense_bool_mask_like(tensor) tensors = [v["weight_value"] for _, v in info.items() if v is not None] flattened_tensors = [t.abs().flatten() for t in tensors] threshold = torch.quantile(torch.cat(flattened_tensors, dim=0), sparsity) @@ -66,6 +103,8 @@ def global_weight_l1(tensor: torch.Tensor, info: dict, sparsity: float): def global_activation_l1(tensor: torch.Tensor, info: dict, sparsity: float): + if _under_fake_or_export_trace(tensor): + return _dense_bool_mask_like(tensor) tensors = [v["activation_value"] for _, v in info.items() if v is not None] flattened_tensors = [t.abs().flatten() for t in tensors] threshold = torch.quantile(torch.cat(flattened_tensors, dim=0), sparsity)