Efficient inference for ReLU neural networks through hypersphere caching and piecewise-linear approximation
A PyTorch framework for analyzing, visualizing, and optimizing ReLU neural networks. The core innovation is hypersphere caching: exploiting the piecewise-linear nature of ReLU networks to cache and reuse linear approximations within validity regions, dramatically reducing inference cost.
ReLU neural networks are piecewise-linear functions. At any point in input space, the network behaves as a simple linear transformation within a local region (hypersphere). This repository provides:
ReluMLP: Flexible ReLU MLP with layer-wise linearization capabilitiesHypersphereCache: Intelligent caching system with O(log n) spatial lookups- Visualization tools: Decision boundaries, linear regions, and activation patterns
- IoT classification: Real-world application on human activity recognition datasets
- 🎯 Linearization at Points: Collapse entire ReLU networks into single linear transformations
- ⚡ Hypersphere Caching: O(log n) cache lookups using BallTree spatial indexing
- 🔍 Decision Boundary Visualization: See neuron activation patterns and contours
- 📊 Linear Region Analysis: Understand where networks behave linearly
- 🚀 GPU Acceleration: Full CUDA support for training and inference
- 📈 IoT Classification: Application to sensor-based activity recognition
A ReLU network with ( L ) layers and ( H ) hidden units per layer divides input space into up to ( 2^{L \times H} ) linear regions. Within each region, the network is exactly:
f(x) = W·x + b
where W and b are determined by which neurons are active (on/off pattern).
For a query point x:
- Linearize: Compute
W,b, and validity radiusratx - Cache: Store
(x, r, W, b)for future queries - Reuse: For new point
x'where||x' - x|| < r, use cachedW·x' + b
Result: Avoid expensive forward passes through the entire network!
fast-iot/
├── fast_iot/
│ ├── model.py # ReluMLP with linearization
│ ├── hypersphere_cache.py # Hypersphere caching system
│ ├── iot_classification.py # Activity recognition training
│ ├── dataset_loaders.py # UCI HAR, WISDM, MHEALTH loaders
│ ├── eval.py # Model evaluation utilities
│ ├── grid_search.py # Hyperparameter optimization
│ └── run_all_datasets.py # Multi-dataset benchmarks
├── example.ipynb # Interactive demo & visualization
├── data/ # Datasets (auto-downloaded)
├── results/ # Training outputs & plots
└── requirements.txt
pip install -r requirements.txtRequirements: PyTorch, NumPy, Matplotlib, Scikit-learn
Open example.ipynb to see:
- Training a 2D circle classifier
- Visualizing decision boundaries and neuron activations
- Sampling paths with cached linear approximations
- Hypersphere visualization with validity regions
import torch
from fast_iot.model import ReluMLP
from fast_iot.hypersphere_cache import HypersphereCache
# Create a model (2D input → 1D output for classification)
model = ReluMLP(input_dim=2, hidden_dim=3, num_layers=3, output_dim=1)
model.eval()
# Create cache with BallTree indexing
cache = HypersphereCache(model, device='cuda', rebuild_threshold=10)
# Query points
point1 = torch.tensor([0.5, 0.5])
point2 = torch.tensor([0.51, 0.49]) # Nearby point
# First query: cache miss (computes linearization)
output1 = cache(point1)
# Second query: cache hit (reuses linearization)
output2 = cache(point2)
# Check statistics
stats = cache.get_stats()
print(f"Hit rate: {stats['hit_rate']*100:.1f}%")
print(f"Cache size: {stats['cache_size']} regions")# Get full linearization information
point = torch.tensor([0.5, 0.5])
output, W, b, idx, is_cached = cache(point, return_linearization=True)
print(f"Output: {output}")
print(f"Linear weights W: {W}")
print(f"Linear bias b: {b}")
print(f"From cache? {is_cached}")
# Or use model directly
radii, W, b, outputs = model.linearize_at_point(point)
print(f"Validity radii: {[r.min().item() for r in radii]}")class ReluMLP(nn.Module):
def __init__(self, input_dim=2, hidden_dim=3, num_layers=3, output_dim=1)Key Methods:
forward(x): Standard forward passlinearize_at_point(x): Collapse network to linear form at pointx- Returns:
(radii, final_W, final_b, outputs) radii: List of validity radii for each hidden layerfinal_W,final_b: Collapsed linear transformationoutputs: Pre-activation values at each layer
- Returns:
eval_activations(x): Get pre-ReLU activations for all layers
Example:
model = ReluMLP(input_dim=10, hidden_dim=64, num_layers=4, output_dim=5)
point = torch.randn(10)
radii, W, b, outputs = model.linearize_at_point(point)
# Now: model(point) ≈ W @ point + b (within validity radius)class HypersphereCache:
def __init__(self, model, device=None, rebuild_threshold=10)Key Methods:
__call__(x, return_linearization=False): Query cache- If
xis in a cached region: return cached result (cache hit) - Otherwise: compute new linearization and cache it (cache miss)
- If
get_stats(): Get cache statistics (hits, misses, hit rate)get_circles(): Access all cached regionsclear_cache(): Reset cache
Caching Strategy:
- O(1) lookup: Check last accessed region
- O(k) lookup: Check recently added regions (not yet in BallTree)
- O(log n) lookup: BallTree spatial search over indexed regions
- Rebuild: Periodically rebuild BallTree when enough new regions accumulate
Parameters:
rebuild_threshold=1: Rebuild BallTree after every insertion (safest, slower)rebuild_threshold=10: Rebuild every 10 insertions (balanced, default)rebuild_threshold=50: Rebuild every 50 insertions (fastest queries, delayed indexing)
The example.ipynb notebook demonstrates:
plot_cell_sdf(model, resolution=300)Shows:
- Network output heatmap (SDF values)
- Neuron activation boundaries (where neurons turn on/off)
- Decision boundary (classification threshold)
Sample a path through input space and visualize:
- Cached circular regions (where linear approximation is valid)
- Decision boundary vs. linearized approximation
- Cache hit/miss points along the path
- Neuron contours from all layers
Individual plots for each neuron showing:
- Output after ReLU (heatmap)
- Pre-ReLU contour (decision boundary)
- Previous layer contours (context)
The framework includes loaders for three activity recognition datasets:
| Dataset | Activities | Sensors | Samples | Features |
|---|---|---|---|---|
| UCI HAR | 6 | Accel + Gyro | 10,299 | 561 |
| WISDM | 5 | Accelerometer | Variable | 15 |
| MHEALTH | 7 | Multi-sensor | Variable | 45 |
Activities: Walking, Jogging, Sitting, Standing, Stairs, Lying down, etc.
cd fast_iot
# Train on single dataset
python iot_classification.py --dataset uci_har
# Train on all datasets
python run_all_datasets.py
# Hyperparameter search
python grid_search.py| Dataset | Test Accuracy | Training Time (GPU) | Parameters |
|---|---|---|---|
| UCI HAR | 94-96% | ~30-60s | ~75K |
| WISDM | 92-95% | ~20-40s | ~12K |
| MHEALTH | 94-97% | ~40-80s | ~16K |
For a point x, the linearization works by:
- Compute activation pattern: Which neurons are on (pre-ReLU > 0)?
- Collapse layers: Multiply weight matrices, skipping dead neurons
- Calculate validity radius: Distance to nearest neuron boundary
radius = |pre_activation| / ||weights||
- Return linear transform:
f(x') ≈ W·x' + bfor||x' - x|| < radius
The cache uses a three-tier lookup:
# Tier 1: O(1) - Check last accessed region
if ||x - last_center|| < last_radius:
return cached_result
# Tier 2: O(k) - Check recent additions not yet in BallTree
for recent_region in new_regions:
if ||x - center|| < radius:
return cached_result
# Tier 3: O(log n) - BallTree spatial query
candidates = ball_tree.query_radius(x, max_radius)
for idx in candidates:
if ||x - center|| < radius:
return cached_result
# Tier 4: Cache miss - compute new linearization
compute_and_cache(x)This ensures:
- ✅ Correctness: Never misses a cached region
- ✅ Speed: O(log n) for most queries after warm-up
- ✅ Flexibility: Tunable rebuild frequency
From example.ipynb path sampling (88 points):
Total circles created: 31
Cache hits: 57
Cache misses: 31
Hit rate: 64.8%
Speedup: For 88 queries, only 31 full forward passes needed (35% of original cost).
| Cache Size | Lookup Time | Memory |
|---|---|---|
| 10 regions | ~0.01ms | <1MB |
| 100 regions | ~0.05ms | ~5MB |
| 1000 regions | ~0.15ms | ~50MB |
| 10000 regions | ~0.5ms | ~500MB |
Assumes 2D input, 1D output, 3 layers with 3 hidden units each
import torch
from fast_iot.model import ReluMLP
# Define custom architecture
model = ReluMLP(
input_dim=100, # High-dimensional input
hidden_dim=128, # Larger hidden layers
num_layers=5, # Deeper network
output_dim=10 # Multi-class classification
)
# Custom initialization for better linearization
with torch.no_grad():
for layer in model.layers:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
nn.init.constant_(layer.bias, 0.1)
# Train with your data
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# ... training loop ...# Conservative: rebuild every insertion
cache = HypersphereCache(model, rebuild_threshold=1)
# Balanced: rebuild every 10 insertions (default)
cache = HypersphereCache(model, rebuild_threshold=10)
# Aggressive: rebuild every 50 insertions (for dense sampling)
cache = HypersphereCache(model, rebuild_threshold=50)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
cache = HypersphereCache(model, device=device)
# All operations automatically on GPU
points = torch.randn(1000, 2).to(device)
outputs = [cache(p) for p in points]This framework supports research on:
- Inference optimization: Reduce computational cost via caching
- Sparse computation: Identify which regions are frequently accessed
- Model compression: Analyze which neurons contribute to decision boundaries
- Decision boundary analysis: Visualize how networks partition input space
- Linear region structure: Understand network complexity
- Neuron activation patterns: See which features activate which neurons
- Region counting: How many linear regions does a network create?
- Boundary smoothness: How often do regions change along paths?
- Approximation quality: How large are validity radii in practice?
- Memory scaling: Cache grows with number of unique regions visited
- High dimensions: BallTree efficiency degrades in high-dimensional spaces (curse of dimensionality)
- Dynamic networks: Cache invalidated if model weights change (no online learning support)
- Adaptive cache eviction (LRU, priority-based)
- Batch query optimization (vectorized cache lookups)
- Approximate linearization (larger radii, lower accuracy)
- Quantization-aware caching
- ONNX export for deployment
- Mobile/embedded optimization
If you use this code in your research, please cite:
@software{fast_iot_2025,
title = {Fast IoT: ReLU Network Linearization and Hypersphere Caching},
year = {2025},
author = {Anonymous},
url = {https://github.com/yourusername/fast-iot}
}- UCI HAR: Anguita, D., et al. (2013). "A Public Domain Dataset for Human Activity Recognition Using Smartphones"
- WISDM: Kwapisz, J., et al. (2011). "Activity Recognition using Cell Phone Accelerometers"
- MHEALTH: Banos, O., et al. (2014). "mHealthDroid: A Novel Framework for Agile Development of Mobile Health Applications"
- Python: 3.7+
- PyTorch: 2.0+
- NumPy: 1.21+
- Scikit-learn: 1.0+ (for BallTree, StandardScaler)
- Matplotlib: 3.5+ (for visualization)
- Pandas: 1.3+ (for IoT datasets)
- Seaborn: 0.11+ (for IoT visualization)
See requirements.txt for exact versions.
# If you get import errors, install from the fast_iot directory:
import sys
sys.path.append('/path/to/fast-iot')
from fast_iot.model import ReluMLP
from fast_iot.hypersphere_cache import HypersphereCache# Reduce batch size or use CPU
device = torch.device('cpu')
cache = HypersphereCache(model, device=device)- Ensure points are actually in cached regions (check radii)
- Lower
rebuild_thresholdfor more frequent BallTree updates - Verify device consistency (all tensors on same device)
Contributions welcome! Areas of interest:
- Alternative spatial data structures (KD-tree, R-tree)
- Cache eviction policies
- High-dimensional optimizations
- Additional visualization tools
- New application domains
MIT License. Datasets have separate licenses from original authors.
Fast IoT - Exploiting Piecewise-Linearity for Efficient Neural Network Inference
Last updated: October 2025