Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions openadmet/models/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ def load_anvil_model_and_metadata(model_dir):


def _generate_pairwise_df(
data, input_col, feat, predictions, predictions_tag, std_tag
data, input_col, feat, predictions, predictions_tag, std_tag, task_idx=0
) -> pd.DataFrame:
"""Generate a DataFrame for pairwise predictions."""
smiles = data[input_col].values
pairwise_dataset = PairwiseAugmentedDataset(smiles, None, how=feat.how_to_pair)
pairs = pairwise_dataset.idxs # list of (i, j) tuples

smiles_i = [smiles[i] for i, j in pairs]
smiles_j = [smiles[j] for i, j in pairs]
pred = predictions[:, j]
smiles_i = [smiles[ii] for ii, jj in pairs]
smiles_j = [smiles[jj] for ii, jj in pairs]
pred = predictions[:, task_idx]

pairwise_df = pd.DataFrame(
{
Expand All @@ -112,7 +112,7 @@ def _generate_pairwise_df(
}
)

pairwise_df[std_tag] = pd.Series(predictions[:, j], index=pairwise_df.index)
pairwise_df[std_tag] = pd.Series(predictions[:, task_idx], index=pairwise_df.index)

pairwise_df[input_col] = (
pairwise_df[f"{input_col}_i"] + " - " + pairwise_df[f"{input_col}_j"]
Expand Down Expand Up @@ -268,7 +268,13 @@ def predict(
"Detected pairwise featurizer, generating pairwise output DataFrame"
)
data = _generate_pairwise_df(
data, input_col, feat, predictions, predictions_tag, std_tag
data,
input_col,
feat,
predictions,
predictions_tag,
std_tag,
task_idx=j,
)

else:
Expand Down
122 changes: 119 additions & 3 deletions openadmet/models/tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pathlib import Path
import pandas as pd
import os
from types import SimpleNamespace

import numpy as np
import pandas as pd
import pytest

from openadmet.models.inference.inference import predict
import openadmet.models.inference.inference as inference_module
from openadmet.models.tests.unit.datafiles import (
pred_test_data_csv,
anvil_lgbm_trained_model_dir,
Expand Down Expand Up @@ -36,7 +39,7 @@ def test_predict(model_dir, request):
output_path = None
debug = False

result = predict(
result = inference_module.predict(
input_path,
input_col,
model_dir,
Expand All @@ -48,3 +51,116 @@ def test_predict(model_dir, request):

# Check if the result is a DataFrame
assert isinstance(result, pd.DataFrame)


def test_generate_pairwise_df_uses_task_idx_column():
data = pd.DataFrame({"SMILES": ["CCO", "CCN"]})
predictions = np.array([[1.0, 11.0], [2.0, 12.0], [3.0, 13.0]])
feat = SimpleNamespace(how_to_pair="ut")

pairwise_df = inference_module._generate_pairwise_df(
data=data,
input_col="SMILES",
feat=feat,
predictions=predictions,
predictions_tag="pred",
std_tag="std",
task_idx=1,
)

assert pairwise_df["pred"].tolist() == [11.0, 12.0, 13.0]
assert pairwise_df["std"].tolist() == [11.0, 12.0, 13.0]


def test_predict_single_task_pairwise_uses_column_zero(monkeypatch):
class DummyPairwiseFeaturizer:
def __init__(self, how_to_pair="ut"):
self.how_to_pair = how_to_pair

def featurize(self, smiles):
return np.zeros((len(smiles), 1)), np.arange(len(smiles))

class DummyModel:
estimator = "dummy"

def predict(self, X_feat, accelerator="cpu"):
return np.array([[7.0], [8.0], [9.0]])

def fake_loader(_):
return (
DummyModel(),
DummyPairwiseFeaturizer(),
SimpleNamespace(tag="PAIR"),
SimpleNamespace(target_cols=["task0"]),
)

monkeypatch.setattr(inference_module, "PairwiseFeaturizer", DummyPairwiseFeaturizer)
monkeypatch.setattr(inference_module, "load_anvil_model_and_metadata", fake_loader)

input_df = pd.DataFrame({"SMILES": ["CCO", "CCN"]})
result = inference_module.predict(
input_path=input_df,
input_col="SMILES",
model_dir="dummy_model",
write_csv=False,
output_csv=None,
debug=False,
accelerator="cpu",
log=False,
)

pred_col = "OADMET_PRED_PAIR_task0"
assert pred_col in result.columns
assert result[pred_col].tolist() == [7.0, 8.0, 9.0]


def test_predict_pairwise_multitask_passes_task_idx(monkeypatch):
class DummyPairwiseFeaturizer:
def __init__(self, how_to_pair="ut"):
self.how_to_pair = how_to_pair

def featurize(self, smiles):
return np.zeros((len(smiles), 1)), np.arange(len(smiles))

class DummyModel:
estimator = "dummy"

def predict(self, X_feat, accelerator="cpu"):
return np.array([[1.0, 2.0, 3.0]])

def fake_loader(_):
return (
DummyModel(),
DummyPairwiseFeaturizer(),
SimpleNamespace(tag="PAIR"),
SimpleNamespace(target_cols=["task0", "task1", "task2"]),
)

task_idxs = []

def fake_generate_pairwise_df(
data, input_col, feat, predictions, predictions_tag, std_tag, task_idx=0
):
task_idxs.append(task_idx)
data[predictions_tag] = task_idx
data[std_tag] = task_idx
return data

monkeypatch.setattr(inference_module, "PairwiseFeaturizer", DummyPairwiseFeaturizer)
monkeypatch.setattr(inference_module, "load_anvil_model_and_metadata", fake_loader)
monkeypatch.setattr(
inference_module, "_generate_pairwise_df", fake_generate_pairwise_df
)

inference_module.predict(
input_path=pd.DataFrame({"SMILES": ["CCO"]}),
input_col="SMILES",
model_dir="dummy_model",
write_csv=False,
output_csv=None,
debug=False,
accelerator="cpu",
log=False,
)

assert task_idxs == [0, 1, 2]