diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 835f6b5..a5b15dc 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -62,30 +62,8 @@ jobs:
- name: Check distributions
run: python -m twine check dist/*
- - name: Smoke install built wheel
- shell: bash
- run: |
- set -euo pipefail
- WHEEL="$(ls dist/*.whl | head -n 1)"
- python -m venv .smoke_venv
- .smoke_venv/bin/python -m pip install --upgrade pip
- .smoke_venv/bin/python -m pip install "$WHEEL"
- mkdir -p .smoke_outside_checkout
- cd .smoke_outside_checkout
- ../.smoke_venv/bin/melite --version
- ../.smoke_venv/bin/python -c "
- import melite
- expected = ['Config', 'load_datasets', 'plot_cv_distributions', 'predict', '__version__']
- assert melite.__all__ == expected, melite.__all__
- for name in expected:
- assert hasattr(melite, name), f'{name} missing'
- assert 'load_dataset' not in melite.__all__, 'load_dataset must not be top-level public API'
- assert 'ResultManager' not in melite.__all__, 'ResultManager must not be top-level public API'
- assert not hasattr(melite, 'Pipeline'), 'Pipeline must not be public'
- from melite.result_manager import ResultManager
- assert ResultManager is not None, 'ResultManager internal import missing'
- print(melite.__version__, 'wheel OK')
- "
+ - name: Smoke installed wheel toy workflow
+ run: python scripts/smoke_install_wheel.py
- name: Smoke install sdist
shell: bash
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 67388d8..99cd477 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
---
+## [v0.2.1] - 2026-05-27
+
+### Changed
+- `[models].active` / `ACTIVE_MODELS` now controls which model families are
+ trained during benchmarking.
+- `melite export` now uses strict dataset loading for registry-based datasets
+ and no longer falls back to `arr.files[0]` when an `.npz` file lacks `X`.
+- Added an installed-wheel smoke test that builds the wheel, installs it
+ outside the repository checkout, runs a toy `[datasets.toy]` smoke benchmark,
+ exports row 0 non-interactively, and verifies generated artifacts.
+
+### Compatibility
+- The top-level public API remains unchanged:
+ `Config`, `load_datasets`, `plot_cv_distributions`, `predict`, and
+ `__version__`.
+- Legacy `reduction_type` + `level` export rows remain supported, but
+ individual legacy `.npz` files must now contain an explicit `X` array.
+
+---
+
## [v0.2.0] - 2026-05-26
### Added
diff --git a/CITATION.cff b/CITATION.cff
index 9c9a2b0..2816b3f 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -2,8 +2,8 @@ cff-version: 1.2.0
message: "If you use this software, please cite it as below."
type: software
title: "MELITE: Multi-model Evaluation and Learning for Inference-ready Tabular Experiments"
-version: "0.2.0"
-date-released: "2026-05-26"
+version: "0.2.1"
+date-released: "2026-05-27"
authors:
- family-names: "Contreras-Torres"
given-names: "Flavio F."
diff --git a/README.md b/README.md
index 0851733..fac5610 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
[](https://github.com/NanoBiostructuresRG/melite/actions/workflows/ci.yml)
[](LICENSE)
-[]()
+[]()
[]()
**MELITE** is a pre-stable Python toolkit for tabular classification
@@ -21,7 +21,7 @@ Project: MELITE
PyPI distribution: melite
Import package: melite
CLI: melite
-Version: 0.2.0
+Version: 0.2.1
License: LGPL-3.0-or-later
Status: alpha / pre-stable
```
@@ -133,6 +133,16 @@ Registered datasets are loaded strictly: missing files, missing `X`, non-2D or
non-numeric `X`, length mismatches, and embedded `y` mismatches fail the run.
Legacy `[benchmark].reduction_types` and `levels` configs are still accepted
and are normalized into equivalent dataset entries such as `PCA70` and `UMAP90`.
+
+Model families are controlled by `[models].active`:
+
+```toml
+[models]
+active = ["svc", "rf", "xgb"]
+```
+
+Remove a key to skip that family during training. Valid keys are `svc`, `rf`,
+and `xgb`.
## CLI
@@ -165,7 +175,7 @@ from melite import __version__
```
Modules not listed above are importable directly but are not part of the public
-contract and may change before 0.2.0.
+contract and may change before 1.0.
## Input Format
@@ -196,13 +206,14 @@ Local inputs and generated artifacts such as `raw/`, `data/`, `output/`,
## Validation
-The current `dev/v0.2.0` branch targets:
+The current `dev/v0.2.1` branch targets:
```bash
python -m pytest tests/ -v --basetemp=.review_pytest_tmp -o cache_dir=.review_pytest_cache
mkdocs build --strict
-python -m build
+python -m build --no-isolation
python -m twine check dist/*
+python scripts/smoke_install_wheel.py
melite --help
melite run --help
melite export --help
@@ -216,7 +227,7 @@ If you use MELITE in your research, please cite it using the metadata in
```text
Contreras-Torres, F. F., & Murrieta, A. C. (2026). MELITE: Multi-model
-Evaluation and Learning for Inference-ready Tabular Experiments (0.1.11).
+Evaluation and Learning for Inference-ready Tabular Experiments (0.2.1).
Tecnologico de Monterrey. https://github.com/NanoBiostructuresRG/melite
```
diff --git a/docs/api.md b/docs/api.md
index 050a425..006a767 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -1,7 +1,7 @@
# API Reference
MELITE exposes an intended public API through five symbols. The project is
-pre-stable, so this API may change before 0.2.0. Internal modules are importable
+pre-stable, so this API may change before 1.0. Internal modules are importable
directly but are not part of the public contract.
```python
diff --git a/docs/index.md b/docs/index.md
index 2e74641..0ad6237 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -21,7 +21,7 @@
@@ -29,9 +29,9 @@
!!! note "Pre-stable"
- MELITE is currently in alpha-stage development (`v0.1.x`). Publication on
+ MELITE is currently in alpha-stage development (`v0.2.x`). Publication on
PyPI is prepared under the package name `melite`. Public APIs may
- change before 0.2.0.
+ change before 1.0.
## Workflow
@@ -118,32 +118,72 @@ industrial features, or manually selected numeric features.
MELITE uses a dataset registry under `[datasets.]`. Each
`dataset_id` names one concrete numeric `X` matrix candidate.
-```toml
-[datasets.morgan_r2_2048]
-path = "data/morgan_r2_2048.npz"
-label_path = "raw/labels.npy"
-family = "fingerprints"
-method = "Morgan"
-
-[datasets.rdkit_descriptors]
-path = "data/rdkit_descriptors.npz"
-label_path = "raw/labels.npy"
-family = "descriptors"
-method = "RDKit"
-
-[datasets.pca85]
-path = "data/PCA85.npz"
-label_path = "raw/labels.npy"
-family = "dimensionality"
-method = "PCA"
-level = 85
-```
+
+
+
Registry pattern
+
One dataset id, one numeric matrix.
+
Use metadata for reporting and traceability; execution follows the
+ registered files, not hardcoded dataset families.
+
+
+
+=== "Fingerprints"
+
+ ```toml
+ [datasets.morgan_r2_2048]
+ path = "data/morgan_r2_2048.npz"
+ label_path = "raw/labels.npy"
+ family = "fingerprints"
+ method = "Morgan"
+ variant = "radius2_2048"
+ ```
+
+ `morgan_r2_2048` is just a user-defined id. MELITE treats it as a concrete
+ feature matrix candidate and reports the metadata with its results.
+
+=== "Descriptors"
+
+ ```toml
+ [datasets.rdkit_descriptors]
+ path = "data/rdkit_descriptors.npz"
+ label_path = "raw/labels.npy"
+ family = "descriptors"
+ method = "RDKit"
+ description = "Curated numeric descriptor table"
+ ```
+
+ Descriptor tables follow the same strict contract: numeric, two-dimensional
+ `X`, plus a label vector loaded from `label_path`.
+
+=== "Dimensionality"
+
+ ```toml
+ [datasets.pca85]
+ path = "data/PCA85.npz"
+ label_path = "raw/labels.npy"
+ family = "dimensionality"
+ method = "PCA"
+ level = 85
+
+ [datasets.umap90]
+ path = "data/UMAP90.npz"
+ label_path = "raw/labels.npy"
+ family = "dimensionality"
+ method = "UMAP"
+ level = 90
+ ```
+
+ PCA and UMAP are ordinary dataset entries. `method` and `level` preserve
+ legacy reporting context without driving special execution logic.
Required fields are `path` and `label_path`; optional metadata fields are
`family`, `method`, `variant`, `level`, and `description`. Legacy
`[benchmark].reduction_types` and `levels` configs are still normalized into
dataset entries when `[datasets]` is absent.
+Each `.npz` dataset must contain an explicit `X` array; missing `X` fails
+strict dataset loading.
+
## Quick Example
```bash
diff --git a/docs/installation.md b/docs/installation.md
index 783b8c5..6ea4947 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -49,5 +49,5 @@ melite --version
Expected version for this release:
```text
-MELITE 0.1.11
+MELITE 0.2.1
```
diff --git a/docs/release.md b/docs/release.md
index d32d7f4..b7b211d 100644
--- a/docs/release.md
+++ b/docs/release.md
@@ -1,29 +1,19 @@
# Release Notes
-MELITE `0.2.0` introduces the generalized tabular dataset registry and keeps
-legacy PCA/UMAP configuration compatibility.
+MELITE `0.2.1` hardens the generalized tabular dataset workflow while
+preserving the top-level public API.
-## 0.2.0 Highlights
+## 0.2.1 Highlights
-- Registers concrete tabular matrices under `[datasets.]`.
-- Requires `path` and `label_path`; preserves optional metadata fields
- `family`, `method`, `variant`, `level`, and `description`.
-- Runs benchmarks through strict `cfg.DATASETS` loading.
-- Exports dataset-based artifacts such as `Model_SVC_morgan_r2_2048.pkl`.
-- Falls back to legacy `reduction_type` + `level` export rows for older CSVs.
-
-## 0.1.11 Highlights
-
-MELITE `0.1.11` prepared the project documentation and package metadata for
-the first PyPI publication as `melite`.
-
-- Uses final release metadata version `0.1.11`.
-- Clarifies that MELITE is tabular at the modeling level and consumes numeric
- `X` and `y` arrays.
-- Documented generalized `[datasets.*]` definitions as a future direction at
- that time.
-- Does not change functional training, selection, export, prediction, or CLI
- behavior.
+- `[models].active` controls which model families are trained.
+- Export uses strict dataset loading and requires explicit `X` in individual
+ `.npz` files.
+- Installed-wheel smoke validation runs and exports a toy `[datasets.toy]`
+ workflow outside the repository checkout.
+- The public API remains `Config`, `load_datasets`, `plot_cv_distributions`,
+ `predict`, and `__version__`.
+- Legacy `reduction_type` + `level` export rows remain supported, but
+ individual legacy `.npz` files must contain an explicit `X` array.
## Validation Targets
@@ -32,8 +22,9 @@ Before release, validate:
```bash
mkdocs build --strict
python -m pytest tests/ -v --basetemp=.review_pytest_tmp -o cache_dir=.review_pytest_cache
-python -m build
+python -m build --no-isolation
python -m twine check dist/*
+python scripts/smoke_install_wheel.py
melite --help
melite run --help
melite export --help
@@ -42,6 +33,6 @@ melite --version
## Full Changelog
-The complete version history is maintained in the repository changelog:
+The complete release history is maintained in the repository changelog:
--8<-- "CHANGELOG.md"
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
index ce869fe..3957f5a 100644
--- a/docs/stylesheets/extra.css
+++ b/docs/stylesheets/extra.css
@@ -460,6 +460,73 @@
width: 2rem;
}
+.ms-dataset-panel {
+ background:
+ linear-gradient(135deg, color-mix(in srgb, var(--ms-amber) 14%, transparent), transparent 40%),
+ linear-gradient(180deg, var(--ms-surface), var(--ms-surface-soft));
+ border: 1px solid var(--ms-border);
+ border-radius: 0.75rem;
+ box-shadow: 0 0.55rem 1.45rem rgba(40, 37, 84, 0.09);
+ margin: 1rem 0 0.8rem;
+ padding: 1rem;
+}
+
+.ms-dataset-panel__intro {
+ border-left: 0.22rem solid var(--ms-amber);
+ padding-left: 0.8rem;
+}
+
+.ms-dataset-panel__kicker {
+ color: var(--ms-text-muted);
+ display: block;
+ font-size: 0.62rem;
+ font-weight: 800;
+ letter-spacing: 0.09em;
+ line-height: 1.1;
+ margin-bottom: 0.35rem;
+ text-transform: uppercase;
+}
+
+.ms-dataset-panel strong {
+ color: var(--ms-steel-dark);
+ display: block;
+ font-size: 1rem;
+ line-height: 1.25;
+}
+
+.ms-dataset-panel p {
+ color: var(--ms-text-muted);
+ margin: 0.45rem 0 0;
+ max-width: 42rem;
+}
+
+.md-typeset .tabbed-set {
+ margin: 0.8rem 0 1.1rem;
+}
+
+.md-typeset .tabbed-labels {
+ gap: 0.25rem;
+}
+
+.md-typeset .tabbed-labels > label {
+ border-radius: 0.45rem 0.45rem 0 0;
+ color: var(--ms-text-muted);
+ font-weight: 750;
+ padding-inline: 0.75rem;
+}
+
+.md-typeset .tabbed-labels > label:hover,
+.md-typeset .tabbed-set input:checked + label {
+ color: var(--ms-steel-dark);
+}
+
+.md-typeset .tabbed-content {
+ background: var(--ms-surface);
+ border: 1px solid var(--ms-border);
+ border-radius: 0 0.55rem 0.55rem 0.55rem;
+ padding: 0.85rem 0.95rem 0.45rem;
+}
+
.md-typeset table:not([class]) td {
vertical-align: top;
}
@@ -545,4 +612,8 @@
.md-typeset .ms-flow__item small {
font-size: 0.54rem;
}
+
+ .md-typeset .tabbed-content {
+ padding: 0.75rem 0.7rem 0.35rem;
+ }
}
diff --git a/melite/export_best_model.py b/melite/export_best_model.py
index fd5126a..b63f1e4 100644
--- a/melite/export_best_model.py
+++ b/melite/export_best_model.py
@@ -20,7 +20,7 @@
import numpy as np
import pandas as pd
from .config import Config
-from .load_dataset import _load_one_dataset
+from .load_dataset import load_datasets, _load_one_dataset
from .plot_metrics import plot_cv_distributions
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_validate, RepeatedStratifiedKFold
@@ -70,20 +70,19 @@ def __init__(self, cfg: Config):
def load_row(self, row: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
"""Load the dataset referenced by a result row.
- New v0.2.0 result rows are resolved by their ``dataset`` id in
+ Dataset-registry result rows are resolved by their ``dataset`` id in
``cfg.DATASETS``. Older result rows without ``dataset`` fall back to
the legacy ``reduction_type`` + ``level`` lookup.
"""
if "dataset" in row and _has_value(row.get("dataset")):
dataset_id = str(row.get("dataset"))
try:
- spec = self._cfg.DATASETS[dataset_id]
+ dataset = load_datasets(self._cfg)[dataset_id]
except KeyError as exc:
raise KeyError(
f"Dataset '{dataset_id}' from results.csv is not registered "
"in cfg.DATASETS."
) from exc
- dataset = _load_one_dataset(dataset_id, spec)
return dataset["X"], dataset["y"]
return self.load(row.reduction_type, int(row.level))
@@ -129,9 +128,19 @@ def _try_individual_file(self, reduction: str, level: int) -> np.ndarray | None:
fp = self._data_root / f"{reduction}{level}.npz"
if not fp.exists():
return None
- arr = np.load(fp)
- self._ensure_labels()
- return arr["X"] if "X" in arr.files else arr[arr.files[0]]
+ dataset_id = f"{reduction}{level}"
+ spec = {
+ "path": fp,
+ "label_path": Path(self._cfg.PATHS["INPUT"]) / "labels.npy",
+ "metadata": {
+ "family": "dimensionality",
+ "method": reduction,
+ "level": level,
+ },
+ }
+ dataset = _load_one_dataset(dataset_id, spec)
+ self._labels = dataset["y"]
+ return dataset["X"]
def _try_aggregated_file(self, reduction: str, level: int) -> np.ndarray | None:
fp = self._data_root / f"{reduction}s.npz"
@@ -142,7 +151,23 @@ def _try_aggregated_file(self, reduction: str, level: int) -> np.ndarray | None:
key = pattern.format(rtype=reduction, lvl=level)
if key in arr:
self._ensure_labels()
- return arr[key]
+ X = arr[key]
+ if X.ndim != 2:
+ raise ValueError(
+ f"Legacy dataset '{reduction}{level}' X must be 2D; "
+ f"got shape {X.shape}."
+ )
+ if not np.issubdtype(X.dtype, np.number):
+ raise ValueError(
+ f"Legacy dataset '{reduction}{level}' X must be numeric; "
+ f"got dtype {X.dtype}."
+ )
+ if len(self._labels) != X.shape[0]:
+ raise ValueError(
+ f"Legacy dataset '{reduction}{level}' X/y length mismatch: "
+ f"X has {X.shape[0]} rows, y has {len(self._labels)} labels."
+ )
+ return X
raise KeyError(f"Level {level} not found inside {fp.name}.")
def _ensure_labels(self) -> None:
diff --git a/melite/model_training.py b/melite/model_training.py
index 2c9fd63..3db527e 100644
--- a/melite/model_training.py
+++ b/melite/model_training.py
@@ -86,6 +86,23 @@ def __init__(self, config):
"rf": lambda: RandomForestClassifier(random_state=rs, n_jobs=-1),
"xgb": lambda: XGBClassifier(eval_metric="logloss", random_state=rs, n_jobs=-1),
}
+ self.active_models = self._validate_active_models()
+
+ def _validate_active_models(self):
+ active_models = getattr(self.config, "ACTIVE_MODELS", list(self.model_builders))
+ if not active_models:
+ raise ValueError("ACTIVE_MODELS must contain at least one model key.")
+
+ unknown = [model for model in active_models if model not in self.model_builders]
+ if unknown:
+ unknown_models = ", ".join(unknown)
+ valid_models = ", ".join(self.model_builders)
+ raise ValueError(
+ f"Unknown active model(s): {unknown_models}. "
+ f"Valid model keys are: {valid_models}."
+ )
+
+ return list(active_models)
def _build_cv_strategy(self):
cv_cfg = self.config.get_cv_config()
@@ -169,7 +186,7 @@ def cross_validate_model(self, model, X_train, y_train):
def train_and_select_best_model(self, X_train, y_train, reduction_type, level):
"""Train all active models and return the best configuration.
- For each model key in ``model_builders``, runs grid search followed by
+ For each configured active model key, runs grid search followed by
cross-validation. The model with the highest mean F1-macro is selected.
Parameters
@@ -209,7 +226,7 @@ def train_and_select_best_model(self, X_train, y_train, reduction_type, level):
"auc": None, "auc_std": None,
}
- for model_name in self.model_builders:
+ for model_name in self.active_models:
model = self.model_builders[model_name]()
param_grid = self._filter_param_grid(model_name)
tuned_model, params = self.perform_grid_search(model, X_train, y_train, param_grid)
diff --git a/melite/version.py b/melite/version.py
index 3ce2a66..88dadf4 100644
--- a/melite/version.py
+++ b/melite/version.py
@@ -6,7 +6,7 @@
and imported by ``result_manager`` to stamp generated reports.
"""
-__version__ = "0.2.0"
+__version__ = "0.2.1"
PROJECT_NAME = "MELITE"
PROJECT_VERSION = __version__
PROJECT_STATUS = "alpha"
diff --git a/mkdocs.yml b/mkdocs.yml
index 07b9534..40799d0 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -38,7 +38,7 @@ theme:
- content.tabs.link
extra_css:
- - stylesheets/extra.css?v=0.1.11-readiness
+ - stylesheets/extra.css?v=0.2.1-docs
extra:
social:
diff --git a/pyproject.toml b/pyproject.toml
index a034a66..7ccd5aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,7 @@ dependencies = [
dev = [
"pytest>=7.0",
"build>=1.0",
+ "hatchling>=1.18",
"twine>=4.0",
]
docs = [
diff --git a/scripts/smoke_install_wheel.py b/scripts/smoke_install_wheel.py
new file mode 100644
index 0000000..6d306d3
--- /dev/null
+++ b/scripts/smoke_install_wheel.py
@@ -0,0 +1,221 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+"""Installed-wheel smoke test for the MELITE toy dataset workflow.
+
+The script builds a wheel from the current checkout, installs it into a
+temporary virtual environment, creates a tiny strict ``[datasets.toy]``
+configuration outside the repository, runs ``melite run --smoke``, exports row
+0 non-interactively, and verifies the expected artifacts.
+"""
+
+from __future__ import annotations
+
+import argparse
+import csv
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+from pathlib import Path
+
+import numpy as np
+
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+EXPECTED_API = [
+ "Config",
+ "load_datasets",
+ "plot_cv_distributions",
+ "predict",
+ "__version__",
+]
+
+
+def _run(cmd: list[str | Path], *, cwd: Path, env: dict[str, str] | None = None) -> None:
+ display = " ".join(str(part) for part in cmd)
+ print(f"[smoke] {display}")
+ subprocess.run([str(part) for part in cmd], cwd=cwd, env=env, check=True)
+
+
+def _venv_python(venv_dir: Path) -> Path:
+ if os.name == "nt":
+ return venv_dir / "Scripts" / "python.exe"
+ return venv_dir / "bin" / "python"
+
+
+def _venv_melite(venv_dir: Path) -> Path:
+ if os.name == "nt":
+ return venv_dir / "Scripts" / "melite.exe"
+ return venv_dir / "bin" / "melite"
+
+
+def _build_wheel(dist_dir: Path) -> Path:
+ if dist_dir.exists():
+ shutil.rmtree(dist_dir)
+ _run(
+ [sys.executable, "-m", "build", "--wheel", "--no-isolation", "--outdir", dist_dir],
+ cwd=REPO_ROOT,
+ )
+ wheels = sorted(dist_dir.glob("melite-*.whl"))
+ if len(wheels) != 1:
+ raise RuntimeError(f"Expected exactly one MELITE wheel in {dist_dir}, got {wheels}")
+ return wheels[0]
+
+
+def _write_toy_project(work_dir: Path) -> Path:
+ raw_dir = work_dir / "raw"
+ data_dir = work_dir / "data"
+ output_dir = work_dir / "output"
+ raw_dir.mkdir()
+ data_dir.mkdir()
+ output_dir.mkdir()
+
+ X = np.array([
+ [0.0, 0.1, 1.0],
+ [0.1, 0.0, 0.9],
+ [0.2, 0.1, 1.1],
+ [0.1, 0.2, 1.0],
+ [0.2, 0.0, 0.8],
+ [0.0, 0.2, 1.2],
+ [1.0, 1.1, 0.0],
+ [1.1, 1.0, 0.1],
+ [0.9, 1.2, 0.0],
+ [1.2, 0.9, 0.2],
+ [1.0, 1.0, 0.0],
+ [1.1, 1.2, 0.1],
+ ], dtype=float)
+ y = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=np.int64)
+
+ label_path = raw_dir / "labels.npy"
+ dataset_path = data_dir / "toy.npz"
+ np.save(label_path, y)
+ np.savez(dataset_path, X=X, y=y)
+
+ config_path = work_dir / "toy_config.toml"
+ config_path.write_text(
+ f"""
+[paths]
+input = "{raw_dir.as_posix()}/"
+dataset = "{data_dir.as_posix()}/"
+output = "{output_dir.as_posix()}/"
+
+[benchmark]
+random_state = 42
+
+[cv]
+n_splits = 3
+n_repeats = 1
+
+[cv_smoke]
+n_splits = 3
+n_repeats = 1
+
+[models]
+active = ["svc"]
+
+[datasets.toy]
+path = "{dataset_path.as_posix()}"
+label_path = "{label_path.as_posix()}"
+family = "smoke"
+method = "toy"
+description = "Installed wheel smoke dataset"
+""".lstrip(),
+ encoding="utf-8",
+ )
+ return config_path
+
+
+def _check_imports(python: Path) -> None:
+ code = f"""
+import pathlib
+import melite
+expected = {EXPECTED_API!r}
+assert melite.__all__ == expected, melite.__all__
+for name in expected:
+ assert hasattr(melite, name), f"{{name}} missing"
+assert 'load_dataset' not in melite.__all__
+assert 'ResultManager' not in melite.__all__
+repo = pathlib.Path({str(REPO_ROOT)!r}).resolve()
+module_path = pathlib.Path(melite.__file__).resolve()
+assert repo not in module_path.parents, module_path
+print(melite.__version__, module_path)
+"""
+ _run([python, "-c", code], cwd=REPO_ROOT.parent)
+
+
+def _verify_outputs(work_dir: Path) -> None:
+ output_dir = work_dir / "output"
+ results_csv = output_dir / "results.csv"
+ model_path = output_dir / "Model_SVC_toy.pkl"
+ figure_path = output_dir / "figures" / "SVC_toy.png"
+
+ for path in (results_csv, model_path, figure_path):
+ if not path.exists():
+ raise AssertionError(f"Expected artifact was not created: {path}")
+
+ with open(results_csv, newline="", encoding="utf-8") as f:
+ rows = list(csv.DictReader(f))
+ if len(rows) != 1:
+ raise AssertionError(f"Expected one result row, got {len(rows)}")
+ row = rows[0]
+ if row["dataset"] != "toy":
+ raise AssertionError(f"Expected dataset 'toy', got {row['dataset']!r}")
+ if row["model_name"] != "SVC":
+ raise AssertionError(f"Expected model 'SVC', got {row['model_name']!r}")
+ if row["smoke"] != "True":
+ raise AssertionError(f"Expected smoke=True row, got {row['smoke']!r}")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--keep-temp",
+ action="store_true",
+ help="Keep the temporary smoke directory for debugging.",
+ )
+ args = parser.parse_args()
+
+ temp_root = Path(tempfile.mkdtemp(prefix="melite-wheel-smoke-")).resolve()
+ print(f"[smoke] temp root: {temp_root}")
+ try:
+ wheel = _build_wheel(temp_root / "dist")
+ venv_dir = temp_root / "venv"
+ work_dir = temp_root / "work"
+ work_dir.mkdir()
+
+ _run([sys.executable, "-m", "venv", "--system-site-packages", venv_dir], cwd=temp_root)
+ python = _venv_python(venv_dir)
+ melite = _venv_melite(venv_dir)
+
+ _run([python, "-m", "pip", "install", "--no-deps", wheel], cwd=temp_root)
+ _check_imports(python)
+
+ config_path = _write_toy_project(work_dir)
+ _run([melite, "run", "--config", config_path, "--smoke"], cwd=work_dir)
+ _run(
+ [
+ melite,
+ "export",
+ "--config",
+ config_path,
+ "--row",
+ "0",
+ "--csv",
+ work_dir / "output" / "results.csv",
+ "--outdir",
+ work_dir / "output",
+ "--force",
+ ],
+ cwd=work_dir,
+ )
+ _verify_outputs(work_dir)
+ print("[smoke] installed-wheel toy workflow passed")
+ finally:
+ if args.keep_temp:
+ print(f"[smoke] kept temp root: {temp_root}")
+ else:
+ shutil.rmtree(temp_root, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/test_export.py b/tests/test_export.py
index 629606b..4fa6c81 100644
--- a/tests/test_export.py
+++ b/tests/test_export.py
@@ -44,6 +44,14 @@ def _write_npz(tmp_path, name, X, y):
return path
+def _write_npz_without_X(tmp_path, name, Z):
+ data_dir = tmp_path / "data"
+ data_dir.mkdir(exist_ok=True)
+ path = data_dir / f"{name}.npz"
+ np.savez(path, Z=Z)
+ return path
+
+
def _write_results_csv(path, fieldnames, row):
path.parent.mkdir(exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as f:
@@ -152,7 +160,81 @@ def test_export_dataset_row_uses_dataset_id_for_artifact(monkeypatch, tmp_path):
assert (tmp_path / "output" / "Model_SVC_morgan_r2_2048.pkl").exists()
-def test_export_legacy_row_falls_back_to_reduction_and_level(monkeypatch, tmp_path):
+def test_export_dataset_row_uses_strict_load_datasets(monkeypatch, tmp_path):
+ import melite.export_best_model as export_module
+
+ label_path, y = _write_labels(tmp_path)
+ cfg = _make_config(tmp_path)
+ cfg.DATASETS = {
+ "maccs": {
+ "path": str(tmp_path / "data" / "maccs.npz"),
+ "label_path": str(label_path),
+ "metadata": {"family": "fingerprints", "method": "MACCS"},
+ }
+ }
+ csv_path = tmp_path / "output" / "results.csv"
+ _write_results_csv(
+ csv_path,
+ ["dataset", "model_name", "parameters", "smoke"],
+ {
+ "dataset": "maccs",
+ "model_name": "SVC",
+ "parameters": "{'kernel': 'linear', 'C': 1}",
+ "smoke": False,
+ },
+ )
+ calls = []
+
+ def fake_load_datasets(config):
+ calls.append(config)
+ return {
+ "maccs": {
+ "X": np.ones((20, 5)),
+ "y": y,
+ "metadata": {"family": "fingerprints", "method": "MACCS"},
+ }
+ }
+
+ monkeypatch.setattr(export_module, "load_datasets", fake_load_datasets)
+ monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel()))
+ monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None)
+
+ Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run()
+
+ assert calls == [cfg]
+ assert (tmp_path / "output" / "Model_SVC_maccs.pkl").exists()
+
+
+def test_export_dataset_npz_without_X_fails_clearly(monkeypatch, tmp_path):
+ label_path, _ = _write_labels(tmp_path)
+ dataset_path = _write_npz_without_X(tmp_path, "maccs", np.ones((20, 5)))
+ cfg = _make_config(tmp_path)
+ cfg.DATASETS = {
+ "maccs": {
+ "path": str(dataset_path),
+ "label_path": str(label_path),
+ "metadata": {"family": "fingerprints", "method": "MACCS"},
+ }
+ }
+ csv_path = tmp_path / "output" / "results.csv"
+ _write_results_csv(
+ csv_path,
+ ["dataset", "model_name", "parameters", "smoke"],
+ {
+ "dataset": "maccs",
+ "model_name": "SVC",
+ "parameters": "{'kernel': 'linear', 'C': 1}",
+ "smoke": False,
+ },
+ )
+ monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel()))
+ monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None)
+
+ with pytest.raises(ValueError, match="Required key 'X' not found"):
+ Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run()
+
+
+def test_export_legacy_row_with_valid_X_and_labels_succeeds(monkeypatch, tmp_path):
_, y = _write_labels(tmp_path)
_write_npz(tmp_path, "PCA70", np.ones((20, 5)), y)
cfg = _make_config(tmp_path)
@@ -182,6 +264,37 @@ def test_export_legacy_row_falls_back_to_reduction_and_level(monkeypatch, tmp_pa
assert (tmp_path / "output" / "Model_SVC_PCA70.pkl").exists()
+def test_export_legacy_npz_without_X_does_not_fallback_to_first_key(
+ monkeypatch, tmp_path
+):
+ _write_labels(tmp_path)
+ _write_npz_without_X(tmp_path, "PCA70", np.ones((20, 5)))
+ cfg = _make_config(tmp_path)
+ csv_path = tmp_path / "output" / "results.csv"
+ _write_results_csv(
+ csv_path,
+ [
+ "reduction_type", "level", "model_name", "parameters",
+ "f1_macro", "accuracy", "auc_roc", "smoke",
+ ],
+ {
+ "reduction_type": "PCA",
+ "level": 70,
+ "model_name": "SVC",
+ "parameters": "{'kernel': 'linear', 'C': 1}",
+ "f1_macro": 0.8,
+ "accuracy": 0.8,
+ "auc_roc": 0.9,
+ "smoke": False,
+ },
+ )
+ monkeypatch.setattr(Finalizer, "_build_model", staticmethod(lambda *_: DummyModel()))
+ monkeypatch.setattr(Finalizer, "_cv_and_plot", lambda *args, **kwargs: None)
+
+ with pytest.raises(ValueError, match="Required key 'X' not found"):
+ Finalizer(csv_path, tmp_path / "output", cfg, row_index=0).run()
+
+
def test_cv_plot_uses_dataset_id_for_figure(monkeypatch, tmp_path):
import melite.export_best_model as export_module
diff --git a/tests/test_model_training.py b/tests/test_model_training.py
new file mode 100644
index 0000000..776888e
--- /dev/null
+++ b/tests/test_model_training.py
@@ -0,0 +1,90 @@
+# SPDX-License-Identifier: LGPL-3.0-or-later
+"""Tests for MELITE model training selection behavior."""
+
+from types import SimpleNamespace
+
+import numpy as np
+import pytest
+
+from melite.model_training import MultiModelTrainer
+
+
+def _config(active_models):
+ return SimpleNamespace(
+ ACTIVE_MODELS=active_models,
+ RANDOM_STATE=42,
+ PARAM_GRID=[
+ {"model": ["svc"], "C": [1], "kernel": ["linear"]},
+ {"model": ["rf"], "n_estimators": [10]},
+ {"model": ["xgb"], "n_estimators": [10]},
+ ],
+ get_cv_config=lambda: {
+ "n_splits": 2,
+ "n_repeats": 1,
+ "random_state": 42,
+ },
+ )
+
+
+def _trainer_with_fake_training(active_models):
+ trainer = MultiModelTrainer(_config(active_models))
+ trainer.model_builders = {
+ "svc": lambda: "svc-estimator",
+ "rf": lambda: "rf-estimator",
+ "xgb": lambda: "xgb-estimator",
+ }
+ calls = []
+
+ def fake_grid_search(model, X_train, y_train, param_grid):
+ model_name = model.split("-")[0]
+ calls.append((model_name, param_grid))
+ return f"{model_name}-tuned", {"model": model_name}
+
+ def fake_cross_validate(model, X_train, y_train):
+ model_name = model.split("-")[0]
+ scores = {"svc": 0.7, "rf": 0.8, "xgb": 0.6}
+ return scores[model_name], 0.01, 0.75, 0.02, 0.85, 0.03
+
+ trainer.perform_grid_search = fake_grid_search
+ trainer.cross_validate_model = fake_cross_validate
+ return trainer, calls
+
+
+def test_active_models_svc_trains_only_svc():
+ trainer, calls = _trainer_with_fake_training(["svc"])
+
+ result = trainer.train_and_select_best_model(
+ np.ones((4, 2)),
+ np.array([0, 1, 0, 1]),
+ "PCA",
+ 70,
+ )
+
+ assert [model_name for model_name, _ in calls] == ["svc"]
+ assert result[0] == "svc-tuned"
+ assert result[1] == {"model": "svc"}
+
+
+def test_active_models_rf_trains_only_rf():
+ trainer, calls = _trainer_with_fake_training(["rf"])
+
+ result = trainer.train_and_select_best_model(
+ np.ones((4, 2)),
+ np.array([0, 1, 0, 1]),
+ "PCA",
+ 70,
+ )
+
+ assert [model_name for model_name, _ in calls] == ["rf"]
+ assert result[0] == "rf-tuned"
+ assert result[1] == {"model": "rf"}
+
+
+def test_invalid_active_model_raises_clear_error():
+ with pytest.raises(ValueError, match="Unknown active model\\(s\\): knn"):
+ MultiModelTrainer(_config(["svc", "knn"]))
+
+
+def test_empty_active_models_raises_clear_error():
+ with pytest.raises(ValueError, match="ACTIVE_MODELS must contain at least one"):
+ MultiModelTrainer(_config([]))