Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion climt/_components/slab_surface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sympl import TendencyComponent, initialize_numpy_arrays_with_properties
import numpy as np
from sympl import TendencyComponent, initialize_numpy_arrays_with_properties


class SlabSurface(TendencyComponent):
Expand Down
2 changes: 1 addition & 1 deletion climt/_core/initialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from datetime import datetime

import numpy as np
Expand All @@ -9,7 +10,6 @@
get_constant,
set_constant,
)
import sys

if sys.version_info < (3, 9):
import importlib_resources
Expand Down
89 changes: 86 additions & 3 deletions climt/_core/unyt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,33 @@ class UnytStateContainer:
dims (tuple of str): The names of the dimensions.
"""

def __init__(self, data, dims):
if not isinstance(data, (unyt.unyt_array, np.ndarray)):
def __init__(self, data, dims=None, attrs=None):
if not isinstance(
data, (unyt.unyt_array, np.ndarray, float, int, np.floating, np.integer)
):
# Ideally strict, but helpful to be flexible if easy.
# For now, strict to match design.
raise TypeError(
f"Data must be a unyt.unyt_array or numpy.ndarray, got {type(data)}"
)

if (
attrs is not None
and "units" in attrs
and not isinstance(data, unyt.unyt_array)
):
try:
sanitized_units = attrs["units"].replace("^", "**").replace(" ", "*")
data = unyt.unyt_array(data, sanitized_units)
except Exception:
pass

self.data = data
self.dims = tuple(dims)
self.dims = tuple(dims) if dims is not None else ()

def rename(self, name_dict):
new_dims = [name_dict.get(d, d) for d in self.dims]
return UnytStateContainer(self.data, new_dims)

def __repr__(self):
return f"UnytStateContainer(data={self.data}, dims={self.dims})"
Expand All @@ -123,6 +141,68 @@ def attrs(self):
return {"units": str(self.data.units)}
return {}

class _LocIndexer:
def __init__(self, container):
self.container = container

def __getitem__(self, key):
if isinstance(key, dict):
slices = []
for dim in self.container.dims:
if dim in key:
slices.append(key[dim])
else:
slices.append(slice(None))
return self.container[tuple(slices)]
return self.container[key]

def __setitem__(self, key, value):
if isinstance(key, dict):
slices = []
for dim in self.container.dims:
if dim in key:
slices.append(key[dim])
else:
slices.append(slice(None))
self.container[tuple(slices)] = value
else:
self.container[key] = value

@property
def loc(self):
return self._LocIndexer(self)

def transpose(self, *dims):
if len(dims) == 1 and isinstance(dims[0], (tuple, list)):
dims = dims[0]
# map dim names to axis indices
perm = [self.dims.index(dim) for dim in dims]
return UnytStateContainer(self.data.transpose(perm), dims)

def __getitem__(self, key):
sliced_data = self.data[key]
if isinstance(sliced_data, (unyt.unyt_array, np.ndarray)):
if sliced_data.ndim == len(self.dims):
return UnytStateContainer(sliced_data, self.dims)
else:
# Provide dummy dimensions or just truncate if we cannot infer dropped dimension easily
# Truncate to match ndim for now
return UnytStateContainer(sliced_data, self.dims[: sliced_data.ndim])
return sliced_data

def __eq__(self, other):
if isinstance(other, UnytStateContainer):
return self.data == other.data
return self.data == other

def __setitem__(self, key, value):
if isinstance(value, UnytStateContainer):
self.data[key] = value.data
elif isinstance(value, unyt.unyt_array) and hasattr(self.data, "units"):
self.data[key] = value.to(self.data.units)
else:
self.data[key] = value

def to_units(self, units):
if not isinstance(self.data, unyt.unyt_array):
return self
Expand Down Expand Up @@ -398,3 +478,6 @@ def get_shape(self, state_value):
if not isinstance(state_value, UnytStateContainer):
raise TypeError(f"Expected UnytStateContainer, got {type(state_value)}")
return state_value.data.shape

def get_container_type(self):
return UnytStateContainer
9 changes: 5 additions & 4 deletions climt/_core/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sympl import jit, DataArray
import numpy as np
import functools

import numpy as np
from sympl import DataArray, jit


def ensure_contiguous_state(func):
@functools.wraps(func)
Expand Down Expand Up @@ -173,14 +174,14 @@ def calculate_q_sat(surf_temp, surf_press, Rd, Rv):
return eps * sat_vap_press / (surf_press - (1 - eps) * sat_vap_press)


@jit(nopython=True)
# @jit(nopython=True)
def bolton_q_sat(T, p, Rd, Rh2O):
es = 611.2 * np.exp(17.67 * (T - 273.15) / (T - 29.65))
epsilon = Rd / Rh2O
return epsilon * es / (p - (1 - epsilon) * es)


@jit(nopython=True)
# @jit(nopython=True)
def bolton_dqsat_dT(T, Lv, Rh2O, q_sat):
"""Uses the assumptions of equation 12 in Reed and Jablonowski, 2012. In
particular, assumes d(qsat)/dT is approximately epsilon/p*d(es)/dT"""
Expand Down
Binary file removed rad_conv_eq_unyt.nc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Loading