From 128197701adfe6407551a5af9441fc820908449d Mon Sep 17 00:00:00 2001 From: gbrunin Date: Wed, 30 Oct 2024 11:38:03 +0100 Subject: [PATCH 1/3] Slightly relaxed tensorflow version. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 28781da1..262f0fe7 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ packages=setuptools.find_packages(), install_requires=[ "pandas~=1.5", - "tensorflow~=2.10,<2.12", + "tensorflow~=2.10,<=2.12", "pymatgen>=2023", "matminer~=0.9", "numpy>=1.24", From 5bf7bc33e62ace9d976053f07a73fbac6904b123 Mon Sep 17 00:00:00 2001 From: gbrunin Date: Fri, 8 Aug 2025 12:57:34 +0200 Subject: [PATCH 2/3] Added sample_weight treatment in fit genetic. --- modnet/hyper_opt/fit_genetic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index 7d0b8b7f..d64d1708 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -465,13 +465,21 @@ def function_fitness( else: n_splits = num_nested_folds train_val_datas = [] + sample_weights = [] for train, val in splits: train_val_datas.append(self.train_data.split((train, val))) + if "sample_weight" in pop[0].fit_params: + sample_weights.append(pop[0].fit_params["sample_weight"][train]) + else: + sample_weights.append(None) tasks = [] for i, individual in enumerate(pop): for j in range(n_splits): train_data, val_data = train_val_datas[j] + sample_weight = sample_weights[j] + if sample_weight is not None: + individual.fit_params["sample_weight"] = sample_weight tasks += [ { "individual": individual, From 1dd188891f7008e903802b47bab1c80b4aa28038 Mon Sep 17 00:00:00 2001 From: gbrunin Date: Fri, 8 Aug 2025 15:55:51 +0200 Subject: [PATCH 3/3] This fixes a bug that was triggered when early stopping was achieved, not sure why it was not triggered otherwise. --- modnet/hyper_opt/fit_genetic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index d64d1708..9e6ea7f7 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -661,6 +661,10 @@ def run( self.best_model = EnsembleMODNetModel(models=ensemble) """ + if "sample_weight" in fit_params: + self.best_individual.fit_params["sample_weight"] = fit_params[ + "sample_weight" + ] self.best_model = self.best_individual.refit_model( self.data, n_models=refit, n_jobs=n_jobs or 1, fast=fast )