Add optional Numba JIT optimization for pure-Python components#205
Merged
JoyMonteiro merged 21 commits intodevelopfrom Mar 30, 2026
Merged
Add optional Numba JIT optimization for pure-Python components#205JoyMonteiro merged 21 commits intodevelopfrom
JoyMonteiro merged 21 commits intodevelopfrom
Conversation
This commit introduces a high-performance, backend-agnostic implementation of the
Emanuel convection scheme and a reusable abstraction layer for future components.
Key Changes:
- Added : A universal layer for JAX/NumPy detection,
JIT orchestration, and functional state updates using Numba overloads.
- Added : A pure functional,
JIT-optimized implementation of the Emanuel scheme achieving ~118x speedup
over the original Python version and ~75x over Fortran.
- Refactored to utilize the new V2 implementation.
- Added comprehensive verification and benchmarking suites:
- : Formal parity regression tests.
- : side-by-side comparison of active tendencies.
- : Performance measurement script.
- Added : Detailed technical report on architectural
decisions and optimization results.
The new implementation maintains strict numerical parity (1e-12) with the original
Python code while providing a clear path toward full differentiability via JAX.
Key Changes: - Introduced : A JAX-JIT optimized implementation using and . - Enhanced with JAX-specific optimizations and vectorization. - Refactored to be backend-agnostic. - Achieved perfect bit-wise parity between JAX x64, Numba, and Original implementations. - Updated benchmarking and verification suites to support JAX performance analysis. - Updated with JAX-specific performance and architectural details.
…vection Key Changes: - Implemented in with Pytree-registered containers. - Registered in top-level init for use with . - Finalized with and for GPU/XLA optimization. - Refactored in to be backend-agnostic. - Verified ~550x speedup with Numba CPU and enabled JAX Metal (GPU) path. - Added to verify gradient computation. - Updated with final performance benchmarks on Apple Silicon.
Key Changes: - Refactored into a dual-path JIT implementation (Numba/JAX). - Verified ~3x GPU speedup at 100k columns on Apple M3 Pro. - Fixed to return raw JAX arrays for differentiable tracing. - Added for parity and grad verification. - Added for multi-backend scaling analysis. - Implemented placeholder in for sympl compatibility.
- Remove parallel=True from @njit/@jit_compile across components to fix potential segfaults. - Fix DryConvectiveAdjustment to use Numba-compatible loops and ensure differentiability. - Update SlabSurface to correctly index vertical dimensions for surface extraction and fix TracerArrayConversionError in JAX. - Enhance EmanuelConvectionPythonV3 with transpose bug fixes for diverse input shapes. - Improve JaxBackend and JaxStateContainer to handle non-numeric data and ensure better sympl integration. - Update test_emanuel_optimization.py to support JAX-based state containers. - Add comprehensive benchmarks and optimization tests for DryConvectiveAdjustment, GridScaleCondensation, HeldSuarez, Instellation, and SlabSurface. - Verify 13.8x speedup in Grey GCM benchmark using Numba-optimized path.
- GridScaleCondensation: fix precipitation sign (p_int[k] - p_int[k+1] was negative; reversed to p_int[k+1] - p_int[k]) - BergerSolarInsolation: fix eccentricity formula — zeta coefficients from CAM 3.0 are in radians, not degrees; remove the erroneous * pi/180 conversion that was applied to the whole argument - GrayLongwaveRadiation: move tendency calculation out of @njit kernel so it runs in numpy (float64 throughout); regenerate cached fixtures which previously stored float32 fluxes from the original pure-Python implementation - tests/test_components.py: restore TestSlabSurface overrides that set surface_material_density = sea_water_density before the stepping tests; these were present on develop but missing from this branch Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… plans JAX support will live on a separate `differentiable-climt` branch. This commit cleans up the numba-only branch: - Remove climt/_core/jax_backend.py and all JAX exports from __init__.py and _core/__init__.py - Remove JAX_KERNEL_FUSION.md, PROPOSAL_JAX_BACKEND.md, plans/PROPOSAL_JAX_BACKEND.md, tests/test_jax_differentiation.py - Add numba @njit kernels for: DryConvectiveAdjustment, EmanuelV2/V3, HeldSuarez, Instellation, SlabSurface (already committed for GSC, Berger, GrayLW in previous commit) - Add optimization parity tests for all seven components - Add backend.py jit_compile helper and prange fallback - Add implementation plans for numba PR and differentiable branch Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…nents Reports Python vs Numba timing for 8192 columns × 30 levels using ._pyfunc (py_func) on each @njit kernel. Results: 32x–124x speedups across HeldSuarez, GrayLW, Frierson tau, GSC, DryConv, Berger, SlabSurface, and Instellation components. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…stabilization - Implement Numba optimizations for EmanuelConvectionPythonV3. - Add comprehensive benchmarks for Numba vs. Backend overhead (benchmark_numba_x_backend.py). - Add UnytBackend support to Emanuel V3 for 3.5x speedup over DataArrayBackend. - Stabilize test suite by adding constant reset to pytest conftest.py. - Add Emanuel V3 parity tests and record benchmark results.
…/climt into numba-optimized-components
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
@njit-decorated kernels with column-parallelprange()loops for 7 pure-Python components:GrayLongwaveRadiation,GridScaleCondensation,HeldSuarez,DryConvectiveAdjustment,BergerSolarInsolation,SlabSurface, andInstellationjit_compile()helper inclimt/_core/backend.pythat applies@njitwhen Numba is available and is a no-op otherwisetests/test_*_optimization.py)differentiable-climtbranch)Bug fixes found during testing
Three bugs were introduced in the original numba kernel work and are fixed here:
GridScaleCondensation: precipitation sign was inverted (p_int[k] - p_int[k+1]should bep_int[k+1] - p_int[k])BergerSolarInsolation: eccentricity formula usedzetacoefficients (from CAM 3.0) as degrees when they are radians, producing wrong orbital parametersGrayLongwaveRadiation: flux arrays were float64 but cached regression fixtures stored float32 (from the original implementation); moved to float64 throughout and regenerated fixturesTestSlabSurface: restored test overrides (present ondevelop, missing from this branch) that setsurface_material_density = sea_water_densitybefore stepping testsTest results
pytest tests/test_components.py— 173 passedpytest tests/test_berger_insolation_optimization.py tests/test_dry_convection_optimization.py tests/test_gray_radiation_optimization.py tests/test_gsc_optimization.py tests/test_held_suarez_optimization.py tests/test_instellation_optimization.py tests/test_slab_surface_optimization.py— 7 passed🤖 Generated with Claude Code