Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .coveragerc

This file was deleted.

8 changes: 4 additions & 4 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ on:

jobs:
run:
continue-on-error: True
continue-on-error: False
runs-on: ${{ matrix.os }}

strategy:
matrix:
os: [ubuntu-latest]
os: [ubuntu-24.04]
python-version: [3.12]

timeout-minutes: 25

steps:
- uses: actions/checkout@master
- uses: actions/setup-python@master
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
25 changes: 9 additions & 16 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@ on:

jobs:
run:
continue-on-error: True
continue-on-error: False
runs-on: ${{ matrix.os }}

strategy:
matrix:
include:
- {os: "macOS-14", python-version: "3.10"}
- {os: "ubuntu-22.04", python-version: "3.12"}
- {os: "ubuntu-22.04", python-version: "3.11"}
- {os: "ubuntu-22.04", python-version: "3.10"}
- {os: "ubuntu-22.04", python-version: "3.9"}
- {os: "ubuntu-24.04", python-version: "3.14"}
- {os: "ubuntu-24.04", python-version: "3.13"}
- {os: "ubuntu-24.04", python-version: "3.12"}
- {os: "ubuntu-24.04", python-version: "3.11"}
- {os: "ubuntu-24.04", python-version: "3.10"}
#os: [ubuntu-latest] # TODO: add macos-latest, windows-latest ?
#python-version: [3.8, 3.9, 3.10, 3.11]

timeout-minutes: 30

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -41,11 +41,4 @@ jobs:
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v1
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# file: ./coverage.xml
# flags: unittests
# name: codecov-umbrella
# fail_ci_if_error: true

9 changes: 6 additions & 3 deletions dadapy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""DADApy: a Python package for Distance-based Analysis of DAta-manifolds."""

import sys

from ._utils.utils import * # noqa: F401,F403
from .base import Base
from .clustering import Clustering
Expand Down Expand Up @@ -31,9 +29,14 @@
"NeighGraph",
]

if sys.version_info >= (3, 9):
try:
from .causal_graph import CausalGraph
from .diff_imbalance import DiffImbalance
from .hamming import BID, Hamming

__all__ += ["CausalGraph", "DiffImbalance", "BID", "Hamming"]
except (ImportError, RuntimeError):
# JAX-dependent classes unavailable (e.g., this exception is raised in
# joblib worker subprocesses where GPU context cannot be re-initialized
# from the parent process).
pass
44 changes: 24 additions & 20 deletions dadapy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
The *causal_graph* module contains the *CausalGraph* class, which inherits from the *DiffImbalance* class.

The code can be runned on gpu using the command
jax.config.update('jax_platform_name', 'gpu') # set 'cpu' or 'gpu'
jax.config.update('jax_platforms', 'gpu') # set 'cpu' or 'gpu'
"""

import itertools
Expand Down Expand Up @@ -451,9 +451,9 @@ def optimize_present_to_future( # noqa: C901
data_A=coords_present,
data_B=coords_future,
periods_A=self.periods,
periods_B=None
if self.periods is None
else self.periods[target_var],
periods_B=(
None if self.periods is None else self.periods[target_var]
),
seed=self.seed,
num_epochs=num_epochs,
batches_per_epoch=batches_per_epoch,
Expand Down Expand Up @@ -567,7 +567,7 @@ def _ancestors(self, adj_matrix):
"""
G = nx.DiGraph(adj_matrix)
auto_sets = []
for var in np.arange(adj_matrix.shape[0]):
for var in range(adj_matrix.shape[0]):
auto_sets.append(sorted(nx.ancestors(G, var) | {var}))
return auto_sets

Expand Down Expand Up @@ -1090,12 +1090,12 @@ def find_mediators(graph, node_start, node_end):

# initialize output variables
nvars = 1 + embedding_dim_present + embedding_dim_present
imbs_training[
community_name_cause, community_name_effect
] = np.zeros((len(time_lags), num_epochs + 1))
weights_final[
community_name_cause, community_name_effect
] = np.zeros((len(time_lags), nvars))
imbs_training[community_name_cause, community_name_effect] = (
np.zeros((len(time_lags), num_epochs + 1))
)
weights_final[community_name_cause, community_name_effect] = (
np.zeros((len(time_lags), nvars))
)
communities_ordered = np.concatenate(
(
[community_name_cause], # don't repeat (single slice)
Expand Down Expand Up @@ -1124,9 +1124,9 @@ def find_mediators(graph, node_start, node_end):
community_name_cause, community_name_effect
] = [communities_ordered, lags_ordered]
if compute_imb_final:
imbs_final[
community_name_cause, community_name_effect
] = np.zeros(len(time_lags))
imbs_final[community_name_cause, community_name_effect] = (
np.zeros(len(time_lags))
)
if compute_error:
errors_final[
community_name_cause, community_name_effect
Expand Down Expand Up @@ -1297,12 +1297,16 @@ def find_mediators(graph, node_start, node_end):
dii = DiffImbalance(
data_A=coords_A,
data_B=coords_future,
periods_A=None
if self.periods is None
else self.periods[variables_A],
periods_B=None
if self.periods is None
else self.periods[community_effect],
periods_A=(
None
if self.periods is None
else self.periods[variables_A]
),
periods_B=(
None
if self.periods is None
else self.periods[community_effect]
),
seed=self.seed,
num_epochs=num_epochs,
batches_per_epoch=batches_per_epoch,
Expand Down
6 changes: 3 additions & 3 deletions dadapy/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,9 @@ def _multimodality_test_v2(
current_saddle = saddle_density[i, 0]

if check == 1:
saddle_indices[
to_remove, -1
] = 0 # the couple center1, center2 is removed
saddle_indices[to_remove, -1] = (
0 # the couple center1, center2 is removed
)
margin1 = max_a1 / max_sum_err1
margin2 = max_a2 / max_sum_err2

Expand Down
22 changes: 19 additions & 3 deletions dadapy/diff_imbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
The only method supposed to be called by the user is 'train', which carries out the automatic optimization ot the
Differential Information as a function of the weights of the features in the first distance space.
The code can be runned on gpu using the command
jax.config.update('jax_platform_name', 'gpu') # set 'cpu' or 'gpu'
jax.config.update('jax_platforms', 'gpu') # set 'cpu' or 'gpu'
"""

import warnings
Expand Down Expand Up @@ -79,6 +79,22 @@ def _compute_dist2_matrix_scaling(
return dist2_matrix


class _LeafArrayTrainState(train_state.TrainState):
# Variant of flax.training.train_state.TrainState that accepts a leaf
# jax.Array as params. flax>=0.9 added an OVERWRITE_WITH_GRADIENT membership
# check in apply_gradients that assumes a Mapping/PyTree and breaks on a
# leaf array.
def apply_gradients(self, *, grads, **kwargs):
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)


# CLASS TO OPTIMIZE THE DIFFERENTIAL INFORMATION IMBALANCE
# ----------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -858,7 +874,7 @@ def _init_optimizer(self):
optimizer = opt_class(self.lr_schedule)

# Initialize training state
self.state = train_state.TrainState.create(
self.state = _LeafArrayTrainState.create(
apply_fn=self._distance_A,
params=self.params_init if self.state is None else self.state.params,
tx=optimizer,
Expand Down Expand Up @@ -1218,7 +1234,7 @@ def forward_greedy_feature_selection( # noqa: C901
selected_indices = valid_indices[best_valid_indices]

# Convert indices to lists for consistent processing
selected_features = [[idx] for idx in selected_indices]
selected_features = [[int(idx)] for idx in selected_indices]

# Add the best single feature to results
best_feature = selected_features[0]
Expand Down
Loading
Loading