From 96f06cbd8785f677b65f0d53ddf283dbe29bb9da Mon Sep 17 00:00:00 2001 From: hmacdope Date: Mon, 20 Apr 2026 12:41:24 +1000 Subject: [PATCH 1/2] Add return_members option to CommitteeRegressor.predict() Adds return_members=False to _predict() and predict(). When True, the raw per-member predictions are returned as the last element of the tuple with shape (n_samples, n_tasks, n_members). Composes cleanly with return_std: callers can request any combination of mean, std, and member predictions. Closes #464 Co-Authored-By: Claude Sonnet 4.6 --- openadmet/models/active_learning/committee.py | 45 ++++++++++++------- .../active_learning/test_active_learning.py | 33 ++++++++++++++ 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/openadmet/models/active_learning/committee.py b/openadmet/models/active_learning/committee.py index 4f8bf7fd..96ecc6a9 100644 --- a/openadmet/models/active_learning/committee.py +++ b/openadmet/models/active_learning/committee.py @@ -353,7 +353,7 @@ def query(self, X, query_strategy: str = None, **kwargs): return _ACQUISITION_FUNCTIONS[query_strategy](mean, std, **kwargs) - def _predict(self, X, return_std=False, **kwargs): + def _predict(self, X, return_std=False, return_members=False, **kwargs): """ Make predictions using the committee model. @@ -363,35 +363,43 @@ def _predict(self, X, return_std=False, **kwargs): The input samples to predict. return_std : bool, optional Whether to return the standard deviation of the predictions. + return_members : bool, optional + Whether to return the raw per-member predictions of shape + (n_samples, n_tasks, n_members). When True, returned as the + last element of the tuple. **kwargs : dict Additional keyword arguments to pass to the committee's predict method. Returns ------- - array-like - Predicted values or probabilities, depending on the committee's implementation. + array-like or tuple + mean, or (mean, std), or (mean, members), or (mean, std, members) + depending on the values of return_std and return_members. """ - # Make predictions + # Make predictions: (n_samples, n_tasks, n_members) preds = np.stack([model.predict(X, **kwargs) for model in self.models], axis=-1) # Compute mean mean = np.mean(preds, axis=-1) - # Skip std if not requested - if return_std is False: + if not return_std and not return_members: return mean - # Compute standard deviation - std = np.std(preds, axis=-1) + result = (mean,) - # Calibrate std if calibration model is available - if self.calibrated: - std = self._get_calibration_function()(std) + if return_std: + std = np.std(preds, axis=-1) + if self.calibrated: + std = self._get_calibration_function()(std) + result += (std,) + + if return_members: + result += (preds,) - return mean, std + return result - def predict(self, X, return_std=False, **kwargs): + def predict(self, X, return_std=False, return_members=False, **kwargs): """ Make predictions using the committee model. @@ -401,13 +409,18 @@ def predict(self, X, return_std=False, **kwargs): The input samples to predict. return_std : bool, optional Whether to return the standard deviation of the predictions. + return_members : bool, optional + Whether to return the raw per-member predictions of shape + (n_samples, n_tasks, n_members). When True, returned as the + last element of the tuple. **kwargs : dict Additional keyword arguments to pass to the committee's predict method. Returns ------- - array-like - Predicted values or probabilities, depending on the committee's implementation. + array-like or tuple + mean, or (mean, std), or (mean, members), or (mean, std, members) + depending on the values of return_std and return_members. """ if return_std is True and not self.calibrated: @@ -415,7 +428,7 @@ def predict(self, X, return_std=False, **kwargs): "Standard deviation not calibrated: consider calling `calibrate_uncertainty`." ) - return self._predict(X, return_std=return_std, **kwargs) + return self._predict(X, return_std=return_std, return_members=return_members, **kwargs) def _save_calibration_model(self, path: PathLike = "calibration_model.pkl"): # Save calibration model diff --git a/openadmet/models/tests/unit/active_learning/test_active_learning.py b/openadmet/models/tests/unit/active_learning/test_active_learning.py index 362d20a6..90c759eb 100644 --- a/openadmet/models/tests/unit/active_learning/test_active_learning.py +++ b/openadmet/models/tests/unit/active_learning/test_active_learning.py @@ -155,6 +155,39 @@ def deserialize(self, param_path, serial_path): pass +def test_return_members(toy_data): + """Test that return_members exposes per-member predictions with correct shape.""" + X_train, _, X_test, y_train, _, _ = toy_data + n_members = 4 + n_tasks = 1 + + committee = CommitteeRegressor.train( + X_train, + y_train, + mod_class=MockCommitteeModel, + mod_params={}, + n_models=n_members, + use_bagging=False, + ) + + # return_members only + mean, members = committee.predict(X_test, return_members=True) + assert members.shape == (X_test.shape[0], n_tasks, n_members) + assert mean.shape == (X_test.shape[0], n_tasks) + assert_allclose(mean, np.mean(members, axis=-1)) + + # return_members + return_std + mean2, std, members2 = committee.predict(X_test, return_std=True, return_members=True) + assert members2.shape == (X_test.shape[0], n_tasks, n_members) + assert std.shape == (X_test.shape[0], n_tasks) + assert_allclose(mean2, np.mean(members2, axis=-1)) + assert_allclose(std, np.std(members2, axis=-1)) + + # Neither flag — plain mean returned (not a tuple) + result = committee.predict(X_test) + assert isinstance(result, np.ndarray) + + def test_committee_bagging_logic(toy_data): """Test that use_bagging flag correctly controls bootstrap aggregation.""" X_train, _, _, y_train, _, _ = toy_data From e0128d107a2576fd0d09d0e81911a6963a5565b0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 02:41:44 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openadmet/models/active_learning/committee.py | 4 +++- .../models/tests/unit/active_learning/test_active_learning.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/openadmet/models/active_learning/committee.py b/openadmet/models/active_learning/committee.py index 96ecc6a9..5fbc615f 100644 --- a/openadmet/models/active_learning/committee.py +++ b/openadmet/models/active_learning/committee.py @@ -428,7 +428,9 @@ def predict(self, X, return_std=False, return_members=False, **kwargs): "Standard deviation not calibrated: consider calling `calibrate_uncertainty`." ) - return self._predict(X, return_std=return_std, return_members=return_members, **kwargs) + return self._predict( + X, return_std=return_std, return_members=return_members, **kwargs + ) def _save_calibration_model(self, path: PathLike = "calibration_model.pkl"): # Save calibration model diff --git a/openadmet/models/tests/unit/active_learning/test_active_learning.py b/openadmet/models/tests/unit/active_learning/test_active_learning.py index 90c759eb..60c415b1 100644 --- a/openadmet/models/tests/unit/active_learning/test_active_learning.py +++ b/openadmet/models/tests/unit/active_learning/test_active_learning.py @@ -177,7 +177,9 @@ def test_return_members(toy_data): assert_allclose(mean, np.mean(members, axis=-1)) # return_members + return_std - mean2, std, members2 = committee.predict(X_test, return_std=True, return_members=True) + mean2, std, members2 = committee.predict( + X_test, return_std=True, return_members=True + ) assert members2.shape == (X_test.shape[0], n_tasks, n_members) assert std.shape == (X_test.shape[0], n_tasks) assert_allclose(mean2, np.mean(members2, axis=-1))