From a11735452f7745bf3d1a7cacd90dc91bd7f9d051 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Fri, 29 May 2026 14:15:33 +0200 Subject: [PATCH 1/7] test: download/contaminants --- tests/download/test_contaminants.py | 393 ++++++++++++++++++++++++++++ 1 file changed, 393 insertions(+) create mode 100644 tests/download/test_contaminants.py diff --git a/tests/download/test_contaminants.py b/tests/download/test_contaminants.py new file mode 100644 index 0000000..2dbfa52 --- /dev/null +++ b/tests/download/test_contaminants.py @@ -0,0 +1,393 @@ +""" +Unit tests for ``proteopy.download.contaminants``. + +Download fidelity is checked by SHA-256 hashing the resulting file and +comparing it against the hash of the expected payload (computed inline +from in-code byte constants for mocked tests, or pinned for the two +real-download tests). +""" + +import hashlib +import importlib +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +# ``proteopy.download.__init__`` re-exports ``contaminants`` (the function), +# shadowing the submodule name on attribute lookup. Resolve via sys.modules +# to get the actual module object for ``patch.object``. +contam_mod = importlib.import_module("proteopy.download.contaminants") +from proteopy.download.contaminants import ( # noqa: E402 + _is_uniprot_accession, + check_uniprot_accession_nr, + contaminants, +) + + +# --------------------------------------------------------------------------- +# In-code FASTA payloads +# --------------------------------------------------------------------------- + +FRANKENFIELD_RAW = ( + b">sp|P12345|HUMAN_PROT first description\n" + b"MAAAAACDEFGHIKLMNPQRSTVWY\n" + b">sp|Cont_P67890|MOUSE_PROT contaminant entry\n" + b"MGGGGGHHHIIIKKKLLL\n" + b">sp|AAAA1|MANUAL_ID manually curated entry\n" + b"KKLLLMMNN\n" +) + +# Byte-exact output produced by ``_format_fasta`` with the +# ``_format_frankenfield_header`` formatter applied to FRANKENFIELD_RAW. +# Only difference: the ``Cont_`` prefix on the second accession is stripped. +FRANKENFIELD_FORMATTED = ( + b">sp|P12345|HUMAN_PROT first description\n" + b"MAAAAACDEFGHIKLMNPQRSTVWY\n" + b">sp|P67890|MOUSE_PROT contaminant entry\n" + b"MGGGGGHHHIIIKKKLLL\n" + b">sp|AAAA1|MANUAL_ID manually curated entry\n" + b"KKLLLMMNN\n" +) + +GPM_RAW = ( + b">sp|P00001|CRAP_ENTRY1 example cRAP entry\n" + b"MAAAACDEF\n" + b">sp|P00002|CRAP_ENTRY2 second entry\n" + b"MGGGGGHHH\n" +) + +# Pinned hashes for real-download tests. Re-pin if upstream rotates content. +# Recorded 2026-05-11 from the URLs in ``contam_mod._SOURCE_MAP``. +EXPECTED_FRANKENFIELD_REMOTE_HASH = ( + "b4c1c74438e3d60ee93546a4b717225da318d7c4c30344b91fef4cb8cf6e9f89" +) +EXPECTED_GPM_REMOTE_HASH = ( + "4b0e6e97ab1d618baa38be612a787d884a781e07f627b919c4b29ac064db5382" +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sha256(path: Path) -> str: + return hashlib.sha256(path.read_bytes()).hexdigest() + + +def _sha256_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _make_download_mock(payload: bytes): + """Return a fake ``_download(url, dest)`` that writes payload to dest.""" + + def fake(url, dest): + Path(dest).write_bytes(payload) + + return fake + + +# --------------------------------------------------------------------------- +# 1. _is_uniprot_accession +# --------------------------------------------------------------------------- + + +class TestIsUniprotAccession: + @pytest.mark.parametrize( + "accession", + [ + "P12345", # Swiss-Prot, [OPQ] branch + "O00001", # Swiss-Prot, [OPQ] branch with zeros + "Q9Y6K9", # Swiss-Prot, mixed alphanumeric + "A0A0A0A0A1", # TrEMBL 10-char, [A-NR-Z] branch + "P12345-2", # isoform suffix + "A0A0A0A0A1-12", # TrEMBL with two-digit isoform + ], + ) + def test_valid_accessions(self, accession): + assert _is_uniprot_accession(accession) is True + + @pytest.mark.parametrize( + "accession", + [ + "", # empty + "p12345", # lowercase + "P1234", # too short + "P123456", # too long for [OPQ] branch + "X12345", # second char not a letter-then-3 pattern + "12345P", # leading digit + "1ABCDE", # leading digit + "P12345-", # dangling isoform separator + "P12345-123", # isoform suffix too long + ], + ) + def test_invalid_accessions(self, accession): + assert _is_uniprot_accession(accession) is False + + +# --------------------------------------------------------------------------- +# 2. check_uniprot_accession_nr +# --------------------------------------------------------------------------- + + +class TestCheckUniprotAccessionNr: + def test_valid_returns_none(self): + assert check_uniprot_accession_nr("P12345") is None + + @pytest.mark.parametrize("accession", ["", "p12345", "BADID", "12345"]) + def test_invalid_raises_value_error(self, accession): + with pytest.raises(ValueError, match="not a valid UniProt accession"): + check_uniprot_accession_nr(accession) + + +# --------------------------------------------------------------------------- +# 3. contaminants (mocked + opt-in real-download) +# +# Frankenfield-header and FASTA-rewrite behaviour is exercised through the +# public ``contaminants()`` function only; the private helpers +# ``_format_frankenfield_header`` and ``_format_fasta`` are not tested +# directly. +# --------------------------------------------------------------------------- + + +class TestContaminants: + # -- error inputs -------------------------------------------------------- + + def test_unsupported_source_raises(self): + with pytest.raises(ValueError, match="Unsupported source"): + contaminants(source="bogus") + + # -- happy paths with hash checks --------------------------------------- + + def test_gpm_crap_hash_matches_raw_payload(self, tmp_path): + dst = tmp_path / "gpm.fasta" + with patch.object( + contam_mod, "_download", _make_download_mock(GPM_RAW), + ): + result = contaminants(source="gpm_crap", path=dst) + + assert result == dst + assert result.exists() + assert _sha256(result) == _sha256_bytes(GPM_RAW) + + def test_frankenfield_hash_matches_formatted_payload(self, tmp_path): + """Confirms ``Cont_`` prefixes are stripped byte-exactly.""" + dst = tmp_path / "frank.fasta" + with patch.object( + contam_mod, + "_download", + _make_download_mock(FRANKENFIELD_RAW), + ): + result = contaminants(source="frankenfield2022", path=dst) + + assert result == dst + assert result.exists() + assert _sha256(result) == _sha256_bytes(FRANKENFIELD_FORMATTED) + + @pytest.mark.parametrize( + "source,payload,expected", + [ + ("gpm_crap", GPM_RAW, GPM_RAW), + ("frankenfield2022", FRANKENFIELD_RAW, FRANKENFIELD_FORMATTED), + ], + ) + def test_returns_path_to_destination( + self, tmp_path, source, payload, expected, + ): + dst = tmp_path / f"{source}.fasta" + with patch.object( + contam_mod, "_download", _make_download_mock(payload), + ): + result = contaminants(source=source, path=dst) + + assert isinstance(result, Path) + assert result == dst + assert _sha256(result) == _sha256_bytes(expected) + + # -- default path with date suffix -------------------------------------- + + @pytest.mark.parametrize( + "source, payload, expected, default_stem", + [ + ( + "gpm_crap", GPM_RAW, GPM_RAW, + "contaminants_gpm-crap", + ), + ( + "frankenfield2022", FRANKENFIELD_RAW, FRANKENFIELD_FORMATTED, + "contaminants_frankenfield2022", + ), + ], + ) + def test_default_path_appends_md5_digest( + self, tmp_path, monkeypatch, source, payload, expected, default_stem, + ): + """ + With ``path=None`` the function writes to a default file in the + current working directory whose stem carries the first 8 hex chars + of the MD5 of the final candidate bytes (post-formatting). The + internal ``TemporaryDirectory`` must be torn down even when the + caller did not supply a path. + """ + monkeypatch.chdir(tmp_path) + expected_md5 = hashlib.md5(expected).hexdigest()[:8] + + captured_temp_dirs = [] + orig_td = tempfile.TemporaryDirectory + + def spy_td(*args, **kwargs): + obj = orig_td(*args, **kwargs) + captured_temp_dirs.append(Path(obj.name)) + return obj + + monkeypatch.setattr(tempfile, "TemporaryDirectory", spy_td) + + with patch.object( + contam_mod, "_download", _make_download_mock(payload), + ): + result = contaminants(source=source, path=None) + + expected_rel = Path(f"{default_stem}_{expected_md5}.fasta") + assert result == expected_rel + assert (tmp_path / expected_rel).exists() + assert _sha256(tmp_path / expected_rel) == _sha256_bytes(expected) + + # The internal temp directory must be cleaned up on the success + # path even when ``path`` was not supplied. + assert len(captured_temp_dirs) == 1 + assert not captured_temp_dirs[0].exists() + + # -- parent directory creation ----------------------------------------- + + def test_parent_directory_is_created(self, tmp_path): + dst = tmp_path / "nested" / "sub" / "x.fasta" + assert not dst.parent.exists() + + with patch.object( + contam_mod, "_download", _make_download_mock(GPM_RAW), + ): + result = contaminants(source="gpm_crap", path=dst) + + assert dst.parent.is_dir() + assert result.exists() + + # -- force=False / force=True ------------------------------------------ + + def test_existing_file_force_false_raises_file_exists(self, tmp_path): + """Pre-existing destination + ``force=False`` must raise + ``FileExistsError`` and leave the existing bytes untouched. The + downloader must not even be invoked.""" + dst = tmp_path / "exists.fasta" + sentinel = b"pre-existing bytes that must not be overwritten\n" + dst.write_bytes(sentinel) + + with patch.object(contam_mod, "_download") as patched: + patched.side_effect = _make_download_mock(GPM_RAW) + with pytest.raises(FileExistsError, match="already exists"): + contaminants(source="gpm_crap", path=dst, force=False) + + assert patched.call_count == 0 + assert dst.read_bytes() == sentinel + + @pytest.mark.parametrize( + "source,payload,expected", + [ + ("gpm_crap", GPM_RAW, GPM_RAW), + ("frankenfield2022", FRANKENFIELD_RAW, FRANKENFIELD_FORMATTED), + ], + ) + def test_existing_file_force_true_overwrites( + self, tmp_path, source, payload, expected, + ): + dst = tmp_path / f"{source}.fasta" + dst.write_bytes(b"pre-existing bytes that must not be overwritten\n") + + with patch.object( + contam_mod, "_download", _make_download_mock(payload), + ): + result = contaminants(source=source, path=dst, force=True) + + assert result == dst + assert _sha256(result) == _sha256_bytes(expected) + + # -- formatter failure cleans up temp file ----------------------------- + + @pytest.mark.parametrize( + "bad_payload, error_match", + [ + ( + b">sp|P12345 missing third pipe segment\nMAAAA\n", + "exactly three", + ), + ( + b">sp|P12345|HUMAN_PROT|extra desc\nMAAAA\n", + "exactly three", + ), + ( + b">sp|BADID|HUMAN_PROT desc\nMAAAA\n", + "not a valid UniProt accession", + ), + ], + ) + def test_formatter_failure_propagates_and_cleans_up_temp( + self, tmp_path, monkeypatch, bad_payload, error_match, + ): + """ + All Frankenfield-header validation errors raised by the internal + formatter must propagate out of ``contaminants()`` and the + ``TemporaryDirectory`` used internally must be torn down on + failure. Verified through the public function only — no direct + call to the ``_format_frankenfield_header`` / ``_format_fasta`` + helpers. + """ + def fake_download(url, dest): + Path(dest).write_bytes(bad_payload) + + captured = [] + orig_td = tempfile.TemporaryDirectory + + def spy_td(*args, **kwargs): + obj = orig_td(*args, **kwargs) + captured.append(Path(obj.name)) + return obj + + monkeypatch.setattr(tempfile, "TemporaryDirectory", spy_td) + + dst = tmp_path / "frank.fasta" + with patch.object(contam_mod, "_download", fake_download): + with pytest.raises(ValueError, match=error_match): + contaminants(source="frankenfield2022", path=dst) + + assert len(captured) == 1 + assert not captured[0].exists() + + # -- opt-in real-download tests ---------------------------------------- + + def test_real_frankenfield_download(self, tmp_path): + """ + Real download from the Hao lab GitHub mirror. Hash pinned to current + upstream contents; re-pin ``EXPECTED_FRANKENFIELD_REMOTE_HASH`` if + Hao lab updates the FASTA (see ``_SOURCE_MAP['frankenfield2022']``). + """ + dst = tmp_path / "frankenfield_real.fasta" + result = contaminants(source="frankenfield2022", path=dst) + + assert result.exists() + assert result.stat().st_size > 1024 + assert _sha256(result) == EXPECTED_FRANKENFIELD_REMOTE_HASH + + def test_real_gpm_crap_download(self, tmp_path): + """ + Real download from GPM cRAP via FTP. Hash pinned to current upstream + contents; re-pin ``EXPECTED_GPM_REMOTE_HASH`` if upstream rotates + (see ``_SOURCE_MAP['gpm_crap']``). + """ + dst = tmp_path / "gpm_real.fasta" + result = contaminants(source="gpm_crap", path=dst) + + assert result.exists() + assert result.stat().st_size > 1024 + assert _sha256(result) == EXPECTED_GPM_REMOTE_HASH From f3b70d1ec9023a67c392a27b5c8b43ec90faf606 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Fri, 29 May 2026 14:16:29 +0200 Subject: [PATCH 2/7] refactor: pp.remove_contaminants, abstracted protein_id reader --- proteopy/pp/filtering.py | 46 +------- proteopy/utils/__init__.py | 1 + proteopy/utils/parsers.py | 234 ++++++++++++++++++++++++++++++++++++- tests/pp/test_filtering.py | 4 +- 4 files changed, 241 insertions(+), 44 deletions(-) diff --git a/proteopy/pp/filtering.py b/proteopy/pp/filtering.py index 5f48839..078104d 100644 --- a/proteopy/pp/filtering.py +++ b/proteopy/pp/filtering.py @@ -1,14 +1,13 @@ import warnings -from pathlib import Path from typing import Callable import numpy as np import pandas as pd import scipy.sparse as sp import anndata as ad -from Bio import SeqIO from proteopy.utils.functools import partial_with_docsig from proteopy.utils.anndata import check_proteodata, is_proteodata +from proteopy.utils.parsers import read_protein_ids def filter_axis( @@ -637,48 +636,13 @@ def remove_contaminants( """ check_proteodata(adata) - if header_parser is None: - def header_parser(header: str) -> str: - parts = header.split("|") - return parts[1] if len(parts) > 1 else header - - def _load_contaminant_ids_from_fasta(fasta_path: Path) -> set[str]: - contaminant_ids = set() - for record in SeqIO.parse(fasta_path, "fasta"): - parsed = header_parser(record.id) - if parsed == "": - warnings.warn( - f"Header parser returned empty ID for record '{record.id}'.", - ) - continue - contaminant_ids.add(parsed) - return contaminant_ids - - def _load_contaminant_ids_from_table(table_path: Path, sep: str) -> set[str]: - series = pd.read_csv(table_path, sep=sep, usecols=[0]).iloc[:, 0] - series = series.dropna().astype(str) - return set(series.tolist()) - - cont_path = Path(contaminant_path) - if not cont_path.exists(): - raise FileNotFoundError(f"Contaminant file not found at {cont_path}") - if protein_key not in adata.var.columns: raise KeyError(f"`protein_key`='{protein_key}' not found in adata.var") - suffix = cont_path.suffix.lower() - match suffix: - case ".fasta" | ".fa" | ".faa": - contaminant_ids = _load_contaminant_ids_from_fasta(cont_path) - case ".csv": - contaminant_ids = _load_contaminant_ids_from_table(cont_path, ",") - case ".tsv": - contaminant_ids = _load_contaminant_ids_from_table(cont_path, "\t") - case _: - raise ValueError( - "Unsupported contaminant file type. Use FASTA (.fasta/.fa/.faa), " - "CSV (.csv), or TSV (.tsv).", - ) + contaminant_ids = read_protein_ids( + contaminant_path, + header_parser=header_parser, + ) proteins = adata.var[protein_key] keep_mask = ~proteins.isin(contaminant_ids) diff --git a/proteopy/utils/__init__.py b/proteopy/utils/__init__.py index 160fcb5..18f9089 100644 --- a/proteopy/utils/__init__.py +++ b/proteopy/utils/__init__.py @@ -3,4 +3,5 @@ check_proteodata, ) from .array import is_log_transformed +from .parsers import read_protein_ids from .stat_tests import volcano_plot diff --git a/proteopy/utils/parsers.py b/proteopy/utils/parsers.py index 023d713..5578345 100644 --- a/proteopy/utils/parsers.py +++ b/proteopy/utils/parsers.py @@ -1,10 +1,13 @@ +import os import re import warnings -from typing import Dict, Optional, List +from pathlib import Path +from typing import Callable, Dict, Optional, List import anndata as ad import numpy as np import pandas as pd +from Bio import SeqIO from proteopy.utils.string import sanitize_string @@ -626,3 +629,232 @@ def _resolve_hclustv_profile_key( ) return profile_key + + +# --------------------------------------------------------------------------- +# Reusable readers for proteomics-related identifier files +# --------------------------------------------------------------------------- + +_FASTA_SUFFIXES = (".fasta", ".fa", ".faa") +_TABULAR_SEPARATORS = {".csv": ",", ".tsv": "\t"} + + +def _default_fasta_header_parser(header: str) -> str: + """ + Return the second pipe-separated token of a FASTA header, falling + back to the full header when no pipe is present. + + Example: ``"sp|P12345|HUMAN"`` -> ``"P12345"``. + """ + parts = header.split("|") + return parts[1] if len(parts) > 1 else header + + +def _read_fasta_protein_ids( + file_path: Path, + header_parser: Callable[[str], str], +) -> set[str]: + """Parse protein IDs from a FASTA file using ``header_parser``.""" + ids: set[str] = set() + for record in SeqIO.parse(file_path, "fasta"): + parsed = header_parser(record.id) + if not isinstance(parsed, str): + warnings.warn( + f"Header parser returned non-string " + f"({type(parsed).__name__}) for record " + f"'{record.id}'; skipping.", + UserWarning, + ) + continue + parsed = parsed.strip() + if parsed == "": + warnings.warn( + f"Header parser returned empty ID for record " + f"'{record.id}'.", + UserWarning, + ) + continue + ids.add(parsed) + return ids + + +def _read_tabular_protein_ids( + file_path: Path, + sep: str, + has_header: bool, +) -> set[str]: + """Parse protein IDs from the first column of a CSV / TSV file.""" + header_arg = 0 if has_header else None + try: + df = pd.read_csv( + file_path, + sep=sep, + usecols=[0], + header=header_arg, + ) + except pd.errors.EmptyDataError as exc: + raise ValueError( + f"Tabular file is empty: {file_path}" + ) from exc + + series = df.iloc[:, 0].dropna().astype(str).str.strip() + series = series[series != ""] + return set(series.tolist()) + + +def _validate_read_protein_ids_input( + path, + header_parser, + has_header, +) -> Path: + """Validate ``read_protein_ids`` inputs and return resolved path.""" + if not isinstance(path, (str, os.PathLike)): + raise TypeError( + f"`path` must be a str or os.PathLike, " + f"got {type(path).__name__}." + ) + if header_parser is not None and not callable(header_parser): + raise TypeError( + f"`header_parser` must be callable or None, " + f"got {type(header_parser).__name__}." + ) + if not isinstance(has_header, bool): + raise TypeError( + f"`has_header` must be a bool, " + f"got {type(has_header).__name__}." + ) + + file_path = Path(path) + if not file_path.exists(): + raise FileNotFoundError(f"File not found at {file_path}") + if not file_path.is_file(): + raise IsADirectoryError( + f"Path is not a regular file: {file_path}" + ) + return file_path + + +def read_protein_ids( + path: str | os.PathLike, + header_parser: Callable[[str], str] | None = None, + has_header: bool = True, +) -> set[str]: + """ + Read protein identifiers from a FASTA, CSV, or TSV file. + + Parameters + ---------- + path : str | os.PathLike + Path to the source file. FASTA (``.fasta`` / ``.fa`` / ``.faa``), + CSV (``.csv``), or TSV (``.tsv``) are supported. + header_parser : callable, optional + Function to extract protein IDs from FASTA headers. Defaults to + splitting the header on ``"|"`` and returning the second + element, falling back to the full header. Ignored (with a + warning) for tabular formats. + has_header : bool, optional + For CSV / TSV files, whether the first row is a header line. + Set to ``False`` for plain single-column ID lists. Ignored + for FASTA files. + + Returns + ------- + set of str + Unique protein identifiers parsed from ``path``. For tabular + files the first column is used. Whitespace is stripped and + empty strings are excluded. + + Raises + ------ + TypeError + If ``path`` is not a str / PathLike, ``header_parser`` is + neither callable nor ``None``, or ``has_header`` is not a bool. + FileNotFoundError + If ``path`` does not exist. + IsADirectoryError + If ``path`` exists but is not a regular file. + ValueError + If ``path`` has an unsupported suffix, or the tabular file is + empty. + + Warns + ----- + UserWarning + Emitted (and the record skipped) when ``header_parser`` returns + a non-string or an empty / whitespace-only string for a FASTA + record. Emitted when ``header_parser`` is supplied alongside a + tabular file (the parser is ignored). Emitted when zero IDs + are parsed from the file. + + Examples + -------- + Default FASTA parser (accession is the second pipe-separated field): + + >>> import tempfile + >>> from pathlib import Path + >>> with tempfile.TemporaryDirectory() as d: + ... p = Path(d) / "contaminants.fasta" + ... _ = p.write_text( + ... ">sp|P12345|HUMAN_A\\nACDEF\\n" + ... ">sp|P67890|HUMAN_B\\nGHIKL\\n" + ... ) + ... print(sorted(read_protein_ids(p))) + ['P12345', 'P67890'] + + Custom header parser: + + >>> with tempfile.TemporaryDirectory() as d: + ... p = Path(d) / "headers.fasta" + ... _ = p.write_text( + ... ">x__protein_0\\nACDEF\\n" + ... ">x__protein_1\\nGHIKL\\n" + ... ) + ... print(sorted(read_protein_ids( + ... p, header_parser=lambda h: h.split("__")[1], + ... ))) + ['protein_0', 'protein_1'] + + Single-column TSV without a header row: + + >>> with tempfile.TemporaryDirectory() as d: + ... p = Path(d) / "ids.tsv" + ... _ = p.write_text("P00001\\nP00002\\nP00003\\n") + ... print(sorted(read_protein_ids(p, has_header=False))) + ['P00001', 'P00002', 'P00003'] + """ + file_path = _validate_read_protein_ids_input( + path, header_parser, has_header, + ) + + user_provided_parser = header_parser is not None + if header_parser is None: + header_parser = _default_fasta_header_parser + + # -- Dispatch on file suffix + suffix = file_path.suffix.lower() + if suffix in _FASTA_SUFFIXES: + ids = _read_fasta_protein_ids(file_path, header_parser) + elif suffix in _TABULAR_SEPARATORS: + if user_provided_parser: + warnings.warn( + f"`header_parser` is ignored for tabular files " + f"(got suffix '{suffix}').", + UserWarning, + ) + ids = _read_tabular_protein_ids( + file_path, + sep=_TABULAR_SEPARATORS[suffix], + has_header=has_header, + ) + else: + raise ValueError( + "Unsupported file type. Use FASTA " + "(.fasta/.fa/.faa), CSV (.csv), or TSV (.tsv)." + ) + + if not ids: + warnings.warn( + f"No protein IDs parsed from {file_path}.", + UserWarning, + ) + return ids diff --git a/tests/pp/test_filtering.py b/tests/pp/test_filtering.py index c86b633..2446f6f 100644 --- a/tests/pp/test_filtering.py +++ b/tests/pp/test_filtering.py @@ -1863,7 +1863,7 @@ def test_missing_contaminant_file_raises(self, tmp_path): with pytest.raises( FileNotFoundError, - match=r"Contaminant file not found", + match=r"File not found", ): remove_contaminants( adata, @@ -1889,7 +1889,7 @@ def test_unsupported_file_type_raises(self, tmp_path): with pytest.raises( ValueError, - match=r"Unsupported contaminant file type", + match=r"Unsupported file type", ): remove_contaminants( adata, From 1e4193ff52a4206b0a593e193267ad38a8d23067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Fri, 29 May 2026 14:17:01 +0200 Subject: [PATCH 3/7] feature: ann.contaminants --- proteopy/ann/__init__.py | 1 + proteopy/ann/contaminants.py | 161 +++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 proteopy/ann/contaminants.py diff --git a/proteopy/ann/__init__.py b/proteopy/ann/__init__.py index 3cf7188..7d77d71 100644 --- a/proteopy/ann/__init__.py +++ b/proteopy/ann/__init__.py @@ -1,2 +1,3 @@ from .base_anndata import var, obs, samples from .proteins import proteins_from_csv +from .contaminants import contaminants diff --git a/proteopy/ann/contaminants.py b/proteopy/ann/contaminants.py new file mode 100644 index 0000000..0def548 --- /dev/null +++ b/proteopy/ann/contaminants.py @@ -0,0 +1,161 @@ +""" +Annotation helpers for marking contaminant variables in an AnnData. +""" +import os +import warnings +from typing import Callable + +from anndata import AnnData + +from proteopy.utils.anndata import check_proteodata, is_proteodata +from proteopy.utils.parsers import read_protein_ids + + +def _print_contaminant_summary( + adata: AnnData, + mask_values, + protein_key: str, + key_added: str, +) -> None: + """Print a level-aware summary of annotated contaminants.""" + _, level = is_proteodata(adata) + if level == "peptide": + n_pep = int(mask_values.sum()) + if n_pep == 0: + n_prot = 0 + else: + n_prot = int( + adata.var.loc[mask_values, protein_key].nunique() + ) + print( + f"Annotated {n_pep} peptides from {n_prot} " + f"contaminating proteins at adata.var['{key_added}']." + ) + elif level == "protein": + n_prot = int(mask_values.sum()) + print( + f"Annotated {n_prot} contaminating proteins at " + f"adata.var['{key_added}']." + ) + else: + n = int(mask_values.sum()) + print( + f"Annotated {n} contaminating variables at " + f"adata.var['{key_added}']." + ) + + +def contaminants( + adata: AnnData, + contaminant_path: str | os.PathLike, + *, + protein_key: str = "protein_id", + key_added: str = "is_contaminant", + header_parser: Callable[[str], str] | None = None, + has_header: bool = True, + inplace: bool = True, + verbose: bool = False, +) -> AnnData | None: + """ + Annotate contaminant variables by flagging them in ``adata.var``. + + Reads protein identifiers from a contaminant list (FASTA / CSV / + TSV) and writes a boolean column ``adata.var[key_added]`` that is + ``True`` where ``adata.var[protein_key]`` matches an entry in the + list. Unlike :func:`proteopy.pp.remove_contaminants`, the + contaminants are kept in the AnnData and only flagged, so the + annotation can be used for downstream filtering decisions, QC + plots, or contaminant-aware normalization. + + Parameters + ---------- + adata : AnnData + :class:`~anndata.AnnData` annotated data matrix. + contaminant_path : str | os.PathLike + Path to the contaminant list. Supported formats: FASTA + (``.fasta`` / ``.fa`` / ``.faa``), CSV (``.csv``), TSV + (``.tsv``). See :func:`proteopy.utils.parsers.read_protein_ids` + for parsing details. + protein_key : str, optional + Column in ``adata.var`` holding protein identifiers used for + matching. Defaults to ``"protein_id"``. + key_added : str, optional + Name of the boolean column written to ``adata.var``. Defaults + to ``"is_contaminant"``. + header_parser : callable, optional + Function to extract protein IDs from FASTA headers. Defaults + to splitting the header on ``"|"`` and returning the second + element, falling back to the full header. Ignored (with a + warning) for tabular formats. + has_header : bool, optional + For CSV / TSV files, whether the first row is a header line. + Set to ``False`` for plain single-column ID lists. + + Returns + ------- + AnnData or None + ``None`` when ``inplace=True``; otherwise the annotated copy + of ``adata``. + + Raises + ------ + KeyError + If ``protein_key`` is not a column of ``adata.var``. + + Warns + ----- + UserWarning + Emitted when ``key_added`` already exists in ``adata.var`` + (the column is overwritten). All warnings raised by + :func:`~proteopy.utils.parsers.read_protein_ids` (empty / + non-string parsed IDs, no IDs parsed, parser ignored for + tabular files) are propagated. + + See Also + -------- + proteopy.pp.remove_contaminants : Remove rather than annotate. + proteopy.utils.parsers.read_protein_ids : Underlying file reader. + + Examples + -------- + >>> import proteopy as pr + >>> adata = pr.datasets.example_protein_data() # doctest: +SKIP + >>> pr.ann.contaminants( + ... adata, + ... "contaminants.fasta", + ... verbose=True, + ... ) # doctest: +SKIP + >>> adata.var["is_contaminant"].sum() # doctest: +SKIP + """ + check_proteodata(adata) + + if protein_key not in adata.var.columns: + raise KeyError( + f"`protein_key`='{protein_key}' not found in adata.var" + ) + if key_added in adata.var.columns: + warnings.warn( + f"`key_added`='{key_added}' already exists in adata.var; " + "overwriting.", + UserWarning, + ) + + adata_target = adata if inplace else adata.copy() + + contaminant_ids = read_protein_ids( + contaminant_path, + header_parser=header_parser, + has_header=has_header, + ) + + mask = adata_target.var[protein_key].isin(contaminant_ids) + mask_values = mask.astype(bool).values + adata_target.var[key_added] = mask_values + + if verbose: + _print_contaminant_summary( + adata_target, mask_values, protein_key, key_added, + ) + + check_proteodata(adata_target) + return None if inplace else adata_target From 7ff90888f48efb309e54957aec9ba010fc810fe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Fri, 29 May 2026 14:26:22 +0200 Subject: [PATCH 4/7] corrected pre-commit issues --- proteopy/ann/contaminants.py | 16 ++-- proteopy/pp/filtering.py | 61 +++++++------- proteopy/utils/parsers.py | 123 +++++++++++++++------------- tests/download/test_contaminants.py | 83 +++++++++++++------ 4 files changed, 165 insertions(+), 118 deletions(-) diff --git a/proteopy/ann/contaminants.py b/proteopy/ann/contaminants.py index 0def548..dfaa80f 100644 --- a/proteopy/ann/contaminants.py +++ b/proteopy/ann/contaminants.py @@ -1,9 +1,10 @@ """ Annotation helpers for marking contaminant variables in an AnnData. """ + import os import warnings -from typing import Callable +from collections.abc import Callable from anndata import AnnData @@ -24,9 +25,7 @@ def _print_contaminant_summary( if n_pep == 0: n_prot = 0 else: - n_prot = int( - adata.var.loc[mask_values, protein_key].nunique() - ) + n_prot = int(adata.var.loc[mask_values, protein_key].nunique()) print( f"Annotated {n_pep} peptides from {n_prot} " f"contaminating proteins at adata.var['{key_added}']." @@ -130,9 +129,7 @@ def contaminants( check_proteodata(adata) if protein_key not in adata.var.columns: - raise KeyError( - f"`protein_key`='{protein_key}' not found in adata.var" - ) + raise KeyError(f"`protein_key`='{protein_key}' not found in adata.var") if key_added in adata.var.columns: warnings.warn( f"`key_added`='{key_added}' already exists in adata.var; " @@ -154,7 +151,10 @@ def contaminants( if verbose: _print_contaminant_summary( - adata_target, mask_values, protein_key, key_added, + adata_target, + mask_values, + protein_key, + key_added, ) check_proteodata(adata_target) diff --git a/proteopy/pp/filtering.py b/proteopy/pp/filtering.py index 078104d..6c345d0 100644 --- a/proteopy/pp/filtering.py +++ b/proteopy/pp/filtering.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable +from collections.abc import Callable import numpy as np import pandas as pd import scipy.sparse as sp @@ -84,7 +84,7 @@ def filter_axis( axis_i = 1 - axis axis_labels = adata.obs_names if axis == 0 else adata.var_names - completeness = None # assigned below when min_fraction is set + completeness = None # assigned below when min_fraction is set if group_by is not None: metadata = adata.obs if axis == 1 else adata.var @@ -126,7 +126,9 @@ def filter_axis( if not completeness_by_group: completeness = pd.Series(0, index=axis_labels, dtype=float) else: - completeness = pd.concat(completeness_by_group, axis=1).max(axis=1) + completeness = pd.concat(completeness_by_group, axis=1).max( + axis=1 + ) else: if sp.issparse(X): counts = pd.Series(X.getnnz(axis=axis_i), index=axis_labels) @@ -157,7 +159,9 @@ def filter_axis( check_proteodata(adata) return None else: - adata_filtered = adata[mask_filt, :] if axis == 0 else adata[:, mask_filt] + adata_filtered = ( + adata[mask_filt, :] if axis == 0 else adata[:, mask_filt] + ) check_proteodata(adata_filtered) return adata_filtered @@ -173,7 +177,7 @@ def filter_axis( filter_axis, axis=0, docstr_header=docstr_header, - ) +) docstr_header = """ Filter observations based on data completeness. @@ -187,7 +191,7 @@ def filter_axis( axis=0, min_count=None, docstr_header=docstr_header, - ) +) docstr_header = """ Filter variables based on non-missing value content. @@ -200,7 +204,7 @@ def filter_axis( filter_axis, axis=1, docstr_header=docstr_header, - ) +) docstr_header = """ Filter variables based on data completeness. @@ -214,7 +218,7 @@ def filter_axis( axis=1, min_count=None, docstr_header=docstr_header, - ) +) def filter_proteins_by_peptide_count( @@ -223,7 +227,7 @@ def filter_proteins_by_peptide_count( max_count=None, protein_col="protein_id", inplace=True, - ): +): """ Filter proteins by their peptide count. @@ -247,9 +251,9 @@ def filter_proteins_by_peptide_count( """ check_proteodata(adata) if is_proteodata(adata)[1] != "peptide": - raise ValueError(( + raise ValueError( "`AnnData` object must be in ProteoData peptide format." - )) + ) if min_count is None and max_count is None: warnings.warn("Pass at least one argument: min_count | max_count") @@ -264,7 +268,9 @@ def filter_proteins_by_peptide_count( if max_count is not None: if max_count < 0: raise ValueError("`max_count` must be non-negative.") - if (min_count is not None and max_count is not None) and (min_count > max_count): + if (min_count is not None and max_count is not None) and ( + min_count > max_count + ): raise ValueError("`min_count` cannot be greater than `max_count`.") if protein_col not in adata.var.columns: @@ -311,7 +317,7 @@ def filter_samples_by_category_count( min_count=None, max_count=None, inplace=True, - ): +): """ Filter observations by the frequency of their category value in a ``.vars`` metadata column. @@ -353,7 +359,9 @@ def filter_samples_by_category_count( raise ValueError("`min_count` cannot be greater than `max_count`.") if category_col not in adata.obs.columns: - raise KeyError(f"`category_col`='{category_col}' not found in adata.obs") + raise KeyError( + f"`category_col`='{category_col}' not found in adata.obs" + ) obs_series = adata.obs[category_col] counts = obs_series.value_counts(dropna=False) @@ -398,26 +406,22 @@ def _validate_remove_zero_variance_vars_input( ) if not isinstance(atol, (int, float)): raise TypeError( - f"`atol` must be a numeric value, " - f"got {type(atol).__name__}." + f"`atol` must be a numeric value, " f"got {type(atol).__name__}." ) if atol < 0: raise ValueError("`atol` must be non-negative.") if not isinstance(inplace, bool): raise TypeError( - f"`inplace` must be a bool, " - f"got {type(inplace).__name__}." + f"`inplace` must be a bool, " f"got {type(inplace).__name__}." ) if not isinstance(verbose, bool): raise TypeError( - f"`verbose` must be a bool, " - f"got {type(verbose).__name__}." + f"`verbose` must be a bool, " f"got {type(verbose).__name__}." ) if group_by is not None: if group_by not in adata.obs.columns: raise KeyError( - f"`group_by`='{group_by}' not found " - f"in adata.obs" + f"`group_by`='{group_by}' not found " f"in adata.obs" ) if adata.obs[group_by].isna().any(): raise ValueError( @@ -531,7 +535,11 @@ def remove_zero_variance_vars( ['p1'] """ _validate_remove_zero_variance_vars_input( - adata, group_by, atol, inplace, verbose, + adata, + group_by, + atol, + inplace, + verbose, ) check_proteodata(adata) X = adata.X @@ -559,10 +567,7 @@ def remove_zero_variance_vars( if idx.size == 0: continue Xg = X[idx, :] - Xg_arr = ( - Xg.toarray() if sp.issparse(Xg) - else np.asarray(Xg) - ) + Xg_arr = Xg.toarray() if sp.issparse(Xg) else np.asarray(Xg) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) vg = np.nanvar(Xg_arr, axis=0, ddof=0) @@ -607,7 +612,7 @@ def remove_contaminants( protein_key="protein_id", header_parser: Callable[[str], str] | None = None, inplace=True, - ): +): """ Remove variables whose protein identifier matches a contaminant FASTA entry. diff --git a/proteopy/utils/parsers.py b/proteopy/utils/parsers.py index 5578345..acf356f 100644 --- a/proteopy/utils/parsers.py +++ b/proteopy/utils/parsers.py @@ -2,7 +2,8 @@ import re import warnings from pathlib import Path -from typing import Callable, Dict, Optional, List +from typing import Dict, Optional, List +from collections.abc import Callable import anndata as ad import numpy as np @@ -19,7 +20,9 @@ } -def parse_tumor_subclass(df: pd.DataFrame, col: str = "tumor_class") -> pd.DataFrame: +def parse_tumor_subclass( + df: pd.DataFrame, col: str = "tumor_class" +) -> pd.DataFrame: """ Parse a less-structured tumor_class column into: - main_tumor_type @@ -56,7 +59,6 @@ def parse_tumor_subclass(df: pd.DataFrame, col: str = "tumor_class") -> pd.DataF df = df.copy() df.index.name = None - # Compile patterns once # Genetic markers to capture (exact phrases) genetic_marker_patterns = [ @@ -67,14 +69,20 @@ def parse_tumor_subclass(df: pd.DataFrame, col: str = "tumor_class") -> pd.DataF ] # subclass and subtype helpers - subclass_bracket_pat = re.compile(r"\[([^\]]*subclass[^\]]*)\]", re.IGNORECASE) + subclass_bracket_pat = re.compile( + r"\[([^\]]*subclass[^\]]*)\]", re.IGNORECASE + ) subclass_pat = re.compile(r"\bsubclass\b[^\),;\]]*", re.IGNORECASE) - subtype_bracket_pat = re.compile(r"\[([^\]]*subtype[^\]]*)\]", re.IGNORECASE) + subtype_bracket_pat = re.compile( + r"\[([^\]]*subtype[^\]]*)\]", re.IGNORECASE + ) # 'subtype ...' subtype_after_pat = re.compile(r"\bsubtype\b[^\),;\]]*", re.IGNORECASE) # '... subtype' (capture up to 3 words before subtype) - subtype_before_pat = re.compile(r"(?:\b[\w/-]+\s+){1,3}\bsubtype\b", re.IGNORECASE) + subtype_before_pat = re.compile( + r"(?:\b[\w/-]+\s+){1,3}\bsubtype\b", re.IGNORECASE + ) # Splitter on comma or the word 'and' splitter = re.compile(r"\s*,\s*|\s+\band\b\s+", re.IGNORECASE) @@ -82,11 +90,13 @@ def parse_tumor_subclass(df: pd.DataFrame, col: str = "tumor_class") -> pd.DataF def strip_wrappers(s: str) -> str: s = s.strip() # remove enclosing brackets or parentheses only if they enclose the whole chunk - if len(s) >= 2 and ((s[0] == "[" and s[-1] == "]") or (s[0] == "(" and s[-1] == ")")): + if len(s) >= 2 and ( + (s[0] == "[" and s[-1] == "]") or (s[0] == "(" and s[-1] == ")") + ): s = s[1:-1].strip() return s.strip(" ,;") - def dedupe_keep_order(items: List[str]) -> List[str]: + def dedupe_keep_order(items: list[str]) -> list[str]: seen = set() out = [] for x in items: @@ -101,7 +111,7 @@ def normalize_case(val: str) -> str: # Keep original chunk case for readability return val.strip() - def parse_one(value: Optional[str]) -> Dict[str, Optional[str]]: + def parse_one(value: str | None) -> dict[str, str | None]: if value is None or (isinstance(value, float) and np.isnan(value)): return { "main_tumor_type": None, @@ -112,11 +122,11 @@ def parse_one(value: Optional[str]) -> Dict[str, Optional[str]]: } remaining = str(value).strip() - markers: List[str] = [] - subclass_val: Optional[str] = None - subtype_val: Optional[str] = None - rest_parts: List[str] = [] - main_tumor_type: Optional[str] = None + markers: list[str] = [] + subclass_val: str | None = None + subtype_val: str | None = None + rest_parts: list[str] = [] + main_tumor_type: str | None = None while True: # Split into tokens @@ -133,7 +143,7 @@ def parse_one(value: Optional[str]) -> Dict[str, Optional[str]]: remaining_next = ", ".join(tokens[:-1]) chunk_work = chunk - consumed_spans: List[tuple] = [] + consumed_spans: list[tuple] = [] def record_span(m): if m: @@ -176,7 +186,9 @@ def record_span(m): record_span(m) # Compute residual of this chunk after removing matches - residual = strip_wrappers(_remove_spans(chunk_work, consumed_spans)) + residual = strip_wrappers( + _remove_spans(chunk_work, consumed_spans) + ) if residual: rest_parts.append(residual) @@ -184,7 +196,9 @@ def record_span(m): if remaining_next is None: # Final chunk: this defines main_tumor_type (after removing matched parts) # If residual is empty (i.e., the entire chunk was a match), fall back to cleaned chunk - main_tumor_type = residual if residual else strip_wrappers(chunk_work) + main_tumor_type = ( + residual if residual else strip_wrappers(chunk_work) + ) break else: remaining = remaining_next @@ -208,7 +222,7 @@ def record_span(m): "rest": rest, } - def _remove_spans(text: str, spans: List[tuple]) -> str: + def _remove_spans(text: str, spans: list[tuple]) -> str: if not spans: return text spans_sorted = sorted(spans) @@ -226,12 +240,12 @@ def _remove_spans(text: str, spans: List[tuple]) -> str: parsed = df[col].apply(parse_one) parsed_df = pd.DataFrame(list(parsed)) df_list = [ - df.reset_index()[['index', col]], - parsed_df.reset_index(drop=True) + df.reset_index()[["index", col]], + parsed_df.reset_index(drop=True), ] - new_df = pd.concat(df_list, axis=1) - new_df = new_df.set_index('index') + new_df = pd.concat(df_list, axis=1) + new_df = new_df.set_index("index") # Add original index new_df = new_df.loc[df.index,] @@ -240,19 +254,22 @@ def _remove_spans(text: str, spans: List[tuple]) -> str: def diann_run(s, warn=False): - match = re.search(r'_(\d+)_T', s) + match = re.search(r"_(\d+)_T", s) if match: - return 'Run_' + match.group(1) + return "Run_" + match.group(1) - match = re.search(r'(?<=_)(?:N?\d{2,5}(?:_[A-Za-z0-9]+)*_[A-Za-z]+|N?\d{5}|N?\d{2}_\d{4}[A-Za-z]?_[A-Za-z]+)(?=_T1_DIA)', s) + match = re.search( + r"(?<=_)(?:N?\d{2,5}(?:_[A-Za-z0-9]+)*_[A-Za-z]+|N?\d{5}|N?\d{2}_\d{4}[A-Za-z]?_[A-Za-z]+)(?=_T1_DIA)", + s, + ) if match: - return 'Run_' + match.group(0) + return "Run_" + match.group(0) if warn: - warnings.warn(f'No match for string:\n{s}') - return 'no_parse_match' + warnings.warn(f"No match for string:\n{s}") + return "no_parse_match" - raise ValueError(f'No match for string:\n{s}') + raise ValueError(f"No match for string:\n{s}") def _pretty_design_label(label: str) -> str: @@ -331,8 +348,7 @@ def parse_stat_test_varm_slot( if layer_part: if adata is not None and adata.layers: layer_map = { - sanitize_string(name): name - for name in adata.layers.keys() + sanitize_string(name): name for name in adata.layers.keys() } if layer_part in layer_map: layer = layer_map[layer_part] @@ -342,7 +358,7 @@ def parse_stat_test_varm_slot( f"must contain the sanitized layer part for back-" f"mapping. '{layer_part}' not found in adata varm layers" f"(unsanitized): {adata.layers}." - ) + ) else: layer = layer_part @@ -369,8 +385,7 @@ def parse_stat_test_varm_slot( ) else: raise ValueError( - "Design must use '_vs_' or " - "'_vs_rest'." + "Design must use '_vs_' or " "'_vs_rest'." ) test_info = { @@ -422,8 +437,8 @@ def _parse_hclustv_key_components(key: str) -> tuple[str, str, str] | None: def _resolve_hclustv_keys( adata: ad.AnnData, - linkage_key: str = 'auto', - values_key: str = 'auto', + linkage_key: str = "auto", + values_key: str = "auto", verbose: bool = True, ) -> tuple[str, str]: """ @@ -433,16 +448,14 @@ def _resolve_hclustv_keys( the resolved key names. """ linkage_candidates = [ - key for key in adata.uns.keys() - if key.startswith("hclustv_linkage;") + key for key in adata.uns.keys() if key.startswith("hclustv_linkage;") ] values_candidates = [ - key for key in adata.uns.keys() - if key.startswith("hclustv_values;") + key for key in adata.uns.keys() if key.startswith("hclustv_values;") ] - linkage_auto = linkage_key == 'auto' - values_auto = values_key == 'auto' + linkage_auto = linkage_key == "auto" + values_auto = values_key == "auto" if linkage_auto: if len(linkage_candidates) == 0: @@ -505,7 +518,7 @@ def _resolve_hclustv_keys( def _resolve_hclustv_cluster_key( adata: ad.AnnData, - cluster_key: str = 'auto', + cluster_key: str = "auto", verbose: bool = True, ) -> str: """ @@ -539,11 +552,10 @@ def _resolve_hclustv_cluster_key( If the specified ``cluster_key`` is not found in ``adata.var``. """ cluster_candidates = [ - col for col in adata.var.columns - if col.startswith("hclustv_cluster;") + col for col in adata.var.columns if col.startswith("hclustv_cluster;") ] - if cluster_key == 'auto': + if cluster_key == "auto": if len(cluster_candidates) == 0: raise ValueError( "No cluster annotations found in adata.var. " @@ -569,7 +581,7 @@ def _resolve_hclustv_cluster_key( def _resolve_hclustv_profile_key( adata: ad.AnnData, - profile_key: str = 'auto', + profile_key: str = "auto", verbose: bool = True, ) -> str: """ @@ -603,11 +615,10 @@ def _resolve_hclustv_profile_key( If the specified ``profile_key`` is not found in ``adata.uns``. """ profile_candidates = [ - key for key in adata.uns.keys() - if key.startswith("hclustv_profiles;") + key for key in adata.uns.keys() if key.startswith("hclustv_profiles;") ] - if profile_key == 'auto': + if profile_key == "auto": if len(profile_candidates) == 0: raise ValueError( "No cluster profiles found in adata.uns. " @@ -693,9 +704,7 @@ def _read_tabular_protein_ids( header=header_arg, ) except pd.errors.EmptyDataError as exc: - raise ValueError( - f"Tabular file is empty: {file_path}" - ) from exc + raise ValueError(f"Tabular file is empty: {file_path}") from exc series = df.iloc[:, 0].dropna().astype(str).str.strip() series = series[series != ""] @@ -728,9 +737,7 @@ def _validate_read_protein_ids_input( if not file_path.exists(): raise FileNotFoundError(f"File not found at {file_path}") if not file_path.is_file(): - raise IsADirectoryError( - f"Path is not a regular file: {file_path}" - ) + raise IsADirectoryError(f"Path is not a regular file: {file_path}") return file_path @@ -823,7 +830,9 @@ def read_protein_ids( ['P00001', 'P00002', 'P00003'] """ file_path = _validate_read_protein_ids_input( - path, header_parser, has_header, + path, + header_parser, + has_header, ) user_provided_parser = header_parser is not None diff --git a/tests/download/test_contaminants.py b/tests/download/test_contaminants.py index 2dbfa52..24f7a59 100644 --- a/tests/download/test_contaminants.py +++ b/tests/download/test_contaminants.py @@ -99,11 +99,11 @@ class TestIsUniprotAccession: @pytest.mark.parametrize( "accession", [ - "P12345", # Swiss-Prot, [OPQ] branch - "O00001", # Swiss-Prot, [OPQ] branch with zeros - "Q9Y6K9", # Swiss-Prot, mixed alphanumeric - "A0A0A0A0A1", # TrEMBL 10-char, [A-NR-Z] branch - "P12345-2", # isoform suffix + "P12345", # Swiss-Prot, [OPQ] branch + "O00001", # Swiss-Prot, [OPQ] branch with zeros + "Q9Y6K9", # Swiss-Prot, mixed alphanumeric + "A0A0A0A0A1", # TrEMBL 10-char, [A-NR-Z] branch + "P12345-2", # isoform suffix "A0A0A0A0A1-12", # TrEMBL with two-digit isoform ], ) @@ -113,15 +113,15 @@ def test_valid_accessions(self, accession): @pytest.mark.parametrize( "accession", [ - "", # empty - "p12345", # lowercase - "P1234", # too short - "P123456", # too long for [OPQ] branch - "X12345", # second char not a letter-then-3 pattern - "12345P", # leading digit - "1ABCDE", # leading digit - "P12345-", # dangling isoform separator - "P12345-123", # isoform suffix too long + "", # empty + "p12345", # lowercase + "P1234", # too short + "P123456", # too long for [OPQ] branch + "X12345", # second char not a letter-then-3 pattern + "12345P", # leading digit + "1ABCDE", # leading digit + "P12345-", # dangling isoform separator + "P12345-123", # isoform suffix too long ], ) def test_invalid_accessions(self, accession): @@ -165,7 +165,9 @@ def test_unsupported_source_raises(self): def test_gpm_crap_hash_matches_raw_payload(self, tmp_path): dst = tmp_path / "gpm.fasta" with patch.object( - contam_mod, "_download", _make_download_mock(GPM_RAW), + contam_mod, + "_download", + _make_download_mock(GPM_RAW), ): result = contaminants(source="gpm_crap", path=dst) @@ -195,11 +197,17 @@ def test_frankenfield_hash_matches_formatted_payload(self, tmp_path): ], ) def test_returns_path_to_destination( - self, tmp_path, source, payload, expected, + self, + tmp_path, + source, + payload, + expected, ): dst = tmp_path / f"{source}.fasta" with patch.object( - contam_mod, "_download", _make_download_mock(payload), + contam_mod, + "_download", + _make_download_mock(payload), ): result = contaminants(source=source, path=dst) @@ -213,17 +221,27 @@ def test_returns_path_to_destination( "source, payload, expected, default_stem", [ ( - "gpm_crap", GPM_RAW, GPM_RAW, + "gpm_crap", + GPM_RAW, + GPM_RAW, "contaminants_gpm-crap", ), ( - "frankenfield2022", FRANKENFIELD_RAW, FRANKENFIELD_FORMATTED, + "frankenfield2022", + FRANKENFIELD_RAW, + FRANKENFIELD_FORMATTED, "contaminants_frankenfield2022", ), ], ) def test_default_path_appends_md5_digest( - self, tmp_path, monkeypatch, source, payload, expected, default_stem, + self, + tmp_path, + monkeypatch, + source, + payload, + expected, + default_stem, ): """ With ``path=None`` the function writes to a default file in the @@ -246,7 +264,9 @@ def spy_td(*args, **kwargs): monkeypatch.setattr(tempfile, "TemporaryDirectory", spy_td) with patch.object( - contam_mod, "_download", _make_download_mock(payload), + contam_mod, + "_download", + _make_download_mock(payload), ): result = contaminants(source=source, path=None) @@ -267,7 +287,9 @@ def test_parent_directory_is_created(self, tmp_path): assert not dst.parent.exists() with patch.object( - contam_mod, "_download", _make_download_mock(GPM_RAW), + contam_mod, + "_download", + _make_download_mock(GPM_RAW), ): result = contaminants(source="gpm_crap", path=dst) @@ -300,13 +322,19 @@ def test_existing_file_force_false_raises_file_exists(self, tmp_path): ], ) def test_existing_file_force_true_overwrites( - self, tmp_path, source, payload, expected, + self, + tmp_path, + source, + payload, + expected, ): dst = tmp_path / f"{source}.fasta" dst.write_bytes(b"pre-existing bytes that must not be overwritten\n") with patch.object( - contam_mod, "_download", _make_download_mock(payload), + contam_mod, + "_download", + _make_download_mock(payload), ): result = contaminants(source=source, path=dst, force=True) @@ -333,7 +361,11 @@ def test_existing_file_force_true_overwrites( ], ) def test_formatter_failure_propagates_and_cleans_up_temp( - self, tmp_path, monkeypatch, bad_payload, error_match, + self, + tmp_path, + monkeypatch, + bad_payload, + error_match, ): """ All Frankenfield-header validation errors raised by the internal @@ -343,6 +375,7 @@ def test_formatter_failure_propagates_and_cleans_up_temp( call to the ``_format_frankenfield_header`` / ``_format_fasta`` helpers. """ + def fake_download(url, dest): Path(dest).write_bytes(bad_payload) From 1e939f8d45ba5c69545d1b9c0d894e02581b16c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:23:03 +0200 Subject: [PATCH 5/7] docstring ann.contaminants, refactoring --- proteopy/ann/contaminants.py | 36 +- proteopy/pl/stats.py | 807 +++++++++++++++++++---------------- 2 files changed, 472 insertions(+), 371 deletions(-) diff --git a/proteopy/ann/contaminants.py b/proteopy/ann/contaminants.py index dfaa80f..4da23d7 100644 --- a/proteopy/ann/contaminants.py +++ b/proteopy/ann/contaminants.py @@ -117,14 +117,36 @@ def contaminants( Examples -------- + Flag two of three proteins as contaminants using an inline FASTA list. + The default header parser extracts the accession from the second + pipe-separated field (Swiss-Prot style): + + >>> import tempfile + >>> from pathlib import Path + >>> import numpy as np + >>> import pandas as pd + >>> from anndata import AnnData >>> import proteopy as pr - >>> adata = pr.datasets.example_protein_data() # doctest: +SKIP - >>> pr.ann.contaminants( - ... adata, - ... "contaminants.fasta", - ... verbose=True, - ... ) # doctest: +SKIP - >>> adata.var["is_contaminant"].sum() # doctest: +SKIP + >>> with tempfile.TemporaryDirectory() as d: + ... fasta = Path(d) / "contaminants.fasta" + ... _ = fasta.write_text( + ... ">sp|P00001|HUMAN_A\\nACDEF\\n" + ... ">sp|P00002|HUMAN_B\\nGHIKL\\n" + ... ) + ... adata = AnnData( + ... X=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + ... obs=pd.DataFrame( + ... {"sample_id": ["s1", "s2"]}, index=["s1", "s2"], + ... ), + ... var=pd.DataFrame( + ... {"protein_id": ["P00001", "P00002", "P00003"]}, + ... index=["P00001", "P00002", "P00003"], + ... ), + ... ) + ... pr.ann.contaminants(adata, fasta, verbose=True) + ... print(int(adata.var["is_contaminant"].sum())) + Annotated 2 contaminating proteins at adata.var['is_contaminant']. + 2 """ check_proteodata(adata) diff --git a/proteopy/pl/stats.py b/proteopy/pl/stats.py index 4d9ed69..5e328cc 100644 --- a/proteopy/pl/stats.py +++ b/proteopy/pl/stats.py @@ -1,6 +1,9 @@ import warnings from pathlib import Path -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence +from typing import Any +from collections.abc import Callable, Sequence import uuid import numpy as np @@ -41,14 +44,9 @@ def _validate_completeness_args( # noqa: C901 check_proteodata(adata) if axis not in (0, 1): - raise ValueError( - "`axis` must be either 0 (var) or 1 (obs)." - ) + raise ValueError("`axis` must be either 0 (var) or 1 (obs).") - if ( - group_by_resolution is not None - and group_by_partition is not None - ): + if group_by_resolution is not None and group_by_partition is not None: raise ValueError( "`group_by_resolution` and `group_by_partition` " "are mutually exclusive. Provide one or neither." @@ -63,18 +61,13 @@ def _validate_completeness_args( # noqa: C901 if fraction_thresh is not None and ( fraction_thresh < 0 or fraction_thresh > 1 ): - raise ValueError( - "`fraction_thresh` must be between 0 and 1." - ) + raise ValueError("`fraction_thresh` must be between 0 and 1.") if bin_width is not None and bin_width <= 0: - raise ValueError( - "`bin_width` must be a positive number." - ) + raise ValueError("`bin_width` must be a positive number.") - if ( - group_by_resolution is None - and (min_count is not None or min_fraction is not None) + if group_by_resolution is None and ( + min_count is not None or min_fraction is not None ): warnings.warn( "`min_count` and `min_fraction` are only used when " @@ -88,15 +81,12 @@ def _validate_completeness_args( # noqa: C901 matrix = adata.X else: if layer not in adata.layers: - raise KeyError( - f"Layer '{layer}' not found in adata.layers." - ) + raise KeyError(f"Layer '{layer}' not found in adata.layers.") matrix = adata.layers[layer] if matrix is None: raise ValueError( - "Selected matrix is empty; cannot compute " - "completeness." + "Selected matrix is empty; cannot compute " "completeness." ) n_obs, n_vars = matrix.shape @@ -113,14 +103,10 @@ def _validate_completeness_args( # noqa: C901 grouping_frame = adata.var if axis_length == 0: - raise ValueError( - "Cannot compute completeness on empty axis." - ) + raise ValueError("Cannot compute completeness on empty axis.") if n_items == 0: - raise ValueError( - "No items to compute completeness for." - ) + raise ValueError("No items to compute completeness for.") if order is not None and group_by_partition is None: warnings.warn( @@ -130,24 +116,36 @@ def _validate_completeness_args( # noqa: C901 ) return [ - matrix, axis_labels, n_items, axis_length, - grouping_frame, min_count, min_fraction, + matrix, + axis_labels, + n_items, + axis_length, + grouping_frame, + min_count, + min_fraction, ] def _summary_stats(values): """Return a single-row DataFrame of summary statistics.""" - s = pd.Series(values) if not isinstance( - values, pd.Series, - ) else values - return pd.DataFrame({ - "count": [s.count()], - "mean": [s.mean()], - "median": [s.median()], - "std": [s.std()], - "min": [s.min()], - "max": [s.max()], - }) + s = ( + pd.Series(values) + if not isinstance( + values, + pd.Series, + ) + else values + ) + return pd.DataFrame( + { + "count": [s.count()], + "mean": [s.mean()], + "median": [s.median()], + "std": [s.std()], + "min": [s.min()], + "max": [s.max()], + } + ) def _count_nonmissing(mat, ax, zero_to_na): @@ -198,9 +196,7 @@ def _resolve_partition_order(order, available): order = [order] else: order = list(order) - missing = [ - g for g in order if g not in available - ] + missing = [g for g in order if g not in available] if missing: raise ValueError( "Unknown group(s) in `order`: " @@ -211,7 +207,10 @@ def _resolve_partition_order(order, available): def _group_completeness_counts( - matrix, axis, g_mask, zero_to_na, + matrix, + axis, + g_mask, + zero_to_na, ): """Count non-missing values per item within a group mask.""" if axis == 0: @@ -249,13 +248,13 @@ def _plot_completeness_partition( group_series = grouping_frame[group_by_partition] available = list(group_series.dropna().unique()) unique_groups = _resolve_partition_order( - order, available, + order, + available, ) if len(unique_groups) == 0: raise ValueError( - "No groups found for the given " - "`group_by_partition` column.", + "No groups found for the given " "`group_by_partition` column.", ) # -- compute completeness per item within each group @@ -263,35 +262,40 @@ def _plot_completeness_partition( for g in unique_groups: g_mask = (group_series == g).values counts_g, g_size = _group_completeness_counts( - matrix, axis, g_mask, zero_to_na, + matrix, + axis, + g_mask, + zero_to_na, ) fracs = counts_g / g_size for f in fracs: - records.append( - {"Group": str(g), "Completeness": f} - ) + records.append({"Group": str(g), "Completeness": f}) long_df = pd.DataFrame(records) if print_stats: print("Global:") - print(_summary_stats( - long_df["Completeness"], - ).to_string( - index=False, float_format="%.4f", - )) + print( + _summary_stats( + long_df["Completeness"], + ).to_string( + index=False, + float_format="%.4f", + ) + ) per_group = ( long_df.groupby("Group")["Completeness"] - .agg(["count", "mean", "median", - "std", "min", "max"]) + .agg(["count", "mean", "median", "std", "min", "max"]) .reindex( [str(g) for g in unique_groups], ) ) print(f"\nPer {group_by_partition}:") - print(per_group.to_string( - float_format="%.4f", - )) + print( + per_group.to_string( + float_format="%.4f", + ) + ) print() if ax is None: @@ -307,8 +311,7 @@ def _plot_completeness_partition( ax=_ax, ) _ax.set_title( - f"Completeness per {axis_labels[0]} " - f"by '{group_by_partition}'", + f"Completeness per {axis_labels[0]} " f"by '{group_by_partition}'", ) _ax.set_xlabel(group_by_partition) _ax.set_ylabel( @@ -352,9 +355,12 @@ def _plot_completeness_ungrouped( if print_stats: print("Global:") - print(_summary_stats(fractions).to_string( - index=False, float_format="%.4f", - )) + print( + _summary_stats(fractions).to_string( + index=False, + float_format="%.4f", + ) + ) print() if ax is None: @@ -379,7 +385,8 @@ def _plot_completeness_ungrouped( ) _ax.legend() plt.setp( - _ax.get_xticklabels(), rotation=xlabel_rotation, + _ax.get_xticklabels(), + rotation=xlabel_rotation, ) return fig, _ax @@ -409,15 +416,12 @@ def _plot_completeness_resolution( ) group_series = grouping_frame[group_by_resolution] - unique_groups = list( - group_series.dropna().unique() - ) + unique_groups = list(group_series.dropna().unique()) n_groups = len(unique_groups) if n_groups == 0: raise ValueError( - "No groups found for the given " - "`group_by_resolution` column.", + "No groups found for the given " "`group_by_resolution` column.", ) # Default threshold: min_count=1 @@ -431,13 +435,14 @@ def _plot_completeness_resolution( for g in unique_groups: g_mask = (group_series == g).values counts_g, group_size = _group_completeness_counts( - matrix, axis, g_mask, zero_to_na, + matrix, + axis, + g_mask, + zero_to_na, ) if use_fraction: - detected = ( - counts_g / group_size >= min_fraction - ) + detected = counts_g / group_size >= min_fraction else: detected = counts_g >= min_count @@ -447,11 +452,14 @@ def _plot_completeness_resolution( if print_stats: print("Global:") - print(_summary_stats( - detection_fractions, - ).to_string( - index=False, float_format="%.4f", - )) + print( + _summary_stats( + detection_fractions, + ).to_string( + index=False, + float_format="%.4f", + ) + ) print() if ax is None: @@ -460,19 +468,18 @@ def _plot_completeness_resolution( _ax = ax fig = _ax.get_figure() sns.histplot( - detection_fractions, bins=bin_edges, ax=_ax, + detection_fractions, + bins=bin_edges, + ax=_ax, ) if use_fraction: - threshold_label = ( - f"min_fraction={min_fraction}" - ) + threshold_label = f"min_fraction={min_fraction}" else: threshold_label = f"min_count={min_count}" _ax.set_title( - f"'{group_by_resolution}' completeness " - f"per {axis_labels[0]}", + f"'{group_by_resolution}' completeness " f"per {axis_labels[0]}", ) _ax.set_xlabel( f"Fraction of '{group_by_resolution}' groups " @@ -488,7 +495,8 @@ def _plot_completeness_resolution( ) _ax.legend() plt.setp( - _ax.get_xticklabels(), rotation=xlabel_rotation, + _ax.get_xticklabels(), + rotation=xlabel_rotation, ) return fig, _ax @@ -583,9 +591,15 @@ def completeness( The Matplotlib Axes object used for plotting. """ validated = _validate_completeness_args( - adata, axis, layer, order, - group_by_resolution, group_by_partition, - min_count, min_fraction, fraction_thresh, + adata, + axis, + layer, + order, + group_by_resolution, + group_by_partition, + min_count, + min_fraction, + fraction_thresh, bin_width, ) matrix = validated[0] @@ -597,7 +611,9 @@ def completeness( min_fraction = validated[6] bin_edges = np.arange( - 0.0, 1.0 + bin_width * 2, bin_width, + 0.0, + 1.0 + bin_width * 2, + bin_width, ) if group_by_partition is not None: @@ -922,30 +938,37 @@ def _append_unique(seq, value) -> None: def _n_var_summary_stats(series): """Return a one-row DataFrame of count summary stats.""" - return pd.DataFrame({ - "mean_count": [series.mean()], - "std_count": [series.std()], - "median_count": [series.median()], - "min_count": [series.min()], - "max_count": [series.max()], - }) + return pd.DataFrame( + { + "mean_count": [series.mean()], + "std_count": [series.std()], + "median_count": [series.median()], + "min_count": [series.min()], + "max_count": [series.max()], + } + ) def _add_pct_cols(df, total): """Add percentage columns to *df* in place.""" for col in [ - "mean", "std", "median", "min", "max", + "mean", + "std", + "median", + "min", + "max", ]: - df[f"{col}_pct"] = ( - df[f"{col}_count"] / total * 100 - ) + df[f"{col}_pct"] = df[f"{col}_count"] / total * 100 def _print_stats_df(df): """Print a DataFrame with one-decimal formatting.""" - print(df.to_string( - index=False, float_format="%.1f", - )) + print( + df.to_string( + index=False, + float_format="%.1f", + ) + ) _AGG_STATS = { @@ -976,16 +999,12 @@ def _validate_n_var_per_sample_args( # noqa: C901 "'peptide', 'protein', or None." ) if level == "peptide" and data_level == "protein": - raise ValueError( - "Cannot count peptides from " - "protein-level data." - ) + raise ValueError("Cannot count peptides from " "protein-level data.") # -- Mutual exclusivity if group_by is not None and order_by is not None: raise ValueError( - "`group_by` and `order_by` cannot be " - "used together." + "`group_by` and `order_by` cannot be " "used together." ) # -- Validate layer @@ -993,55 +1012,37 @@ def _validate_n_var_per_sample_args( # noqa: C901 matrix = adata.X else: if layer not in adata.layers: - raise KeyError( - f"Layer '{layer}' not found in " - "adata.layers." - ) + raise KeyError(f"Layer '{layer}' not found in " "adata.layers.") matrix = adata.layers[layer] if matrix is None: raise ValueError( - "Selected layer is empty; cannot " - "compute variable counts." + "Selected layer is empty; cannot " "compute variable counts." ) # -- Validate group_by column if group_by is not None: if group_by not in adata.obs.columns: - raise KeyError( - f"Column '{group_by}' not found " - "in adata.obs." - ) + raise KeyError(f"Column '{group_by}' not found " "in adata.obs.") # -- Validate order_by column if order_by is not None: if order_by not in adata.obs.columns: - raise KeyError( - f"Column '{order_by}' not found " - "in adata.obs." - ) + raise KeyError(f"Column '{order_by}' not found " "in adata.obs.") # -- Validate order elements if order is not None: if group_by is not None: - valid = set( - adata.obs[group_by].dropna().unique() - ) + valid = set(adata.obs[group_by].dropna().unique()) source = f"adata.obs['{group_by}']" elif order_by is not None: - valid = set( - adata.obs[order_by].dropna().unique() - ) + valid = set(adata.obs[order_by].dropna().unique()) source = f"adata.obs['{order_by}']" else: valid = set(adata.obs_names) source = "adata.obs_names" - invalid = [ - o for o in order if o not in valid - ] + invalid = [o for o in order if o not in valid] if invalid: - invalid_str = ", ".join( - map(str, invalid) - ) + invalid_str = ", ".join(map(str, invalid)) raise ValueError( f"Unknown value(s) in `order`: " f"{invalid_str}. Valid values " @@ -1064,7 +1065,11 @@ def _valid_mask(matrix, zero_to_na): def _n_var_count_per_sample( - matrix, zero_to_na, level, data_level, adata, + matrix, + zero_to_na, + level, + data_level, + adata, ): """Count non-missing vars per sample. @@ -1086,7 +1091,8 @@ def _n_var_count_per_sample( n_proteins = protein_codes.max() + 1 # OR-reduce peptide columns into protein columns prot_detected = np.zeros( - (valid.shape[0], n_proteins), dtype=bool, + (valid.shape[0], n_proteins), + dtype=bool, ) np.maximum.at( prot_detected, @@ -1103,8 +1109,13 @@ def _n_var_count_per_sample( def _n_var_derive_totals( - counts_array, level, data_level, - percentage, ylabel, title, adata, + counts_array, + level, + data_level, + percentage, + ylabel, + title, + adata, ): """Derive totals, percentage, ylabel, and title.""" if level == "protein" and data_level == "peptide": @@ -1115,12 +1126,9 @@ def _n_var_derive_totals( if percentage: if total_vars == 0: raise ValueError( - "Cannot compute percentage: " - "no variables found." + "Cannot compute percentage: " "no variables found." ) - counts_array = ( - counts_array / total_vars - ) * 100 + counts_array = (counts_array / total_vars) * 100 # -- Resolve y-axis label if ylabel is None: @@ -1128,15 +1136,9 @@ def _n_var_derive_totals( # -- Resolve title if title is None: - if level == "protein" or ( - level is None - and data_level == "protein" - ): + if level == "protein" or (level is None and data_level == "protein"): entity = "proteins" - elif level == "peptide" or ( - level is None - and data_level == "peptide" - ): + elif level == "peptide" or (level is None and data_level == "peptide"): entity = "peptides" else: entity = "variables" @@ -1146,7 +1148,10 @@ def _n_var_derive_totals( def _n_var_print_group_stats( - counts, stats_df, group_by, total_vars, + counts, + stats_df, + group_by, + total_vars, ): """Print global and per-group statistics.""" global_df = _n_var_summary_stats(counts["count"]) @@ -1160,27 +1165,38 @@ def _n_var_print_group_stats( def _n_var_resolve_bar_colors( - color_scheme, group_order, stats_df, group_by, + color_scheme, + group_order, + stats_df, + group_by, ): """Resolve bar colors from a color scheme.""" if color_scheme is None: return None colors = _resolve_color_scheme( - color_scheme, group_order, + color_scheme, + group_order, ) if colors is None: return None - return [ - colors[group_order.index(grp)] - for grp in stats_df[group_by] - ] + return [colors[group_order.index(grp)] for grp in stats_df[group_by]] def _n_var_group_by_path( - counts, adata, group_by, order, - color_scheme, total_vars, ylabel, title, - print_stats, figsize, xlabel_rotation, - save, show, ax=None, + counts, + adata, + group_by, + order, + color_scheme, + total_vars, + ylabel, + title, + print_stats, + figsize, + xlabel_rotation, + save, + show, + ax=None, ): """Plot mean +/- std bar chart grouped by an obs column.""" group_df = adata.obs[[group_by]].copy() @@ -1188,23 +1204,23 @@ def _n_var_group_by_path( "obs", ).reset_index() counts = pd.merge( - counts, group_df, on="obs", how="left", + counts, + group_df, + on="obs", + how="left", ) counts = counts.dropna(subset=[group_by]) if counts.empty: raise ValueError( - "No observations remain after " - "aligning `group_by` labels.", + "No observations remain after " "aligning `group_by` labels.", ) group_values = counts[group_by] if isinstance( - group_values.dtype, pd.CategoricalDtype, + group_values.dtype, + pd.CategoricalDtype, ): - group_values = ( - group_values.cat - .remove_unused_categories() - ) + group_values = group_values.cat.remove_unused_categories() counts[group_by] = group_values available_groups: list[Any] = [] @@ -1212,7 +1228,9 @@ def _n_var_group_by_path( _append_unique(available_groups, value) group_order = _n_var_resolve_group_order( - order, available_groups, group_values, + order, + available_groups, + group_values, ) # Append any groups not yet in order @@ -1221,29 +1239,30 @@ def _n_var_group_by_path( # -- Compute per-group statistics stats_df = ( - counts.groupby(group_by, observed=True)[ - "count" - ] + counts.groupby(group_by, observed=True)["count"] .agg(**_AGG_STATS) .reindex(group_order) ) stats_df = stats_df.dropna( subset=["mean_count"], ) - stats_df["std_count"] = ( - stats_df["std_count"].fillna(0.0) - ) + stats_df["std_count"] = stats_df["std_count"].fillna(0.0) stats_df = stats_df.reset_index() if print_stats: _n_var_print_group_stats( - counts, stats_df, group_by, total_vars, + counts, + stats_df, + group_by, + total_vars, ) # -- Plot grouped bar chart bar_colors = _n_var_resolve_bar_colors( - color_scheme, group_order, - stats_df, group_by, + color_scheme, + group_order, + stats_df, + group_by, ) if ax is not None: @@ -1272,7 +1291,8 @@ def _n_var_group_by_path( if save is not None: fig.savefig( - save, dpi=300, + save, + dpi=300, bbox_inches="tight", ) if show: @@ -1281,7 +1301,9 @@ def _n_var_group_by_path( def _n_var_resolve_group_order( - order, available_groups, group_values, + order, + available_groups, + group_values, ): """Resolve group ordering from order arg or categories.""" if order: @@ -1289,7 +1311,8 @@ def _n_var_resolve_group_order( group_order: list[Any] = [] for grp in order: if not _contains_value( - group_order, grp, + group_order, + grp, ): group_order.append(grp) return group_order @@ -1305,30 +1328,32 @@ def _n_var_resolve_group_order( def _n_var_resolve_obs_ordering( - counts, obs_df, group_key, order, - available_groups, ascending, + counts, + obs_df, + group_key, + order, + available_groups, + ascending, ): """Resolve observation ordering for the per-obs bar path.""" has_grouping = group_key != "_group" if has_grouping: group_order = _n_var_resolve_group_order( - order, available_groups, obs_df[group_key], + order, + available_groups, + obs_df[group_key], ) for grp in available_groups: _append_unique(group_order, grp) cat_index_map: dict[str, list[str]] = {} for grp in group_order: - obs_list = obs_df.loc[ - obs_df[group_key] == grp, "obs" - ].tolist() + obs_list = obs_df.loc[obs_df[group_key] == grp, "obs"].tolist() if obs_list: cat_index_map[str(grp)] = obs_list x_ordered = [ - obs - for obs_list in cat_index_map.values() - for obs in obs_list + obs for obs_list in cat_index_map.values() for obs in obs_list ] else: if order: @@ -1336,11 +1361,13 @@ def _n_var_resolve_obs_ordering( x_ordered: list[Any] = [] for obs_name in order: _append_unique( - x_ordered, obs_name, + x_ordered, + obs_name, ) for obs_name in counts["obs"]: _append_unique( - x_ordered, obs_name, + x_ordered, + obs_name, ) else: if ascending is not None: @@ -1349,24 +1376,30 @@ def _n_var_resolve_obs_ordering( ascending=ascending, kind="mergesort", ) - x_ordered = sorted_counts[ - "obs" - ].tolist() + x_ordered = sorted_counts["obs"].tolist() else: - x_ordered = counts[ - "obs" - ].tolist() + x_ordered = counts["obs"].tolist() cat_index_map = {"all": x_ordered} return x_ordered, cat_index_map def _n_var_plot_per_obs( - counts, x_ordered, cat_index_map, - group_key, order_by, total_vars, - color_scheme, ylabel, title, - print_stats, figsize, xlabel_rotation, - order_by_label_rotation, save, show, + counts, + x_ordered, + cat_index_map, + group_key, + order_by, + total_vars, + color_scheme, + ylabel, + title, + print_stats, + figsize, + xlabel_rotation, + order_by_label_rotation, + save, + show, ax=None, ): """Plot per-observation bars with group labels.""" @@ -1383,7 +1416,8 @@ def _n_var_plot_per_obs( _print_stats_df(global_df) print_df = ( counts.groupby( - order_by, observed=True, + order_by, + observed=True, )["count"] .agg(**_AGG_STATS) .reset_index() @@ -1399,24 +1433,20 @@ def _n_var_plot_per_obs( _print_stats_df(print_df) # -- Resolve colors - counts[group_key] = ( - counts[group_key].astype(str) - ) + counts[group_key] = counts[group_key].astype(str) unique_groups = list(cat_index_map.keys()) colors = _resolve_color_scheme( - color_scheme, unique_groups, + color_scheme, + unique_groups, ) plot_kwargs = {} if colors is not None: color_map = { - str(grp): colors[i] - for i, grp in enumerate(unique_groups) + str(grp): colors[i] for i, grp in enumerate(unique_groups) } - plot_kwargs["color"] = ( - counts[group_key].map(color_map).to_list() - ) + plot_kwargs["color"] = counts[group_key].map(color_map).to_list() # -- Plot per-observation bars if ax is not None: @@ -1442,10 +1472,8 @@ def _n_var_plot_per_obs( _ax.set_ylabel(ylabel) # -- Add group labels above bars - obs_idx_map = { - obs: i for i, obs in enumerate(x_ordered) - } - ymax = counts['count'].max() + obs_idx_map = {obs: i for i, obs in enumerate(x_ordered)} + ymax = counts["count"].max() for cat, obs_list in cat_index_map.items(): if not obs_list: continue @@ -1457,10 +1485,10 @@ def _n_var_plot_per_obs( x=mid_idx, y=ymax * 1.05, s=cat, - ha='center', - va='bottom', + ha="center", + va="bottom", fontsize=8, - fontweight='bold', + fontweight="bold", rotation=order_by_label_rotation, ) @@ -1469,7 +1497,9 @@ def _n_var_plot_per_obs( if save is not None: fig.savefig( - save, dpi=300, bbox_inches='tight', + save, + dpi=300, + bbox_inches="tight", ) if show: plt.show() @@ -1489,7 +1519,7 @@ def n_var_per_sample( group_by: str | None = None, print_stats: bool = False, figsize: tuple[float, float] = (6.0, 4.0), - color_scheme: str | dict | Sequence | Colormap | callable | None = None, + color_scheme: str | dict | Sequence | Colormap | Callable | None = None, title: str | None = None, ylabel: str | None = None, xlabel_rotation: float = 90, @@ -1591,24 +1621,33 @@ def n_var_per_sample( ... order=["LBaso", "Ortho"], ... ) """ - data_level, level, matrix = ( - _validate_n_var_per_sample_args( - adata, level, group_by, order_by, - order, layer, - ) + data_level, level, matrix = _validate_n_var_per_sample_args( + adata, + level, + group_by, + order_by, + order, + layer, ) # -- Count non-missing vars per sample counts_array = _n_var_count_per_sample( - matrix, zero_to_na, level, data_level, adata, + matrix, + zero_to_na, + level, + data_level, + adata, ) # -- Derive totals, percentage, ylabel, and title - total_vars, counts_array, ylabel, title = ( - _n_var_derive_totals( - counts_array, level, data_level, - percentage, ylabel, title, adata, - ) + total_vars, counts_array, ylabel, title = _n_var_derive_totals( + counts_array, + level, + data_level, + percentage, + ylabel, + title, + adata, ) # -- Build counts DataFrame @@ -1625,15 +1664,13 @@ def n_var_per_sample( if ascending is not None: if group_by is not None: warnings.warn( - "`ascending` is ignored when " - "`group_by` is set.", + "`ascending` is ignored when " "`group_by` is set.", UserWarning, stacklevel=2, ) elif order is not None: warnings.warn( - "`ascending` is ignored when " - "`order` is set explicitly.", + "`ascending` is ignored when " "`order` is set explicitly.", UserWarning, stacklevel=2, ) @@ -1641,17 +1678,25 @@ def n_var_per_sample( # -- group_by path: mean +/- std bar plot per group if group_by is not None: return _n_var_group_by_path( - counts, adata, group_by, order, - color_scheme, total_vars, ylabel, - title, print_stats, figsize, - xlabel_rotation, save, show, ax, + counts, + adata, + group_by, + order, + color_scheme, + total_vars, + ylabel, + title, + print_stats, + figsize, + xlabel_rotation, + save, + show, + ax, ) # -- Per-observation bar plot (with optional order_by) has_grouping = order_by is not None - group_key = ( - order_by if has_grouping else "_group" - ) + group_key = order_by if has_grouping else "_group" # Attach grouping column to counts if has_grouping: @@ -1661,7 +1706,10 @@ def n_var_per_sample( "obs", ).reset_index() counts = pd.merge( - counts, obs, on="obs", how="left", + counts, + obs, + on="obs", + how="left", ) else: counts[group_key] = counts["obs"] @@ -1678,20 +1726,20 @@ def n_var_per_sample( obs_df[group_key].dtype, pd.CategoricalDtype, ): - obs_df[group_key] = ( - obs_df[group_key].astype("category") - ) + obs_df[group_key] = obs_df[group_key].astype("category") available_groups: list[Any] = [] for value in obs_df[group_key]: _append_unique(available_groups, value) # -- Resolve observation ordering - x_ordered, cat_index_map = ( - _n_var_resolve_obs_ordering( - counts, obs_df, group_key, order, - available_groups, ascending, - ) + x_ordered, cat_index_map = _n_var_resolve_obs_ordering( + counts, + obs_df, + group_key, + order, + available_groups, + ascending, ) counts["obs"] = pd.Categorical( @@ -1703,11 +1751,22 @@ def n_var_per_sample( # -- Plot per-observation bars return _n_var_plot_per_obs( - counts, x_ordered, cat_index_map, - group_key, order_by, total_vars, - color_scheme, ylabel, title, - print_stats, figsize, xlabel_rotation, - order_by_label_rotation, save, show, ax, + counts, + x_ordered, + cat_index_map, + group_key, + order_by, + total_vars, + color_scheme, + ylabel, + title, + print_stats, + figsize, + xlabel_rotation, + order_by_label_rotation, + save, + show, + ax, ) @@ -1887,14 +1946,18 @@ def _ordered_categories(series: pd.Series) -> list[Any]: if selected_categories is not None: first_level_order = [ - category for category in selected_categories if category in first_level_order + category + for category in selected_categories + if category in first_level_order ] if order is not None: if isinstance(order, str): specified = [order] else: specified = list(order) - unknown_specified = [cat for cat in specified if cat not in first_level_order] + unknown_specified = [ + cat for cat in specified if cat not in first_level_order + ] if unknown_specified: raise ValueError( "Order values not present in the first category column: " @@ -1948,13 +2011,13 @@ def _ordered_categories(series: pd.Series) -> list[Any]: _ax.yaxis.set_major_locator(MaxNLocator(integer=True)) _ax.set_xlabel(first_cat_col) - _ax.set_ylabel('#') + _ax.set_ylabel("#") ha = ( - 'right' if xlabel_rotation > 0 - else 'left' if xlabel_rotation < 0 - else 'center' - ) + "right" + if xlabel_rotation > 0 + else "left" if xlabel_rotation < 0 else "center" + ) plt.setp(_ax.get_xticklabels(), rotation=xlabel_rotation, ha=ha) fig.tight_layout() @@ -2063,9 +2126,15 @@ def n_cat1_per_cat2_hist( ) lower, upper = bin_range if lower >= upper: - raise ValueError("bin_range lower bound must be less than upper bound.") + raise ValueError( + "bin_range lower bound must be less than upper bound." + ) - temp_col = "__proteopy_axis_index__" if first_category == "index" else first_category + temp_col = ( + "__proteopy_axis_index__" + if first_category == "index" + else first_category + ) data = frame[[second_category]].copy() if first_category == "index": index_values = adata.obs_names if axis == 0 else adata.var_names @@ -2125,9 +2194,8 @@ def n_cat1_per_cat2_hist( return _ax -docstr_header = ( - "Plot the distribution of the number of first-category entries per second category." - ) + +docstr_header = "Plot the distribution of the number of first-category entries per second category." n_peptides_per_protein = partial_with_docsig( n_cat1_per_cat2_hist, first_category="peptide_id", @@ -2295,13 +2363,17 @@ def cv_by_group( if temp_key_name is not None: del adata.varm[temp_key_name] - df_melted = cv_df.melt(var_name="Group", value_name="CV", ignore_index=False) + df_melted = cv_df.melt( + var_name="Group", value_name="CV", ignore_index=False + ) df_melted = df_melted.reset_index(drop=True) if order is None: order = unique_groups else: - missing = [grp for grp in order if grp not in df_melted["Group"].unique()] + missing = [ + grp for grp in order if grp not in df_melted["Group"].unique() + ] if missing: raise ValueError( "Requested ordering includes groups with no CV data: " @@ -2316,14 +2388,16 @@ def cv_by_group( if print_stats: cv_values = df_melted["CV"].dropna() - global_summary = pd.DataFrame({ - "Count": [cv_values.count()], - "Min": [round(cv_values.min(), 4)], - "Max": [round(cv_values.max(), 4)], - "Median": [round(cv_values.median(), 4)], - "Mean": [round(cv_values.mean(), 4)], - "Std": [round(cv_values.std(), 4)], - }) + global_summary = pd.DataFrame( + { + "Count": [cv_values.count()], + "Min": [round(cv_values.min(), 4)], + "Max": [round(cv_values.max(), 4)], + "Median": [round(cv_values.median(), 4)], + "Mean": [round(cv_values.mean(), 4)], + "Std": [round(cv_values.std(), 4)], + } + ) print("Global CV Summary:") print(global_summary.to_string(index=False)) print() @@ -2353,14 +2427,13 @@ def cv_by_group( if total_count > 0 else 0.0 ) - global_thresh = pd.DataFrame({ - "Count below": [int(below_count)], - "Percentage below": [pct], - }) - print( - f"Global Threshold Summary " - f"(hline={hline}):" + global_thresh = pd.DataFrame( + { + "Count below": [int(below_count)], + "Percentage below": [pct], + } ) + print(f"Global Threshold Summary " f"(hline={hline}):") print(global_thresh.to_string(index=False)) print() @@ -2368,14 +2441,14 @@ def _thresh_stats(group_cv): n_below = (group_cv < hline).sum() n_total = group_cv.count() pct_below = ( - round(n_below / n_total * 100, 4) - if n_total > 0 - else 0.0 + round(n_below / n_total * 100, 4) if n_total > 0 else 0.0 + ) + return pd.Series( + { + "Count below": int(n_below), + "Percentage below": pct_below, + } ) - return pd.Series({ - "Count below": int(n_below), - "Percentage below": pct_below, - }) per_group_thresh = ( df_melted.groupby("Group")["CV"] @@ -2383,10 +2456,7 @@ def _thresh_stats(group_cv): .unstack() .reindex(order) ) - print( - f"Per-Group Threshold Summary " - f"(hline={hline}):" - ) + print(f"Per-Group Threshold Summary " f"(hline={hline}):") print(per_group_thresh.to_string()) print() @@ -2546,7 +2616,9 @@ def sample_correlation_matrix( matrix = adata.layers[layer] if matrix is None: - raise ValueError("Selected matrix is empty; cannot compute correlations.") + raise ValueError( + "Selected matrix is empty; cannot compute correlations." + ) if matrix.shape != expected_shape: raise ValueError( @@ -2555,7 +2627,9 @@ def sample_correlation_matrix( ) if isinstance(matrix, pd.DataFrame): - vals = matrix.reindex(index=adata.obs_names, columns=adata.var_names).copy() + vals = matrix.reindex( + index=adata.obs_names, columns=adata.var_names + ).copy() else: if sparse.issparse(matrix): # correlation requires dense values; convert temporarily @@ -2609,7 +2683,9 @@ def sample_correlation_matrix( sns.color_palette(n_colors=len(cats)) if len(cats) > 0 else [] ) - palette = {str(cat): color for cat, color in zip(cats, resolved_colors)} + palette = { + str(cat): color for cat, color in zip(cats, resolved_colors) + } groups_str = groups.astype("string") row_color_series = groups_str.map(palette) @@ -2623,7 +2699,9 @@ def sample_correlation_matrix( ) legend_handles = [ - Patch(facecolor=palette[str(cat)], edgecolor="none", label=str(cat)) + Patch( + facecolor=palette[str(cat)], edgecolor="none", label=str(cat) + ) for cat in cats ] @@ -2636,7 +2714,9 @@ def sample_correlation_matrix( ) row_colors = ( - row_color_series.to_numpy() if row_color_series is not None else None + row_color_series.to_numpy() + if row_color_series is not None + else None ) # ---- hierarchical clustering on (1 - r) @@ -2648,30 +2728,29 @@ def sample_correlation_matrix( # ---- optional statistics printout if print_stats and n > 1: # 1) Overall off-diagonal summary - summary = pd.DataFrame({ - "min": [np.nanmin(offdiag)], - "max": [np.nanmax(offdiag)], - "mean": [np.nanmean(offdiag)], - "median": [np.nanmedian(offdiag)], - "std": [np.nanstd(offdiag)], - }) - print( - f"Sample correlation summary " - f"(off-diagonal, {method}):" + summary = pd.DataFrame( + { + "min": [np.nanmin(offdiag)], + "max": [np.nanmax(offdiag)], + "mean": [np.nanmean(offdiag)], + "median": [np.nanmedian(offdiag)], + "std": [np.nanstd(offdiag)], + } ) + print(f"Sample correlation summary " f"(off-diagonal, {method}):") print(summary.to_string(index=False)) print() # 2) Per-sample mean correlation (dendrogram order) mask = ~np.eye(n, dtype=bool) - per_sample_mean = np.nanmean( - np.where(mask, A, np.nan), axis=1 - ) + per_sample_mean = np.nanmean(np.where(mask, A, np.nan), axis=1) heatmap_order = leaves_list(Z) - per_sample_df = pd.DataFrame({ - "sample_id": corr_df.index[heatmap_order], - "mean_corr": per_sample_mean[heatmap_order], - }) + per_sample_df = pd.DataFrame( + { + "sample_id": corr_df.index[heatmap_order], + "mean_corr": per_sample_mean[heatmap_order], + } + ) print("Per-sample mean correlation:") print(per_sample_df.to_string(index=False)) print() @@ -2680,44 +2759,35 @@ def sample_correlation_matrix( if margin_color is not None: if margin_color not in adata.obs.columns: raise KeyError( - f"Column '{margin_color}' not found " - f"in adata.obs." + f"Column '{margin_color}' not found " f"in adata.obs." ) - groups_ps = adata.obs.loc[ - corr_df.index, margin_color - ] + groups_ps = adata.obs.loc[corr_df.index, margin_color] unique_groups = groups_ps.dropna().unique() group_rows = [] for grp in sorted(unique_groups): - grp_idx = groups_ps[ - groups_ps == grp - ].index + grp_idx = groups_ps[groups_ps == grp].index other_idx = groups_ps[ (groups_ps != grp) & groups_ps.notna() ].index within = corr_df.loc[grp_idx, grp_idx] - within_vals = within.values[ - ~np.eye(len(grp_idx), dtype=bool) - ] + within_vals = within.values[~np.eye(len(grp_idx), dtype=bool)] mean_within = ( - np.nanmean(within_vals) - if len(within_vals) > 0 - else np.nan + np.nanmean(within_vals) if len(within_vals) > 0 else np.nan ) if len(other_idx) > 0: between_vals = corr_df.loc[ grp_idx, other_idx ].values.ravel() - mean_between = np.nanmean( - between_vals - ) + mean_between = np.nanmean(between_vals) else: mean_between = np.nan - group_rows.append({ - "group": grp, - "mean_within": mean_within, - "mean_between": mean_between, - }) + group_rows.append( + { + "group": grp, + "mean_within": mean_within, + "mean_between": mean_between, + } + ) group_df = pd.DataFrame(group_rows) print("Per-group mean correlation:") print(group_df.to_string(index=False)) @@ -2731,7 +2801,7 @@ def sample_correlation_matrix( row_colors=row_colors, col_colors=row_colors if row_colors is not None else None, cmap=cmap, - center=center_val, + center=center_val, figsize=figsize, xticklabels=xticklabels, yticklabels=yticklabels, @@ -2744,8 +2814,8 @@ def sample_correlation_matrix( handles=legend_handles, title=margin_color, bbox_to_anchor=(1.05, 1), - loc='upper left', - borderaxespad=0., + loc="upper left", + borderaxespad=0.0, frameon=False, ) @@ -2783,7 +2853,10 @@ def hclustv_profiles_heatmap( row_cluster: bool = True, col_cluster: bool = True, cbar_pos: tuple[float, float, float, float] | None = ( - 0.02, 0.8, 0.05, 0.18 + 0.02, + 0.8, + 0.05, + 0.18, ), tree_kws: dict | None = None, xticklabels: bool = True, @@ -2923,19 +2996,19 @@ def hclustv_profiles_heatmap( raise KeyError(f"Column '{order_by}' not found in adata.obs.") # order_by and col_cluster are mutually exclusive; disable clustering if col_cluster: - print(( + print( "`order_by` parameter is incompatible with `col_cluster=True`. " "`col_cluster` has been overridden." - )) + ) col_cluster = False # Validate order parameter if order is not None: if col_cluster: - print(( + print( "`order` parameter is incompatible with `col_cluster=True`. " "`col_cluster` has been overridden." - )) + ) col_cluster = False order = list(order) if order_by is None and group_by is None: @@ -3061,8 +3134,9 @@ def hclustv_profiles_heatmap( ) sorted_idx = ( pd.Series(order_col_values, index=filtered_cols) - .sort_values().index - ) + .sort_values() + .index + ) else: # Use categorical order if categorical, sorted order otherwise if isinstance(order_col_values.dtype, pd.CategoricalDtype): @@ -3072,10 +3146,14 @@ def hclustv_profiles_heatmap( categories=cat_order, ordered=True, ) - sorted_idx = pd.Series( - order_col_values, - index=z_df_filled.columns, - ).sort_values().index + sorted_idx = ( + pd.Series( + order_col_values, + index=z_df_filled.columns, + ) + .sort_values() + .index + ) else: sorted_idx = order_col_values.sort_values().index z_df_filled = z_df_filled[sorted_idx] @@ -3105,7 +3183,8 @@ def hclustv_profiles_heatmap( if resolved_colors is None: resolved_colors = ( sns.color_palette("husl", n_colors=len(unique_cats)) - if len(unique_cats) > 0 else [] + if len(unique_cats) > 0 + else [] ) color_map = dict(zip(unique_cats, resolved_colors)) col_colors = pd.Series( From e856dd58c0daf36cb149153fbb81af04242cb980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:49:14 +0200 Subject: [PATCH 6/7] test/filtering refactored --- tests/pp/test_filtering.py | 221 ++++++++++++++++++++++--------------- 1 file changed, 134 insertions(+), 87 deletions(-) diff --git a/tests/pp/test_filtering.py b/tests/pp/test_filtering.py index 2446f6f..5bd4535 100644 --- a/tests/pp/test_filtering.py +++ b/tests/pp/test_filtering.py @@ -17,12 +17,12 @@ def _make_adata_filter_obs_base() -> AnnData: n = np.nan X = np.array( [ - [1, 1, 2, 2, 3], # obs0: complete - [n, 1, 2, 2, 3], # obs1: 4/5 complete - [n, n, 2, 2, 3], # obs2: 3/5 complete - [n, n, n, 2, 3], # obs3: 2/5 complete - [0, 1, 2, 2, 3], # obs4: complete and a zero - [0, n, 2, 2, 3], # obs5: 4/5 complete and a zero + [1, 1, 2, 2, 3], # obs0: complete + [n, 1, 2, 2, 3], # obs1: 4/5 complete + [n, n, 2, 2, 3], # obs2: 3/5 complete + [n, n, n, 2, 3], # obs3: 2/5 complete + [0, 1, 2, 2, 3], # obs4: complete and a zero + [0, n, 2, 2, 3], # obs5: 4/5 complete and a zero ], dtype=float, ) @@ -38,9 +38,9 @@ def _make_adata_filter_obs_groupby_singletons() -> AnnData: n = np.nan X = np.array( [ - [n, n], # obs0 - [1, n], # obs1 - [1, 1], # obs2 + [n, n], # obs0 + [1, n], # obs1 + [1, 1], # obs2 ], dtype=float, ) @@ -62,11 +62,11 @@ def _make_adata_filter_obs_groupby() -> AnnData: n = np.nan X = np.array( [ - [1, 1, 2, 2, 3], # obs0: both groups complete - [1, n, 2, 2, 3], # obs1: group 0 -> 1/2 complete - [1, 1, 2, 2, n], # obs2: group 1 -> 2/3 incomplete - [1, n, 2, 2, n], # obs3: g0 1/2, g1 2/3 - [1, n, 2, n, n], # obs4: g0 1/2, g1 1/3 + [1, 1, 2, 2, 3], # obs0: both groups complete + [1, n, 2, 2, 3], # obs1: group 0 -> 1/2 complete + [1, 1, 2, 2, n], # obs2: group 1 -> 2/3 incomplete + [1, n, 2, 2, n], # obs3: g0 1/2, g1 2/3 + [1, n, 2, n, n], # obs4: g0 1/2, g1 1/3 ], dtype=float, ) @@ -106,8 +106,15 @@ def _make_adata_filter_obs_groupby_na() -> AnnData: { "protein_id": var_names, "group": [ - "g1", "g1", "g2", "g2", "g2", - np.nan, np.nan, np.nan, np.nan, + "g1", + "g1", + "g2", + "g2", + "g2", + np.nan, + np.nan, + np.nan, + np.nan, ], }, index=var_names, @@ -213,8 +220,15 @@ def _make_adata_filter_var_groupby_na() -> AnnData: { "sample_id": obs_names, "group": [ - "g1", "g1", "g2", "g2", "g2", - np.nan, np.nan, np.nan, np.nan, + "g1", + "g1", + "g2", + "g2", + "g2", + np.nan, + np.nan, + np.nan, + np.nan, ], }, index=obs_names, @@ -224,13 +238,18 @@ def _make_adata_filter_var_groupby_na() -> AnnData: { "protein_id": var_names, "group": [ - "g1", "g1", "g2", "g2", np.nan, + "g1", + "g1", + "g2", + "g2", + np.nan, ], }, index=var_names, ) return AnnData(X=X, obs=obs, var=var) + # ── helpers: remove_zero_variance_vars ────────────────────────────── @@ -313,9 +332,7 @@ def _make_adata_rzv_all_vary() -> AnnData: Expected kept (atol=1e-8): [p0, p1, p2] (nothing removed). """ X = np.array( - [[1.0, 10.0, 100.0], - [2.0, 20.0, 200.0], - [3.0, 30.0, 300.0]], + [[1.0, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]], ) obs_names = ["s0", "s1", "s2"] var_names = ["p0", "p1", "p2"] @@ -414,8 +431,7 @@ def _make_adata_rzv_groupby() -> AnnData: ) obs_names = [f"s{i}" for i in range(5)] obs = pd.DataFrame( - {"sample_id": obs_names, - "group": ["g1", "g1", "g2", "g2", "g2"]}, + {"sample_id": obs_names, "group": ["g1", "g1", "g2", "g2", "g2"]}, index=obs_names, ) var_names = [f"p{i}" for i in range(5)] @@ -442,8 +458,7 @@ def _make_adata_rzv_groupby_singletons() -> AnnData: ) obs_names = [f"s{i}" for i in range(5)] obs = pd.DataFrame( - {"sample_id": obs_names, - "group": ["g1", "g2", "g3", "g4", "g5"]}, + {"sample_id": obs_names, "group": ["g1", "g2", "g3", "g4", "g5"]}, index=obs_names, ) var_names = ["p0", "p1", "p2"] @@ -470,8 +485,7 @@ def _make_adata_rzv_groupby_allnan_one_group() -> AnnData: ) obs_names = ["s0", "s1", "s2", "s3"] obs = pd.DataFrame( - {"sample_id": obs_names, - "group": ["A", "A", "B", "B"]}, + {"sample_id": obs_names, "group": ["A", "A", "B", "B"]}, index=obs_names, ) var_names = ["p0", "p1"] @@ -495,8 +509,7 @@ def _make_adata_rzv_groupby_single_group() -> AnnData: ) obs_names = ["s0", "s1", "s2", "s3"] obs = pd.DataFrame( - {"sample_id": obs_names, - "group": ["A", "A", "A", "A"]}, + {"sample_id": obs_names, "group": ["A", "A", "A", "A"]}, index=obs_names, ) var_names = ["p0", "p1", "p2"] @@ -548,9 +561,7 @@ def _make_adata_remove_contaminants_base() -> AnnData: obs_names = [f"obs{i}" for i in range(5)] obs = pd.DataFrame({"sample_id": obs_names}, index=obs_names) var_names = [f"protein_{i}" for i in range(5)] - var = pd.DataFrame({ - "protein_id": var_names - }, index=var_names) + var = pd.DataFrame({"protein_id": var_names}, index=var_names) return AnnData(X=X, obs=obs, var=var) @@ -918,8 +929,11 @@ def test_filter_axis_var_min_fraction_and_min_count(): cases = { (0.4, 3): [ - "protein_0", "protein_1", "protein_2", - "protein_4", "protein_5", + "protein_0", + "protein_1", + "protein_2", + "protein_4", + "protein_5", ], (1.0, 5): ["protein_0", "protein_4"], (0.0, 0): list(adata.var_names), @@ -968,7 +982,9 @@ def test_filter_axis_var_zero_to_na(): ) assert returned is None assert list(adata_inplace.var_names) == [ - "protein_0", "protein_1", "protein_4", + "protein_0", + "protein_1", + "protein_4", ] @@ -1410,7 +1426,9 @@ def test_atol_boundary_equal_variance_removed(self): ], ) obs = pd.DataFrame({"sample_id": ["s0", "s1"]}, index=["s0", "s1"]) - var = pd.DataFrame({"protein_id": ["p0", "p1", "p2"]}, index=["p0", "p1", "p2"]) + var = pd.DataFrame( + {"protein_id": ["p0", "p1", "p2"]}, index=["p0", "p1", "p2"] + ) adata = AnnData(X=X, obs=obs, var=var) filtered = remove_zero_variance_vars(adata, atol=1.0, inplace=False) @@ -1423,7 +1441,9 @@ def test_atol_zero_keeps_tiny_variance(self): # p1: var≈3.3e-17 (> atol) → kept # p2: var≈0.667 (> atol) → kept filtered = remove_zero_variance_vars( - adata, atol=0.0, inplace=False, + adata, + atol=0.0, + inplace=False, ) assert filtered is not None assert list(filtered.var_names) == ["p1", "p2"] @@ -1432,7 +1452,9 @@ def test_large_atol_removes_everything(self): adata = _make_adata_rzv_all_vary() # all vars have variance < 1e10 → all removed filtered = remove_zero_variance_vars( - adata, atol=1e10, inplace=False, + adata, + atol=1e10, + inplace=False, ) assert filtered is not None assert list(filtered.var_names) == [] @@ -1441,7 +1463,8 @@ def test_large_atol_removes_everything(self): def test_negative_atol_raises(self): adata = _make_adata_rzv_base() with pytest.raises( - ValueError, match=r"`atol` must be non-negative.", + ValueError, + match=r"`atol` must be non-negative.", ): remove_zero_variance_vars(adata, atol=-2) @@ -1451,7 +1474,9 @@ def test_negative_atol_raises(self): def test_groupby_removes_zero_in_any_group(self, inplace): adata = _make_adata_rzv_groupby() result = remove_zero_variance_vars( - adata, group_by="group", inplace=inplace, + adata, + group_by="group", + inplace=inplace, ) target = adata if inplace else result if inplace: @@ -1469,7 +1494,9 @@ def test_groupby_singleton_groups_removes_all(self, inplace): match=r"at least one group", ): result = remove_zero_variance_vars( - adata, group_by="group", inplace=inplace, + adata, + group_by="group", + inplace=inplace, ) target = adata if inplace else result if inplace: @@ -1482,21 +1509,28 @@ def test_groupby_singleton_groups_removes_all(self, inplace): def test_groupby_all_nan_in_one_group_warns(self): adata = _make_adata_rzv_groupby_allnan_one_group() with pytest.warns( - UserWarning, match=r"at least one group", + UserWarning, + match=r"at least one group", ): filtered = remove_zero_variance_vars( - adata, group_by="group", inplace=False, + adata, + group_by="group", + inplace=False, ) assert list(filtered.var_names) == ["p1"] def test_groupby_single_group_matches_global(self): adata = _make_adata_rzv_groupby_single_group() filtered_grouped = remove_zero_variance_vars( - adata, group_by="group", inplace=False, + adata, + group_by="group", + inplace=False, ) adata2 = _make_adata_rzv_groupby_single_group() filtered_global = remove_zero_variance_vars( - adata2, group_by=None, inplace=False, + adata2, + group_by=None, + inplace=False, ) assert ( list(filtered_grouped.var_names) @@ -1508,7 +1542,9 @@ def test_groupby_categorical_column(self): adata = _make_adata_rzv_groupby() adata.obs["group"] = pd.Categorical(adata.obs["group"]) filtered = remove_zero_variance_vars( - adata, group_by="group", inplace=False, + adata, + group_by="group", + inplace=False, ) assert list(filtered.var_names) == ["p0"] @@ -1534,7 +1570,8 @@ def test_groupby_missing_column_raises(self): @pytest.mark.parametrize("bad_adata", ["not-anndata", 42, None]) def test_invalid_adata_type(self, bad_adata): with pytest.raises( - TypeError, match=r"`adata` must be an AnnData object", + TypeError, + match=r"`adata` must be an AnnData object", ): remove_zero_variance_vars(adata=bad_adata) @@ -1542,7 +1579,8 @@ def test_invalid_adata_type(self, bad_adata): def test_invalid_group_by_type(self, bad_group_by): adata = _make_adata_rzv_base() with pytest.raises( - TypeError, match=r"`group_by` must be a string or None", + TypeError, + match=r"`group_by` must be a string or None", ): remove_zero_variance_vars(adata, group_by=bad_group_by) @@ -1550,7 +1588,8 @@ def test_invalid_group_by_type(self, bad_group_by): def test_invalid_atol_type(self, bad_atol): adata = _make_adata_rzv_base() with pytest.raises( - TypeError, match=r"`atol` must be a numeric value", + TypeError, + match=r"`atol` must be a numeric value", ): remove_zero_variance_vars(adata, atol=bad_atol) @@ -1558,7 +1597,8 @@ def test_invalid_atol_type(self, bad_atol): def test_invalid_inplace_type(self, bad_inplace): adata = _make_adata_rzv_base() with pytest.raises( - TypeError, match=r"`inplace` must be a bool", + TypeError, + match=r"`inplace` must be a bool", ): remove_zero_variance_vars(adata, inplace=bad_inplace) @@ -1566,7 +1606,8 @@ def test_invalid_inplace_type(self, bad_inplace): def test_invalid_verbose_type(self, bad_verbose): adata = _make_adata_rzv_base() with pytest.raises( - TypeError, match=r"`verbose` must be a bool", + TypeError, + match=r"`verbose` must be a bool", ): remove_zero_variance_vars(adata, verbose=bad_verbose) @@ -1575,7 +1616,9 @@ def test_invalid_verbose_type(self, bad_verbose): def test_verbose_reports_correct_counts(self, capsys): adata = _make_adata_rzv_base() remove_zero_variance_vars( - adata, inplace=True, verbose=True, + adata, + inplace=True, + verbose=True, ) captured = capsys.readouterr() assert "5 variables present" in captured.out @@ -1585,7 +1628,9 @@ def test_verbose_reports_correct_counts(self, capsys): def test_verbose_false_prints_nothing(self, capsys): adata = _make_adata_rzv_base() remove_zero_variance_vars( - adata, inplace=True, verbose=False, + adata, + inplace=True, + verbose=False, ) captured = capsys.readouterr() assert captured.out == "" @@ -1615,17 +1660,15 @@ def test_idempotency(self): second = remove_zero_variance_vars(first, inplace=False) assert list(first.var_names) == list(second.var_names) np.testing.assert_array_equal( - np.asarray(first.X), np.asarray(second.X), + np.asarray(first.X), + np.asarray(second.X), ) def test_kept_var_values_unchanged(self): adata = _make_adata_rzv_base() original_X = adata.X.copy() filtered = remove_zero_variance_vars(adata, inplace=False) - kept_idx = [ - list(adata.var_names).index(v) - for v in filtered.var_names - ] + kept_idx = [list(adata.var_names).index(v) for v in filtered.var_names] np.testing.assert_array_equal( np.asarray(filtered.X), original_X[:, kept_idx], @@ -1638,7 +1681,8 @@ def test_peptide_level_data_basic(self, inplace): adata = _make_adata_rzv_peptide_level() with pytest.warns(UserWarning, match=r"1 variable\(s\)"): result = remove_zero_variance_vars( - adata, inplace=inplace, + adata, + inplace=inplace, ) target = adata if inplace else result if inplace: @@ -1658,7 +1702,9 @@ def test_peptide_level_data_with_groupby(self): # pep3: all-NaN in both groups → removed (warning) with pytest.warns(UserWarning, match=r"at least one group"): filtered = remove_zero_variance_vars( - adata, group_by="group", inplace=False, + adata, + group_by="group", + inplace=False, ) assert filtered is not None assert list(filtered.var_names) == ["pep0", "pep2"] @@ -1669,12 +1715,7 @@ def test_peptide_level_data_with_groupby(self): class TestRemoveContaminants: @pytest.fixture def fasta(self, tmp_path): - fasta_content = ( - ">sp|protein_1\n" - "AAAA\n" - ">sp|protein_2\n" - "CCCC\n" - ) + fasta_content = ">sp|protein_1\n" "AAAA\n" ">sp|protein_2\n" "CCCC\n" fasta_path = tmp_path / "test.fasta" fasta_path.write_text(fasta_content) return fasta_path @@ -1683,9 +1724,7 @@ def fasta(self, tmp_path): def csv_file(self, tmp_path): csv_path = tmp_path / "contaminants.csv" csv_path.write_text( - "contaminant,source\n" - "protein_2,db\n" - "protein_4,db\n", + "contaminant,source\n" "protein_2,db\n" "protein_4,db\n", ) return csv_path @@ -1693,9 +1732,7 @@ def csv_file(self, tmp_path): def tsv_file(self, tmp_path): tsv_path = tmp_path / "contaminants.tsv" tsv_path.write_text( - "contaminant\tcomment\n" - "protein_0\ta\n" - "protein_3\tb\n", + "contaminant\tcomment\n" "protein_0\ta\n" "protein_3\tb\n", ) return tsv_path @@ -1714,7 +1751,9 @@ def test_fasta_filters_expected_proteins(self, fasta, inplace): target = adata if inplace else result assert list(target.var_names) == [ - "protein_0", "protein_3", "protein_4", + "protein_0", + "protein_3", + "protein_4", ] assert target.n_obs == 5 @@ -1727,8 +1766,7 @@ def test_fasta_filters_expected_proteins(self, fasta, inplace): def test_no_matching_contaminants_keeps_all_variables(self, tmp_path): fasta_path = tmp_path / "none_match.fasta" fasta_path.write_text( - ">sp|not_present_a\nAAAA\n" - ">sp|not_present_b\nCCCC\n", + ">sp|not_present_a\nAAAA\n" ">sp|not_present_b\nCCCC\n", ) adata = _make_adata_remove_contaminants_base() @@ -1749,7 +1787,9 @@ def test_csv_filters_using_first_column(self, csv_file): inplace=False, ) assert list(filtered.var_names) == [ - "protein_0", "protein_1", "protein_3", + "protein_0", + "protein_1", + "protein_3", ] def test_tsv_filters_using_first_column(self, tsv_file): @@ -1760,7 +1800,9 @@ def test_tsv_filters_using_first_column(self, tsv_file): inplace=False, ) assert list(filtered.var_names) == [ - "protein_1", "protein_2", "protein_4", + "protein_1", + "protein_2", + "protein_4", ] def test_custom_protein_key_column(self, tmp_path): @@ -1769,12 +1811,15 @@ def test_custom_protein_key_column(self, tmp_path): # in a different order to confirm filtering uses protein_key, # not var_names or var.index adata.var["uniprot_id"] = [ - "Q99714", "P12345", "P67890", "O75822", "Q9Y6K9", + "Q99714", + "P12345", + "P67890", + "O75822", + "Q9Y6K9", ] fasta_path = tmp_path / "custom_key.fasta" fasta_path.write_text( - ">sp|P12345\nAAAA\n" - ">sp|P67890\nCCCC\n", + ">sp|P12345\nAAAA\n" ">sp|P67890\nCCCC\n", ) filtered = remove_contaminants( @@ -1784,14 +1829,15 @@ def test_custom_protein_key_column(self, tmp_path): inplace=False, ) assert list(filtered.var_names) == [ - "protein_0", "protein_3", "protein_4", + "protein_0", + "protein_3", + "protein_4", ] def test_custom_header_parser_is_used(self, tmp_path): fasta_path = tmp_path / "custom_header.fasta" fasta_path.write_text( - ">contam__protein_0\nAAAA\n" - ">contam__protein_4\nCCCC\n", + ">contam__protein_0\nAAAA\n" ">contam__protein_4\nCCCC\n", ) adata = _make_adata_remove_contaminants_base() @@ -1802,14 +1848,15 @@ def test_custom_header_parser_is_used(self, tmp_path): inplace=False, ) assert list(filtered.var_names) == [ - "protein_1", "protein_2", "protein_3", + "protein_1", + "protein_2", + "protein_3", ] def test_header_parser_empty_id_warns_and_skips(self, tmp_path): fasta_path = tmp_path / "empty_id.fasta" fasta_path.write_text( - ">sp|protein_1\nAAAA\n" - ">sp|protein_2\nCCCC\n", + ">sp|protein_1\nAAAA\n" ">sp|protein_2\nCCCC\n", ) adata = _make_adata_remove_contaminants_base() From 5fe5e0791767e0bee157f76c7bc231cc2ed8fa7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Levente=20Temesv=C3=A1ri-Nagy?= <147416790+leventetn@users.noreply.github.com> Date: Tue, 9 Jun 2026 00:33:16 +0200 Subject: [PATCH 7/7] download/contaminants format_fasta corrected for Windows --- proteopy/download/contaminants.py | 1 + 1 file changed, 1 insertion(+) diff --git a/proteopy/download/contaminants.py b/proteopy/download/contaminants.py index e024262..d87fab3 100644 --- a/proteopy/download/contaminants.py +++ b/proteopy/download/contaminants.py @@ -74,6 +74,7 @@ def _format_fasta( destination_path, "w", encoding="utf-8", + newline="\n", ) as dest, ): for line in src: