Skip to content

Tether is a Triton-powered framework for training and deploying Spiking Transformers.

License

Notifications You must be signed in to change notification settings

Khushiyant/tether

Repository files navigation

Tether

Gemini_Generated_Image_xofloxxofloxxofl

Tether is a Triton-powered framework for training and deploying Spiking Transformers and deep Spiking Neural Networks (SNNs).

We’ve solved the non-differentiability of discrete spikes by implementing high-performance Triton kernels with modular Surrogate Gradients.

Key Features

  • High-Performance Neurons:
    • LIF (Leaky Integrate-and-Fire): Standard spiking neuron with fused Triton kernels.
    • ALIF (Adaptive LIF): Neurons with adaptive thresholds for better temporal dynamics.
    • PLIF (Parametric LIF): Neurons with learnable, per-channel decay and threshold parameters.
  • Modular Surrogate Gradients: Choose from Arctan, Sigmoid, or FastSigmoid to train your SNNs effectively.
  • Linear Spike-Driven Attention: Eliminates the $O(N^2)$ Softmax bottleneck, allowing for massive context windows with significantly lower energy per inference.
  • Data Utilities: SpikingDatasetWrapper and encoding functions (rate_encoding, latency_encoding) to convert static datasets to spike trains.
  • Triton-Powered: Leverages OpenAI's Triton language for custom CUDA kernels, enabling massive speedups (60x+) over vanilla PyTorch.

Installation

This project is managed with uv.

uv sync

Or install dependencies manually:

pip install torch triton numpy

Usage

Using PLIF with Sigmoid Surrogate

import torch
from tether import PLIF, Sigmoid

# Create a Parametric LIF layer with Sigmoid surrogate
# Decay and threshold are learnable vectors per neuron
layer = PLIF(
    n_neurons=128, 
    init_decay=0.9, 
    surrogate=Sigmoid(alpha=4.0)
).cuda()

# Input sequence: (Time, Batch, Neurons)
x = torch.randn(32, 16, 128).cuda()
spikes = layer(x)

Training a Spiking Language Model

The train_stories.py script demonstrates training a Spiking-LLM on the TinyShakespeare dataset.

python train_stories.py

Data Encoding

from tether.data import SpikingDatasetWrapper, rate_encoding
from torchvision.datasets import MNIST

# Wrap MNIST to output spike trains
spiking_mnist = SpikingDatasetWrapper(
    MNIST(root="./data", download=True, train=True),
    encode_fn=lambda x: rate_encoding(x, n_steps=10)
)

Architecture

  • tether.kernels: Custom Triton kernels for LIF, ALIF, and PLIF.
  • tether.functional: PyTorch autograd functions wrapping the Triton kernels.
  • tether.nn: Neural network modules including LIF, ALIF, PLIF, SpikingSelfAttention.
  • tether.data: Utilities for spike encoding and dataset wrapping.

License

Apache-2.0

About

Tether is a Triton-powered framework for training and deploying Spiking Transformers.

Topics

Resources

License

Stars

Watchers

Forks

Languages