Skip to content

feat: FP8 blockwise-quantized GMM/TGMM Pallas kernels#184

Open
pengchengneo wants to merge 1 commit intomainfrom
feat/gmm-fp8-for-release-v0.3
Open

feat: FP8 blockwise-quantized GMM/TGMM Pallas kernels#184
pengchengneo wants to merge 1 commit intomainfrom
feat/gmm-fp8-for-release-v0.3

Conversation

@pengchengneo
Copy link
Copy Markdown

@pengchengneo pengchengneo commented Apr 9, 2026

Summary

  • Add block-wise FP8 (e4m3fn) quantized GMM and TGMM Pallas TPU kernels with subchannel loop support
  • Implement fp8_quantize / fp8_dual_quantize with per-block absmax calibration
  • High-level gmm() API with custom_vjp for automatic differentiation (forward/dgrad/wgrad)
  • Pure-JAX reference implementations (gmm_fp8_ref / tgmm_fp8_ref) with float64 accumulation for precision testing
  • Comprehensive test suite: 69 precision tests with SNR/cosine/Frobenius/ULP metrics

Changed files

Path Description
tops/ops/gmm/backend.py GMM/TGMM FP8 Pallas kernels + reference impls
tops/ops/gmm/quantize.py Block-wise FP8 quantization (fp8_quantize, fp8_dual_quantize)
tops/ops/gmm/ops.py High-level gmm() API with custom_vjp
tops/ops/gmm/common.py Shared utilities (is_tpu, etc.)
tops/ops/gmm/__init__.py Package exports
tests/ops/gmm/test_gmm_fp8.py GMM FP8 forward precision tests
tests/ops/gmm/test_tgmm_fp8.py TGMM FP8 wgrad precision tests
tests/ops/gmm/test_fp8_precision.py Cross-kernel FP8 precision comparison
tests/ops/gmm/test_quantize.py Quantization unit tests
tests/ops/gmm/test_*.py Additional test modules

Downstream usage

ant-pretrain feat/fp8-gmm branch depends on this via:

uv add git+https://github.com/primatrix/pallas-kernel.git@feat/gmm-fp8-for-release-v0.3

Unit tests passing on TPU (20/20): gmm_fp8 fwd, gmm_fp8 dgrad, tgmm_fp8 wgrad with ALModel production dimensions (M=8192, G=32).

Test plan

  • Downstream ant-pretrain CI unit tests pass (p99.9 ULP ≤ 256, cosine ≥ 0.99)
  • Run pallas-kernel internal test suite on TPU

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added grouped matrix multiplication (GMM) capabilities with support for standard and transposed variants
    • Added FP8 quantization support for memory-efficient computation
    • Added custom backward pass implementations with optional quantized gradients
    • Added configurable tiling and block-size parameters for performance optimization
  • Tests

    • Comprehensive test coverage for all GMM and FP8 operations
    • Precision validation tests comparing quantized vs non-quantized results
    • Gradient correctness verification and VJP backward pass testing

Add complete FP8 grouped matrix multiply implementation:
- Block-wise FP8 e4m3fn quantization with per-block absmax calibration
- GMM and TGMM Pallas TPU kernels with subchannel loop support
- High-level gmm() API with custom_vjp for automatic differentiation
- Default block_size=512 for optimal speed/precision tradeoff (1.35-1.53x vs BF16)
- Optimized fp8_dual_quantize sharing single f32 cast across row/column passes
- Comprehensive test suite: 69 precision tests with SNR/cosine/Frobenius metrics

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

Added a complete grouped matrix multiplication (GMM) module to tops.ops.gmm with core implementations of standard and FP8-quantized kernels, reference implementations, high-level differentiable API with custom VJP, quantization utilities, and comprehensive test coverage across functionality, precision, and gradient correctness.

Changes

Cohort / File(s) Summary
Core GMM Infrastructure
tops/ops/gmm/__init__.py, tops/ops/gmm/common.py, tops/ops/gmm/quantize.py
Package API exports for GMM kernels and utilities; TPU detection and matmul dtype selection (is_tpu, tpu_generation, supports_bfloat16_matmul, select_input_dtype); FP8 block-wise quantization/dequantization with per-block scale calibration (fp8_quantize, fp8_dequantize, fp8_dual_quantize) using float8_e4m3fn format.
Backend Kernels
tops/ops/gmm/backend.py
Pallas TPU kernels for standard (gmm, tgmm) and FP8-quantized (gmm_fp8, tgmm_fp8) grouped matmul with optional transposition, group offsetting, and existing output accumulation; pure-JAX reference implementations (*_ref variants); group metadata construction for CSR-style tile/group mapping with optional empty-group handling.
High-Level Differentiable API
tops/ops/gmm/ops.py
Differentiable gmm() function with custom VJP supporting optional FP8 quantization, group offsetting, and backward quantization control; auto-selects interpret mode based on device; forward path dispatches to quantized/non-quantized backend kernels; backward computes dlhs/drhs via FP8 or BF16 paths depending on bwd_quantize setting.
Reference and Utility Tests
tests/ops/gmm/test_common.py, tests/ops/gmm/test_gmm_ref.py, tests/ops/gmm/test_quantize.py
TPU detection utilities validation; make_group_metadata CSR mapping and empty-group handling; gmm_ref/tgmm_ref correctness against NumPy baselines; FP8 quantization roundtrip fidelity, scale correctness, dtype/shape behavior, and edge cases (all-zero inputs, axis misalignment).
Kernel Forward Tests
tests/ops/gmm/test_gmm.py, tests/ops/gmm/test_tgmm.py
Standard GMM/TGMM kernel validation against reference across parameterized shapes/groups; transpose support, output shape/dtype correctness, dtype promotion behavior, existing output accumulation, callable vs static tiling dispatch; error handling for invalid ranks/shapes and non-integer group sizes; group offsetting with zero-padding of inactive rows.
FP8 Kernel Tests
tests/ops/gmm/test_gmm_fp8.py, tests/ops/gmm/test_tgmm_fp8.py
FP8-quantized GMM/TGMM kernel correctness vs reference and BF16 baseline across block sizes; subchannel coverage (tk/tm > block_size); finite output smoke tests; dtype/shape correctness; quantization/alignment constraint validation.
High-Level API and VJP Tests
tests/ops/gmm/test_gmm_vjp.py
Differentiable gmm() API forward correctness, VJP shape/dtype checks, numerical gradient verification via finite differences; group offsetting behavior; FP8 forward/backward smoke tests; bwd_quantize auto-follow, explicit control, and error cases; FP8 backward gradient closeness to BF16.
Comprehensive Precision Tests
tests/ops/gmm/test_fp8_precision.py
Multi-faceted FP8 precision validation including kernel vs BF16/FP8-ref comparisons, subchannel accumulation-order analysis, high-level API precision, VJP gradient finiteness and quality thresholds with forward/backward metric tables; deterministic test generators and shared assertion helpers.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

  • primatrix/ant-pretrain#207: Directly related—both implement FP8 per-block quantization for GMM, adding block-wise quantize/dequantize utilities and FP8 gmm/gmm_fp8 kernel paths with forward/backward support and comprehensive tests.

Suggested reviewers

  • labyrinth-ssr

Poem

🐰 Hops of joy for GMM so grand,
FP8 quantized across the land,
VJPs dance with backward grace,
TPU kernels at their pace,
Tests that measure every trace!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 67.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: FP8 blockwise-quantized GMM/TGMM Pallas kernels' directly and clearly describes the primary feature added: FP8 blockwise-quantized kernels for grouped matrix multiplication operations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/gmm-fp8-for-release-v0.3

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (13)
tests/ops/gmm/test_tgmm_fp8.py (1)

189-205: Move TestTgmmFp8Subchannel class before if __name__ block.

The TestTgmmFp8Subchannel class is defined after the if __name__ == "__main__" block (lines 189-190), which is unconventional. While pytest will still discover it, the standard practice is to place all test classes before the main block.

♻️ Reorganize file structure

Move lines 193-263 (the subchannel test section) to before line 189 (if __name__).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_tgmm_fp8.py` around lines 189 - 205, The
TestTgmmFp8Subchannel test class (and its SUBCHANNEL_CASES data) is defined
after the if __name__ == "__main__" block; move the entire subchannel test
section (SUBCHANNEL_CASES and class TestTgmmFp8Subchannel) to appear before the
if __name__ == "__main__" guard so all tests are declared prior to the script
entry point; ensure you relocate the block that defines SUBCHANNEL_CASES and the
TestTgmmFp8Subchannel class together to preserve references and imports.
tests/ops/gmm/test_common.py (1)

15-25: Consider adding a test for supports_bfloat16_matmul().

The function supports_bfloat16_matmul() is only indirectly tested through select_input_dtype. A direct test would improve coverage and document expected behavior.

🧪 Suggested test addition
def test_supports_bfloat16_matmul_returns_bool():
  from tops.ops.gmm.common import supports_bfloat16_matmul
  result = supports_bfloat16_matmul()
  assert isinstance(result, bool)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_common.py` around lines 15 - 25, Add a direct unit test
for supports_bfloat16_matmul() in tests/ops/gmm/test_common.py: import
supports_bfloat16_matmul from tops.ops.gmm.common, call it, and assert the
return is a bool (e.g., assert isinstance(result, bool)) so the function is
explicitly covered and documented; follow the existing test style used for
is_tpu() and tpu_generation().
tops/ops/gmm/common.py (2)

13-16: Consider narrowing the exception type.

The broad Exception catch (flagged by Ruff BLE001) is intentional here for resilience, but could mask unexpected errors. Consider catching more specific exceptions like IndexError (no devices) or RuntimeError (JAX initialization issues).

🔧 Proposed narrower exception handling
 def is_tpu() -> bool:
   """Check if the current default device is a TPU."""
   try:
     return "TPU" in jax.devices()[0].device_kind
-  except Exception:
+  except (IndexError, RuntimeError):
     return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/common.py` around lines 13 - 16, The current broad except in the
TPU-detection block should be narrowed to avoid masking unexpected errors:
replace the blanket `except Exception:` around the `return "TPU" in
jax.devices()[0].device_kind` check with specific exceptions such as
`IndexError` (no devices) and `RuntimeError` (JAX init issues) so only those are
swallowed and other errors propagate; update the exception tuple around the
jax.devices()[0].device_kind access accordingly.

19-21: Add error handling or document precondition for tpu_kind().

Unlike is_tpu(), this function doesn't handle the case when no devices are available. If called without first checking is_tpu(), it will raise an IndexError.

🛡️ Option: Add precondition documentation
 def tpu_kind() -> str:
-  """Return the TPU device kind string (e.g., 'TPU v5 lite')."""
+  """Return the TPU device kind string (e.g., 'TPU v5 lite').
+
+  Raises:
+    IndexError: If no JAX devices are available.
+  """
   return jax.devices()[0].device_kind
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/common.py` around lines 19 - 21, The tpu_kind() function can
raise IndexError when jax.devices() is empty; update tpu_kind() to check for
available devices (e.g., if not jax.devices():) and raise a clear, documented
exception (RuntimeError or ValueError) with guidance to call is_tpu() first, or
alternatively add a short docstring note that callers must ensure devices exist
(mention is_tpu()); reference the tpu_kind() function and the is_tpu() helper so
the change is easy to locate.
tests/ops/gmm/test_gmm_ref.py (3)

43-58: Prefix unused unpacked variables with underscore.

The static analysis correctly identifies that group_ids and m_tile_ids are unused in this test. Using underscore prefix documents the intentional discard.

♻️ Proposed fix
   def test_uneven_groups(self):
     """Groups of different sizes with partial tiles."""
     group_sizes = jnp.array([96, 32], dtype=jnp.int32)
-    (group_offsets, group_ids, m_tile_ids), num_tiles = make_group_metadata(
+    (group_offsets, _group_ids, _m_tile_ids), num_tiles = make_group_metadata(
       group_sizes=group_sizes,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_ref.py` around lines 43 - 58, The test_uneven_groups
function unpacks make_group_metadata into group_offsets, group_ids, m_tile_ids
but never uses group_ids or m_tile_ids; update the unpacking to explicitly
discard them by prefixing with underscores (e.g., change the tuple to
(group_offsets, _group_ids, _m_tile_ids), num_tiles = make_group_metadata(...))
so static analysis knows these values are intentionally unused while leaving
num_tiles and group_offsets unchanged.

87-114: Prefix unused unpacked variables with underscore.

Same issue at lines 90 and 105 - group_ids is unpacked but not used.

♻️ Proposed fix
   def test_visit_empty_groups(self):
     """Empty groups get one tile when visit_empty_groups=True."""
     group_sizes = jnp.array([64, 0, 64], dtype=jnp.int32)
-    (_, group_ids, _), num_tiles = make_group_metadata(
+    (_, _group_ids, _), num_tiles = make_group_metadata(
       ...
     )
     ...

   def test_no_visit_empty_groups(self):
     """Empty groups are skipped when visit_empty_groups=False."""
     group_sizes = jnp.array([64, 0, 64], dtype=jnp.int32)
-    (_, group_ids, _), num_tiles = make_group_metadata(
+    (_, _group_ids, _), num_tiles = make_group_metadata(
       ...
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_ref.py` around lines 87 - 114, The tests unpack an
unused variable named group_ids from the return of make_group_metadata in
test_visit_empty_groups and test_no_visit_empty_groups; update those unpacking
sites to prefix unused variables with an underscore (e.g., rename group_ids to
_group_ids or replace with a single leading-underscore placeholder) so the
linter/no-unused-variable warnings are resolved while keeping the call to
make_group_metadata and using num_tiles as before.

187-196: Consider using compare_tensor utility as per coding guidelines.

The coding guidelines specify using compare_tensor utility from tests/utils.py with appropriate tolerance parameters. While np.testing.assert_allclose works, using the project's standard utility ensures consistency.

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_ref.py` around lines 187 - 196, Replace the direct
NumPy assertion in test_gmm_ref_matches_manual with the project's compare_tensor
utility to follow guidelines: import and call compare_tensor(expected, result)
(or compare_tensor(np.array(result), expected) to match shapes) instead of
np.testing.assert_allclose, and pass the same tolerances used previously by
specifying atol=1e-5, rtol=1e-5 and a sensible max_ulp (e.g., max_ulp=4) so
gmm_ref(lhs, rhs, group_sizes) is compared to _manual_gmm(lhs, rhs, group_sizes)
using compare_tensor from tests/utils.py.
tests/ops/gmm/test_gmm_vjp.py (2)

160-164: Unused variable k2 from random key split.

The static analysis tool flagged that k2 is unpacked but never used. Since you're using np.random.RandomState for rhs_full instead, either use k2 consistently or use underscore prefix.

♻️ Prefix unused variable
     key = jax.random.PRNGKey(42)
-    k1, k2 = jax.random.split(key)
+    k1, _k2 = jax.random.split(key)
     lhs = jax.random.normal(k1, (M, K), dtype=jnp.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vjp.py` around lines 160 - 164, The variable k2
returned from jax.random.split(key) is unused which triggers a linter warning;
either consume k2 when generating rhs_full with JAX (e.g., use
jax.random.normal(k2, ...)) or mark it as intentionally unused by renaming to _
(replace "k1, k2 = jax.random.split(key)" with "k1, _ = jax.random.split(key)").
Update the code around the jax.random.split call and the rhs_full generation
(currently using np.random.RandomState) so the chosen approach is consistent
with how rhs_full is produced.

308-314: Add strict=True to zip() call.

The static analysis flagged this zip() lacks an explicit strict= parameter. Since grads_bf16, grads_fp8, and the names list should always have exactly 2 elements each, adding strict=True helps catch length mismatches.

♻️ Add strict parameter
-    for g_bf16, g_fp8, name in zip(grads_bf16, grads_fp8, ["dlhs", "drhs"]):
+    for g_bf16, g_fp8, name in zip(grads_bf16, grads_fp8, ["dlhs", "drhs"], strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_vjp.py` around lines 308 - 314, The zip over
grads_bf16, grads_fp8 and the literal names list should be strict to catch
length mismatches; update the zip(...) invocation in the loop that iterates over
grads_bf16 and grads_fp8 (the line using zip(grads_bf16, grads_fp8, ["dlhs",
"drhs"])) to pass strict=True so the loop raises if the iterables differ in
length, keeping the rest of the loop logic unchanged.
tests/ops/gmm/test_gmm_fp8.py (2)

191-194: Redundant import of fp8_quantize.

The fp8_quantize function is already imported at line 11. This inner import is unnecessary.

♻️ Remove redundant import
   def test_gmm_fp8_k_not_divisible(self):
     """K not divisible by block_size raises in fp8_quantize."""
     rng = np.random.RandomState(0)
     lhs = jnp.array(rng.randn(128, 192).astype(np.float32))
     with pytest.raises(AssertionError, match="divisible"):
-      from tops.ops.gmm.quantize import fp8_quantize
-
       fp8_quantize(lhs, axis=-1, block_size=128)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_fp8.py` around lines 191 - 194, Remove the redundant
inner import of fp8_quantize inside the pytest.raises block in
tests/ops/gmm/test_gmm_fp8.py; the symbol fp8_quantize is already imported at
the top of the file (line 11), so delete the line "from tops.ops.gmm.quantize
import fp8_quantize" within the with pytest.raises(...) block and ensure the
test uses the existing fp8_quantize reference.

87-92: Consider using compare_tensor utility per coding guidelines.

The test uses np.testing.assert_allclose directly, but the coding guidelines recommend using the compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations.

♻️ Suggested refactor
+from tests.utils import compare_tensor
+
 # ... in test method ...
-    np.testing.assert_allclose(
-      np.array(result),
-      np.array(expected),
-      rtol=RTOL,
-      atol=ATOL,
-    )
+    compare_tensor(result, expected, rtol=RTOL, atol=ATOL)

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_gmm_fp8.py` around lines 87 - 92, Replace the direct NumPy
comparison with the testing utility: instead of calling
np.testing.assert_allclose on the arrays `result` and `expected`, use the
`compare_tensor` helper from `tests/utils.py` to compare the kernel output; pass
`atol=ATOL`, `rtol=RTOL`, and an appropriate `max_ulp` value (e.g., 1 or a small
integer) to the call so `result` and `expected` are compared per project
guidelines. Ensure you import `compare_tensor` if not already imported and call
it with the same array inputs (converted if necessary) and the tolerance
variables used in this test.
tests/ops/gmm/test_tgmm.py (1)

75-80: Consider using compare_tensor utility per coding guidelines.

Similar to other test files in this PR, the tests use np.testing.assert_allclose directly rather than the compare_tensor utility recommended by the coding guidelines.

As per coding guidelines: "Use compare_tensor utility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/gmm/test_tgmm.py` around lines 75 - 80, The test currently uses
np.testing.assert_allclose to compare arrays; replace that with the project's
compare_tensor utility from tests/utils.py: call compare_tensor(result,
expected, atol=ATOL, rtol=RTOL, max_ulp=<appropriate value>) (or omit max_ulp if
not required) so the test uses the standard comparison helper; update the import
if missing and ensure you pass the same RTOL/ATOL variables and the expected
tensor converted to the same dtype/shape as result before calling
compare_tensor.
tops/ops/gmm/backend.py (1)

130-139: Histogram-based tile visit counting.

The use of jnp.histogram for counting tile visits is clever but may have edge cases with floating-point bin boundaries for very large tile counts. The current implementation works well for typical sizes but consider documenting this limitation.

The histogram approach is acceptable since tiles_m is typically small (M / tm where tm >= 128), but an alternative using jnp.bincount could be more explicit:

# Alternative (if needed for very large grids):
# tile_visits = jnp.bincount(partial_tile_ids, length=tiles_m) + 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/backend.py` around lines 130 - 139, The histogram-based counting
using jnp.histogram for computing tile_visits from partial_tile_ids can suffer
floating-point/bin-boundary edge cases for very large tiles_m; update the
implementation to use jnp.bincount(partial_tile_ids, length=tiles_m) + 1 instead
of jnp.histogram (or at minimum add a comment documenting the histogram
limitation), and ensure m_tile_ids (built from jnp.arange and tile_visits) still
uses the updated tile_visits shape/dtype (cast to jnp.int32) so that jnp.repeat
call remains correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tops/ops/gmm/ops.py`:
- Line 56: The docstring in tops.ops.gmm.ops wrongly states "block_size:
Quantization block size (default 128)" while the function/module actually
defines block_size: int = 512; update the docstring to reflect the true default
(change the docstring default to 512) for the block_size parameter in this
module (tops.ops.gmm.ops) so the documentation matches the function signature.

---

Nitpick comments:
In `@tests/ops/gmm/test_common.py`:
- Around line 15-25: Add a direct unit test for supports_bfloat16_matmul() in
tests/ops/gmm/test_common.py: import supports_bfloat16_matmul from
tops.ops.gmm.common, call it, and assert the return is a bool (e.g., assert
isinstance(result, bool)) so the function is explicitly covered and documented;
follow the existing test style used for is_tpu() and tpu_generation().

In `@tests/ops/gmm/test_gmm_fp8.py`:
- Around line 191-194: Remove the redundant inner import of fp8_quantize inside
the pytest.raises block in tests/ops/gmm/test_gmm_fp8.py; the symbol
fp8_quantize is already imported at the top of the file (line 11), so delete the
line "from tops.ops.gmm.quantize import fp8_quantize" within the with
pytest.raises(...) block and ensure the test uses the existing fp8_quantize
reference.
- Around line 87-92: Replace the direct NumPy comparison with the testing
utility: instead of calling np.testing.assert_allclose on the arrays `result`
and `expected`, use the `compare_tensor` helper from `tests/utils.py` to compare
the kernel output; pass `atol=ATOL`, `rtol=RTOL`, and an appropriate `max_ulp`
value (e.g., 1 or a small integer) to the call so `result` and `expected` are
compared per project guidelines. Ensure you import `compare_tensor` if not
already imported and call it with the same array inputs (converted if necessary)
and the tolerance variables used in this test.

In `@tests/ops/gmm/test_gmm_ref.py`:
- Around line 43-58: The test_uneven_groups function unpacks make_group_metadata
into group_offsets, group_ids, m_tile_ids but never uses group_ids or
m_tile_ids; update the unpacking to explicitly discard them by prefixing with
underscores (e.g., change the tuple to (group_offsets, _group_ids, _m_tile_ids),
num_tiles = make_group_metadata(...)) so static analysis knows these values are
intentionally unused while leaving num_tiles and group_offsets unchanged.
- Around line 87-114: The tests unpack an unused variable named group_ids from
the return of make_group_metadata in test_visit_empty_groups and
test_no_visit_empty_groups; update those unpacking sites to prefix unused
variables with an underscore (e.g., rename group_ids to _group_ids or replace
with a single leading-underscore placeholder) so the linter/no-unused-variable
warnings are resolved while keeping the call to make_group_metadata and using
num_tiles as before.
- Around line 187-196: Replace the direct NumPy assertion in
test_gmm_ref_matches_manual with the project's compare_tensor utility to follow
guidelines: import and call compare_tensor(expected, result) (or
compare_tensor(np.array(result), expected) to match shapes) instead of
np.testing.assert_allclose, and pass the same tolerances used previously by
specifying atol=1e-5, rtol=1e-5 and a sensible max_ulp (e.g., max_ulp=4) so
gmm_ref(lhs, rhs, group_sizes) is compared to _manual_gmm(lhs, rhs, group_sizes)
using compare_tensor from tests/utils.py.

In `@tests/ops/gmm/test_gmm_vjp.py`:
- Around line 160-164: The variable k2 returned from jax.random.split(key) is
unused which triggers a linter warning; either consume k2 when generating
rhs_full with JAX (e.g., use jax.random.normal(k2, ...)) or mark it as
intentionally unused by renaming to _ (replace "k1, k2 = jax.random.split(key)"
with "k1, _ = jax.random.split(key)"). Update the code around the
jax.random.split call and the rhs_full generation (currently using
np.random.RandomState) so the chosen approach is consistent with how rhs_full is
produced.
- Around line 308-314: The zip over grads_bf16, grads_fp8 and the literal names
list should be strict to catch length mismatches; update the zip(...) invocation
in the loop that iterates over grads_bf16 and grads_fp8 (the line using
zip(grads_bf16, grads_fp8, ["dlhs", "drhs"])) to pass strict=True so the loop
raises if the iterables differ in length, keeping the rest of the loop logic
unchanged.

In `@tests/ops/gmm/test_tgmm_fp8.py`:
- Around line 189-205: The TestTgmmFp8Subchannel test class (and its
SUBCHANNEL_CASES data) is defined after the if __name__ == "__main__" block;
move the entire subchannel test section (SUBCHANNEL_CASES and class
TestTgmmFp8Subchannel) to appear before the if __name__ == "__main__" guard so
all tests are declared prior to the script entry point; ensure you relocate the
block that defines SUBCHANNEL_CASES and the TestTgmmFp8Subchannel class together
to preserve references and imports.

In `@tests/ops/gmm/test_tgmm.py`:
- Around line 75-80: The test currently uses np.testing.assert_allclose to
compare arrays; replace that with the project's compare_tensor utility from
tests/utils.py: call compare_tensor(result, expected, atol=ATOL, rtol=RTOL,
max_ulp=<appropriate value>) (or omit max_ulp if not required) so the test uses
the standard comparison helper; update the import if missing and ensure you pass
the same RTOL/ATOL variables and the expected tensor converted to the same
dtype/shape as result before calling compare_tensor.

In `@tops/ops/gmm/backend.py`:
- Around line 130-139: The histogram-based counting using jnp.histogram for
computing tile_visits from partial_tile_ids can suffer
floating-point/bin-boundary edge cases for very large tiles_m; update the
implementation to use jnp.bincount(partial_tile_ids, length=tiles_m) + 1 instead
of jnp.histogram (or at minimum add a comment documenting the histogram
limitation), and ensure m_tile_ids (built from jnp.arange and tile_visits) still
uses the updated tile_visits shape/dtype (cast to jnp.int32) so that jnp.repeat
call remains correct.

In `@tops/ops/gmm/common.py`:
- Around line 13-16: The current broad except in the TPU-detection block should
be narrowed to avoid masking unexpected errors: replace the blanket `except
Exception:` around the `return "TPU" in jax.devices()[0].device_kind` check with
specific exceptions such as `IndexError` (no devices) and `RuntimeError` (JAX
init issues) so only those are swallowed and other errors propagate; update the
exception tuple around the jax.devices()[0].device_kind access accordingly.
- Around line 19-21: The tpu_kind() function can raise IndexError when
jax.devices() is empty; update tpu_kind() to check for available devices (e.g.,
if not jax.devices():) and raise a clear, documented exception (RuntimeError or
ValueError) with guidance to call is_tpu() first, or alternatively add a short
docstring note that callers must ensure devices exist (mention is_tpu());
reference the tpu_kind() function and the is_tpu() helper so the change is easy
to locate.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 59705eef-6096-4154-9ede-6ef5e116367a

📥 Commits

Reviewing files that changed from the base of the PR and between bfa3564 and 55f7c9f.

📒 Files selected for processing (15)
  • tests/ops/gmm/__init__.py
  • tests/ops/gmm/test_common.py
  • tests/ops/gmm/test_fp8_precision.py
  • tests/ops/gmm/test_gmm.py
  • tests/ops/gmm/test_gmm_fp8.py
  • tests/ops/gmm/test_gmm_ref.py
  • tests/ops/gmm/test_gmm_vjp.py
  • tests/ops/gmm/test_quantize.py
  • tests/ops/gmm/test_tgmm.py
  • tests/ops/gmm/test_tgmm_fp8.py
  • tops/ops/gmm/__init__.py
  • tops/ops/gmm/backend.py
  • tops/ops/gmm/common.py
  • tops/ops/gmm/ops.py
  • tops/ops/gmm/quantize.py

``None`` (default) follows ``quantize``. ``True`` uses FP8 kernels
in the backward pass (requires ``quantize=True``). ``False`` with
``quantize=True`` dequantizes residuals and uses BF16 backward kernels.
block_size: Quantization block size (default 128).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring default value mismatch for block_size.

The docstring states "block_size: Quantization block size (default 128)" but the actual default value on line 29 is block_size: int = 512.

📝 Fix docstring
-    block_size: Quantization block size (default 128).
+    block_size: Quantization block size (default 512).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_size: Quantization block size (default 128).
block_size: Quantization block size (default 512).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tops/ops/gmm/ops.py` at line 56, The docstring in tops.ops.gmm.ops wrongly
states "block_size: Quantization block size (default 128)" while the
function/module actually defines block_size: int = 512; update the docstring to
reflect the true default (change the docstring default to 512) for the
block_size parameter in this module (tops.ops.gmm.ops) so the documentation
matches the function signature.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive FP8 precision validation tests for GMM/TGMM kernels, covering various block sizes, forward and backward passes, and high-level API usage. It also adds new tests for GMM and TGMM kernels, FP8 GMM/TGMM, and FP8 quantization utilities, along with VJP correctness tests for the high-level GMM API. The core gmm and tgmm functions have been updated to support FP8 quantization and subchannel loops. Feedback includes addressing an inconsistency in the default block_size value in tops/ops/gmm/ops.py and refactoring fp8_dual_quantize in tops/ops/gmm/quantize.py to reduce code duplication.

transpose_rhs: bool = False,
quantize: bool = False,
bwd_quantize: bool | None = None,
block_size: int = 512,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistency in the default value for block_size. Here, the default is 512, but the docstring on line 56 states it is 128. Several test helpers also use 128 as a default. The default value is 512 across other implementation files (quantize.py, backend.py). To ensure consistency, please align the default value in the signature with the documentation and intended usage. 128 seems to be a more common value in the tests.

Suggested change
block_size: int = 512,
block_size: int = 128,

Comment on lines +110 to +174
def fp8_dual_quantize(
x: jnp.ndarray,
block_size: int = 512,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Quantize *x* along both rows (axis=-1) and columns (axis=-2).

This is the TE-style dual quantization used for FP8 GMM: the forward
pass uses row-quantized operands while the backward pass uses
column-quantized operands for the transposed matmul.

Optimized to share the f32 cast across both orientations (avoiding
a redundant full-tensor type conversion).

Args:
x: [..., M, K] -- input tensor (at least 2-D).
block_size: Block size for both row and column quantization.

Returns:
(row_q, row_scale, col_q, col_scale) where
row_q, row_scale: quantized along axis=-1 (row-wise).
col_q, col_scale: quantized along axis=-2 (column-wise).
"""
ndim = x.ndim
assert ndim >= 2, f"Need at least 2D input, got {ndim}D"

eps = jnp.finfo(jnp.float32).eps

# Single f32 cast shared by both row and column quantization.
x_f32 = x.astype(jnp.float32)

# --- Row quantization (axis=-1) ---
ax_row = (ndim - 1) % ndim
dim_row = x.shape[ax_row]
assert dim_row % block_size == 0, (
f"Axis -1 size {dim_row} must be divisible by block_size {block_size}"
)
num_blocks_row = dim_row // block_size
shape_row = (*x.shape[:ax_row], num_blocks_row, block_size, *x.shape[ax_row + 1:])
blocks_row = jnp.reshape(x_f32, shape_row)
bax_row = ax_row + 1
absmax_row = jnp.max(jnp.abs(blocks_row), axis=bax_row, keepdims=True)
absmax_row = jnp.maximum(absmax_row, eps)
row_scale = (absmax_row / FP8_MAX).astype(jnp.float32)
row_q = jnp.clip(blocks_row / row_scale, -FP8_MAX, FP8_MAX).astype(jnp.float8_e4m3fn)
row_q = jnp.reshape(row_q, x.shape)
row_scale = jnp.squeeze(row_scale, axis=bax_row)

# --- Column quantization (axis=-2) ---
ax_col = (ndim - 2) % ndim
dim_col = x.shape[ax_col]
assert dim_col % block_size == 0, (
f"Axis -2 size {dim_col} must be divisible by block_size {block_size}"
)
num_blocks_col = dim_col // block_size
shape_col = (*x.shape[:ax_col], num_blocks_col, block_size, *x.shape[ax_col + 1:])
blocks_col = jnp.reshape(x_f32, shape_col)
bax_col = ax_col + 1
absmax_col = jnp.max(jnp.abs(blocks_col), axis=bax_col, keepdims=True)
absmax_col = jnp.maximum(absmax_col, eps)
col_scale = (absmax_col / FP8_MAX).astype(jnp.float32)
col_q = jnp.clip(blocks_col / col_scale, -FP8_MAX, FP8_MAX).astype(jnp.float8_e4m3fn)
col_q = jnp.reshape(col_q, x.shape)
col_scale = jnp.squeeze(col_scale, axis=bax_col)

return row_q, row_scale, col_q, col_scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function fp8_dual_quantize duplicates a significant amount of logic from fp8_quantize. To improve maintainability and reduce code duplication, you could refactor it to call fp8_quantize directly for each axis. While this removes the optimization of sharing the float32 cast, it greatly simplifies the code, and the performance impact of the extra cast is likely negligible.

def fp8_dual_quantize(
  x: jnp.ndarray,
  block_size: int = 512,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Quantize *x* along both rows (axis=-1) and columns (axis=-2).

  This is the TE-style dual quantization used for FP8 GMM: the forward
  pass uses row-quantized operands while the backward pass uses
  column-quantized operands for the transposed matmul.

  Args:
    x: [..., M, K] -- input tensor (at least 2-D).
    block_size: Block size for both row and column quantization.

  Returns:
    (row_q, row_scale, col_q, col_scale) where
      row_q, row_scale: quantized along axis=-1 (row-wise).
      col_q, col_scale: quantized along axis=-2 (column-wise).
  """
  assert x.ndim >= 2, f"Need at least 2D input, got {x.ndim}D"

  row_q, row_scale = fp8_quantize(x, axis=-1, block_size=block_size)
  col_q, col_scale = fp8_quantize(x, axis=-2, block_size=block_size)

  return row_q, row_scale, col_q, col_scale

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant