Skip to content

Adding new numba optimizations#206

Merged
JoyMonteiro merged 25 commits intomasterfrom
develop
Mar 30, 2026
Merged

Adding new numba optimizations#206
JoyMonteiro merged 25 commits intomasterfrom
develop

Conversation

@JoyMonteiro
Copy link
Copy Markdown
Member

No description provided.

Joy Monteiro and others added 25 commits February 11, 2026 22:43
This commit introduces a high-performance, backend-agnostic implementation of the
Emanuel convection scheme and a reusable abstraction layer for future components.

Key Changes:
- Added : A universal layer for JAX/NumPy detection,
  JIT orchestration, and functional state updates using Numba overloads.
- Added : A pure functional,
  JIT-optimized implementation of the Emanuel scheme achieving ~118x speedup
  over the original Python version and ~75x over Fortran.
- Refactored  to utilize the new V2 implementation.
- Added comprehensive verification and benchmarking suites:
    - : Formal parity regression tests.
    - : side-by-side comparison of active tendencies.
    - : Performance measurement script.
- Added : Detailed technical report on architectural
  decisions and optimization results.

The new implementation maintains strict numerical parity (1e-12) with the original
Python code while providing a clear path toward full differentiability via JAX.
Key Changes:
- Introduced : A JAX-JIT optimized implementation using  and .
- Enhanced  with JAX-specific optimizations and vectorization.
- Refactored  to be backend-agnostic.
- Achieved perfect bit-wise parity between JAX x64, Numba, and Original implementations.
- Updated benchmarking and verification suites to support JAX performance analysis.
- Updated  with JAX-specific performance and architectural details.
…vection

Key Changes:
- Implemented  in  with Pytree-registered containers.
- Registered  in  top-level init for use with .
- Finalized  with  and  for GPU/XLA optimization.
- Refactored  in  to be backend-agnostic.
- Verified ~550x speedup with Numba CPU and enabled JAX Metal (GPU) path.
- Added  to verify gradient computation.
- Updated  with final performance benchmarks on Apple Silicon.
Key Changes:
- Refactored  into a dual-path JIT implementation (Numba/JAX).
- Verified ~3x GPU speedup at 100k columns on Apple M3 Pro.
- Fixed  to return raw JAX arrays for differentiable tracing.
- Added  for parity and grad verification.
- Added  for multi-backend scaling analysis.
- Implemented  placeholder in  for sympl compatibility.
- Remove parallel=True from @njit/@jit_compile across components to fix potential segfaults.
- Fix DryConvectiveAdjustment to use Numba-compatible loops and ensure differentiability.
- Update SlabSurface to correctly index vertical dimensions for surface extraction and fix TracerArrayConversionError in JAX.
- Enhance EmanuelConvectionPythonV3 with transpose bug fixes for diverse input shapes.
- Improve JaxBackend and JaxStateContainer to handle non-numeric data and ensure better sympl integration.
- Update test_emanuel_optimization.py to support JAX-based state containers.
- Add comprehensive benchmarks and optimization tests for DryConvectiveAdjustment, GridScaleCondensation, HeldSuarez, Instellation, and SlabSurface.
- Verify 13.8x speedup in Grey GCM benchmark using Numba-optimized path.
- GridScaleCondensation: fix precipitation sign (p_int[k] - p_int[k+1]
  was negative; reversed to p_int[k+1] - p_int[k])
- BergerSolarInsolation: fix eccentricity formula — zeta coefficients
  from CAM 3.0 are in radians, not degrees; remove the erroneous
  * pi/180 conversion that was applied to the whole argument
- GrayLongwaveRadiation: move tendency calculation out of @njit kernel
  so it runs in numpy (float64 throughout); regenerate cached fixtures
  which previously stored float32 fluxes from the original pure-Python
  implementation
- tests/test_components.py: restore TestSlabSurface overrides that set
  surface_material_density = sea_water_density before the stepping
  tests; these were present on develop but missing from this branch

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… plans

JAX support will live on a separate `differentiable-climt` branch.
This commit cleans up the numba-only branch:

- Remove climt/_core/jax_backend.py and all JAX exports from
  __init__.py and _core/__init__.py
- Remove JAX_KERNEL_FUSION.md, PROPOSAL_JAX_BACKEND.md,
  plans/PROPOSAL_JAX_BACKEND.md, tests/test_jax_differentiation.py
- Add numba @njit kernels for: DryConvectiveAdjustment, EmanuelV2/V3,
  HeldSuarez, Instellation, SlabSurface (already committed for GSC,
  Berger, GrayLW in previous commit)
- Add optimization parity tests for all seven components
- Add backend.py jit_compile helper and prange fallback
- Add implementation plans for numba PR and differentiable branch

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nents

Reports Python vs Numba timing for 8192 columns × 30 levels using
._pyfunc (py_func) on each @njit kernel. Results: 32x–124x speedups
across HeldSuarez, GrayLW, Frierson tau, GSC, DryConv, Berger,
SlabSurface, and Instellation components.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…stabilization

- Implement Numba optimizations for EmanuelConvectionPythonV3.
- Add comprehensive benchmarks for Numba vs. Backend overhead (benchmark_numba_x_backend.py).
- Add UnytBackend support to Emanuel V3 for 3.5x speedup over DataArrayBackend.
- Stabilize test suite by adding constant reset to pytest conftest.py.
- Add Emanuel V3 parity tests and record benchmark results.
Add optional Numba JIT optimization for pure-Python components
@JoyMonteiro JoyMonteiro merged commit 878a1ad into master Mar 30, 2026
0 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant