Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c712283
adding all relevant files for emanuel convection
Feb 11, 2026
ab97148
updating
Feb 24, 2026
7e59a0c
adding files
Feb 24, 2026
fb5e0fa
fix(emanuel): resolve parity discrepancies with fortran implementation
JoyMonteiro Feb 25, 2026
430573f
feat(emanuel): add pure python sympl component and example notebook
JoyMonteiro Feb 25, 2026
b07c6c8
initial python port of emanuel convection is done!
JoyMonteiro Feb 26, 2026
57b3e16
Optimize Emanuel Convection and introduce universal backend abstraction
JoyMonteiro Feb 26, 2026
ef263be
Enable JAX support and differentiable path for Emanuel Convection
JoyMonteiro Feb 26, 2026
370c023
Implement JaxBackend and finalize differentiable path for Emanuel Con…
JoyMonteiro Feb 26, 2026
a9e4486
Optimize HeldSuarez and enhance JaxBackend for differentiation
JoyMonteiro Feb 26, 2026
530e725
Optimize components for Numba/JAX and fix array backend compatibility
JoyMonteiro Feb 26, 2026
245ef8f
adding a few things
JoyMonteiro Feb 27, 2026
95015a0
fix: correct three numba kernel bugs and restore SlabSurface test
JoyMonteiro Mar 17, 2026
7700802
chore: remove JAX backend; add numba optimizations and implementation…
JoyMonteiro Mar 17, 2026
300a7d6
Add benchmark script measuring Numba speedups for all optimized compo…
JoyMonteiro Mar 17, 2026
e55614e
feat: add Emanuel V3 optimizations, UnytBackend benchmarks, and test …
JoyMonteiro Mar 30, 2026
39ae471
Merge branch 'develop' into numba-optimized-components
JoyMonteiro Mar 30, 2026
6b58f90
fixed linter errors
JoyMonteiro Mar 30, 2026
ffda176
Merge branch 'numba-optimized-components' of https://github.com/CliMT…
JoyMonteiro Mar 30, 2026
0ee428c
fix test fail issues
JoyMonteiro Mar 30, 2026
1b11675
remove support for py39
JoyMonteiro Mar 30, 2026
61e3107
Merge pull request #205 from CliMT/numba-optimized-components
JoyMonteiro Mar 30, 2026
03a6d0c
Bump version: 0.18.5 → 0.19.0
JoyMonteiro Mar 30, 2026
4b82810
Bump version: 0.19.0 → 0.20.0
JoyMonteiro Mar 30, 2026
6adbb5e
Merge branch 'develop' of https://github.com/CliMT/climt into develop
JoyMonteiro Mar 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release_climt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:

- name: Build on Linux
env:
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-*
CIBW_BUILD: cp310-* cp311-* cp312-*
CIBW_SKIP: "*-musllinux_*"
CIBW_ARCHS_LINUX: "x86_64"
CIBW_ENVIRONMENT: "CC=gcc FC=gfortran CLIMT_ARCH=Linux"
Expand All @@ -40,7 +40,7 @@ jobs:

- name: Build on macOS
env:
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-*
CIBW_BUILD: cp310-* cp311-* cp312-*
CIBW_ARCHS_MACOS: "arm64"
CIBW_ENVIRONMENT: "CLIMT_ARCH=Darwin MACOSX_DEPLOYMENT_TARGET=15.0"
if: ${{ runner.os == 'macOS' }}
Expand Down
41 changes: 41 additions & 0 deletions EMANUEL_OPTIMIZATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Emanuel Convection Optimization Report

This document details the refactoring and optimization of the Emanuel convection scheme in `climt`. The primary goals were to achieve near-native performance, ensure compatibility with any array backend (NumPy, Unyt, JAX), and provide a path toward a fully differentiable model.

## 1. Architectural Changes

### Core Abstraction Layer (`climt/_core/backend.py`)
A universal abstraction layer was introduced to handle backend-specific operations transparently:
* **`get_array_namespace(*arrays)`**: Dynamically identifies the correct API (`numpy` or `jax.numpy`) based on input data.
* **`set_item(arr, idx, val)`**: A unified interface for updating arrays. It uses Numba `overload` for high-performance in-place updates on NumPy, and functional updates (`at[idx].set(val)`) for JAX.
* **`JaxBackend`**: A new `sympl` StateBackend that allows the entire model to run on JAX arrays, enabling differentiability through the `sympl` interface.

### Functional Refactoring Versions
* **V2 (`pure_python_v2.py`)**: Optimized for Numba JIT on NumPy/Unyt. Uses standard Python control flow. Highly performant on CPU.
* **V3 (`pure_python_v3.py`)**: Dual-path implementation optimized for both Numba and JAX XLA. Replaces vertical loops with `jax.lax.scan` and branches with `jnp.where` for the JAX path.

## 2. Optimization Strategy & Performance

### Multi-Backend Performance (1000 Columns, 30 Levels, Apple M3 Pro)
| Backend | Platform | Time per Call (s) | Relative Speedup |
| :--- | :--- | :--- | :--- |
| Original Python | CPU (Serial) | 0.6607 | 1.0x |
| Fortran | CPU (Serial) | 0.0039 | ~170x |
| **V3 Numba JIT** | **CPU (Parallel)** | **0.0012** | **~550x** |
| **V3 JAX JIT** | **METAL (GPU)** | **0.2660** | **~2.5x** |

**Note on GPU Performance**: The Emanuel scheme's high logic complexity results in many small kernels. On Apple Silicon, the dispatch latency for these kernels currently makes the CPU/Numba path significantly faster than the GPU path for this specific algorithm. However, the JAX path provides full differentiability.

## 3. Verification and Differentiability

### Continuous Parity Check
* **Tolerance**: A strict tolerance of `1e-12` is enforced between Numba and Original Python.
* **JAX x64**: Precision was verified using `jax_enable_x64=True` on CPU, achieving perfect bit-wise parity. Note that METAL GPU currently only supports `float32`.

### Differentiability
The `V3` implementation is fully differentiable. Sensitivities can be computed using `jax.grad` through the `array_call` interface, providing a foundation for parameter optimization and sensitivity analysis.

## 4. Summary of New Components
* `climt/_core/jax_backend.py`: JAX-native state handling.
* `climt/_components/emanuel/pure_python_v3.py`: High-performance differentiable physics.
* `tests/test_jax_differentiation.py`: Verification of gradient flow.
69 changes: 0 additions & 69 deletions PROPOSAL_JAX_BACKEND.md

This file was deleted.

50 changes: 50 additions & 0 deletions benchmarks/benchmark_berger_insolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import time
import numpy as np
import climt
from climt import BergerSolarInsolation, get_grid, get_default_state
import sympl
from datetime import datetime

def benchmark_berger_insolation(ncol=8192, iterations=100):
grid = get_grid(nx=ncol, ny=1, nz=20)
insolation = BergerSolarInsolation()

# NumPy path
sympl.set_backend(sympl.DataArrayBackend())
state = get_default_state([insolation], grid_state=grid)
state['time'] = datetime(2026, 6, 21, 12, 0) # Summer solstice

# Run once
insolation(state)

print(f"Benchmarking BergerSolarInsolation with {ncol} columns, {iterations} iterations")

start = time.perf_counter()
for _ in range(iterations):
insolation(state)
end = time.perf_counter()
print(f"Berger Insolation Cython/NumPy: {end - start:.4f}s")

# JAX path
try:
from climt import JaxBackend
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
sympl.set_backend(JaxBackend())
state_jax = get_default_state([insolation], grid_state=grid)
state_jax['time'] = datetime(2026, 6, 21, 12, 0)

# Run once to JIT
insolation(state_jax)

start = time.perf_counter()
for _ in range(iterations):
insolation(state_jax)
end = time.perf_counter()
print(f"Berger Insolation JAX: {end - start:.4f}s")
except (ImportError, Exception) as e:
print(f"JAX path failed or not available: {e}")

if __name__ == "__main__":
benchmark_berger_insolation()
60 changes: 60 additions & 0 deletions benchmarks/benchmark_dry_convection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time
import numpy as np
import climt
from climt import DryConvectiveAdjustment, get_grid, get_default_state
import sympl
from datetime import timedelta

def benchmark_dry_convection(ncol=1024, nlev=30, iterations=10):
grid = get_grid(nx=ncol, ny=1, nz=nlev)
dc = DryConvectiveAdjustment()

# NumPy path
sympl.set_backend(sympl.DataArrayBackend())
state = get_default_state([dc], grid_state=grid)

# Create an unstable profile
# Temperature decreasing too fast with height (super-adiabatic)
unstable_level = 5
state['air_temperature'].values[:unstable_level, :, :] += 20.0
state['specific_humidity'].values[:unstable_level, :, :] = 0.02

timestep = timedelta(minutes=10)

# Run once
dc(state, timestep)

print(f"Benchmarking DryConvectiveAdjustment with {ncol} columns, {iterations} iterations")

start = time.perf_counter()
for _ in range(iterations):
dc(state, timestep)
end = time.perf_counter()
print(f"Dry Convection NumPy: {end - start:.4f}s")

# JAX path
try:
from climt import JaxBackend
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
sympl.set_backend(JaxBackend())
state_jax = get_default_state([dc], grid_state=grid)

# Unstable profile for JAX
state_jax['air_temperature'].data = state_jax['air_temperature'].data.at[:unstable_level].add(20.0)
state_jax['specific_humidity'].data = state_jax['specific_humidity'].data.at[:unstable_level].set(0.02)

# Run once to JIT
dc(state_jax, timestep)

start = time.perf_counter()
for _ in range(iterations):
dc(state_jax, timestep)
end = time.perf_counter()
print(f"Dry Convection JAX: {end - start:.4f}s")
except (ImportError, Exception) as e:
print(f"JAX path failed or not available: {e}")

if __name__ == "__main__":
benchmark_dry_convection()
58 changes: 58 additions & 0 deletions benchmarks/benchmark_emanuel_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import time
from datetime import timedelta

import numpy as np
import sympl
from sympl._core.backend import DataArrayBackend

import climt
from climt import (
EmanuelConvection,
EmanuelConvectionPythonV3,
get_default_state,
get_grid,
)


def run_benchmark():
print("Benchmarking Emanuel Convection: Fortran vs Pure Python...")

sympl.set_backend(DataArrayBackend())
grid = get_grid(nx=32, ny=32, nz=30)

# Initialize components
conv_fortran = EmanuelConvection()
conv_python = EmanuelConvectionPythonV3()

# Initial state
state = get_default_state([conv_fortran], grid_state=grid)
# Ensure specific humidity is reasonable for convection
state["specific_humidity"].values[:] = 0.04
state["air_temperature"].values[:] = 290.0

timestep = timedelta(minutes=20)

# Run Fortran version
start = time.perf_counter()
t_fort, d_fort = conv_fortran(state, timestep)
dur_fort = time.perf_counter() - start
print(f" Fortran Duration: {dur_fort:.4f} s")

# Run Python version
start = time.perf_counter()
t_py, d_py = conv_python(state, timestep)
dur_py = time.perf_counter() - start
print(f" Python Duration: {dur_py:.4f} s")

# Compare outputs
print("\nVerifying Outputs...")
t_vars = ["air_temperature", "specific_humidity"]
for var in t_vars:
diff = np.abs(t_fort[var].values - t_py[var].values)
print(f" {var} max diff: {np.max(diff):.2e}")

print(f"\nSpeedup (Fortran/Python): {dur_fort / dur_py:.2f}x")


if __name__ == "__main__":
run_benchmark()
Loading
Loading