diff --git a/python/trlib/README.md b/python/trlib/README.md index bd8bd90a..153e2186 100644 --- a/python/trlib/README.md +++ b/python/trlib/README.md @@ -184,9 +184,12 @@ interactive menu — those live in `tr/tr2` only. - **Single instance per process.** TR backend uses COMMON blocks. Two concurrent `Trlib()` instances share state; the second `tr_init` resets globals. -- **No graphics, no MPI, no OpenMP API.** Graphics symbols exist but - are not reachable from the 5 exported entry points; the loader uses - `RTLD_LAZY` so dangling graphics references never resolve. +- **No PGPlot/GSAF graphics, no MPI, no OpenMP API.** Fortran-side + graphics symbols exist but are not reachable from the 5 exported + entry points; the loader uses `RTLD_LAZY` so dangling graphics + references never resolve. Python-side visualization is provided + separately via `trlib.plot` (matplotlib backend, optional dependency) + — see the Plot section below. - **String parameters not yet wired** (e.g. `KNAMEQ`, `KNAMTR`). See `docs/superpowers/specs/2026-04-17-tr-library-design.md` §4.3. - **Unregistered namelist keys** — any name missing from @@ -226,6 +229,116 @@ top-level license. Bug reports and PRs are welcome; please keep wrapper changes minimal — the C ABI is the stable layer, so new parameters should be added to the Fortran registry first. +## TOML config 実行方法 + +`python -m trlib ` でパラメータ設定 + 計算 + プロットを 1 +コマンドで実行できます。CLI 版の namelist 入力 (`./tr < tr_iter01.in`) +を Python 側に置き換えるための入口です。 + +サンプル config はリポジトリ内に同梱しています: + +- `python/trlib/samples/iter01.toml` — `test_run/inputs/tr_iter01.in` を TOML 化したもの +- `python/trlib/samples/tst2.toml` — `test_run/inputs/tr_tst2.in` を TOML 化したもの + +```bash +# 計算 + プロットを一気に実行 (libtrapi.so + matplotlib が必要) +python -m trlib python/trlib/samples/iter01.toml + +# 設定だけ確認 (ライブラリ未ビルドでも OK) +python -m trlib python/trlib/samples/iter01.toml --dry-run + +# NTMAX を上書き +python -m trlib python/trlib/samples/tst2.toml --ntmax 5 + +# プロットだけスキップ (matplotlib が無い環境向け) +python -m trlib python/trlib/samples/iter01.toml --no-plots +``` + +### TOML スキーマ + +```toml +[module] +name = "tr" +ntmax = 100 # NTMAX scalar への alias + +[scalars] +RR = 3.0 +NSMAX = 4 + +[arrays] +PN = [0.7, 0.315, 0.315, 0.035] # 1-origin リスト +# PA = { 2 = 1.0 } # sparse dict 形式 (PA(1) は default) + +[strings] +KNAMEQ = "eqdata.ITER" + +[[plots]] # 配列 of tables で複数プロット +variable = "RNT" +output = "file" # window | file | return +format = "png" +path = "./plots/rnt.png" +title = "温度密度プロファイル" +``` + +### Exit code + +| code | 意味 | +|---|---| +| 0 | 正常終了 | +| 1 | ライブラリ / 計算エラー | +| 2 | config エラー (未存在ファイル / TOML 構文エラー / matplotlib 不足) | + +## プロット + +`trlib.plot` は :mod:`matplotlib` をオプション依存とする可視化レイヤー +です。`trlib` 本体は matplotlib なしでも import できますが、`plot()` を +呼ぶと `ImportError` になります。 + +### 対話的に呼ぶ + +```bash +python -c "from trlib import Trlib; \ + tr = Trlib(); tr.run(0); tr.plot('RNT'); tr.close()" +``` + +### 利用可能な変数を確認する + +```python +from trlib import Trlib +print(Trlib.plot_available()) # ['AJ', 'ALI', 'BETA0', ..., 'WPT'] +``` + +`VARIABLE_INFO` 辞書 (in `trlib/plot.py`) に登録された変数のみが描画でき +ます。新しい変数を追加するときはこの dict にエントリを足してください。 + +### 出力モード + +| `output` | 用途 | 戻り値 | +|---|---|---| +| `"window"` (default) | CLI / インタラクティブ表示 | `None` | +| `"file"` | バッチ / CI で画像保存 | `pathlib.Path` | +| `"return"` | notebook embed / 後処理 | `matplotlib.figure.Figure` | + +### サンプル + +```python +from trlib import Trlib + +with Trlib() as tr: + tr.set_params(RR=8.5, RA=2.0, BB=5.3, NSMAX=2, DT=0.1) + tr.run(50) + tr.plot("RNT", output="file", path="./rnt.png") # PNG 保存 + fig = tr.plot("AJ", output="return") # Figure 受け取り +``` + +### matplotlib が無い環境での挙動 + +`pip install matplotlib` を行わずに `trlib.plot` を import / `.plot()` を +呼ぶと、明示的な `ImportError` が発生します。`__main__` では plot +セクションが無い限り matplotlib を import しないため、`--no-plots` を +付ければ matplotlib 未インストール環境でも `python -m trlib` を実行でき +ます。 + ## See also - `docs/superpowers/specs/2026-04-17-tr-library-design.md` — full diff --git a/python/trlib/__main__.py b/python/trlib/__main__.py new file mode 100644 index 00000000..4e2f8166 --- /dev/null +++ b/python/trlib/__main__.py @@ -0,0 +1,170 @@ +"""``python -m trlib `` — TOML-driven runner. + +Pipeline:: + + init → apply_config → run(ntmax) → get_state → run_plots → finalize + +Exit codes: + +* 0 — success +* 1 — library / calculation error +* 2 — config error (missing file, malformed TOML, unknown variable) + +Flags: + +* ``--dry-run`` — parse and validate the TOML, print a summary, but + skip every libtrapi.so call. Useful for CI and quick verification. +* ``--ntmax N`` — override ``[module].ntmax`` / ``[scalars].NTMAX``. +* ``--no-plots`` — apply scalars/arrays and run, but skip every plot + spec. Handy when matplotlib is not installed in the runner env. +* ``--help`` — argparse-generated usage. +""" +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import List, Optional, Sequence + +from .loader import apply_config, load_config, run_plots, run_sweep_plots + + +_EXIT_OK = 0 +_EXIT_LIB = 1 +_EXIT_CONFIG = 2 + + +def build_parser() -> argparse.ArgumentParser: + """Return the argparse parser used by :func:`main`.""" + parser = argparse.ArgumentParser( + prog="python -m trlib", + description=( + "Run a TASK/TR simulation defined by a TOML config file and " + "optionally render plots. See python/trlib/samples/ for " + "example configurations." + ), + ) + parser.add_argument( + "config", type=Path, + help="Path to a TOML config file.", + ) + parser.add_argument( + "--dry-run", action="store_true", + help="Parse the config and print a summary; skip every library call.", + ) + parser.add_argument( + "--ntmax", type=int, default=None, + help="Override the NTMAX scalar from the config.", + ) + parser.add_argument( + "--no-plots", action="store_true", + help="Skip [plot] / [[plots]] execution even if the config defines them.", + ) + return parser + + +def _summarise(cfg: dict) -> str: + """Return a short human-readable summary of a parsed config.""" + lines = [] + module = cfg.get("module") or {} + lines.append(f"module: {module.get('name', '')}") + ntmax = cfg.get("scalars", {}).get("NTMAX") + if ntmax is not None: + lines.append(f"NTMAX: {ntmax}") + lines.append(f"scalars: {len(cfg.get('scalars', {}))} keys") + lines.append(f"arrays: {len(cfg.get('arrays', {}))} keys") + lines.append(f"strings: {len(cfg.get('strings', {}))} keys") + lines.append(f"plots: {len(cfg.get('plots', []))} specs") + return "\n".join(lines) + + +def main(argv: Optional[Sequence[str]] = None) -> int: + """Program entry point. Returns a process exit code.""" + parser = build_parser() + args = parser.parse_args(argv) + + # --- Config load --------------------------------------------------- + if not args.config.exists(): + print(f"[trlib] config not found: {args.config}", file=sys.stderr) + return _EXIT_CONFIG + try: + cfg = load_config(args.config) + except Exception as exc: + print(f"[trlib] failed to parse {args.config}: {exc}", file=sys.stderr) + return _EXIT_CONFIG + + # --- CLI overrides ------------------------------------------------- + if args.ntmax is not None: + cfg.setdefault("scalars", {})["NTMAX"] = int(args.ntmax) + if args.no_plots: + cfg["plots"] = [] + + # --- Summary + dry run -------------------------------------------- + print(f"[trlib] loaded {args.config}") + print(_summarise(cfg)) + if args.dry_run: + print("[trlib] --dry-run: skipping library calls") + return _EXIT_OK + + ntmax = int(cfg.get("scalars", {}).get("NTMAX", 0)) + + # --- Live run ------------------------------------------------------ + # Lazy import: importing ``Trlib`` triggers _ffi.load_library() which + # probes libtrapi.so. Doing this only inside the live branch keeps + # --dry-run functional on systems where the .so is not built. + try: + from . import Trlib + except Exception as exc: # pragma: no cover - extremely unusual + print(f"[trlib] cannot import Trlib: {exc}", file=sys.stderr) + return _EXIT_LIB + + try: + # State-dependent plots run inside the with-block (need live tr). + with Trlib() as tr: + apply_config(tr, cfg) + tr.run(ntmax=ntmax) + state = tr.get_state() + print( + f"[trlib] run complete: NT={state.nt} NRMAX={state.nrmax} " + f"NSMAX={state.nsmax}" + ) + if cfg.get("plots"): + try: + results = run_plots(tr, cfg) + except ImportError as exc: + print(f"[trlib] plot backend unavailable: {exc}", + file=sys.stderr) + return _EXIT_CONFIG + except (KeyError, ValueError, TypeError) as exc: + # User-config errors (unknown variable, bad output=, etc.) + print(f"[trlib] plot config error: {exc}", + file=sys.stderr) + return _EXIT_CONFIG + for name, descriptor in results: + print(f"[trlib] plot {name} -> {descriptor}") + # Sweep plots run AFTER the outer Trlib closes — each sweep + # sample needs its own tr_init/tr_run/tr_finalize cycle and + # would collide with the still-live outer instance. + if cfg.get("plots"): + try: + sweep_results = run_sweep_plots(cfg) + except ImportError as exc: + print(f"[trlib] plot backend unavailable: {exc}", + file=sys.stderr) + return _EXIT_CONFIG + except (KeyError, ValueError, TypeError) as exc: + # User-config errors in sweep spec (missing range, unknown y, etc.) + print(f"[trlib] sweep config error: {exc}", + file=sys.stderr) + return _EXIT_CONFIG + for name, descriptor in sweep_results: + print(f"[trlib] plot {name} -> {descriptor}") + except Exception as exc: + print(f"[trlib] library error: {exc}", file=sys.stderr) + return _EXIT_LIB + + return _EXIT_OK + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/trlib/loader.py b/python/trlib/loader.py new file mode 100644 index 00000000..3e113793 --- /dev/null +++ b/python/trlib/loader.py @@ -0,0 +1,300 @@ +"""TOML configuration loader for :mod:`trlib`. + +The schema (mirrored from ``project_toml_sample_runner.md``):: + + [module] + name = "tr" + ntmax = 100 # overrides NTMAX scalar if present + + [scalars] + RR = 3.0 + NSMAX = 4 + + [arrays] + PN = [0.7, 0.315, 0.315, 0.035] # 1-origin list + # PN = {1 = 0.7, 2 = 0.315} # sparse dict form + + [strings] + KNAMEQ = "eqdata.ITER" + + [[plots]] + variable = "RNT" + output = "file" # window | file | return + format = "png" + path = "./plots/rnt.png" + title = "RNT profile" + + [plot] # singular form also accepted + variable = "RT" + output = "file" + path = "./plots/rt.png" + +The module exports three pure functions: + +* :func:`load_config` — parse TOML into a structured dict. +* :func:`apply_config` — replay ``scalars``/``arrays``/``strings`` onto + a live :class:`~trlib.Trlib` instance. +* :func:`run_plots` — execute all ``[plot]`` / ``[[plots]]`` specs. + +``load_config`` never touches libtrapi.so, so tests can exercise TOML +parsing without the shared library present. +""" +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict, IO, List, Mapping, Tuple, Union + +# tomllib is stdlib since Python 3.11. Fall back to the third-party +# ``tomli`` package for older runtimes (matches the request in the +# design spec). +try: + import tomllib as _toml # type: ignore[import-not-found] +except ImportError: # pragma: no cover - 3.10 fallback + try: + import tomli as _toml # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover + raise ImportError( + "trlib.loader requires Python 3.11+ (stdlib tomllib) or " + "`pip install tomli` on 3.10 and earlier." + ) from exc + + +PathLike = Union[str, "os.PathLike[str]"] +ConfigInput = Union[PathLike, IO[bytes], IO[str], bytes, str] + + +# ===================================================================== +# Parsing +# ===================================================================== +def _loads(raw: Union[bytes, str]) -> Dict[str, Any]: + """Wrap ``tomllib.loads`` so both ``bytes`` and ``str`` inputs work.""" + if isinstance(raw, bytes): + raw = raw.decode("utf-8") + return _toml.loads(raw) + + +def _read_stream(src: ConfigInput) -> Dict[str, Any]: + """Read TOML data from a path, stream, or raw str/bytes. + + Accepted shapes: + + * :class:`os.PathLike` (e.g. :class:`pathlib.Path`) — read as binary file. + * ``str`` / ``bytes`` — treated as a filesystem path if it looks like + one, else as inline TOML text. + * file-like with ``.read()`` returning ``str`` / ``bytes``. + """ + # PathLike (Path, etc.) — always a path. + if isinstance(src, os.PathLike): + with open(src, "rb") as fh: + return _toml.load(fh) + if isinstance(src, (bytes, str)): + if _looks_like_toml(src): + return _loads(src) + # Treat as path. + with open(src, "rb") as fh: + return _toml.load(fh) + if hasattr(src, "read"): + data = src.read() + return _loads(data) + raise TypeError(f"unsupported TOML source: {type(src).__name__}") + + +def _looks_like_toml(s: Union[bytes, str]) -> bool: + """Heuristic: strings containing newlines or ``=`` are inline TOML, + everything else is treated as a filesystem path.""" + text = s.decode("utf-8", errors="ignore") if isinstance(s, bytes) else s + # An actual path does not contain '\n' or a top-level '=' sign or + # square brackets. This is a small heuristic; callers that want + # precise behaviour can wrap inline text in ``io.BytesIO``. + return ("\n" in text) or ("=" in text and len(text) > 40) or ("[" in text and "]" in text and "=" in text) + + +def load_config(path_or_stream: ConfigInput) -> Dict[str, Any]: + """Parse TOML config and return a structured dict. + + The result has a stable shape regardless of which sections the + author supplied: + + ``{"module": {...}, "scalars": {...}, "arrays": {...}, + "strings": {...}, "plots": [...]}`` + + Unknown keys are preserved under ``"raw"`` so future loader versions + can read older configs. + """ + data = _read_stream(path_or_stream) + if not isinstance(data, Mapping): # pragma: no cover - tomllib invariant + raise ValueError("TOML root must be a table") + + module = dict(data.get("module", {})) if isinstance(data.get("module"), Mapping) else {} + scalars = dict(data.get("scalars", {})) if isinstance(data.get("scalars"), Mapping) else {} + arrays = dict(data.get("arrays", {})) if isinstance(data.get("arrays"), Mapping) else {} + strings = dict(data.get("strings", {})) if isinstance(data.get("strings"), Mapping) else {} + + # Collect [plot] and [[plots]] into a single list. Order: singular + # first, then array-of-tables, so authors can "pin" one default + # plot before the declarative batch. + plots: List[Dict[str, Any]] = [] + singular = data.get("plot") + if isinstance(singular, Mapping): + plots.append(dict(singular)) + multi = data.get("plots") + if isinstance(multi, list): + for entry in multi: + if isinstance(entry, Mapping): + plots.append(dict(entry)) + + # Module-level ntmax propagates into scalars if the author did not + # also set NTMAX explicitly. This matches the schema-spec precedent + # that `[module] ntmax = 100` is a convenience alias. + if "ntmax" in module and "NTMAX" not in scalars: + scalars["NTMAX"] = int(module["ntmax"]) + + return { + "module": module, + "scalars": scalars, + "arrays": arrays, + "strings": strings, + "plots": plots, + "raw": dict(data), + } + + +# ===================================================================== +# Application +# ===================================================================== +# ``Any`` instead of ``Trlib`` on the annotation keeps this module +# importable without pulling in libtrapi.so (Trlib.__init__ loads the +# shared library eagerly). +def apply_config(tr: Any, cfg: Mapping[str, Any]) -> None: + """Replay ``scalars`` / ``arrays`` / ``strings`` onto a live Trlib. + + Order of application: strings first (some strings, e.g. ``KNAMEQ``, + are read inside tr_init's follow-up calls), then scalars, then + arrays — matching the ``fixtures/tr_iter01_params.py::apply`` + precedent. + """ + for name, value in cfg.get("strings", {}).items(): + tr.set_param_str(name, str(value)) + for name, value in cfg.get("scalars", {}).items(): + tr.set_param(name, float(value)) + for name, arr in cfg.get("arrays", {}).items(): + _apply_array(tr, name, arr) + + +def _apply_array(tr: Any, name: str, arr: Any) -> None: + """Apply an ``arrays`` entry (list or ``{idx: val}`` dict).""" + if isinstance(arr, Mapping): + for k, v in arr.items(): + tr.set_param(f"{name}[{int(k)}]", float(v)) + elif isinstance(arr, (list, tuple)): + for i, v in enumerate(arr, start=1): + tr.set_param(f"{name}[{i}]", float(v)) + else: + raise ValueError( + f"[arrays] {name!r} must be a list or {{idx: val}} dict, " + f"got {type(arr).__name__}" + ) + + +# ===================================================================== +# Plot execution +# ===================================================================== +def run_plots(tr: Any, cfg: Mapping[str, Any]) -> List[Tuple[str, Any]]: + """Execute every state-dependent plot spec in ``cfg`` against a live + :class:`Trlib`. **Sweep plots are skipped** here — they require their + own isolated Trlib lifecycle (see :func:`run_sweep_plots`) because the + Fortran COMMON-block backend is single-instance per process. + + Returns a list of ``(varname, output_descriptor)`` tuples where + ``output_descriptor`` is the return value of :func:`trlib.plot.plot` + (:class:`Path` when ``output="file"``, :class:`Figure` when + ``output="return"``, :obj:`None` when ``output="window"``). + """ + from . import plot as _plot_mod # lazy: matplotlib is optional + + plots = cfg.get("plots", []) + results: List[Tuple[str, Any]] = [] + if not plots: + return results + + # Cache the state once per run so multiple plots don't re-run the + # simulation backend. + state = tr.get_state() + for spec in plots: + if not isinstance(spec, Mapping): + continue + # Sweep plots are deferred to run_sweep_plots() (own Trlib lifecycle). + # Skip them here regardless of whether they also carry a `variable` + # key, otherwise they would be processed twice. + if spec.get("kind") == "sweep": + continue + varname = spec.get("variable") + if not varname: + continue + kw = _plot_kwargs(spec) + descriptor = _plot_mod.plot(varname, state=state, **kw) + results.append((varname, descriptor)) + return results + + +def run_sweep_plots(cfg: Mapping[str, Any]) -> List[Tuple[str, Any]]: + """Execute sweep / compare plot specs that need their own Trlib + lifecycle. MUST be called AFTER any outer ``with Trlib()`` block has + exited — sweeps open fresh Trlib instances internally and would + collide with a still-live caller instance. + """ + from . import plot as _plot_mod # lazy + + plots = cfg.get("plots", []) + results: List[Tuple[str, Any]] = [] + if not plots: + return results + for spec in plots: + if not isinstance(spec, Mapping): + continue + kind = spec.get("kind") + if kind == "sweep": + results.append(_run_sweep_spec(spec, _plot_mod)) + return results + + +def _plot_kwargs(spec: Mapping[str, Any]) -> Dict[str, Any]: + """Translate a ``[[plots]]`` entry into :func:`trlib.plot.plot` kwargs.""" + kw: Dict[str, Any] = {} + for key in ("output", "format", "path", "title", "overlay"): + if key in spec: + kw[key] = spec[key] + return kw + + +def _run_sweep_spec(spec: Mapping[str, Any], plot_mod: Any) -> Tuple[str, Any]: + """Helper for ``kind = "sweep"`` entries. + + No `tr` argument: plot_mod.plot_sweep manages its own per-sample + Trlib lifecycle internally. + """ + param = spec["param"] + y = spec["y"] + rng = tuple(spec["range"]) + if len(rng) != 3: + raise ValueError( + f"sweep 'range' must be [start, stop, n_samples]; got {rng!r}" + ) + kw = _plot_kwargs(spec) + # Sweep-specific kwargs that _plot_kwargs (shared with regular plots) + # does not extract. Forward them only when the user supplied them. + if "ntmax" in spec: + kw["ntmax"] = spec["ntmax"] + if "base_params" in spec: + kw["base_params"] = spec["base_params"] + descriptor = plot_mod.plot_sweep(param, y, range=rng, **kw) + return (f"sweep:{param}->{y}", descriptor) + + +__all__ = [ + "load_config", + "apply_config", + "run_plots", + "run_sweep_plots", +] diff --git a/python/trlib/plot.py b/python/trlib/plot.py new file mode 100644 index 00000000..8eb29f3a --- /dev/null +++ b/python/trlib/plot.py @@ -0,0 +1,459 @@ +"""Visualization layer for :mod:`trlib`. + +Reference implementation of the "declarative plot(varname)" API described +in ``project_visualization_followup.md``. Other modules (ti / wr / wrx / +fp / eq) will copy this file and adjust their ``VARIABLE_INFO`` + +``_extract_series`` when they reach Phase L+viz. + +Design points: + +* :func:`plot_available` returns the list of variable names that have + registered metadata in :data:`VARIABLE_INFO`. Callers (human or LLM via + MCP) can check support before asking for a plot. +* :func:`plot` takes a :class:`~trlib.state.TrState` snapshot and draws a + 1D profile (or a species-stacked profile for 2D fields like ``RN`` / + ``RT``). +* ``output="window"`` shows an interactive figure (``plt.show``); + ``output="file"`` saves to disk and returns the :class:`pathlib.Path`; + ``output="return"`` returns the :class:`matplotlib.figure.Figure` so + notebook callers can embed it in their cell output. +* :func:`plot_sweep` runs an ad-hoc parameter sweep over a single + scalar and plots the resulting scalar quantity. It creates and + destroys a :class:`~trlib.Trlib` context internally. + +matplotlib is an **optional** dependency. If it is missing, importing +:mod:`trlib.plot` raises :class:`ImportError` with an actionable +message; ``trlib`` itself stays importable either way. +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +try: + import matplotlib # type: ignore[import-not-found] + import matplotlib.pyplot as plt # type: ignore[import-not-found] + from matplotlib.figure import Figure # type: ignore[import-not-found] +except ImportError as _mpl_err: # pragma: no cover - exercised in env without mpl + raise ImportError( + "trlib.plot requires matplotlib. Install it via " + "`pip install matplotlib` or mark the feature as optional by " + "not importing trlib.plot." + ) from _mpl_err + +from .state import TrState + + +# ===================================================================== +# Variable metadata (labels, axis info, units). +# +# Hard-coded dict is a deliberate short-term choice — keeping it in +# Python is easier to edit and review than pulling from JSON/YAML. +# Future revisions may move this to a schema file shared with the MCP +# `describe_state_schema` tool. +# ===================================================================== + + +VARIABLE_INFO: Dict[str, Dict[str, Any]] = { + # --- 1D radial profiles (length nrmax) -------------------------- + "AJ": { + "label": "AJ (current density profile)", + "ylabel": "AJ [MA/m^2]", + "xaxis": "rg", + "kind": "profile_1d", + "dim": 1, + }, + "QP": { + "label": "QP (safety factor profile)", + "ylabel": "q", + "xaxis": "rg", + "kind": "profile_1d", + "dim": 1, + }, + # --- 2D species-stacked profiles (nrmax x nsmax) ---------------- + "RN": { + "label": "RN (density profile per species)", + "ylabel": "n [10^20 m^-3]", + "xaxis": "rg", + "kind": "profile_2d", + "dim": 2, + }, + "RT": { + "label": "RT (temperature profile per species)", + "ylabel": "T [keV]", + "xaxis": "rg", + "kind": "profile_2d", + "dim": 2, + }, + # --- Aliases matching the diagnostic-plot naming used in + # the Fortran GSAF screens (RNT = density-like, RWT = temperature- + # like). Keeping both aliases and canonical names means a TOML + # author can write either. ---------------------------------------- + "RNT": { + "label": "RNT (density profile, species-stacked)", + "ylabel": "n [10^20 m^-3]", + "xaxis": "rg", + "kind": "profile_2d", + "dim": 2, + "alias_of": "RN", + }, + "RWT": { + "label": "RWT (temperature profile, species-stacked)", + "ylabel": "T [keV]", + "xaxis": "rg", + "kind": "profile_2d", + "dim": 2, + "alias_of": "RT", + }, + # --- Scalar-over-sweep plots are handled via plot_sweep(); these + # live in the same registry so plot_available() can surface them. + "T": { + "label": "T (time)", + "ylabel": "t [s]", + "kind": "scalar", + "dim": 0, + }, + "WPT": { + "label": "WPT (stored energy)", + "ylabel": "W [MJ]", + "kind": "scalar", + "dim": 0, + }, + "AJT": { + "label": "AJT (total plasma current)", + "ylabel": "Ip [MA]", + "kind": "scalar", + "dim": 0, + }, + "Q0": { + "label": "Q0 (on-axis safety factor)", + "ylabel": "q(0)", + "kind": "scalar", + "dim": 0, + }, + "BETA0": {"label": "BETA0", "ylabel": "beta_0", "kind": "scalar", "dim": 0}, + "BETAP0": {"label": "BETAP0", "ylabel": "beta_p0", "kind": "scalar", "dim": 0}, + "BETAA": {"label": "BETAA", "ylabel": "beta_a", "kind": "scalar", "dim": 0}, + "BETAN": {"label": "BETAN", "ylabel": "beta_N", "kind": "scalar", "dim": 0}, + "TAUE1": {"label": "TAUE1 (confinement time)", "ylabel": "tau_E1 [s]", "kind": "scalar", "dim": 0}, + "TAUE2": {"label": "TAUE2 (confinement time)", "ylabel": "tau_E2 [s]", "kind": "scalar", "dim": 0}, + "ZEFF0": {"label": "ZEFF0 (effective charge)", "ylabel": "Z_eff(0)", "kind": "scalar", "dim": 0}, + "ALI": {"label": "ALI (internal inductance)", "ylabel": "l_i", "kind": "scalar", "dim": 0}, + "RQ1": {"label": "RQ1 (q=1 surface radius)", "ylabel": "r(q=1)", "kind": "scalar", "dim": 0}, +} + + +# ===================================================================== +# Public API +# ===================================================================== +def plot_available() -> List[str]: + """Return the canonical list of variables that have plot support. + + The list is sorted alphabetically so automated callers (MCP / + doctests) get a stable ordering. + """ + return sorted(VARIABLE_INFO.keys()) + + +def _resolve_varname(varname: str) -> Tuple[str, Dict[str, Any]]: + """Translate aliases and return (canonical_name, info_dict).""" + if varname not in VARIABLE_INFO: + raise KeyError( + f"variable {varname!r} has no plot support. " + f"Use trlib.plot.plot_available() to list supported variables." + ) + info = VARIABLE_INFO[varname] + canonical = info.get("alias_of", varname) + return canonical, info + + +def _profile_xaxis(n: int) -> List[float]: + """Return ``[0, 1/(n-1), ..., 1]`` as a normalised radial axis.""" + if n <= 0: + return [] + if n == 1: + return [0.0] + return [i / (n - 1) for i in range(n)] + + +def _in_notebook() -> bool: + """Return True when called from within an IPython / Jupyter kernel.""" + try: + from IPython import get_ipython # type: ignore[import-not-found] + except Exception: + return False + ip = get_ipython() + if ip is None: + return False + # ZMQInteractiveShell = notebook / qtconsole; TerminalInteractiveShell = plain ipython + return "ZMQInteractive" in type(ip).__name__ + + +def _draw_profile( + fig: "Figure", + varname: str, + info: Dict[str, Any], + state: TrState, + *, + overlay: bool, + title: Optional[str], + display_name: Optional[str] = None, +) -> None: + """Populate ``fig`` with a 1D or 2D profile plot. + + `varname` is the canonical TrState attribute (e.g. RN). `display_name` + is the user-requested name (e.g. RNT) which may differ when an alias + was resolved; it's used in legend labels so users see the name they + asked for. + """ + ax = fig.gca() + kind = info.get("kind") + label_name = display_name or varname + + if kind == "profile_1d": + data: List[float] = list(getattr(state, varname)) + x = _profile_xaxis(len(data)) + ax.plot(x, data, marker=".", linewidth=1.0, label=label_name) + elif kind == "profile_2d": + data2d: List[List[float]] = list(getattr(state, varname)) + x = _profile_xaxis(len(data2d)) + nsmax = state.nsmax + for j in range(nsmax): + series = [row[j] for row in data2d] + ax.plot(x, series, marker=".", linewidth=1.0, + label=f"{label_name}[*,{j + 1}]") + ax.legend(loc="best", fontsize="small") + elif kind == "scalar": + # A scalar cannot be plotted from a single snapshot. Instead we + # draw a trivial bar so the caller still gets a figure — this + # also exercises the same code path the sweep plot will use. + value = state.scalars.get(varname, 0.0) + ax.bar([label_name], [value]) + ax.set_ylabel(info.get("ylabel", label_name)) + else: # pragma: no cover - exhaustive guard + raise ValueError(f"unknown plot kind: {kind!r} for {varname}") + + # Profile plots default to "rg" x-axis; scalar bar charts have no + # meaningful x-axis category so leave it blank. + if kind in ("profile_1d", "profile_2d"): + ax.set_xlabel(info.get("xaxis", "rg")) + ax.set_ylabel(info.get("ylabel", label_name)) + else: + ax.set_xlabel(info.get("xaxis", "")) + ax.set_title(title or info.get("label", label_name)) + ax.grid(True, alpha=0.3) + # overlay=True means "caller will add more curves"; leave the axes alone. + if not overlay: + fig.tight_layout() + + +def plot( + varname: str, + *, + state: Optional[TrState] = None, + output: str = "window", + format: str = "png", + path: Optional[Union[str, Path]] = None, + title: Optional[str] = None, + overlay: bool = False, + figure: Optional["Figure"] = None, + **kwargs: Any, +) -> Union["Figure", Path, None]: + """Plot a variable from a :class:`TrState`. + + Parameters + ---------- + varname: + Key into :data:`VARIABLE_INFO`. Raises :class:`KeyError` if unknown. + state: + The :class:`TrState` to draw from. If None, callers must pass a + ``figure`` already populated with data — useful for the ``overlay`` + code path used by :class:`Trlib.plot`. + output: + ``"window"`` (default, interactive), ``"file"`` (save and return + Path), or ``"return"`` (return :class:`Figure`). + format: + When ``output="file"``: file extension (png / pdf / svg / jpg / eps). + path: + Output path (only for ``output="file"``). Derived from ``varname`` + and ``format`` when omitted. + title: + Figure title. Defaults to the ``"label"`` field in ``VARIABLE_INFO``. + overlay: + Reserved for future multi-variable overlays. Currently forwarded + to the drawing routine, which skips the final ``tight_layout`` + call so the caller can keep adding curves before finalising. + figure: + Optional existing :class:`Figure` to reuse (for overlays). + + Returns + ------- + * :class:`pathlib.Path` when ``output="file"`` + * :class:`matplotlib.figure.Figure` when ``output="return"`` + * :obj:`None` when ``output="window"`` + """ + canonical, info = _resolve_varname(varname) + if output not in ("window", "file", "return"): + raise ValueError( + f"output must be one of 'window' / 'file' / 'return', got {output!r}" + ) + + if state is None and figure is None: + raise ValueError( + "plot() needs either a TrState (state=...) or an existing " + "figure to render into. Call Trlib.plot() for the common case." + ) + + fig = figure if figure is not None else plt.figure() + if state is not None: + _draw_profile( + fig, canonical, info, state, + overlay=overlay, title=title, + display_name=varname, # preserve user-facing alias (e.g. "RNT") + ) + + if output == "return": + return fig + if output == "file": + # Default filename uses the user-requested varname (preserving + # aliases like "RNT") rather than the resolved canonical so the + # saved file matches what the caller asked for. + out_path = Path(path) if path else Path(f"{varname}.{format}") + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, format=format) + plt.close(fig) + return out_path + + # output == "window" + if _in_notebook(): + plt.show(block=False) + else: # pragma: no cover - cannot drive a real window in tests + plt.show(block=True) + return None + + +# ===================================================================== +# Sweep plotting +# ===================================================================== +def plot_sweep( + param: str, + y: str, + *, + sweep_range: Tuple[float, float, int] = None, + output: str = "window", + format: str = "png", + path: Optional[Union[str, Path]] = None, + title: Optional[str] = None, + ntmax: int = 0, + base_params: Optional[Dict[str, Any]] = None, + range: Tuple[float, float, int] = None, # backward-compat alias + **kwargs: Any, +) -> Union["Figure", Path, None]: + """Run a 1D scan over ``param`` and plot the scalar ``y`` on the y-axis. + + Parameters + ---------- + param: + Scalar parameter name passed to :meth:`Trlib.set_param`. + y: + Scalar key into :attr:`TrState.scalars`. + sweep_range (or ``range`` for backward compat): + ``(start, stop, n_samples)`` triple (inclusive endpoints). + output / format / path / title: + Same semantics as :func:`plot`. + ntmax: + Number of time-steps to advance per sample (default 0 to keep the + scan fast; pass a positive number to exercise the real transport). + base_params: + Optional dict of scalars applied to every sample point before + setting ``param``. + """ + from .trlib import Trlib # local import to avoid cycle + + # Accept the legacy keyword name `range` but prefer `sweep_range`. + rng = sweep_range if sweep_range is not None else range + if rng is None: + raise TypeError("plot_sweep requires sweep_range=(start, stop, n)") + + # Validate output up front (mirror plot() behaviour) so a TOML typo + # like `output = "bogus"` fails loudly instead of silently falling + # through to the window branch. + if output not in ("window", "file", "return"): + raise ValueError( + f"output must be one of 'window' / 'file' / 'return', got {output!r}" + ) + + # plot_sweep currently only handles scalar y (state.scalars[y]). Profile + # variables like RN/RNT/AJ exist in VARIABLE_INFO but are not scalars, + # so accepting them would silently produce a flat-zero plot. Reject + # explicitly with a helpful message. + if y not in VARIABLE_INFO: + raise KeyError( + f"y variable {y!r} has no plot support. See plot_available()." + ) + y_info = VARIABLE_INFO[y] + if y_info.get("kind") != "scalar": + scalars = sorted(k for k, v in VARIABLE_INFO.items() + if v.get("kind") == "scalar") + raise ValueError( + f"plot_sweep y={y!r} is a {y_info.get('kind','?')} variable; " + f"only scalars are supported. Choose one of: {scalars}" + ) + start, stop, n = rng + n = int(n) # TOML may pass float; range() requires int + if n < 2: + raise ValueError("plot_sweep needs at least 2 samples") + # Use the builtin via __builtins__ since `range` parameter shadows it. + import builtins as _builtins + xs = [float(start) + (float(stop) - float(start)) * i / (n - 1) + for i in _builtins.range(n)] + ys: List[float] = [] + + # Sweep MUST own its own Trlib lifecycle AND re-init per sample: + # - One outer Trlib + iteration would inherit cumulative state from the + # previous sample (tr.run advances from the previous end state). + # - Even with ntmax=0, scalars derived at tr_init time (WPT, AJT, Q0) + # would freeze at the first sample's value. + # So: open a fresh Trlib() for each sample. tr_finalize cleans up, + # then the next iteration's tr_init re-reads defaults + applies params. + # This means each sample is a completely independent run. + for x in xs: + with Trlib() as tr: + if base_params: + for k, v in base_params.items(): + tr.set_param(k, float(v)) + tr.set_param(param, float(x)) + tr.run(ntmax=ntmax) + state = tr.get_state() + ys.append(float(state.scalars.get(y, 0.0))) + + fig = plt.figure() + ax = fig.gca() + ax.plot(xs, ys, marker="o", linewidth=1.0) + ax.set_xlabel(param) + ax.set_ylabel(VARIABLE_INFO[y].get("ylabel", y)) + ax.set_title(title or f"{y} vs {param} ({n} samples)") + ax.grid(True, alpha=0.3) + fig.tight_layout() + + if output == "return": + return fig + if output == "file": + out_path = Path(path) if path else Path(f"sweep_{param}_{y}.{format}") + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, format=format) + plt.close(fig) + return out_path + if _in_notebook(): + plt.show(block=False) + else: # pragma: no cover + plt.show(block=True) + return None + + +__all__ = [ + "VARIABLE_INFO", + "plot", + "plot_available", + "plot_sweep", +] diff --git a/python/trlib/samples/iter01.toml b/python/trlib/samples/iter01.toml new file mode 100644 index 00000000..014c4448 --- /dev/null +++ b/python/trlib/samples/iter01.toml @@ -0,0 +1,73 @@ +# TASK/TR ITER01 reference configuration. +# +# Mirrors test_run/inputs/tr_iter01.in so the Python runner reproduces +# the same physics case without the namelist CLI. Run with: +# +# python -m trlib python/trlib/samples/iter01.toml +# +# Add --dry-run to inspect the parsed config without loading +# libtrapi.so. + +[module] +name = "tr" +ntmax = 100 + +[scalars] +MODELG = 3 +NSMAX = 4 +PROFN2 = 0.15 +MDLNF = 1 +PNBR0 = 0.0 +PNBRW = 1.0 +PNBENG = 1000.0 +PNBRTG = 6.2 +PICCD = 0.1 +PICR0 = 0.0 +PICRW = 0.2 +PICNPR = 5.0 +PECCD = 0.5 +PECR0 = 1.1 +PECRW = 0.05 +PECNPR = 5.0 +PLHCD = 1.0 +PLHR0 = 0.8 +PLHRW = 0.2 +PLHNPR = 2.0 +DT = 0.02 +NTSTEP = 100 +NTMAX = 100 +RIPS = 2.0 +RIPE = 7.0 + +[arrays] +PN = [0.7, 0.315, 0.315, 0.035] +PNS = [0.1, 0.045, 0.045, 0.005] +PT = [1.0, 1.0, 1.0, 1.0] +PTS = [0.1, 0.1, 0.1, 0.1] + +[strings] +KNAMEQ = "eqdata.ITER01" + +# --- visualization ----------------------------------------------------- +# Each [[plots]] entry is forwarded to trlib.plot.plot(variable, ...). + +[[plots]] +variable = "RNT" +output = "file" +format = "png" +path = "./plots/iter01_rnt.png" +title = "ITER01: density profile" + +[[plots]] +variable = "RWT" +output = "file" +format = "png" +path = "./plots/iter01_rwt.png" +title = "ITER01: temperature profile" + +[[plots]] +variable = "QP" +output = "file" +format = "png" +path = "./plots/iter01_qp.png" +title = "ITER01: safety factor" diff --git a/python/trlib/samples/tst2.toml b/python/trlib/samples/tst2.toml new file mode 100644 index 00000000..1923a6de --- /dev/null +++ b/python/trlib/samples/tst2.toml @@ -0,0 +1,57 @@ +# TASK/TR TST-2 reference configuration. +# +# Mirrors test_run/inputs/tr_tst2.in. Run with: +# +# python -m trlib python/trlib/samples/tst2.toml + +[module] +name = "tr" +ntmax = 10 + +[scalars] +MODELG = 3 +NSMAX = 2 +PROFN1 = 2.0 +PROFN2 = 1.0 +MDLIMP = 3 +PNC = 0.00001 +PLHCD = 0.0 +PLHR0 = 0.15 +PLHRW = 0.05 +PLHNPR = 4.0 +RIPS = 0.015 +RIPE = 0.015 +NTSTEP = 1 +NGTSTP = 1 +NGRSTP = 10 +DT = 1.0e-5 +NTMAX = 10 +PLHTOT = 0.0 + +[arrays] +# PA(2) = 1.0, PZ(2) = 1.0 — sparse dict form so PA(1)/PZ(1) inherit +# the tr_init defaults (matches fixtures/tr_tst2_params.py). +PA = { 2 = 1.0 } +PZ = { 2 = 1.0 } +PN = [0.010, 0.010] +PNS = [0.001, 0.001] +PT = [0.010, 0.0010] +PTS = [0.001, 0.0001] + +[strings] +KNAMEQ = "eqdata.TST-2" + +# --- visualization --------------------------------------------------- +[[plots]] +variable = "RNT" +output = "file" +format = "png" +path = "./plots/tst2_rnt.png" +title = "TST-2: density profile" + +[[plots]] +variable = "RWT" +output = "file" +format = "png" +path = "./plots/tst2_rwt.png" +title = "TST-2: temperature profile" diff --git a/python/trlib/tests/test_loader.py b/python/trlib/tests/test_loader.py new file mode 100644 index 00000000..e2536765 --- /dev/null +++ b/python/trlib/tests/test_loader.py @@ -0,0 +1,191 @@ +"""Unit tests for :mod:`trlib.loader`. + +Pure parsing / application tests — libtrapi.so is not needed because +we use a :class:`_FakeTrlib` double to record calls. +""" +from __future__ import annotations + +import io +import sys +import tempfile +import textwrap +import unittest +from pathlib import Path + +HERE = Path(__file__).resolve() +PYTHON_ROOT = HERE.parents[2] +if str(PYTHON_ROOT) not in sys.path: + sys.path.insert(0, str(PYTHON_ROOT)) + +from trlib import loader # noqa: E402 + + +class _FakeTrlib: + """Record set_param / set_param_str calls for assertion. + + The loader shouldn't care that this isn't a real :class:`Trlib`; it + only touches ``set_param`` / ``set_param_str`` / ``get_state``. + """ + + def __init__(self): + self.scalar_calls = [] + self.string_calls = [] + self.closed = False + + def set_param(self, name, value): + self.scalar_calls.append((name, value)) + + def set_param_str(self, name, value): + self.string_calls.append((name, value)) + + def get_state(self): + # Minimal stand-in so run_plots can iterate; the real plot call + # path is exercised in tests/test_plot.py. + raise NotImplementedError + + +_SAMPLE_TOML = textwrap.dedent(""" + [module] + name = "tr" + ntmax = 42 + + [scalars] + RR = 3.0 + NSMAX = 4 + + [arrays] + PN = [0.7, 0.315, 0.315, 0.035] + PA = { 2 = 1.0 } + + [strings] + KNAMEQ = "eqdata.ITER" + + [[plots]] + variable = "RNT" + output = "file" + format = "png" + path = "./plots/rnt.png" + + [[plots]] + variable = "AJ" + output = "return" +""").strip() + + +class TestLoadConfig(unittest.TestCase): + + def test_parse_inline_string(self): + cfg = loader.load_config(_SAMPLE_TOML) + self.assertEqual(cfg["module"]["name"], "tr") + # ntmax alias propagates into NTMAX scalar. + self.assertEqual(cfg["scalars"]["NTMAX"], 42) + self.assertEqual(cfg["scalars"]["RR"], 3.0) + self.assertEqual(cfg["arrays"]["PN"][0], 0.7) + self.assertEqual(cfg["strings"]["KNAMEQ"], "eqdata.ITER") + self.assertEqual(len(cfg["plots"]), 2) + self.assertEqual(cfg["plots"][0]["variable"], "RNT") + + def test_parse_bytes_stream(self): + buf = io.BytesIO(_SAMPLE_TOML.encode("utf-8")) + cfg = loader.load_config(buf) + self.assertEqual(cfg["module"]["name"], "tr") + + def test_parse_file_path(self): + with tempfile.NamedTemporaryFile( + "w", suffix=".toml", delete=False, encoding="utf-8" + ) as fh: + fh.write(_SAMPLE_TOML) + tmp = Path(fh.name) + try: + cfg = loader.load_config(tmp) + self.assertEqual(cfg["scalars"]["RR"], 3.0) + finally: + tmp.unlink() + + def test_ntmax_alias_does_not_overwrite_explicit_ntmax(self): + raw = textwrap.dedent(""" + [module] + ntmax = 10 + + [scalars] + NTMAX = 99 + """).strip() + cfg = loader.load_config(raw) + self.assertEqual(cfg["scalars"]["NTMAX"], 99) + + def test_missing_sections_default_empty(self): + cfg = loader.load_config("[module]\nname='tr'\n") + self.assertEqual(cfg["scalars"], {}) + self.assertEqual(cfg["arrays"], {}) + self.assertEqual(cfg["strings"], {}) + self.assertEqual(cfg["plots"], []) + + def test_singular_plot_section(self): + raw = textwrap.dedent(""" + [plot] + variable = "RT" + output = "return" + """).strip() + cfg = loader.load_config(raw) + self.assertEqual(len(cfg["plots"]), 1) + self.assertEqual(cfg["plots"][0]["variable"], "RT") + + +class TestApplyConfig(unittest.TestCase): + + def test_apply_scalars_arrays_strings(self): + cfg = loader.load_config(_SAMPLE_TOML) + fake = _FakeTrlib() + loader.apply_config(fake, cfg) + + # Strings applied first. + self.assertEqual(fake.string_calls, [("KNAMEQ", "eqdata.ITER")]) + + # Scalars expanded. + scalar_names = [n for n, _ in fake.scalar_calls] + self.assertIn("RR", scalar_names) + self.assertIn("NSMAX", scalar_names) + self.assertIn("NTMAX", scalar_names) + + # Array with 1-origin index. + self.assertIn(("PN[1]", 0.7), fake.scalar_calls) + self.assertIn(("PN[4]", 0.035), fake.scalar_calls) + # Sparse dict form: only PA[2] set. + pa_keys = [n for n, _ in fake.scalar_calls if n.startswith("PA[")] + self.assertEqual(pa_keys, ["PA[2]"]) + + def test_apply_rejects_unsupported_array_shape(self): + cfg = { + "strings": {}, + "scalars": {}, + "arrays": {"PN": "not a list"}, + "plots": [], + } + fake = _FakeTrlib() + with self.assertRaises(ValueError): + loader.apply_config(fake, cfg) + + +class TestSampleTomlFiles(unittest.TestCase): + """Make sure the shipped samples parse and include plot specs.""" + + SAMPLES = PYTHON_ROOT / "trlib" / "samples" + + def test_iter01_parses(self): + cfg = loader.load_config(self.SAMPLES / "iter01.toml") + self.assertEqual(cfg["module"]["name"], "tr") + self.assertEqual(cfg["scalars"]["NSMAX"], 4) + self.assertEqual(len(cfg["arrays"]["PN"]), 4) + self.assertTrue(any(p.get("variable") == "RNT" for p in cfg["plots"])) + + def test_tst2_parses(self): + cfg = loader.load_config(self.SAMPLES / "tst2.toml") + self.assertEqual(cfg["module"]["name"], "tr") + self.assertEqual(cfg["scalars"]["NSMAX"], 2) + # Sparse dict form preserved. + self.assertEqual(cfg["arrays"]["PA"], {"2": 1.0}) + self.assertTrue(cfg["plots"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/trlib/tests/test_main.py b/python/trlib/tests/test_main.py new file mode 100644 index 00000000..0141b2b4 --- /dev/null +++ b/python/trlib/tests/test_main.py @@ -0,0 +1,114 @@ +"""Smoke tests for ``python -m trlib``. + +The full library run is gated on libtrapi.so being present; the +``--dry-run`` and ``--help`` paths are always exercised. +""" +from __future__ import annotations + +import io +import sys +import unittest +from contextlib import redirect_stderr, redirect_stdout +from pathlib import Path + +HERE = Path(__file__).resolve() +PYTHON_ROOT = HERE.parents[2] +REPO = HERE.parents[3] +if str(PYTHON_ROOT) not in sys.path: + sys.path.insert(0, str(PYTHON_ROOT)) + +from trlib import __main__ as trlib_main # noqa: E402 + +DEFAULT_SO = REPO / "tr" / "libtrapi.so" +SAMPLE_ITER01 = PYTHON_ROOT / "trlib" / "samples" / "iter01.toml" +SAMPLE_TST2 = PYTHON_ROOT / "trlib" / "samples" / "tst2.toml" + + +class TestMainCli(unittest.TestCase): + + def test_help_exits_zero(self): + # argparse exits with SystemExit(0) on --help. + buf = io.StringIO() + with self.assertRaises(SystemExit) as cm, redirect_stdout(buf): + trlib_main.main(["--help"]) + self.assertEqual(cm.exception.code, 0) + out = buf.getvalue() + self.assertIn("python -m trlib", out) + self.assertIn("--dry-run", out) + + def test_missing_config_returns_2(self): + buf = io.StringIO() + with redirect_stderr(buf): + rc = trlib_main.main(["/no/such/path.toml"]) + self.assertEqual(rc, 2) + self.assertIn("config not found", buf.getvalue()) + + def test_dry_run_iter01_returns_0(self): + out = io.StringIO() + with redirect_stdout(out): + rc = trlib_main.main([str(SAMPLE_ITER01), "--dry-run"]) + self.assertEqual(rc, 0) + text = out.getvalue() + self.assertIn("module: tr", text) + self.assertIn("NTMAX", text) + self.assertIn("--dry-run", text) + + def test_dry_run_tst2_returns_0(self): + out = io.StringIO() + with redirect_stdout(out): + rc = trlib_main.main([str(SAMPLE_TST2), "--dry-run"]) + self.assertEqual(rc, 0) + + def test_ntmax_override_visible_in_summary(self): + out = io.StringIO() + with redirect_stdout(out): + rc = trlib_main.main( + [str(SAMPLE_TST2), "--dry-run", "--ntmax", "7"] + ) + self.assertEqual(rc, 0) + self.assertIn("NTMAX: 7", out.getvalue()) + + def test_no_plots_zeroes_plot_count(self): + out = io.StringIO() + with redirect_stdout(out): + rc = trlib_main.main( + [str(SAMPLE_ITER01), "--dry-run", "--no-plots"] + ) + self.assertEqual(rc, 0) + self.assertIn("plots: 0", out.getvalue()) + + def test_malformed_config_returns_2(self): + import tempfile + with tempfile.NamedTemporaryFile( + "w", suffix=".toml", delete=False, encoding="utf-8" + ) as fh: + fh.write("this is = not [valid] = toml\n") + tmp = Path(fh.name) + try: + buf = io.StringIO() + with redirect_stderr(buf): + rc = trlib_main.main([str(tmp)]) + self.assertEqual(rc, 2) + self.assertIn("failed to parse", buf.getvalue()) + finally: + tmp.unlink() + + +@unittest.skipUnless( + DEFAULT_SO.exists(), + f"libtrapi.so not built at {DEFAULT_SO}; run `make -C tr libtrapi.so`", +) +class TestMainLibrary(unittest.TestCase): + + def test_iter01_full_run_returns_0(self): + # Disable plots so the test doesn't depend on filesystem write + # permissions to ./plots/. The library path itself is the value + # we want to exercise here. + rc = trlib_main.main( + [str(SAMPLE_ITER01), "--no-plots", "--ntmax", "0"] + ) + self.assertEqual(rc, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/trlib/tests/test_plot.py b/python/trlib/tests/test_plot.py new file mode 100644 index 00000000..c5598627 --- /dev/null +++ b/python/trlib/tests/test_plot.py @@ -0,0 +1,144 @@ +"""Unit tests for :mod:`trlib.plot`. + +Skipped entirely if matplotlib is not installed. The non-matplotlib +bits (``plot_available``, ``VARIABLE_INFO`` structure) are re-checked +in :mod:`tests.test_loader` so the TOML parser's reliance on the +metadata is always exercised. +""" +from __future__ import annotations + +import sys +import tempfile +import unittest +from pathlib import Path + +HERE = Path(__file__).resolve() +PYTHON_ROOT = HERE.parents[2] # .../python +if str(PYTHON_ROOT) not in sys.path: + sys.path.insert(0, str(PYTHON_ROOT)) + + +try: + import matplotlib # noqa: F401 + HAS_MPL = True +except ImportError: # pragma: no cover - CI-only + HAS_MPL = False + + +@unittest.skipUnless(HAS_MPL, "matplotlib not installed") +class TestPlotModule(unittest.TestCase): + + def setUp(self): + # Force a non-interactive backend so the test never tries to + # open a window. + import matplotlib + matplotlib.use("Agg", force=True) + + from trlib import plot as plot_mod + from trlib.state import TrState + self.plot_mod = plot_mod + self.TrState = TrState + + def _fake_state(self, nr=4, ns=2): + """Build a small :class:`TrState` without touching libtrapi.so.""" + scalars = { + "T": 0.0, "WPT": 1.0, "AJT": 2.0, "Q0": 0.9, + "BETA0": 0.01, "BETAP0": 0.02, "BETAA": 0.03, "BETAN": 0.04, + "TAUE1": 0.5, "TAUE2": 0.6, "ZEFF0": 1.0, "ALI": 0.7, "RQ1": 1.1, + } + rn = [[0.5 + 0.1 * i + 0.01 * j for j in range(ns)] for i in range(nr)] + rt = [[1.0 - 0.1 * i + 0.01 * j for j in range(ns)] for i in range(nr)] + aj = [0.2 * i for i in range(nr)] + qp = [1.0 + 0.1 * i for i in range(nr)] + return self.TrState( + nt=0, nrmax=nr, nsmax=ns, scalars=scalars, + RN=rn, RT=rt, AJ=aj, QP=qp, + ) + + # --- metadata tests ---------------------------------------------- + def test_plot_available_sorted(self): + names = self.plot_mod.plot_available() + self.assertEqual(names, sorted(names)) + # Must include the core profile variables called out in the spec. + for key in ("RN", "RT", "AJ", "QP", "RNT", "RWT"): + self.assertIn(key, names) + + def test_variable_info_structure(self): + for name, info in self.plot_mod.VARIABLE_INFO.items(): + self.assertIn("label", info, f"{name} missing 'label'") + self.assertIn("kind", info, f"{name} missing 'kind'") + self.assertIn("dim", info, f"{name} missing 'dim'") + self.assertIn(info["kind"], ("profile_1d", "profile_2d", "scalar")) + self.assertIn(info["dim"], (0, 1, 2)) + + def test_plot_unknown_variable_raises(self): + with self.assertRaises(KeyError): + self.plot_mod.plot("NOT_A_REAL_VAR", state=self._fake_state()) + + def test_plot_invalid_output_raises(self): + with self.assertRaises(ValueError): + self.plot_mod.plot( + "AJ", state=self._fake_state(), output="bogus", + ) + + def test_plot_requires_state_or_figure(self): + with self.assertRaises(ValueError): + self.plot_mod.plot("AJ") + + # --- rendering tests --------------------------------------------- + def test_plot_return_mode_returns_figure(self): + fig = self.plot_mod.plot( + "AJ", state=self._fake_state(), output="return", + ) + from matplotlib.figure import Figure + self.assertIsInstance(fig, Figure) + + def test_plot_file_mode_writes_png(self): + with tempfile.TemporaryDirectory() as tmp: + target = Path(tmp) / "aj.png" + out = self.plot_mod.plot( + "AJ", state=self._fake_state(), + output="file", format="png", path=str(target), + ) + self.assertEqual(Path(out), target) + self.assertTrue(target.exists()) + self.assertGreater(target.stat().st_size, 0) + + def test_plot_profile_2d_species_stacked(self): + # RN is 2D (nrmax x nsmax). Make sure we don't blow up. + fig = self.plot_mod.plot( + "RN", state=self._fake_state(nr=5, ns=3), output="return", + ) + from matplotlib.figure import Figure + self.assertIsInstance(fig, Figure) + + def test_plot_alias_rnt_to_rn(self): + # RNT is an alias of RN. Both paths should succeed without the + # alias leaking into the figure title. + fig = self.plot_mod.plot( + "RNT", state=self._fake_state(), output="return", + ) + from matplotlib.figure import Figure + self.assertIsInstance(fig, Figure) + + def test_plot_scalar_draws_bar(self): + # Scalars can't be 1D-plotted from a snapshot; the module falls + # back to a single-bar chart so the caller still gets a figure. + fig = self.plot_mod.plot( + "WPT", state=self._fake_state(), output="return", + ) + from matplotlib.figure import Figure + self.assertIsInstance(fig, Figure) + + def test_plot_file_creates_parent_dir(self): + with tempfile.TemporaryDirectory() as tmp: + target = Path(tmp) / "nested" / "out" / "rn.png" + out = self.plot_mod.plot( + "RN", state=self._fake_state(), + output="file", format="png", path=str(target), + ) + self.assertTrue(Path(out).exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/trlib/trlib.py b/python/trlib/trlib.py index b802558a..b62b9798 100644 --- a/python/trlib/trlib.py +++ b/python/trlib/trlib.py @@ -150,5 +150,38 @@ def get_state(self) -> TrState: raise_for_ierr("tr_get_state", ierr) return TrState.from_c(c) + # --- visualization -------------------------------------------------- + def plot(self, varname: str, **kwargs): + """Plot a state variable from the current simulation snapshot. + + Thin wrapper around :func:`trlib.plot.plot` that captures the + current :class:`TrState` via :meth:`get_state` and forwards + keyword arguments verbatim (``output``, ``format``, ``path``, + ``title``, ``overlay``, etc.). See :mod:`trlib.plot` for the + full argument reference. + + ``matplotlib`` is an optional dependency; importing the plot + module lazily lets ``trlib`` stay usable when matplotlib is + absent. If you call this method without matplotlib installed + the :class:`ImportError` from :mod:`trlib.plot` will propagate. + """ + from . import plot as _plot_mod + state = kwargs.pop("state", None) + if state is None: + state = self.get_state() + return _plot_mod.plot(varname, state=state, **kwargs) + + @staticmethod + def plot_available(): + """Return the variable names that :meth:`plot` can draw. + + Thin wrapper around :func:`trlib.plot.plot_available`. Kept as a + static method so callers can probe the plot backend without + instantiating ``Trlib()`` (and therefore without loading + libtrapi.so). + """ + from . import plot as _plot_mod + return _plot_mod.plot_available() + __all__ = ["Trlib"]