diff --git a/src/mapcv/splitter.py b/src/mapcv/splitter.py index b187d3d..3f51d30 100644 --- a/src/mapcv/splitter.py +++ b/src/mapcv/splitter.py @@ -31,13 +31,9 @@ def _classify_entry(entry: ManifestEntry) -> int: """ counts = entry["per_class_pixel_counts"] if counts: - has_bg = "0" in counts - has_labeled = any(k != "0" for k in counts) - if has_bg and not has_labeled: - return 0 - if has_labeled and not has_bg: - return 1 - return 2 + if "0" in counts: + return 0 if len(counts) == 1 else 2 + return 1 if entry["empty_ratio"] >= 1.0: return 0 if entry["empty_ratio"] <= 0.0: