diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index 7d0b8b7f..9e6ea7f7 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -465,13 +465,21 @@ def function_fitness( else: n_splits = num_nested_folds train_val_datas = [] + sample_weights = [] for train, val in splits: train_val_datas.append(self.train_data.split((train, val))) + if "sample_weight" in pop[0].fit_params: + sample_weights.append(pop[0].fit_params["sample_weight"][train]) + else: + sample_weights.append(None) tasks = [] for i, individual in enumerate(pop): for j in range(n_splits): train_data, val_data = train_val_datas[j] + sample_weight = sample_weights[j] + if sample_weight is not None: + individual.fit_params["sample_weight"] = sample_weight tasks += [ { "individual": individual, @@ -653,6 +661,10 @@ def run( self.best_model = EnsembleMODNetModel(models=ensemble) """ + if "sample_weight" in fit_params: + self.best_individual.fit_params["sample_weight"] = fit_params[ + "sample_weight" + ] self.best_model = self.best_individual.refit_model( self.data, n_models=refit, n_jobs=n_jobs or 1, fast=fast )