Skip to content

hkevin01/kimi-linear

Repository files navigation

⚡ kimi-linear

Open-source implementation of Kimi Delta Attention — the hybrid linear-attention architecture from arXiv:2510.26692.

License GitHub Stars GitHub Forks Last Commit Repo Size Issues Python PyTorch arXiv Tests Code style: black


Table of Contents


🔭 Overview

kimi-linear implements Kimi Delta Attention (KDA) — the linear attention module at the heart of the Kimi Linear hybrid architecture described in arXiv:2510.26692.

KDA extends Gated DeltaNet with channel-wise (fine-grained) forget gating, replacing head-level scalar gates with per-channel vector gates $\alpha_t \in (0,1)^{d_k}$. This gives the finite-state RNN memory finer control over what context to retain or discard at each token step — improving long-context task performance without increasing the asymptotic O(B·H·K·V) state footprint.

The hybrid deployment stacks three KDA layers for every one full MLA (Multi-Head Latent Attention) layer (3:1 ratio), achieving up to 6× decoding throughput and 75% KV-cache reduction at 1M-token contexts versus a full-attention baseline.

Important

This implementation targets the KDA module only. The full Kimi Linear model (including MLA, MoE feed-forward, and training infra) is not included. This is a research reference implementation, not a production serving stack.

(back to top ↑)


✨ Key Features

Icon Feature Description Impact Status
🎛️ Fine-Grained Gating Per-channel $\alpha_t \in (0,1)^{d_k}$ via low-rank linear bottleneck Selective memory control ✅ Stable
🔄 DPLR Transition Two-step diagonal + rank-1 correction; O(K·V) not O(K²·V) 2× faster than general DPLR ✅ Stable
🧠 State Management Constant-memory RNN state with checkpointing and NaN guards O(1) per-token memory ✅ Stable
🔬 Short Conv on K Depthwise causal conv (kernel=4) on key projection (§3.1) Local context in keys ✅ Stable
📐 RMSNorm Output Per-head RMSNorm on retrieved content before output gate Numerical stability ✅ Stable
🚪 Output Gate Low-rank sigmoid gate $\sigma(W_\text{up}W_\text{down}x) \odot \text{norm}(o_t)$ (§3.2) Expressiveness ✅ Stable
🔢 Mixed Precision FP32 / FP16 / BF16 via PyTorch dtype Flexibility ✅ Stable
🧪 Test Suite 45 unit + integration tests across all components Coverage ✅ Stable

Performance from original paper (at 1M-token context):

  • KV cache reduced by up to 75% vs full MLA baseline
  • Decoding throughput up to 6× faster than full attention
  • Prefill speedup of 2.3–2.9× at 512k–1M token range
  • MMLU-Pro (4k) ≥ 51.0 competitiveness maintained

Tip

For inference-heavy workloads, enable use_short_conv=True (default) and use_output_gate=True (default) for the best accuracy; disable them only when comparing against the minimal ablation baseline.

(back to top ↑)


🏗️ Architecture

KDA State Update

At each token position $t$, the state $S_t \in \mathbb{R}^{d_k \times d_v}$ is updated as:

$$S_t = \bigl(\text{Diag}(\alpha_t) - \beta_t k_t k_t^\top \text{Diag}(\alpha_t)\bigr) S_{t-1} + \beta_t k_t v_t^\top$$

Computed in two sequential steps (no materialised K×K matrix):

  1. Diagonal decay$S' = \text{Diag}(\alpha_t) \cdot S_{t-1}$ (element-wise broadcast)
  2. Rank-1 delta correction$S_t = S' - \beta_t k_t (k_t^\top S')$ (two einsum calls)
  3. KV write$S_t \mathrel{+}= \beta_t k_t v_t^\top$

Component Diagram

flowchart TD
    X["Input x ∈ ℝ^(B×T×D)"] --> QP["q_proj\n(Linear)"]
    X --> KP["k_proj\n(Linear)"]
    X --> VP["v_proj\n(Linear)"]
    X --> BP["β_proj\n(Linear → sigmoid)"]
    X --> FG["FineGrainedGating\n(low-rank bottleneck)"]
    X --> OG["OutputGate\ndown/up projections"]

    KP --> SC["ShortConv\n(depthwise, kernel=4, SiLU)"]
    SC --> KN["L2-Normalise keys"]

    KN --> DPLR["DPLRTransition\nDiag decay + rank-1 correction"]
    VP --> DPLR
    FG -- "α_t gates" --> DPLR
    BP -- "β_t scalar" --> DPLR

    DPLR --> SM["StateManager\nS_t ∈ ℝ^(B×H×K×V)"]
    SM --> RET["Retrieval\neinsum S_t^T q_t"]
    QP --> RET

    RET --> NORM["RMSNorm(head_dim)"]
    NORM --> MUL["⊙ Output Gate\nσ(W_up W_down x)"]
    OG --> MUL
    MUL --> OP["out_proj\n(Linear)"]
    OP --> Y["Output y ∈ ℝ^(B×T×D)"]
Loading

Sequence Flow

sequenceDiagram
    participant App
    participant KDALayer
    participant FGGating
    participant DPLRTrans
    participant StateManager

    App->>KDALayer: forward(x, state=None)
    KDALayer->>StateManager: initialize_state(B)
    StateManager-->>KDALayer: S_0 = zeros(B,H,K,V)

    loop For each token t in [0, T)
        KDALayer->>FGGating: forward(x[:,t,:])
        FGGating-->>KDALayer: α_t ∈ (0,1)^(B×H×K)
        KDALayer->>DPLRTrans: forward(S_{t-1}, k_t, v_t, α_t, β_t)
        DPLRTrans-->>KDALayer: S_t updated
        KDALayer->>KDALayer: o_t = RMSNorm(einsum(S_t, q_t))
    end

    KDALayer->>App: output (B,T,D), final_state (B,H,K,V)
Loading

Layer Stack (3:1 Hybrid Deployment)

flowchart LR
    I[Input Tokens] --> L1[KDA Layer]
    L1 --> L2[KDA Layer]
    L2 --> L3[KDA Layer]
    L3 --> L4[MLA Layer]
    L4 --> L5[KDA Layer]
    L5 --> L6[KDA Layer]
    L6 --> L7[KDA Layer]
    L7 --> L8[MLA Layer]
    L8 --> O[Output]
    style L4 fill:#e8c84e,color:#000
    style L8 fill:#e8c84e,color:#000
Loading

Component Responsibilities

Component File Purpose Time Space
FineGrainedGating src/kda/gating.py Per-channel α_t via W_down/W_up + sigmoid O(B·T·D·rank) O(D·rank)
DPLRTransition src/kda/dplr.py Two-step state transition; Gershgorin stability check O(B·H·K·V) O(B·H·K·V)
StateManager src/kda/state_manager.py S_t lifecycle: init, update, checkpoint, OOM guard O(B·H·K·V) O(B·H·K·V)
KDALayer src/kda/kda_layer.py Full KDA forward: projections → conv → gate → DPLR → norm → gate → out O(B·T·H·K·V) O(B·H·K·V)

(back to top ↑)


🧰 Technology Stack

Each technology in this project was chosen deliberately. This section explains what each one is, what it does inside kimi-linear, and why it was selected over the alternatives.


🔥 Triton / CUDA Kernels (src/kda/triton_kernels.py)

What they are. Triton is an open-source GPU programming language and compiler developed by OpenAI. It lets you write GPU kernels in Python-like syntax that compile down to PTX (the NVIDIA GPU instruction set), achieving performance close to hand-written CUDA without requiring C++ expertise. CUDA kernels are the lower-level equivalent — raw C++ functions that run in parallel across thousands of GPU threads.

What they do here. The triton_kernels.py module acts as a dispatch layer. When flash-linear-attention (FLA) is installed, chunk_kda_forward and fused_recurrent_kda_forward automatically route to FLA's production Triton kernels (fla.ops.kda.chunk_kda, fla.ops.kda.fused_recurrent_kda). These kernels:

  • Fuse the WY representation, UT-transform state update, and intra-chunk attention into a single GPU kernel launch — eliminating the Python loop overhead and the intermediate tensor allocations that the pure-PyTorch path incurs.
  • Use shared memory tiling to keep frequently-reused data (keys, gates, the running state) on-chip rather than round-tripping through HBM (GPU RAM), which is the dominant bottleneck at the sizes used in practice.
  • Support bfloat16 and float16 with fused operations that reduce the number of memory reads/writes by keeping intermediate results in registers.

When FLA is not installed (CPU-only machine, no CUDA, CI environment), the module falls back transparently to the pure-PyTorch token loop — correct output, just slower.

Why Triton over alternatives.

Option Problem
Pure PyTorch (eager) O(T) Python loop; each step materialises intermediate tensors; ~10–50× slower than fused kernel at T=2048
torch.compile + eager Reduces overhead but cannot tile across the recurrent state dimension; still memory-bandwidth-bound
Hand-written CUDA Correct but requires C++ build toolchain, CUDA SDK, and per-architecture tuning; maintenance burden is high
Triton via FLA Python-authored, auto-tuned launch configs, works on any sm70+ GPU (V100, A100, H100), active maintenance from fla-org

The dispatch pattern — try FLA, fall back to PyTorch — means the codebase works on a laptop CPU for development and gets production speed on a GPU cluster without code changes.


🔢 PyTorch 2.6+ (torch)

What it is. PyTorch is the dominant deep-learning framework for research and production. It provides n-dimensional tensor arithmetic, automatic differentiation (autograd), GPU memory management, and a large ecosystem of pre-built layers (nn.Module).

What it does here. Every layer — KDALayer, MLALayer, ChunkwiseParallelKDA, KDAVLLMAdapter — is a torch.nn.Module. The recurrent state is a plain torch.Tensor. Einsum contractions (torch.einsum) express the DPLR and WY update equations in a form that is both readable and JIT-compilable. F.scaled_dot_product_attention in MLALayer automatically dispatches to Flash Attention 2 when available.

Why PyTorch over alternatives.

Option Problem
JAX Functional-only style conflicts with the stateful recurrence design; smaller ecosystem for model serving
MLX (Apple) CUDA support is a first-class requirement; MLX targets Apple Silicon
TensorFlow 2 Less expressive dynamic graph for research; declining community adoption

PyTorch 2.6 specifically added stable nn.RMSNorm (used in our output normalisation) and improved torch.compile support, which motivated the ≥ 2.6 minimum.


📐 torch.einsum

What it is. torch.einsum evaluates Einstein-summation expressions — a compact notation for tensor contractions where repeated indices are summed over.

What it does here. The two hottest operations in KDA — the DPLR state update and the WY rank-BT correction — are expressed as einsums:

# Outer product write: k (B,H,K) × e (B,H,V) → delta_S (B,H,K,V)
state += beta * torch.einsum("bhk,bhv->bhkv", k_t, e_t)

# Query retrieval: S (B,H,K,V) × q (B,H,K) → output (B,H,V)
o_t = torch.einsum("bhkv,bhk->bhv", state, q_t)

# UT rank-BT correction: w (B,H,BT,K) × y (B,H,BT,V) → (B,H,K,V)
delta = torch.einsum("bhtk,bhtv->bhkv", w, y)

Why einsum over explicit matmul/bmm. Einsum expressions map directly to cuBLAS/cuBLASLt GEMM calls after PyTorch's contraction optimiser selects the best pairwise contraction order. Explicit matmul would require manual reshapes and transposes that obscure the mathematical intent. The trade-off is that einsum is slightly harder to read at first but much easier to verify against the paper equations.


📏 nn.RMSNorm

What it is. Root Mean Square Layer Normalisation normalises activations by dividing by their RMS rather than subtracting the mean and dividing by the full standard deviation (as LayerNorm does).

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \varepsilon}} \cdot \gamma$$

What it does here. Applied per-head to the retrieved content vector o_t = S_t^T q_t before the output gate. Without this, the retrieved vector's magnitude grows unboundedly as the state S_t accumulates outer products — causing the output gate to saturate and gradients to vanish.

Why RMSNorm over LayerNorm. RMSNorm omits the mean-subtraction step. Empirically (Zhang & Sennrich 2019; used in LLaMA, Mistral, Kimi Linear) this has equal or better training stability at ~10% lower compute cost. nn.RMSNorm is native in PyTorch ≥ 2.4, removing the need for a custom kernel.


🌀 Depthwise nn.Conv1d (short convolution on keys)

What it is. A depthwise (grouped) 1-D convolution where each channel is filtered independently — groups=C means zero cross-channel mixing. Kernel size 4 with causal (left) padding injects a 4-token receptive field.

What it does here. Applied to the key projection before L2-normalisation (§3.1 of arXiv:2510.26692):

K = self.k_conv(K.transpose(1, 2)).narrow(2, 0, T).transpose(1, 2)
K = F.silu(K)

The convolution gives each key a 4-step local context window — nearby tokens can influence what gets written into the state — without breaking the O(1) recurrent complexity because the final key at each step is still a single vector.

Why depthwise Conv1d over alternatives. A full (non-depthwise) convolution would mix key channels and increase parameters by a factor of inner_dim. A self-attention local window would reintroduce quadratic complexity for the conv pass. Depthwise Conv1d is O(T·C·kernel) — essentially free — and adds exactly the right amount of local context per the paper specification.


🐍 Python 3.10+

What it provides here.

  • match/case structural pattern matching for clean dispatch in state management error handling.
  • X | Y union type syntax in annotations (Optional[Tensor]Tensor | None).
  • __future__.annotations deferred evaluation, enabling forward references without quotes.

Why 3.10 minimum over 3.8/3.9. nn.RMSNorm requires PyTorch 2.4, which dropped 3.8 support. Pattern matching and the union syntax materially improve code readability in the dispatch and error-handling paths. 3.10 is the oldest version still receiving security patches at the time of writing.


🧪 pytest 9.x

What it is. pytest is a Python testing framework that discovers and runs test functions/classes, provides rich assertion introspection, and supports fixtures, parametrize, and plugins.

What it does here. 130 tests across 8 files verify shape contracts, numerical correctness (no NaN/Inf), gradient flow, stateful continuity, error raising, and physics (gate decay suppresses old state). Fixtures (@pytest.fixture) provide reusable layer instances without duplicating setup code. --tb=short gives compact failure output in CI.

Why pytest over unittest. unittest requires wrapping everything in classes that inherit TestCase and uses self.assertEqual style assertions. pytest uses bare assert statements, infers test discovery automatically, and its failure messages show the actual vs. expected values without any extra boilerplate.


🐋 Docker (docker/Dockerfile, docker/Dockerfile.dev)

What it is. Docker packages an application and its entire runtime environment (OS libraries, CUDA runtime, Python, pip dependencies) into a portable container image that runs identically on any Linux host with a Docker daemon.

What it does here.

  • Dockerfile.dev mounts the source directory at runtime (-v $(pwd):/workspace) so code edits are reflected immediately without rebuilding — fast iteration during development.
  • Dockerfile (production) bakes the source in at build time, producing a self-contained image suitable for deployment on Kubernetes, RunPod, or any container orchestration platform.
  • Both images pin the CUDA runtime version, eliminating the "works on my machine" class of GPU driver mismatches.

Why Docker over conda/venv-only. A bare venv does not capture system libraries (CUDA runtime, libcudnn, NCCL). conda captures more but is slower to resolve and not the standard in production serving. Docker produces an immutable, reproducible artefact that can be pushed to a registry and deployed without any environment setup on the target machine.


⚡ FLA — flash-linear-attention ([fla] optional extra)

What it is. flash-linear-attention is the official reference implementation of KDA (and other linear attention variants) by the Moonshot/FLA team. It provides hand-optimised Triton kernels for the KDA chunk forward/backward passes.

What it does here. When installed, HAS_TRITON becomes True and chunk_kda_forward routes to fla.ops.kda.chunk_kda — the same kernel that powers the production Kimi Linear model. This is the fastest available path on CUDA hardware.

Why optional. Most development and CI runs happen on CPU or without the FLA package. Making it optional means the package installs cleanly on any machine (pip install kimi-linear) and degrades gracefully to the pure-PyTorch path — no broken import, no missing .so files.

(back to top ↑)


⚙️ Setup & Installation

Prerequisites

  • Python ≥ 3.10
  • PyTorch ≥ 2.6 (CPU or CUDA)
  • CUDA ≥ 12.0 for GPU acceleration (optional)

Install into your own project (recommended)

# From PyPI (once published)
pip install kimi-linear

# From GitHub — always up to date
pip install git+https://github.com/hkevin01/kimi-linear.git

# With optional Triton kernels (requires CUDA)
pip install "git+https://github.com/hkevin01/kimi-linear.git#egg=kimi-linear[fla]"

# With vLLM deployment support
pip install "git+https://github.com/hkevin01/kimi-linear.git#egg=kimi-linear[vllm]"

Developer install (editable — for contributing)

git clone https://github.com/hkevin01/kimi-linear.git
cd kimi-linear
python -m venv .venv
source .venv/bin/activate
pip install -e ".[dev]"

Docker

# Development (mounts source, hot-reload)
docker build -f docker/Dockerfile.dev -t kimi-linear:dev .
docker run --gpus all -it -v $(pwd):/workspace kimi-linear:dev

# Production
docker build -f docker/Dockerfile -t kimi-linear:latest .

Verify installation

python -c "import kda; print(kda.__version__)"
pytest tests/ -q
# 130 passed in ~1.6s

Note

After pip install, the package is importable as import kda from any project in that environment. No need to be inside the kimi-linear directory.

(back to top ↑)


🚀 Usage

After pip install kimi-linear (or the editable install), all imports use import kda.

Single KDA layer

import torch
from kda import KDALayer

layer = KDALayer(
    hidden_dim=512,
    num_heads=8,
    head_dim=64,
    dropout=0.0,
    max_batch_size=32,
)

x = torch.randn(4, 128, 512)        # (batch, seq_len, hidden_dim)
output, state = layer(x)
print(output.shape)                  # torch.Size([4, 128, 512])
print(state.shape)                   # torch.Size([4, 8, 64, 64])

Chunkwise parallel (faster on GPU, T ≥ 512)

from kda import KDALayer

layer = KDALayer(
    hidden_dim=512, num_heads=8, head_dim=64,
    use_chunk_parallel=True,   # WY + UT transform algorithm
    chunk_size=64,
)

x = torch.randn(4, 512, 512)
output, state = layer(x)

MLA layer (full-attention complement)

from kda import MLALayer

mla = MLALayer(
    hidden_dim=512,
    num_heads=8,
    head_dim=64,
    kv_latent_dim=64,           # KV cache compressed to this size
)

x = torch.randn(4, 128, 512)
output, c_kv = mla(x)          # c_kv: (4, 128, 64) — store this as KV cache
print(f"KV cache compression: {mla.kv_cache_compression_ratio:.1f}x")

# Generation: pass cached c_kv for subsequent tokens
next_token = torch.randn(4, 1, 512)
out_step, c_kv_new = mla(next_token, kv_cache=c_kv)

3:1 hybrid stack (KDA layers + 1 MLA every 4)

import torch.nn as nn
from kda import KDALayer, MLALayer

class HybridBlock(nn.Module):
    """3 KDA layers followed by 1 MLA — matches Kimi Linear deployment."""
    def __init__(self, hidden_dim=512, num_heads=8, head_dim=64):
        super().__init__()
        self.kda_layers = nn.ModuleList([
            KDALayer(hidden_dim, num_heads, head_dim) for _ in range(3)
        ])
        self.mla_layer = MLALayer(hidden_dim, num_heads, head_dim, kv_latent_dim=64)
        self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(4)])

    def forward(self, x, kda_states=None):
        states = []
        if kda_states is None:
            kda_states = [None] * 3
        for i, kda in enumerate(self.kda_layers):
            out, s = kda(self.norms[i](x), state=kda_states[i])
            x = x + out
            states.append(s)
        out_mla, c_kv = self.mla_layer(self.norms[3](x))
        x = x + out_mla
        return x, states, c_kv

model = HybridBlock()
x = torch.randn(2, 128, 512)
out, kda_states, c_kv = model(x)

Stateful chunked inference

from kda import KDALayer

layer = KDALayer(hidden_dim=512, num_heads=8, head_dim=64)

# Process long sequence in chunks — state carries context across boundaries
state = None
for chunk in chunks:                 # each chunk: (B, chunk_len, D)
    output, state = layer(chunk, state=state)

vLLM-style inference adapter

from kda import KDALayer, KDAVLLMAdapter

kda_layer = KDALayer(hidden_dim=512, num_heads=8, head_dim=64)
adapter = KDAVLLMAdapter(
    kda_layer=kda_layer,
    num_heads=8, key_dim=64, value_dim=64,
    max_blocks=1024,
)

# Prefill (encode context)
x_context = torch.randn(2, 256, 512)
out, state = adapter.prefill(x_context, seq_ids=[0, 1])

# Autoregressive decode (single token per step)
for step in range(100):
    x_token = torch.randn(2, 1, 512)
    out, state = adapter.decode_step(x_token, seq_ids=[0, 1])

adapter.free_sequence(0)
adapter.free_sequence(1)

Kernel dispatch (Triton when available)

import torch
from kda import chunk_kda_forward, fused_recurrent_kda_forward, HAS_TRITON

print(f"Triton/FLA kernels active: {HAS_TRITON}")

# (B, H, T, d) convention
B, H, T, D = 2, 8, 128, 64
q = torch.randn(B, H, T, D)
k = torch.nn.functional.normalize(torch.randn(B, H, T, D), dim=-1)
v = torch.randn(B, H, T, D)
g = -torch.rand(B, H, T, 1) * 0.5   # log-space gate ≤ 0
beta = torch.rand(B, H, T, 1) * 0.5 + 0.5

# Chunkwise parallel (dispatches to Triton if FLA installed)
out, final_state = chunk_kda_forward(q, k, v, g, beta, chunk_size=64)

# Fused recurrent (optimal for T=1 decode steps)
out_r, state_r = fused_recurrent_kda_forward(q[:, :, :1, :], k[:, :, :1, :],
                                              v[:, :, :1, :], g[:, :, :1, :],
                                              beta[:, :, :1, :])

Use individual components

from kda import FineGrainedGating, StateManager, DPLRTransition

gating = FineGrainedGating(hidden_dim=512, num_heads=8, head_dim=64)
state_mgr = StateManager(key_dim=64, value_dim=64, num_heads=8, max_batch_size=32)
dplr = DPLRTransition(key_dim=64, value_dim=64, num_heads=8)

x = torch.randn(4, 128, 512)
gates, _ = gating(x)                 # (4, 128, 8, 64)
state = state_mgr.initialize_state(batch_size=4)  # (4, 8, 64, 64)

Disable architectural options (ablation)

from kda import KDALayer

# Without short conv
layer_no_conv = KDALayer(hidden_dim=512, num_heads=8, head_dim=64,
                         use_short_conv=False)

# Minimal baseline (no short conv, no output gate)
layer_minimal = KDALayer(hidden_dim=512, num_heads=8, head_dim=64,
                         use_short_conv=False, use_output_gate=False)

(back to top ↑)


🔬 Core Capabilities

🎛️ Fine-Grained Channel-Wise Gating

Traditional linear attention uses a scalar gate per head. KDA uses a K-dimensional vector gate per head, computed via a low-rank bottleneck:

$$\alpha_t = \sigma!\left(W_\text{up}, \text{SiLU}(W_\text{down}, x_t)\right) \quad \in (0,1)^{d_k}$$

This gives the model fine-grained control over which memory dimensions to retain or decay at each step — comparable expressiveness to full attention's content-adaptive routing at O(1) state cost.

🔄 Constrained DPLR Transition

The general DPLR update requires O(K²·V) operations. KDA exploits the structural constraint $a_t = \beta_t k_t$, $b_t = k_t \odot \alpha_t$ to reduce this to O(K·V):

Step 1: S' = Diag(α_t) · S_{t-1}          ← element-wise broadcast
Step 2: S  = S' - β_t · k_t · (k_t⊤ S')  ← two einsum calls
Step 3: S  = S  + β_t · k_t · v_t⊤        ← outer product write

A Gershgorin spectral radius estimate is computed each forward pass to warn when the transition matrix approaches instability ($\rho > 1.1$).

📡 Short Convolution on Keys (§3.1)

A depthwise causal convolution (kernel size 4) is applied to the key projection before L2-normalisation:

# Causal: pad left by kernel-1, then trim right side
K = self.k_conv(K.transpose(1, 2)).narrow(2, 0, T).transpose(1, 2)
K = F.silu(K)

This injects local positional context into keys without breaking the recurrent complexity.

Warning

Short conv requires cross-chunk key history for exact chunk-boundary equivalence. When testing sequential chunking, use use_short_conv=False to verify state-carry correctness independently.

📐 RMSNorm + Output Gate (§3.2)

After retrieval $o_t = S_t^\top q_t$, the output is normalised and gated:

$$y_t = \sigma!\left(W_\text{up}, W_\text{down}, x_t\right) \odot \text{RMSNorm}(o_t)$$

The RMSNorm prevents magnitude explosion across deep stacks, while the output gate adds expressiveness without extra state cost.

(back to top ↑)


📊 Benchmarks

Run the benchmark suite:

python scripts/benchmark/run_benchmarks.py

The benchmark script measures:

  • FineGrainedGating — throughput across sequence lengths and batch sizes
  • DPLRTransition — two-step state update latency per head configuration
  • KDALayer — end-to-end forward pass latency across hidden dims
# Run tests with coverage
pytest --cov=src tests/ -v

# Run only integration tests
pytest tests/kda/test_integration.py -v

Test distribution:

pie title Test Coverage by Module
    "test_gating.py" : 9
    "test_dplr.py" : 9
    "test_state_manager.py" : 9
    "test_kda_layer.py" : 6
    "test_integration.py" : 12
Loading

(back to top ↑)


🗺️ Project Roadmap

gantt
    title kimi-linear Development Roadmap
    dateFormat  YYYY-MM-DD
    section Core KDA
        FineGrainedGating         :done,    g1, 2025-11-01, 2025-11-15
        DPLRTransition            :done,    g2, 2025-11-15, 2025-12-01
        StateManager              :done,    g3, 2025-12-01, 2025-12-15
        KDALayer (full §3.1–3.2)  :done,    g4, 2025-12-15, 2026-01-15
    section Quality
        Unit + Integration Tests  :done,    q1, 2026-01-15, 2026-02-01
        Structured Spec Comments  :done,    q2, 2026-02-01, 2026-03-01
    section Performance
        Chunkwise Parallelization :active,  p1, 2026-03-01, 2026-06-01
        Triton / CUDA Kernels     :         p2, 2026-06-01, 2026-09-01
    section Integration
        MLA Reference Module      :         i1, 2026-06-01, 2026-08-01
        vLLM / SGLang Plugin      :         i2, 2026-09-01, 2026-12-01
Loading
Phase Goals Target Status
1 — Core KDA Gating, DPLR, State, Layer assembly Q4 2025 ✅ Complete
2 — Quality 45 tests, structured spec comments, Docker Q1 2026 ✅ Complete
3 — Performance Chunkwise parallel, WY rep, UT transform Q2 2026 ✅ Complete
4 — Kernels Triton DPLR kernel dispatch (FLA fallback) Q3 2026 ✅ Complete
5 — Full Hybrid MLA layer, 3:1 stack, vLLM integration Q4 2026 ✅ Complete

(back to top ↑)


📋 Implementation Status

Component Version Stability Tests Known Limitations
FineGrainedGating 1.0 ✅ Stable 9 Low-rank factorisation fixes gate rank at init; no dynamic rank
DPLRTransition 1.0 ✅ Stable 9 Eigen check is Gershgorin heuristic
StateManager 1.0 ✅ Stable 9 Pre-alloc buffer requires max_batch_size at init
KDALayer 1.2 ✅ Stable 6 + 12 integration Short conv needs chunk-carry for exact equivalence
ChunkwiseParallelKDA 1.0 ✅ Stable 11 BT must be power-of-2; caller pads T
WY representation 1.0 ✅ Stable 5 (in chunk tests) O(BT²) Python loop; Triton path via FLA
UT transform 1.0 ✅ Stable 3 (in chunk tests) Rank-BT update grows memory linearly with chunk size
MLALayer 1.0 ✅ Stable 11 Full causal attention; no sparse variant
Triton/CUDA kernels 1.0 ✅ Stable Dispatches to FLA when installed; PyTorch fallback
KDAVLLMAdapter 1.0 ✅ Stable 14 vLLM not required; standalone mode supported

(back to top ↑)


🛠️ Development

Project Structure

kimi-linear/
├── src/
│   └── kda/
│       ├── __init__.py          # Package exports
│       ├── gating.py            # FineGrainedGating (KDA-GATE-*)
│       ├── dplr.py              # DPLRTransition   (KDA-DPLR-*)
│       ├── state_manager.py     # StateManager     (KDA-SM-*)
│       └── kda_layer.py         # KDALayer         (KDA-LAYER-*)
├── tests/
│   └── kda/
│       ├── test_gating.py
│       ├── test_dplr.py
│       ├── test_state_manager.py
│       ├── test_kda_layer.py
│       └── test_integration.py  # End-to-end integration tests
├── scripts/
│   └── benchmark/
│       └── run_benchmarks.py
├── docs/
│   ├── project-plan.md
│   ├── PROJECT_STATUS.md
│   └── IMPLEMENTATION_SUMMARY.md
├── docker/
│   ├── Dockerfile
│   └── Dockerfile.dev
├── requirements.txt
├── setup.py
└── README.md

Code Style

black src/ tests/     # format
pylint src/           # lint
mypy src/             # type-check

Adding a New Component

  1. Create module in src/kda/
  2. Add structured spec comment block (IDs: KDA-<MODULE>-<TYPE>-<NNN>)
  3. Write unit tests in tests/kda/test_<module>.py
  4. Export from src/kda/__init__.py
  5. Add integration coverage in tests/kda/test_integration.py
📐 Spec Comment Format

Every module, class, and public method carries a structured spec block:

# ─────────────────────────────────────────────────────────────────────────────
# METHOD SPEC
# ID:            KDA-MODULE-FWD-001
# Requirement:   One precise, testable statement of what this method must do.
# Purpose:       Why this method exists and what objective it supports.
# Rationale:     Engineering reasoning behind the design choice.
# Inputs:        All arguments: name, type, units, valid ranges.
# Outputs:       Return values: type, shape, constraints.
# Preconditions: What must be true before calling.
# Postconditions:What is guaranteed true after return.
# Side Effects:  State changes, I/O, counters updated.
# Failure Modes: How the method fails and mitigation strategy.
# Verification:  Which tests cover this method.
# References:    Paper sections, standards, or algorithms implemented.
# ─────────────────────────────────────────────────────────────────────────────
🔁 Git Workflow
gitGraph
    commit id: "main"
    branch feature/component
    checkout feature/component
    commit id: "Add module"
    commit id: "Add tests"
    commit id: "Add spec comments"
    checkout main
    merge feature/component id: "PR merge"
    branch feature/kernel
    checkout feature/kernel
    commit id: "Triton kernel"
    checkout main
    merge feature/kernel id: "Kernel merge"
Loading

Branch naming: feature/<component>, fix/<issue>, perf/<scope>.

(back to top ↑)


🤝 Contributing

Contributions are welcome. Please open an issue before starting significant work.

📋 Contribution Guidelines

Workflow

  1. Fork the repository
  2. Create a branch: git checkout -b feature/your-feature
  3. Make changes with tests: pytest tests/ -q
  4. Format: black src/ tests/
  5. Open a pull request against main

Requirements for merge

  • All 45 existing tests must pass
  • New public methods must have a structured spec comment block
  • New components require unit tests + at least one integration test
  • No raw print() — use logging throughout
  • Docstrings not required for private helpers; required for public API

Code conventions

  • self._attr prefix for instrumentation counters (_fwd_calls, _fwd_time_ms)
  • @property shims for any renamed attributes to preserve backward compatibility
  • torch.einsum preferred over explicit matmul for readability at contraction boundaries

(back to top ↑)


📎 Citation

If this implementation is useful in your work, please cite the original paper:

@article{kimiteam2025kimilinear,
  title   = {Kimi Linear: An Expressive, Efficient Attention Architecture},
  author  = {Kimi Team},
  journal = {arXiv preprint arXiv:2510.26692},
  year    = {2025}
}

📄 License

This project is licensed under the MIT License — you are free to use, modify, and distribute it with attribution. See LICENSE for the full text.

The Kimi Linear architecture is described in arXiv:2510.26692 by the Kimi Team.

(back to top ↑)

About

An optimized implementation of the Kimi Linear architecture - a hybrid linear attention mechanism outperforming traditional full attention.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors