diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 835f6b5..a5b15dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,30 +62,8 @@ jobs: - name: Check distributions run: python -m twine check dist/* - - name: Smoke install built wheel - shell: bash - run: | - set -euo pipefail - WHEEL="$(ls dist/*.whl | head -n 1)" - python -m venv .smoke_venv - .smoke_venv/bin/python -m pip install --upgrade pip - .smoke_venv/bin/python -m pip install "$WHEEL" - mkdir -p .smoke_outside_checkout - cd .smoke_outside_checkout - ../.smoke_venv/bin/melite --version - ../.smoke_venv/bin/python -c " - import melite - expected = ['Config', 'load_datasets', 'plot_cv_distributions', 'predict', '__version__'] - assert melite.__all__ == expected, melite.__all__ - for name in expected: - assert hasattr(melite, name), f'{name} missing' - assert 'load_dataset' not in melite.__all__, 'load_dataset must not be top-level public API' - assert 'ResultManager' not in melite.__all__, 'ResultManager must not be top-level public API' - assert not hasattr(melite, 'Pipeline'), 'Pipeline must not be public' - from melite.result_manager import ResultManager - assert ResultManager is not None, 'ResultManager internal import missing' - print(melite.__version__, 'wheel OK') - " + - name: Smoke installed wheel toy workflow + run: python scripts/smoke_install_wheel.py - name: Smoke install sdist shell: bash diff --git a/CHANGELOG.md b/CHANGELOG.md index 67388d8..99cd477 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [v0.2.1] - 2026-05-27 + +### Changed +- `[models].active` / `ACTIVE_MODELS` now controls which model families are + trained during benchmarking. +- `melite export` now uses strict dataset loading for registry-based datasets + and no longer falls back to `arr.files[0]` when an `.npz` file lacks `X`. +- Added an installed-wheel smoke test that builds the wheel, installs it + outside the repository checkout, runs a toy `[datasets.toy]` smoke benchmark, + exports row 0 non-interactively, and verifies generated artifacts. + +### Compatibility +- The top-level public API remains unchanged: + `Config`, `load_datasets`, `plot_cv_distributions`, `predict`, and + `__version__`. +- Legacy `reduction_type` + `level` export rows remain supported, but + individual legacy `.npz` files must now contain an explicit `X` array. + +--- + ## [v0.2.0] - 2026-05-26 ### Added diff --git a/CITATION.cff b/CITATION.cff index 9c9a2b0..2816b3f 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -2,8 +2,8 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below." type: software title: "MELITE: Multi-model Evaluation and Learning for Inference-ready Tabular Experiments" -version: "0.2.0" -date-released: "2026-05-26" +version: "0.2.1" +date-released: "2026-05-27" authors: - family-names: "Contreras-Torres" given-names: "Flavio F." diff --git a/README.md b/README.md index 0851733..fac5610 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![CI](https://github.com/NanoBiostructuresRG/melite/actions/workflows/ci.yml/badge.svg)](https://github.com/NanoBiostructuresRG/melite/actions/workflows/ci.yml) [![License: LGPL v3](https://img.shields.io/badge/License-LGPL_v3-blue.svg)](LICENSE) -[![Version](https://img.shields.io/badge/version-v0.1.11-blue.svg)]() +[![Version](https://img.shields.io/badge/version-v0.2.1-blue.svg)]() [![Python](https://img.shields.io/badge/python-3.11%20%7C%203.12-blue)]() **MELITE** is a pre-stable Python toolkit for tabular classification @@ -21,7 +21,7 @@ Project: MELITE PyPI distribution: melite Import package: melite CLI: melite -Version: 0.2.0 +Version: 0.2.1 License: LGPL-3.0-or-later Status: alpha / pre-stable ``` @@ -133,6 +133,16 @@ Registered datasets are loaded strictly: missing files, missing `X`, non-2D or non-numeric `X`, length mismatches, and embedded `y` mismatches fail the run. Legacy `[benchmark].reduction_types` and `levels` configs are still accepted and are normalized into equivalent dataset entries such as `PCA70` and `UMAP90`. + +Model families are controlled by `[models].active`: + +```toml +[models] +active = ["svc", "rf", "xgb"] +``` + +Remove a key to skip that family during training. Valid keys are `svc`, `rf`, +and `xgb`. ## CLI @@ -165,7 +175,7 @@ from melite import __version__ ``` Modules not listed above are importable directly but are not part of the public -contract and may change before 0.2.0. +contract and may change before 1.0. ## Input Format @@ -196,13 +206,14 @@ Local inputs and generated artifacts such as `raw/`, `data/`, `output/`, ## Validation -The current `dev/v0.2.0` branch targets: +The current `dev/v0.2.1` branch targets: ```bash python -m pytest tests/ -v --basetemp=.review_pytest_tmp -o cache_dir=.review_pytest_cache mkdocs build --strict -python -m build +python -m build --no-isolation python -m twine check dist/* +python scripts/smoke_install_wheel.py melite --help melite run --help melite export --help @@ -216,7 +227,7 @@ If you use MELITE in your research, please cite it using the metadata in ```text Contreras-Torres, F. F., & Murrieta, A. C. (2026). MELITE: Multi-model -Evaluation and Learning for Inference-ready Tabular Experiments (0.1.11). +Evaluation and Learning for Inference-ready Tabular Experiments (0.2.1). Tecnologico de Monterrey. https://github.com/NanoBiostructuresRG/melite ``` diff --git a/docs/api.md b/docs/api.md index 050a425..006a767 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1,7 +1,7 @@ # API Reference MELITE exposes an intended public API through five symbols. The project is -pre-stable, so this API may change before 0.2.0. Internal modules are importable +pre-stable, so this API may change before 1.0. Internal modules are importable directly but are not part of the public contract. ```python diff --git a/docs/index.md b/docs/index.md index 2e74641..0ad6237 100644 --- a/docs/index.md +++ b/docs/index.md @@ -21,7 +21,7 @@
CI - Version + Version Python versions License: LGPL v3+
@@ -29,9 +29,9 @@ !!! note "Pre-stable" - MELITE is currently in alpha-stage development (`v0.1.x`). Publication on + MELITE is currently in alpha-stage development (`v0.2.x`). Publication on PyPI is prepared under the package name `melite`. Public APIs may - change before 0.2.0. + change before 1.0. ## Workflow @@ -118,32 +118,72 @@ industrial features, or manually selected numeric features. MELITE uses a dataset registry under `[datasets.]`. Each `dataset_id` names one concrete numeric `X` matrix candidate. -```toml -[datasets.morgan_r2_2048] -path = "data/morgan_r2_2048.npz" -label_path = "raw/labels.npy" -family = "fingerprints" -method = "Morgan" - -[datasets.rdkit_descriptors] -path = "data/rdkit_descriptors.npz" -label_path = "raw/labels.npy" -family = "descriptors" -method = "RDKit" - -[datasets.pca85] -path = "data/PCA85.npz" -label_path = "raw/labels.npy" -family = "dimensionality" -method = "PCA" -level = 85 -``` +
+
+ Registry pattern + One dataset id, one numeric matrix. +

Use metadata for reporting and traceability; execution follows the + registered files, not hardcoded dataset families.

+
+
+ +=== "Fingerprints" + + ```toml + [datasets.morgan_r2_2048] + path = "data/morgan_r2_2048.npz" + label_path = "raw/labels.npy" + family = "fingerprints" + method = "Morgan" + variant = "radius2_2048" + ``` + + `morgan_r2_2048` is just a user-defined id. MELITE treats it as a concrete + feature matrix candidate and reports the metadata with its results. + +=== "Descriptors" + + ```toml + [datasets.rdkit_descriptors] + path = "data/rdkit_descriptors.npz" + label_path = "raw/labels.npy" + family = "descriptors" + method = "RDKit" + description = "Curated numeric descriptor table" + ``` + + Descriptor tables follow the same strict contract: numeric, two-dimensional + `X`, plus a label vector loaded from `label_path`. + +=== "Dimensionality" + + ```toml + [datasets.pca85] + path = "data/PCA85.npz" + label_path = "raw/labels.npy" + family = "dimensionality" + method = "PCA" + level = 85 + + [datasets.umap90] + path = "data/UMAP90.npz" + label_path = "raw/labels.npy" + family = "dimensionality" + method = "UMAP" + level = 90 + ``` + + PCA and UMAP are ordinary dataset entries. `method` and `level` preserve + legacy reporting context without driving special execution logic. Required fields are `path` and `label_path`; optional metadata fields are `family`, `method`, `variant`, `level`, and `description`. Legacy `[benchmark].reduction_types` and `levels` configs are still normalized into dataset entries when `[datasets]` is absent. +Each `.npz` dataset must contain an explicit `X` array; missing `X` fails +strict dataset loading. + ## Quick Example ```bash diff --git a/docs/installation.md b/docs/installation.md index 783b8c5..6ea4947 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -49,5 +49,5 @@ melite --version Expected version for this release: ```text -MELITE 0.1.11 +MELITE 0.2.1 ``` diff --git a/docs/release.md b/docs/release.md index d32d7f4..b7b211d 100644 --- a/docs/release.md +++ b/docs/release.md @@ -1,29 +1,19 @@ # Release Notes -MELITE `0.2.0` introduces the generalized tabular dataset registry and keeps -legacy PCA/UMAP configuration compatibility. +MELITE `0.2.1` hardens the generalized tabular dataset workflow while +preserving the top-level public API. -## 0.2.0 Highlights +## 0.2.1 Highlights -- Registers concrete tabular matrices under `[datasets.]`. -- Requires `path` and `label_path`; preserves optional metadata fields - `family`, `method`, `variant`, `level`, and `description`. -- Runs benchmarks through strict `cfg.DATASETS` loading. -- Exports dataset-based artifacts such as `Model_SVC_morgan_r2_2048.pkl`. -- Falls back to legacy `reduction_type` + `level` export rows for older CSVs. - -## 0.1.11 Highlights - -MELITE `0.1.11` prepared the project documentation and package metadata for -the first PyPI publication as `melite`. - -- Uses final release metadata version `0.1.11`. -- Clarifies that MELITE is tabular at the modeling level and consumes numeric - `X` and `y` arrays. -- Documented generalized `[datasets.*]` definitions as a future direction at - that time. -- Does not change functional training, selection, export, prediction, or CLI - behavior. +- `[models].active` controls which model families are trained. +- Export uses strict dataset loading and requires explicit `X` in individual + `.npz` files. +- Installed-wheel smoke validation runs and exports a toy `[datasets.toy]` + workflow outside the repository checkout. +- The public API remains `Config`, `load_datasets`, `plot_cv_distributions`, + `predict`, and `__version__`. +- Legacy `reduction_type` + `level` export rows remain supported, but + individual legacy `.npz` files must contain an explicit `X` array. ## Validation Targets @@ -32,8 +22,9 @@ Before release, validate: ```bash mkdocs build --strict python -m pytest tests/ -v --basetemp=.review_pytest_tmp -o cache_dir=.review_pytest_cache -python -m build +python -m build --no-isolation python -m twine check dist/* +python scripts/smoke_install_wheel.py melite --help melite run --help melite export --help @@ -42,6 +33,6 @@ melite --version ## Full Changelog -The complete version history is maintained in the repository changelog: +The complete release history is maintained in the repository changelog: --8<-- "CHANGELOG.md" diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index ce869fe..3957f5a 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -460,6 +460,73 @@ width: 2rem; } +.ms-dataset-panel { + background: + linear-gradient(135deg, color-mix(in srgb, var(--ms-amber) 14%, transparent), transparent 40%), + linear-gradient(180deg, var(--ms-surface), var(--ms-surface-soft)); + border: 1px solid var(--ms-border); + border-radius: 0.75rem; + box-shadow: 0 0.55rem 1.45rem rgba(40, 37, 84, 0.09); + margin: 1rem 0 0.8rem; + padding: 1rem; +} + +.ms-dataset-panel__intro { + border-left: 0.22rem solid var(--ms-amber); + padding-left: 0.8rem; +} + +.ms-dataset-panel__kicker { + color: var(--ms-text-muted); + display: block; + font-size: 0.62rem; + font-weight: 800; + letter-spacing: 0.09em; + line-height: 1.1; + margin-bottom: 0.35rem; + text-transform: uppercase; +} + +.ms-dataset-panel strong { + color: var(--ms-steel-dark); + display: block; + font-size: 1rem; + line-height: 1.25; +} + +.ms-dataset-panel p { + color: var(--ms-text-muted); + margin: 0.45rem 0 0; + max-width: 42rem; +} + +.md-typeset .tabbed-set { + margin: 0.8rem 0 1.1rem; +} + +.md-typeset .tabbed-labels { + gap: 0.25rem; +} + +.md-typeset .tabbed-labels > label { + border-radius: 0.45rem 0.45rem 0 0; + color: var(--ms-text-muted); + font-weight: 750; + padding-inline: 0.75rem; +} + +.md-typeset .tabbed-labels > label:hover, +.md-typeset .tabbed-set input:checked + label { + color: var(--ms-steel-dark); +} + +.md-typeset .tabbed-content { + background: var(--ms-surface); + border: 1px solid var(--ms-border); + border-radius: 0 0.55rem 0.55rem 0.55rem; + padding: 0.85rem 0.95rem 0.45rem; +} + .md-typeset table:not([class]) td { vertical-align: top; } @@ -545,4 +612,8 @@ .md-typeset .ms-flow__item small { font-size: 0.54rem; } + + .md-typeset .tabbed-content { + padding: 0.75rem 0.7rem 0.35rem; + } } diff --git a/melite/export_best_model.py b/melite/export_best_model.py index fd5126a..b63f1e4 100644 --- a/melite/export_best_model.py +++ b/melite/export_best_model.py @@ -20,7 +20,7 @@ import numpy as np import pandas as pd from .config import Config -from .load_dataset import _load_one_dataset +from .load_dataset import load_datasets, _load_one_dataset from .plot_metrics import plot_cv_distributions from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_validate, RepeatedStratifiedKFold @@ -70,20 +70,19 @@ def __init__(self, cfg: Config): def load_row(self, row: pd.Series) -> Tuple[np.ndarray, np.ndarray]: """Load the dataset referenced by a result row. - New v0.2.0 result rows are resolved by their ``dataset`` id in + Dataset-registry result rows are resolved by their ``dataset`` id in ``cfg.DATASETS``. Older result rows without ``dataset`` fall back to the legacy ``reduction_type`` + ``level`` lookup. """ if "dataset" in row and _has_value(row.get("dataset")): dataset_id = str(row.get("dataset")) try: - spec = self._cfg.DATASETS[dataset_id] + dataset = load_datasets(self._cfg)[dataset_id] except KeyError as exc: raise KeyError( f"Dataset '{dataset_id}' from results.csv is not registered " "in cfg.DATASETS." ) from exc - dataset = _load_one_dataset(dataset_id, spec) return dataset["X"], dataset["y"] return self.load(row.reduction_type, int(row.level)) @@ -129,9 +128,19 @@ def _try_individual_file(self, reduction: str, level: int) -> np.ndarray | None: fp = self._data_root / f"{reduction}{level}.npz" if not fp.exists(): return None - arr = np.load(fp) - self._ensure_labels() - return arr["X"] if "X" in arr.files else arr[arr.files[0]] + dataset_id = f"{reduction}{level}" + spec = { + "path": fp, + "label_path": Path(self._cfg.PATHS["INPUT"]) / "labels.npy", + "metadata": { + "family": "dimensionality", + "method": reduction, + "level": level, + }, + } + dataset = _load_one_dataset(dataset_id, spec) + self._labels = dataset["y"] + return dataset["X"] def _try_aggregated_file(self, reduction: str, level: int) -> np.ndarray | None: fp = self._data_root / f"{reduction}s.npz" @@ -142,7 +151,23 @@ def _try_aggregated_file(self, reduction: str, level: int) -> np.ndarray | None: key = pattern.format(rtype=reduction, lvl=level) if key in arr: self._ensure_labels() - return arr[key] + X = arr[key] + if X.ndim != 2: + raise ValueError( + f"Legacy dataset '{reduction}{level}' X must be 2D; " + f"got shape {X.shape}." + ) + if not np.issubdtype(X.dtype, np.number): + raise ValueError( + f"Legacy dataset '{reduction}{level}' X must be numeric; " + f"got dtype {X.dtype}." + ) + if len(self._labels) != X.shape[0]: + raise ValueError( + f"Legacy dataset '{reduction}{level}' X/y length mismatch: " + f"X has {X.shape[0]} rows, y has {len(self._labels)} labels." + ) + return X raise KeyError(f"Level {level} not found inside {fp.name}.") def _ensure_labels(self) -> None: diff --git a/melite/model_training.py b/melite/model_training.py index 2c9fd63..3db527e 100644 --- a/melite/model_training.py +++ b/melite/model_training.py @@ -86,6 +86,23 @@ def __init__(self, config): "rf": lambda: RandomForestClassifier(random_state=rs, n_jobs=-1), "xgb": lambda: XGBClassifier(eval_metric="logloss", random_state=rs, n_jobs=-1), } + self.active_models = self._validate_active_models() + + def _validate_active_models(self): + active_models = getattr(self.config, "ACTIVE_MODELS", list(self.model_builders)) + if not active_models: + raise ValueError("ACTIVE_MODELS must contain at least one model key.") + + unknown = [model for model in active_models if model not in self.model_builders] + if unknown: + unknown_models = ", ".join(unknown) + valid_models = ", ".join(self.model_builders) + raise ValueError( + f"Unknown active model(s): {unknown_models}. " + f"Valid model keys are: {valid_models}." + ) + + return list(active_models) def _build_cv_strategy(self): cv_cfg = self.config.get_cv_config() @@ -169,7 +186,7 @@ def cross_validate_model(self, model, X_train, y_train): def train_and_select_best_model(self, X_train, y_train, reduction_type, level): """Train all active models and return the best configuration. - For each model key in ``model_builders``, runs grid search followed by + For each configured active model key, runs grid search followed by cross-validation. The model with the highest mean F1-macro is selected. Parameters @@ -209,7 +226,7 @@ def train_and_select_best_model(self, X_train, y_train, reduction_type, level): "auc": None, "auc_std": None, } - for model_name in self.model_builders: + for model_name in self.active_models: model = self.model_builders[model_name]() param_grid = self._filter_param_grid(model_name) tuned_model, params = self.perform_grid_search(model, X_train, y_train, param_grid) diff --git a/melite/version.py b/melite/version.py index 3ce2a66..88dadf4 100644 --- a/melite/version.py +++ b/melite/version.py @@ -6,7 +6,7 @@ and imported by ``result_manager`` to stamp generated reports. """ -__version__ = "0.2.0" +__version__ = "0.2.1" PROJECT_NAME = "MELITE" PROJECT_VERSION = __version__ PROJECT_STATUS = "alpha" diff --git a/mkdocs.yml b/mkdocs.yml index 07b9534..40799d0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -38,7 +38,7 @@ theme: - content.tabs.link extra_css: - - stylesheets/extra.css?v=0.1.11-readiness + - stylesheets/extra.css?v=0.2.1-docs extra: social: diff --git a/pyproject.toml b/pyproject.toml index a034a66..7ccd5aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ dev = [ "pytest>=7.0", "build>=1.0", + "hatchling>=1.18", "twine>=4.0", ] docs = [ diff --git a/scripts/smoke_install_wheel.py b/scripts/smoke_install_wheel.py new file mode 100644 index 0000000..6d306d3 --- /dev/null +++ b/scripts/smoke_install_wheel.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Installed-wheel smoke test for the MELITE toy dataset workflow. + +The script builds a wheel from the current checkout, installs it into a +temporary virtual environment, creates a tiny strict ``[datasets.toy]`` +configuration outside the repository, runs ``melite run --smoke``, exports row +0 non-interactively, and verifies the expected artifacts. +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +import numpy as np + + +REPO_ROOT = Path(__file__).resolve().parents[1] +EXPECTED_API = [ + "Config", + "load_datasets", + "plot_cv_distributions", + "predict", + "__version__", +] + + +def _run(cmd: list[str | Path], *, cwd: Path, env: dict[str, str] | None = None) -> None: + display = " ".join(str(part) for part in cmd) + print(f"[smoke] {display}") + subprocess.run([str(part) for part in cmd], cwd=cwd, env=env, check=True) + + +def _venv_python(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +def _venv_melite(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "melite.exe" + return venv_dir / "bin" / "melite" + + +def _build_wheel(dist_dir: Path) -> Path: + if dist_dir.exists(): + shutil.rmtree(dist_dir) + _run( + [sys.executable, "-m", "build", "--wheel", "--no-isolation", "--outdir", dist_dir], + cwd=REPO_ROOT, + ) + wheels = sorted(dist_dir.glob("melite-*.whl")) + if len(wheels) != 1: + raise RuntimeError(f"Expected exactly one MELITE wheel in {dist_dir}, got {wheels}") + return wheels[0] + + +def _write_toy_project(work_dir: Path) -> Path: + raw_dir = work_dir / "raw" + data_dir = work_dir / "data" + output_dir = work_dir / "output" + raw_dir.mkdir() + data_dir.mkdir() + output_dir.mkdir() + + X = np.array([ + [0.0, 0.1, 1.0], + [0.1, 0.0, 0.9], + [0.2, 0.1, 1.1], + [0.1, 0.2, 1.0], + [0.2, 0.0, 0.8], + [0.0, 0.2, 1.2], + [1.0, 1.1, 0.0], + [1.1, 1.0, 0.1], + [0.9, 1.2, 0.0], + [1.2, 0.9, 0.2], + [1.0, 1.0, 0.0], + [1.1, 1.2, 0.1], + ], dtype=float) + y = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=np.int64) + + label_path = raw_dir / "labels.npy" + dataset_path = data_dir / "toy.npz" + np.save(label_path, y) + np.savez(dataset_path, X=X, y=y) + + config_path = work_dir / "toy_config.toml" + config_path.write_text( + f""" +[paths] +input = "{raw_dir.as_posix()}/" +dataset = "{data_dir.as_posix()}/" +output = "{output_dir.as_posix()}/" + +[benchmark] +random_state = 42 + +[cv] +n_splits = 3 +n_repeats = 1 + +[cv_smoke] +n_splits = 3 +n_repeats = 1 + +[models] +active = ["svc"] + +[datasets.toy] +path = "{dataset_path.as_posix()}" +label_path = "{label_path.as_posix()}" +family = "smoke" +method = "toy" +description = "Installed wheel smoke dataset" +""".lstrip(), + encoding="utf-8", + ) + return config_path + + +def _check_imports(python: Path) -> None: + code = f""" +import pathlib +import melite +expected = {EXPECTED_API!r} +assert melite.__all__ == expected, melite.__all__ +for name in expected: + assert hasattr(melite, name), f"{{name}} missing" +assert 'load_dataset' not in melite.__all__ +assert 'ResultManager' not in melite.__all__ +repo = pathlib.Path({str(REPO_ROOT)!r}).resolve() +module_path = pathlib.Path(melite.__file__).resolve() +assert repo not in module_path.parents, module_path +print(melite.__version__, module_path) +""" + _run([python, "-c", code], cwd=REPO_ROOT.parent) + + +def _verify_outputs(work_dir: Path) -> None: + output_dir = work_dir / "output" + results_csv = output_dir / "results.csv" + model_path = output_dir / "Model_SVC_toy.pkl" + figure_path = output_dir / "figures" / "SVC_toy.png" + + for path in (results_csv, model_path, figure_path): + if not path.exists(): + raise AssertionError(f"Expected artifact was not created: {path}") + + with open(results_csv, newline="", encoding="utf-8") as f: + rows = list(csv.DictReader(f)) + if len(rows) != 1: + raise AssertionError(f"Expected one result row, got {len(rows)}") + row = rows[0] + if row["dataset"] != "toy": + raise AssertionError(f"Expected dataset 'toy', got {row['dataset']!r}") + if row["model_name"] != "SVC": + raise AssertionError(f"Expected model 'SVC', got {row['model_name']!r}") + if row["smoke"] != "True": + raise AssertionError(f"Expected smoke=True row, got {row['smoke']!r}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--keep-temp", + action="store_true", + help="Keep the temporary smoke directory for debugging.", + ) + args = parser.parse_args() + + temp_root = Path(tempfile.mkdtemp(prefix="melite-wheel-smoke-")).resolve() + print(f"[smoke] temp root: {temp_root}") + try: + wheel = _build_wheel(temp_root / "dist") + venv_dir = temp_root / "venv" + work_dir = temp_root / "work" + work_dir.mkdir() + + _run([sys.executable, "-m", "venv", "--system-site-packages", venv_dir], cwd=temp_root) + python = _venv_python(venv_dir) + melite = _venv_melite(venv_dir) + + _run([python, "-m", "pip", "install", "--no-deps", wheel], cwd=temp_root) + _check_imports(python) + + config_path = _write_toy_project(work_dir) + _run([melite, "run", "--config", config_path, "--smoke"], cwd=work_dir) + _run( + [ + melite, + "export", + "--config", + config_path, + "--row", + "0", + "--csv", + work_dir / "output" / "results.csv", + "--outdir", + work_dir / "output", + "--force", + ], + cwd=work_dir, + ) + _verify_outputs(work_dir) + print("[smoke] installed-wheel toy workflow passed") + finally: + if args.keep_temp: + print(f"[smoke] kept temp root: {temp_root}") + else: + shutil.rmtree(temp_root, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/tests/test_export.py b/tests/test_export.py index 629606b..4fa6c81 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -44,6 +44,14 @@ def _write_npz(tmp_path, name, X, y): return path +def _write_npz_without_X(tmp_path, name, Z): + data_dir = tmp_path / "data" + data_dir.mkdir(exist_ok=True) + path = data_dir / f"{name}.npz" + np.savez(path, Z=Z) + return path + + def _write_results_csv(path, fieldnames, row): path.parent.mkdir(exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: @@ -152,7 +160,81 @@ def test_export_dataset_row_uses_dataset_id_for_artifact(monkeypatch, tmp_path): assert (tmp_path / "output" / "Model_SVC_morgan_r2_2048.pkl").exists() -def test_export_legacy_row_falls_back_to_reduction_and_level(monkeypatch, tmp_path): +def test_export_dataset_row_uses_strict_load_datasets(monkeypatch, tmp_path): + import melite.export_best_model as export_module + + label_path, y = _write_labels(tmp_path) + cfg = _make_config(tmp_path) + cfg.DATASETS = { + "maccs": { + "path": str(tmp_path / "data" / "maccs.npz"), + "label_path": str(label_path), + "metadata": {"family": "fingerprints", "method": "MACCS"}, + } + } + csv_path = tmp_path / "output" / "results.csv" + _write_results_csv( + csv_path, + ["dataset", "model_name", "parameters", "smoke"], + { + "dataset": "maccs", + "model_name": "SVC", + "parameters": "{'kernel': 'linear', 'C': 1}", + "smoke": False, + }, + ) + calls = [] + + def fake_load_datasets(config): + calls.append(config) + return { + "maccs": { + "X": np.ones((20, 5)), + "y": y, + "metadata": {"family": "fingerprints", "method": "MACCS"}, + } + } + + monkeypatch.setattr(export_module, "load_datasets", fake_load_datasets) + monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel())) + monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None) + + Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run() + + assert calls == [cfg] + assert (tmp_path / "output" / "Model_SVC_maccs.pkl").exists() + + +def test_export_dataset_npz_without_X_fails_clearly(monkeypatch, tmp_path): + label_path, _ = _write_labels(tmp_path) + dataset_path = _write_npz_without_X(tmp_path, "maccs", np.ones((20, 5))) + cfg = _make_config(tmp_path) + cfg.DATASETS = { + "maccs": { + "path": str(dataset_path), + "label_path": str(label_path), + "metadata": {"family": "fingerprints", "method": "MACCS"}, + } + } + csv_path = tmp_path / "output" / "results.csv" + _write_results_csv( + csv_path, + ["dataset", "model_name", "parameters", "smoke"], + { + "dataset": "maccs", + "model_name": "SVC", + "parameters": "{'kernel': 'linear', 'C': 1}", + "smoke": False, + }, + ) + monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel())) + monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="Required key 'X' not found"): + Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run() + + +def test_export_legacy_row_with_valid_X_and_labels_succeeds(monkeypatch, tmp_path): _, y = _write_labels(tmp_path) _write_npz(tmp_path, "PCA70", np.ones((20, 5)), y) cfg = _make_config(tmp_path) @@ -182,6 +264,37 @@ def test_export_legacy_row_falls_back_to_reduction_and_level(monkeypatch, tmp_pa assert (tmp_path / "output" / "Model_SVC_PCA70.pkl").exists() +def test_export_legacy_npz_without_X_does_not_fallback_to_first_key( + monkeypatch, tmp_path +): + _write_labels(tmp_path) + _write_npz_without_X(tmp_path, "PCA70", np.ones((20, 5))) + cfg = _make_config(tmp_path) + csv_path = tmp_path / "output" / "results.csv" + _write_results_csv( + csv_path, + [ + "reduction_type", "level", "model_name", "parameters", + "f1_macro", "accuracy", "auc_roc", "smoke", + ], + { + "reduction_type": "PCA", + "level": 70, + "model_name": "SVC", + "parameters": "{'kernel': 'linear', 'C': 1}", + "f1_macro": 0.8, + "accuracy": 0.8, + "auc_roc": 0.9, + "smoke": False, + }, + ) + monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel())) + monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None) + + with pytest.raises(ValueError, match="Required key 'X' not found"): + Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run() + + def test_cv_plot_uses_dataset_id_for_figure(monkeypatch, tmp_path): import melite.export_best_model as export_module diff --git a/tests/test_model_training.py b/tests/test_model_training.py new file mode 100644 index 0000000..776888e --- /dev/null +++ b/tests/test_model_training.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for MELITE model training selection behavior.""" + +from types import SimpleNamespace + +import numpy as np +import pytest + +from melite.model_training import MultiModelTrainer + + +def _config(active_models): + return SimpleNamespace( + ACTIVE_MODELS=active_models, + RANDOM_STATE=42, + PARAM_GRID=[ + {"model": ["svc"], "C": [1], "kernel": ["linear"]}, + {"model": ["rf"], "n_estimators": [10]}, + {"model": ["xgb"], "n_estimators": [10]}, + ], + get_cv_config=lambda: { + "n_splits": 2, + "n_repeats": 1, + "random_state": 42, + }, + ) + + +def _trainer_with_fake_training(active_models): + trainer = MultiModelTrainer(_config(active_models)) + trainer.model_builders = { + "svc": lambda: "svc-estimator", + "rf": lambda: "rf-estimator", + "xgb": lambda: "xgb-estimator", + } + calls = [] + + def fake_grid_search(model, X_train, y_train, param_grid): + model_name = model.split("-")[0] + calls.append((model_name, param_grid)) + return f"{model_name}-tuned", {"model": model_name} + + def fake_cross_validate(model, X_train, y_train): + model_name = model.split("-")[0] + scores = {"svc": 0.7, "rf": 0.8, "xgb": 0.6} + return scores[model_name], 0.01, 0.75, 0.02, 0.85, 0.03 + + trainer.perform_grid_search = fake_grid_search + trainer.cross_validate_model = fake_cross_validate + return trainer, calls + + +def test_active_models_svc_trains_only_svc(): + trainer, calls = _trainer_with_fake_training(["svc"]) + + result = trainer.train_and_select_best_model( + np.ones((4, 2)), + np.array([0, 1, 0, 1]), + "PCA", + 70, + ) + + assert [model_name for model_name, _ in calls] == ["svc"] + assert result[0] == "svc-tuned" + assert result[1] == {"model": "svc"} + + +def test_active_models_rf_trains_only_rf(): + trainer, calls = _trainer_with_fake_training(["rf"]) + + result = trainer.train_and_select_best_model( + np.ones((4, 2)), + np.array([0, 1, 0, 1]), + "PCA", + 70, + ) + + assert [model_name for model_name, _ in calls] == ["rf"] + assert result[0] == "rf-tuned" + assert result[1] == {"model": "rf"} + + +def test_invalid_active_model_raises_clear_error(): + with pytest.raises(ValueError, match="Unknown active model\\(s\\): knn"): + MultiModelTrainer(_config(["svc", "knn"])) + + +def test_empty_active_models_raises_clear_error(): + with pytest.raises(ValueError, match="ACTIVE_MODELS must contain at least one"): + MultiModelTrainer(_config([]))