Open-source implementation of Kimi Delta Attention — the hybrid linear-attention architecture from arXiv:2510.26692.
- Overview
- Key Features
- Architecture
- Technology Stack
- Setup & Installation
- Usage
- Core Capabilities
- Benchmarks
- Project Roadmap
- Implementation Status
- Development
- Contributing
- Citation
- License
kimi-linear implements Kimi Delta Attention (KDA) — the linear attention module at the heart of the Kimi Linear hybrid architecture described in arXiv:2510.26692.
KDA extends Gated DeltaNet with channel-wise (fine-grained) forget gating, replacing head-level scalar gates with per-channel vector gates
The hybrid deployment stacks three KDA layers for every one full MLA (Multi-Head Latent Attention) layer (3:1 ratio), achieving up to 6× decoding throughput and 75% KV-cache reduction at 1M-token contexts versus a full-attention baseline.
Important
This implementation targets the KDA module only. The full Kimi Linear model (including MLA, MoE feed-forward, and training infra) is not included. This is a research reference implementation, not a production serving stack.
| Icon | Feature | Description | Impact | Status |
|---|---|---|---|---|
| 🎛️ | Fine-Grained Gating | Per-channel |
Selective memory control | ✅ Stable |
| 🔄 | DPLR Transition | Two-step diagonal + rank-1 correction; O(K·V) not O(K²·V) | 2× faster than general DPLR | ✅ Stable |
| 🧠 | State Management | Constant-memory RNN state with checkpointing and NaN guards | O(1) per-token memory | ✅ Stable |
| 🔬 | Short Conv on K | Depthwise causal conv (kernel=4) on key projection (§3.1) | Local context in keys | ✅ Stable |
| 📐 | RMSNorm Output | Per-head RMSNorm on retrieved content before output gate | Numerical stability | ✅ Stable |
| 🚪 | Output Gate | Low-rank sigmoid gate |
Expressiveness | ✅ Stable |
| 🔢 | Mixed Precision | FP32 / FP16 / BF16 via PyTorch dtype | Flexibility | ✅ Stable |
| 🧪 | Test Suite | 45 unit + integration tests across all components | Coverage | ✅ Stable |
Performance from original paper (at 1M-token context):
- KV cache reduced by up to 75% vs full MLA baseline
- Decoding throughput up to 6× faster than full attention
- Prefill speedup of 2.3–2.9× at 512k–1M token range
- MMLU-Pro (4k) ≥ 51.0 competitiveness maintained
Tip
For inference-heavy workloads, enable use_short_conv=True (default) and use_output_gate=True (default) for the best accuracy; disable them only when comparing against the minimal ablation baseline.
At each token position
Computed in two sequential steps (no materialised K×K matrix):
-
Diagonal decay —
$S' = \text{Diag}(\alpha_t) \cdot S_{t-1}$ (element-wise broadcast) -
Rank-1 delta correction —
$S_t = S' - \beta_t k_t (k_t^\top S')$ (two einsum calls) -
KV write —
$S_t \mathrel{+}= \beta_t k_t v_t^\top$
flowchart TD
X["Input x ∈ ℝ^(B×T×D)"] --> QP["q_proj\n(Linear)"]
X --> KP["k_proj\n(Linear)"]
X --> VP["v_proj\n(Linear)"]
X --> BP["β_proj\n(Linear → sigmoid)"]
X --> FG["FineGrainedGating\n(low-rank bottleneck)"]
X --> OG["OutputGate\ndown/up projections"]
KP --> SC["ShortConv\n(depthwise, kernel=4, SiLU)"]
SC --> KN["L2-Normalise keys"]
KN --> DPLR["DPLRTransition\nDiag decay + rank-1 correction"]
VP --> DPLR
FG -- "α_t gates" --> DPLR
BP -- "β_t scalar" --> DPLR
DPLR --> SM["StateManager\nS_t ∈ ℝ^(B×H×K×V)"]
SM --> RET["Retrieval\neinsum S_t^T q_t"]
QP --> RET
RET --> NORM["RMSNorm(head_dim)"]
NORM --> MUL["⊙ Output Gate\nσ(W_up W_down x)"]
OG --> MUL
MUL --> OP["out_proj\n(Linear)"]
OP --> Y["Output y ∈ ℝ^(B×T×D)"]
sequenceDiagram
participant App
participant KDALayer
participant FGGating
participant DPLRTrans
participant StateManager
App->>KDALayer: forward(x, state=None)
KDALayer->>StateManager: initialize_state(B)
StateManager-->>KDALayer: S_0 = zeros(B,H,K,V)
loop For each token t in [0, T)
KDALayer->>FGGating: forward(x[:,t,:])
FGGating-->>KDALayer: α_t ∈ (0,1)^(B×H×K)
KDALayer->>DPLRTrans: forward(S_{t-1}, k_t, v_t, α_t, β_t)
DPLRTrans-->>KDALayer: S_t updated
KDALayer->>KDALayer: o_t = RMSNorm(einsum(S_t, q_t))
end
KDALayer->>App: output (B,T,D), final_state (B,H,K,V)
flowchart LR
I[Input Tokens] --> L1[KDA Layer]
L1 --> L2[KDA Layer]
L2 --> L3[KDA Layer]
L3 --> L4[MLA Layer]
L4 --> L5[KDA Layer]
L5 --> L6[KDA Layer]
L6 --> L7[KDA Layer]
L7 --> L8[MLA Layer]
L8 --> O[Output]
style L4 fill:#e8c84e,color:#000
style L8 fill:#e8c84e,color:#000
| Component | File | Purpose | Time | Space |
|---|---|---|---|---|
FineGrainedGating |
src/kda/gating.py |
Per-channel α_t via W_down/W_up + sigmoid | O(B·T·D·rank) | O(D·rank) |
DPLRTransition |
src/kda/dplr.py |
Two-step state transition; Gershgorin stability check | O(B·H·K·V) | O(B·H·K·V) |
StateManager |
src/kda/state_manager.py |
S_t lifecycle: init, update, checkpoint, OOM guard | O(B·H·K·V) | O(B·H·K·V) |
KDALayer |
src/kda/kda_layer.py |
Full KDA forward: projections → conv → gate → DPLR → norm → gate → out | O(B·T·H·K·V) | O(B·H·K·V) |
Each technology in this project was chosen deliberately. This section explains what each one is, what it does inside kimi-linear, and why it was selected over the alternatives.
What they are. Triton is an open-source GPU programming language and compiler developed by OpenAI. It lets you write GPU kernels in Python-like syntax that compile down to PTX (the NVIDIA GPU instruction set), achieving performance close to hand-written CUDA without requiring C++ expertise. CUDA kernels are the lower-level equivalent — raw C++ functions that run in parallel across thousands of GPU threads.
What they do here.
The triton_kernels.py module acts as a dispatch layer. When flash-linear-attention (FLA) is installed, chunk_kda_forward and fused_recurrent_kda_forward automatically route to FLA's production Triton kernels (fla.ops.kda.chunk_kda, fla.ops.kda.fused_recurrent_kda). These kernels:
- Fuse the WY representation, UT-transform state update, and intra-chunk attention into a single GPU kernel launch — eliminating the Python loop overhead and the intermediate tensor allocations that the pure-PyTorch path incurs.
- Use shared memory tiling to keep frequently-reused data (keys, gates, the running state) on-chip rather than round-tripping through HBM (GPU RAM), which is the dominant bottleneck at the sizes used in practice.
- Support bfloat16 and float16 with fused operations that reduce the number of memory reads/writes by keeping intermediate results in registers.
When FLA is not installed (CPU-only machine, no CUDA, CI environment), the module falls back transparently to the pure-PyTorch token loop — correct output, just slower.
Why Triton over alternatives.
| Option | Problem |
|---|---|
| Pure PyTorch (eager) | O(T) Python loop; each step materialises intermediate tensors; ~10–50× slower than fused kernel at T=2048 |
torch.compile + eager |
Reduces overhead but cannot tile across the recurrent state dimension; still memory-bandwidth-bound |
| Hand-written CUDA | Correct but requires C++ build toolchain, CUDA SDK, and per-architecture tuning; maintenance burden is high |
| Triton via FLA | Python-authored, auto-tuned launch configs, works on any sm70+ GPU (V100, A100, H100), active maintenance from fla-org |
The dispatch pattern — try FLA, fall back to PyTorch — means the codebase works on a laptop CPU for development and gets production speed on a GPU cluster without code changes.
What it is.
PyTorch is the dominant deep-learning framework for research and production. It provides n-dimensional tensor arithmetic, automatic differentiation (autograd), GPU memory management, and a large ecosystem of pre-built layers (nn.Module).
What it does here.
Every layer — KDALayer, MLALayer, ChunkwiseParallelKDA, KDAVLLMAdapter — is a torch.nn.Module. The recurrent state is a plain torch.Tensor. Einsum contractions (torch.einsum) express the DPLR and WY update equations in a form that is both readable and JIT-compilable. F.scaled_dot_product_attention in MLALayer automatically dispatches to Flash Attention 2 when available.
Why PyTorch over alternatives.
| Option | Problem |
|---|---|
| JAX | Functional-only style conflicts with the stateful recurrence design; smaller ecosystem for model serving |
| MLX (Apple) | CUDA support is a first-class requirement; MLX targets Apple Silicon |
| TensorFlow 2 | Less expressive dynamic graph for research; declining community adoption |
PyTorch 2.6 specifically added stable nn.RMSNorm (used in our output normalisation) and improved torch.compile support, which motivated the ≥ 2.6 minimum.
What it is.
torch.einsum evaluates Einstein-summation expressions — a compact notation for tensor contractions where repeated indices are summed over.
What it does here. The two hottest operations in KDA — the DPLR state update and the WY rank-BT correction — are expressed as einsums:
# Outer product write: k (B,H,K) × e (B,H,V) → delta_S (B,H,K,V)
state += beta * torch.einsum("bhk,bhv->bhkv", k_t, e_t)
# Query retrieval: S (B,H,K,V) × q (B,H,K) → output (B,H,V)
o_t = torch.einsum("bhkv,bhk->bhv", state, q_t)
# UT rank-BT correction: w (B,H,BT,K) × y (B,H,BT,V) → (B,H,K,V)
delta = torch.einsum("bhtk,bhtv->bhkv", w, y)Why einsum over explicit matmul/bmm.
Einsum expressions map directly to cuBLAS/cuBLASLt GEMM calls after PyTorch's contraction optimiser selects the best pairwise contraction order. Explicit matmul would require manual reshapes and transposes that obscure the mathematical intent. The trade-off is that einsum is slightly harder to read at first but much easier to verify against the paper equations.
What it is. Root Mean Square Layer Normalisation normalises activations by dividing by their RMS rather than subtracting the mean and dividing by the full standard deviation (as LayerNorm does).
What it does here.
Applied per-head to the retrieved content vector o_t = S_t^T q_t before the output gate. Without this, the retrieved vector's magnitude grows unboundedly as the state S_t accumulates outer products — causing the output gate to saturate and gradients to vanish.
Why RMSNorm over LayerNorm.
RMSNorm omits the mean-subtraction step. Empirically (Zhang & Sennrich 2019; used in LLaMA, Mistral, Kimi Linear) this has equal or better training stability at ~10% lower compute cost. nn.RMSNorm is native in PyTorch ≥ 2.4, removing the need for a custom kernel.
What it is.
A depthwise (grouped) 1-D convolution where each channel is filtered independently — groups=C means zero cross-channel mixing. Kernel size 4 with causal (left) padding injects a 4-token receptive field.
What it does here. Applied to the key projection before L2-normalisation (§3.1 of arXiv:2510.26692):
K = self.k_conv(K.transpose(1, 2)).narrow(2, 0, T).transpose(1, 2)
K = F.silu(K)The convolution gives each key a 4-step local context window — nearby tokens can influence what gets written into the state — without breaking the O(1) recurrent complexity because the final key at each step is still a single vector.
Why depthwise Conv1d over alternatives.
A full (non-depthwise) convolution would mix key channels and increase parameters by a factor of inner_dim. A self-attention local window would reintroduce quadratic complexity for the conv pass. Depthwise Conv1d is O(T·C·kernel) — essentially free — and adds exactly the right amount of local context per the paper specification.
What it provides here.
match/casestructural pattern matching for clean dispatch in state management error handling.X | Yunion type syntax in annotations (Optional[Tensor]→Tensor | None).__future__.annotationsdeferred evaluation, enabling forward references without quotes.
Why 3.10 minimum over 3.8/3.9.
nn.RMSNorm requires PyTorch 2.4, which dropped 3.8 support. Pattern matching and the union syntax materially improve code readability in the dispatch and error-handling paths. 3.10 is the oldest version still receiving security patches at the time of writing.
What it is. pytest is a Python testing framework that discovers and runs test functions/classes, provides rich assertion introspection, and supports fixtures, parametrize, and plugins.
What it does here.
130 tests across 8 files verify shape contracts, numerical correctness (no NaN/Inf), gradient flow, stateful continuity, error raising, and physics (gate decay suppresses old state). Fixtures (@pytest.fixture) provide reusable layer instances without duplicating setup code. --tb=short gives compact failure output in CI.
Why pytest over unittest.
unittest requires wrapping everything in classes that inherit TestCase and uses self.assertEqual style assertions. pytest uses bare assert statements, infers test discovery automatically, and its failure messages show the actual vs. expected values without any extra boilerplate.
What it is. Docker packages an application and its entire runtime environment (OS libraries, CUDA runtime, Python, pip dependencies) into a portable container image that runs identically on any Linux host with a Docker daemon.
What it does here.
Dockerfile.devmounts the source directory at runtime (-v $(pwd):/workspace) so code edits are reflected immediately without rebuilding — fast iteration during development.Dockerfile(production) bakes the source in at build time, producing a self-contained image suitable for deployment on Kubernetes, RunPod, or any container orchestration platform.- Both images pin the CUDA runtime version, eliminating the "works on my machine" class of GPU driver mismatches.
Why Docker over conda/venv-only.
A bare venv does not capture system libraries (CUDA runtime, libcudnn, NCCL). conda captures more but is slower to resolve and not the standard in production serving. Docker produces an immutable, reproducible artefact that can be pushed to a registry and deployed without any environment setup on the target machine.
What it is. flash-linear-attention is the official reference implementation of KDA (and other linear attention variants) by the Moonshot/FLA team. It provides hand-optimised Triton kernels for the KDA chunk forward/backward passes.
What it does here.
When installed, HAS_TRITON becomes True and chunk_kda_forward routes to fla.ops.kda.chunk_kda — the same kernel that powers the production Kimi Linear model. This is the fastest available path on CUDA hardware.
Why optional.
Most development and CI runs happen on CPU or without the FLA package. Making it optional means the package installs cleanly on any machine (pip install kimi-linear) and degrades gracefully to the pure-PyTorch path — no broken import, no missing .so files.
- Python ≥ 3.10
- PyTorch ≥ 2.6 (CPU or CUDA)
- CUDA ≥ 12.0 for GPU acceleration (optional)
# From PyPI (once published)
pip install kimi-linear
# From GitHub — always up to date
pip install git+https://github.com/hkevin01/kimi-linear.git
# With optional Triton kernels (requires CUDA)
pip install "git+https://github.com/hkevin01/kimi-linear.git#egg=kimi-linear[fla]"
# With vLLM deployment support
pip install "git+https://github.com/hkevin01/kimi-linear.git#egg=kimi-linear[vllm]"git clone https://github.com/hkevin01/kimi-linear.git
cd kimi-linear
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"# Development (mounts source, hot-reload)
docker build -f docker/Dockerfile.dev -t kimi-linear:dev .
docker run --gpus all -it -v $(pwd):/workspace kimi-linear:dev
# Production
docker build -f docker/Dockerfile -t kimi-linear:latest .python -c "import kda; print(kda.__version__)"
pytest tests/ -q
# 130 passed in ~1.6sNote
After pip install, the package is importable as import kda from any project in that environment. No need to be inside the kimi-linear directory.
After
pip install kimi-linear(or the editable install), all imports useimport kda.
import torch
from kda import KDALayer
layer = KDALayer(
hidden_dim=512,
num_heads=8,
head_dim=64,
dropout=0.0,
max_batch_size=32,
)
x = torch.randn(4, 128, 512) # (batch, seq_len, hidden_dim)
output, state = layer(x)
print(output.shape) # torch.Size([4, 128, 512])
print(state.shape) # torch.Size([4, 8, 64, 64])from kda import KDALayer
layer = KDALayer(
hidden_dim=512, num_heads=8, head_dim=64,
use_chunk_parallel=True, # WY + UT transform algorithm
chunk_size=64,
)
x = torch.randn(4, 512, 512)
output, state = layer(x)from kda import MLALayer
mla = MLALayer(
hidden_dim=512,
num_heads=8,
head_dim=64,
kv_latent_dim=64, # KV cache compressed to this size
)
x = torch.randn(4, 128, 512)
output, c_kv = mla(x) # c_kv: (4, 128, 64) — store this as KV cache
print(f"KV cache compression: {mla.kv_cache_compression_ratio:.1f}x")
# Generation: pass cached c_kv for subsequent tokens
next_token = torch.randn(4, 1, 512)
out_step, c_kv_new = mla(next_token, kv_cache=c_kv)import torch.nn as nn
from kda import KDALayer, MLALayer
class HybridBlock(nn.Module):
"""3 KDA layers followed by 1 MLA — matches Kimi Linear deployment."""
def __init__(self, hidden_dim=512, num_heads=8, head_dim=64):
super().__init__()
self.kda_layers = nn.ModuleList([
KDALayer(hidden_dim, num_heads, head_dim) for _ in range(3)
])
self.mla_layer = MLALayer(hidden_dim, num_heads, head_dim, kv_latent_dim=64)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(4)])
def forward(self, x, kda_states=None):
states = []
if kda_states is None:
kda_states = [None] * 3
for i, kda in enumerate(self.kda_layers):
out, s = kda(self.norms[i](x), state=kda_states[i])
x = x + out
states.append(s)
out_mla, c_kv = self.mla_layer(self.norms[3](x))
x = x + out_mla
return x, states, c_kv
model = HybridBlock()
x = torch.randn(2, 128, 512)
out, kda_states, c_kv = model(x)from kda import KDALayer
layer = KDALayer(hidden_dim=512, num_heads=8, head_dim=64)
# Process long sequence in chunks — state carries context across boundaries
state = None
for chunk in chunks: # each chunk: (B, chunk_len, D)
output, state = layer(chunk, state=state)from kda import KDALayer, KDAVLLMAdapter
kda_layer = KDALayer(hidden_dim=512, num_heads=8, head_dim=64)
adapter = KDAVLLMAdapter(
kda_layer=kda_layer,
num_heads=8, key_dim=64, value_dim=64,
max_blocks=1024,
)
# Prefill (encode context)
x_context = torch.randn(2, 256, 512)
out, state = adapter.prefill(x_context, seq_ids=[0, 1])
# Autoregressive decode (single token per step)
for step in range(100):
x_token = torch.randn(2, 1, 512)
out, state = adapter.decode_step(x_token, seq_ids=[0, 1])
adapter.free_sequence(0)
adapter.free_sequence(1)import torch
from kda import chunk_kda_forward, fused_recurrent_kda_forward, HAS_TRITON
print(f"Triton/FLA kernels active: {HAS_TRITON}")
# (B, H, T, d) convention
B, H, T, D = 2, 8, 128, 64
q = torch.randn(B, H, T, D)
k = torch.nn.functional.normalize(torch.randn(B, H, T, D), dim=-1)
v = torch.randn(B, H, T, D)
g = -torch.rand(B, H, T, 1) * 0.5 # log-space gate ≤ 0
beta = torch.rand(B, H, T, 1) * 0.5 + 0.5
# Chunkwise parallel (dispatches to Triton if FLA installed)
out, final_state = chunk_kda_forward(q, k, v, g, beta, chunk_size=64)
# Fused recurrent (optimal for T=1 decode steps)
out_r, state_r = fused_recurrent_kda_forward(q[:, :, :1, :], k[:, :, :1, :],
v[:, :, :1, :], g[:, :, :1, :],
beta[:, :, :1, :])from kda import FineGrainedGating, StateManager, DPLRTransition
gating = FineGrainedGating(hidden_dim=512, num_heads=8, head_dim=64)
state_mgr = StateManager(key_dim=64, value_dim=64, num_heads=8, max_batch_size=32)
dplr = DPLRTransition(key_dim=64, value_dim=64, num_heads=8)
x = torch.randn(4, 128, 512)
gates, _ = gating(x) # (4, 128, 8, 64)
state = state_mgr.initialize_state(batch_size=4) # (4, 8, 64, 64)from kda import KDALayer
# Without short conv
layer_no_conv = KDALayer(hidden_dim=512, num_heads=8, head_dim=64,
use_short_conv=False)
# Minimal baseline (no short conv, no output gate)
layer_minimal = KDALayer(hidden_dim=512, num_heads=8, head_dim=64,
use_short_conv=False, use_output_gate=False)Traditional linear attention uses a scalar gate per head. KDA uses a K-dimensional vector gate per head, computed via a low-rank bottleneck:
This gives the model fine-grained control over which memory dimensions to retain or decay at each step — comparable expressiveness to full attention's content-adaptive routing at O(1) state cost.
The general DPLR update requires O(K²·V) operations. KDA exploits the structural constraint
Step 1: S' = Diag(α_t) · S_{t-1} ← element-wise broadcast
Step 2: S = S' - β_t · k_t · (k_t⊤ S') ← two einsum calls
Step 3: S = S + β_t · k_t · v_t⊤ ← outer product write
A Gershgorin spectral radius estimate is computed each forward pass to warn when the transition matrix approaches instability (
A depthwise causal convolution (kernel size 4) is applied to the key projection before L2-normalisation:
# Causal: pad left by kernel-1, then trim right side
K = self.k_conv(K.transpose(1, 2)).narrow(2, 0, T).transpose(1, 2)
K = F.silu(K)This injects local positional context into keys without breaking the recurrent complexity.
Warning
Short conv requires cross-chunk key history for exact chunk-boundary equivalence. When testing sequential chunking, use use_short_conv=False to verify state-carry correctness independently.
After retrieval
The RMSNorm prevents magnitude explosion across deep stacks, while the output gate adds expressiveness without extra state cost.
Run the benchmark suite:
python scripts/benchmark/run_benchmarks.pyThe benchmark script measures:
- FineGrainedGating — throughput across sequence lengths and batch sizes
- DPLRTransition — two-step state update latency per head configuration
- KDALayer — end-to-end forward pass latency across hidden dims
# Run tests with coverage
pytest --cov=src tests/ -v
# Run only integration tests
pytest tests/kda/test_integration.py -vTest distribution:
pie title Test Coverage by Module
"test_gating.py" : 9
"test_dplr.py" : 9
"test_state_manager.py" : 9
"test_kda_layer.py" : 6
"test_integration.py" : 12
gantt
title kimi-linear Development Roadmap
dateFormat YYYY-MM-DD
section Core KDA
FineGrainedGating :done, g1, 2025-11-01, 2025-11-15
DPLRTransition :done, g2, 2025-11-15, 2025-12-01
StateManager :done, g3, 2025-12-01, 2025-12-15
KDALayer (full §3.1–3.2) :done, g4, 2025-12-15, 2026-01-15
section Quality
Unit + Integration Tests :done, q1, 2026-01-15, 2026-02-01
Structured Spec Comments :done, q2, 2026-02-01, 2026-03-01
section Performance
Chunkwise Parallelization :active, p1, 2026-03-01, 2026-06-01
Triton / CUDA Kernels : p2, 2026-06-01, 2026-09-01
section Integration
MLA Reference Module : i1, 2026-06-01, 2026-08-01
vLLM / SGLang Plugin : i2, 2026-09-01, 2026-12-01
| Phase | Goals | Target | Status |
|---|---|---|---|
| 1 — Core KDA | Gating, DPLR, State, Layer assembly | Q4 2025 | ✅ Complete |
| 2 — Quality | 45 tests, structured spec comments, Docker | Q1 2026 | ✅ Complete |
| 3 — Performance | Chunkwise parallel, WY rep, UT transform | Q2 2026 | ✅ Complete |
| 4 — Kernels | Triton DPLR kernel dispatch (FLA fallback) | Q3 2026 | ✅ Complete |
| 5 — Full Hybrid | MLA layer, 3:1 stack, vLLM integration | Q4 2026 | ✅ Complete |
| Component | Version | Stability | Tests | Known Limitations |
|---|---|---|---|---|
FineGrainedGating |
1.0 | ✅ Stable | 9 | Low-rank factorisation fixes gate rank at init; no dynamic rank |
DPLRTransition |
1.0 | ✅ Stable | 9 | Eigen check is Gershgorin heuristic |
StateManager |
1.0 | ✅ Stable | 9 | Pre-alloc buffer requires max_batch_size at init |
KDALayer |
1.2 | ✅ Stable | 6 + 12 integration | Short conv needs chunk-carry for exact equivalence |
ChunkwiseParallelKDA |
1.0 | ✅ Stable | 11 | BT must be power-of-2; caller pads T |
| WY representation | 1.0 | ✅ Stable | 5 (in chunk tests) | O(BT²) Python loop; Triton path via FLA |
| UT transform | 1.0 | ✅ Stable | 3 (in chunk tests) | Rank-BT update grows memory linearly with chunk size |
MLALayer |
1.0 | ✅ Stable | 11 | Full causal attention; no sparse variant |
| Triton/CUDA kernels | 1.0 | ✅ Stable | — | Dispatches to FLA when installed; PyTorch fallback |
KDAVLLMAdapter |
1.0 | ✅ Stable | 14 | vLLM not required; standalone mode supported |
kimi-linear/
├── src/
│ └── kda/
│ ├── __init__.py # Package exports
│ ├── gating.py # FineGrainedGating (KDA-GATE-*)
│ ├── dplr.py # DPLRTransition (KDA-DPLR-*)
│ ├── state_manager.py # StateManager (KDA-SM-*)
│ └── kda_layer.py # KDALayer (KDA-LAYER-*)
├── tests/
│ └── kda/
│ ├── test_gating.py
│ ├── test_dplr.py
│ ├── test_state_manager.py
│ ├── test_kda_layer.py
│ └── test_integration.py # End-to-end integration tests
├── scripts/
│ └── benchmark/
│ └── run_benchmarks.py
├── docs/
│ ├── project-plan.md
│ ├── PROJECT_STATUS.md
│ └── IMPLEMENTATION_SUMMARY.md
├── docker/
│ ├── Dockerfile
│ └── Dockerfile.dev
├── requirements.txt
├── setup.py
└── README.md
black src/ tests/ # format
pylint src/ # lint
mypy src/ # type-check- Create module in
src/kda/ - Add structured spec comment block (IDs:
KDA-<MODULE>-<TYPE>-<NNN>) - Write unit tests in
tests/kda/test_<module>.py - Export from
src/kda/__init__.py - Add integration coverage in
tests/kda/test_integration.py
📐 Spec Comment Format
Every module, class, and public method carries a structured spec block:
# ─────────────────────────────────────────────────────────────────────────────
# METHOD SPEC
# ID: KDA-MODULE-FWD-001
# Requirement: One precise, testable statement of what this method must do.
# Purpose: Why this method exists and what objective it supports.
# Rationale: Engineering reasoning behind the design choice.
# Inputs: All arguments: name, type, units, valid ranges.
# Outputs: Return values: type, shape, constraints.
# Preconditions: What must be true before calling.
# Postconditions:What is guaranteed true after return.
# Side Effects: State changes, I/O, counters updated.
# Failure Modes: How the method fails and mitigation strategy.
# Verification: Which tests cover this method.
# References: Paper sections, standards, or algorithms implemented.
# ─────────────────────────────────────────────────────────────────────────────🔁 Git Workflow
gitGraph
commit id: "main"
branch feature/component
checkout feature/component
commit id: "Add module"
commit id: "Add tests"
commit id: "Add spec comments"
checkout main
merge feature/component id: "PR merge"
branch feature/kernel
checkout feature/kernel
commit id: "Triton kernel"
checkout main
merge feature/kernel id: "Kernel merge"
Branch naming: feature/<component>, fix/<issue>, perf/<scope>.
Contributions are welcome. Please open an issue before starting significant work.
📋 Contribution Guidelines
- Fork the repository
- Create a branch:
git checkout -b feature/your-feature - Make changes with tests:
pytest tests/ -q - Format:
black src/ tests/ - Open a pull request against
main
- All 45 existing tests must pass
- New public methods must have a structured spec comment block
- New components require unit tests + at least one integration test
- No raw
print()— useloggingthroughout - Docstrings not required for private helpers; required for public API
self._attrprefix for instrumentation counters (_fwd_calls,_fwd_time_ms)@propertyshims for any renamed attributes to preserve backward compatibilitytorch.einsumpreferred over explicitmatmulfor readability at contraction boundaries
If this implementation is useful in your work, please cite the original paper:
@article{kimiteam2025kimilinear,
title = {Kimi Linear: An Expressive, Efficient Attention Architecture},
author = {Kimi Team},
journal = {arXiv preprint arXiv:2510.26692},
year = {2025}
}This project is licensed under the MIT License — you are free to use, modify, and distribute it with attribution. See LICENSE for the full text.
The Kimi Linear architecture is described in arXiv:2510.26692 by the Kimi Team.