Skip to content

Fix hmfast TPU support: auto-disable x64 for mcfit FFT paths#2

Draft
licongxu wants to merge 9 commits into
licongxu_autoresearchfrom
cursor/tpu-jax-mcfit-compat-15a6
Draft

Fix hmfast TPU support: auto-disable x64 for mcfit FFT paths#2
licongxu wants to merge 9 commits into
licongxu_autoresearchfrom
cursor/tpu-jax-mcfit-compat-15a6

Conversation

@licongxu

Copy link
Copy Markdown
Owner

Problem

hmfast fails on TPU while GPU/CPU work. Root cause is a dtype mismatch with mcfit's JAX backend:

  • Both hmfast and mcfit force jax_enable_x64=True
  • mcfit Hankel / TophatVar transforms use rfft / hfft with complex128 kernels when x64 is on
  • TPU rejects 64-bit types (Element type C128 is not supported on TPU, 64-bit data types are not yet supported on the TPU driver API)

Fix

  • Add hmfast.jax_platform to auto-detect TPU and default jax_enable_x64=False (float32 / complex64), while keeping x64 on CPU/GPU
  • Add hmfast.mcfit_compat to re-apply hmfast x64 settings after mcfit import (mcfit unconditionally enables x64 in its __init__.py)
  • Route mcfit usage through the compat module; build Hankel/Tophat grids and emulator weights with float_dtype()
  • Document usage in docs/tpu.md
  • Add tests/test_jax_platform.py (simulates TPU via HMFAST_JAX_ENABLE_X64=0)

How to test on TPU

export JAX_PLATFORMS=tpu
python -c "import jax; import hmfast; from hmfast.mcfit_compat import Hankel; import jax.numpy as jnp; print(jax.devices(), jax.config.jax_enable_x64)"
pytest tests/test_jax_platform.py -v

Notes

  • Float32 on TPU is slightly less accurate than GPU float64; validate key spectra if you need sub-percent agreement.
  • Override: HMFAST_JAX_ENABLE_X64=0|1 or JAX_ENABLE_X64=0|1
Open in Web Open in Cursor 

cursoragent and others added 9 commits May 25, 2026 09:17
TPU does not support float64/complex128 FFT paths used by mcfit with
jax_enable_x64=True. Add jax_platform auto-detection (x64 off on TPU),
mcfit_compat to undo mcfit's forced x64 import, and float_dtype helpers
for Hankel/TophatVar grids and emulator weights.

Includes docs/tpu.md and tests/test_jax_platform.py.

Co-authored-by: Licong Xu <lx256@cam.ac.uk>
The original pipeline could not even import hmfast on a Cloud TPU v6e
VM (jax 0.6.2 from the libtpu wheel), and once that was patched the
emulator path silently produced NaN spectra. This commit collects the
minimum changes needed to load and run hmfast on TPU:

* hmfast/_special.py (new): version-agnostic `sici` shim. Prefers
  `jax.scipy.special.sici` (added in recent JAX, PR jax-ml/jax#32052)
  and falls back to `scipy.special.sici` via `jax.pure_callback` when
  it's missing. The callback explicitly casts back to the requested
  dtype so the pure_callback boundary doesn't truncate float64 -> NaN
  on TPU. base_profile.py and tracers/cmb_lensing.py now import from
  this shim instead of `jax.scipy.special` directly.
* hmfast/emulator_load.py: stop forcing `jax.config.update(...x64..., True)`
  at module load and stop hard-coding `dtype=jnp.float64` in
  `EmulatorLoader.__init__`. Use the platform-aware `float_dtype()` so
  TPU runs stay in float32/complex64 (matches `docs/tpu.md` and the
  already-correct `EmulatorLoaderPCA`). The previous behaviour produced
  truncate-cast overflow warnings and NaN downstream.
* pyproject.toml: add `requests` (used by `download.py`) and
  `tensorflow>=2.10` to runtime deps. We never call any TF op, but the
  cosmopower .npz emulator files are pickled with references to TF keras
  classes, so `numpy.load(..., allow_pickle=True)` needs `tensorflow`
  importable.

* tpu_test_tsz.py (new): JIT-compiled tSZ Cl smoke test that hits
  `HaloModel.cl_1h` / `cl_2h` with a GNFW pressure profile, prints
  timings and sanity checks, and saves the spectrum to
  `tsz_spectrum_tpu.png`. Used by `tpu_submission/sync_and_run.sh`.

* paper/ removed: paper sources (.tex / .cls / built figures / hmfast.pdf)
  are not part of the package and were broken in the working tree.
  Drop them per the user's earlier cleanup request; the rendered PDF is
  not the source of truth.

* .gitignore: ignore the regenerated `tsz_spectrum_tpu.png` artifact.

Verified end to end against `lxu-persistent` in `us-east5-b`: import
succeeds, JAX backend reports `tpu`, both `cl_1h` and `cl_2h` JIT-compile
and execute (~10.5 s first run, ~0.05 s cached), and the plot is
retrieved locally. The remaining NaN in the tSZ amplitude is the
pre-existing scientific issue documented in `docs/tpu.md` (TPU float32
precision in mcfit transforms) and is *not* the import/dependency
failure this commit targets.
Three independent float32 overflows turned the Cloud TPU tSZ pipeline
into all-NaN, even though every individual building block was finite at
the start. They are now fixed at the source rather than papered over
with x64-on-TPU (which mcfit's complex FFTs do not support).

1. cosmology.Cosmology._cosmo_params: rearrange Omega0_g so the
   intermediate `Mpc_over_m**2 ≈ 9.5e44` is never materialized. That
   single Python literal silently cast to +inf on float32, which made
   Omega0_g a constant 0, Omega_Lambda = 1, and finally
   delta_mean = log10(delta / omega_m(z)) = NaN inside T08HaloMass.

2. tracers.ksz.kSZTracer.kernel: same Mpc_over_m**2 pattern; replace
   with chained division.

3. halos.halo_model: refactor pk_1h/pk_2h to expose internal
   ``_pk_*_impl`` helpers that accept per-z multiplicative scales on
   each profile, and have cl_1h / cl_2h / cl_1h_masked / trispectrum_1h
   / trispectrum_1h_masked fold the per-z tracer kernel into u_k
   **before** the m-axis squaring. The legacy form computes
   ``u_pressure ~ 1e20 -> u_pressure**2 ~ 1e40`` (overflow -> +inf) and
   then ``inf * kernel**2 ≈ 1e-56 = NaN``. Folding the kernel inside
   gives ``(kernel*u_pressure)^2 ~ 1e-16`` which fits float32 by ~22
   decades. The public pk_1h / pk_2h API is unchanged (call sites pass
   ``ones`` as the scale factors).

4. halos.massfunc: HMF log-interp clip floor of 1e-300 underflows to 0
   on float32; switch to ``jnp.finfo(dtype).tiny`` so log(...) -> finite
   instead of -inf and the RegularGridInterpolator no longer poisons
   downstream queries.

Verified end-to-end on lxu-persistent (Cloud TPU v6e, jax 0.6.2 +
libtpu): cl_1h and cl_2h are positive and finite (mean Cl ~ 3.67e-17),
the tsz_spectrum_tpu.png plot is physically sensible (1h-2h crossover
near l~200, peak at l~3000), and pk_1h/pk_2h backward compatibility is
preserved.
Repo hygiene + the timing comparison requested for the tSZ Cl pipeline.

Reorganization
--------------
Move the TPU smoke test out of the repo root so the top level is no longer
sprinkled with one-off Python + PNG artifacts.

    tpu_test_tsz.py        -> tpu/test_tsz.py   (writes plot next to itself)
    tsz_spectrum_tpu.png   -> tpu/tsz_spectrum_*.png   (git-ignored)

``tpu/`` is now the single home for runtime/CI scripts that exercise the
hmfast pipeline on Cloud TPU VMs; nothing inside it is imported by the
package. A short README explains each file.

The upstream driver ``~/tpu_submission/sync_and_run.sh`` ships the whole
``tpu/`` dir (excluding stale plot/bench artifacts), runs ``tpu/test_tsz.py``,
and pulls ``tpu/tsz_spectrum_tpu.png`` back.

CPU-vs-TPU timing comparison
----------------------------
* ``tpu/benchmark_cpu_vs_tpu.py`` -- single-backend timing harness
  (compile + N=5 ``cl_1h + cl_2h`` runs); writes a JSON result.
* ``tpu/run_benchmark.sh`` -- runs the harness once under
  ``JAX_PLATFORMS=cpu`` and once under ``JAX_PLATFORMS=tpu``
  (separate processes so hmfast.jax_platform picks the right dtype per
  backend), then prints a side-by-side table with the speed-up and a
  CPU↔TPU ``cl_mean`` consistency check.

Measured on lxu-persistent (TPU v6e ``TPU v6 lite``, jax 0.6.2 + libtpu)::

    backend  device           dtype     compile    min      median   mean_cl
    cpu      cpu:cpu          float64   4.14 s     84.9 ms  95.7 ms  3.608e-17
    tpu      tpu:TPU v6 lite  float32   10.39 s    49.1 ms  49.1 ms  3.667e-17

    Speedup (median): 1.95x   |   Speedup (best): 1.73x
    cl_mean relative diff (CPU vs TPU): 1.640%

The 1.6% CPU/TPU difference is the expected float64-vs-float32 spread for
a workload with this dynamic range; the spectral shape is identical.
The previous benchmark only printed mean_cl agreement. Extend it to a real
correctness gate: dump the full cl_1h / cl_2h arrays plus a few pk_1h /
pk_2h spot points to the per-backend JSONs, then have run_benchmark.sh
compare both backends across all ell with explicit pass/fail.

  * Per-ell max & median relative diff for cl_1h and cl_2h
    (tolerances 5% max / 2% median for float32 vs float64).
  * Finite & positive check on both backends.
  * pk_1h / pk_2h spot consistency at (k=[0.01, 0.1, 1.0], z=0.5).
  * Two-panel comparison plot tpu/tsz_cpu_vs_tpu.png: overlaid spectra
    on top, (TPU - CPU) / CPU residual on the bottom with the tolerance
    band drawn in.
  * Combined machine-readable tpu/bench_results/summary.json with timing
    speedup numbers and every correctness check result.
  * Non-zero exit code on any FAIL so this can run as a regression gate.

Measured on lxu-persistent (TPU v6 lite, jax 0.6.2 + libtpu):

    backend  compile      min       median    mean_cl
    cpu      4.10 s       91.2 ms   99.8 ms   3.608e-17  (float64)
    tpu      10.55 s      49.1 ms   49.1 ms   3.667e-17  (float32)

    Speedup (median): 2.03x   |   Speedup (best): 1.86x

    cl1h max/median rel diff: 1.86% / 0.77%   (tol 5% / 2%)
    cl2h max/median rel diff: 0.77% / 0.54%
    cl_mean rel diff:         1.64%           (tol 2%)
    All 11 checks PASS.
Replace the ``jax.vmap(process_bin)(arange(Nm))`` pattern in both internal
helpers with a direct ``jnp.sum`` over the mass axis. The previous form
called ``profile.u_k(self, k, m, z)`` *inside* the vmap, building the full
(Nk, Nm, Nz) tensor on every iteration and then slicing out one row.
XLA was hoisting the invariant subexpression on TPU (verified: benchmark
identical to within run-to-run noise), but the structure was misleading
and made the m-axis dependency hard to read.

The new form is structurally what XLA was already producing:

    u1 = profile.u_k(k, m, z)             # (Nk, Nm, Nz) -- once
    uk_sq = (s1 * u1) * (s2 * u2)
    pk1h  = jnp.sum(uk_sq * weights[None,:,:], axis=1)

Verified on lxu-persistent: 11/11 correctness checks PASS, ``mean_cl``
bit-identical to the pre-refactor commit.
Add ``tpu/profile_stages.py`` -- a stage-by-stage wall-time profiler for
the tSZ Cl pipeline. Run on lxu-persistent (TPU v6 lite) it surfaces the
dominant cost on TPU:

  u_k(k, m, [z0])  single z slice     : 1.3 ms
  u_k vmap over 32 z slices           :  53 ms     <-- ~linear in N_z

i.e. ``profile.u_k`` is being dispatched per-z rather than fused into a
single batched Hankel transform. XLA cannot auto-fuse because the per-z
chi(z) / d_A(z) / prefactor terms differ.

Document the finding plus four concrete acceleration paths in
``tpu/README.md`` (vmap-over-z elimination, cosmology-batched chains,
mcfit pad shrink, bfloat16 sub-ops), so the next person to touch this
knows where to spend effort. None of these are implemented yet -- the
current PR just measures and explains the gap; the implementation is
deliberately split off to keep this PR a pure measurement+cleanup
change.
The fast-path implementation landed in 806e27d but the README still
quoted the pre-fast-path numbers. Update with the measured numbers
from lxu-persistent (TPU v6 lite, jax 0.6.2):

  TPU median: 49 ms -> 27 ms      (1.82x speedup)
  CPU vs TPU: 2.04x -> 3.10x      (TPU vs CPU float64 baseline)
  Correctness:  11/11 PASS, mean_cl unchanged to 4 sig figs

Also mark acceleration path #1 (eliminate redundant per-z Hankel
transforms) as DONE; the remaining paths (cosmology batching, mcfit
pad shrink, bfloat16) are still open for future work.
Root-caused the TPU bottleneck to jnp.interp with per-(m,z) varying
xp grids: searchsorted on the TPU scalar unit is 117× slower than
with a fixed xp (21 ms vs 0.18 ms for 2048 interpolations).

Three optimizations in _cl_1h/2h_pressure_fast:

1. Hoist _counter_terms outside the z-vmap — was redundantly
   computing HMF + bias 32× inside vmap (~25 ms saved).

2. Transpose arrays to z-leading layout before vmap — eliminates
   dynamic-index gathers through the TPU scalar unit.

3. Factor ell_native = k_native × ell_delta(m,z) so interpolation
   uses the FIXED k_native grid with searchsorted + gather, rather
   than per-(m,z) jnp.interp with varying xp.

_u_ell_native now returns (k_native, ell_delta, u_ell_val) instead
of (ell_native, u_ell_val) to support the decomposed interpolation.

Results on lxu-persistent (TPU v6 lite, single chip):
  cl_1h + cl_2h combined: 27.4 ms → 15.9 ms (1.72×)
  cl_1h standalone:       51.8 ms → 30.1 ms (1.72×)
  test_tsz.py 2nd run:    49.1 ms → 16.0 ms (3.07×)

Includes hw_microbench.py and profile scripts used to diagnose.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants