Skip to content

Comments

[WIP] Spectral-Grassmann OT#792

Open
thibaut-germain wants to merge 11 commits intoPythonOT:masterfrom
thibaut-germain:sgot
Open

[WIP] Spectral-Grassmann OT#792
thibaut-germain wants to merge 11 commits intoPythonOT:masterfrom
thibaut-germain:sgot

Conversation

@thibaut-germain
Copy link

Types of changes

Adding sgot file in the ot folder.

Motivation and context / Related issue

Keep track of SGOT implementation in POT.

How has this been tested (if it applies)

Not tested yet.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • [] All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title Sgot [WIP] Spactral-Gromov OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spactral-Gromov OT [WIP] Spectral-Grassman OT Feb 9, 2026
@rflamary rflamary changed the title [WIP] Spectral-Grassman OT [WIP] Spectral-Grassmann OT Feb 9, 2026
@codecov
Copy link

codecov bot commented Feb 11, 2026

Codecov Report

❌ Patch coverage is 7.17131% with 233 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.77%. Comparing base (e164e78) to head (3f10111).
⚠️ Report is 2 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #792      +/-   ##
==========================================
- Coverage   96.77%   95.77%   -1.00%     
==========================================
  Files         107      108       +1     
  Lines       22342    22621     +279     
==========================================
+ Hits        21622    21666      +44     
- Misses        720      955     +235     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @osheasienna and @thibaut-germain this is a nice first step.

Here are below a few comments that we can discuss together

return C


def metric(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def metric(
def sgot_metric(

return prod ** (q / 2)


def ot_plan(C, Ws=None, Wt=None, nx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things

### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ###
#####################################################################################################################################
#####################################################################################################################################
def cost(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cost(
def sgot_cost_matrix(

imag_scale=1.0,
nx=None,
):
"""Compute the SGOT cost matrix between two spectral decompositions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recall here the equation with eta and define with math teh different acceptable metrics

raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}")
lam2 = Dt

lam1 = nx.astype(lam1, "complex128")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?

logits_s = rng.randn(r)
logits_t = rng.randn(r)

Ws = np.exp(logits_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simpler and return only positive values

Suggested change
Ws = np.exp(logits_s)
Ws = rng.rand(r)

"""Create test_cost for each trial: sweep over HPs and run cost()."""
grassmann_types = ["geodesic", "chordal", "procrustes", "martin"]
n_trials = 10
for _ in range(n_trials):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for trials ;)

def test_hyperparameter_sweep():
grassmann_types = ["geodesic", "chordal", "procrustes", "martin"]

for _ in range(10):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add feature here

## Upcomming 0.9.7.post1

#### New features
The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments from talking together

if grassman_metric == "procrustes":
return 2.0 * (1.0 - delta)
if grassman_metric == "martin":
return -nx.log(nx.clip(delta**2, eps, 1e300))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove upper threshold

C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx)

C2 = eta * C_lambda + (1.0 - eta) * C_grass
C = C2 ** (p / 2.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
C = C2 ** (p / 2.0)
C = nx.real(C2) ** (p / 2.0)

q=1,
r=2,
grassman_metric="chordal",
real_scale=1.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets call this eigen_scaling and set it to None by default

nx=None,
):
"""Compute the SGOT metric between two spectral decompositions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add equation that illustrate p q and r

import numpy as np
import pytest

from ot.backend import get_backend
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ot.backend import get_backend
from ot.backend import get_backend, torch, jax

rng = np.random.RandomState(0)


def rand_complex(shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def rand_complex(shape):
def rand_complex(shape,rng):

return real + 1j * imag


def random_atoms(d=8, r=4):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def random_atoms(d=8, r=4):
def random_atoms(d=8, r=4,seed=42):



@pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"])
def test_cost_backend_consistency(backend_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_cost_backend_consistency(backend_name):
def test_cost_backend_consistency(nx):

# ---------------------------------------------------------------------


def test_hyperparameter_sweep_cost(nx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_hyperparameter_sweep_cost(nx):
def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants