diff --git a/openadmet/models/inference/inference.py b/openadmet/models/inference/inference.py index 797e9c7e..1e575685 100644 --- a/openadmet/models/inference/inference.py +++ b/openadmet/models/inference/inference.py @@ -93,16 +93,16 @@ def load_anvil_model_and_metadata(model_dir): def _generate_pairwise_df( - data, input_col, feat, predictions, predictions_tag, std_tag + data, input_col, feat, predictions, predictions_tag, std_tag, task_idx=0 ) -> pd.DataFrame: """Generate a DataFrame for pairwise predictions.""" smiles = data[input_col].values pairwise_dataset = PairwiseAugmentedDataset(smiles, None, how=feat.how_to_pair) pairs = pairwise_dataset.idxs # list of (i, j) tuples - smiles_i = [smiles[i] for i, j in pairs] - smiles_j = [smiles[j] for i, j in pairs] - pred = predictions[:, j] + smiles_i = [smiles[ii] for ii, jj in pairs] + smiles_j = [smiles[jj] for ii, jj in pairs] + pred = predictions[:, task_idx] pairwise_df = pd.DataFrame( { @@ -112,7 +112,7 @@ def _generate_pairwise_df( } ) - pairwise_df[std_tag] = pd.Series(predictions[:, j], index=pairwise_df.index) + pairwise_df[std_tag] = pd.Series(predictions[:, task_idx], index=pairwise_df.index) pairwise_df[input_col] = ( pairwise_df[f"{input_col}_i"] + " - " + pairwise_df[f"{input_col}_j"] @@ -268,7 +268,13 @@ def predict( "Detected pairwise featurizer, generating pairwise output DataFrame" ) data = _generate_pairwise_df( - data, input_col, feat, predictions, predictions_tag, std_tag + data, + input_col, + feat, + predictions, + predictions_tag, + std_tag, + task_idx=j, ) else: diff --git a/openadmet/models/tests/unit/inference/test_inference.py b/openadmet/models/tests/unit/inference/test_inference.py index df341b21..f64d0819 100644 --- a/openadmet/models/tests/unit/inference/test_inference.py +++ b/openadmet/models/tests/unit/inference/test_inference.py @@ -1,9 +1,12 @@ from pathlib import Path -import pandas as pd import os +from types import SimpleNamespace + +import numpy as np +import pandas as pd import pytest -from openadmet.models.inference.inference import predict +import openadmet.models.inference.inference as inference_module from openadmet.models.tests.unit.datafiles import ( pred_test_data_csv, anvil_lgbm_trained_model_dir, @@ -36,7 +39,7 @@ def test_predict(model_dir, request): output_path = None debug = False - result = predict( + result = inference_module.predict( input_path, input_col, model_dir, @@ -48,3 +51,116 @@ def test_predict(model_dir, request): # Check if the result is a DataFrame assert isinstance(result, pd.DataFrame) + + +def test_generate_pairwise_df_uses_task_idx_column(): + data = pd.DataFrame({"SMILES": ["CCO", "CCN"]}) + predictions = np.array([[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]]) + feat = SimpleNamespace(how_to_pair="ut") + + pairwise_df = inference_module._generate_pairwise_df( + data=data, + input_col="SMILES", + feat=feat, + predictions=predictions, + predictions_tag="pred", + std_tag="std", + task_idx=1, + ) + + assert pairwise_df["pred"].tolist() == [11.0, 12.0, 13.0] + assert pairwise_df["std"].tolist() == [11.0, 12.0, 13.0] + + +def test_predict_single_task_pairwise_uses_column_zero(monkeypatch): + class DummyPairwiseFeaturizer: + def __init__(self, how_to_pair="ut"): + self.how_to_pair = how_to_pair + + def featurize(self, smiles): + return np.zeros((len(smiles), 1)), np.arange(len(smiles)) + + class DummyModel: + estimator = "dummy" + + def predict(self, X_feat, accelerator="cpu"): + return np.array([[7.0], [8.0], [9.0]]) + + def fake_loader(_): + return ( + DummyModel(), + DummyPairwiseFeaturizer(), + SimpleNamespace(tag="PAIR"), + SimpleNamespace(target_cols=["task0"]), + ) + + monkeypatch.setattr(inference_module, "PairwiseFeaturizer", DummyPairwiseFeaturizer) + monkeypatch.setattr(inference_module, "load_anvil_model_and_metadata", fake_loader) + + input_df = pd.DataFrame({"SMILES": ["CCO", "CCN"]}) + result = inference_module.predict( + input_path=input_df, + input_col="SMILES", + model_dir="dummy_model", + write_csv=False, + output_csv=None, + debug=False, + accelerator="cpu", + log=False, + ) + + pred_col = "OADMET_PRED_PAIR_task0" + assert pred_col in result.columns + assert result[pred_col].tolist() == [7.0, 8.0, 9.0] + + +def test_predict_pairwise_multitask_passes_task_idx(monkeypatch): + class DummyPairwiseFeaturizer: + def __init__(self, how_to_pair="ut"): + self.how_to_pair = how_to_pair + + def featurize(self, smiles): + return np.zeros((len(smiles), 1)), np.arange(len(smiles)) + + class DummyModel: + estimator = "dummy" + + def predict(self, X_feat, accelerator="cpu"): + return np.array([[1.0, 2.0, 3.0]]) + + def fake_loader(_): + return ( + DummyModel(), + DummyPairwiseFeaturizer(), + SimpleNamespace(tag="PAIR"), + SimpleNamespace(target_cols=["task0", "task1", "task2"]), + ) + + task_idxs = [] + + def fake_generate_pairwise_df( + data, input_col, feat, predictions, predictions_tag, std_tag, task_idx=0 + ): + task_idxs.append(task_idx) + data[predictions_tag] = task_idx + data[std_tag] = task_idx + return data + + monkeypatch.setattr(inference_module, "PairwiseFeaturizer", DummyPairwiseFeaturizer) + monkeypatch.setattr(inference_module, "load_anvil_model_and_metadata", fake_loader) + monkeypatch.setattr( + inference_module, "_generate_pairwise_df", fake_generate_pairwise_df + ) + + inference_module.predict( + input_path=pd.DataFrame({"SMILES": ["CCO"]}), + input_col="SMILES", + model_dir="dummy_model", + write_csv=False, + output_csv=None, + debug=False, + accelerator="cpu", + log=False, + ) + + assert task_idxs == [0, 1, 2]