Skip to content

Add optional Numba JIT optimization for pure-Python components#205

Merged
JoyMonteiro merged 21 commits intodevelopfrom
numba-optimized-components
Mar 30, 2026
Merged

Add optional Numba JIT optimization for pure-Python components#205
JoyMonteiro merged 21 commits intodevelopfrom
numba-optimized-components

Conversation

@JoyMonteiro
Copy link
Copy Markdown
Member

Summary

  • Adds @njit-decorated kernels with column-parallel prange() loops for 7 pure-Python components: GrayLongwaveRadiation, GridScaleCondensation, HeldSuarez, DryConvectiveAdjustment, BergerSolarInsolation, SlabSurface, and Instellation
  • Numba is an optional dependency — all components fall back to pure Python when Numba is not installed
  • Benchmarks show ~13.8× speedup in a Grey GCM integration vs the unoptimized path
  • Adds a jit_compile() helper in climt/_core/backend.py that applies @njit when Numba is available and is a no-op otherwise
  • Adds parity tests for each optimized component (tests/test_*_optimization.py)
  • Removes the experimental JAX backend (will land on a separate differentiable-climt branch)

Bug fixes found during testing

Three bugs were introduced in the original numba kernel work and are fixed here:

  • GridScaleCondensation: precipitation sign was inverted (p_int[k] - p_int[k+1] should be p_int[k+1] - p_int[k])
  • BergerSolarInsolation: eccentricity formula used zeta coefficients (from CAM 3.0) as degrees when they are radians, producing wrong orbital parameters
  • GrayLongwaveRadiation: flux arrays were float64 but cached regression fixtures stored float32 (from the original implementation); moved to float64 throughout and regenerated fixtures
  • TestSlabSurface: restored test overrides (present on develop, missing from this branch) that set surface_material_density = sea_water_density before stepping tests

Test results

  • pytest tests/test_components.py — 173 passed
  • pytest tests/test_berger_insolation_optimization.py tests/test_dry_convection_optimization.py tests/test_gray_radiation_optimization.py tests/test_gsc_optimization.py tests/test_held_suarez_optimization.py tests/test_instellation_optimization.py tests/test_slab_surface_optimization.py — 7 passed

🤖 Generated with Claude Code

Joy Monteiro and others added 21 commits February 11, 2026 22:43
This commit introduces a high-performance, backend-agnostic implementation of the
Emanuel convection scheme and a reusable abstraction layer for future components.

Key Changes:
- Added : A universal layer for JAX/NumPy detection,
  JIT orchestration, and functional state updates using Numba overloads.
- Added : A pure functional,
  JIT-optimized implementation of the Emanuel scheme achieving ~118x speedup
  over the original Python version and ~75x over Fortran.
- Refactored  to utilize the new V2 implementation.
- Added comprehensive verification and benchmarking suites:
    - : Formal parity regression tests.
    - : side-by-side comparison of active tendencies.
    - : Performance measurement script.
- Added : Detailed technical report on architectural
  decisions and optimization results.

The new implementation maintains strict numerical parity (1e-12) with the original
Python code while providing a clear path toward full differentiability via JAX.
Key Changes:
- Introduced : A JAX-JIT optimized implementation using  and .
- Enhanced  with JAX-specific optimizations and vectorization.
- Refactored  to be backend-agnostic.
- Achieved perfect bit-wise parity between JAX x64, Numba, and Original implementations.
- Updated benchmarking and verification suites to support JAX performance analysis.
- Updated  with JAX-specific performance and architectural details.
…vection

Key Changes:
- Implemented  in  with Pytree-registered containers.
- Registered  in  top-level init for use with .
- Finalized  with  and  for GPU/XLA optimization.
- Refactored  in  to be backend-agnostic.
- Verified ~550x speedup with Numba CPU and enabled JAX Metal (GPU) path.
- Added  to verify gradient computation.
- Updated  with final performance benchmarks on Apple Silicon.
Key Changes:
- Refactored  into a dual-path JIT implementation (Numba/JAX).
- Verified ~3x GPU speedup at 100k columns on Apple M3 Pro.
- Fixed  to return raw JAX arrays for differentiable tracing.
- Added  for parity and grad verification.
- Added  for multi-backend scaling analysis.
- Implemented  placeholder in  for sympl compatibility.
- Remove parallel=True from @njit/@jit_compile across components to fix potential segfaults.
- Fix DryConvectiveAdjustment to use Numba-compatible loops and ensure differentiability.
- Update SlabSurface to correctly index vertical dimensions for surface extraction and fix TracerArrayConversionError in JAX.
- Enhance EmanuelConvectionPythonV3 with transpose bug fixes for diverse input shapes.
- Improve JaxBackend and JaxStateContainer to handle non-numeric data and ensure better sympl integration.
- Update test_emanuel_optimization.py to support JAX-based state containers.
- Add comprehensive benchmarks and optimization tests for DryConvectiveAdjustment, GridScaleCondensation, HeldSuarez, Instellation, and SlabSurface.
- Verify 13.8x speedup in Grey GCM benchmark using Numba-optimized path.
- GridScaleCondensation: fix precipitation sign (p_int[k] - p_int[k+1]
  was negative; reversed to p_int[k+1] - p_int[k])
- BergerSolarInsolation: fix eccentricity formula — zeta coefficients
  from CAM 3.0 are in radians, not degrees; remove the erroneous
  * pi/180 conversion that was applied to the whole argument
- GrayLongwaveRadiation: move tendency calculation out of @njit kernel
  so it runs in numpy (float64 throughout); regenerate cached fixtures
  which previously stored float32 fluxes from the original pure-Python
  implementation
- tests/test_components.py: restore TestSlabSurface overrides that set
  surface_material_density = sea_water_density before the stepping
  tests; these were present on develop but missing from this branch

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… plans

JAX support will live on a separate `differentiable-climt` branch.
This commit cleans up the numba-only branch:

- Remove climt/_core/jax_backend.py and all JAX exports from
  __init__.py and _core/__init__.py
- Remove JAX_KERNEL_FUSION.md, PROPOSAL_JAX_BACKEND.md,
  plans/PROPOSAL_JAX_BACKEND.md, tests/test_jax_differentiation.py
- Add numba @njit kernels for: DryConvectiveAdjustment, EmanuelV2/V3,
  HeldSuarez, Instellation, SlabSurface (already committed for GSC,
  Berger, GrayLW in previous commit)
- Add optimization parity tests for all seven components
- Add backend.py jit_compile helper and prange fallback
- Add implementation plans for numba PR and differentiable branch

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nents

Reports Python vs Numba timing for 8192 columns × 30 levels using
._pyfunc (py_func) on each @njit kernel. Results: 32x–124x speedups
across HeldSuarez, GrayLW, Frierson tau, GSC, DryConv, Berger,
SlabSurface, and Instellation components.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…stabilization

- Implement Numba optimizations for EmanuelConvectionPythonV3.
- Add comprehensive benchmarks for Numba vs. Backend overhead (benchmark_numba_x_backend.py).
- Add UnytBackend support to Emanuel V3 for 3.5x speedup over DataArrayBackend.
- Stabilize test suite by adding constant reset to pytest conftest.py.
- Add Emanuel V3 parity tests and record benchmark results.
@JoyMonteiro JoyMonteiro merged commit 61e3107 into develop Mar 30, 2026
6 checks passed
@JoyMonteiro JoyMonteiro deleted the numba-optimized-components branch April 4, 2026 07:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant