Skip to content
Open
8 changes: 8 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Releases


## Upcomming 0.9.7.post1

#### New features
The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number




## 0.9.7.dev0

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.
Expand Down
152 changes: 152 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,55 @@ def clip(self, a, a_min, a_max):
"""
raise NotImplementedError()

def real(self, a):
"""
Return the real part of the tensor element-wise.

This function follows the api from :any:`numpy.real`

See: https://numpy.org/doc/stable/reference/generated/numpy.real.html
"""
raise NotImplementedError()

def imag(self, a):
"""
Return the imaginary part of the tensor element-wise.

This function follows the api from :any:`numpy.imag`

See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html
"""
raise NotImplementedError()

def conj(self, a):
"""
Return the complex conjugate, element-wise.

This function follows the api from :any:`numpy.conj`

See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html
"""
raise NotImplementedError()

def arccos(self, a):
"""
Trigonometric inverse cosine, element-wise.

This function follows the api from :any:`numpy.arccos`

See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html
"""
raise NotImplementedError()

def astype(self, a, dtype):
"""
Cast tensor to a given dtype.

dtype can be a string (e.g. "complex128", "float64") or backend-specific
dtype. Backend converts to the corresponding type.
"""
raise NotImplementedError()

def repeat(self, a, repeats, axis=None):
r"""
Repeats elements of a tensor.
Expand Down Expand Up @@ -1294,6 +1343,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return np.clip(a, a_min, a_max)

def real(self, a):
return np.real(a)

def imag(self, a):
return np.imag(a)

def conj(self, a):
return np.conj(a)

def arccos(self, a):
return np.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(np, dtype, None) or np.dtype(dtype)
return np.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return np.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1711,6 +1777,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return jnp.clip(a, a_min, a_max)

def real(self, a):
return jnp.real(a)

def imag(self, a):
return jnp.imag(a)

def conj(self, a):
return jnp.conj(a)

def arccos(self, a):
return jnp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(jnp, dtype, None) or jnp.dtype(dtype)
return jnp.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return jnp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -2208,6 +2291,41 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return torch.clamp(a, a_min, a_max)

def real(self, a):
return torch.real(a)

def imag(self, a):
return torch.imag(a)

def conj(self, a):
return torch.conj(a)

def arccos(self, a):
return torch.acos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
# Map common numpy-style string dtypes to torch dtypes explicitly.
# This makes backend.astype robust across torch versions and aliases.
mapping = {
"float32": torch.float32,
"float64": torch.float64,
"float": torch.float32,
"double": torch.float64,
"complex64": getattr(torch, "complex64", None),
"complex128": getattr(torch, "complex128", None),
}
torch_dtype = mapping.get(dtype)
if torch_dtype is None:
# Fallback: try direct attribute lookup (e.g. torch.float16)
torch_dtype = getattr(torch, dtype, None)
if torch_dtype is None:
raise ValueError(
f"Unsupported dtype for TorchBackend.astype: {dtype!r}"
)
dtype = torch_dtype
return a.to(dtype=dtype)

def repeat(self, a, repeats, axis=None):
return torch.repeat_interleave(a, repeats, dim=axis)

Expand Down Expand Up @@ -2709,6 +2827,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return cp.clip(a, a_min, a_max)

def real(self, a):
return cp.real(a)

def imag(self, a):
return cp.imag(a)

def conj(self, a):
return cp.conj(a)

def arccos(self, a):
return cp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(cp, dtype, None) or cp.dtype(dtype)
return cp.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -3143,6 +3278,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return tnp.clip(a, a_min, a_max)

def real(self, a):
return tnp.real(a)

def imag(self, a):
return tnp.imag(a)

def conj(self, a):
return tnp.conj(a)

def arccos(self, a):
return tnp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(tnp, dtype, None) or tnp.dtype(dtype)
return tnp.array(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return tnp.repeat(a, repeats, axis)

Expand Down
Loading
Loading