Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions openadmet/models/active_learning/committee.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def query(self, X, query_strategy: str = None, **kwargs):

return _ACQUISITION_FUNCTIONS[query_strategy](mean, std, **kwargs)

def _predict(self, X, return_std=False, **kwargs):
def _predict(self, X, return_std=False, return_members=False, **kwargs):
"""
Make predictions using the committee model.

Expand All @@ -363,35 +363,43 @@ def _predict(self, X, return_std=False, **kwargs):
The input samples to predict.
return_std : bool, optional
Whether to return the standard deviation of the predictions.
return_members : bool, optional
Whether to return the raw per-member predictions of shape
(n_samples, n_tasks, n_members). When True, returned as the
last element of the tuple.
**kwargs : dict
Additional keyword arguments to pass to the committee's predict method.

Returns
-------
array-like
Predicted values or probabilities, depending on the committee's implementation.
array-like or tuple
mean, or (mean, std), or (mean, members), or (mean, std, members)
depending on the values of return_std and return_members.

"""
# Make predictions
# Make predictions: (n_samples, n_tasks, n_members)
preds = np.stack([model.predict(X, **kwargs) for model in self.models], axis=-1)

# Compute mean
mean = np.mean(preds, axis=-1)

# Skip std if not requested
if return_std is False:
if not return_std and not return_members:
return mean

# Compute standard deviation
std = np.std(preds, axis=-1)
result = (mean,)

# Calibrate std if calibration model is available
if self.calibrated:
std = self._get_calibration_function()(std)
if return_std:
std = np.std(preds, axis=-1)
if self.calibrated:
std = self._get_calibration_function()(std)
result += (std,)

if return_members:
result += (preds,)

return mean, std
return result

def predict(self, X, return_std=False, **kwargs):
def predict(self, X, return_std=False, return_members=False, **kwargs):
"""
Make predictions using the committee model.

Expand All @@ -401,21 +409,28 @@ def predict(self, X, return_std=False, **kwargs):
The input samples to predict.
return_std : bool, optional
Whether to return the standard deviation of the predictions.
return_members : bool, optional
Whether to return the raw per-member predictions of shape
(n_samples, n_tasks, n_members). When True, returned as the
last element of the tuple.
**kwargs : dict
Additional keyword arguments to pass to the committee's predict method.

Returns
-------
array-like
Predicted values or probabilities, depending on the committee's implementation.
array-like or tuple
mean, or (mean, std), or (mean, members), or (mean, std, members)
depending on the values of return_std and return_members.

"""
if return_std is True and not self.calibrated:
logger.warning(
"Standard deviation not calibrated: consider calling `calibrate_uncertainty`."
)

return self._predict(X, return_std=return_std, **kwargs)
return self._predict(
X, return_std=return_std, return_members=return_members, **kwargs
)

def _save_calibration_model(self, path: PathLike = "calibration_model.pkl"):
# Save calibration model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,41 @@ def deserialize(self, param_path, serial_path):
pass


def test_return_members(toy_data):
"""Test that return_members exposes per-member predictions with correct shape."""
X_train, _, X_test, y_train, _, _ = toy_data
n_members = 4
n_tasks = 1

committee = CommitteeRegressor.train(
X_train,
y_train,
mod_class=MockCommitteeModel,
mod_params={},
n_models=n_members,
use_bagging=False,
)

# return_members only
mean, members = committee.predict(X_test, return_members=True)
assert members.shape == (X_test.shape[0], n_tasks, n_members)
assert mean.shape == (X_test.shape[0], n_tasks)
assert_allclose(mean, np.mean(members, axis=-1))

# return_members + return_std
mean2, std, members2 = committee.predict(
X_test, return_std=True, return_members=True
)
assert members2.shape == (X_test.shape[0], n_tasks, n_members)
assert std.shape == (X_test.shape[0], n_tasks)
assert_allclose(mean2, np.mean(members2, axis=-1))
assert_allclose(std, np.std(members2, axis=-1))

# Neither flag — plain mean returned (not a tuple)
result = committee.predict(X_test)
assert isinstance(result, np.ndarray)


def test_committee_bagging_logic(toy_data):
"""Test that use_bagging flag correctly controls bootstrap aggregation."""
X_train, _, _, y_train, _, _ = toy_data
Expand Down
Loading