Python14#175
Conversation
Surfaced by flake8 (F821): - warninings.warn -> warnings.warn (4 sites: 358, 364, 1154, 1160) - time_series.shape -> self.time_series.shape (line 140) - jtau -> j_tau (line 1250) Each would have raised NameError when the corresponding branch ran; they were never caught because no test exercises those paths. Also includes incidental Black reformatting from format-on-save. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ifferences in rng generation between numpy and jax versions between old and new code
….py , added _LeafArrayTrainState to diff imabalnce
|
@diegodoimo Grazie per le modifiche e i fix in causal_graph.py! Ho sistemato due cose tra cui la precisione di alcuni test, rimettendola a 0.001. Il random generator "nuovo" dovrebbe essere lo stesso in tutte le versioni di jax compatibili con py>=3.10, non ho modo di verificarlo però perché i test nei vari environment non runnano (credo che vadano risolti i conflitti prima) |
|
Thank you @vdeltatto. Let me know if you have other things to add on your side. |
|
@diegodoimo I have to update other parts in the |
|
I think that part was written by @acevedo-s, as it concerns the BID estimator
|
|
Hi. The widened tolerances seem reasonable, thanks! S. |
This PR updates the supported Python versions to 3.13 and 3.14.
Python 3.8 and 3.9 have been removed from CI testing since they have been out of their support cycle for several months.
Most of the codebase works without modification on the newer Python versions. The main updates were made to the differentiable information imbalance class and its related tests, with smaller changes to hamming_test.py.
The modification to the differentiable information imbalance code is small (around 15 lines) but was done with Claude, so it requires careful revision (see below).
Types of changes
Update supported Python versions:
Removed unsupported unsupported Python versions:
Changes to diff_imbalance.py + related tests @vdeltatto
The major updates/modifications were done with the help of Claude on
dadapy/diff_imbalance.pytests/test_diff_imbalance_jax/test_train.pytests/test_diff_imbalance_jax/test_greedy_dii.pyThey require careful validation from @vdeltatto.
1)
dadapy/diff_imbalance.py:Added a
_LeafArrayTrainStatesubclass offlax.training.train_state.TrainStateand swapped the single
TrainState.create(...)call site to use it.Why:
DiffImbalancepasses a leafjax.Arrayasparams, not a PyTree.In
flax >= 0.9,TrainState.apply_gradientswas changed to doif OVERWRITE_WITH_GRADIENT in grads, which assumesgradsis a Mapping.That
incheck rejects leaf arrays at runtime with:TypeError: Array.contains: unsupported operand type <class 'str'>
The old
flax==0.8.5pin masked this incompatibility; the new stack(Python 3.13+ with
jax/flaxunpinned) trips it.The subclass reimplements
apply_gradientsusing the canonical optaxthree-line pattern (
tx.update→optax.apply_updates→replace) thatflax itself used pre-0.9. Behaviorally identical on the old stack,
compatible on the new stack. No other call sites in the codebase were touched.
2)
tests/test_diff_imbalance_jax/test_train.py:train11e-3on both stackstrain2batches_per_epoch=5,l1_strength=1e-4weights[-1]: 1e-3 → 0.5 (ADAM-trajectory drift);imbs[-1],imb_final: 1e-3 → 0.05train3point_adapt_lambda=True,ratio_rows_columns=1imbs[-1]: 0.01 → 0.02;imb_final,error_final: 1e-3 → 0.02train4ratio_rows_columns=0.5imb_final,error_final: 1e-3 → 0.02train5params_groups=[2,1],ratio=0.5,num_points_rows=50imbs[-1]: 0.01 → 0.02;imb_final,error_final: 1e-3 → 0.05 (constrained-optimization + small-subsample variance)train6distances_B,ratio=0.5imbs[-1]: 0.01 → 0.02;imb_final,error_final: 1e-3 → 0.023) tests/test_diff_imbalance_jax/test_greedy_dii.py
Two changes:
The companion test test_DiffImbalance_greedy_symmetry_5d_gaussian was left untouched — it has no tied features and continues to pass with the original strict assertions.
Changes to test_hamming.py @imacocco
1) Test:
tests/test_hamming/test0.pyWidened tolerances on the three final assertions:
absabsB.Op.d0 ≈ d_0(≈ 99.855)1e-35e-3B.Op.d1 ≈ d_1(≈ 0.003)1e-31e-3(unchanged)np.log(B.Op.KL) ≈ logKL(≈ −12.39)1e-22e-2The test runs a stochastic optimizer for 1e6 steps on a fixed
seed. With
jax/jaxlibunpinned for Python 3.13+ (jax 0.4.30 → 0.7+),the PRNG bit sequence for a given seed is no longer byte-identical to
the old stack, and XLA reorders some floating-point reductions. Over
1e6 stochastic steps, this drifts the final values by ~1e-3 in
d_0and ~1e-2 in
logKL— right at or just past the original tolerances.