Skip to content

Python14#175

Merged
diegodoimo merged 14 commits into
mainfrom
python14
May 27, 2026
Merged

Python14#175
diegodoimo merged 14 commits into
mainfrom
python14

Conversation

@diegodoimo

@diegodoimo diegodoimo commented May 15, 2026

Copy link
Copy Markdown
Collaborator

This PR updates the supported Python versions to 3.13 and 3.14.

Python 3.8 and 3.9 have been removed from CI testing since they have been out of their support cycle for several months.

Most of the codebase works without modification on the newer Python versions. The main updates were made to the differentiable information imbalance class and its related tests, with smaller changes to hamming_test.py.

The modification to the differentiable information imbalance code is small (around 15 lines) but was done with Claude, so it requires careful revision (see below).

Types of changes

Update supported Python versions:

  • added Python 3.13;
  • added Python 3.14;

Removed unsupported unsupported Python versions:

  • removed Python 3.8 (maintenance ended in 2024-10-07);
  • removed Python 3.9 (maintenance ended in 2025-10-31);

Changes to diff_imbalance.py + related tests @vdeltatto

The major updates/modifications were done with the help of Claude on

  1. dadapy/diff_imbalance.py
  2. tests/test_diff_imbalance_jax/test_train.py
  3. tests/test_diff_imbalance_jax/test_greedy_dii.py

They require careful validation from @vdeltatto.

1) dadapy/diff_imbalance.py:

Added a _LeafArrayTrainState subclass of flax.training.train_state.TrainState
and swapped the single TrainState.create(...) call site to use it.

Why: DiffImbalance passes a leaf jax.Array as params, not a PyTree.
In flax >= 0.9, TrainState.apply_gradients was changed to do
if OVERWRITE_WITH_GRADIENT in grads, which assumes grads is a Mapping.
That in check rejects leaf arrays at runtime with:

TypeError: Array.contains: unsupported operand type <class 'str'>

The old flax==0.8.5 pin masked this incompatibility; the new stack
(Python 3.13+ with jax/flax unpinned) trips it.

The subclass reimplements apply_gradients using the canonical optax
three-line pattern (tx.updateoptax.apply_updatesreplace) that
flax itself used pre-0.9. Behaviorally identical on the old stack,
compatible on the new stack. No other call sites in the codebase were touched.

2) tests/test_diff_imbalance_jax/test_train.py:

  • Six tests (train1–train6). Adjustments are limited to tolerance widening for floating-point drift that arises from the new jax/jaxlib/flax stack, producing a different stochastic-gradient trajectory and a different jax.random subsampling sequence for the same seed.
  • tests that don't subsample (train1, train2's imb_final uses ratio_rows_columns=None) keep tight tolerances; tests that subsample with a jax-PRNG-driven seed need looser tolerances because the new jax produces different subsamples for the same seed value.
Test Setup Tolerance changes
train1 SGD, no subsampling unchanged — passes at 1e-3 on both stacks
train2 ADAM, batches_per_epoch=5, l1_strength=1e-4 weights[-1]: 1e-3 → 0.5 (ADAM-trajectory drift); imbs[-1], imb_final: 1e-3 → 0.05
train3 periodic, point_adapt_lambda=True, ratio_rows_columns=1 imbs[-1]: 0.01 → 0.02; imb_final, error_final: 1e-3 → 0.02
train4 SGD, ratio_rows_columns=0.5 imb_final, error_final: 1e-3 → 0.02
train5 params_groups=[2,1], ratio=0.5, num_points_rows=50 imbs[-1]: 0.01 → 0.02; imb_final, error_final: 1e-3 → 0.05 (constrained-optimization + small-subsample variance)
train6 precomputed distances_B, ratio=0.5 imbs[-1]: 0.01 → 0.02; imb_final, error_final: 1e-3 → 0.02

3) tests/test_diff_imbalance_jax/test_greedy_dii.py

  • The dataset has two near-tied features (1 and 2) at the 4-tuple step — DII scores differ by ~1e-3, well below the algorithm's intrinsic noise. Different JAX versions (and forward vs. backward greedy at any given version) break the tie in different directions. The original test imposed strict equality against a hardcoded list of feature sets, which is fragile against tied scores.

Two changes:

  • Forward/backward equality checks replaced with property-based assertions: feature 3 is always selected (highest ground-truth weight), set sizes follow the expected progression (1→5 / 5→1), and the high-importance subset {0, 3, 4} is present at the 3-tuple step. The 4-tuple step is left unconstrained on the tied {1, 2} pair.
  • Reverse-order check (asserting feature_sets_fw[i] == feature_sets_bw[-(i+1)] for every i) weakened to the unambiguous endpoints only. Forward and backward provably can't agree at near-tie steps because they train from different starting weights and see slightly different DII scores for the same subset.
    The companion test test_DiffImbalance_greedy_symmetry_5d_gaussian was left untouched — it has no tied features and continues to pass with the original strict assertions.

Changes to test_hamming.py @imacocco

1) Test: tests/test_hamming/test0.py

Widened tolerances on the three final assertions:

Assertion Old abs New abs
B.Op.d0 ≈ d_0 (≈ 99.855) 1e-3 5e-3
B.Op.d1 ≈ d_1 (≈ 0.003) 1e-3 1e-3 (unchanged)
np.log(B.Op.KL) ≈ logKL (≈ −12.39) 1e-2 2e-2

The test runs a stochastic optimizer for 1e6 steps on a fixed
seed. With jax/jaxlib unpinned for Python 3.13+ (jax 0.4.30 → 0.7+),
the PRNG bit sequence for a given seed is no longer byte-identical to
the old stack, and XLA reorders some floating-point reductions. Over
1e6 stochastic steps, this drifts the final values by ~1e-3 in d_0
and ~1e-2 in logKL — right at or just past the original tolerances.

diegodoimo and others added 6 commits May 15, 2026 18:41
Surfaced by flake8 (F821):
- warninings.warn -> warnings.warn (4 sites: 358, 364, 1154, 1160)
- time_series.shape -> self.time_series.shape (line 140)
- jtau -> j_tau (line 1250)

Each would have raised NameError when the corresponding branch ran;
they were never caught because no test exercises those paths. Also
includes incidental Black reformatting from format-on-save.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ifferences in rng generation between numpy and jax versions between old and new code
….py , added _LeafArrayTrainState to diff imabalnce
@vdeltatto

Copy link
Copy Markdown
Collaborator

@diegodoimo Grazie per le modifiche e i fix in causal_graph.py! Ho sistemato due cose tra cui la precisione di alcuni test, rimettendola a 0.001. Il random generator "nuovo" dovrebbe essere lo stesso in tutte le versioni di jax compatibili con py>=3.10, non ho modo di verificarlo però perché i test nei vari environment non runnano (credo che vadano risolti i conflitti prima)

@diegodoimo

Copy link
Copy Markdown
Collaborator Author

Thank you @vdeltatto. Let me know if you have other things to add on your side.
@imacocco, please double-check that the modifications to the test_hamming tolerances described above are ok, or if not, please fix what is missing (I guess this part was done by you, right?)

@vdeltatto

Copy link
Copy Markdown
Collaborator

@diegodoimo I have to update other parts in the DiffImbalance and CausalGraph modules, but these modifications are not related to Python 14, so I'll do everything in a separated PR

@imacocco

Copy link
Copy Markdown
Collaborator

I think that part was written by @acevedo-s, as it concerns the BID estimator

Thank you @vdeltatto. Let me know if you have other things to add on your side. @imacocco, please double-check that the modifications to the test_hamming tolerances described above are ok, or if not, please fix what is missing (I guess this part was done by you, right?)

@acevedo-s

Copy link
Copy Markdown
Collaborator

Hi. The widened tolerances seem reasonable, thanks!

S.

@diegodoimo diegodoimo merged commit 56e2476 into main May 27, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants