Skip to content

Add fast validation for fp4#24

Open
b-shi wants to merge 34 commits into
subtile_mxfrom
subtile_mx_fast1
Open

Add fast validation for fp4#24
b-shi wants to merge 34 commits into
subtile_mxfrom
subtile_mx_fast1

Conversation

@b-shi

@b-shi b-shi commented Apr 8, 2026

Copy link
Copy Markdown
Owner

Fast1 Data Initialization Mode (DataInitTypeA/B: 27)

Fast1 is a new client-side init mode for MX FP4 GEMMs that enables fast, closed-form correctness validation without a full CPU GEMM reference.

How it works:

  • ~50% of rows of A and ~50% of columns of B are randomly selected as "active"; the rest are zeroed.
  • All active rows of A share the same K-element pattern of {-1, 0, 1} values; all active columns of B share a separate K-element pattern.
  • Scale blocks are set to either 0 (zero-out) or max (pass-through), tiled consistently with the data pattern.
  • The reference output for any active (m, n) pair is a single integer dot product alpha * dot(patA, patB), gated by the scale block pattern. Inactive positions produce 0.

Why it's faster:
Validation requires only one O(K) dot product instead of an O(M·N·K) CPU GEMM, making it practical for large problem sizes. The closed-form reference is exact for float output and tolerance-bounded (K * epsilon) for BFloat16.

Restrictions: MX FP4 A/B only; bias, activation, and E-output are not supported.

b-shi and others added 30 commits April 2, 2026 11:35
* Add sample subtile impl

* Move allocOffsetRegisters before setupNewTile

* Start adding GR offset calculation

* Rest of logic (no swizzling)

* refacto

* spgr offsets

* Add newserial code

* Add script to debug offsets

* Add unit test for GR offset calculation

* Grid display

* Fix both code and ref test function

* Add DPP quad perm to rocisa

* Apply swizzling (no rotation yet)

* Function swizzling + rotation + test

* Refactor test to have a single output array + add test for SGPRs

* Add debug mode to test + add dynamic wavegroup calculation based on MT

* Fix test runtime issue and check all vgpr offsets

* Add ref test code for 1x4 & 4x1

* Fix tests

* Fixed SGPR offset calculation for 2x2

* Fix more tests

* Add more tests

* Refactor tests

* simplify tests

* Remove unused script

* cleanup

* fix camelCase in ref test code

* cleanup

* Fix typo

---------

Co-authored-by: brianshi <brianshi@amd.com>
* Add tests

* as is

* Add permlane16_swap instruction to rocisa

* Ongoing progress

* Draft for partition A0/A1

* Wave partitioning

* Draft ref code in tests

* Handle 1x4 wavesplit param

* 2x2 test passing

* Draft 1x4 LR wave partitioning

* Fix alginement issue

* Integration testing

* Update integration test

* Fix swizzling pattern on GRA. Only swizzling on even LDS rows

* Subtile based test

* testing A

* Test both A and B

* Remove graonly mode

* Fix 1x4 case

* Move global offset for B after rest of the logic

* cleanup

* cleanup

* Fix ref test code for 4x1

* Fix spgr alloc issue

* Remove tmp test file

* Remove debug prints

* Add test case
* Emit ds_reads

* Add waits for LR and GR

* Init Acc VGPR to Zero

* Add missing bit_length on VLShiftLeftB32

* Insert SNop between VLShiftLeftB32 & VReadfirstlaneB32 for correctness

* Fix gra test ref code for 1x4

* Remove some debug prints
…nfigs (#7)

* 64x64

* Fix MFMA emit code

* Remove label

* cleanup

* cleanup

* Cleanup

* Update tests

* New GR offset calculation (no swizzling yet)

* Refacto

* cleanup

* Re-enable swizzling

* Fix SPGR alloc

* Update M0

* Tensile passing no swizzling

* Fix swizzling

* LDS padding

* as is

* Multiple bugfixes

* Fix 128x64

* Refactor pre-swizzling change

* Add wave specific rotation to swizzling

* Fix gra Test

* Fix LRA test

* Fix roundtrip test

* LdsNumBytes as int

* Use float type for bpe

* Cleanup

* Cleanup

* More cleanup

* cleanup

* Simplify _grSwizzleColIds

* Remove debug label

* Fix typo LDS size calculation
* Add fp4 mfma support

* Allow using Zeros, Ones and Identity for MX types

* Display scales for MX types

* Fix non subtileImpl path bug

* Fix display issue on MX types (PrintTensor option)
…ernel (#6)

* Add sample subtile impl

* Fix issues when disabling subtile impl

* GR Offset calculation (#1)

* Add sample subtile impl

* Move allocOffsetRegisters before setupNewTile

* Start adding GR offset calculation

* Rest of logic (no swizzling)

* refacto

* spgr offsets

* Add newserial code

* Add script to debug offsets

* Add unit test for GR offset calculation

* Grid display

* Fix both code and ref test function

* Add DPP quad perm to rocisa

* Apply swizzling (no rotation yet)

* Function swizzling + rotation + test

* Refactor test to have a single output array + add test for SGPRs

* Add debug mode to test + add dynamic wavegroup calculation based on MT

* Fix test runtime issue and check all vgpr offsets

* Add ref test code for 1x4 & 4x1

* Fix tests

* Fixed SGPR offset calculation for 2x2

* Fix more tests

* Add more tests

* Refactor tests

* simplify tests

* Remove unused script

* cleanup

* fix camelCase in ref test code

* cleanup

* Fix typo

---------

Co-authored-by: brianshi <brianshi@amd.com>

* Enable post-loop code generation, and add some subroutines

* LR offset calculation (#2)

* Add tests

* as is

* Add permlane16_swap instruction to rocisa

* Ongoing progress

* Draft for partition A0/A1

* Wave partitioning

* Draft ref code in tests

* Handle 1x4 wavesplit param

* 2x2 test passing

* Draft 1x4 LR wave partitioning

* Fix alginement issue

* Integration testing

* Update integration test

* Fix swizzling pattern on GRA. Only swizzling on even LDS rows

* Subtile based test

* testing A

* Test both A and B

* Remove graonly mode

* Fix 1x4 case

* Move global offset for B after rest of the logic

* cleanup

* cleanup

* Fix ref test code for 4x1

* Fix spgr alloc issue

* Remove tmp test file

* Remove debug prints

* Add test case

* Add GR load emit logic, and misc fixes (#3)

* gr emit fix

* Emit LR + init ACCVGPR (#4)

* Emit ds_reads

* Add waits for LR and GR

* Init Acc VGPR to Zero

* Add missing bit_length on VLShiftLeftB32

* Insert SNop between VLShiftLeftB32 & VReadfirstlaneB32 for correctness

* Fix gra test ref code for 1x4

* Remove some debug prints

* Add loop and ptr update code

* Update scale offset

* Add tests

* Address review

* Add scale roundtrip e2e test and constraint assertions

Add GR->LDS->LR roundtrip GPU test verifying scale offset consistency
across 4 tile configs x 2 matrices. Add power-of-2 assertion for
scaleBlockSize and matching scaleBlockSize assertions for A/B in
shared GR/LR offset computation. Pass kernel dict to compute_lds_sizes
instead of re-deriving MIWaveGroup from tile dimensions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update fixes

* Fix scale being skipped

* Add flag to print layout

* Fix missed merge conflicts

* Fix missed merge conflicts

* Refactor scale rountrip test with gpu helper fns

* Fix extra spaces

* Fix tests

---------

Co-authored-by: brianshi <brianshi@amd.com>
Co-authored-by: sebvince <115461989+sebvince@users.noreply.github.com>
Co-authored-by: b-shi <bbbrianme@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
* Add optimized storeD code

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address comments in PR, add some misc fixes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Add pk_f16 cvt support

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Enable subtile impl only for gfx950

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…MXInput (#11)

Problem:
When the tensilelite-client is configured with init-mxScaleA=One and
init-mxScaleB=One, the MX scale tensors should contain the value 1.0
(E8M0 byte 127). Instead, they contained 2.0 (E8M0 byte 128), and the
FP4 data tensors contained 0.5 instead of 1.0.

Root cause:
The mxDataGenerator library's DataGeneratorOptions has a forceDenorm
flag that defaults to true. When forceDenorm is true, the generator's
setOne<ocp_e2m1_mxfp4>() function uses a subnormal decomposition of
1.0: it sets the FP4 data to the subnormal value 0.5 (dataSubNormalOneMask)
and the E8M0 scale to 2.0 (Constants::E8M0_2 = 128), so that the product
0.5 * 2.0 = 1.0 is still correct. When forceDenorm is false, it uses the
normal decomposition: FP4 data = 1.0 (oneMask) and scale = 1.0
(Constants::E8M0_1 = 127).

The generateMXInput() function in mxDataGen.cpp never set this option,
inheriting the default forceDenorm=true. This caused init modes like
"Ones" to produce unexpected data/scale values even though the float
product was mathematically correct.

Fix:
Set opt.forceDenorm = false in generateMXInput() so that deterministic
init modes (Ones, Identity, Sequential, etc.) produce the intuitive
normal-form data and scale values.

Impact:
No existing callers are affected:
- The hipblaslt client (testing_matmul.hpp) only allows hpl, trig_float,
  or uniform_01 init methods for MX data, all of which use the Bounded
  or TrigonometricFromFloat code paths that do not call setOne.
- The MXDataGen unit tests all use "Bounded" init method.
- Only the tensilelite client passes "Ones" (via initModeToMXMethod),
  which is the path this fixes.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…kernel (#10)

Enable the MX FP4 scale emit code in the subtile-based kernel
---------
Co-authored-by: Koji Nakajima <Koji.Nakajima@amd.com>
Co-authored-by: Archana Ramalingam <Archana.Ramalingam@amd.com>
Co-authored-by: Brian Shi <brianshi@amd.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
* Remove duplicate scale loads

* Bugfix scheduler (#15)

* Fix test

* Single subtile size attempt

* Unrolling

* bug fix

* cleanup

* refactor allocator

* Simplify allocator

* Dont use allocator for scales

* Use unrolling only when number of partition is odd

* Remove duplicated buffer_load

* Fix duplicate scale load after rebase

* Fix merge conflicts

* Fix beta=0, address comments from PR

---------

Co-authored-by: sebvince <115461989+sebvince@users.noreply.github.com>
* Enable DU > 256, and reduce sgpr allocation

* Address comments from PR
* custom Scale init

* Add env-var fallback

* Fix build issue

* Fix an issue

* Support swizzle case

* Move scale init to mxdatagenerator

* Add tests

* Fixes

* Make mxDataGenerator visible to all tensilelite targets

* Update MXScaleBlockI/J comment
* Fix tensilelite test failures

* minor clean-up

* Add subtile test yaml
b-shi and others added 2 commits April 7, 2026 20:01
* Enable FixSrd2 for A/B

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Address comments from PR

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
@b-shi b-shi requested review from nakajee and sebvince April 8, 2026 13:45
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
@b-shi b-shi force-pushed the subtile_mx_fast1 branch from 9d5a99c to ca49a47 Compare April 8, 2026 13:50
case InitMode::TrigIndAbsCos:
case InitMode::Count:
throw std::runtime_error("Invalid InitMode.");
case InitMode::Fast1:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to put this before line 522?

return MXScale(getValueWithUpperLowerBoundFP<float>());
}

// Fast1: random choice from {-1, 0, 1} — only MX FP4 (Float4x2) is supported.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is a bit confusing.
Can this work with MXBlockA/B=0?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is only supported for MXFP4 for now. Will update the comment 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants