Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions pygod/metric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Author: Yingtong Dou <ytongdou@gmail.com>, Kay Liu <zliu234@uic.edu>
# License: BSD 2 clause

import warnings

from sklearn.metrics import (
roc_auc_score,
average_precision_score,
Expand All @@ -29,6 +31,14 @@ def eval_roc_auc(label, score):
roc_auc : float
Average ROC-AUC score across different labels.
"""
unique_labels = label.unique()
if len(unique_labels) < 2:
warnings.warn(
"Only one class is present in y_true. ROC AUC score is "
"not defined in that case. Returning 0.0.",
UserWarning
)
return 0.0

roc_auc = roc_auc_score(y_true=label, y_score=score)
return roc_auc
Expand All @@ -54,10 +64,29 @@ def eval_recall_at_k(label, score, k=None):
recall_at_k : float
Recall for top k instances with the highest outlier scores.
"""
num_labels = int(sum(label))

if num_labels == 0:
warnings.warn(
"No positive labels found in y_true. Returning 0.0.",
UserWarning
)
return 0.0

if k is None:
k = sum(label)
recall_at_k = sum(label[score.topk(k).indices]) / sum(label)
k = num_labels

# Validate k parameter
if not isinstance(k, int):
raise TypeError(f"k must be an integer, got {type(k).__name__}")

if k < 0:
raise ValueError(f"k must be non-negative, got {k}")

# Clamp k to valid range
k = min(k, len(label))

recall_at_k = sum(label[score.topk(k).indices]) / num_labels
return recall_at_k


Expand All @@ -81,9 +110,26 @@ def eval_precision_at_k(label, score, k=None):
precision_at_k : float
Precision for top k instances with the highest outlier scores.
"""

if k is None:
k = sum(label)
k = int(sum(label))

# Validate k parameter
if not isinstance(k, int):
raise TypeError(f"k must be an integer, got {type(k).__name__}")

if k < 0:
raise ValueError(f"k must be non-negative, got {k}")

if k == 0:
warnings.warn(
"k is 0, which results in division by zero. Returning 0.0.",
UserWarning
)
return 0.0

# Clamp k to valid range
k = min(k, len(label))

precision_at_k = sum(label[score.topk(k).indices]) / k
return precision_at_k

Expand Down
91 changes: 91 additions & 0 deletions pygod/test/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import unittest
import warnings
from numpy.testing import assert_allclose

import torch
Expand Down Expand Up @@ -46,3 +47,93 @@ def test_eval_average_precision(self):
def test_eval_f1(self):
assert_allclose(f1_score(self.y, self.pred),
eval_f1(self.y, self.pred))


class TestMetricEdgeCases(unittest.TestCase):
"""Tests for edge cases and boundary conditions in metric functions."""

def setUp(self):
self.y_all_zeros = torch.tensor([0, 0, 0, 0, 0])
self.y_all_ones = torch.tensor([1, 1, 1, 1, 1])
self.y_mixed = torch.tensor([0, 0, 1, 1])
self.score = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
self.score_mixed = torch.tensor([0.1, 0.2, 0.3, 0.4])

# --- eval_roc_auc edge cases ---

def test_eval_roc_auc_all_zeros(self):
"""eval_roc_auc should return 0.0 with a warning when all labels are 0."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = eval_roc_auc(self.y_all_zeros, self.score)
self.assertEqual(result, 0.0)
self.assertEqual(len(w), 1)
self.assertIn("Only one class", str(w[0].message))

def test_eval_roc_auc_all_ones(self):
"""eval_roc_auc should return 0.0 with a warning when all labels are 1."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = eval_roc_auc(self.y_all_ones, self.score)
self.assertEqual(result, 0.0)
self.assertEqual(len(w), 1)
self.assertIn("Only one class", str(w[0].message))

# --- eval_recall_at_k edge cases ---

def test_eval_recall_at_k_no_positives(self):
"""eval_recall_at_k should return 0.0 with a warning when no positive labels."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = eval_recall_at_k(self.y_all_zeros, self.score)
self.assertEqual(result, 0.0)
self.assertEqual(len(w), 1)
self.assertIn("No positive labels", str(w[0].message))

def test_eval_recall_at_k_zero_k(self):
"""eval_recall_at_k should return 0.0 when k=0."""
result = eval_recall_at_k(self.y_mixed, self.score_mixed, k=0)
self.assertEqual(result, 0.0)

def test_eval_recall_at_k_large_k(self):
"""eval_recall_at_k should clamp k to len(label) when k > len(label)."""
result = eval_recall_at_k(self.y_mixed, self.score_mixed, k=100)
self.assertGreaterEqual(result, 0.0)
self.assertLessEqual(result, 1.0)

def test_eval_recall_at_k_negative_k(self):
"""eval_recall_at_k should raise ValueError for negative k."""
with self.assertRaises(ValueError):
eval_recall_at_k(self.y_mixed, self.score_mixed, k=-1)

def test_eval_recall_at_k_float_k(self):
"""eval_recall_at_k should raise TypeError for float k."""
with self.assertRaises(TypeError):
eval_recall_at_k(self.y_mixed, self.score_mixed, k=2.5)

# --- eval_precision_at_k edge cases ---

def test_eval_precision_at_k_zero_k(self):
"""eval_precision_at_k should return 0.0 with a warning when k=0."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = eval_precision_at_k(self.y_mixed, self.score_mixed, k=0)
self.assertEqual(result, 0.0)
self.assertEqual(len(w), 1)
self.assertIn("division by zero", str(w[0].message))

def test_eval_precision_at_k_large_k(self):
"""eval_precision_at_k should clamp k to len(label) when k > len(label)."""
result = eval_precision_at_k(self.y_mixed, self.score_mixed, k=100)
self.assertGreaterEqual(result, 0.0)
self.assertLessEqual(result, 1.0)

def test_eval_precision_at_k_negative_k(self):
"""eval_precision_at_k should raise ValueError for negative k."""
with self.assertRaises(ValueError):
eval_precision_at_k(self.y_mixed, self.score_mixed, k=-1)

def test_eval_precision_at_k_float_k(self):
"""eval_precision_at_k should raise TypeError for float k."""
with self.assertRaises(TypeError):
eval_precision_at_k(self.y_mixed, self.score_mixed, k=2.5)