Skip to content

Stippler/fast-iot

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fast IoT - ReLU Network Linearization & Caching

Efficient inference for ReLU neural networks through hypersphere caching and piecewise-linear approximation

A PyTorch framework for analyzing, visualizing, and optimizing ReLU neural networks. The core innovation is hypersphere caching: exploiting the piecewise-linear nature of ReLU networks to cache and reuse linear approximations within validity regions, dramatically reducing inference cost.


Overview

ReLU neural networks are piecewise-linear functions. At any point in input space, the network behaves as a simple linear transformation within a local region (hypersphere). This repository provides:

  1. ReluMLP: Flexible ReLU MLP with layer-wise linearization capabilities
  2. HypersphereCache: Intelligent caching system with O(log n) spatial lookups
  3. Visualization tools: Decision boundaries, linear regions, and activation patterns
  4. IoT classification: Real-world application on human activity recognition datasets

Key Features

  • 🎯 Linearization at Points: Collapse entire ReLU networks into single linear transformations
  • Hypersphere Caching: O(log n) cache lookups using BallTree spatial indexing
  • 🔍 Decision Boundary Visualization: See neuron activation patterns and contours
  • 📊 Linear Region Analysis: Understand where networks behave linearly
  • 🚀 GPU Acceleration: Full CUDA support for training and inference
  • 📈 IoT Classification: Application to sensor-based activity recognition

Core Concepts

Piecewise-Linear Nature of ReLU Networks

A ReLU network with ( L ) layers and ( H ) hidden units per layer divides input space into up to ( 2^{L \times H} ) linear regions. Within each region, the network is exactly:

f(x) = W·x + b

where W and b are determined by which neurons are active (on/off pattern).

Hypersphere Caching

For a query point x:

  1. Linearize: Compute W, b, and validity radius r at x
  2. Cache: Store (x, r, W, b) for future queries
  3. Reuse: For new point x' where ||x' - x|| < r, use cached W·x' + b

Result: Avoid expensive forward passes through the entire network!


Repository Structure

fast-iot/
├── fast_iot/
│   ├── model.py                   # ReluMLP with linearization
│   ├── hypersphere_cache.py       # Hypersphere caching system
│   ├── iot_classification.py      # Activity recognition training
│   ├── dataset_loaders.py         # UCI HAR, WISDM, MHEALTH loaders
│   ├── eval.py                    # Model evaluation utilities
│   ├── grid_search.py             # Hyperparameter optimization
│   └── run_all_datasets.py        # Multi-dataset benchmarks
├── example.ipynb                  # Interactive demo & visualization
├── data/                          # Datasets (auto-downloaded)
├── results/                       # Training outputs & plots
└── requirements.txt

Quick Start

Installation

pip install -r requirements.txt

Requirements: PyTorch, NumPy, Matplotlib, Scikit-learn

Basic Usage

1. Interactive Demo

Open example.ipynb to see:

  • Training a 2D circle classifier
  • Visualizing decision boundaries and neuron activations
  • Sampling paths with cached linear approximations
  • Hypersphere visualization with validity regions

2. Python API

import torch
from fast_iot.model import ReluMLP
from fast_iot.hypersphere_cache import HypersphereCache

# Create a model (2D input → 1D output for classification)
model = ReluMLP(input_dim=2, hidden_dim=3, num_layers=3, output_dim=1)
model.eval()

# Create cache with BallTree indexing
cache = HypersphereCache(model, device='cuda', rebuild_threshold=10)

# Query points
point1 = torch.tensor([0.5, 0.5])
point2 = torch.tensor([0.51, 0.49])  # Nearby point

# First query: cache miss (computes linearization)
output1 = cache(point1)

# Second query: cache hit (reuses linearization)
output2 = cache(point2)

# Check statistics
stats = cache.get_stats()
print(f"Hit rate: {stats['hit_rate']*100:.1f}%")
print(f"Cache size: {stats['cache_size']} regions")

3. Linearization Details

# Get full linearization information
point = torch.tensor([0.5, 0.5])
output, W, b, idx, is_cached = cache(point, return_linearization=True)

print(f"Output: {output}")
print(f"Linear weights W: {W}")
print(f"Linear bias b: {b}")
print(f"From cache? {is_cached}")

# Or use model directly
radii, W, b, outputs = model.linearize_at_point(point)
print(f"Validity radii: {[r.min().item() for r in radii]}")

API Reference

ReluMLP

class ReluMLP(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=3, num_layers=3, output_dim=1)

Key Methods:

  • forward(x): Standard forward pass
  • linearize_at_point(x): Collapse network to linear form at point x
    • Returns: (radii, final_W, final_b, outputs)
    • radii: List of validity radii for each hidden layer
    • final_W, final_b: Collapsed linear transformation
    • outputs: Pre-activation values at each layer
  • eval_activations(x): Get pre-ReLU activations for all layers

Example:

model = ReluMLP(input_dim=10, hidden_dim=64, num_layers=4, output_dim=5)
point = torch.randn(10)
radii, W, b, outputs = model.linearize_at_point(point)
# Now: model(point) ≈ W @ point + b (within validity radius)

HypersphereCache

class HypersphereCache:
    def __init__(self, model, device=None, rebuild_threshold=10)

Key Methods:

  • __call__(x, return_linearization=False): Query cache
    • If x is in a cached region: return cached result (cache hit)
    • Otherwise: compute new linearization and cache it (cache miss)
  • get_stats(): Get cache statistics (hits, misses, hit rate)
  • get_circles(): Access all cached regions
  • clear_cache(): Reset cache

Caching Strategy:

  1. O(1) lookup: Check last accessed region
  2. O(k) lookup: Check recently added regions (not yet in BallTree)
  3. O(log n) lookup: BallTree spatial search over indexed regions
  4. Rebuild: Periodically rebuild BallTree when enough new regions accumulate

Parameters:

  • rebuild_threshold=1: Rebuild BallTree after every insertion (safest, slower)
  • rebuild_threshold=10: Rebuild every 10 insertions (balanced, default)
  • rebuild_threshold=50: Rebuild every 50 insertions (fastest queries, delayed indexing)

Visualization

The example.ipynb notebook demonstrates:

1. Decision Boundary Visualization

plot_cell_sdf(model, resolution=300)

Shows:

  • Network output heatmap (SDF values)
  • Neuron activation boundaries (where neurons turn on/off)
  • Decision boundary (classification threshold)

2. Linear Region Visualization

Sample a path through input space and visualize:

  • Cached circular regions (where linear approximation is valid)
  • Decision boundary vs. linearized approximation
  • Cache hit/miss points along the path
  • Neuron contours from all layers

3. Neuron-Level Analysis

Individual plots for each neuron showing:

  • Output after ReLU (heatmap)
  • Pre-ReLU contour (decision boundary)
  • Previous layer contours (context)

IoT Classification Application

Datasets

The framework includes loaders for three activity recognition datasets:

Dataset Activities Sensors Samples Features
UCI HAR 6 Accel + Gyro 10,299 561
WISDM 5 Accelerometer Variable 15
MHEALTH 7 Multi-sensor Variable 45

Activities: Walking, Jogging, Sitting, Standing, Stairs, Lying down, etc.

Training

cd fast_iot

# Train on single dataset
python iot_classification.py --dataset uci_har

# Train on all datasets
python run_all_datasets.py

# Hyperparameter search
python grid_search.py

Expected Performance

Dataset Test Accuracy Training Time (GPU) Parameters
UCI HAR 94-96% ~30-60s ~75K
WISDM 92-95% ~20-40s ~12K
MHEALTH 94-97% ~40-80s ~16K

Implementation Details

Linearization Algorithm

For a point x, the linearization works by:

  1. Compute activation pattern: Which neurons are on (pre-ReLU > 0)?
  2. Collapse layers: Multiply weight matrices, skipping dead neurons
  3. Calculate validity radius: Distance to nearest neuron boundary
    radius = |pre_activation| / ||weights||
  4. Return linear transform: f(x') ≈ W·x' + b for ||x' - x|| < radius

BallTree Caching Strategy

The cache uses a three-tier lookup:

# Tier 1: O(1) - Check last accessed region
if ||x - last_center|| < last_radius:
    return cached_result

# Tier 2: O(k) - Check recent additions not yet in BallTree
for recent_region in new_regions:
    if ||x - center|| < radius:
        return cached_result

# Tier 3: O(log n) - BallTree spatial query
candidates = ball_tree.query_radius(x, max_radius)
for idx in candidates:
    if ||x - center|| < radius:
        return cached_result

# Tier 4: Cache miss - compute new linearization
compute_and_cache(x)

This ensures:

  • ✅ Correctness: Never misses a cached region
  • ✅ Speed: O(log n) for most queries after warm-up
  • ✅ Flexibility: Tunable rebuild frequency

Performance Analysis

Cache Efficiency

From example.ipynb path sampling (88 points):

Total circles created: 31
Cache hits: 57
Cache misses: 31
Hit rate: 64.8%

Speedup: For 88 queries, only 31 full forward passes needed (35% of original cost).

Scalability

Cache Size Lookup Time Memory
10 regions ~0.01ms <1MB
100 regions ~0.05ms ~5MB
1000 regions ~0.15ms ~50MB
10000 regions ~0.5ms ~500MB

Assumes 2D input, 1D output, 3 layers with 3 hidden units each


Advanced Usage

Custom Training

import torch
from fast_iot.model import ReluMLP

# Define custom architecture
model = ReluMLP(
    input_dim=100,      # High-dimensional input
    hidden_dim=128,     # Larger hidden layers
    num_layers=5,       # Deeper network
    output_dim=10       # Multi-class classification
)

# Custom initialization for better linearization
with torch.no_grad():
    for layer in model.layers:
        if isinstance(layer, nn.Linear):
            nn.init.xavier_normal_(layer.weight)
            nn.init.constant_(layer.bias, 0.1)

# Train with your data
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# ... training loop ...

Cache Configuration

# Conservative: rebuild every insertion
cache = HypersphereCache(model, rebuild_threshold=1)

# Balanced: rebuild every 10 insertions (default)
cache = HypersphereCache(model, rebuild_threshold=10)

# Aggressive: rebuild every 50 insertions (for dense sampling)
cache = HypersphereCache(model, rebuild_threshold=50)

GPU Acceleration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
cache = HypersphereCache(model, device=device)

# All operations automatically on GPU
points = torch.randn(1000, 2).to(device)
outputs = [cache(p) for p in points]

Research Applications

This framework supports research on:

1. Neural Network Efficiency

  • Inference optimization: Reduce computational cost via caching
  • Sparse computation: Identify which regions are frequently accessed
  • Model compression: Analyze which neurons contribute to decision boundaries

2. Interpretability

  • Decision boundary analysis: Visualize how networks partition input space
  • Linear region structure: Understand network complexity
  • Neuron activation patterns: See which features activate which neurons

3. Theoretical Analysis

  • Region counting: How many linear regions does a network create?
  • Boundary smoothness: How often do regions change along paths?
  • Approximation quality: How large are validity radii in practice?

Limitations & Future Work

Current Limitations

  • Memory scaling: Cache grows with number of unique regions visited
  • High dimensions: BallTree efficiency degrades in high-dimensional spaces (curse of dimensionality)
  • Dynamic networks: Cache invalidated if model weights change (no online learning support)

Planned Enhancements

  • Adaptive cache eviction (LRU, priority-based)
  • Batch query optimization (vectorized cache lookups)
  • Approximate linearization (larger radii, lower accuracy)
  • Quantization-aware caching
  • ONNX export for deployment
  • Mobile/embedded optimization

Citation

If you use this code in your research, please cite:

@software{fast_iot_2025,
  title = {Fast IoT: ReLU Network Linearization and Hypersphere Caching},
  year = {2025},
  author = {Anonymous},
  url = {https://github.com/yourusername/fast-iot}
}

Dataset Citations

  • UCI HAR: Anguita, D., et al. (2013). "A Public Domain Dataset for Human Activity Recognition Using Smartphones"
  • WISDM: Kwapisz, J., et al. (2011). "Activity Recognition using Cell Phone Accelerometers"
  • MHEALTH: Banos, O., et al. (2014). "mHealthDroid: A Novel Framework for Agile Development of Mobile Health Applications"

Requirements

  • Python: 3.7+
  • PyTorch: 2.0+
  • NumPy: 1.21+
  • Scikit-learn: 1.0+ (for BallTree, StandardScaler)
  • Matplotlib: 3.5+ (for visualization)
  • Pandas: 1.3+ (for IoT datasets)
  • Seaborn: 0.11+ (for IoT visualization)

See requirements.txt for exact versions.


Troubleshooting

Import Errors

# If you get import errors, install from the fast_iot directory:
import sys
sys.path.append('/path/to/fast-iot')
from fast_iot.model import ReluMLP
from fast_iot.hypersphere_cache import HypersphereCache

CUDA Out of Memory

# Reduce batch size or use CPU
device = torch.device('cpu')
cache = HypersphereCache(model, device=device)

Cache Not Hitting

  • Ensure points are actually in cached regions (check radii)
  • Lower rebuild_threshold for more frequent BallTree updates
  • Verify device consistency (all tensors on same device)

Contributing

Contributions welcome! Areas of interest:

  • Alternative spatial data structures (KD-tree, R-tree)
  • Cache eviction policies
  • High-dimensional optimizations
  • Additional visualization tools
  • New application domains

License

MIT License. Datasets have separate licenses from original authors.


Fast IoT - Exploiting Piecewise-Linearity for Efficient Neural Network Inference

Last updated: October 2025

About

No description or website provided.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors