feat: FP8 blockwise-quantized GMM/TGMM Pallas kernels#184
feat: FP8 blockwise-quantized GMM/TGMM Pallas kernels#184pengchengneo wants to merge 1 commit intomainfrom
Conversation
Add complete FP8 grouped matrix multiply implementation: - Block-wise FP8 e4m3fn quantization with per-block absmax calibration - GMM and TGMM Pallas TPU kernels with subchannel loop support - High-level gmm() API with custom_vjp for automatic differentiation - Default block_size=512 for optimal speed/precision tradeoff (1.35-1.53x vs BF16) - Optimized fp8_dual_quantize sharing single f32 cast across row/column passes - Comprehensive test suite: 69 precision tests with SNR/cosine/Frobenius metrics Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdded a complete grouped matrix multiplication (GMM) module to Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (13)
tests/ops/gmm/test_tgmm_fp8.py (1)
189-205: MoveTestTgmmFp8Subchannelclass beforeif __name__block.The
TestTgmmFp8Subchannelclass is defined after theif __name__ == "__main__"block (lines 189-190), which is unconventional. While pytest will still discover it, the standard practice is to place all test classes before the main block.♻️ Reorganize file structure
Move lines 193-263 (the subchannel test section) to before line 189 (
if __name__).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_tgmm_fp8.py` around lines 189 - 205, The TestTgmmFp8Subchannel test class (and its SUBCHANNEL_CASES data) is defined after the if __name__ == "__main__" block; move the entire subchannel test section (SUBCHANNEL_CASES and class TestTgmmFp8Subchannel) to appear before the if __name__ == "__main__" guard so all tests are declared prior to the script entry point; ensure you relocate the block that defines SUBCHANNEL_CASES and the TestTgmmFp8Subchannel class together to preserve references and imports.tests/ops/gmm/test_common.py (1)
15-25: Consider adding a test forsupports_bfloat16_matmul().The function
supports_bfloat16_matmul()is only indirectly tested throughselect_input_dtype. A direct test would improve coverage and document expected behavior.🧪 Suggested test addition
def test_supports_bfloat16_matmul_returns_bool(): from tops.ops.gmm.common import supports_bfloat16_matmul result = supports_bfloat16_matmul() assert isinstance(result, bool)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_common.py` around lines 15 - 25, Add a direct unit test for supports_bfloat16_matmul() in tests/ops/gmm/test_common.py: import supports_bfloat16_matmul from tops.ops.gmm.common, call it, and assert the return is a bool (e.g., assert isinstance(result, bool)) so the function is explicitly covered and documented; follow the existing test style used for is_tpu() and tpu_generation().tops/ops/gmm/common.py (2)
13-16: Consider narrowing the exception type.The broad
Exceptioncatch (flagged by Ruff BLE001) is intentional here for resilience, but could mask unexpected errors. Consider catching more specific exceptions likeIndexError(no devices) orRuntimeError(JAX initialization issues).🔧 Proposed narrower exception handling
def is_tpu() -> bool: """Check if the current default device is a TPU.""" try: return "TPU" in jax.devices()[0].device_kind - except Exception: + except (IndexError, RuntimeError): return False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gmm/common.py` around lines 13 - 16, The current broad except in the TPU-detection block should be narrowed to avoid masking unexpected errors: replace the blanket `except Exception:` around the `return "TPU" in jax.devices()[0].device_kind` check with specific exceptions such as `IndexError` (no devices) and `RuntimeError` (JAX init issues) so only those are swallowed and other errors propagate; update the exception tuple around the jax.devices()[0].device_kind access accordingly.
19-21: Add error handling or document precondition fortpu_kind().Unlike
is_tpu(), this function doesn't handle the case when no devices are available. If called without first checkingis_tpu(), it will raise anIndexError.🛡️ Option: Add precondition documentation
def tpu_kind() -> str: - """Return the TPU device kind string (e.g., 'TPU v5 lite').""" + """Return the TPU device kind string (e.g., 'TPU v5 lite'). + + Raises: + IndexError: If no JAX devices are available. + """ return jax.devices()[0].device_kind🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gmm/common.py` around lines 19 - 21, The tpu_kind() function can raise IndexError when jax.devices() is empty; update tpu_kind() to check for available devices (e.g., if not jax.devices():) and raise a clear, documented exception (RuntimeError or ValueError) with guidance to call is_tpu() first, or alternatively add a short docstring note that callers must ensure devices exist (mention is_tpu()); reference the tpu_kind() function and the is_tpu() helper so the change is easy to locate.tests/ops/gmm/test_gmm_ref.py (3)
43-58: Prefix unused unpacked variables with underscore.The static analysis correctly identifies that
group_idsandm_tile_idsare unused in this test. Using underscore prefix documents the intentional discard.♻️ Proposed fix
def test_uneven_groups(self): """Groups of different sizes with partial tiles.""" group_sizes = jnp.array([96, 32], dtype=jnp.int32) - (group_offsets, group_ids, m_tile_ids), num_tiles = make_group_metadata( + (group_offsets, _group_ids, _m_tile_ids), num_tiles = make_group_metadata( group_sizes=group_sizes,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_ref.py` around lines 43 - 58, The test_uneven_groups function unpacks make_group_metadata into group_offsets, group_ids, m_tile_ids but never uses group_ids or m_tile_ids; update the unpacking to explicitly discard them by prefixing with underscores (e.g., change the tuple to (group_offsets, _group_ids, _m_tile_ids), num_tiles = make_group_metadata(...)) so static analysis knows these values are intentionally unused while leaving num_tiles and group_offsets unchanged.
87-114: Prefix unused unpacked variables with underscore.Same issue at lines 90 and 105 -
group_idsis unpacked but not used.♻️ Proposed fix
def test_visit_empty_groups(self): """Empty groups get one tile when visit_empty_groups=True.""" group_sizes = jnp.array([64, 0, 64], dtype=jnp.int32) - (_, group_ids, _), num_tiles = make_group_metadata( + (_, _group_ids, _), num_tiles = make_group_metadata( ... ) ... def test_no_visit_empty_groups(self): """Empty groups are skipped when visit_empty_groups=False.""" group_sizes = jnp.array([64, 0, 64], dtype=jnp.int32) - (_, group_ids, _), num_tiles = make_group_metadata( + (_, _group_ids, _), num_tiles = make_group_metadata( ... )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_ref.py` around lines 87 - 114, The tests unpack an unused variable named group_ids from the return of make_group_metadata in test_visit_empty_groups and test_no_visit_empty_groups; update those unpacking sites to prefix unused variables with an underscore (e.g., rename group_ids to _group_ids or replace with a single leading-underscore placeholder) so the linter/no-unused-variable warnings are resolved while keeping the call to make_group_metadata and using num_tiles as before.
187-196: Consider usingcompare_tensorutility as per coding guidelines.The coding guidelines specify using
compare_tensorutility fromtests/utils.pywith appropriate tolerance parameters. Whilenp.testing.assert_allcloseworks, using the project's standard utility ensures consistency.As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_ref.py` around lines 187 - 196, Replace the direct NumPy assertion in test_gmm_ref_matches_manual with the project's compare_tensor utility to follow guidelines: import and call compare_tensor(expected, result) (or compare_tensor(np.array(result), expected) to match shapes) instead of np.testing.assert_allclose, and pass the same tolerances used previously by specifying atol=1e-5, rtol=1e-5 and a sensible max_ulp (e.g., max_ulp=4) so gmm_ref(lhs, rhs, group_sizes) is compared to _manual_gmm(lhs, rhs, group_sizes) using compare_tensor from tests/utils.py.tests/ops/gmm/test_gmm_vjp.py (2)
160-164: Unused variablek2from random key split.The static analysis tool flagged that
k2is unpacked but never used. Since you're usingnp.random.RandomStateforrhs_fullinstead, either usek2consistently or use underscore prefix.♻️ Prefix unused variable
key = jax.random.PRNGKey(42) - k1, k2 = jax.random.split(key) + k1, _k2 = jax.random.split(key) lhs = jax.random.normal(k1, (M, K), dtype=jnp.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_vjp.py` around lines 160 - 164, The variable k2 returned from jax.random.split(key) is unused which triggers a linter warning; either consume k2 when generating rhs_full with JAX (e.g., use jax.random.normal(k2, ...)) or mark it as intentionally unused by renaming to _ (replace "k1, k2 = jax.random.split(key)" with "k1, _ = jax.random.split(key)"). Update the code around the jax.random.split call and the rhs_full generation (currently using np.random.RandomState) so the chosen approach is consistent with how rhs_full is produced.
308-314: Addstrict=Truetozip()call.The static analysis flagged this
zip()lacks an explicitstrict=parameter. Sincegrads_bf16,grads_fp8, and the names list should always have exactly 2 elements each, addingstrict=Truehelps catch length mismatches.♻️ Add strict parameter
- for g_bf16, g_fp8, name in zip(grads_bf16, grads_fp8, ["dlhs", "drhs"]): + for g_bf16, g_fp8, name in zip(grads_bf16, grads_fp8, ["dlhs", "drhs"], strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_vjp.py` around lines 308 - 314, The zip over grads_bf16, grads_fp8 and the literal names list should be strict to catch length mismatches; update the zip(...) invocation in the loop that iterates over grads_bf16 and grads_fp8 (the line using zip(grads_bf16, grads_fp8, ["dlhs", "drhs"])) to pass strict=True so the loop raises if the iterables differ in length, keeping the rest of the loop logic unchanged.tests/ops/gmm/test_gmm_fp8.py (2)
191-194: Redundant import offp8_quantize.The
fp8_quantizefunction is already imported at line 11. This inner import is unnecessary.♻️ Remove redundant import
def test_gmm_fp8_k_not_divisible(self): """K not divisible by block_size raises in fp8_quantize.""" rng = np.random.RandomState(0) lhs = jnp.array(rng.randn(128, 192).astype(np.float32)) with pytest.raises(AssertionError, match="divisible"): - from tops.ops.gmm.quantize import fp8_quantize - fp8_quantize(lhs, axis=-1, block_size=128)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_fp8.py` around lines 191 - 194, Remove the redundant inner import of fp8_quantize inside the pytest.raises block in tests/ops/gmm/test_gmm_fp8.py; the symbol fp8_quantize is already imported at the top of the file (line 11), so delete the line "from tops.ops.gmm.quantize import fp8_quantize" within the with pytest.raises(...) block and ensure the test uses the existing fp8_quantize reference.
87-92: Consider usingcompare_tensorutility per coding guidelines.The test uses
np.testing.assert_allclosedirectly, but the coding guidelines recommend using thecompare_tensorutility fromtests/utils.pywith appropriate tolerance parameters (atol,rtol,max_ulp) when comparing kernel outputs against reference implementations.♻️ Suggested refactor
+from tests.utils import compare_tensor + # ... in test method ... - np.testing.assert_allclose( - np.array(result), - np.array(expected), - rtol=RTOL, - atol=ATOL, - ) + compare_tensor(result, expected, rtol=RTOL, atol=ATOL)As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_gmm_fp8.py` around lines 87 - 92, Replace the direct NumPy comparison with the testing utility: instead of calling np.testing.assert_allclose on the arrays `result` and `expected`, use the `compare_tensor` helper from `tests/utils.py` to compare the kernel output; pass `atol=ATOL`, `rtol=RTOL`, and an appropriate `max_ulp` value (e.g., 1 or a small integer) to the call so `result` and `expected` are compared per project guidelines. Ensure you import `compare_tensor` if not already imported and call it with the same array inputs (converted if necessary) and the tolerance variables used in this test.tests/ops/gmm/test_tgmm.py (1)
75-80: Consider usingcompare_tensorutility per coding guidelines.Similar to other test files in this PR, the tests use
np.testing.assert_allclosedirectly rather than thecompare_tensorutility recommended by the coding guidelines.As per coding guidelines: "Use
compare_tensorutility from tests/utils.py with appropriate tolerance parameters (atol, rtol, max_ulp) when comparing kernel outputs against reference implementations"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/gmm/test_tgmm.py` around lines 75 - 80, The test currently uses np.testing.assert_allclose to compare arrays; replace that with the project's compare_tensor utility from tests/utils.py: call compare_tensor(result, expected, atol=ATOL, rtol=RTOL, max_ulp=<appropriate value>) (or omit max_ulp if not required) so the test uses the standard comparison helper; update the import if missing and ensure you pass the same RTOL/ATOL variables and the expected tensor converted to the same dtype/shape as result before calling compare_tensor.tops/ops/gmm/backend.py (1)
130-139: Histogram-based tile visit counting.The use of
jnp.histogramfor counting tile visits is clever but may have edge cases with floating-point bin boundaries for very large tile counts. The current implementation works well for typical sizes but consider documenting this limitation.The histogram approach is acceptable since
tiles_mis typically small (M / tm where tm >= 128), but an alternative usingjnp.bincountcould be more explicit:# Alternative (if needed for very large grids): # tile_visits = jnp.bincount(partial_tile_ids, length=tiles_m) + 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tops/ops/gmm/backend.py` around lines 130 - 139, The histogram-based counting using jnp.histogram for computing tile_visits from partial_tile_ids can suffer floating-point/bin-boundary edge cases for very large tiles_m; update the implementation to use jnp.bincount(partial_tile_ids, length=tiles_m) + 1 instead of jnp.histogram (or at minimum add a comment documenting the histogram limitation), and ensure m_tile_ids (built from jnp.arange and tile_visits) still uses the updated tile_visits shape/dtype (cast to jnp.int32) so that jnp.repeat call remains correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tops/ops/gmm/ops.py`:
- Line 56: The docstring in tops.ops.gmm.ops wrongly states "block_size:
Quantization block size (default 128)" while the function/module actually
defines block_size: int = 512; update the docstring to reflect the true default
(change the docstring default to 512) for the block_size parameter in this
module (tops.ops.gmm.ops) so the documentation matches the function signature.
---
Nitpick comments:
In `@tests/ops/gmm/test_common.py`:
- Around line 15-25: Add a direct unit test for supports_bfloat16_matmul() in
tests/ops/gmm/test_common.py: import supports_bfloat16_matmul from
tops.ops.gmm.common, call it, and assert the return is a bool (e.g., assert
isinstance(result, bool)) so the function is explicitly covered and documented;
follow the existing test style used for is_tpu() and tpu_generation().
In `@tests/ops/gmm/test_gmm_fp8.py`:
- Around line 191-194: Remove the redundant inner import of fp8_quantize inside
the pytest.raises block in tests/ops/gmm/test_gmm_fp8.py; the symbol
fp8_quantize is already imported at the top of the file (line 11), so delete the
line "from tops.ops.gmm.quantize import fp8_quantize" within the with
pytest.raises(...) block and ensure the test uses the existing fp8_quantize
reference.
- Around line 87-92: Replace the direct NumPy comparison with the testing
utility: instead of calling np.testing.assert_allclose on the arrays `result`
and `expected`, use the `compare_tensor` helper from `tests/utils.py` to compare
the kernel output; pass `atol=ATOL`, `rtol=RTOL`, and an appropriate `max_ulp`
value (e.g., 1 or a small integer) to the call so `result` and `expected` are
compared per project guidelines. Ensure you import `compare_tensor` if not
already imported and call it with the same array inputs (converted if necessary)
and the tolerance variables used in this test.
In `@tests/ops/gmm/test_gmm_ref.py`:
- Around line 43-58: The test_uneven_groups function unpacks make_group_metadata
into group_offsets, group_ids, m_tile_ids but never uses group_ids or
m_tile_ids; update the unpacking to explicitly discard them by prefixing with
underscores (e.g., change the tuple to (group_offsets, _group_ids, _m_tile_ids),
num_tiles = make_group_metadata(...)) so static analysis knows these values are
intentionally unused while leaving num_tiles and group_offsets unchanged.
- Around line 87-114: The tests unpack an unused variable named group_ids from
the return of make_group_metadata in test_visit_empty_groups and
test_no_visit_empty_groups; update those unpacking sites to prefix unused
variables with an underscore (e.g., rename group_ids to _group_ids or replace
with a single leading-underscore placeholder) so the linter/no-unused-variable
warnings are resolved while keeping the call to make_group_metadata and using
num_tiles as before.
- Around line 187-196: Replace the direct NumPy assertion in
test_gmm_ref_matches_manual with the project's compare_tensor utility to follow
guidelines: import and call compare_tensor(expected, result) (or
compare_tensor(np.array(result), expected) to match shapes) instead of
np.testing.assert_allclose, and pass the same tolerances used previously by
specifying atol=1e-5, rtol=1e-5 and a sensible max_ulp (e.g., max_ulp=4) so
gmm_ref(lhs, rhs, group_sizes) is compared to _manual_gmm(lhs, rhs, group_sizes)
using compare_tensor from tests/utils.py.
In `@tests/ops/gmm/test_gmm_vjp.py`:
- Around line 160-164: The variable k2 returned from jax.random.split(key) is
unused which triggers a linter warning; either consume k2 when generating
rhs_full with JAX (e.g., use jax.random.normal(k2, ...)) or mark it as
intentionally unused by renaming to _ (replace "k1, k2 = jax.random.split(key)"
with "k1, _ = jax.random.split(key)"). Update the code around the
jax.random.split call and the rhs_full generation (currently using
np.random.RandomState) so the chosen approach is consistent with how rhs_full is
produced.
- Around line 308-314: The zip over grads_bf16, grads_fp8 and the literal names
list should be strict to catch length mismatches; update the zip(...) invocation
in the loop that iterates over grads_bf16 and grads_fp8 (the line using
zip(grads_bf16, grads_fp8, ["dlhs", "drhs"])) to pass strict=True so the loop
raises if the iterables differ in length, keeping the rest of the loop logic
unchanged.
In `@tests/ops/gmm/test_tgmm_fp8.py`:
- Around line 189-205: The TestTgmmFp8Subchannel test class (and its
SUBCHANNEL_CASES data) is defined after the if __name__ == "__main__" block;
move the entire subchannel test section (SUBCHANNEL_CASES and class
TestTgmmFp8Subchannel) to appear before the if __name__ == "__main__" guard so
all tests are declared prior to the script entry point; ensure you relocate the
block that defines SUBCHANNEL_CASES and the TestTgmmFp8Subchannel class together
to preserve references and imports.
In `@tests/ops/gmm/test_tgmm.py`:
- Around line 75-80: The test currently uses np.testing.assert_allclose to
compare arrays; replace that with the project's compare_tensor utility from
tests/utils.py: call compare_tensor(result, expected, atol=ATOL, rtol=RTOL,
max_ulp=<appropriate value>) (or omit max_ulp if not required) so the test uses
the standard comparison helper; update the import if missing and ensure you pass
the same RTOL/ATOL variables and the expected tensor converted to the same
dtype/shape as result before calling compare_tensor.
In `@tops/ops/gmm/backend.py`:
- Around line 130-139: The histogram-based counting using jnp.histogram for
computing tile_visits from partial_tile_ids can suffer
floating-point/bin-boundary edge cases for very large tiles_m; update the
implementation to use jnp.bincount(partial_tile_ids, length=tiles_m) + 1 instead
of jnp.histogram (or at minimum add a comment documenting the histogram
limitation), and ensure m_tile_ids (built from jnp.arange and tile_visits) still
uses the updated tile_visits shape/dtype (cast to jnp.int32) so that jnp.repeat
call remains correct.
In `@tops/ops/gmm/common.py`:
- Around line 13-16: The current broad except in the TPU-detection block should
be narrowed to avoid masking unexpected errors: replace the blanket `except
Exception:` around the `return "TPU" in jax.devices()[0].device_kind` check with
specific exceptions such as `IndexError` (no devices) and `RuntimeError` (JAX
init issues) so only those are swallowed and other errors propagate; update the
exception tuple around the jax.devices()[0].device_kind access accordingly.
- Around line 19-21: The tpu_kind() function can raise IndexError when
jax.devices() is empty; update tpu_kind() to check for available devices (e.g.,
if not jax.devices():) and raise a clear, documented exception (RuntimeError or
ValueError) with guidance to call is_tpu() first, or alternatively add a short
docstring note that callers must ensure devices exist (mention is_tpu());
reference the tpu_kind() function and the is_tpu() helper so the change is easy
to locate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 59705eef-6096-4154-9ede-6ef5e116367a
📒 Files selected for processing (15)
tests/ops/gmm/__init__.pytests/ops/gmm/test_common.pytests/ops/gmm/test_fp8_precision.pytests/ops/gmm/test_gmm.pytests/ops/gmm/test_gmm_fp8.pytests/ops/gmm/test_gmm_ref.pytests/ops/gmm/test_gmm_vjp.pytests/ops/gmm/test_quantize.pytests/ops/gmm/test_tgmm.pytests/ops/gmm/test_tgmm_fp8.pytops/ops/gmm/__init__.pytops/ops/gmm/backend.pytops/ops/gmm/common.pytops/ops/gmm/ops.pytops/ops/gmm/quantize.py
| ``None`` (default) follows ``quantize``. ``True`` uses FP8 kernels | ||
| in the backward pass (requires ``quantize=True``). ``False`` with | ||
| ``quantize=True`` dequantizes residuals and uses BF16 backward kernels. | ||
| block_size: Quantization block size (default 128). |
There was a problem hiding this comment.
Docstring default value mismatch for block_size.
The docstring states "block_size: Quantization block size (default 128)" but the actual default value on line 29 is block_size: int = 512.
📝 Fix docstring
- block_size: Quantization block size (default 128).
+ block_size: Quantization block size (default 512).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| block_size: Quantization block size (default 128). | |
| block_size: Quantization block size (default 512). |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tops/ops/gmm/ops.py` at line 56, The docstring in tops.ops.gmm.ops wrongly
states "block_size: Quantization block size (default 128)" while the
function/module actually defines block_size: int = 512; update the docstring to
reflect the true default (change the docstring default to 512) for the
block_size parameter in this module (tops.ops.gmm.ops) so the documentation
matches the function signature.
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive FP8 precision validation tests for GMM/TGMM kernels, covering various block sizes, forward and backward passes, and high-level API usage. It also adds new tests for GMM and TGMM kernels, FP8 GMM/TGMM, and FP8 quantization utilities, along with VJP correctness tests for the high-level GMM API. The core gmm and tgmm functions have been updated to support FP8 quantization and subchannel loops. Feedback includes addressing an inconsistency in the default block_size value in tops/ops/gmm/ops.py and refactoring fp8_dual_quantize in tops/ops/gmm/quantize.py to reduce code duplication.
| transpose_rhs: bool = False, | ||
| quantize: bool = False, | ||
| bwd_quantize: bool | None = None, | ||
| block_size: int = 512, |
There was a problem hiding this comment.
There's an inconsistency in the default value for block_size. Here, the default is 512, but the docstring on line 56 states it is 128. Several test helpers also use 128 as a default. The default value is 512 across other implementation files (quantize.py, backend.py). To ensure consistency, please align the default value in the signature with the documentation and intended usage. 128 seems to be a more common value in the tests.
| block_size: int = 512, | |
| block_size: int = 128, |
| def fp8_dual_quantize( | ||
| x: jnp.ndarray, | ||
| block_size: int = 512, | ||
| ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: | ||
| """Quantize *x* along both rows (axis=-1) and columns (axis=-2). | ||
|
|
||
| This is the TE-style dual quantization used for FP8 GMM: the forward | ||
| pass uses row-quantized operands while the backward pass uses | ||
| column-quantized operands for the transposed matmul. | ||
|
|
||
| Optimized to share the f32 cast across both orientations (avoiding | ||
| a redundant full-tensor type conversion). | ||
|
|
||
| Args: | ||
| x: [..., M, K] -- input tensor (at least 2-D). | ||
| block_size: Block size for both row and column quantization. | ||
|
|
||
| Returns: | ||
| (row_q, row_scale, col_q, col_scale) where | ||
| row_q, row_scale: quantized along axis=-1 (row-wise). | ||
| col_q, col_scale: quantized along axis=-2 (column-wise). | ||
| """ | ||
| ndim = x.ndim | ||
| assert ndim >= 2, f"Need at least 2D input, got {ndim}D" | ||
|
|
||
| eps = jnp.finfo(jnp.float32).eps | ||
|
|
||
| # Single f32 cast shared by both row and column quantization. | ||
| x_f32 = x.astype(jnp.float32) | ||
|
|
||
| # --- Row quantization (axis=-1) --- | ||
| ax_row = (ndim - 1) % ndim | ||
| dim_row = x.shape[ax_row] | ||
| assert dim_row % block_size == 0, ( | ||
| f"Axis -1 size {dim_row} must be divisible by block_size {block_size}" | ||
| ) | ||
| num_blocks_row = dim_row // block_size | ||
| shape_row = (*x.shape[:ax_row], num_blocks_row, block_size, *x.shape[ax_row + 1:]) | ||
| blocks_row = jnp.reshape(x_f32, shape_row) | ||
| bax_row = ax_row + 1 | ||
| absmax_row = jnp.max(jnp.abs(blocks_row), axis=bax_row, keepdims=True) | ||
| absmax_row = jnp.maximum(absmax_row, eps) | ||
| row_scale = (absmax_row / FP8_MAX).astype(jnp.float32) | ||
| row_q = jnp.clip(blocks_row / row_scale, -FP8_MAX, FP8_MAX).astype(jnp.float8_e4m3fn) | ||
| row_q = jnp.reshape(row_q, x.shape) | ||
| row_scale = jnp.squeeze(row_scale, axis=bax_row) | ||
|
|
||
| # --- Column quantization (axis=-2) --- | ||
| ax_col = (ndim - 2) % ndim | ||
| dim_col = x.shape[ax_col] | ||
| assert dim_col % block_size == 0, ( | ||
| f"Axis -2 size {dim_col} must be divisible by block_size {block_size}" | ||
| ) | ||
| num_blocks_col = dim_col // block_size | ||
| shape_col = (*x.shape[:ax_col], num_blocks_col, block_size, *x.shape[ax_col + 1:]) | ||
| blocks_col = jnp.reshape(x_f32, shape_col) | ||
| bax_col = ax_col + 1 | ||
| absmax_col = jnp.max(jnp.abs(blocks_col), axis=bax_col, keepdims=True) | ||
| absmax_col = jnp.maximum(absmax_col, eps) | ||
| col_scale = (absmax_col / FP8_MAX).astype(jnp.float32) | ||
| col_q = jnp.clip(blocks_col / col_scale, -FP8_MAX, FP8_MAX).astype(jnp.float8_e4m3fn) | ||
| col_q = jnp.reshape(col_q, x.shape) | ||
| col_scale = jnp.squeeze(col_scale, axis=bax_col) | ||
|
|
||
| return row_q, row_scale, col_q, col_scale |
There was a problem hiding this comment.
The function fp8_dual_quantize duplicates a significant amount of logic from fp8_quantize. To improve maintainability and reduce code duplication, you could refactor it to call fp8_quantize directly for each axis. While this removes the optimization of sharing the float32 cast, it greatly simplifies the code, and the performance impact of the extra cast is likely negligible.
def fp8_dual_quantize(
x: jnp.ndarray,
block_size: int = 512,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Quantize *x* along both rows (axis=-1) and columns (axis=-2).
This is the TE-style dual quantization used for FP8 GMM: the forward
pass uses row-quantized operands while the backward pass uses
column-quantized operands for the transposed matmul.
Args:
x: [..., M, K] -- input tensor (at least 2-D).
block_size: Block size for both row and column quantization.
Returns:
(row_q, row_scale, col_q, col_scale) where
row_q, row_scale: quantized along axis=-1 (row-wise).
col_q, col_scale: quantized along axis=-2 (column-wise).
"""
assert x.ndim >= 2, f"Need at least 2D input, got {x.ndim}D"
row_q, row_scale = fp8_quantize(x, axis=-1, block_size=block_size)
col_q, col_scale = fp8_quantize(x, axis=-2, block_size=block_size)
return row_q, row_scale, col_q, col_scale
Summary
fp8_quantize/fp8_dual_quantizewith per-block absmax calibrationgmm()API withcustom_vjpfor automatic differentiation (forward/dgrad/wgrad)gmm_fp8_ref/tgmm_fp8_ref) with float64 accumulation for precision testingChanged files
tops/ops/gmm/backend.pytops/ops/gmm/quantize.pytops/ops/gmm/ops.pytops/ops/gmm/common.pytops/ops/gmm/__init__.pytests/ops/gmm/test_gmm_fp8.pytests/ops/gmm/test_tgmm_fp8.pytests/ops/gmm/test_fp8_precision.pytests/ops/gmm/test_quantize.pytests/ops/gmm/test_*.pyDownstream usage
ant-pretrain
feat/fp8-gmmbranch depends on this via:Unit tests passing on TPU (20/20): gmm_fp8 fwd, gmm_fp8 dgrad, tgmm_fp8 wgrad with ALModel production dimensions (M=8192, G=32).
Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Tests