Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ mase-trainer/
test-trainer/

# DiffLogic: tutorial files
docs/tutorials/difflogic/data-mnist/
docs/tutorials/difflogic/data-mnist/

tiny-imagenet-200/
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ description = "Machine-Learning Accelerator System Exploration Tools"
readme = "README.md"
requires-python = ">=3.11.9"

# PyTorch stack is pinned to match ExecuTorch 1.0.x on PyPI (torch>=2.9,<2.10).
# torchvision 0.24.1 requires torch==2.9.1. For export tooling: pip install -e ".[executorch]"
dependencies = [
"torch==2.6",
"torchvision",
"torch==2.9.1",
"torchvision==0.24.1",
"onnx",
"black",
"toml",
Expand Down Expand Up @@ -81,6 +83,11 @@ dependencies = [
"mase-triton>=0.0.6.post4; platform_machine == 'x86_64' and sys_platform == 'linux'",
]

[project.optional-dependencies]
executorch = [
"executorch==1.0.1",
]

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
Expand Down
10 changes: 9 additions & 1 deletion src/chop/passes/graph/transforms/pruning/prune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch._subclasses.fake_tensor import DataDependentOutputException

from chop.tools import get_logger

Expand Down Expand Up @@ -43,7 +44,14 @@ def sparsify_input(module, args):
f"{module.__class__.__name__} takes more than 1 argument at inference, the current sparsiy_input pre forward hook only allows one!"
)
x = args[0]
mask = a_rank_fn(x, info, a_sparsity)
# torch.export / fake-tensor tracing: L1/random activation masks use ops like
# torch.quantile that are data-dependent and cannot appear in an ExportedProgram.
# Weight pruning (parametrization) stays in the graph; dynamic activation sparsity
# is skipped for export only — the deployed .pte has dense activations.
try:
mask = a_rank_fn(x, info, a_sparsity)
except DataDependentOutputException:
return x
module.activation_mask = mask
# it seems like the output of this can be a non-tuple thing??
return x * mask
Expand Down
39 changes: 39 additions & 0 deletions src/chop/passes/graph/transforms/pruning/pruning_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,37 @@
"""


def _under_fake_or_export_trace(tensor: torch.Tensor) -> bool:
"""Avoid data-dependent ops (quantile, rand) when PT2 fake mode is active so they are
not embedded in the graph; export may succeed while to_edge decomps would raise."""
try:
if torch._guards.detect_fake_mode() is not None:
return True
except Exception:
pass
try:
from torch._subclasses.fake_tensor import FakeTensor

if isinstance(tensor, FakeTensor):
return True
except Exception:
pass
try:
from torch.fx.experimental.proxy_tensor import _ProxyTensor

if isinstance(tensor, _ProxyTensor):
return True
except Exception:
pass
if type(tensor).__name__ == "FunctionalTensor":
return True
return False


def _dense_bool_mask_like(tensor: torch.Tensor) -> torch.Tensor:
return torch.ones(tensor.shape, dtype=torch.bool, device=tensor.device)


def random(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor:
"""set sparsity percentage of values
in the mask to False (i.e. 0) randomly
Expand All @@ -34,6 +65,8 @@ def random(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor:
:return: a random sparsity mask generated based on the sparsity value
:rtype: torch.Tensor
"""
if _under_fake_or_export_trace(tensor):
return _dense_bool_mask_like(tensor)
mask = torch.ones(tensor.size(), dtype=torch.bool, device=tensor.device)
mask[torch.rand(tensor.size()) < sparsity] = False
return mask
Expand All @@ -52,12 +85,16 @@ def l1(tensor: torch.Tensor, info: dict, sparsity: float) -> torch.Tensor:
:return: a sparsity mask
:rtype: torch.Tensor
"""
if _under_fake_or_export_trace(tensor):
return _dense_bool_mask_like(tensor)
threshold = torch.quantile(tensor.abs().flatten(), sparsity)
mask = (tensor.abs() > threshold).to(torch.bool).to(tensor.device)
return mask


def global_weight_l1(tensor: torch.Tensor, info: dict, sparsity: float):
if _under_fake_or_export_trace(tensor):
return _dense_bool_mask_like(tensor)
tensors = [v["weight_value"] for _, v in info.items() if v is not None]
flattened_tensors = [t.abs().flatten() for t in tensors]
threshold = torch.quantile(torch.cat(flattened_tensors, dim=0), sparsity)
Expand All @@ -66,6 +103,8 @@ def global_weight_l1(tensor: torch.Tensor, info: dict, sparsity: float):


def global_activation_l1(tensor: torch.Tensor, info: dict, sparsity: float):
if _under_fake_or_export_trace(tensor):
return _dense_bool_mask_like(tensor)
tensors = [v["activation_value"] for _, v in info.items() if v is not None]
flattened_tensors = [t.abs().flatten() for t in tensors]
threshold = torch.quantile(torch.cat(flattened_tensors, dim=0), sparsity)
Expand Down