From 527e92945cfd32d84103b99ad8eb56027e399c8c Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 15:36:15 +0800 Subject: [PATCH 01/15] docs: add GMM Pallas kernel design for MoE support Phase 1 design for grouped matrix multiplication kernel that replaces tokamax/qwix backends. BF16 with float32 accumulation, custom_vjp for full differentiability. Co-Authored-By: Claude Opus 4.6 --- docs/plans/2026-04-06-gmm-kernel-design.md | 144 +++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 docs/plans/2026-04-06-gmm-kernel-design.md diff --git a/docs/plans/2026-04-06-gmm-kernel-design.md b/docs/plans/2026-04-06-gmm-kernel-design.md new file mode 100644 index 00000000..73377f5c --- /dev/null +++ b/docs/plans/2026-04-06-gmm-kernel-design.md @@ -0,0 +1,144 @@ +# GMM (Grouped Matrix Multiplication) Pallas Kernel Design + +**Date:** 2026-04-06 +**Status:** Approved +**Phase:** 1 (BF16 only, no quantization) + +## Goal + +Implement a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) to support MoE (Mixture-of-Experts) layers. This replaces the `tokamax` and `qwix` backends used in maxtext's megablox with a clean, self-contained implementation following tops conventions. + +## Semantics + +**GMM forward:** For each expert group `i` with rows `[start_i, end_i)`: +``` +out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :] +``` + +**TGMM (transposed GMM, for weight gradients):** For each group `i`: +``` +out[i, :, :] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] +``` + +Tokens in `lhs` are pre-sorted by expert assignment. `group_sizes[i]` gives the number of rows belonging to expert `i`. + +## Public API + +```python +def gmm( + lhs: jnp.ndarray, # [m, k] bf16 - stacked token activations + rhs: jnp.ndarray, # [num_groups, k, n] bf16 - per-expert weights + group_sizes: jnp.ndarray, # [num_groups] int32 - token count per expert + tiling: tuple[int, int, int] = (128, 128, 128), # (tm, tk, tn) + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: # [m, n] bf16 +``` + +Fully differentiable via `jax.custom_vjp`: +- **dlhs** = `gmm(grad, rhs, group_sizes, transpose_rhs=True)` +- **drhs** = `tgmm(lhs, grad, group_sizes)` + +### Internal: `tgmm` + +```python +def tgmm( + lhs: jnp.ndarray, # [m, k] bf16 + rhs: jnp.ndarray, # [m, n] bf16 + group_sizes: jnp.ndarray, # [num_groups] int32 + tiling: tuple[int, int, int] = (128, 128, 128), + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: # [num_groups, k, n] bf16 +``` + +## Kernel Architecture + +### Grid Layout (gmm) + +``` +grid = (tiles_n, num_active_tiles, tiles_k) +dimension_semantics = ("parallel", "arbitrary", "arbitrary") +``` + +- `tiles_n = n // tn` -- parallelized over output columns +- `num_active_tiles` = total m-tiles across all groups (computed by `make_group_metadata`) +- `tiles_k = k // tk` -- sequential reduction dimension + +### Group Metadata (computed on host, passed via scalar prefetch) + +`make_group_metadata(group_sizes, m, tm)` produces: +- `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries +- `group_ids`: maps each active m-tile index to its group +- `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group + +### BlockSpecs + +| Tensor | Block shape | Index map | +|--------|-------------|-----------| +| `lhs` | `[tm, tk]` | `(m_tile_ids[grid_m], grid_k)` | +| `rhs` | `[1, tk, tn]` | `(group_ids[grid_m], grid_k, grid_n)` | +| `out` | `[tm, tn]` | `(m_tile_ids[grid_m], grid_n)` | + +When `transpose_rhs=True`, rhs block shape is `[1, tn, tk]`. + +### Kernel Body (gmm) + +1. Load `lhs_block [tm, tk]` and `rhs_block [tk, tn]` via BlockSpec +2. Accumulate `dot(lhs_block, rhs_block, preferred_element_type=float32)` into VMEM scratch `[tm, tn]` +3. On last k-tile: apply group-boundary mask (zero rows outside the group), cast to output dtype, store + +### Grid Layout (tgmm) + +``` +grid = (tiles_n, tiles_k, num_active_tiles) +dimension_semantics = ("parallel", "arbitrary", "arbitrary") +``` + +### Kernel Body (tgmm) + +For each active m-tile, accumulates `lhs_block^T [tk, tm] @ rhs_block [tm, tn]` into the output for the corresponding group. When the group changes between adjacent m-tiles, the accumulated result is stored and the accumulator reset. + +## Precision + +- Input dtype: bf16 +- Accumulation: float32 (via VMEM scratch) +- Output: cast back to bf16 +- `jax.lax.Precision.HIGHEST` on all dot products + +## File Layout + +``` +tops/ops/gmm/ + __init__.py # Public API: gmm() + gmm.py # Pallas kernels + custom_vjp + tgmm + metadata.py # make_group_metadata() + +tops/cpu/ops/gmm/ + __init__.py + naive.py # Pure JAX reference implementation + +tests/ops/gmm/ + test_gmm_tpu.py # Pallas vs CPU reference tests + conftest.py # GMM test fixtures +``` + +## Testing Strategy + +1. **CPU reference:** Pure JAX loop over groups with plain matmul +2. **Forward test:** Compare Pallas gmm output vs reference across configs +3. **Gradient test:** Compare custom_vjp gradients vs `jax.grad` of reference +4. **Configs:** Vary (m, k, n, num_groups) with distributions: + - Uniform group sizes + - Skewed (one large group, many small) + - Single group (degenerates to plain matmul) + - Empty groups (group_size=0) + - Sizes not divisible by tm +5. **Tolerances:** atol ~1e-2, rtol ~1e-2 (bf16 accumulation) + +## Future Work (Phase 2+) + +- Block-wise quantization: (128,128) for weights, (1,128) for activations +- `group_offset` for expert parallelism / sharded groups +- `existing_out` for accumulation into pre-existing buffers +- Double/triple buffering (`input_buffer_count`) +- Async DMA pipelining From 68ddcb3409985c2d7dd8a110f48e35e3d99b9603 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 15:50:15 +0800 Subject: [PATCH 02/15] docs: add GMM Pallas kernel implementation plan 5-task TDD implementation plan covering CPU reference, group metadata, GMM/TGMM kernels, custom_vjp, and public API with complete code. Co-Authored-By: Claude Opus 4.6 --- docs/plans/2026-04-06-gmm-kernel-impl.md | 1296 ++++++++++++++++++++++ 1 file changed, 1296 insertions(+) create mode 100644 docs/plans/2026-04-06-gmm-kernel-impl.md diff --git a/docs/plans/2026-04-06-gmm-kernel-impl.md b/docs/plans/2026-04-06-gmm-kernel-impl.md new file mode 100644 index 00000000..102b2f1c --- /dev/null +++ b/docs/plans/2026-04-06-gmm-kernel-impl.md @@ -0,0 +1,1296 @@ +# GMM (Grouped Matrix Multiplication) Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Implement a BF16 Pallas TPU kernel for Grouped Matrix Multiplication (GMM/TGMM) with custom_vjp, replacing tokamax/qwix backends used in maxtext megablox. + +**Architecture:** Port megablox's proven grid/metadata strategy (CSR-like tile scheduling) to a clean tops-style implementation. Forward kernel uses grid `(tiles_n, num_active_tiles, tiles_k)` with float32 VMEM accumulation. Backward uses `gmm(transpose_rhs=True)` for dlhs and a separate `tgmm` kernel for drhs. No quantization, no sharding (Phase 1). + +**Tech Stack:** JAX, Pallas (TPU mosaic backend), `jax.experimental.pallas.tpu`, `jax.custom_vjp` + +--- + +### Task 1: CPU Reference Implementation + +**Files:** +- Create: `tops/cpu/ops/gmm/__init__.py` +- Create: `tops/cpu/ops/gmm/naive.py` +- Create: `tests/ops/gmm/__init__.py` +- Create: `tests/ops/gmm/test_cpu_ref.py` + +**Step 1: Create the CPU reference module** + +Create `tops/cpu/ops/gmm/__init__.py`: + +```python +from .naive import gmm_ref, tgmm_ref + +__all__ = ["gmm_ref", "tgmm_ref"] +``` + +Create `tops/cpu/ops/gmm/naive.py`: + +```python +"""Pure JAX CPU reference for Grouped Matrix Multiplication.""" + +import jax +import jax.numpy as jnp + +from tops.cpu.ops import cpu_reference + + +@cpu_reference +def gmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, k, n] but used as + [num_groups, n, k] (transposed before matmul). + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, output_dim] where output_dim = rhs.shape[2] if not transpose_rhs + else rhs.shape[1]. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + + m = lhs.shape[0] + num_groups = rhs.shape[0] + n = rhs.shape[1] if transpose_rhs else rhs.shape[2] + orig_dtype = lhs.dtype + + out = jnp.zeros((m, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_mat = rhs[i].astype(jnp.float32) + if transpose_rhs: + rhs_mat = rhs_mat.T + out = out.at[start:end].set(lhs_slice @ rhs_mat) + start = end + return out.astype(orig_dtype) + + +@cpu_reference +def tgmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Transposed grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Args: + lhs: [m, k] input activations. + rhs: [m, n] gradient or second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group outer products. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + assert lhs.shape[0] == rhs.shape[0], ( + f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" + ) + + k = lhs.shape[1] + n = rhs.shape[1] + num_groups = group_sizes.shape[0] + orig_dtype = lhs.dtype + + out = jnp.zeros((num_groups, k, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_slice = rhs[start:end].astype(jnp.float32) + out = out.at[i].set(lhs_slice.T @ rhs_slice) + start = end + return out.astype(orig_dtype) +``` + +**Step 2: Write tests for CPU reference** + +Create empty `tests/ops/gmm/__init__.py`. + +Create `tests/ops/gmm/test_cpu_ref.py`: + +```python +"""Verify CPU reference implementations for GMM/TGMM are correct.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax +import jax.numpy as jnp +import numpy as np + +from tops.cpu.ops.gmm import gmm_ref, tgmm_ref + + +def _make_gmm_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): + """Generate random inputs for GMM tests.""" + key = jax.random.PRNGKey(seed) + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) + rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) + gs = jnp.array(group_sizes, dtype=jnp.int32) + return lhs, rhs, gs + + +class TestGmmRef: + """Test gmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group = standard matmul.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float32) + gs = jnp.array([2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = lhs # identity matmul + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_two_groups(self): + """Two groups with different weights.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) + rhs = jnp.array([ + [[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2 + [[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns + ], dtype=jnp.float32) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[2.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_empty_group(self): + """Empty group produces zeros for those rows (none exist).""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) + rhs = jnp.array([ + [[1.0], [1.0]], # group 0: empty + [[1.0], [1.0]], # group 1: 1 row + ], dtype=jnp.float32) + gs = jnp.array([0, 1], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[3.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_transpose_rhs(self): + """transpose_rhs transposes each rhs[i] before matmul.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.float32) + gs = jnp.array([1], dtype=jnp.int32) + # Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16] + out_normal = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-5) + # With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17] + out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-5) + + +class TestTgmmRef: + """Test tgmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group: lhs^T @ rhs.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.float32) + gs = jnp.array([2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]] + expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_two_groups(self): + """Two groups produce separate outer products.""" + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.float32) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.float32) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # Group 0: [1]^T @ [4] = [[4]] + # Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]] + expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +``` + +**Step 3: Run tests** + +Run: `uv run pytest tests/ops/gmm/test_cpu_ref.py -v` +Expected: All PASS + +**Step 4: Commit** + +```bash +git add tops/cpu/ops/gmm/ tests/ops/gmm/ +git commit -m "feat(gmm): add CPU reference implementations for gmm and tgmm" +``` + +--- + +### Task 2: Group Metadata Helper + +**Files:** +- Create: `tops/ops/gmm/__init__.py` (empty initially) +- Create: `tops/ops/gmm/metadata.py` +- Create: `tests/ops/gmm/test_metadata.py` + +**Step 1: Write metadata tests** + +Create `tests/ops/gmm/test_metadata.py`: + +```python +"""Test group metadata construction for GMM kernel scheduling.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax.numpy as jnp +import numpy as np + +from tops.ops.gmm.metadata import make_group_metadata + + +class TestMakeGroupMetadata: + """Verify CSR-like metadata maps grid indices to correct groups/tiles.""" + + def test_uniform_groups_aligned(self): + """Groups perfectly aligned to tile boundaries.""" + # 2 groups, 128 rows each, tm=128 -> 1 tile per group, 2 active tiles + gs = jnp.array([128, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=256, tm=128 + ) + assert int(num_tiles) == 2 + np.testing.assert_array_equal(offsets, [0, 128, 256]) + np.testing.assert_array_equal(gids[:2], [0, 1]) + np.testing.assert_array_equal(mids[:2], [0, 1]) + + def test_uniform_groups_multi_tile(self): + """Groups spanning multiple tiles.""" + # 2 groups, 256 rows each, tm=128 -> 2 tiles per group, 4 active tiles + gs = jnp.array([256, 256], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=512, tm=128 + ) + assert int(num_tiles) == 4 + np.testing.assert_array_equal(gids[:4], [0, 0, 1, 1]) + np.testing.assert_array_equal(mids[:4], [0, 1, 2, 3]) + + def test_shared_tile_at_boundary(self): + """Group boundary falls mid-tile -> tile visited twice.""" + # Group 0: 64 rows (not aligned to 128), Group 1: 64 rows + # Tile 0 (rows 0-127) is shared between both groups + gs = jnp.array([64, 64], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128 + ) + # Tile 0 visited twice: once for group 0, once for group 1 + assert int(num_tiles) == 2 + np.testing.assert_array_equal(gids[:2], [0, 1]) + np.testing.assert_array_equal(mids[:2], [0, 0]) + + def test_empty_group(self): + """Empty group (size=0) should not produce active tiles.""" + gs = jnp.array([0, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128 + ) + assert int(num_tiles) == 1 + assert int(gids[0]) == 1 + + def test_visit_empty_groups(self): + """With visit_empty_groups=True, empty groups get one tile each.""" + gs = jnp.array([0, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128, visit_empty_groups=True + ) + assert int(num_tiles) == 2 + np.testing.assert_array_equal(gids[:2], [0, 1]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +``` + +**Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/ops/gmm/test_metadata.py -v` +Expected: FAIL (ImportError - module doesn't exist yet) + +**Step 3: Implement metadata** + +Create empty `tops/ops/gmm/__init__.py`: + +```python +``` + +Create `tops/ops/gmm/metadata.py`: + +```python +"""Group metadata construction for GMM kernel scheduling. + +Builds CSR-like metadata arrays that map Pallas grid indices to (group_id, +m_tile_id) pairs. This enables the GMM kernel to process ragged groups of +varying sizes using a flat 1-D grid over m-tiles. +""" + +from typing import Any + +import jax.numpy as jnp + +GroupMetadata = tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + + +def make_group_metadata( + *, + group_sizes: jnp.ndarray, + m: int, + tm: int, + visit_empty_groups: bool = False, +) -> tuple[GroupMetadata, jnp.ndarray]: + """Build scheduling metadata for grouped matmul. + + Maps each grid index in the ``num_active_tiles`` dimension to a + ``(group_id, m_tile_id)`` pair so the Pallas kernel knows which group + and which row-tile to process. + + Args: + group_sizes: [num_groups] int32 -- number of rows per group. + m: Total number of rows in lhs (may exceed sum(group_sizes) due to + padding). + tm: Row-dimension tile size. + visit_empty_groups: If True, allocate one tile per empty group (needed + by tgmm to zero the output for empty groups). + + Returns: + (group_offsets, group_ids, m_tile_ids): Metadata arrays. + - group_offsets: [num_groups + 1] int32, CSR-style row offsets. + - group_ids: [tiles_m + num_groups - 1] int32, group for each + active tile. + - m_tile_ids: [tiles_m + num_groups - 1] int32, row-tile index + for each active tile. + num_active_tiles: Scalar int32, how many entries in group_ids / + m_tile_ids are valid. + """ + num_groups = group_sizes.shape[0] + + # --- CSR-style offsets --- + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends] + ) + + # --- Round boundaries to tile multiples --- + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] + ) + rounded_group_starts = (group_starts // tm * tm).astype(jnp.int32) + + # --- Tiles per group --- + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) + + tiles_m = (m + tm - 1) // tm + total_len = tiles_m + num_groups - 1 + + # --- group_ids: map grid index -> group --- + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles, + total_repeat_length=total_len, + ) + + # --- m_tile_ids: map grid index -> row-tile --- + # Tiles at group boundaries may be visited twice; count visits per tile. + partial_tile_mask = jnp.logical_or( + (group_offsets[:-1] % tm) == 0, + group_sizes == 0, + ) + if visit_empty_groups: + partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) + + partial_tile_ids = jnp.where( + partial_tile_mask, tiles_m, group_offsets[:-1] // tm + ) + tile_visits = ( + jnp.histogram( + partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1) + )[0] + + 1 + ) + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=total_len, + ) + + num_active_tiles = group_tiles.sum() + return (group_offsets, group_ids, m_tile_ids), num_active_tiles +``` + +**Step 4: Run tests** + +Run: `uv run pytest tests/ops/gmm/test_metadata.py -v` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add tops/ops/gmm/ tests/ops/gmm/test_metadata.py +git commit -m "feat(gmm): add make_group_metadata for kernel scheduling" +``` + +--- + +### Task 3: GMM Forward Pallas Kernel + Tests + +**Files:** +- Create: `tops/ops/gmm/gmm.py` +- Create: `tests/ops/gmm/test_gmm_tpu.py` + +**Step 1: Write forward test** + +Create `tests/ops/gmm/test_gmm_tpu.py`: + +```python +"""GMM Pallas kernel accuracy vs CPU reference. + +Forward: gmm_forward (tops.ops.gmm) vs gmm_ref (tops.cpu.ops.gmm) +Backward (tgmm): tgmm_forward vs tgmm_ref +Gradient: custom_vjp gradients vs jax.grad of CPU reference. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax +import jax.numpy as jnp + +from tops.ops.gmm.gmm import gmm_forward, tgmm_forward +from tops.cpu.ops.gmm import gmm_ref, tgmm_ref +from tests.utils import compare_tensor + + +# ============================================================================ +# Test Helpers +# ============================================================================ + +def _make_group_sizes(num_groups, total_m, distribution="uniform", seed=0): + """Generate group_sizes that sum to total_m.""" + key = jax.random.PRNGKey(seed) + if distribution == "uniform": + base = total_m // num_groups + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + remainder = total_m - base * num_groups + sizes = sizes.at[:remainder].add(1) + elif distribution == "skewed": + # First group gets half, rest split the remainder + first = total_m // 2 + rest = total_m - first + base = rest // (num_groups - 1) if num_groups > 1 else 0 + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + sizes = sizes.at[0].set(first) + remainder = total_m - first - base * (num_groups - 1) + if num_groups > 1: + sizes = sizes.at[1].add(remainder) + elif distribution == "single": + sizes = jnp.zeros(num_groups, dtype=jnp.int32) + sizes = sizes.at[0].set(total_m) + elif distribution == "with_empty": + # First group empty, rest uniform + base = total_m // (num_groups - 1) if num_groups > 1 else total_m + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + sizes = sizes.at[0].set(0) + remainder = total_m - base * (num_groups - 1) + sizes = sizes.at[1].add(remainder) + else: + raise ValueError(f"Unknown distribution: {distribution}") + return sizes + + +def _make_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): + """Generate random lhs, rhs for GMM.""" + key = jax.random.PRNGKey(seed) + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) + rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) + return lhs, rhs + + +# ============================================================================ +# Forward Test Cases +# ============================================================================ + +FWD_CASES = [ + # (m, k, n, num_groups, distribution, seed) + dict(m=128, k=128, n=128, ng=1, dist="single", seed=100), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=101), + dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=102), + dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=103), + dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=104), + dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=105), + dict(m=1024, k=128, n=128, ng=8, dist="uniform", seed=106), + dict(m=640, k=128, n=128, ng=5, dist="skewed", seed=107), +] + + +def _fwd_case_id(c): + return f"m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", FWD_CASES, ids=[_fwd_case_id(c) for c in FWD_CASES]) +def test_gmm_fwd_vs_cpu(cfg): + """gmm_forward (Pallas) should match gmm_ref (CPU).""" + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"]) + + out_ref = gmm_ref(lhs, rhs, gs) + out_pl = gmm_forward(lhs, rhs, gs) + + assert compare_tensor("gmm_fwd", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) + + +@pytest.mark.parametrize("cfg", FWD_CASES[:4], ids=[_fwd_case_id(c) for c in FWD_CASES[:4]]) +def test_gmm_fwd_transpose_rhs(cfg): + """gmm_forward with transpose_rhs should match gmm_ref.""" + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + lhs, rhs = _make_inputs(m, n, k, ng, gs, seed=cfg["seed"]) + + out_ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + out_pl = gmm_forward(lhs, rhs, gs, transpose_rhs=True) + + assert compare_tensor("gmm_fwd_T", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +``` + +**Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_fwd_vs_cpu -v --no-header -x` +Expected: FAIL (ImportError - `gmm_forward` doesn't exist) + +**Step 3: Implement GMM forward kernel** + +Create `tops/ops/gmm/gmm.py`: + +```python +"""Pallas TPU kernels for Grouped Matrix Multiplication. + +Implements gmm (forward) and tgmm (transposed, for weight gradients) +using the megablox-style grid scheduling with CSR-like group metadata. +""" + +import functools + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +from tops.ops.gmm.metadata import make_group_metadata, GroupMetadata +from tops.ops.utils import get_interpret + + +# ============================================================================ +# Helpers +# ============================================================================ + +def _get_store_mask( + *, + grid_id: jnp.ndarray, + group_metadata: GroupMetadata, + tm: int, + tn: int, +) -> jnp.ndarray: + """Boolean mask [tm, tn] for rows belonging to the current group.""" + group_offsets, group_ids, m_tile_ids = group_metadata + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + m_id = m_tile_ids[grid_id] * tm + iota = lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id + return jnp.logical_and(iota >= group_start, iota < group_end) + + +def _get_group_size( + *, grid_id: jnp.ndarray, group_metadata: GroupMetadata +) -> jnp.ndarray: + """Number of rows in the current group.""" + group_offsets, group_ids = group_metadata[:2] + group_id = group_ids[grid_id] + return group_offsets[group_id + 1] - group_offsets[group_id] + + +# ============================================================================ +# GMM Forward +# ============================================================================ + +@functools.partial( + jax.jit, + static_argnames=["tiling", "transpose_rhs", "preferred_element_type", "interpret"], +) +def gmm_forward( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + tiling: tuple[int, int, int] = (128, 128, 128), + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, + interpret: bool | None = None, +) -> jnp.ndarray: + """Grouped matrix multiplication: out[group_rows] = lhs[group_rows] @ rhs[group]. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + When transpose_rhs=True, rhs[i] is transposed before the matmul. + + Args: + lhs: [m, k] bf16 input activations, rows sorted by group. + rhs: [num_groups, k, n] bf16 per-group weights. + group_sizes: [num_groups] int32, row count per group. sum <= m. + tiling: (tm, tk, tn) tile sizes. k must be divisible by tk, + n must be divisible by tn. + transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). + preferred_element_type: Output dtype for accumulation. + interpret: Pallas interpret mode. None = auto-detect from env. + + Returns: + [m, output_n] tensor where output_n = rhs.shape[2] if not + transpose_rhs, else rhs.shape[1]. + """ + if interpret is None: + interpret = get_interpret() + + # --- Validate --- + assert lhs.ndim == 2, f"lhs must be 2D [m, k], got {lhs.ndim}D" + assert rhs.ndim == 3, f"rhs must be 3D [E, k, n], got {rhs.ndim}D" + assert group_sizes.dtype == jnp.int32, ( + f"group_sizes must be int32, got {group_sizes.dtype}" + ) + + # --- Shape info --- + m, k = lhs.shape + if transpose_rhs: + n = rhs.shape[1] + else: + n = rhs.shape[2] + + tm, tk, tn = tiling + assert k % tk == 0, f"k ({k}) must be divisible by tk ({tk})" + assert n % tn == 0, f"n ({n}) must be divisible by tn ({tn})" + + tiles_k = k // tk + tiles_n = n // tn + + # --- Group metadata --- + group_metadata, num_active_tiles = make_group_metadata( + group_sizes=group_sizes, m=m, tm=tm, visit_empty_groups=False + ) + + # --- Kernel --- + def kernel( + group_metadata_ref, + lhs_ref, + rhs_ref, + out_ref, + acc_ref, + ): + group_offsets, group_ids, m_tile_ids = group_metadata_ref + grid_id = pl.program_id(1) + k_i = pl.program_id(2) + + @pl.when(k_i == 0) + def _zero_acc(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + lhs_block = lhs_ref[...] + rhs_block = rhs_ref[...] + + if transpose_rhs: + dims = (((1,), (1,)), ((), ())) + else: + dims = (((1,), (0,)), ((), ())) + + acc_ref[...] += lax.dot_general( + lhs_block, + rhs_block, + dimension_numbers=dims, + preferred_element_type=jnp.float32, + ) + + @pl.when(k_i == tiles_k - 1) + def _store(): + mask = _get_store_mask( + grid_id=grid_id, + group_metadata=(group_offsets, group_ids, m_tile_ids), + tm=tm, + tn=tn, + ) + out_ref[...] = lax.select( + mask, acc_ref[...], out_ref[...].astype(jnp.float32) + ).astype(preferred_element_type) + + # --- Index maps --- + def lhs_index_map(n_i, grid_id, k_i, group_metadata_ref): + _, _, m_tile_ids = group_metadata_ref + return m_tile_ids[grid_id], k_i + + def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): + _, group_ids, _ = group_metadata_ref + if transpose_rhs: + return group_ids[grid_id], n_i, k_i + else: + return group_ids[grid_id], k_i, n_i + + def out_index_map(n_i, grid_id, k_i, group_metadata_ref): + _, _, m_tile_ids = group_metadata_ref + return m_tile_ids[grid_id], n_i + + # --- BlockSpecs --- + lhs_spec = pl.BlockSpec((tm, tk), lhs_index_map) + if transpose_rhs: + rhs_spec = pl.BlockSpec((None, tn, tk), rhs_index_map) + else: + rhs_spec = pl.BlockSpec((None, tk, tn), rhs_index_map) + out_spec = pl.BlockSpec((tm, tn), out_index_map) + + # --- Launch --- + call_fn = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, + grid=(tiles_n, num_active_tiles, tiles_k), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary") + ), + interpret=interpret, + ) + + out = call_fn(group_metadata, lhs, rhs) + return out + + +# ============================================================================ +# TGMM (Transposed GMM for weight gradients) +# ============================================================================ + +@functools.partial( + jax.jit, + static_argnames=["tiling", "preferred_element_type", "interpret"], +) +def tgmm_forward( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + tiling: tuple[int, int, int] = (128, 128, 128), + preferred_element_type: jnp.dtype = jnp.float32, + interpret: bool | None = None, +) -> jnp.ndarray: + """Transposed grouped matrix multiplication for weight gradients. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Args: + lhs: [m, k] bf16 input activations. + rhs: [m, n] bf16 gradients. + group_sizes: [num_groups] int32. + tiling: (tm, tk, tn) tile sizes. + preferred_element_type: Output dtype. + interpret: Pallas interpret mode. + + Returns: + [num_groups, k, n] per-group weight gradients. + """ + if interpret is None: + interpret = get_interpret() + + # --- Validate --- + assert lhs.ndim == 2, f"lhs must be 2D [m, k], got {lhs.ndim}D" + assert rhs.ndim == 2, f"rhs must be 2D [m, n], got {rhs.ndim}D" + assert lhs.shape[0] == rhs.shape[0], ( + f"lhs and rhs must have same m, got {lhs.shape[0]} vs {rhs.shape[0]}" + ) + assert group_sizes.dtype == jnp.int32 + + # --- Shape info --- + m = lhs.shape[0] + k = lhs.shape[1] + n = rhs.shape[1] + num_groups = group_sizes.shape[0] + + tm, tk, tn = tiling + assert k % tk == 0, f"k ({k}) must be divisible by tk ({tk})" + assert n % tn == 0, f"n ({n}) must be divisible by tn ({tn})" + + tiles_k = k // tk + tiles_n = n // tn + + # --- Group metadata --- + group_metadata, num_active_tiles = make_group_metadata( + group_sizes=group_sizes, m=m, tm=tm, visit_empty_groups=True + ) + + # --- Kernel --- + def kernel( + group_metadata_ref, + lhs_ref, + rhs_ref, + out_ref, + acc_ref, + ): + group_offsets, group_ids, m_tile_ids = group_metadata_ref + grid_id = pl.program_id(2) + + group = group_ids[grid_id] + prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) + prev_group = group_ids[prev_grid_id] + group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group) + + @pl.when(group_has_changed) + def _zero_acc(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + # Only compute if group has rows + has_rows = ( + _get_group_size(grid_id=grid_id, group_metadata=(group_offsets, group_ids, m_tile_ids)) + > 0 + ) + + @pl.when(has_rows) + def _compute(): + # Mask rows outside group + lhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=(group_offsets, group_ids, m_tile_ids), + tm=tm, + tn=tk, + ) + rhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=(group_offsets, group_ids, m_tile_ids), + tm=tm, + tn=tn, + ) + loaded_lhs = lax.select(lhs_mask, lhs_ref[...], jnp.zeros_like(lhs_ref)) + loaded_rhs = lax.select(rhs_mask, rhs_ref[...], jnp.zeros_like(rhs_ref)) + + # lhs^T [tk, tm] @ rhs [tm, tn] = [tk, tn] + acc_ref[...] += lax.dot( + loaded_lhs.swapaxes(0, 1), + loaded_rhs, + preferred_element_type=jnp.float32, + ) + + # Store when group is about to change + is_end = grid_id == (pl.num_programs(2) - 1) + next_grid_id = jnp.where(is_end, grid_id, grid_id + 1) + next_group = group_ids[next_grid_id] + group_is_changing = jnp.logical_or(is_end, group != next_group) + + @pl.when(group_is_changing) + def _store(): + out_ref[...] = acc_ref[...].astype(preferred_element_type) + + # --- Index maps --- + def lhs_index_map(n_i, k_i, grid_id, group_metadata_ref): + _, _, m_tile_ids = group_metadata_ref + return m_tile_ids[grid_id], k_i + + def rhs_index_map(n_i, k_i, grid_id, group_metadata_ref): + _, _, m_tile_ids = group_metadata_ref + return m_tile_ids[grid_id], n_i + + def out_index_map(n_i, k_i, grid_id, group_metadata_ref): + _, group_ids, _ = group_metadata_ref + return group_ids[grid_id], k_i, n_i + + # --- BlockSpecs --- + lhs_spec = pl.BlockSpec((tm, tk), lhs_index_map) + rhs_spec = pl.BlockSpec((tm, tn), rhs_index_map) + out_spec = pl.BlockSpec((None, tk, tn), out_index_map) + + # --- Launch --- + call_fn = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct( + (num_groups, k, n), preferred_element_type + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, + grid=(tiles_n, tiles_k, num_active_tiles), + scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary") + ), + interpret=interpret, + ) + + out = call_fn(group_metadata, lhs, rhs) + return out +``` + +**Step 4: Run forward tests** + +Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_fwd_vs_cpu -v` +Expected: All PASS + +**Step 5: Commit** + +```bash +git add tops/ops/gmm/gmm.py tests/ops/gmm/test_gmm_tpu.py +git commit -m "feat(gmm): add GMM forward Pallas kernel" +``` + +--- + +### Task 4: TGMM Tests + +**Files:** +- Modify: `tests/ops/gmm/test_gmm_tpu.py` + +**Step 1: Add TGMM tests to test_gmm_tpu.py** + +Append to `tests/ops/gmm/test_gmm_tpu.py`: + +```python +# ============================================================================ +# TGMM Test Cases +# ============================================================================ + +TGMM_CASES = [ + dict(m=128, k=128, n=128, ng=1, dist="single", seed=200), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=201), + dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=202), + dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=203), + dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=204), + dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=205), +] + + +def _tgmm_case_id(c): + return f"tgmm_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", TGMM_CASES, ids=[_tgmm_case_id(c) for c in TGMM_CASES]) +def test_tgmm_vs_cpu(cfg): + """tgmm_forward (Pallas) should match tgmm_ref (CPU).""" + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + + key = jax.random.PRNGKey(cfg["seed"]) + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) + rhs = jax.random.normal(k2, (m, n), dtype=jnp.float32).astype(jnp.bfloat16) + + out_ref = tgmm_ref(lhs, rhs, gs) + out_pl = tgmm_forward(lhs, rhs, gs) + + assert compare_tensor("tgmm", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) +``` + +**Step 2: Run TGMM tests** + +Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_tgmm_vs_cpu -v` +Expected: All PASS + +**Step 3: Commit** + +```bash +git add tests/ops/gmm/test_gmm_tpu.py +git commit -m "test(gmm): add tgmm accuracy tests" +``` + +--- + +### Task 5: Custom VJP + Public API + Gradient Tests + +**Files:** +- Modify: `tops/ops/gmm/__init__.py` +- Modify: `tops/ops/gmm/gmm.py` (add custom_vjp wrapper) +- Modify: `tops/ops/__init__.py` +- Modify: `tests/ops/gmm/test_gmm_tpu.py` (add gradient tests) + +**Step 1: Add gradient tests to test_gmm_tpu.py** + +Append to `tests/ops/gmm/test_gmm_tpu.py`: + +```python +# ============================================================================ +# Gradient Tests +# ============================================================================ + +from tops.ops.gmm import gmm + + +def _gmm_ref_differentiable(lhs, rhs, group_sizes): + """Differentiable CPU reference for gradient comparison.""" + m = lhs.shape[0] + num_groups = rhs.shape[0] + n = rhs.shape[2] + out = jnp.zeros((m, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_mat = rhs[i].astype(jnp.float32) + out = out.at[start:end].set(lhs_slice @ rhs_mat) + start = end + return out.sum() + + +GRAD_CASES = [ + dict(m=128, k=128, n=128, ng=1, dist="single", seed=300), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=301), + dict(m=384, k=128, n=128, ng=3, dist="skewed", seed=302), +] + + +def _grad_case_id(c): + return f"grad_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", GRAD_CASES, ids=[_grad_case_id(c) for c in GRAD_CASES]) +def test_gmm_gradient(cfg): + """custom_vjp gradients should match numerical/reference gradients.""" + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"], dtype=jnp.bfloat16) + + # Pallas gmm gradients + def pallas_loss(lhs, rhs): + return gmm(lhs, rhs, gs).sum() + + dlhs_pl, drhs_pl = jax.grad(pallas_loss, argnums=(0, 1))(lhs, rhs) + + # Reference gradients (CPU) + def ref_loss(lhs, rhs): + return _gmm_ref_differentiable(lhs, rhs, gs) + + dlhs_ref, drhs_ref = jax.grad(ref_loss, argnums=(0, 1))(lhs, rhs) + + assert compare_tensor("dlhs", dlhs_ref, dlhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) + assert compare_tensor("drhs", drhs_ref, drhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) +``` + +**Step 2: Run gradient test to verify it fails** + +Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_gradient -v --no-header -x` +Expected: FAIL (ImportError - `tops.ops.gmm.gmm` function doesn't exist) + +**Step 3: Add custom_vjp to gmm.py** + +Add to end of `tops/ops/gmm/gmm.py`: + +```python +# ============================================================================ +# Differentiable GMM with custom_vjp +# ============================================================================ + +@functools.partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + tiling: tuple[int, int, int] = (128, 128, 128), + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + """Differentiable grouped matrix multiplication. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Supports automatic differentiation via custom_vjp: + - dlhs: computed via gmm with transposed rhs + - drhs: computed via tgmm + + Args: + lhs: [m, k] bf16 input activations, rows sorted by group. + rhs: [num_groups, k, n] bf16 per-group weights. + group_sizes: [num_groups] int32, row count per group. + tiling: (tm, tk, tn) tile sizes. + transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). + preferred_element_type: Output dtype. + + Returns: + [m, output_n] bf16 tensor. + """ + return _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type)[0] + + +def _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type): + out = gmm_forward( + lhs, rhs, group_sizes, + tiling=tiling, + transpose_rhs=transpose_rhs, + preferred_element_type=preferred_element_type, + ) + return out, (lhs, rhs, group_sizes) + + +def _gmm_bwd(tiling, transpose_rhs, preferred_element_type, residuals, grad): + lhs, rhs, group_sizes = residuals + + # dlhs = grad @ rhs^T per group + dlhs = gmm_forward( + grad, rhs, group_sizes, + tiling=tiling, + transpose_rhs=not transpose_rhs, + preferred_element_type=preferred_element_type, + ).astype(lhs.dtype) + + # drhs = lhs^T @ grad per group + drhs = tgmm_forward( + lhs, grad, group_sizes, + tiling=tiling, + preferred_element_type=preferred_element_type, + ).astype(rhs.dtype) + + return dlhs, drhs, None + + +gmm.defvjp(_gmm_fwd, _gmm_bwd) +``` + +**Step 4: Create public API** + +Update `tops/ops/gmm/__init__.py`: + +```python +"""Public API for grouped matrix multiplication.""" + +from .gmm import gmm, gmm_forward, tgmm_forward + +__all__ = ["gmm", "gmm_forward", "tgmm_forward"] +``` + +Update `tops/ops/__init__.py` to add gmm: + +```python +"""Public API for tops.ops. + +All public interfaces are exported exclusively via this file. +Any interface not re-exported here is considered an internal implementation +detail with **no API stability guarantee**. +""" + +from .simple_gla import simple_gla +from .gmm import gmm + +__all__ = [ + "simple_gla", + "gmm", +] +``` + +**Step 5: Run gradient tests** + +Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_gradient -v` +Expected: All PASS + +**Step 6: Run all GMM tests** + +Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/ -v` +Expected: All PASS + +**Step 7: Lint** + +Run: `uv run ruff check tops/ops/gmm/ tops/cpu/ops/gmm/ tests/ops/gmm/` +Run: `uv run ruff format tops/ops/gmm/ tops/cpu/ops/gmm/ tests/ops/gmm/` + +**Step 8: Commit** + +```bash +git add tops/ops/gmm/ tops/ops/__init__.py tests/ops/gmm/ +git commit -m "feat(gmm): add custom_vjp wrapper and public API for differentiable GMM" +``` + +--- + +## Implementation Notes + +**Key differences from megablox:** +- No `qwix`/`qpl` dependency -- uses standard `lax.dot_general` and `pl.pallas_call` +- No `group_offset` / sharding support (Phase 1) +- No `existing_out` / input_output_aliases +- No `LutFn` tiling lookup -- static tuple only +- `num_scalar_prefetch=1` (only group_metadata, no group_offset) +- Clean tops-style assertions and docstrings + +**Testing approach:** +- CPU reference (`gmm_ref`, `tgmm_ref`) as ground truth +- `PALLAS_INTERPRET=1` for CPU-based Pallas testing during development +- On TPU: run without PALLAS_INTERPRET for native kernel execution +- Tolerance: atol=1e-2, rtol=1e-2 (bf16 accumulation in tiles) +- Gradient tolerance: atol=5e-2 (backward has compounded error) + +**File dependency graph:** +``` +tops/ops/gmm/metadata.py <- no internal deps +tops/ops/gmm/gmm.py <- metadata.py, tops.ops.utils +tops/ops/gmm/__init__.py <- gmm.py +tops/cpu/ops/gmm/naive.py <- tops.cpu.ops (cpu_reference) +tops/cpu/ops/gmm/__init__.py <- naive.py +``` From f2df07cd2639e35865536fd7e61f7b9b50ec9a49 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 15:55:51 +0800 Subject: [PATCH 03/15] feat(gmm): add CPU reference implementations for gmm and tgmm Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/__init__.py | 0 tests/ops/gmm/test_cpu_ref.py | 98 +++++++++++++++++++++++++++++++++++ tops/cpu/ops/gmm/__init__.py | 3 ++ tops/cpu/ops/gmm/naive.py | 98 +++++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 tests/ops/gmm/__init__.py create mode 100644 tests/ops/gmm/test_cpu_ref.py create mode 100644 tops/cpu/ops/gmm/__init__.py create mode 100644 tops/cpu/ops/gmm/naive.py diff --git a/tests/ops/gmm/__init__.py b/tests/ops/gmm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ops/gmm/test_cpu_ref.py b/tests/ops/gmm/test_cpu_ref.py new file mode 100644 index 00000000..db1fa3d9 --- /dev/null +++ b/tests/ops/gmm/test_cpu_ref.py @@ -0,0 +1,98 @@ +"""Verify CPU reference implementations for GMM/TGMM are correct.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax.numpy as jnp +import numpy as np +import pytest + +from tops.cpu.ops.gmm import gmm_ref, tgmm_ref + + +class TestGmmRef: + """Test gmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group = standard matmul.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float32) + gs = jnp.array([2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = lhs # identity matmul + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_two_groups(self): + """Two groups with different weights.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) + rhs = jnp.array( + [ + [[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2 + [[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns + ], + dtype=jnp.float32, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[2.0, 0.0], [1.0, 0.0], [1.0, 1.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_empty_group(self): + """Empty group produces zeros for those rows (none exist).""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) + rhs = jnp.array( + [ + [[1.0], [1.0]], # group 0: empty + [[1.0], [1.0]], # group 1: 1 row + ], + dtype=jnp.float32, + ) + gs = jnp.array([0, 1], dtype=jnp.int32) + out = gmm_ref(lhs, rhs, gs) + expected = jnp.array([[3.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_transpose_rhs(self): + """transpose_rhs transposes each rhs[i] before matmul.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.float32) + gs = jnp.array([1], dtype=jnp.int32) + # Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16] + out_normal = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-5) + # With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17] + out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-5) + + +class TestTgmmRef: + """Test tgmm_ref against manual numpy computation.""" + + def test_single_group(self): + """Single group: lhs^T @ rhs.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.float32) + gs = jnp.array([2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]] + expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + def test_two_groups(self): + """Two groups produce separate outer products.""" + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.float32) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.float32) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = tgmm_ref(lhs, rhs, gs) + # Group 0: [1]^T @ [4] = [[4]] + # Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]] + expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/cpu/ops/gmm/__init__.py b/tops/cpu/ops/gmm/__init__.py new file mode 100644 index 00000000..f8379c46 --- /dev/null +++ b/tops/cpu/ops/gmm/__init__.py @@ -0,0 +1,3 @@ +from .naive import gmm_ref, tgmm_ref + +__all__ = ["gmm_ref", "tgmm_ref"] diff --git a/tops/cpu/ops/gmm/naive.py b/tops/cpu/ops/gmm/naive.py new file mode 100644 index 00000000..fa106501 --- /dev/null +++ b/tops/cpu/ops/gmm/naive.py @@ -0,0 +1,98 @@ +"""Pure JAX CPU reference for Grouped Matrix Multiplication.""" + +import jax +import jax.numpy as jnp + +from tops.cpu.ops import cpu_reference + + +@cpu_reference +def gmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, k, n] but used as + [num_groups, n, k] (transposed before matmul). + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, output_dim] where output_dim = rhs.shape[2] if not transpose_rhs + else rhs.shape[1]. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + + m = lhs.shape[0] + num_groups = rhs.shape[0] + n = rhs.shape[1] if transpose_rhs else rhs.shape[2] + orig_dtype = lhs.dtype + + out = jnp.zeros((m, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_mat = rhs[i].astype(jnp.float32) + if transpose_rhs: + rhs_mat = rhs_mat.T + out = out.at[start:end].set(lhs_slice @ rhs_mat) + start = end + return out.astype(orig_dtype) + + +@cpu_reference +def tgmm_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Transposed grouped matrix multiplication reference implementation. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Args: + lhs: [m, k] input activations. + rhs: [m, n] gradient or second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group outer products. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + assert lhs.shape[0] == rhs.shape[0], ( + f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" + ) + + k = lhs.shape[1] + n = rhs.shape[1] + num_groups = group_sizes.shape[0] + orig_dtype = lhs.dtype + + out = jnp.zeros((num_groups, k, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_slice = rhs[start:end].astype(jnp.float32) + out = out.at[i].set(lhs_slice.T @ rhs_slice) + start = end + return out.astype(orig_dtype) From fb5d636e504820065402854c313118430761739c Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:01:43 +0800 Subject: [PATCH 04/15] feat(gmm): add make_group_metadata for kernel scheduling Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_metadata.py | 76 ++++++++++++++++++++++ tops/ops/gmm/__init__.py | 0 tops/ops/gmm/metadata.py | 111 +++++++++++++++++++++++++++++++++ 3 files changed, 187 insertions(+) create mode 100644 tests/ops/gmm/test_metadata.py create mode 100644 tops/ops/gmm/__init__.py create mode 100644 tops/ops/gmm/metadata.py diff --git a/tests/ops/gmm/test_metadata.py b/tests/ops/gmm/test_metadata.py new file mode 100644 index 00000000..4ef2dbbf --- /dev/null +++ b/tests/ops/gmm/test_metadata.py @@ -0,0 +1,76 @@ +"""Test group metadata construction for GMM kernel scheduling.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax.numpy as jnp +import numpy as np + +from tops.ops.gmm.metadata import make_group_metadata + + +class TestMakeGroupMetadata: + """Verify CSR-like metadata maps grid indices to correct groups/tiles.""" + + def test_uniform_groups_aligned(self): + """Groups perfectly aligned to tile boundaries.""" + # 2 groups, 128 rows each, tm=128 -> 1 tile per group, 2 active tiles + gs = jnp.array([128, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=256, tm=128 + ) + assert int(num_tiles) == 2 + np.testing.assert_array_equal(offsets, [0, 128, 256]) + np.testing.assert_array_equal(gids[:2], [0, 1]) + np.testing.assert_array_equal(mids[:2], [0, 1]) + + def test_uniform_groups_multi_tile(self): + """Groups spanning multiple tiles.""" + # 2 groups, 256 rows each, tm=128 -> 2 tiles per group, 4 active tiles + gs = jnp.array([256, 256], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=512, tm=128 + ) + assert int(num_tiles) == 4 + np.testing.assert_array_equal(gids[:4], [0, 0, 1, 1]) + np.testing.assert_array_equal(mids[:4], [0, 1, 2, 3]) + + def test_shared_tile_at_boundary(self): + """Group boundary falls mid-tile -> tile visited twice.""" + # Group 0: 64 rows (not aligned to 128), Group 1: 64 rows + # Tile 0 (rows 0-127) is shared between both groups + gs = jnp.array([64, 64], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128 + ) + # Tile 0 visited twice: once for group 0, once for group 1 + assert int(num_tiles) == 2 + np.testing.assert_array_equal(gids[:2], [0, 1]) + np.testing.assert_array_equal(mids[:2], [0, 0]) + + def test_empty_group(self): + """Empty group (size=0) should not produce active tiles.""" + gs = jnp.array([0, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128 + ) + assert int(num_tiles) == 1 + assert int(gids[0]) == 1 + + def test_visit_empty_groups(self): + """With visit_empty_groups=True, empty groups get one tile each.""" + gs = jnp.array([0, 128], dtype=jnp.int32) + (offsets, gids, mids), num_tiles = make_group_metadata( + group_sizes=gs, m=128, tm=128, visit_empty_groups=True + ) + assert int(num_tiles) == 2 + np.testing.assert_array_equal(gids[:2], [0, 1]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/ops/gmm/__init__.py b/tops/ops/gmm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tops/ops/gmm/metadata.py b/tops/ops/gmm/metadata.py new file mode 100644 index 00000000..46e288cb --- /dev/null +++ b/tops/ops/gmm/metadata.py @@ -0,0 +1,111 @@ +"""Group metadata construction for GMM kernel scheduling. + +Builds CSR-like metadata arrays that map Pallas grid indices to (group_id, +m_tile_id) pairs. This enables the GMM kernel to process ragged groups of +varying sizes using a flat 1-D grid over m-tiles. +""" + +from __future__ import annotations + +import jax.numpy as jnp + +GroupMetadata = tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] + + +def make_group_metadata( + *, + group_sizes: jnp.ndarray, + m: int, + tm: int, + visit_empty_groups: bool = False, +) -> tuple[GroupMetadata, jnp.ndarray]: + """Build scheduling metadata for grouped matmul. + + Maps each grid index in the ``num_active_tiles`` dimension to a + ``(group_id, m_tile_id)`` pair so the Pallas kernel knows which group + and which row-tile to process. + + Args: + group_sizes: [num_groups] int32 -- number of rows per group. + m: Total number of rows in lhs (may exceed sum(group_sizes) due to + padding). + tm: Row-dimension tile size. + visit_empty_groups: If True, allocate one tile per empty group (needed + by tgmm to zero the output for empty groups). + + Returns: + (group_offsets, group_ids, m_tile_ids): Metadata arrays. + - group_offsets: [num_groups + 1] int32, CSR-style row offsets. + - group_ids: [tiles_m + num_groups - 1] int32, group for each + active tile. + - m_tile_ids: [tiles_m + num_groups - 1] int32, row-tile index + for each active tile. + num_active_tiles: Scalar int32, how many entries in group_ids / + m_tile_ids are valid. + """ + assert group_sizes.ndim == 1, "group_sizes must be 1-D" + assert m > 0, "m must be positive" + assert tm > 0, "tm must be positive" + + num_groups = group_sizes.shape[0] + + # --- CSR-style offsets --- + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) + + # --- Compute tile ranges for each group --- + group_starts = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]) + + # First tile index touched by each group (floor division) + first_tile = group_starts // tm + # Last tile index touched by each group (ceil division - 1, i.e. inclusive) + # For empty groups, last_tile < first_tile so they produce 0 tiles. + last_tile_plus_one = (group_ends + tm - 1) // tm + # Clamp empty groups to produce 0 tiles + tiles_per_group = jnp.where( + group_sizes == 0, + 0, + last_tile_plus_one - first_tile, + ).astype(jnp.int32) + + if visit_empty_groups: + tiles_per_group = jnp.where(group_sizes == 0, 1, tiles_per_group) + + tiles_m = (m + tm - 1) // tm + # Worst case: each group boundary can split a tile, adding at most + # (num_groups - 1) extra visits. + total_len = tiles_m + num_groups - 1 + + # --- group_ids: map grid index -> group --- + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + tiles_per_group, + total_repeat_length=total_len, + ) + + # --- m_tile_ids: map grid index -> row-tile --- + # For each group, the tile indices are first_tile, first_tile+1, ..., + # first_tile + tiles_per_group - 1. + # We build this by creating a per-slot offset within the group, then + # adding the group's first_tile. + + # First, compute the starting offset for each group's tiles in the + # output array using cumsum of tiles_per_group. + group_tile_offsets = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(tiles_per_group)] + ) + + # For empty groups with visit_empty_groups, first_tile doesn't make sense; + # use 0 as placeholder (the tile id for empty group visits). + effective_first_tile = jnp.where(group_sizes == 0, 0, first_tile).astype(jnp.int32) + + # Build m_tile_ids using a scatter approach: + # For each active slot i, m_tile_ids[i] = first_tile[group_ids[i]] + local_offset + # where local_offset = i - group_tile_offsets[group_ids[i]] + + slot_indices = jnp.arange(total_len, dtype=jnp.int32) + local_offsets = slot_indices - group_tile_offsets[group_ids] + m_tile_ids = effective_first_tile[group_ids] + local_offsets + + num_active_tiles = tiles_per_group.sum() + return (group_offsets, group_ids, m_tile_ids), num_active_tiles From ed871dbcea0b55919b77c524b5031c812026be94 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:18:51 +0800 Subject: [PATCH 05/15] feat(gmm): implement GMM and TGMM Pallas TPU kernels Add two core Pallas kernel functions for grouped matrix multiplication: - gmm_forward: forward pass computing per-group matmuls with support for transposed rhs and data-dependent indexing via scalar prefetch - tgmm_forward: transposed GMM for weight gradients, accumulating lhs^T @ rhs per group with proper boundary masking Both kernels use PrefetchScalarGridSpec with group metadata for efficient ragged-batch processing. Tests validate accuracy against CPU reference implementations across uniform, skewed, single-group, and with-empty-group distributions using PALLAS_INTERPRET mode. Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_gmm_tpu.py | 147 +++++++++ tops/ops/gmm/gmm.py | 555 ++++++++++++++++++++++++++++++++++ 2 files changed, 702 insertions(+) create mode 100644 tests/ops/gmm/test_gmm_tpu.py create mode 100644 tops/ops/gmm/gmm.py diff --git a/tests/ops/gmm/test_gmm_tpu.py b/tests/ops/gmm/test_gmm_tpu.py new file mode 100644 index 00000000..2de3d4a6 --- /dev/null +++ b/tests/ops/gmm/test_gmm_tpu.py @@ -0,0 +1,147 @@ +"""GMM Pallas kernel accuracy vs CPU reference.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import pytest +import jax +import jax.numpy as jnp + +import numpy as np + +from tops.ops.gmm.gmm import gmm_forward, tgmm_forward +from tops.cpu.ops.gmm import gmm_ref, tgmm_ref + + +def compare_tensor(name, gold, tensor, atol=1e-5, rtol=1e-5, max_ulp=1): + """Lightweight compare_tensor that works without torch.""" + if isinstance(gold, jax.Array): + gold = np.array(gold).astype(np.float64) + if isinstance(tensor, jax.Array): + tensor = np.array(tensor).astype(np.float64) + if gold.shape != tensor.shape: + print(f"[{name}] Shape mismatch: {gold.shape} vs {tensor.shape}. FAIL.") + return False + diff = np.abs(gold - tensor) + max_diff = np.max(diff) + max_val = np.max(np.abs(tensor)) + is_close = np.allclose(gold, tensor, atol=atol, rtol=rtol, equal_nan=True) + status = "PASS" if is_close else "FAIL" + print(f"[{name}] {status} max_val={max_val:.6e} max_diff={max_diff:.6e}") + return is_close + + +# Helpers +def _make_group_sizes(num_groups, total_m, distribution="uniform", seed=0): + key = jax.random.PRNGKey(seed) + if distribution == "uniform": + base = total_m // num_groups + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + remainder = total_m - base * num_groups + sizes = sizes.at[:remainder].add(1) + elif distribution == "skewed": + first = total_m // 2 + rest = total_m - first + base = rest // (num_groups - 1) if num_groups > 1 else 0 + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + sizes = sizes.at[0].set(first) + remainder = total_m - first - base * (num_groups - 1) + if num_groups > 1: + sizes = sizes.at[1].add(remainder) + elif distribution == "single": + sizes = jnp.zeros(num_groups, dtype=jnp.int32) + sizes = sizes.at[0].set(total_m) + elif distribution == "with_empty": + base = total_m // (num_groups - 1) if num_groups > 1 else total_m + sizes = jnp.full(num_groups, base, dtype=jnp.int32) + sizes = sizes.at[0].set(0) + remainder = total_m - base * (num_groups - 1) + sizes = sizes.at[1].add(remainder) + return sizes + + +def _make_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): + key = jax.random.PRNGKey(seed) + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) + rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) + return lhs, rhs + + +# Forward test cases +FWD_CASES = [ + dict(m=128, k=128, n=128, ng=1, dist="single", seed=100), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=101), + dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=102), + dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=103), + dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=104), + dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=105), + dict(m=1024, k=128, n=128, ng=8, dist="uniform", seed=106), + dict(m=640, k=128, n=128, ng=5, dist="skewed", seed=107), +] + + +def _fwd_case_id(c): + return f"m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", FWD_CASES, ids=[_fwd_case_id(c) for c in FWD_CASES]) +def test_gmm_fwd_vs_cpu(cfg): + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"]) + out_ref = gmm_ref(lhs, rhs, gs) + out_pl = gmm_forward(lhs, rhs, gs) + assert compare_tensor("gmm_fwd", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) + + +@pytest.mark.parametrize( + "cfg", FWD_CASES[:4], ids=[_fwd_case_id(c) for c in FWD_CASES[:4]] +) +def test_gmm_fwd_transpose_rhs(cfg): + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + key = jax.random.PRNGKey(cfg["seed"]) + k1, k2 = jax.random.split(key) + # For transpose_rhs: lhs [m, k], rhs [ng, n, k] (rhs is transposed inside) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) + rhs = jax.random.normal(k2, (ng, n, k), dtype=jnp.float32).astype(jnp.bfloat16) + out_ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + out_pl = gmm_forward(lhs, rhs, gs, transpose_rhs=True) + assert compare_tensor("gmm_fwd_T", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) + + +# TGMM test cases +TGMM_CASES = [ + dict(m=128, k=128, n=128, ng=1, dist="single", seed=200), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=201), + dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=202), + dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=203), + dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=204), + dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=205), +] + + +def _tgmm_case_id(c): + return f"tgmm_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", TGMM_CASES, ids=[_tgmm_case_id(c) for c in TGMM_CASES]) +def test_tgmm_vs_cpu(cfg): + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + key = jax.random.PRNGKey(cfg["seed"]) + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) + rhs = jax.random.normal(k2, (m, n), dtype=jnp.float32).astype(jnp.bfloat16) + out_ref = tgmm_ref(lhs, rhs, gs) + out_pl = tgmm_forward(lhs, rhs, gs) + assert compare_tensor("tgmm", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py new file mode 100644 index 00000000..4fc2d77c --- /dev/null +++ b/tops/ops/gmm/gmm.py @@ -0,0 +1,555 @@ +"""Pallas TPU kernels for Grouped Matrix Multiplication (GMM). + +Provides two core operations: + +* ``gmm_forward`` -- forward grouped matmul: + for each group *i* with rows ``[start_i, end_i)``: + ``out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i]`` + +* ``tgmm_forward`` -- transposed grouped matmul (weight gradient): + for each group *i* with rows ``[start_i, end_i)``: + ``out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :]`` +""" + +from __future__ import annotations + +import functools +from typing import Any + +import jax +import jax.lax as lax +import jax.numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +from tops.ops.gmm.metadata import make_group_metadata +from tops.ops.utils import get_interpret + + +# --------------------------------------------------------------------------- +# Tiling helpers +# --------------------------------------------------------------------------- + + +def _validate_tiling( + tiling: tuple[int, int, int] | None, m: int, k: int, n: int +) -> tuple[int, int, int]: + """Return validated (tm, tk, tn) tile sizes. + + Args: + tiling: User-provided ``(tm, tk, tn)`` or *None* for defaults. + m: Total rows in lhs. + k: Contraction dimension. + n: Output columns (or rhs last dim). + + Returns: + ``(tm, tk, tn)`` with each dimension clamped to the actual size. + """ + if tiling is None: + tm = min(128, m) + tk = min(128, k) + tn = min(128, n) + else: + tm, tk, tn = tiling + return (tm, tk, tn) + + +# --------------------------------------------------------------------------- +# Masking helper +# --------------------------------------------------------------------------- + + +def _get_store_mask( + group_metadata_ref, + group_offsets_ref, + m_tile_ids_ref, + group_ids_ref, + grid_id: int, + tm: int, + tn: int, +) -> jax.Array: + """Build a ``[tm, tn]`` boolean mask that is True for valid rows. + + A tile at a group boundary may contain rows belonging to two groups. + Only the rows that belong to the current group should be stored. + + Args: + group_metadata_ref: Scalar-prefetch ref holding + ``(group_offsets, group_ids, m_tile_ids)``. + group_offsets_ref: Ref into group_offsets array. + m_tile_ids_ref: Ref into m_tile_ids array. + group_ids_ref: Ref into group_ids array. + grid_id: Current index in the ``num_active_tiles`` grid dimension. + tm: Row tile size. + tn: Column tile size. + + Returns: + Boolean array of shape ``[tm, tn]``. + """ + group_id = group_ids_ref[grid_id] + group_start = group_offsets_ref[group_id] + group_end = group_offsets_ref[group_id + 1] + + m_tile_id = m_tile_ids_ref[grid_id] + row_start = m_tile_id * tm + row_indices = row_start + jnp.arange(tm, dtype=jnp.int32) + + valid_rows = (row_indices >= group_start) & (row_indices < group_end) + # Broadcast to [tm, tn] + return jnp.broadcast_to(valid_rows[:, None], (tm, tn)) + + +# =================================================================== +# GMM Forward +# =================================================================== + + +def _gmm_kernel( + # Scalar prefetch ref + group_metadata_ref, + # Input refs + lhs_ref, + rhs_ref, + # Output ref + out_ref, + # Scratch ref + acc_ref, + *, + tm: int, + tk: int, + tn: int, + tiles_k: int, + preferred_element_type: Any, + transpose_rhs: bool, +): + """GMM forward kernel body. + + Args: + group_metadata_ref: Scalar-prefetch ref containing + ``(group_offsets, group_ids, m_tile_ids)`` as a flat array. + lhs_ref: VMEM ref ``[tm, tk]``. + rhs_ref: VMEM ref ``[tk, tn]`` (or ``[tn, tk]`` when transpose_rhs). + out_ref: VMEM ref ``[tm, tn]``. + acc_ref: VMEM scratch ``[tm, tn]`` float32. + tm: Row tile size. + tk: K-dimension tile size. + tn: Column tile size. + tiles_k: Number of tiles along K. + preferred_element_type: Output dtype. + transpose_rhs: Whether rhs is stored transposed. + """ + n_i, grid_id, k_i = pl.program_id(0), pl.program_id(1), pl.program_id(2) + + # Unpack metadata from the scalar prefetch ref. + # group_metadata_ref is a tuple of 3 refs: (group_offsets, group_ids, m_tile_ids) + group_offsets_ref = group_metadata_ref[0] + group_ids_ref = group_metadata_ref[1] + m_tile_ids_ref = group_metadata_ref[2] + + # Zero accumulator on first k-tile. + @pl.when(k_i == 0) + def _zero(): + acc_ref[...] = jnp.zeros((tm, tn), dtype=jnp.float32) + + # Load blocks and accumulate. + lhs_block = lhs_ref[...].astype(jnp.float32) + rhs_block = rhs_ref[...].astype(jnp.float32) + + if transpose_rhs: + dims = ((1,), (1,)), ((), ()) + else: + dims = ((1,), (0,)), ((), ()) + + acc_ref[...] += lax.dot_general(lhs_block, rhs_block, dims) + + # On last k-tile: apply mask and store. + @pl.when(k_i == tiles_k - 1) + def _store(): + mask = _get_store_mask( + group_metadata_ref, + group_offsets_ref, + m_tile_ids_ref, + group_ids_ref, + grid_id, + tm, + tn, + ) + acc = acc_ref[...] + existing = out_ref[...].astype(jnp.float32) + result = lax.select(mask, acc, existing) + out_ref[...] = result.astype(preferred_element_type) + + +def gmm_forward( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + tiling: tuple[int, int, int] | None = None, + transpose_rhs: bool = False, + preferred_element_type: Any = None, + interpret: bool | None = None, +) -> jax.Array: + """Grouped matrix multiplication forward pass. + + Computes per-group matmuls where rows of ``lhs`` are partitioned into + groups defined by ``group_sizes``, and each group multiplies against the + corresponding weight matrix in ``rhs``. + + Semantics:: + + for group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Args: + lhs: ``[m, k]`` input activations in bfloat16 / float32. + rhs: ``[num_groups, k, n]`` per-group weight matrices. + When *transpose_rhs* is True the layout is ``[num_groups, n, k]`` + and each slice is transposed before multiplication. + group_sizes: ``[num_groups]`` int32 giving the number of rows per group. + tiling: ``(tm, tk, tn)`` tile sizes or *None* for auto. + transpose_rhs: If True, each ``rhs[i]`` is ``[n, k]`` and transposed. + preferred_element_type: Output dtype; defaults to ``lhs.dtype``. + interpret: If *None* falls back to ``PALLAS_INTERPRET`` env var. + + Returns: + ``[m, n]`` output tensor. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + + if interpret is None: + interpret = get_interpret() + if preferred_element_type is None: + preferred_element_type = lhs.dtype + + m, k_lhs = lhs.shape + num_groups = rhs.shape[0] + + if transpose_rhs: + n, k_rhs = rhs.shape[1], rhs.shape[2] + else: + k_rhs, n = rhs.shape[1], rhs.shape[2] + assert k_lhs == k_rhs, f"lhs K ({k_lhs}) must match rhs K ({k_rhs})" + + k = k_lhs + tm, tk, tn = _validate_tiling(tiling, m, k, n) + + # Pad m to multiple of tm. + m_padded = ((m + tm - 1) // tm) * tm + if m_padded > m: + lhs = jnp.pad(lhs, ((0, m_padded - m), (0, 0))) + + # Pad k to multiple of tk. + k_padded = ((k + tk - 1) // tk) * tk + if k_padded > k: + lhs = jnp.pad(lhs, ((0, 0), (0, k_padded - k))) + if transpose_rhs: + rhs = jnp.pad(rhs, ((0, 0), (0, 0), (0, k_padded - k))) + else: + rhs = jnp.pad(rhs, ((0, 0), (0, k_padded - k), (0, 0))) + + # Pad n to multiple of tn. + n_padded = ((n + tn - 1) // tn) * tn + if n_padded > n: + if transpose_rhs: + rhs = jnp.pad(rhs, ((0, 0), (0, n_padded - n), (0, 0))) + else: + rhs = jnp.pad(rhs, ((0, 0), (0, 0), (0, n_padded - n))) + + tiles_k = k_padded // tk + tiles_n = n_padded // tn + + # Build group metadata. + group_metadata, num_active_tiles = make_group_metadata( + group_sizes=group_sizes, + m=m_padded, + tm=tm, + ) + group_offsets, group_ids, m_tile_ids = group_metadata + + grid = (tiles_n, int(num_active_tiles), tiles_k) + + # --- Index maps --- + # Note: BlockSpec multiplies returned block indices by the corresponding + # block_shape dimension. Axes with ``None`` in block_shape are element + # indices (not multiplied). + def lhs_index_map(n_i, grid_id, k_i, group_metadata_ref): + del n_i + m_tile_ids_ref = group_metadata_ref[2] + return (m_tile_ids_ref[grid_id], k_i) + + if transpose_rhs: + rhs_block_shape = (None, tn, tk) + + def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): + group_ids_ref = group_metadata_ref[1] + return (group_ids_ref[grid_id], n_i, k_i) + else: + rhs_block_shape = (None, tk, tn) + + def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): + group_ids_ref = group_metadata_ref[1] + return (group_ids_ref[grid_id], k_i, n_i) + + def out_index_map(n_i, grid_id, k_i, group_metadata_ref): + del k_i + m_tile_ids_ref = group_metadata_ref[2] + return (m_tile_ids_ref[grid_id], n_i) + + out_shape = jax.ShapeDtypeStruct((m_padded, n_padded), preferred_element_type) + + kernel_fn = functools.partial( + _gmm_kernel, + tm=tm, + tk=tk, + tn=tn, + tiles_k=tiles_k, + preferred_element_type=preferred_element_type, + transpose_rhs=transpose_rhs, + ) + + result = pl.pallas_call( + kernel_fn, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=grid, + in_specs=[ + pl.BlockSpec((tm, tk), lhs_index_map), + pl.BlockSpec(rhs_block_shape, rhs_index_map), + ], + out_specs=pl.BlockSpec((tm, tn), out_index_map), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + ), + out_shape=out_shape, + interpret=interpret, + )( + (group_offsets, group_ids, m_tile_ids), + lhs, + rhs, + ) + + # Un-pad output. + return result[:m, :n] + + +# =================================================================== +# TGMM Forward (Transposed Grouped Matrix Multiplication) +# =================================================================== + + +def _tgmm_kernel( + # Scalar prefetch ref + group_metadata_ref, + # Input refs + lhs_ref, + rhs_ref, + # Output ref + out_ref, + # Scratch ref + acc_ref, + *, + tm: int, + tk: int, + tn: int, + num_active_tiles: int, +): + """TGMM kernel body. + + Accumulates ``lhs^T @ rhs`` per group into the output. + + Args: + group_metadata_ref: Scalar-prefetch ref containing + ``(group_offsets, group_ids, m_tile_ids)``. + lhs_ref: VMEM ref ``[tm, tk]``. + rhs_ref: VMEM ref ``[tm, tn]``. + out_ref: VMEM ref ``[tk, tn]``. + acc_ref: VMEM scratch ``[tk, tn]`` float32. + tm: Row tile size. + tk: K-dimension tile size. + tn: Column tile size. + num_active_tiles: Total number of active tiles along the m dimension. + """ + n_i, k_i, grid_id = pl.program_id(0), pl.program_id(1), pl.program_id(2) + + group_offsets_ref = group_metadata_ref[0] + group_ids_ref = group_metadata_ref[1] + m_tile_ids_ref = group_metadata_ref[2] + + group_id = group_ids_ref[grid_id] + + # Determine if this is the first tile for the current group. + prev_group_id = jnp.where(grid_id > 0, group_ids_ref[grid_id - 1], -1) + is_new_group = (group_id != prev_group_id) | (grid_id == 0) + + # Zero accumulator when entering a new group. + @pl.when(is_new_group) + def _zero(): + acc_ref[...] = jnp.zeros((tk, tn), dtype=jnp.float32) + + # Compute row mask for group boundaries. + m_tile_id = m_tile_ids_ref[grid_id] + group_start = group_offsets_ref[group_id] + group_end = group_offsets_ref[group_id + 1] + row_start = m_tile_id * tm + row_indices = row_start + jnp.arange(tm, dtype=jnp.int32) + valid_rows = (row_indices >= group_start) & (row_indices < group_end) + + # Mask lhs and rhs rows outside group boundaries. + lhs_block = lhs_ref[...].astype(jnp.float32) + rhs_block = rhs_ref[...].astype(jnp.float32) + + row_mask_lhs = jnp.broadcast_to(valid_rows[:, None], (tm, tk)) + row_mask_rhs = jnp.broadcast_to(valid_rows[:, None], (tm, tn)) + + lhs_block = jnp.where(row_mask_lhs, lhs_block, 0.0) + rhs_block = jnp.where(row_mask_rhs, rhs_block, 0.0) + + # Accumulate: lhs^T @ rhs -> [tk, tn] + acc_ref[...] += lax.dot(lhs_block.swapaxes(0, 1), rhs_block) + + # Store when group is about to change or this is the last active tile. + next_group_id = jnp.where( + grid_id < num_active_tiles - 1, + group_ids_ref[grid_id + 1], + -1, + ) + is_last_for_group = (group_id != next_group_id) | (grid_id == num_active_tiles - 1) + + @pl.when(is_last_for_group) + def _store(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +def tgmm_forward( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + tiling: tuple[int, int, int] | None = None, + preferred_element_type: Any = None, + interpret: bool | None = None, +) -> jax.Array: + """Transposed grouped matrix multiplication forward pass. + + Computes per-group weight gradients: for each group *i* with rows + ``[start_i, end_i)``: + + ``out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :]`` + + Args: + lhs: ``[m, k]`` input activations in bfloat16 / float32. + rhs: ``[m, n]`` gradient tensor (or second operand). + group_sizes: ``[num_groups]`` int32, number of rows per group. + tiling: ``(tm, tk, tn)`` tile sizes or *None* for auto. + preferred_element_type: Output dtype; defaults to ``lhs.dtype``. + interpret: If *None* falls back to ``PALLAS_INTERPRET`` env var. + + Returns: + ``[num_groups, k, n]`` per-group outer products. + """ + assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" + assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" + assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" + assert lhs.shape[0] == rhs.shape[0], ( + f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" + ) + + if interpret is None: + interpret = get_interpret() + if preferred_element_type is None: + preferred_element_type = lhs.dtype + + m, k = lhs.shape + n = rhs.shape[1] + num_groups = group_sizes.shape[0] + tm, tk, tn = _validate_tiling(tiling, m, k, n) + + # Pad m to multiple of tm. + m_padded = ((m + tm - 1) // tm) * tm + if m_padded > m: + lhs = jnp.pad(lhs, ((0, m_padded - m), (0, 0))) + rhs = jnp.pad(rhs, ((0, m_padded - m), (0, 0))) + + # Pad k to multiple of tk. + k_padded = ((k + tk - 1) // tk) * tk + if k_padded > k: + lhs = jnp.pad(lhs, ((0, 0), (0, k_padded - k))) + + # Pad n to multiple of tn. + n_padded = ((n + tn - 1) // tn) * tn + if n_padded > n: + rhs = jnp.pad(rhs, ((0, 0), (0, n_padded - n))) + + tiles_k = k_padded // tk + tiles_n = n_padded // tn + + # Build group metadata with visit_empty_groups=True. + group_metadata, num_active_tiles_val = make_group_metadata( + group_sizes=group_sizes, + m=m_padded, + tm=tm, + visit_empty_groups=True, + ) + group_offsets, group_ids, m_tile_ids = group_metadata + + num_active_tiles_int = int(num_active_tiles_val) + grid = (tiles_n, tiles_k, num_active_tiles_int) + + # --- Index maps --- + # Note: BlockSpec multiplies returned block indices by the corresponding + # block_shape dimension. Axes with ``None`` in block_shape are element + # indices (not multiplied). + def lhs_index_map(n_i, k_i, grid_id, group_metadata_ref): + del n_i + m_tile_ids_ref = group_metadata_ref[2] + return (m_tile_ids_ref[grid_id], k_i) + + def rhs_index_map(n_i, k_i, grid_id, group_metadata_ref): + del k_i + m_tile_ids_ref = group_metadata_ref[2] + return (m_tile_ids_ref[grid_id], n_i) + + def out_index_map(n_i, k_i, grid_id, group_metadata_ref): + group_ids_ref = group_metadata_ref[1] + return (group_ids_ref[grid_id], k_i, n_i) + + out_shape = jax.ShapeDtypeStruct( + (num_groups, k_padded, n_padded), preferred_element_type + ) + + kernel_fn = functools.partial( + _tgmm_kernel, + tm=tm, + tk=tk, + tn=tn, + num_active_tiles=num_active_tiles_int, + ) + + result = pl.pallas_call( + kernel_fn, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=grid, + in_specs=[ + pl.BlockSpec((tm, tk), lhs_index_map), + pl.BlockSpec((tm, tn), rhs_index_map), + ], + out_specs=pl.BlockSpec((None, tk, tn), out_index_map), + scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + ), + out_shape=out_shape, + interpret=interpret, + )( + (group_offsets, group_ids, m_tile_ids), + lhs, + rhs, + ) + + # Un-pad output. + return result[:, :k, :n] From 5c2587e8f3627b92d8cf3af62636d4ea55c4d615 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:23:30 +0800 Subject: [PATCH 06/15] fix(gmm): add @jax.jit decorators, fix traced grid dimensions - Add @jax.jit with static_argnames to gmm_forward and tgmm_forward - Remove int() casts on num_active_tiles (traced values under JIT) - Use pl.num_programs(2) in tgmm kernel instead of captured constant Co-Authored-By: Claude Opus 4.6 --- tops/ops/gmm/gmm.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py index 4fc2d77c..c2ff7261 100644 --- a/tops/ops/gmm/gmm.py +++ b/tops/ops/gmm/gmm.py @@ -180,6 +180,10 @@ def _store(): out_ref[...] = result.astype(preferred_element_type) +@functools.partial( + jax.jit, + static_argnames=["tiling", "transpose_rhs", "preferred_element_type", "interpret"], +) def gmm_forward( lhs: jax.Array, rhs: jax.Array, @@ -268,7 +272,7 @@ def gmm_forward( ) group_offsets, group_ids, m_tile_ids = group_metadata - grid = (tiles_n, int(num_active_tiles), tiles_k) + grid = (tiles_n, num_active_tiles, tiles_k) # --- Index maps --- # Note: BlockSpec multiplies returned block indices by the corresponding @@ -355,7 +359,6 @@ def _tgmm_kernel( tm: int, tk: int, tn: int, - num_active_tiles: int, ): """TGMM kernel body. @@ -371,7 +374,6 @@ def _tgmm_kernel( tm: Row tile size. tk: K-dimension tile size. tn: Column tile size. - num_active_tiles: Total number of active tiles along the m dimension. """ n_i, k_i, grid_id = pl.program_id(0), pl.program_id(1), pl.program_id(2) @@ -412,18 +414,23 @@ def _zero(): acc_ref[...] += lax.dot(lhs_block.swapaxes(0, 1), rhs_block) # Store when group is about to change or this is the last active tile. + num_active = pl.num_programs(2) next_group_id = jnp.where( - grid_id < num_active_tiles - 1, + grid_id < num_active - 1, group_ids_ref[grid_id + 1], -1, ) - is_last_for_group = (group_id != next_group_id) | (grid_id == num_active_tiles - 1) + is_last_for_group = (group_id != next_group_id) | (grid_id == num_active - 1) @pl.when(is_last_for_group) def _store(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) +@functools.partial( + jax.jit, + static_argnames=["tiling", "preferred_element_type", "interpret"], +) def tgmm_forward( lhs: jax.Array, rhs: jax.Array, @@ -495,8 +502,7 @@ def tgmm_forward( ) group_offsets, group_ids, m_tile_ids = group_metadata - num_active_tiles_int = int(num_active_tiles_val) - grid = (tiles_n, tiles_k, num_active_tiles_int) + grid = (tiles_n, tiles_k, num_active_tiles_val) # --- Index maps --- # Note: BlockSpec multiplies returned block indices by the corresponding @@ -525,7 +531,6 @@ def out_index_map(n_i, k_i, grid_id, group_metadata_ref): tm=tm, tk=tk, tn=tn, - num_active_tiles=num_active_tiles_int, ) result = pl.pallas_call( From 9e2ffa58b3d2fc9463cb40cf82a61f662500ed77 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:25:54 +0800 Subject: [PATCH 07/15] feat(gmm): add differentiable gmm with custom_vjp and public API Wire up custom_vjp for the gmm function so it supports jax.grad: - dlhs computed via gmm_forward with transposed rhs - drhs computed via tgmm_forward Export gmm, gmm_forward, tgmm_forward from tops.ops.gmm and add gmm to tops.ops public API. Add 3 gradient test cases comparing Pallas backward against a differentiable JAX reference. Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_gmm_tpu.py | 58 +++++++++++++++++++++++++ tops/ops/__init__.py | 2 + tops/ops/gmm/__init__.py | 5 +++ tops/ops/gmm/gmm.py | 79 +++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+) diff --git a/tests/ops/gmm/test_gmm_tpu.py b/tests/ops/gmm/test_gmm_tpu.py index 2de3d4a6..f6bf718f 100644 --- a/tests/ops/gmm/test_gmm_tpu.py +++ b/tests/ops/gmm/test_gmm_tpu.py @@ -14,6 +14,7 @@ import numpy as np from tops.ops.gmm.gmm import gmm_forward, tgmm_forward +from tops.ops.gmm import gmm from tops.cpu.ops.gmm import gmm_ref, tgmm_ref @@ -143,5 +144,62 @@ def test_tgmm_vs_cpu(cfg): assert compare_tensor("tgmm", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) +# ============================================================================ +# Gradient Tests +# ============================================================================ + + +def _gmm_ref_differentiable(lhs, rhs, group_sizes): + """Differentiable reference for gradient comparison (not @cpu_reference).""" + m = lhs.shape[0] + num_groups = rhs.shape[0] + n = rhs.shape[2] + out = jnp.zeros((m, n), dtype=jnp.float32) + start = 0 + for i in range(num_groups): + size = int(group_sizes[i]) + end = start + size + if size > 0: + lhs_slice = lhs[start:end].astype(jnp.float32) + rhs_mat = rhs[i].astype(jnp.float32) + out = out.at[start:end].set(lhs_slice @ rhs_mat) + start = end + return out.sum() + + +GRAD_CASES = [ + dict(m=128, k=128, n=128, ng=1, dist="single", seed=300), + dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=301), + dict(m=384, k=128, n=128, ng=3, dist="skewed", seed=302), +] + + +def _grad_case_id(c): + return f"grad_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" + + +@pytest.mark.parametrize("cfg", GRAD_CASES, ids=[_grad_case_id(c) for c in GRAD_CASES]) +def test_gmm_gradient(cfg): + """custom_vjp gradients should match reference gradients.""" + m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] + gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) + lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"], dtype=jnp.bfloat16) + + # Pallas gmm gradients + def pallas_loss(lhs, rhs): + return gmm(lhs, rhs, gs).sum() + + dlhs_pl, drhs_pl = jax.grad(pallas_loss, argnums=(0, 1))(lhs, rhs) + + # Reference gradients + def ref_loss(lhs, rhs): + return _gmm_ref_differentiable(lhs, rhs, gs) + + dlhs_ref, drhs_ref = jax.grad(ref_loss, argnums=(0, 1))(lhs, rhs) + + assert compare_tensor("dlhs", dlhs_ref, dlhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) + assert compare_tensor("drhs", drhs_ref, drhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tops/ops/__init__.py b/tops/ops/__init__.py index fc529e87..12b6a4c8 100644 --- a/tops/ops/__init__.py +++ b/tops/ops/__init__.py @@ -5,8 +5,10 @@ detail with **no API stability guarantee**. """ +from .gmm import gmm from .simple_gla import simple_gla __all__ = [ + "gmm", "simple_gla", ] diff --git a/tops/ops/gmm/__init__.py b/tops/ops/gmm/__init__.py index e69de29b..a5d357f4 100644 --- a/tops/ops/gmm/__init__.py +++ b/tops/ops/gmm/__init__.py @@ -0,0 +1,5 @@ +"""Public API for grouped matrix multiplication.""" + +from .gmm import gmm, gmm_forward, tgmm_forward + +__all__ = ["gmm", "gmm_forward", "tgmm_forward"] diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py index c2ff7261..2539bd5d 100644 --- a/tops/ops/gmm/gmm.py +++ b/tops/ops/gmm/gmm.py @@ -558,3 +558,82 @@ def out_index_map(n_i, k_i, grid_id, group_metadata_ref): # Un-pad output. return result[:, :k, :n] + + +# ============================================================================ +# Differentiable GMM with custom_vjp +# ============================================================================ + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) +def gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + tiling: tuple[int, int, int] = (128, 128, 128), + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + """Differentiable grouped matrix multiplication. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Supports automatic differentiation via custom_vjp: + - dlhs: computed via gmm with transposed rhs + - drhs: computed via tgmm + + Args: + lhs: [m, k] bf16 input activations, rows sorted by group. + rhs: [num_groups, k, n] bf16 per-group weights. + group_sizes: [num_groups] int32, row count per group. + tiling: (tm, tk, tn) tile sizes. + transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). + preferred_element_type: Output dtype. + + Returns: + [m, output_n] tensor. + """ + return _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type)[ + 0 + ] + + +def _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type): + out = gmm_forward( + lhs, + rhs, + group_sizes, + tiling=tiling, + transpose_rhs=transpose_rhs, + preferred_element_type=preferred_element_type, + ) + return out, (lhs, rhs, group_sizes) + + +def _gmm_bwd(tiling, transpose_rhs, preferred_element_type, residuals, grad): + lhs, rhs, group_sizes = residuals + + # dlhs = grad @ rhs^T per group + dlhs = gmm_forward( + grad, + rhs, + group_sizes, + tiling=tiling, + transpose_rhs=not transpose_rhs, + preferred_element_type=preferred_element_type, + ).astype(lhs.dtype) + + # drhs = lhs^T @ grad per group + drhs = tgmm_forward( + lhs, + grad, + group_sizes, + tiling=tiling, + preferred_element_type=preferred_element_type, + ).astype(rhs.dtype) + + return dlhs, drhs, None + + +gmm.defvjp(_gmm_fwd, _gmm_bwd) From 22b84a9a0947ab7facf0c7195640b721d428c071 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:52:33 +0800 Subject: [PATCH 08/15] chore: remove planning docs from PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Design and implementation plan documents are not needed in the final codebase — the code and tests serve as the authoritative reference. Co-Authored-By: Claude Opus 4.6 --- docs/plans/2026-04-06-gmm-kernel-design.md | 144 --- docs/plans/2026-04-06-gmm-kernel-impl.md | 1296 -------------------- 2 files changed, 1440 deletions(-) delete mode 100644 docs/plans/2026-04-06-gmm-kernel-design.md delete mode 100644 docs/plans/2026-04-06-gmm-kernel-impl.md diff --git a/docs/plans/2026-04-06-gmm-kernel-design.md b/docs/plans/2026-04-06-gmm-kernel-design.md deleted file mode 100644 index 73377f5c..00000000 --- a/docs/plans/2026-04-06-gmm-kernel-design.md +++ /dev/null @@ -1,144 +0,0 @@ -# GMM (Grouped Matrix Multiplication) Pallas Kernel Design - -**Date:** 2026-04-06 -**Status:** Approved -**Phase:** 1 (BF16 only, no quantization) - -## Goal - -Implement a Pallas TPU kernel for Grouped Matrix Multiplication (GMM) to support MoE (Mixture-of-Experts) layers. This replaces the `tokamax` and `qwix` backends used in maxtext's megablox with a clean, self-contained implementation following tops conventions. - -## Semantics - -**GMM forward:** For each expert group `i` with rows `[start_i, end_i)`: -``` -out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i, :, :] -``` - -**TGMM (transposed GMM, for weight gradients):** For each group `i`: -``` -out[i, :, :] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] -``` - -Tokens in `lhs` are pre-sorted by expert assignment. `group_sizes[i]` gives the number of rows belonging to expert `i`. - -## Public API - -```python -def gmm( - lhs: jnp.ndarray, # [m, k] bf16 - stacked token activations - rhs: jnp.ndarray, # [num_groups, k, n] bf16 - per-expert weights - group_sizes: jnp.ndarray, # [num_groups] int32 - token count per expert - tiling: tuple[int, int, int] = (128, 128, 128), # (tm, tk, tn) - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, -) -> jnp.ndarray: # [m, n] bf16 -``` - -Fully differentiable via `jax.custom_vjp`: -- **dlhs** = `gmm(grad, rhs, group_sizes, transpose_rhs=True)` -- **drhs** = `tgmm(lhs, grad, group_sizes)` - -### Internal: `tgmm` - -```python -def tgmm( - lhs: jnp.ndarray, # [m, k] bf16 - rhs: jnp.ndarray, # [m, n] bf16 - group_sizes: jnp.ndarray, # [num_groups] int32 - tiling: tuple[int, int, int] = (128, 128, 128), - preferred_element_type: jnp.dtype = jnp.float32, -) -> jnp.ndarray: # [num_groups, k, n] bf16 -``` - -## Kernel Architecture - -### Grid Layout (gmm) - -``` -grid = (tiles_n, num_active_tiles, tiles_k) -dimension_semantics = ("parallel", "arbitrary", "arbitrary") -``` - -- `tiles_n = n // tn` -- parallelized over output columns -- `num_active_tiles` = total m-tiles across all groups (computed by `make_group_metadata`) -- `tiles_k = k // tk` -- sequential reduction dimension - -### Group Metadata (computed on host, passed via scalar prefetch) - -`make_group_metadata(group_sizes, m, tm)` produces: -- `group_offsets`: CSR-style cumulative row offsets, rounded to tm boundaries -- `group_ids`: maps each active m-tile index to its group -- `m_tile_ids`: maps each active m-tile index to its row-tile offset within the group - -### BlockSpecs - -| Tensor | Block shape | Index map | -|--------|-------------|-----------| -| `lhs` | `[tm, tk]` | `(m_tile_ids[grid_m], grid_k)` | -| `rhs` | `[1, tk, tn]` | `(group_ids[grid_m], grid_k, grid_n)` | -| `out` | `[tm, tn]` | `(m_tile_ids[grid_m], grid_n)` | - -When `transpose_rhs=True`, rhs block shape is `[1, tn, tk]`. - -### Kernel Body (gmm) - -1. Load `lhs_block [tm, tk]` and `rhs_block [tk, tn]` via BlockSpec -2. Accumulate `dot(lhs_block, rhs_block, preferred_element_type=float32)` into VMEM scratch `[tm, tn]` -3. On last k-tile: apply group-boundary mask (zero rows outside the group), cast to output dtype, store - -### Grid Layout (tgmm) - -``` -grid = (tiles_n, tiles_k, num_active_tiles) -dimension_semantics = ("parallel", "arbitrary", "arbitrary") -``` - -### Kernel Body (tgmm) - -For each active m-tile, accumulates `lhs_block^T [tk, tm] @ rhs_block [tm, tn]` into the output for the corresponding group. When the group changes between adjacent m-tiles, the accumulated result is stored and the accumulator reset. - -## Precision - -- Input dtype: bf16 -- Accumulation: float32 (via VMEM scratch) -- Output: cast back to bf16 -- `jax.lax.Precision.HIGHEST` on all dot products - -## File Layout - -``` -tops/ops/gmm/ - __init__.py # Public API: gmm() - gmm.py # Pallas kernels + custom_vjp + tgmm - metadata.py # make_group_metadata() - -tops/cpu/ops/gmm/ - __init__.py - naive.py # Pure JAX reference implementation - -tests/ops/gmm/ - test_gmm_tpu.py # Pallas vs CPU reference tests - conftest.py # GMM test fixtures -``` - -## Testing Strategy - -1. **CPU reference:** Pure JAX loop over groups with plain matmul -2. **Forward test:** Compare Pallas gmm output vs reference across configs -3. **Gradient test:** Compare custom_vjp gradients vs `jax.grad` of reference -4. **Configs:** Vary (m, k, n, num_groups) with distributions: - - Uniform group sizes - - Skewed (one large group, many small) - - Single group (degenerates to plain matmul) - - Empty groups (group_size=0) - - Sizes not divisible by tm -5. **Tolerances:** atol ~1e-2, rtol ~1e-2 (bf16 accumulation) - -## Future Work (Phase 2+) - -- Block-wise quantization: (128,128) for weights, (1,128) for activations -- `group_offset` for expert parallelism / sharded groups -- `existing_out` for accumulation into pre-existing buffers -- Double/triple buffering (`input_buffer_count`) -- Async DMA pipelining diff --git a/docs/plans/2026-04-06-gmm-kernel-impl.md b/docs/plans/2026-04-06-gmm-kernel-impl.md deleted file mode 100644 index 102b2f1c..00000000 --- a/docs/plans/2026-04-06-gmm-kernel-impl.md +++ /dev/null @@ -1,1296 +0,0 @@ -# GMM (Grouped Matrix Multiplication) Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Implement a BF16 Pallas TPU kernel for Grouped Matrix Multiplication (GMM/TGMM) with custom_vjp, replacing tokamax/qwix backends used in maxtext megablox. - -**Architecture:** Port megablox's proven grid/metadata strategy (CSR-like tile scheduling) to a clean tops-style implementation. Forward kernel uses grid `(tiles_n, num_active_tiles, tiles_k)` with float32 VMEM accumulation. Backward uses `gmm(transpose_rhs=True)` for dlhs and a separate `tgmm` kernel for drhs. No quantization, no sharding (Phase 1). - -**Tech Stack:** JAX, Pallas (TPU mosaic backend), `jax.experimental.pallas.tpu`, `jax.custom_vjp` - ---- - -### Task 1: CPU Reference Implementation - -**Files:** -- Create: `tops/cpu/ops/gmm/__init__.py` -- Create: `tops/cpu/ops/gmm/naive.py` -- Create: `tests/ops/gmm/__init__.py` -- Create: `tests/ops/gmm/test_cpu_ref.py` - -**Step 1: Create the CPU reference module** - -Create `tops/cpu/ops/gmm/__init__.py`: - -```python -from .naive import gmm_ref, tgmm_ref - -__all__ = ["gmm_ref", "tgmm_ref"] -``` - -Create `tops/cpu/ops/gmm/naive.py`: - -```python -"""Pure JAX CPU reference for Grouped Matrix Multiplication.""" - -import jax -import jax.numpy as jnp - -from tops.cpu.ops import cpu_reference - - -@cpu_reference -def gmm_ref( - lhs: jax.Array, - rhs: jax.Array, - group_sizes: jax.Array, - transpose_rhs: bool = False, -) -> jax.Array: - """Grouped matrix multiplication reference implementation. - - For each group i with rows [start_i, end_i): - out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] - - Args: - lhs: [m, k] input activations. - rhs: [num_groups, k, n] per-group weights. - If transpose_rhs=True, rhs is [num_groups, k, n] but used as - [num_groups, n, k] (transposed before matmul). - group_sizes: [num_groups] int32, number of rows per group. - transpose_rhs: If True, transpose each rhs[i] before matmul. - - Returns: - [m, output_dim] where output_dim = rhs.shape[2] if not transpose_rhs - else rhs.shape[1]. - """ - assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" - assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" - assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" - - m = lhs.shape[0] - num_groups = rhs.shape[0] - n = rhs.shape[1] if transpose_rhs else rhs.shape[2] - orig_dtype = lhs.dtype - - out = jnp.zeros((m, n), dtype=jnp.float32) - start = 0 - for i in range(num_groups): - size = int(group_sizes[i]) - end = start + size - if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_mat = rhs[i].astype(jnp.float32) - if transpose_rhs: - rhs_mat = rhs_mat.T - out = out.at[start:end].set(lhs_slice @ rhs_mat) - start = end - return out.astype(orig_dtype) - - -@cpu_reference -def tgmm_ref( - lhs: jax.Array, - rhs: jax.Array, - group_sizes: jax.Array, -) -> jax.Array: - """Transposed grouped matrix multiplication reference implementation. - - For each group i with rows [start_i, end_i): - out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] - - Args: - lhs: [m, k] input activations. - rhs: [m, n] gradient or second operand. - group_sizes: [num_groups] int32, number of rows per group. - - Returns: - [num_groups, k, n] per-group outer products. - """ - assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" - assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" - assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" - assert lhs.shape[0] == rhs.shape[0], ( - f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" - ) - - k = lhs.shape[1] - n = rhs.shape[1] - num_groups = group_sizes.shape[0] - orig_dtype = lhs.dtype - - out = jnp.zeros((num_groups, k, n), dtype=jnp.float32) - start = 0 - for i in range(num_groups): - size = int(group_sizes[i]) - end = start + size - if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_slice = rhs[start:end].astype(jnp.float32) - out = out.at[i].set(lhs_slice.T @ rhs_slice) - start = end - return out.astype(orig_dtype) -``` - -**Step 2: Write tests for CPU reference** - -Create empty `tests/ops/gmm/__init__.py`. - -Create `tests/ops/gmm/test_cpu_ref.py`: - -```python -"""Verify CPU reference implementations for GMM/TGMM are correct.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parents[3])) - -import pytest -import jax -import jax.numpy as jnp -import numpy as np - -from tops.cpu.ops.gmm import gmm_ref, tgmm_ref - - -def _make_gmm_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): - """Generate random inputs for GMM tests.""" - key = jax.random.PRNGKey(seed) - k1, k2 = jax.random.split(key) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) - rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) - gs = jnp.array(group_sizes, dtype=jnp.int32) - return lhs, rhs, gs - - -class TestGmmRef: - """Test gmm_ref against manual numpy computation.""" - - def test_single_group(self): - """Single group = standard matmul.""" - lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) - rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float32) - gs = jnp.array([2], dtype=jnp.int32) - out = gmm_ref(lhs, rhs, gs) - expected = lhs # identity matmul - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) - - def test_two_groups(self): - """Two groups with different weights.""" - lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) - rhs = jnp.array([ - [[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2 - [[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns - ], dtype=jnp.float32) - gs = jnp.array([1, 2], dtype=jnp.int32) - out = gmm_ref(lhs, rhs, gs) - expected = jnp.array([[2.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) - - def test_empty_group(self): - """Empty group produces zeros for those rows (none exist).""" - lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) - rhs = jnp.array([ - [[1.0], [1.0]], # group 0: empty - [[1.0], [1.0]], # group 1: 1 row - ], dtype=jnp.float32) - gs = jnp.array([0, 1], dtype=jnp.int32) - out = gmm_ref(lhs, rhs, gs) - expected = jnp.array([[3.0]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) - - def test_transpose_rhs(self): - """transpose_rhs transposes each rhs[i] before matmul.""" - lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) - rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.float32) - gs = jnp.array([1], dtype=jnp.int32) - # Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16] - out_normal = gmm_ref(lhs, rhs, gs) - np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-5) - # With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17] - out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True) - np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-5) - - -class TestTgmmRef: - """Test tgmm_ref against manual numpy computation.""" - - def test_single_group(self): - """Single group: lhs^T @ rhs.""" - lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) - rhs = jnp.array([[5.0], [6.0]], dtype=jnp.float32) - gs = jnp.array([2], dtype=jnp.int32) - out = tgmm_ref(lhs, rhs, gs) - # lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]] - expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) - - def test_two_groups(self): - """Two groups produce separate outer products.""" - lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.float32) - rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.float32) - gs = jnp.array([1, 2], dtype=jnp.int32) - out = tgmm_ref(lhs, rhs, gs) - # Group 0: [1]^T @ [4] = [[4]] - # Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]] - expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) -``` - -**Step 3: Run tests** - -Run: `uv run pytest tests/ops/gmm/test_cpu_ref.py -v` -Expected: All PASS - -**Step 4: Commit** - -```bash -git add tops/cpu/ops/gmm/ tests/ops/gmm/ -git commit -m "feat(gmm): add CPU reference implementations for gmm and tgmm" -``` - ---- - -### Task 2: Group Metadata Helper - -**Files:** -- Create: `tops/ops/gmm/__init__.py` (empty initially) -- Create: `tops/ops/gmm/metadata.py` -- Create: `tests/ops/gmm/test_metadata.py` - -**Step 1: Write metadata tests** - -Create `tests/ops/gmm/test_metadata.py`: - -```python -"""Test group metadata construction for GMM kernel scheduling.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parents[3])) - -import pytest -import jax.numpy as jnp -import numpy as np - -from tops.ops.gmm.metadata import make_group_metadata - - -class TestMakeGroupMetadata: - """Verify CSR-like metadata maps grid indices to correct groups/tiles.""" - - def test_uniform_groups_aligned(self): - """Groups perfectly aligned to tile boundaries.""" - # 2 groups, 128 rows each, tm=128 -> 1 tile per group, 2 active tiles - gs = jnp.array([128, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=256, tm=128 - ) - assert int(num_tiles) == 2 - np.testing.assert_array_equal(offsets, [0, 128, 256]) - np.testing.assert_array_equal(gids[:2], [0, 1]) - np.testing.assert_array_equal(mids[:2], [0, 1]) - - def test_uniform_groups_multi_tile(self): - """Groups spanning multiple tiles.""" - # 2 groups, 256 rows each, tm=128 -> 2 tiles per group, 4 active tiles - gs = jnp.array([256, 256], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=512, tm=128 - ) - assert int(num_tiles) == 4 - np.testing.assert_array_equal(gids[:4], [0, 0, 1, 1]) - np.testing.assert_array_equal(mids[:4], [0, 1, 2, 3]) - - def test_shared_tile_at_boundary(self): - """Group boundary falls mid-tile -> tile visited twice.""" - # Group 0: 64 rows (not aligned to 128), Group 1: 64 rows - # Tile 0 (rows 0-127) is shared between both groups - gs = jnp.array([64, 64], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128 - ) - # Tile 0 visited twice: once for group 0, once for group 1 - assert int(num_tiles) == 2 - np.testing.assert_array_equal(gids[:2], [0, 1]) - np.testing.assert_array_equal(mids[:2], [0, 0]) - - def test_empty_group(self): - """Empty group (size=0) should not produce active tiles.""" - gs = jnp.array([0, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128 - ) - assert int(num_tiles) == 1 - assert int(gids[0]) == 1 - - def test_visit_empty_groups(self): - """With visit_empty_groups=True, empty groups get one tile each.""" - gs = jnp.array([0, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128, visit_empty_groups=True - ) - assert int(num_tiles) == 2 - np.testing.assert_array_equal(gids[:2], [0, 1]) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) -``` - -**Step 2: Run tests to verify they fail** - -Run: `uv run pytest tests/ops/gmm/test_metadata.py -v` -Expected: FAIL (ImportError - module doesn't exist yet) - -**Step 3: Implement metadata** - -Create empty `tops/ops/gmm/__init__.py`: - -```python -``` - -Create `tops/ops/gmm/metadata.py`: - -```python -"""Group metadata construction for GMM kernel scheduling. - -Builds CSR-like metadata arrays that map Pallas grid indices to (group_id, -m_tile_id) pairs. This enables the GMM kernel to process ragged groups of -varying sizes using a flat 1-D grid over m-tiles. -""" - -from typing import Any - -import jax.numpy as jnp - -GroupMetadata = tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] - - -def make_group_metadata( - *, - group_sizes: jnp.ndarray, - m: int, - tm: int, - visit_empty_groups: bool = False, -) -> tuple[GroupMetadata, jnp.ndarray]: - """Build scheduling metadata for grouped matmul. - - Maps each grid index in the ``num_active_tiles`` dimension to a - ``(group_id, m_tile_id)`` pair so the Pallas kernel knows which group - and which row-tile to process. - - Args: - group_sizes: [num_groups] int32 -- number of rows per group. - m: Total number of rows in lhs (may exceed sum(group_sizes) due to - padding). - tm: Row-dimension tile size. - visit_empty_groups: If True, allocate one tile per empty group (needed - by tgmm to zero the output for empty groups). - - Returns: - (group_offsets, group_ids, m_tile_ids): Metadata arrays. - - group_offsets: [num_groups + 1] int32, CSR-style row offsets. - - group_ids: [tiles_m + num_groups - 1] int32, group for each - active tile. - - m_tile_ids: [tiles_m + num_groups - 1] int32, row-tile index - for each active tile. - num_active_tiles: Scalar int32, how many entries in group_ids / - m_tile_ids are valid. - """ - num_groups = group_sizes.shape[0] - - # --- CSR-style offsets --- - group_ends = jnp.cumsum(group_sizes) - group_offsets = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), group_ends] - ) - - # --- Round boundaries to tile multiples --- - rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) - group_starts = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]] - ) - rounded_group_starts = (group_starts // tm * tm).astype(jnp.int32) - - # --- Tiles per group --- - rounded_group_sizes = rounded_group_ends - rounded_group_starts - rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) - group_tiles = rounded_group_sizes // tm - - if visit_empty_groups: - group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) - - tiles_m = (m + tm - 1) // tm - total_len = tiles_m + num_groups - 1 - - # --- group_ids: map grid index -> group --- - group_ids = jnp.repeat( - jnp.arange(num_groups, dtype=jnp.int32), - group_tiles, - total_repeat_length=total_len, - ) - - # --- m_tile_ids: map grid index -> row-tile --- - # Tiles at group boundaries may be visited twice; count visits per tile. - partial_tile_mask = jnp.logical_or( - (group_offsets[:-1] % tm) == 0, - group_sizes == 0, - ) - if visit_empty_groups: - partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) - - partial_tile_ids = jnp.where( - partial_tile_mask, tiles_m, group_offsets[:-1] // tm - ) - tile_visits = ( - jnp.histogram( - partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1) - )[0] - + 1 - ) - m_tile_ids = jnp.repeat( - jnp.arange(tiles_m, dtype=jnp.int32), - tile_visits.astype(jnp.int32), - total_repeat_length=total_len, - ) - - num_active_tiles = group_tiles.sum() - return (group_offsets, group_ids, m_tile_ids), num_active_tiles -``` - -**Step 4: Run tests** - -Run: `uv run pytest tests/ops/gmm/test_metadata.py -v` -Expected: All PASS - -**Step 5: Commit** - -```bash -git add tops/ops/gmm/ tests/ops/gmm/test_metadata.py -git commit -m "feat(gmm): add make_group_metadata for kernel scheduling" -``` - ---- - -### Task 3: GMM Forward Pallas Kernel + Tests - -**Files:** -- Create: `tops/ops/gmm/gmm.py` -- Create: `tests/ops/gmm/test_gmm_tpu.py` - -**Step 1: Write forward test** - -Create `tests/ops/gmm/test_gmm_tpu.py`: - -```python -"""GMM Pallas kernel accuracy vs CPU reference. - -Forward: gmm_forward (tops.ops.gmm) vs gmm_ref (tops.cpu.ops.gmm) -Backward (tgmm): tgmm_forward vs tgmm_ref -Gradient: custom_vjp gradients vs jax.grad of CPU reference. -""" - -from __future__ import annotations - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parents[3])) - -import pytest -import jax -import jax.numpy as jnp - -from tops.ops.gmm.gmm import gmm_forward, tgmm_forward -from tops.cpu.ops.gmm import gmm_ref, tgmm_ref -from tests.utils import compare_tensor - - -# ============================================================================ -# Test Helpers -# ============================================================================ - -def _make_group_sizes(num_groups, total_m, distribution="uniform", seed=0): - """Generate group_sizes that sum to total_m.""" - key = jax.random.PRNGKey(seed) - if distribution == "uniform": - base = total_m // num_groups - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - remainder = total_m - base * num_groups - sizes = sizes.at[:remainder].add(1) - elif distribution == "skewed": - # First group gets half, rest split the remainder - first = total_m // 2 - rest = total_m - first - base = rest // (num_groups - 1) if num_groups > 1 else 0 - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - sizes = sizes.at[0].set(first) - remainder = total_m - first - base * (num_groups - 1) - if num_groups > 1: - sizes = sizes.at[1].add(remainder) - elif distribution == "single": - sizes = jnp.zeros(num_groups, dtype=jnp.int32) - sizes = sizes.at[0].set(total_m) - elif distribution == "with_empty": - # First group empty, rest uniform - base = total_m // (num_groups - 1) if num_groups > 1 else total_m - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - sizes = sizes.at[0].set(0) - remainder = total_m - base * (num_groups - 1) - sizes = sizes.at[1].add(remainder) - else: - raise ValueError(f"Unknown distribution: {distribution}") - return sizes - - -def _make_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): - """Generate random lhs, rhs for GMM.""" - key = jax.random.PRNGKey(seed) - k1, k2 = jax.random.split(key) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) - rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) - return lhs, rhs - - -# ============================================================================ -# Forward Test Cases -# ============================================================================ - -FWD_CASES = [ - # (m, k, n, num_groups, distribution, seed) - dict(m=128, k=128, n=128, ng=1, dist="single", seed=100), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=101), - dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=102), - dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=103), - dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=104), - dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=105), - dict(m=1024, k=128, n=128, ng=8, dist="uniform", seed=106), - dict(m=640, k=128, n=128, ng=5, dist="skewed", seed=107), -] - - -def _fwd_case_id(c): - return f"m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", FWD_CASES, ids=[_fwd_case_id(c) for c in FWD_CASES]) -def test_gmm_fwd_vs_cpu(cfg): - """gmm_forward (Pallas) should match gmm_ref (CPU).""" - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"]) - - out_ref = gmm_ref(lhs, rhs, gs) - out_pl = gmm_forward(lhs, rhs, gs) - - assert compare_tensor("gmm_fwd", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) - - -@pytest.mark.parametrize("cfg", FWD_CASES[:4], ids=[_fwd_case_id(c) for c in FWD_CASES[:4]]) -def test_gmm_fwd_transpose_rhs(cfg): - """gmm_forward with transpose_rhs should match gmm_ref.""" - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - lhs, rhs = _make_inputs(m, n, k, ng, gs, seed=cfg["seed"]) - - out_ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) - out_pl = gmm_forward(lhs, rhs, gs, transpose_rhs=True) - - assert compare_tensor("gmm_fwd_T", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) -``` - -**Step 2: Run test to verify it fails** - -Run: `uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_fwd_vs_cpu -v --no-header -x` -Expected: FAIL (ImportError - `gmm_forward` doesn't exist) - -**Step 3: Implement GMM forward kernel** - -Create `tops/ops/gmm/gmm.py`: - -```python -"""Pallas TPU kernels for Grouped Matrix Multiplication. - -Implements gmm (forward) and tgmm (transposed, for weight gradients) -using the megablox-style grid scheduling with CSR-like group metadata. -""" - -import functools - -import jax -from jax import lax -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu -import jax.numpy as jnp - -from tops.ops.gmm.metadata import make_group_metadata, GroupMetadata -from tops.ops.utils import get_interpret - - -# ============================================================================ -# Helpers -# ============================================================================ - -def _get_store_mask( - *, - grid_id: jnp.ndarray, - group_metadata: GroupMetadata, - tm: int, - tn: int, -) -> jnp.ndarray: - """Boolean mask [tm, tn] for rows belonging to the current group.""" - group_offsets, group_ids, m_tile_ids = group_metadata - group_id = group_ids[grid_id] - group_start = group_offsets[group_id] - group_end = group_offsets[group_id + 1] - m_id = m_tile_ids[grid_id] * tm - iota = lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id - return jnp.logical_and(iota >= group_start, iota < group_end) - - -def _get_group_size( - *, grid_id: jnp.ndarray, group_metadata: GroupMetadata -) -> jnp.ndarray: - """Number of rows in the current group.""" - group_offsets, group_ids = group_metadata[:2] - group_id = group_ids[grid_id] - return group_offsets[group_id + 1] - group_offsets[group_id] - - -# ============================================================================ -# GMM Forward -# ============================================================================ - -@functools.partial( - jax.jit, - static_argnames=["tiling", "transpose_rhs", "preferred_element_type", "interpret"], -) -def gmm_forward( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - tiling: tuple[int, int, int] = (128, 128, 128), - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, - interpret: bool | None = None, -) -> jnp.ndarray: - """Grouped matrix multiplication: out[group_rows] = lhs[group_rows] @ rhs[group]. - - For each group i with rows [start_i, end_i): - out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] - - When transpose_rhs=True, rhs[i] is transposed before the matmul. - - Args: - lhs: [m, k] bf16 input activations, rows sorted by group. - rhs: [num_groups, k, n] bf16 per-group weights. - group_sizes: [num_groups] int32, row count per group. sum <= m. - tiling: (tm, tk, tn) tile sizes. k must be divisible by tk, - n must be divisible by tn. - transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). - preferred_element_type: Output dtype for accumulation. - interpret: Pallas interpret mode. None = auto-detect from env. - - Returns: - [m, output_n] tensor where output_n = rhs.shape[2] if not - transpose_rhs, else rhs.shape[1]. - """ - if interpret is None: - interpret = get_interpret() - - # --- Validate --- - assert lhs.ndim == 2, f"lhs must be 2D [m, k], got {lhs.ndim}D" - assert rhs.ndim == 3, f"rhs must be 3D [E, k, n], got {rhs.ndim}D" - assert group_sizes.dtype == jnp.int32, ( - f"group_sizes must be int32, got {group_sizes.dtype}" - ) - - # --- Shape info --- - m, k = lhs.shape - if transpose_rhs: - n = rhs.shape[1] - else: - n = rhs.shape[2] - - tm, tk, tn = tiling - assert k % tk == 0, f"k ({k}) must be divisible by tk ({tk})" - assert n % tn == 0, f"n ({n}) must be divisible by tn ({tn})" - - tiles_k = k // tk - tiles_n = n // tn - - # --- Group metadata --- - group_metadata, num_active_tiles = make_group_metadata( - group_sizes=group_sizes, m=m, tm=tm, visit_empty_groups=False - ) - - # --- Kernel --- - def kernel( - group_metadata_ref, - lhs_ref, - rhs_ref, - out_ref, - acc_ref, - ): - group_offsets, group_ids, m_tile_ids = group_metadata_ref - grid_id = pl.program_id(1) - k_i = pl.program_id(2) - - @pl.when(k_i == 0) - def _zero_acc(): - acc_ref[...] = jnp.zeros_like(acc_ref) - - lhs_block = lhs_ref[...] - rhs_block = rhs_ref[...] - - if transpose_rhs: - dims = (((1,), (1,)), ((), ())) - else: - dims = (((1,), (0,)), ((), ())) - - acc_ref[...] += lax.dot_general( - lhs_block, - rhs_block, - dimension_numbers=dims, - preferred_element_type=jnp.float32, - ) - - @pl.when(k_i == tiles_k - 1) - def _store(): - mask = _get_store_mask( - grid_id=grid_id, - group_metadata=(group_offsets, group_ids, m_tile_ids), - tm=tm, - tn=tn, - ) - out_ref[...] = lax.select( - mask, acc_ref[...], out_ref[...].astype(jnp.float32) - ).astype(preferred_element_type) - - # --- Index maps --- - def lhs_index_map(n_i, grid_id, k_i, group_metadata_ref): - _, _, m_tile_ids = group_metadata_ref - return m_tile_ids[grid_id], k_i - - def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): - _, group_ids, _ = group_metadata_ref - if transpose_rhs: - return group_ids[grid_id], n_i, k_i - else: - return group_ids[grid_id], k_i, n_i - - def out_index_map(n_i, grid_id, k_i, group_metadata_ref): - _, _, m_tile_ids = group_metadata_ref - return m_tile_ids[grid_id], n_i - - # --- BlockSpecs --- - lhs_spec = pl.BlockSpec((tm, tk), lhs_index_map) - if transpose_rhs: - rhs_spec = pl.BlockSpec((None, tn, tk), rhs_index_map) - else: - rhs_spec = pl.BlockSpec((None, tk, tn), rhs_index_map) - out_spec = pl.BlockSpec((tm, tn), out_index_map) - - # --- Launch --- - call_fn = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[lhs_spec, rhs_spec], - out_specs=out_spec, - grid=(tiles_n, num_active_tiles, tiles_k), - scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], - ), - compiler_params=pltpu.CompilerParams( - dimension_semantics=("parallel", "arbitrary", "arbitrary") - ), - interpret=interpret, - ) - - out = call_fn(group_metadata, lhs, rhs) - return out - - -# ============================================================================ -# TGMM (Transposed GMM for weight gradients) -# ============================================================================ - -@functools.partial( - jax.jit, - static_argnames=["tiling", "preferred_element_type", "interpret"], -) -def tgmm_forward( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - tiling: tuple[int, int, int] = (128, 128, 128), - preferred_element_type: jnp.dtype = jnp.float32, - interpret: bool | None = None, -) -> jnp.ndarray: - """Transposed grouped matrix multiplication for weight gradients. - - For each group i with rows [start_i, end_i): - out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] - - Args: - lhs: [m, k] bf16 input activations. - rhs: [m, n] bf16 gradients. - group_sizes: [num_groups] int32. - tiling: (tm, tk, tn) tile sizes. - preferred_element_type: Output dtype. - interpret: Pallas interpret mode. - - Returns: - [num_groups, k, n] per-group weight gradients. - """ - if interpret is None: - interpret = get_interpret() - - # --- Validate --- - assert lhs.ndim == 2, f"lhs must be 2D [m, k], got {lhs.ndim}D" - assert rhs.ndim == 2, f"rhs must be 2D [m, n], got {rhs.ndim}D" - assert lhs.shape[0] == rhs.shape[0], ( - f"lhs and rhs must have same m, got {lhs.shape[0]} vs {rhs.shape[0]}" - ) - assert group_sizes.dtype == jnp.int32 - - # --- Shape info --- - m = lhs.shape[0] - k = lhs.shape[1] - n = rhs.shape[1] - num_groups = group_sizes.shape[0] - - tm, tk, tn = tiling - assert k % tk == 0, f"k ({k}) must be divisible by tk ({tk})" - assert n % tn == 0, f"n ({n}) must be divisible by tn ({tn})" - - tiles_k = k // tk - tiles_n = n // tn - - # --- Group metadata --- - group_metadata, num_active_tiles = make_group_metadata( - group_sizes=group_sizes, m=m, tm=tm, visit_empty_groups=True - ) - - # --- Kernel --- - def kernel( - group_metadata_ref, - lhs_ref, - rhs_ref, - out_ref, - acc_ref, - ): - group_offsets, group_ids, m_tile_ids = group_metadata_ref - grid_id = pl.program_id(2) - - group = group_ids[grid_id] - prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) - prev_group = group_ids[prev_grid_id] - group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group) - - @pl.when(group_has_changed) - def _zero_acc(): - acc_ref[...] = jnp.zeros_like(acc_ref) - - # Only compute if group has rows - has_rows = ( - _get_group_size(grid_id=grid_id, group_metadata=(group_offsets, group_ids, m_tile_ids)) - > 0 - ) - - @pl.when(has_rows) - def _compute(): - # Mask rows outside group - lhs_mask = _get_store_mask( - grid_id=grid_id, - group_metadata=(group_offsets, group_ids, m_tile_ids), - tm=tm, - tn=tk, - ) - rhs_mask = _get_store_mask( - grid_id=grid_id, - group_metadata=(group_offsets, group_ids, m_tile_ids), - tm=tm, - tn=tn, - ) - loaded_lhs = lax.select(lhs_mask, lhs_ref[...], jnp.zeros_like(lhs_ref)) - loaded_rhs = lax.select(rhs_mask, rhs_ref[...], jnp.zeros_like(rhs_ref)) - - # lhs^T [tk, tm] @ rhs [tm, tn] = [tk, tn] - acc_ref[...] += lax.dot( - loaded_lhs.swapaxes(0, 1), - loaded_rhs, - preferred_element_type=jnp.float32, - ) - - # Store when group is about to change - is_end = grid_id == (pl.num_programs(2) - 1) - next_grid_id = jnp.where(is_end, grid_id, grid_id + 1) - next_group = group_ids[next_grid_id] - group_is_changing = jnp.logical_or(is_end, group != next_group) - - @pl.when(group_is_changing) - def _store(): - out_ref[...] = acc_ref[...].astype(preferred_element_type) - - # --- Index maps --- - def lhs_index_map(n_i, k_i, grid_id, group_metadata_ref): - _, _, m_tile_ids = group_metadata_ref - return m_tile_ids[grid_id], k_i - - def rhs_index_map(n_i, k_i, grid_id, group_metadata_ref): - _, _, m_tile_ids = group_metadata_ref - return m_tile_ids[grid_id], n_i - - def out_index_map(n_i, k_i, grid_id, group_metadata_ref): - _, group_ids, _ = group_metadata_ref - return group_ids[grid_id], k_i, n_i - - # --- BlockSpecs --- - lhs_spec = pl.BlockSpec((tm, tk), lhs_index_map) - rhs_spec = pl.BlockSpec((tm, tn), rhs_index_map) - out_spec = pl.BlockSpec((None, tk, tn), out_index_map) - - # --- Launch --- - call_fn = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct( - (num_groups, k, n), preferred_element_type - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - in_specs=[lhs_spec, rhs_spec], - out_specs=out_spec, - grid=(tiles_n, tiles_k, num_active_tiles), - scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], - ), - compiler_params=pltpu.CompilerParams( - dimension_semantics=("parallel", "arbitrary", "arbitrary") - ), - interpret=interpret, - ) - - out = call_fn(group_metadata, lhs, rhs) - return out -``` - -**Step 4: Run forward tests** - -Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_fwd_vs_cpu -v` -Expected: All PASS - -**Step 5: Commit** - -```bash -git add tops/ops/gmm/gmm.py tests/ops/gmm/test_gmm_tpu.py -git commit -m "feat(gmm): add GMM forward Pallas kernel" -``` - ---- - -### Task 4: TGMM Tests - -**Files:** -- Modify: `tests/ops/gmm/test_gmm_tpu.py` - -**Step 1: Add TGMM tests to test_gmm_tpu.py** - -Append to `tests/ops/gmm/test_gmm_tpu.py`: - -```python -# ============================================================================ -# TGMM Test Cases -# ============================================================================ - -TGMM_CASES = [ - dict(m=128, k=128, n=128, ng=1, dist="single", seed=200), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=201), - dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=202), - dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=203), - dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=204), - dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=205), -] - - -def _tgmm_case_id(c): - return f"tgmm_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", TGMM_CASES, ids=[_tgmm_case_id(c) for c in TGMM_CASES]) -def test_tgmm_vs_cpu(cfg): - """tgmm_forward (Pallas) should match tgmm_ref (CPU).""" - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - - key = jax.random.PRNGKey(cfg["seed"]) - k1, k2 = jax.random.split(key) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) - rhs = jax.random.normal(k2, (m, n), dtype=jnp.float32).astype(jnp.bfloat16) - - out_ref = tgmm_ref(lhs, rhs, gs) - out_pl = tgmm_forward(lhs, rhs, gs) - - assert compare_tensor("tgmm", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) -``` - -**Step 2: Run TGMM tests** - -Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_tgmm_vs_cpu -v` -Expected: All PASS - -**Step 3: Commit** - -```bash -git add tests/ops/gmm/test_gmm_tpu.py -git commit -m "test(gmm): add tgmm accuracy tests" -``` - ---- - -### Task 5: Custom VJP + Public API + Gradient Tests - -**Files:** -- Modify: `tops/ops/gmm/__init__.py` -- Modify: `tops/ops/gmm/gmm.py` (add custom_vjp wrapper) -- Modify: `tops/ops/__init__.py` -- Modify: `tests/ops/gmm/test_gmm_tpu.py` (add gradient tests) - -**Step 1: Add gradient tests to test_gmm_tpu.py** - -Append to `tests/ops/gmm/test_gmm_tpu.py`: - -```python -# ============================================================================ -# Gradient Tests -# ============================================================================ - -from tops.ops.gmm import gmm - - -def _gmm_ref_differentiable(lhs, rhs, group_sizes): - """Differentiable CPU reference for gradient comparison.""" - m = lhs.shape[0] - num_groups = rhs.shape[0] - n = rhs.shape[2] - out = jnp.zeros((m, n), dtype=jnp.float32) - start = 0 - for i in range(num_groups): - size = int(group_sizes[i]) - end = start + size - if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_mat = rhs[i].astype(jnp.float32) - out = out.at[start:end].set(lhs_slice @ rhs_mat) - start = end - return out.sum() - - -GRAD_CASES = [ - dict(m=128, k=128, n=128, ng=1, dist="single", seed=300), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=301), - dict(m=384, k=128, n=128, ng=3, dist="skewed", seed=302), -] - - -def _grad_case_id(c): - return f"grad_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", GRAD_CASES, ids=[_grad_case_id(c) for c in GRAD_CASES]) -def test_gmm_gradient(cfg): - """custom_vjp gradients should match numerical/reference gradients.""" - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"], dtype=jnp.bfloat16) - - # Pallas gmm gradients - def pallas_loss(lhs, rhs): - return gmm(lhs, rhs, gs).sum() - - dlhs_pl, drhs_pl = jax.grad(pallas_loss, argnums=(0, 1))(lhs, rhs) - - # Reference gradients (CPU) - def ref_loss(lhs, rhs): - return _gmm_ref_differentiable(lhs, rhs, gs) - - dlhs_ref, drhs_ref = jax.grad(ref_loss, argnums=(0, 1))(lhs, rhs) - - assert compare_tensor("dlhs", dlhs_ref, dlhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) - assert compare_tensor("drhs", drhs_ref, drhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) -``` - -**Step 2: Run gradient test to verify it fails** - -Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_gradient -v --no-header -x` -Expected: FAIL (ImportError - `tops.ops.gmm.gmm` function doesn't exist) - -**Step 3: Add custom_vjp to gmm.py** - -Add to end of `tops/ops/gmm/gmm.py`: - -```python -# ============================================================================ -# Differentiable GMM with custom_vjp -# ============================================================================ - -@functools.partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - tiling: tuple[int, int, int] = (128, 128, 128), - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, -) -> jnp.ndarray: - """Differentiable grouped matrix multiplication. - - For each group i with rows [start_i, end_i): - out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] - - Supports automatic differentiation via custom_vjp: - - dlhs: computed via gmm with transposed rhs - - drhs: computed via tgmm - - Args: - lhs: [m, k] bf16 input activations, rows sorted by group. - rhs: [num_groups, k, n] bf16 per-group weights. - group_sizes: [num_groups] int32, row count per group. - tiling: (tm, tk, tn) tile sizes. - transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). - preferred_element_type: Output dtype. - - Returns: - [m, output_n] bf16 tensor. - """ - return _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type)[0] - - -def _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type): - out = gmm_forward( - lhs, rhs, group_sizes, - tiling=tiling, - transpose_rhs=transpose_rhs, - preferred_element_type=preferred_element_type, - ) - return out, (lhs, rhs, group_sizes) - - -def _gmm_bwd(tiling, transpose_rhs, preferred_element_type, residuals, grad): - lhs, rhs, group_sizes = residuals - - # dlhs = grad @ rhs^T per group - dlhs = gmm_forward( - grad, rhs, group_sizes, - tiling=tiling, - transpose_rhs=not transpose_rhs, - preferred_element_type=preferred_element_type, - ).astype(lhs.dtype) - - # drhs = lhs^T @ grad per group - drhs = tgmm_forward( - lhs, grad, group_sizes, - tiling=tiling, - preferred_element_type=preferred_element_type, - ).astype(rhs.dtype) - - return dlhs, drhs, None - - -gmm.defvjp(_gmm_fwd, _gmm_bwd) -``` - -**Step 4: Create public API** - -Update `tops/ops/gmm/__init__.py`: - -```python -"""Public API for grouped matrix multiplication.""" - -from .gmm import gmm, gmm_forward, tgmm_forward - -__all__ = ["gmm", "gmm_forward", "tgmm_forward"] -``` - -Update `tops/ops/__init__.py` to add gmm: - -```python -"""Public API for tops.ops. - -All public interfaces are exported exclusively via this file. -Any interface not re-exported here is considered an internal implementation -detail with **no API stability guarantee**. -""" - -from .simple_gla import simple_gla -from .gmm import gmm - -__all__ = [ - "simple_gla", - "gmm", -] -``` - -**Step 5: Run gradient tests** - -Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/test_gmm_tpu.py::test_gmm_gradient -v` -Expected: All PASS - -**Step 6: Run all GMM tests** - -Run: `PALLAS_INTERPRET=1 uv run pytest tests/ops/gmm/ -v` -Expected: All PASS - -**Step 7: Lint** - -Run: `uv run ruff check tops/ops/gmm/ tops/cpu/ops/gmm/ tests/ops/gmm/` -Run: `uv run ruff format tops/ops/gmm/ tops/cpu/ops/gmm/ tests/ops/gmm/` - -**Step 8: Commit** - -```bash -git add tops/ops/gmm/ tops/ops/__init__.py tests/ops/gmm/ -git commit -m "feat(gmm): add custom_vjp wrapper and public API for differentiable GMM" -``` - ---- - -## Implementation Notes - -**Key differences from megablox:** -- No `qwix`/`qpl` dependency -- uses standard `lax.dot_general` and `pl.pallas_call` -- No `group_offset` / sharding support (Phase 1) -- No `existing_out` / input_output_aliases -- No `LutFn` tiling lookup -- static tuple only -- `num_scalar_prefetch=1` (only group_metadata, no group_offset) -- Clean tops-style assertions and docstrings - -**Testing approach:** -- CPU reference (`gmm_ref`, `tgmm_ref`) as ground truth -- `PALLAS_INTERPRET=1` for CPU-based Pallas testing during development -- On TPU: run without PALLAS_INTERPRET for native kernel execution -- Tolerance: atol=1e-2, rtol=1e-2 (bf16 accumulation in tiles) -- Gradient tolerance: atol=5e-2 (backward has compounded error) - -**File dependency graph:** -``` -tops/ops/gmm/metadata.py <- no internal deps -tops/ops/gmm/gmm.py <- metadata.py, tops.ops.utils -tops/ops/gmm/__init__.py <- gmm.py -tops/cpu/ops/gmm/naive.py <- tops.cpu.ops (cpu_reference) -tops/cpu/ops/gmm/__init__.py <- naive.py -``` From 4d7af49401c8b56034ff1811844d89c178e56f30 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 16:58:07 +0800 Subject: [PATCH 09/15] fix(gmm): correct transpose_rhs docstring in gmm_ref The docstring incorrectly stated rhs shape as [num_groups, k, n] when transpose_rhs=True. The actual input shape is [num_groups, n, k], with each slice transposed to [k, n] before matmul. Co-Authored-By: Claude Opus 4.6 --- tops/cpu/ops/gmm/naive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tops/cpu/ops/gmm/naive.py b/tops/cpu/ops/gmm/naive.py index fa106501..1a7ac059 100644 --- a/tops/cpu/ops/gmm/naive.py +++ b/tops/cpu/ops/gmm/naive.py @@ -21,8 +21,8 @@ def gmm_ref( Args: lhs: [m, k] input activations. rhs: [num_groups, k, n] per-group weights. - If transpose_rhs=True, rhs is [num_groups, k, n] but used as - [num_groups, n, k] (transposed before matmul). + If transpose_rhs=True, rhs is [num_groups, n, k] and each slice + is transposed to [k, n] before matmul. group_sizes: [num_groups] int32, number of rows per group. transpose_rhs: If True, transpose each rhs[i] before matmul. From afa058e5e34400a50825ab568e75412435fbfd52 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:01:07 +0800 Subject: [PATCH 10/15] refactor(gmm): remove Pallas kernels, use bf16 mul + f32 accumulation in refs Remove all Pallas TPU kernel code (gmm.py, metadata.py) and their tests (test_gmm_tpu.py, test_metadata.py). Retain CPU reference implementations as the ground truth. Change naive refs from "cast to f32 then multiply" to "bf16 multiply with f32 accumulation" via lax.dot(preferred_element_type=f32), matching TPU MXU semantics. Output is now f32 directly instead of casting back to input dtype. Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_gmm_tpu.py | 205 ----------- tests/ops/gmm/test_metadata.py | 76 ---- tops/cpu/ops/gmm/naive.py | 32 +- tops/ops/__init__.py | 2 - tops/ops/gmm/__init__.py | 4 - tops/ops/gmm/gmm.py | 639 --------------------------------- tops/ops/gmm/metadata.py | 111 ------ 7 files changed, 21 insertions(+), 1048 deletions(-) delete mode 100644 tests/ops/gmm/test_gmm_tpu.py delete mode 100644 tests/ops/gmm/test_metadata.py delete mode 100644 tops/ops/gmm/gmm.py delete mode 100644 tops/ops/gmm/metadata.py diff --git a/tests/ops/gmm/test_gmm_tpu.py b/tests/ops/gmm/test_gmm_tpu.py deleted file mode 100644 index f6bf718f..00000000 --- a/tests/ops/gmm/test_gmm_tpu.py +++ /dev/null @@ -1,205 +0,0 @@ -"""GMM Pallas kernel accuracy vs CPU reference.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parents[3])) - -import pytest -import jax -import jax.numpy as jnp - -import numpy as np - -from tops.ops.gmm.gmm import gmm_forward, tgmm_forward -from tops.ops.gmm import gmm -from tops.cpu.ops.gmm import gmm_ref, tgmm_ref - - -def compare_tensor(name, gold, tensor, atol=1e-5, rtol=1e-5, max_ulp=1): - """Lightweight compare_tensor that works without torch.""" - if isinstance(gold, jax.Array): - gold = np.array(gold).astype(np.float64) - if isinstance(tensor, jax.Array): - tensor = np.array(tensor).astype(np.float64) - if gold.shape != tensor.shape: - print(f"[{name}] Shape mismatch: {gold.shape} vs {tensor.shape}. FAIL.") - return False - diff = np.abs(gold - tensor) - max_diff = np.max(diff) - max_val = np.max(np.abs(tensor)) - is_close = np.allclose(gold, tensor, atol=atol, rtol=rtol, equal_nan=True) - status = "PASS" if is_close else "FAIL" - print(f"[{name}] {status} max_val={max_val:.6e} max_diff={max_diff:.6e}") - return is_close - - -# Helpers -def _make_group_sizes(num_groups, total_m, distribution="uniform", seed=0): - key = jax.random.PRNGKey(seed) - if distribution == "uniform": - base = total_m // num_groups - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - remainder = total_m - base * num_groups - sizes = sizes.at[:remainder].add(1) - elif distribution == "skewed": - first = total_m // 2 - rest = total_m - first - base = rest // (num_groups - 1) if num_groups > 1 else 0 - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - sizes = sizes.at[0].set(first) - remainder = total_m - first - base * (num_groups - 1) - if num_groups > 1: - sizes = sizes.at[1].add(remainder) - elif distribution == "single": - sizes = jnp.zeros(num_groups, dtype=jnp.int32) - sizes = sizes.at[0].set(total_m) - elif distribution == "with_empty": - base = total_m // (num_groups - 1) if num_groups > 1 else total_m - sizes = jnp.full(num_groups, base, dtype=jnp.int32) - sizes = sizes.at[0].set(0) - remainder = total_m - base * (num_groups - 1) - sizes = sizes.at[1].add(remainder) - return sizes - - -def _make_inputs(m, k, n, num_groups, group_sizes, seed=42, dtype=jnp.bfloat16): - key = jax.random.PRNGKey(seed) - k1, k2 = jax.random.split(key) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(dtype) - rhs = jax.random.normal(k2, (num_groups, k, n), dtype=jnp.float32).astype(dtype) - return lhs, rhs - - -# Forward test cases -FWD_CASES = [ - dict(m=128, k=128, n=128, ng=1, dist="single", seed=100), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=101), - dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=102), - dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=103), - dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=104), - dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=105), - dict(m=1024, k=128, n=128, ng=8, dist="uniform", seed=106), - dict(m=640, k=128, n=128, ng=5, dist="skewed", seed=107), -] - - -def _fwd_case_id(c): - return f"m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", FWD_CASES, ids=[_fwd_case_id(c) for c in FWD_CASES]) -def test_gmm_fwd_vs_cpu(cfg): - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"]) - out_ref = gmm_ref(lhs, rhs, gs) - out_pl = gmm_forward(lhs, rhs, gs) - assert compare_tensor("gmm_fwd", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) - - -@pytest.mark.parametrize( - "cfg", FWD_CASES[:4], ids=[_fwd_case_id(c) for c in FWD_CASES[:4]] -) -def test_gmm_fwd_transpose_rhs(cfg): - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - key = jax.random.PRNGKey(cfg["seed"]) - k1, k2 = jax.random.split(key) - # For transpose_rhs: lhs [m, k], rhs [ng, n, k] (rhs is transposed inside) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) - rhs = jax.random.normal(k2, (ng, n, k), dtype=jnp.float32).astype(jnp.bfloat16) - out_ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) - out_pl = gmm_forward(lhs, rhs, gs, transpose_rhs=True) - assert compare_tensor("gmm_fwd_T", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) - - -# TGMM test cases -TGMM_CASES = [ - dict(m=128, k=128, n=128, ng=1, dist="single", seed=200), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=201), - dict(m=512, k=128, n=256, ng=4, dist="uniform", seed=202), - dict(m=384, k=256, n=128, ng=3, dist="skewed", seed=203), - dict(m=256, k=128, n=128, ng=4, dist="with_empty", seed=204), - dict(m=512, k=256, n=256, ng=4, dist="uniform", seed=205), -] - - -def _tgmm_case_id(c): - return f"tgmm_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", TGMM_CASES, ids=[_tgmm_case_id(c) for c in TGMM_CASES]) -def test_tgmm_vs_cpu(cfg): - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - key = jax.random.PRNGKey(cfg["seed"]) - k1, k2 = jax.random.split(key) - lhs = jax.random.normal(k1, (m, k), dtype=jnp.float32).astype(jnp.bfloat16) - rhs = jax.random.normal(k2, (m, n), dtype=jnp.float32).astype(jnp.bfloat16) - out_ref = tgmm_ref(lhs, rhs, gs) - out_pl = tgmm_forward(lhs, rhs, gs) - assert compare_tensor("tgmm", out_ref, out_pl, atol=1e-2, rtol=1e-2, max_ulp=4) - - -# ============================================================================ -# Gradient Tests -# ============================================================================ - - -def _gmm_ref_differentiable(lhs, rhs, group_sizes): - """Differentiable reference for gradient comparison (not @cpu_reference).""" - m = lhs.shape[0] - num_groups = rhs.shape[0] - n = rhs.shape[2] - out = jnp.zeros((m, n), dtype=jnp.float32) - start = 0 - for i in range(num_groups): - size = int(group_sizes[i]) - end = start + size - if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_mat = rhs[i].astype(jnp.float32) - out = out.at[start:end].set(lhs_slice @ rhs_mat) - start = end - return out.sum() - - -GRAD_CASES = [ - dict(m=128, k=128, n=128, ng=1, dist="single", seed=300), - dict(m=256, k=128, n=128, ng=2, dist="uniform", seed=301), - dict(m=384, k=128, n=128, ng=3, dist="skewed", seed=302), -] - - -def _grad_case_id(c): - return f"grad_m{c['m']}_k{c['k']}_n{c['n']}_ng{c['ng']}_{c['dist']}" - - -@pytest.mark.parametrize("cfg", GRAD_CASES, ids=[_grad_case_id(c) for c in GRAD_CASES]) -def test_gmm_gradient(cfg): - """custom_vjp gradients should match reference gradients.""" - m, k, n, ng = cfg["m"], cfg["k"], cfg["n"], cfg["ng"] - gs = _make_group_sizes(ng, m, cfg["dist"], seed=cfg["seed"]) - lhs, rhs = _make_inputs(m, k, n, ng, gs, seed=cfg["seed"], dtype=jnp.bfloat16) - - # Pallas gmm gradients - def pallas_loss(lhs, rhs): - return gmm(lhs, rhs, gs).sum() - - dlhs_pl, drhs_pl = jax.grad(pallas_loss, argnums=(0, 1))(lhs, rhs) - - # Reference gradients - def ref_loss(lhs, rhs): - return _gmm_ref_differentiable(lhs, rhs, gs) - - dlhs_ref, drhs_ref = jax.grad(ref_loss, argnums=(0, 1))(lhs, rhs) - - assert compare_tensor("dlhs", dlhs_ref, dlhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) - assert compare_tensor("drhs", drhs_ref, drhs_pl, atol=5e-2, rtol=5e-2, max_ulp=8) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/ops/gmm/test_metadata.py b/tests/ops/gmm/test_metadata.py deleted file mode 100644 index 4ef2dbbf..00000000 --- a/tests/ops/gmm/test_metadata.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Test group metadata construction for GMM kernel scheduling.""" - -from __future__ import annotations - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parents[3])) - -import pytest -import jax.numpy as jnp -import numpy as np - -from tops.ops.gmm.metadata import make_group_metadata - - -class TestMakeGroupMetadata: - """Verify CSR-like metadata maps grid indices to correct groups/tiles.""" - - def test_uniform_groups_aligned(self): - """Groups perfectly aligned to tile boundaries.""" - # 2 groups, 128 rows each, tm=128 -> 1 tile per group, 2 active tiles - gs = jnp.array([128, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=256, tm=128 - ) - assert int(num_tiles) == 2 - np.testing.assert_array_equal(offsets, [0, 128, 256]) - np.testing.assert_array_equal(gids[:2], [0, 1]) - np.testing.assert_array_equal(mids[:2], [0, 1]) - - def test_uniform_groups_multi_tile(self): - """Groups spanning multiple tiles.""" - # 2 groups, 256 rows each, tm=128 -> 2 tiles per group, 4 active tiles - gs = jnp.array([256, 256], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=512, tm=128 - ) - assert int(num_tiles) == 4 - np.testing.assert_array_equal(gids[:4], [0, 0, 1, 1]) - np.testing.assert_array_equal(mids[:4], [0, 1, 2, 3]) - - def test_shared_tile_at_boundary(self): - """Group boundary falls mid-tile -> tile visited twice.""" - # Group 0: 64 rows (not aligned to 128), Group 1: 64 rows - # Tile 0 (rows 0-127) is shared between both groups - gs = jnp.array([64, 64], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128 - ) - # Tile 0 visited twice: once for group 0, once for group 1 - assert int(num_tiles) == 2 - np.testing.assert_array_equal(gids[:2], [0, 1]) - np.testing.assert_array_equal(mids[:2], [0, 0]) - - def test_empty_group(self): - """Empty group (size=0) should not produce active tiles.""" - gs = jnp.array([0, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128 - ) - assert int(num_tiles) == 1 - assert int(gids[0]) == 1 - - def test_visit_empty_groups(self): - """With visit_empty_groups=True, empty groups get one tile each.""" - gs = jnp.array([0, 128], dtype=jnp.int32) - (offsets, gids, mids), num_tiles = make_group_metadata( - group_sizes=gs, m=128, tm=128, visit_empty_groups=True - ) - assert int(num_tiles) == 2 - np.testing.assert_array_equal(gids[:2], [0, 1]) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tops/cpu/ops/gmm/naive.py b/tops/cpu/ops/gmm/naive.py index 1a7ac059..8b146346 100644 --- a/tops/cpu/ops/gmm/naive.py +++ b/tops/cpu/ops/gmm/naive.py @@ -1,7 +1,11 @@ -"""Pure JAX CPU reference for Grouped Matrix Multiplication.""" +"""Pure JAX CPU reference for Grouped Matrix Multiplication. + +Uses bf16 multiplication with f32 accumulation to match TPU MXU semantics. +""" import jax import jax.numpy as jnp +from jax import lax from tops.cpu.ops import cpu_reference @@ -18,6 +22,8 @@ def gmm_ref( For each group i with rows [start_i, end_i): out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + Uses bf16 multiplication with f32 accumulation (``preferred_element_type``). + Args: lhs: [m, k] input activations. rhs: [num_groups, k, n] per-group weights. @@ -37,7 +43,6 @@ def gmm_ref( m = lhs.shape[0] num_groups = rhs.shape[0] n = rhs.shape[1] if transpose_rhs else rhs.shape[2] - orig_dtype = lhs.dtype out = jnp.zeros((m, n), dtype=jnp.float32) start = 0 @@ -45,13 +50,15 @@ def gmm_ref( size = int(group_sizes[i]) end = start + size if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_mat = rhs[i].astype(jnp.float32) + lhs_slice = lhs[start:end] + rhs_mat = rhs[i] if transpose_rhs: rhs_mat = rhs_mat.T - out = out.at[start:end].set(lhs_slice @ rhs_mat) + out = out.at[start:end].set( + lax.dot(lhs_slice, rhs_mat, preferred_element_type=jnp.float32) + ) start = end - return out.astype(orig_dtype) + return out @cpu_reference @@ -65,6 +72,8 @@ def tgmm_ref( For each group i with rows [start_i, end_i): out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + Uses bf16 multiplication with f32 accumulation (``preferred_element_type``). + Args: lhs: [m, k] input activations. rhs: [m, n] gradient or second operand. @@ -83,7 +92,6 @@ def tgmm_ref( k = lhs.shape[1] n = rhs.shape[1] num_groups = group_sizes.shape[0] - orig_dtype = lhs.dtype out = jnp.zeros((num_groups, k, n), dtype=jnp.float32) start = 0 @@ -91,8 +99,10 @@ def tgmm_ref( size = int(group_sizes[i]) end = start + size if size > 0: - lhs_slice = lhs[start:end].astype(jnp.float32) - rhs_slice = rhs[start:end].astype(jnp.float32) - out = out.at[i].set(lhs_slice.T @ rhs_slice) + lhs_slice = lhs[start:end] + rhs_slice = rhs[start:end] + out = out.at[i].set( + lax.dot(lhs_slice.T, rhs_slice, preferred_element_type=jnp.float32) + ) start = end - return out.astype(orig_dtype) + return out diff --git a/tops/ops/__init__.py b/tops/ops/__init__.py index 12b6a4c8..fc529e87 100644 --- a/tops/ops/__init__.py +++ b/tops/ops/__init__.py @@ -5,10 +5,8 @@ detail with **no API stability guarantee**. """ -from .gmm import gmm from .simple_gla import simple_gla __all__ = [ - "gmm", "simple_gla", ] diff --git a/tops/ops/gmm/__init__.py b/tops/ops/gmm/__init__.py index a5d357f4..1f7910f8 100644 --- a/tops/ops/gmm/__init__.py +++ b/tops/ops/gmm/__init__.py @@ -1,5 +1 @@ """Public API for grouped matrix multiplication.""" - -from .gmm import gmm, gmm_forward, tgmm_forward - -__all__ = ["gmm", "gmm_forward", "tgmm_forward"] diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py deleted file mode 100644 index 2539bd5d..00000000 --- a/tops/ops/gmm/gmm.py +++ /dev/null @@ -1,639 +0,0 @@ -"""Pallas TPU kernels for Grouped Matrix Multiplication (GMM). - -Provides two core operations: - -* ``gmm_forward`` -- forward grouped matmul: - for each group *i* with rows ``[start_i, end_i)``: - ``out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i]`` - -* ``tgmm_forward`` -- transposed grouped matmul (weight gradient): - for each group *i* with rows ``[start_i, end_i)``: - ``out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :]`` -""" - -from __future__ import annotations - -import functools -from typing import Any - -import jax -import jax.lax as lax -import jax.numpy as jnp -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu - -from tops.ops.gmm.metadata import make_group_metadata -from tops.ops.utils import get_interpret - - -# --------------------------------------------------------------------------- -# Tiling helpers -# --------------------------------------------------------------------------- - - -def _validate_tiling( - tiling: tuple[int, int, int] | None, m: int, k: int, n: int -) -> tuple[int, int, int]: - """Return validated (tm, tk, tn) tile sizes. - - Args: - tiling: User-provided ``(tm, tk, tn)`` or *None* for defaults. - m: Total rows in lhs. - k: Contraction dimension. - n: Output columns (or rhs last dim). - - Returns: - ``(tm, tk, tn)`` with each dimension clamped to the actual size. - """ - if tiling is None: - tm = min(128, m) - tk = min(128, k) - tn = min(128, n) - else: - tm, tk, tn = tiling - return (tm, tk, tn) - - -# --------------------------------------------------------------------------- -# Masking helper -# --------------------------------------------------------------------------- - - -def _get_store_mask( - group_metadata_ref, - group_offsets_ref, - m_tile_ids_ref, - group_ids_ref, - grid_id: int, - tm: int, - tn: int, -) -> jax.Array: - """Build a ``[tm, tn]`` boolean mask that is True for valid rows. - - A tile at a group boundary may contain rows belonging to two groups. - Only the rows that belong to the current group should be stored. - - Args: - group_metadata_ref: Scalar-prefetch ref holding - ``(group_offsets, group_ids, m_tile_ids)``. - group_offsets_ref: Ref into group_offsets array. - m_tile_ids_ref: Ref into m_tile_ids array. - group_ids_ref: Ref into group_ids array. - grid_id: Current index in the ``num_active_tiles`` grid dimension. - tm: Row tile size. - tn: Column tile size. - - Returns: - Boolean array of shape ``[tm, tn]``. - """ - group_id = group_ids_ref[grid_id] - group_start = group_offsets_ref[group_id] - group_end = group_offsets_ref[group_id + 1] - - m_tile_id = m_tile_ids_ref[grid_id] - row_start = m_tile_id * tm - row_indices = row_start + jnp.arange(tm, dtype=jnp.int32) - - valid_rows = (row_indices >= group_start) & (row_indices < group_end) - # Broadcast to [tm, tn] - return jnp.broadcast_to(valid_rows[:, None], (tm, tn)) - - -# =================================================================== -# GMM Forward -# =================================================================== - - -def _gmm_kernel( - # Scalar prefetch ref - group_metadata_ref, - # Input refs - lhs_ref, - rhs_ref, - # Output ref - out_ref, - # Scratch ref - acc_ref, - *, - tm: int, - tk: int, - tn: int, - tiles_k: int, - preferred_element_type: Any, - transpose_rhs: bool, -): - """GMM forward kernel body. - - Args: - group_metadata_ref: Scalar-prefetch ref containing - ``(group_offsets, group_ids, m_tile_ids)`` as a flat array. - lhs_ref: VMEM ref ``[tm, tk]``. - rhs_ref: VMEM ref ``[tk, tn]`` (or ``[tn, tk]`` when transpose_rhs). - out_ref: VMEM ref ``[tm, tn]``. - acc_ref: VMEM scratch ``[tm, tn]`` float32. - tm: Row tile size. - tk: K-dimension tile size. - tn: Column tile size. - tiles_k: Number of tiles along K. - preferred_element_type: Output dtype. - transpose_rhs: Whether rhs is stored transposed. - """ - n_i, grid_id, k_i = pl.program_id(0), pl.program_id(1), pl.program_id(2) - - # Unpack metadata from the scalar prefetch ref. - # group_metadata_ref is a tuple of 3 refs: (group_offsets, group_ids, m_tile_ids) - group_offsets_ref = group_metadata_ref[0] - group_ids_ref = group_metadata_ref[1] - m_tile_ids_ref = group_metadata_ref[2] - - # Zero accumulator on first k-tile. - @pl.when(k_i == 0) - def _zero(): - acc_ref[...] = jnp.zeros((tm, tn), dtype=jnp.float32) - - # Load blocks and accumulate. - lhs_block = lhs_ref[...].astype(jnp.float32) - rhs_block = rhs_ref[...].astype(jnp.float32) - - if transpose_rhs: - dims = ((1,), (1,)), ((), ()) - else: - dims = ((1,), (0,)), ((), ()) - - acc_ref[...] += lax.dot_general(lhs_block, rhs_block, dims) - - # On last k-tile: apply mask and store. - @pl.when(k_i == tiles_k - 1) - def _store(): - mask = _get_store_mask( - group_metadata_ref, - group_offsets_ref, - m_tile_ids_ref, - group_ids_ref, - grid_id, - tm, - tn, - ) - acc = acc_ref[...] - existing = out_ref[...].astype(jnp.float32) - result = lax.select(mask, acc, existing) - out_ref[...] = result.astype(preferred_element_type) - - -@functools.partial( - jax.jit, - static_argnames=["tiling", "transpose_rhs", "preferred_element_type", "interpret"], -) -def gmm_forward( - lhs: jax.Array, - rhs: jax.Array, - group_sizes: jax.Array, - tiling: tuple[int, int, int] | None = None, - transpose_rhs: bool = False, - preferred_element_type: Any = None, - interpret: bool | None = None, -) -> jax.Array: - """Grouped matrix multiplication forward pass. - - Computes per-group matmuls where rows of ``lhs`` are partitioned into - groups defined by ``group_sizes``, and each group multiplies against the - corresponding weight matrix in ``rhs``. - - Semantics:: - - for group i with rows [start_i, end_i): - out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] - - Args: - lhs: ``[m, k]`` input activations in bfloat16 / float32. - rhs: ``[num_groups, k, n]`` per-group weight matrices. - When *transpose_rhs* is True the layout is ``[num_groups, n, k]`` - and each slice is transposed before multiplication. - group_sizes: ``[num_groups]`` int32 giving the number of rows per group. - tiling: ``(tm, tk, tn)`` tile sizes or *None* for auto. - transpose_rhs: If True, each ``rhs[i]`` is ``[n, k]`` and transposed. - preferred_element_type: Output dtype; defaults to ``lhs.dtype``. - interpret: If *None* falls back to ``PALLAS_INTERPRET`` env var. - - Returns: - ``[m, n]`` output tensor. - """ - assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" - assert rhs.ndim == 3, f"rhs must be 3D, got {rhs.ndim}D" - assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" - - if interpret is None: - interpret = get_interpret() - if preferred_element_type is None: - preferred_element_type = lhs.dtype - - m, k_lhs = lhs.shape - num_groups = rhs.shape[0] - - if transpose_rhs: - n, k_rhs = rhs.shape[1], rhs.shape[2] - else: - k_rhs, n = rhs.shape[1], rhs.shape[2] - assert k_lhs == k_rhs, f"lhs K ({k_lhs}) must match rhs K ({k_rhs})" - - k = k_lhs - tm, tk, tn = _validate_tiling(tiling, m, k, n) - - # Pad m to multiple of tm. - m_padded = ((m + tm - 1) // tm) * tm - if m_padded > m: - lhs = jnp.pad(lhs, ((0, m_padded - m), (0, 0))) - - # Pad k to multiple of tk. - k_padded = ((k + tk - 1) // tk) * tk - if k_padded > k: - lhs = jnp.pad(lhs, ((0, 0), (0, k_padded - k))) - if transpose_rhs: - rhs = jnp.pad(rhs, ((0, 0), (0, 0), (0, k_padded - k))) - else: - rhs = jnp.pad(rhs, ((0, 0), (0, k_padded - k), (0, 0))) - - # Pad n to multiple of tn. - n_padded = ((n + tn - 1) // tn) * tn - if n_padded > n: - if transpose_rhs: - rhs = jnp.pad(rhs, ((0, 0), (0, n_padded - n), (0, 0))) - else: - rhs = jnp.pad(rhs, ((0, 0), (0, 0), (0, n_padded - n))) - - tiles_k = k_padded // tk - tiles_n = n_padded // tn - - # Build group metadata. - group_metadata, num_active_tiles = make_group_metadata( - group_sizes=group_sizes, - m=m_padded, - tm=tm, - ) - group_offsets, group_ids, m_tile_ids = group_metadata - - grid = (tiles_n, num_active_tiles, tiles_k) - - # --- Index maps --- - # Note: BlockSpec multiplies returned block indices by the corresponding - # block_shape dimension. Axes with ``None`` in block_shape are element - # indices (not multiplied). - def lhs_index_map(n_i, grid_id, k_i, group_metadata_ref): - del n_i - m_tile_ids_ref = group_metadata_ref[2] - return (m_tile_ids_ref[grid_id], k_i) - - if transpose_rhs: - rhs_block_shape = (None, tn, tk) - - def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): - group_ids_ref = group_metadata_ref[1] - return (group_ids_ref[grid_id], n_i, k_i) - else: - rhs_block_shape = (None, tk, tn) - - def rhs_index_map(n_i, grid_id, k_i, group_metadata_ref): - group_ids_ref = group_metadata_ref[1] - return (group_ids_ref[grid_id], k_i, n_i) - - def out_index_map(n_i, grid_id, k_i, group_metadata_ref): - del k_i - m_tile_ids_ref = group_metadata_ref[2] - return (m_tile_ids_ref[grid_id], n_i) - - out_shape = jax.ShapeDtypeStruct((m_padded, n_padded), preferred_element_type) - - kernel_fn = functools.partial( - _gmm_kernel, - tm=tm, - tk=tk, - tn=tn, - tiles_k=tiles_k, - preferred_element_type=preferred_element_type, - transpose_rhs=transpose_rhs, - ) - - result = pl.pallas_call( - kernel_fn, - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=grid, - in_specs=[ - pl.BlockSpec((tm, tk), lhs_index_map), - pl.BlockSpec(rhs_block_shape, rhs_index_map), - ], - out_specs=pl.BlockSpec((tm, tn), out_index_map), - scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], - ), - compiler_params=pltpu.CompilerParams( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - ), - out_shape=out_shape, - interpret=interpret, - )( - (group_offsets, group_ids, m_tile_ids), - lhs, - rhs, - ) - - # Un-pad output. - return result[:m, :n] - - -# =================================================================== -# TGMM Forward (Transposed Grouped Matrix Multiplication) -# =================================================================== - - -def _tgmm_kernel( - # Scalar prefetch ref - group_metadata_ref, - # Input refs - lhs_ref, - rhs_ref, - # Output ref - out_ref, - # Scratch ref - acc_ref, - *, - tm: int, - tk: int, - tn: int, -): - """TGMM kernel body. - - Accumulates ``lhs^T @ rhs`` per group into the output. - - Args: - group_metadata_ref: Scalar-prefetch ref containing - ``(group_offsets, group_ids, m_tile_ids)``. - lhs_ref: VMEM ref ``[tm, tk]``. - rhs_ref: VMEM ref ``[tm, tn]``. - out_ref: VMEM ref ``[tk, tn]``. - acc_ref: VMEM scratch ``[tk, tn]`` float32. - tm: Row tile size. - tk: K-dimension tile size. - tn: Column tile size. - """ - n_i, k_i, grid_id = pl.program_id(0), pl.program_id(1), pl.program_id(2) - - group_offsets_ref = group_metadata_ref[0] - group_ids_ref = group_metadata_ref[1] - m_tile_ids_ref = group_metadata_ref[2] - - group_id = group_ids_ref[grid_id] - - # Determine if this is the first tile for the current group. - prev_group_id = jnp.where(grid_id > 0, group_ids_ref[grid_id - 1], -1) - is_new_group = (group_id != prev_group_id) | (grid_id == 0) - - # Zero accumulator when entering a new group. - @pl.when(is_new_group) - def _zero(): - acc_ref[...] = jnp.zeros((tk, tn), dtype=jnp.float32) - - # Compute row mask for group boundaries. - m_tile_id = m_tile_ids_ref[grid_id] - group_start = group_offsets_ref[group_id] - group_end = group_offsets_ref[group_id + 1] - row_start = m_tile_id * tm - row_indices = row_start + jnp.arange(tm, dtype=jnp.int32) - valid_rows = (row_indices >= group_start) & (row_indices < group_end) - - # Mask lhs and rhs rows outside group boundaries. - lhs_block = lhs_ref[...].astype(jnp.float32) - rhs_block = rhs_ref[...].astype(jnp.float32) - - row_mask_lhs = jnp.broadcast_to(valid_rows[:, None], (tm, tk)) - row_mask_rhs = jnp.broadcast_to(valid_rows[:, None], (tm, tn)) - - lhs_block = jnp.where(row_mask_lhs, lhs_block, 0.0) - rhs_block = jnp.where(row_mask_rhs, rhs_block, 0.0) - - # Accumulate: lhs^T @ rhs -> [tk, tn] - acc_ref[...] += lax.dot(lhs_block.swapaxes(0, 1), rhs_block) - - # Store when group is about to change or this is the last active tile. - num_active = pl.num_programs(2) - next_group_id = jnp.where( - grid_id < num_active - 1, - group_ids_ref[grid_id + 1], - -1, - ) - is_last_for_group = (group_id != next_group_id) | (grid_id == num_active - 1) - - @pl.when(is_last_for_group) - def _store(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) - - -@functools.partial( - jax.jit, - static_argnames=["tiling", "preferred_element_type", "interpret"], -) -def tgmm_forward( - lhs: jax.Array, - rhs: jax.Array, - group_sizes: jax.Array, - tiling: tuple[int, int, int] | None = None, - preferred_element_type: Any = None, - interpret: bool | None = None, -) -> jax.Array: - """Transposed grouped matrix multiplication forward pass. - - Computes per-group weight gradients: for each group *i* with rows - ``[start_i, end_i)``: - - ``out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :]`` - - Args: - lhs: ``[m, k]`` input activations in bfloat16 / float32. - rhs: ``[m, n]`` gradient tensor (or second operand). - group_sizes: ``[num_groups]`` int32, number of rows per group. - tiling: ``(tm, tk, tn)`` tile sizes or *None* for auto. - preferred_element_type: Output dtype; defaults to ``lhs.dtype``. - interpret: If *None* falls back to ``PALLAS_INTERPRET`` env var. - - Returns: - ``[num_groups, k, n]`` per-group outer products. - """ - assert lhs.ndim == 2, f"lhs must be 2D, got {lhs.ndim}D" - assert rhs.ndim == 2, f"rhs must be 2D, got {rhs.ndim}D" - assert group_sizes.ndim == 1, f"group_sizes must be 1D, got {group_sizes.ndim}D" - assert lhs.shape[0] == rhs.shape[0], ( - f"lhs and rhs must have same m dim, got {lhs.shape[0]} vs {rhs.shape[0]}" - ) - - if interpret is None: - interpret = get_interpret() - if preferred_element_type is None: - preferred_element_type = lhs.dtype - - m, k = lhs.shape - n = rhs.shape[1] - num_groups = group_sizes.shape[0] - tm, tk, tn = _validate_tiling(tiling, m, k, n) - - # Pad m to multiple of tm. - m_padded = ((m + tm - 1) // tm) * tm - if m_padded > m: - lhs = jnp.pad(lhs, ((0, m_padded - m), (0, 0))) - rhs = jnp.pad(rhs, ((0, m_padded - m), (0, 0))) - - # Pad k to multiple of tk. - k_padded = ((k + tk - 1) // tk) * tk - if k_padded > k: - lhs = jnp.pad(lhs, ((0, 0), (0, k_padded - k))) - - # Pad n to multiple of tn. - n_padded = ((n + tn - 1) // tn) * tn - if n_padded > n: - rhs = jnp.pad(rhs, ((0, 0), (0, n_padded - n))) - - tiles_k = k_padded // tk - tiles_n = n_padded // tn - - # Build group metadata with visit_empty_groups=True. - group_metadata, num_active_tiles_val = make_group_metadata( - group_sizes=group_sizes, - m=m_padded, - tm=tm, - visit_empty_groups=True, - ) - group_offsets, group_ids, m_tile_ids = group_metadata - - grid = (tiles_n, tiles_k, num_active_tiles_val) - - # --- Index maps --- - # Note: BlockSpec multiplies returned block indices by the corresponding - # block_shape dimension. Axes with ``None`` in block_shape are element - # indices (not multiplied). - def lhs_index_map(n_i, k_i, grid_id, group_metadata_ref): - del n_i - m_tile_ids_ref = group_metadata_ref[2] - return (m_tile_ids_ref[grid_id], k_i) - - def rhs_index_map(n_i, k_i, grid_id, group_metadata_ref): - del k_i - m_tile_ids_ref = group_metadata_ref[2] - return (m_tile_ids_ref[grid_id], n_i) - - def out_index_map(n_i, k_i, grid_id, group_metadata_ref): - group_ids_ref = group_metadata_ref[1] - return (group_ids_ref[grid_id], k_i, n_i) - - out_shape = jax.ShapeDtypeStruct( - (num_groups, k_padded, n_padded), preferred_element_type - ) - - kernel_fn = functools.partial( - _tgmm_kernel, - tm=tm, - tk=tk, - tn=tn, - ) - - result = pl.pallas_call( - kernel_fn, - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=grid, - in_specs=[ - pl.BlockSpec((tm, tk), lhs_index_map), - pl.BlockSpec((tm, tn), rhs_index_map), - ], - out_specs=pl.BlockSpec((None, tk, tn), out_index_map), - scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], - ), - compiler_params=pltpu.CompilerParams( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - ), - out_shape=out_shape, - interpret=interpret, - )( - (group_offsets, group_ids, m_tile_ids), - lhs, - rhs, - ) - - # Un-pad output. - return result[:, :k, :n] - - -# ============================================================================ -# Differentiable GMM with custom_vjp -# ============================================================================ - - -@functools.partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - tiling: tuple[int, int, int] = (128, 128, 128), - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, -) -> jnp.ndarray: - """Differentiable grouped matrix multiplication. - - For each group i with rows [start_i, end_i): - out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] - - Supports automatic differentiation via custom_vjp: - - dlhs: computed via gmm with transposed rhs - - drhs: computed via tgmm - - Args: - lhs: [m, k] bf16 input activations, rows sorted by group. - rhs: [num_groups, k, n] bf16 per-group weights. - group_sizes: [num_groups] int32, row count per group. - tiling: (tm, tk, tn) tile sizes. - transpose_rhs: If True, use rhs as [num_groups, n, k] (transposed). - preferred_element_type: Output dtype. - - Returns: - [m, output_n] tensor. - """ - return _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type)[ - 0 - ] - - -def _gmm_fwd(lhs, rhs, group_sizes, tiling, transpose_rhs, preferred_element_type): - out = gmm_forward( - lhs, - rhs, - group_sizes, - tiling=tiling, - transpose_rhs=transpose_rhs, - preferred_element_type=preferred_element_type, - ) - return out, (lhs, rhs, group_sizes) - - -def _gmm_bwd(tiling, transpose_rhs, preferred_element_type, residuals, grad): - lhs, rhs, group_sizes = residuals - - # dlhs = grad @ rhs^T per group - dlhs = gmm_forward( - grad, - rhs, - group_sizes, - tiling=tiling, - transpose_rhs=not transpose_rhs, - preferred_element_type=preferred_element_type, - ).astype(lhs.dtype) - - # drhs = lhs^T @ grad per group - drhs = tgmm_forward( - lhs, - grad, - group_sizes, - tiling=tiling, - preferred_element_type=preferred_element_type, - ).astype(rhs.dtype) - - return dlhs, drhs, None - - -gmm.defvjp(_gmm_fwd, _gmm_bwd) diff --git a/tops/ops/gmm/metadata.py b/tops/ops/gmm/metadata.py deleted file mode 100644 index 46e288cb..00000000 --- a/tops/ops/gmm/metadata.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Group metadata construction for GMM kernel scheduling. - -Builds CSR-like metadata arrays that map Pallas grid indices to (group_id, -m_tile_id) pairs. This enables the GMM kernel to process ragged groups of -varying sizes using a flat 1-D grid over m-tiles. -""" - -from __future__ import annotations - -import jax.numpy as jnp - -GroupMetadata = tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] - - -def make_group_metadata( - *, - group_sizes: jnp.ndarray, - m: int, - tm: int, - visit_empty_groups: bool = False, -) -> tuple[GroupMetadata, jnp.ndarray]: - """Build scheduling metadata for grouped matmul. - - Maps each grid index in the ``num_active_tiles`` dimension to a - ``(group_id, m_tile_id)`` pair so the Pallas kernel knows which group - and which row-tile to process. - - Args: - group_sizes: [num_groups] int32 -- number of rows per group. - m: Total number of rows in lhs (may exceed sum(group_sizes) due to - padding). - tm: Row-dimension tile size. - visit_empty_groups: If True, allocate one tile per empty group (needed - by tgmm to zero the output for empty groups). - - Returns: - (group_offsets, group_ids, m_tile_ids): Metadata arrays. - - group_offsets: [num_groups + 1] int32, CSR-style row offsets. - - group_ids: [tiles_m + num_groups - 1] int32, group for each - active tile. - - m_tile_ids: [tiles_m + num_groups - 1] int32, row-tile index - for each active tile. - num_active_tiles: Scalar int32, how many entries in group_ids / - m_tile_ids are valid. - """ - assert group_sizes.ndim == 1, "group_sizes must be 1-D" - assert m > 0, "m must be positive" - assert tm > 0, "tm must be positive" - - num_groups = group_sizes.shape[0] - - # --- CSR-style offsets --- - group_ends = jnp.cumsum(group_sizes) - group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) - - # --- Compute tile ranges for each group --- - group_starts = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]) - - # First tile index touched by each group (floor division) - first_tile = group_starts // tm - # Last tile index touched by each group (ceil division - 1, i.e. inclusive) - # For empty groups, last_tile < first_tile so they produce 0 tiles. - last_tile_plus_one = (group_ends + tm - 1) // tm - # Clamp empty groups to produce 0 tiles - tiles_per_group = jnp.where( - group_sizes == 0, - 0, - last_tile_plus_one - first_tile, - ).astype(jnp.int32) - - if visit_empty_groups: - tiles_per_group = jnp.where(group_sizes == 0, 1, tiles_per_group) - - tiles_m = (m + tm - 1) // tm - # Worst case: each group boundary can split a tile, adding at most - # (num_groups - 1) extra visits. - total_len = tiles_m + num_groups - 1 - - # --- group_ids: map grid index -> group --- - group_ids = jnp.repeat( - jnp.arange(num_groups, dtype=jnp.int32), - tiles_per_group, - total_repeat_length=total_len, - ) - - # --- m_tile_ids: map grid index -> row-tile --- - # For each group, the tile indices are first_tile, first_tile+1, ..., - # first_tile + tiles_per_group - 1. - # We build this by creating a per-slot offset within the group, then - # adding the group's first_tile. - - # First, compute the starting offset for each group's tiles in the - # output array using cumsum of tiles_per_group. - group_tile_offsets = jnp.concatenate( - [jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(tiles_per_group)] - ) - - # For empty groups with visit_empty_groups, first_tile doesn't make sense; - # use 0 as placeholder (the tile id for empty group visits). - effective_first_tile = jnp.where(group_sizes == 0, 0, first_tile).astype(jnp.int32) - - # Build m_tile_ids using a scatter approach: - # For each active slot i, m_tile_ids[i] = first_tile[group_ids[i]] + local_offset - # where local_offset = i - group_tile_offsets[group_ids[i]] - - slot_indices = jnp.arange(total_len, dtype=jnp.int32) - local_offsets = slot_indices - group_tile_offsets[group_ids] - m_tile_ids = effective_first_tile[group_ids] + local_offsets - - num_active_tiles = tiles_per_group.sum() - return (group_offsets, group_ids, m_tile_ids), num_active_tiles From addc5d880313b608694416937666c359e03e05e7 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:02:35 +0800 Subject: [PATCH 11/15] test(gmm): use bfloat16 inputs in CPU reference tests Change all test inputs from float32 to bfloat16 to match the intended usage pattern (bf16 mul + f32 accumulation). Expected values remain f32 as the reference now outputs f32. Tolerances relaxed to 1e-2. Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_cpu_ref.py | 40 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/ops/gmm/test_cpu_ref.py b/tests/ops/gmm/test_cpu_ref.py index db1fa3d9..86335bb9 100644 --- a/tests/ops/gmm/test_cpu_ref.py +++ b/tests/ops/gmm/test_cpu_ref.py @@ -19,54 +19,54 @@ class TestGmmRef: def test_single_group(self): """Single group = standard matmul.""" - lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) - rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.float32) + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16) gs = jnp.array([2], dtype=jnp.int32) out = gmm_ref(lhs, rhs, gs) - expected = lhs # identity matmul - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + expected = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) def test_two_groups(self): """Two groups with different weights.""" - lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.float32) + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) rhs = jnp.array( [ [[2.0, 0.0], [0.0, 2.0]], # group 0: scale by 2 [[0.0, 1.0], [1.0, 0.0]], # group 1: swap columns ], - dtype=jnp.float32, + dtype=jnp.bfloat16, ) gs = jnp.array([1, 2], dtype=jnp.int32) out = gmm_ref(lhs, rhs, gs) expected = jnp.array([[2.0, 0.0], [1.0, 0.0], [1.0, 1.0]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) def test_empty_group(self): """Empty group produces zeros for those rows (none exist).""" - lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) rhs = jnp.array( [ [[1.0], [1.0]], # group 0: empty [[1.0], [1.0]], # group 1: 1 row ], - dtype=jnp.float32, + dtype=jnp.bfloat16, ) gs = jnp.array([0, 1], dtype=jnp.int32) out = gmm_ref(lhs, rhs, gs) expected = jnp.array([[3.0]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) def test_transpose_rhs(self): """transpose_rhs transposes each rhs[i] before matmul.""" - lhs = jnp.array([[1.0, 2.0]], dtype=jnp.float32) - rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.float32) + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) gs = jnp.array([1], dtype=jnp.int32) # Without transpose: lhs [1,2] @ rhs [2,2] = [1*3+2*5, 1*4+2*6] = [13, 16] out_normal = gmm_ref(lhs, rhs, gs) - np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-5) + np.testing.assert_allclose(np.array(out_normal), [[13.0, 16.0]], atol=1e-2) # With transpose: lhs [1,2] @ rhs.T [2,2] = [1*3+2*4, 1*5+2*6] = [11, 17] out_transposed = gmm_ref(lhs, rhs, gs, transpose_rhs=True) - np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-5) + np.testing.assert_allclose(np.array(out_transposed), [[11.0, 17.0]], atol=1e-2) class TestTgmmRef: @@ -74,24 +74,24 @@ class TestTgmmRef: def test_single_group(self): """Single group: lhs^T @ rhs.""" - lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) - rhs = jnp.array([[5.0], [6.0]], dtype=jnp.float32) + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16) gs = jnp.array([2], dtype=jnp.int32) out = tgmm_ref(lhs, rhs, gs) # lhs^T [2,2] @ rhs [2,1] = [[1*5+3*6], [2*5+4*6]] = [[23], [34]] expected = jnp.array([[[23.0], [34.0]]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) def test_two_groups(self): """Two groups produce separate outer products.""" - lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.float32) - rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.float32) + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16) gs = jnp.array([1, 2], dtype=jnp.int32) out = tgmm_ref(lhs, rhs, gs) # Group 0: [1]^T @ [4] = [[4]] # Group 1: [2,3]^T @ [5,6] = [[2*5+3*6]] = [[28]] expected = jnp.array([[[4.0]], [[28.0]]], dtype=jnp.float32) - np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-5) + np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) if __name__ == "__main__": From 7fd19cb953f90d4fa4f271111e10f6b832bd0eba Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:26:35 +0800 Subject: [PATCH 12/15] chore(gmm): add tokamax as optional TPU dependency Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0a5e44be..7c3e1ab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ gpu = [ tpu = [ "jax[tpu]>=0.8.1", "torch", + "tokamax>=0.0.12", ] profile = [ "xprof==2.22.0", From 6577a37aadf60c7a32327d270093b03b195ff283 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:30:09 +0800 Subject: [PATCH 13/15] feat(gmm): add JIT-compilable gmm/tgmm using lax.scan Implements _gmm_impl and _tgmm_impl with scan-based grouped matmul using dynamic_slice for TPU/JIT compatibility. bf16 inputs with f32 accumulation to match TPU MXU semantics. Adds tests verifying JAX implementations match CPU reference outputs. Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_cpu_ref.py | 60 ++++++++++++++ tops/ops/gmm/__init__.py | 4 + tops/ops/gmm/gmm.py | 151 ++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+) create mode 100644 tops/ops/gmm/gmm.py diff --git a/tests/ops/gmm/test_cpu_ref.py b/tests/ops/gmm/test_cpu_ref.py index 86335bb9..b1b06635 100644 --- a/tests/ops/gmm/test_cpu_ref.py +++ b/tests/ops/gmm/test_cpu_ref.py @@ -12,6 +12,7 @@ import pytest from tops.cpu.ops.gmm import gmm_ref, tgmm_ref +from tops.ops.gmm import gmm, tgmm class TestGmmRef: @@ -94,5 +95,64 @@ def test_two_groups(self): np.testing.assert_allclose(np.array(out), np.array(expected), atol=1e-2) +class TestGmmJax: + """Test JIT-compilable gmm against CPU reference.""" + + def test_single_group(self): + """Single group: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = gmm(lhs, rhs, gs) + ref = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_two_groups(self): + """Two groups: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [ + [[2.0, 0.0], [0.0, 2.0]], + [[0.0, 1.0], [1.0, 0.0]], + ], + dtype=jnp.bfloat16, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = gmm(lhs, rhs, gs) + ref = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_transpose_rhs(self): + """transpose_rhs: gmm matches gmm_ref.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) + gs = jnp.array([1], dtype=jnp.int32) + out = gmm(lhs, rhs, gs, transpose_rhs=True) + ref = gmm_ref(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + +class TestTgmmJax: + """Test JIT-compilable tgmm against CPU reference.""" + + def test_single_group(self): + """Single group: tgmm matches tgmm_ref.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + out = tgmm(lhs, rhs, gs) + ref = tgmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + def test_two_groups(self): + """Two groups: tgmm matches tgmm_ref.""" + lhs = jnp.array([[1.0], [2.0], [3.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[4.0], [5.0], [6.0]], dtype=jnp.bfloat16) + gs = jnp.array([1, 2], dtype=jnp.int32) + out = tgmm(lhs, rhs, gs) + ref = tgmm_ref(lhs, rhs, gs) + np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tops/ops/gmm/__init__.py b/tops/ops/gmm/__init__.py index 1f7910f8..1b99b467 100644 --- a/tops/ops/gmm/__init__.py +++ b/tops/ops/gmm/__init__.py @@ -1 +1,5 @@ """Public API for grouped matrix multiplication.""" + +from .gmm import gmm, tgmm + +__all__ = ["gmm", "tgmm"] diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py new file mode 100644 index 00000000..cc61ff70 --- /dev/null +++ b/tops/ops/gmm/gmm.py @@ -0,0 +1,151 @@ +"""JIT-compilable Grouped Matrix Multiplication for TPU. + +Uses lax.scan + dynamic_slice for TPU-compatible grouped matmul. +bf16 multiplication with f32 accumulation to match TPU MXU semantics. +""" + +import jax +import jax.numpy as jnp +from jax import lax + + +def _gmm_impl( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Core scan-based grouped matmul. + + For each group i with rows [start_i, end_i): + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Uses lax.scan over groups with dynamic_slice for JIT/TPU compatibility. + bf16 inputs, f32 accumulation via lax.dot preferred_element_type. + + Args: + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, n, k]. + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, n] output in float32. + """ + m, k = lhs.shape + n = rhs.shape[1] if transpose_rhs else rhs.shape[2] + + offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)]) + + # Pad lhs so dynamic_slice(start, m) never triggers OOB clamping. + lhs_padded = jnp.pad(lhs, ((0, m), (0, 0))) # [2m, k] + + def body(carry, i): + out_padded = carry # [2m, n] + start = offsets[i] + size = group_sizes[i] + + lhs_slice = lax.dynamic_slice(lhs_padded, (start, 0), (m, k)) + rhs_mat = rhs[i] + if transpose_rhs: + rhs_mat = rhs_mat.T + + prod = lax.dot(lhs_slice, rhs_mat, preferred_element_type=jnp.float32) + + # Zero out rows beyond this group's size. + valid = jnp.arange(m) < size + prod = prod * valid[:, None] + + out_padded = lax.dynamic_update_slice(out_padded, prod, (start, 0)) + return out_padded, None + + out_padded = jnp.zeros((2 * m, n), dtype=jnp.float32) + out_padded, _ = lax.scan(body, out_padded, jnp.arange(offsets.shape[0] - 1)) + return out_padded[:m] + + +def _tgmm_impl( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Core scan-based transposed grouped matmul. + + For each group i with rows [start_i, end_i): + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Uses lax.scan over groups with dynamic_slice for JIT/TPU compatibility. + bf16 inputs, f32 accumulation via lax.dot preferred_element_type. + + Args: + lhs: [m, k] input activations. + rhs: [m, n] second operand. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group products in float32. + """ + m, k = lhs.shape + n = rhs.shape[1] + + offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(group_sizes)]) + + lhs_padded = jnp.pad(lhs, ((0, m), (0, 0))) # [2m, k] + rhs_padded = jnp.pad(rhs, ((0, m), (0, 0))) # [2m, n] + + def body(_, i): + start = offsets[i] + size = group_sizes[i] + + lhs_slice = lax.dynamic_slice(lhs_padded, (start, 0), (m, k)) + rhs_slice = lax.dynamic_slice(rhs_padded, (start, 0), (m, n)) + + # Mask invalid rows to zero (masking one operand suffices). + valid = jnp.arange(m) < size + lhs_slice = lhs_slice * valid[:, None] + + result = lax.dot( + lhs_slice.T, rhs_slice, preferred_element_type=jnp.float32 + ) # [k, n] + return None, result + + _, results = lax.scan(body, None, jnp.arange(group_sizes.shape[0])) + return results # [num_groups, k, n] + + +# Temporary thin wrappers — will be replaced with custom_vjp in Task 4 +def gmm(lhs, rhs, group_sizes, transpose_rhs=False): + """Grouped matrix multiplication. + + Computes per-group matmul: for each group i with rows [start_i, end_i), + out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + + Args: + lhs: [m, k] input activations in bfloat16. + rhs: [num_groups, k, n] per-group weights in bfloat16. + If transpose_rhs=True, shape is [num_groups, n, k]. + group_sizes: [num_groups] int32, number of rows per group. + transpose_rhs: If True, transpose each rhs[i] before matmul. + + Returns: + [m, n] output in float32. + """ + return _gmm_impl(lhs, rhs, group_sizes, transpose_rhs) + + +def tgmm(lhs, rhs, group_sizes): + """Transposed grouped matrix multiplication. + + Computes per-group transposed matmul: for each group i, + out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + + Args: + lhs: [m, k] input activations in bfloat16. + rhs: [m, n] second operand in bfloat16. + group_sizes: [num_groups] int32, number of rows per group. + + Returns: + [num_groups, k, n] per-group products in float32. + """ + return _tgmm_impl(lhs, rhs, group_sizes) From 26d10e3f3b6d4c8c691b187a4a3e21797f79543f Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:32:35 +0800 Subject: [PATCH 14/15] feat(gmm): add custom_vjp for differentiable gmm with forward/backward Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_cpu_ref.py | 64 +++++++++++++++++++++++++++++++++ tops/ops/gmm/gmm.py | 68 ++++++++++++++++++++++++++++------- 2 files changed, 120 insertions(+), 12 deletions(-) diff --git a/tests/ops/gmm/test_cpu_ref.py b/tests/ops/gmm/test_cpu_ref.py index b1b06635..024ba5d1 100644 --- a/tests/ops/gmm/test_cpu_ref.py +++ b/tests/ops/gmm/test_cpu_ref.py @@ -7,6 +7,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[3])) +import jax import jax.numpy as jnp import numpy as np import pytest @@ -154,5 +155,68 @@ def test_two_groups(self): np.testing.assert_allclose(np.array(out), np.array(ref), atol=1e-2) +class TestGmmGrad: + """Test gmm gradient via custom_vjp.""" + + def test_grad_lhs(self): + """Gradient w.r.t. lhs is correct.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + + def loss_fn(x): + return gmm(x, rhs, gs).sum() + + grad = jax.grad(loss_fn)(lhs) + assert grad.shape == lhs.shape + assert not jnp.any(jnp.isnan(grad)) + + def test_grad_rhs(self): + """Gradient w.r.t. rhs is correct.""" + lhs = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[1.0, 0.5], [0.5, 1.0]]], dtype=jnp.bfloat16) + gs = jnp.array([2], dtype=jnp.int32) + + def loss_fn(w): + return gmm(lhs, w, gs).sum() + + grad = jax.grad(loss_fn)(rhs) + assert grad.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad)) + + def test_grad_both(self): + """Gradient w.r.t. both lhs and rhs, two groups.""" + lhs = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=jnp.bfloat16) + rhs = jnp.array( + [[[2.0, 0.0], [0.0, 2.0]], [[0.0, 1.0], [1.0, 0.0]]], + dtype=jnp.bfloat16, + ) + gs = jnp.array([1, 2], dtype=jnp.int32) + + def loss_fn(x, w): + return gmm(x, w, gs).sum() + + grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs) + assert grad_lhs.shape == lhs.shape + assert grad_rhs.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad_lhs)) + assert not jnp.any(jnp.isnan(grad_rhs)) + + def test_grad_transpose_rhs(self): + """Gradient with transpose_rhs=True.""" + lhs = jnp.array([[1.0, 2.0]], dtype=jnp.bfloat16) + rhs = jnp.array([[[3.0, 4.0], [5.0, 6.0]]], dtype=jnp.bfloat16) + gs = jnp.array([1], dtype=jnp.int32) + + def loss_fn(x, w): + return gmm(x, w, gs, transpose_rhs=True).sum() + + grad_lhs, grad_rhs = jax.grad(loss_fn, argnums=(0, 1))(lhs, rhs) + assert grad_lhs.shape == lhs.shape + assert grad_rhs.shape == rhs.shape + assert not jnp.any(jnp.isnan(grad_lhs)) + assert not jnp.any(jnp.isnan(grad_rhs)) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tops/ops/gmm/gmm.py b/tops/ops/gmm/gmm.py index cc61ff70..24082a9d 100644 --- a/tops/ops/gmm/gmm.py +++ b/tops/ops/gmm/gmm.py @@ -4,6 +4,8 @@ bf16 multiplication with f32 accumulation to match TPU MXU semantics. """ +import functools + import jax import jax.numpy as jnp from jax import lax @@ -114,17 +116,50 @@ def body(_, i): return results # [num_groups, k, n] -# Temporary thin wrappers — will be replaced with custom_vjp in Task 4 -def gmm(lhs, rhs, group_sizes, transpose_rhs=False): - """Grouped matrix multiplication. +def _gmm_fwd(lhs, rhs, group_sizes, transpose_rhs): + """Forward rule: compute output and save residuals.""" + out = _gmm_impl(lhs, rhs, group_sizes, transpose_rhs) + return out, (lhs, rhs, group_sizes) + + +def _gmm_bwd(transpose_rhs, residuals, grad): + """Backward rule: compute dlhs via GMM, drhs via TGMM.""" + lhs, rhs, group_sizes = residuals + + # dlhs = grad @ W^T per group + dlhs = _gmm_impl(grad, rhs, group_sizes, not transpose_rhs) + + # drhs: depends on transpose_rhs + if transpose_rhs: + # rhs shape [G, n, k], drhs[i] = grad_slice^T @ lhs_slice + drhs = _tgmm_impl(grad, lhs, group_sizes) + else: + # rhs shape [G, k, n], drhs[i] = lhs_slice^T @ grad_slice + drhs = _tgmm_impl(lhs, grad, group_sizes) + + # group_sizes is int32, not differentiable + return dlhs, drhs, None + - Computes per-group matmul: for each group i with rows [start_i, end_i), +@functools.partial(jax.custom_vjp, nondiff_argnums=(3,)) +def gmm( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + transpose_rhs: bool = False, +) -> jax.Array: + """Grouped matrix multiplication. JIT-compilable, runs on TPU. + + For each group i with rows [start_i, end_i): out[start_i:end_i, :] = lhs[start_i:end_i, :] @ rhs[i] + Uses bf16 multiplication with f32 accumulation. + Differentiable via custom_vjp (forward: GMM, backward: GMM + TGMM). + Args: - lhs: [m, k] input activations in bfloat16. - rhs: [num_groups, k, n] per-group weights in bfloat16. - If transpose_rhs=True, shape is [num_groups, n, k]. + lhs: [m, k] input activations. + rhs: [num_groups, k, n] per-group weights. + If transpose_rhs=True, rhs is [num_groups, n, k]. group_sizes: [num_groups] int32, number of rows per group. transpose_rhs: If True, transpose each rhs[i] before matmul. @@ -134,15 +169,24 @@ def gmm(lhs, rhs, group_sizes, transpose_rhs=False): return _gmm_impl(lhs, rhs, group_sizes, transpose_rhs) -def tgmm(lhs, rhs, group_sizes): - """Transposed grouped matrix multiplication. +gmm.defvjp(_gmm_fwd, _gmm_bwd) + + +def tgmm( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, +) -> jax.Array: + """Transposed grouped matrix multiplication. JIT-compilable, runs on TPU. - Computes per-group transposed matmul: for each group i, + For each group i with rows [start_i, end_i): out[i] = lhs[start_i:end_i, :]^T @ rhs[start_i:end_i, :] + Uses bf16 multiplication with f32 accumulation. + Args: - lhs: [m, k] input activations in bfloat16. - rhs: [m, n] second operand in bfloat16. + lhs: [m, k] input activations. + rhs: [m, n] second operand. group_sizes: [num_groups] int32, number of rows per group. Returns: From 6d312fe03d10b7ab94341c7636044cc4c28e9245 Mon Sep 17 00:00:00 2001 From: sii-xinglong <253108540219@sii.edu.cn> Date: Mon, 6 Apr 2026 17:34:51 +0800 Subject: [PATCH 15/15] test(gmm): add JAX vs tokamax comparison tests for bf16 GMM Co-Authored-By: Claude Opus 4.6 --- tests/ops/gmm/test_gmm_vs_tokamax.py | 187 +++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 tests/ops/gmm/test_gmm_vs_tokamax.py diff --git a/tests/ops/gmm/test_gmm_vs_tokamax.py b/tests/ops/gmm/test_gmm_vs_tokamax.py new file mode 100644 index 00000000..5f9b1d74 --- /dev/null +++ b/tests/ops/gmm/test_gmm_vs_tokamax.py @@ -0,0 +1,187 @@ +"""Compare JAX GMM against tokamax GMM in bf16 on TPU. + +Tests skip automatically if tokamax is not installed or not on TPU. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tops.ops.gmm import gmm +from tops.cpu.ops.gmm import gmm_ref + +# Skip entire module if tokamax is not installed. +tokamax_kernel = pytest.importorskip( + "tokamax._src.ops.ragged_dot.pallas_mosaic_tpu_kernel" +) + +# Skip if not running on TPU. +pytestmark = pytest.mark.skipif( + jax.default_backend() != "tpu", + reason="tokamax comparison requires TPU", +) + + +# --------------------------------------------------------------------------- +# Test cases: TPU-aligned shapes (multiples of 128) +# --------------------------------------------------------------------------- + +CASES = [ + dict(m=128, k=128, n=128, num_groups=1, group_sizes=[128]), + dict(m=256, k=128, n=128, num_groups=4, group_sizes=[64, 64, 64, 64]), + dict(m=512, k=256, n=256, num_groups=8, group_sizes=[64] * 8), + dict(m=384, k=128, n=128, num_groups=4, group_sizes=[128, 64, 128, 64]), +] + + +def _case_id(case): + return f"m{case['m']}_k{case['k']}_n{case['n']}_g{case['num_groups']}" + + +def _make_inputs(case, key=jax.random.PRNGKey(42)): + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16) + rhs = jax.random.normal( + k2, (case["num_groups"], case["k"], case["n"]), dtype=jnp.bfloat16 + ) + gs = jnp.array(case["group_sizes"], dtype=jnp.int32) + return lhs, rhs, gs + + +def _make_inputs_transposed(case, key=jax.random.PRNGKey(42)): + k1, k2 = jax.random.split(key) + lhs = jax.random.normal(k1, (case["m"], case["k"]), dtype=jnp.bfloat16) + # For transpose_rhs, rhs shape is [num_groups, n, k] + rhs = jax.random.normal( + k2, (case["num_groups"], case["n"], case["k"]), dtype=jnp.bfloat16 + ) + gs = jnp.array(case["group_sizes"], dtype=jnp.int32) + return lhs, rhs, gs + + +def _call_tokamax_gmm(lhs, rhs, gs, transpose_rhs=False): + """Call tokamax GMM with the standard calling convention.""" + return tokamax_kernel.gmm( + lhs=lhs, + rhs=rhs, + group_sizes=gs, + precision=jax.lax.Precision.DEFAULT, + out_dtype=jnp.float32, + tiling=(128, 128, 128), + transpose_rhs=transpose_rhs, + interpret=False, + ) + + +# --------------------------------------------------------------------------- +# Forward tests +# --------------------------------------------------------------------------- + + +class TestGmmForwardVsTokamax: + """Compare forward output of JAX gmm vs tokamax gmm.""" + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_bf16(self, case): + lhs, rhs, gs = _make_inputs(case) + jax_out = gmm(lhs, rhs, gs) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs) + np.testing.assert_allclose( + np.array(jax_out), + np.array(tokamax_out), + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_transpose_rhs(self, case): + lhs, rhs, gs = _make_inputs_transposed(case) + jax_out = gmm(lhs, rhs, gs, transpose_rhs=True) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs, transpose_rhs=True) + np.testing.assert_allclose( + np.array(jax_out), + np.array(tokamax_out), + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.parametrize("case", CASES, ids=[_case_id(c) for c in CASES]) + def test_forward_vs_cpu_ref(self, case): + """Both JAX and tokamax should match the CPU reference.""" + lhs, rhs, gs = _make_inputs(case) + jax_out = gmm(lhs, rhs, gs) + tokamax_out = _call_tokamax_gmm(lhs, rhs, gs) + ref_out = gmm_ref(lhs, rhs, gs) + np.testing.assert_allclose( + np.array(jax_out), + np.array(ref_out), + atol=1e-2, + rtol=1e-2, + ) + np.testing.assert_allclose( + np.array(tokamax_out), + np.array(ref_out), + atol=1e-2, + rtol=1e-2, + ) + + +# --------------------------------------------------------------------------- +# Backward tests +# --------------------------------------------------------------------------- + + +class TestGmmBackwardVsTokamax: + """Compare backward gradients of JAX gmm vs tokamax gmm.""" + + @pytest.mark.parametrize("case", CASES[:2], ids=[_case_id(c) for c in CASES[:2]]) + def test_grad_lhs(self, case): + """dlhs should match between JAX and tokamax.""" + lhs, rhs, gs = _make_inputs(case) + + def jax_loss(x): + return gmm(x, rhs, gs).sum() + + def tokamax_loss(x): + return _call_tokamax_gmm(x, rhs, gs).sum() + + jax_grad = jax.grad(jax_loss)(lhs) + tokamax_grad = jax.grad(tokamax_loss)(lhs) + np.testing.assert_allclose( + np.array(jax_grad), + np.array(tokamax_grad), + atol=1e-1, + rtol=1e-1, + ) + + @pytest.mark.parametrize("case", CASES[:2], ids=[_case_id(c) for c in CASES[:2]]) + def test_grad_rhs(self, case): + """drhs should match between JAX and tokamax.""" + lhs, rhs, gs = _make_inputs(case) + + def jax_loss(w): + return gmm(lhs, w, gs).sum() + + def tokamax_loss(w): + return _call_tokamax_gmm(lhs, w, gs).sum() + + jax_grad = jax.grad(jax_loss)(rhs) + tokamax_grad = jax.grad(tokamax_loss)(rhs) + np.testing.assert_allclose( + np.array(jax_grad), + np.array(tokamax_grad), + atol=1e-1, + rtol=1e-1, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])