diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..6d5349e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +--- +# Set update schedule for GitHub Actions + +version: 2 +updates: + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every month + interval: "monthly" diff --git a/README.md b/README.md index 09180e2..26346d2 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,59 @@ Stereo analysis methods implemented in Eventdisplay provide direction / energies Output is a single ROOT tree called `StereoAnalysis` with the same number of events as the input tree. +### Training Stereo Reconstruction Models + +The stereo regression training pipeline uses multi-target XGBoost to predict residuals (deviations from baseline reconstructions): + +**Targets:** `[Xoff_residual, Yoff_residual, E_residual]` (residuals on direction and energy as reconstruction by the BDT stereo reconstruction method) + +**Key techniques:** + +- **Target standardization:** Targets are mean-centered and scaled to unit variance during training +- **Energy-bin weighting:** Events are weighted inversely by energy bin density; bins with fewer than 10 events are excluded from training to prevent overfitting on low-statistics regions +- **Multiplicity weighting:** Higher-multiplicity events (more telescopes) receive higher sample weights to prioritize high-confidence reconstructions +- **Per-target SHAP importance:** Feature importance values computed during training for each target and cached for later analysis + +**Training command:** + +```bash +eventdisplay-ml-train-xgb-stereo \ + --input_file_list train_files.txt \ + --model_prefix models/stereo_model \ + --max_events 100000 \ + --train_test_fraction 0.5 \ + --max_cores 8 +``` + +**Output:** Joblib model file containing: + +- XGBoost trained model object +- Target standardization scalers (mean/std) +- Feature list and SHAP importance rankings +- Training metadata (random state, hyperparameters) + +### Applying Stereo Reconstruction Models + +The apply pipeline loads trained models and makes predictions: + +**Key safeguards:** + +- Invalid energy values (≤0 or NaN) produce NaN outputs but preserve all input event rows +- Missing standardization parameters raise ValueError (prevents silent data corruption) +- Output row count always equals input row count + +**Apply command:** + +```bash +eventdisplay-ml-apply-xgb-stereo \ + --input_file_list apply_files.txt \ + --output_file_list output_files.txt \ + --model_prefix models/stereo_model +``` + + +**Output:** ROOT files with `StereoAnalysis` tree containing reconstructed Xoff, Yoff, and log10(E). + ## Gamma/hadron separation using XGBoost Gamma/hadron separation is performed using XGB Boost classification trees. Features are image parameters and stereo reconstruction parameters provided by Eventdisplay. @@ -27,6 +80,223 @@ The zenith angle dependence is accounted for by including the zenith angle as a Output is a single ROOT tree called `Classification` with the same number of events as the input tree. It contains the classification prediction (`Gamma_Prediction`) and boolean flags (e.g. `Is_Gamma_75` for 75% signal efficiency cut). +## Diagnostic Tools + +The committed regression diagnostics in this branch are: + +### SHAP feature-importance summary + + Tests: Feature importance + +- Load per-target SHAP importances cached in the trained model file +- Create one top-20 feature plot per residual target (`Xoff_residual`, `Yoff_residual`, `E_residual`) + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated PNGs + +Run: + +```bash + eventdisplay-ml-diagnostic-shap-summary \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Outputs: + +- `diagnostics/shap_importance_Xoff_residual.png` +- `diagnostics/shap_importance_Yoff_residual.png` +- `diagnostics/shap_importance_E_residual.png` + +### Permutation importance + +- Rebuild the held-out test split from the model metadata and original input files +- Shuffle one feature at a time and measure the relative RMSE increase per residual target +- Validate predictive dependence on features rather than cached model attribution + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots +- `--top_n`: number of top features to include in the plot (optional) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-permutation-importance \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ \ + --top_n 20 +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-permutation-importance \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --output_dir diagnostics/ +``` + +Output: + +- `diagnostics/permutation_importance.png` + +Notes: + +- This diagnostic is slower than the SHAP summary because it rebuilds the processed test split. +- It is the better choice when you want to measure actual performance sensitivity to each feature. + +### Generalization gap + +- Read the cached train/test RMSE summary written during training +- Compare final train and test RMSE for each residual target +- Quantify the overfitting gap after training is complete + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-generalization-gap \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-generalization-gap \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --output_dir diagnostics/ +``` + +Output: + +- `diagnostics/generalization_gap.png` + +Notes: + +- This diagnostic measures final overfitting by comparing train and test residual RMSE. +- Older model files without cached metrics fall back to rebuilding the original train/test split. +- Unlike `plot_training_evaluation.py`, it summarizes final RMSE, not the per-iteration XGBoost training history. + +### Partial Dependence Plots + +- Visualize how each feature influences model predictions +- Prove the model captures physics by checking that multiplicity reduces corrections and baselines show smooth relationships + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots (optional; default: `diagnostics`) +- `--features`: space-separated list of features to plot (optional; default: `DispNImages Xoff_weighted_bdt Yoff_weighted_bdt ErecS`) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-partial-dependence \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ \ + --features DispNImages Xoff_weighted_bdt ErecS +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-partial-dependence \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --features Xoff_weighted_bdt Yoff_weighted_bdt +``` + +Output: + +- `diagnostics/partial_dependence.png` (grid of feature × target subplots) + +Notes: + +- PDP displays predicted residual output as a function of a single feature while holding others constant +- Multiplicity effect: high-multiplicity events should show smaller corrections (negative slope) +- Baseline stability: baseline features (e.g., `weighted_bdt`) should show smooth, linear relationships +- This diagnostic rebuilds the held-out test split and is slower than SHAP summary + +### Residual Normality Diagnostics + +- Validate that model residuals follow a normal distribution +- Detect outlier events and check for systematic biases in reconstruction errors + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots (optional; default: `diagnostics`) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-residual-normality \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-residual-normality \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt +``` + +Output: + +- Residual normality statistics printed to console: + - Mean and standard deviation per target + - Kolmogorov-Smirnov test p-value (normality test) + - Anderson-Darling test statistic and critical value + - Skewness and kurtosis + - Q-Q plot R² value + - Number of outliers (>3σ) per target +- `diagnostics/residual_diagnostics.png` (single 2xN grid; generated on cache miss when reconstruction is required) + +Notes: + +- Residual normality stats are cached during training and loaded from the model file for fast retrieval +- Diagnostic plots (histograms, Q-Q plots) are only generated when the split must be reconstructed +- Invalid KS test or Anderson-Darling results (NaN/inf) are reported as special values +- Outlier counts help identify events with unusually large reconstruction errors + +### Training-evaluation curves + +- Plot XGBoost training vs validation metric curves +- Useful for checking convergence and overfitting behavior + +Required inputs: + +- `--model_file`: trained model `.joblib` containing an XGBoost model +- `--output_file`: output image path (optional; if omitted, plot is shown interactively) + +Run: + +```bash +eventdisplay-ml-plot-training-evaluation \ + --model_file models/stereo_model.joblib \ + --output_file diagnostics/training_curves.png +``` + +Output: + +- Figure with one panel per tracked metric (for example `rmse`), showing training and test curves. + ## Generative AI disclosure Generative AI tools (including Claude, ChatGPT, and Gemini) were used to assist with code development, debugging, and documentation drafting. All AI-assisted outputs were reviewed, validated, and, where necessary, modified by the authors to ensure accuracy and reliability. diff --git a/docs/changes/53.feature.md b/docs/changes/53.feature.md index 2e545ea..cc91512 100644 --- a/docs/changes/53.feature.md +++ b/docs/changes/53.feature.md @@ -1,14 +1,37 @@ -Fix critical bugs in stereo regression pipeline: +## Stereo Regression: Training on Residuals with Standardization and Energy Weighting -- **Fixed double log10 application**: E_residual was being computed with log10(ErecS) that had already been log10'd. Now ErecS/Erec remain in linear space during training/apply; log10 applied explicitly when needed. -- **Fixed energy bin weighting**: Bins with fewer than 10 events now correctly get zero weight instead of being clamped; weight sorting preserves bin order. -- **Fixed standardization inversion**: Added proper loading and validation of target_mean/target_std scalers in stereo apply pipeline to prevent KeyError crashes. -- **Fixed ErecS validation**: Safe log10 computation during apply avoids RuntimeWarning for invalid values; all output rows preserved even with invalid energy. -- **Fixed evaluation metrics**: ErecS in evaluation now properly converted to log10 space for energy resolution comparison. -- **Fixed FutureWarning**: Series positional indexing converted to numpy arrays for future pandas compatibility. +### Architectural Change -New features and improvements: +- **Training targets changed from absolute to residual values**: Models now predict residuals (deviations from baseline reconstructions) rather than absolute directions/energies. This allows XGBoost to learn corrections to existing Eventdisplay reconstructions (DispBDT, intersection method) and leverage their baseline accuracy as a starting point. -- **Comprehensive test coverage**: Added `test_regression_apply.py` with full unit test suite covering standardization inversion, residual computation, ErecS handling, and final prediction reconstruction. -- **Improved error messages**: Clear, actionable error messages when standardization parameters are missing or mismatched in apply pipeline. -- **Data preservation guarantee**: Stereo apply pipeline now preserves all input rows even when encountering invalid energy values, ensuring output count equals input count. +### Critical Bug Fixes + +- **Fixed double log10 application**: Energy residuals computed in linear space; log10 applied explicitly during evaluation +- **Fixed standardization inversion**: Apply pipeline now loads and validates target_mean/target_std scalers (prevents KeyError) +- **Fixed energy-bin weighting**: Bins with <10 events get zero weight; correct inverse weighting for balanced training +- **Fixed ErecS validation**: Safe log10 computation during apply; all input rows preserved in output +- **Fixed evaluation metrics**: Energy resolution compared in log10 space with proper baseline alignment +- **Fixed FutureWarning**: Series positional indexing converted to numpy arrays for pandas compatibility + +### New Features + +- **Target standardization in training**: Residuals standardized to mean=0, std=1 during training to enable multi-target learning with balanced learning signals (direction and energy equally weighted) +- **Energy-bin weighted training**: Events weighted inversely by energy bin density; bins with <10 events excluded to prevent overfitting on low-statistics regions +- **Per-target SHAP importance caching**: Feature importances computed once during training for each target (Xoff_residual, Yoff_residual, E_residual), cached for diagnostic tools +- **Diagnostic scripts**: + - `diagnostic_shap_summary.py`: Top-20 feature importance plots per residual target + - `plot_training_evaluation.py`: Energy resolution and residual distribution visualization +- **Comprehensive test suites**: 20 new tests covering residual computation, standardization, energy weighting, apply inference +- **Robust error handling**: Clear messages for missing scalers; guaranteed row-count preservation in apply pipeline + +### Enhanced Diagnostic Pipeline + +- **Generalization-gap metrics cached during training**: Train/test RMSE, gap %, and generalization ratio computed and cached in the model artifact, enabling fast overfitting assessment without recomputation +- **Residual normality statistics cached during training**: Normality tests (Kolmogorov-Smirnov, Anderson-Darling), distribution shape metrics (skewness, kurtosis, Q-Q R²), and outlier counts computed once during training and cached for fast retrieval +- **Diagnostic reconstruction from model metadata**: All regression diagnostics (generalization-gap, partial-dependence, residual-normality) now reconstruct the held-out test split from stored model metadata + input file list, enabling reproducibility and offline analysis without CSV exports +- **Cache-first diagnostic workflows**: Diagnostic scripts load cached metrics first (fast) with graceful fallback to reconstruction if cache unavailable (backward compatible with older models) +- **CLI entry points for all diagnostics**: + - `eventdisplay-ml-diagnostic-generalization-gap`: Quantify overfitting via train/test RMSE comparison + - `eventdisplay-ml-diagnostic-partial-dependence`: Validate model captures physics via partial dependence curves + - `eventdisplay-ml-diagnostic-residual-normality`: Validate residual normality and detect outliers +- **Fixed sklearn FutureWarning**: Partial dependence plots convert feature data to float64 to avoid integer dtype warnings in newer scikit-learn versions diff --git a/pyproject.toml b/pyproject.toml index 8927dc5..b4f7445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,11 @@ urls."documentation" = "https://github.com/Eventdisplay/Eventdisplay-ML" urls."repository" = "https://github.com/Eventdisplay/Eventdisplay-ML" scripts.eventdisplay-ml-apply-xgb-classify = "eventdisplay_ml.scripts.apply_xgb_classify:main" scripts.eventdisplay-ml-apply-xgb-stereo = "eventdisplay_ml.scripts.apply_xgb_stereo:main" +scripts.eventdisplay-ml-diagnostic-generalization-gap = "eventdisplay_ml.scripts.diagnostic_generalization_gap:main" +scripts.eventdisplay-ml-diagnostic-partial-dependence = "eventdisplay_ml.scripts.diagnostic_partial_dependence:main" +scripts.eventdisplay-ml-diagnostic-permutation-importance = "eventdisplay_ml.scripts.diagnostic_permutation_importance:main" +scripts.eventdisplay-ml-diagnostic-residual-normality = "eventdisplay_ml.scripts.diagnostic_residual_normality:main" +scripts.eventdisplay-ml-diagnostic-shap-summary = "eventdisplay_ml.scripts.diagnostic_shap_summary:main" scripts.eventdisplay-ml-plot-classification-performance-metrics = "eventdisplay_ml.scripts.plot_classification_performance_metrics:main" scripts.eventdisplay-ml-plot-classification-gamma-efficiency = "eventdisplay_ml.scripts.plot_classification_gamma_efficiency:main" scripts.eventdisplay-ml-plot-training-evaluation = "eventdisplay_ml.scripts.plot_training_evaluation:main" diff --git a/src/eventdisplay_ml/diagnostic_utils.py b/src/eventdisplay_ml/diagnostic_utils.py new file mode 100644 index 0000000..aae271e --- /dev/null +++ b/src/eventdisplay_ml/diagnostic_utils.py @@ -0,0 +1,401 @@ +"""Utilities for inspecting cached diagnostic data in model joblib files. + +The current stereo regression cache stores per-target SHAP importances under +``models[]["shap_importance"]``. This module supports that layout while +remaining compatible with older cache keys where possible. +""" + +import logging + +import joblib +import numpy as np +import pandas as pd +from scipy import stats +from sklearn.metrics import mean_squared_error +from sklearn.model_selection import train_test_split + +from eventdisplay_ml.data_processing import load_training_data + +_logger = logging.getLogger(__name__) + + +def _load_model_cfg(model_file): + """Load full model dictionary and the first model configuration entry.""" + model_dict = joblib.load(model_file) + models = model_dict.get("models", {}) + model_cfg = next(iter(models.values())) if models else None + return model_dict, model_cfg + + +def load_stereo_regression_split(model_file, input_file_list=None): + """Load a stereo regression model and reconstruct its train/test split. + + Parameters + ---------- + model_file : str + Path to trained model joblib file. + input_file_list : str or None, optional + Optional override for the input file list stored in the model metadata. + + Returns + ------- + tuple + Trained model, reconstructed x_train, y_train, x_test, y_test, + feature names, target names, and full model metadata. + """ + _logger.info(f"Loading model from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + raise ValueError(f"No models found in model file: {model_file}") + + model = model_cfg.get("model") + if model is None: + raise ValueError(f"No trained model object found in model file: {model_file}") + + file_list = input_file_list or model_dict.get("input_file_list") + if not file_list: + raise ValueError( + "No input file list available. Provide --input_file_list or retrain with " + "input_file_list stored in the model metadata." + ) + + _logger.info(f"Rebuilding training data from input file list: {file_list}") + df = load_training_data(model_dict, file_list, "stereo_analysis") + + features = model_cfg.get("features") or model_dict.get("features", []) + targets = model_dict.get("targets", ["Xoff_residual", "Yoff_residual", "E_residual"]) + if not features: + raise ValueError(f"No feature list found in model file: {model_file}") + + x_data = df[features] + y_data = df[targets] + + _logger.info("Reconstructing train/test split from model metadata") + x_train, x_test, y_train, y_test = train_test_split( + x_data, + y_data, + train_size=model_dict.get("train_test_fraction", 0.5), + random_state=model_dict.get("random_state", None), + ) + + _logger.info( + "Reconstructed split with %d training and %d test events", + len(x_train), + len(x_test), + ) + return model, x_train, y_train, x_test, y_test, features, targets, model_dict + + +def predict_unscaled_residuals(model, x_data, model_dict, target_names): + """Predict residual targets and inverse-standardize them to original scale.""" + preds_scaled = model.predict(x_data) + + target_mean_cfg = model_dict.get("target_mean") + target_std_cfg = model_dict.get("target_std") + if not target_mean_cfg or not target_std_cfg: + raise ValueError( + "Missing target standardization parameters (target_mean/target_std) in model " + "file. This diagnostic requires a residual-trained stereo model." + ) + + target_mean = np.array([target_mean_cfg[target] for target in target_names], dtype=np.float64) + target_std = np.array([target_std_cfg[target] for target in target_names], dtype=np.float64) + + preds = preds_scaled * target_std + target_mean + return pd.DataFrame(preds, columns=target_names, index=x_data.index) + + +def compute_generalization_metrics(y_train, y_train_pred, y_test, y_test_pred, target_names): + """Compute train/test RMSE and relative generalization gap per target.""" + metrics = {} + + for target_name in target_names: + rmse_train = np.sqrt(mean_squared_error(y_train[target_name], y_train_pred[target_name])) + rmse_test = np.sqrt(mean_squared_error(y_test[target_name], y_test_pred[target_name])) + + if rmse_train == 0: + gap_pct = 0.0 if rmse_test == 0 else np.inf + else: + gap_pct = (rmse_test - rmse_train) / rmse_train * 100 + + # Unitless generalization ratio: >1 means worse on test than train. + if rmse_train == 0: + gen_ratio = 1.0 if rmse_test == 0 else np.inf + else: + gen_ratio = rmse_test / rmse_train + + metrics[target_name] = { + "rmse_train": float(rmse_train), + "rmse_test": float(rmse_test), + "gap_pct": float(gap_pct), + "gen_ratio": float(gen_ratio), + } + + return metrics + + +def load_cached_generalization_metrics(model_file): + """Load cached train/test RMSE summary from a model file if available.""" + _logger.info(f"Loading cached generalization metrics from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + metrics = model_cfg.get("generalization_metrics") + if not isinstance(metrics, dict) or not metrics: + _logger.warning("No cached generalization metrics found in model file") + return model_dict, None + + _logger.info("Loaded cached generalization metrics for %d targets", len(metrics)) + return model_dict, metrics + + +def compute_residual_normality_stats(y_test, y_test_pred, target_names): + """Compute Gaussian fit parameters and normality tests for residuals.""" + stats_dict = {} + + for target_name in target_names: + residuals = y_test[target_name].values - y_test_pred[target_name].values + residuals_clean = residuals[~np.isnan(residuals)] + + if len(residuals_clean) == 0: + _logger.warning(f"Skipping {target_name}: no finite residuals") + continue + + # Gaussian parameters + mean = float(np.mean(residuals_clean)) + std = float(np.std(residuals_clean)) + + # Normality tests + _, p_ks = stats.kstest(residuals_clean, "norm", args=(mean, std)) + ad_result = stats.anderson(residuals_clean, dist="norm", method="interpolate") + ad_stat = float(ad_result.statistic) + # With method='interpolate', ad_result is SignificanceResult with pvalue + # (not AndersonResult, which has critical_values) + ad_crit_5 = float(ad_result.pvalue if hasattr(ad_result, "pvalue") else np.nan) + + # Skewness and kurtosis + skewness = float(stats.skew(residuals_clean)) + kurtosis = float(stats.kurtosis(residuals_clean)) + + # Quantile-Quantile test (visual) + _, (_, _, qq_r) = stats.probplot(residuals_clean, dist="norm") + qq_r2 = float(qq_r**2) + + # Outlier count + n_outliers = int(np.sum(np.abs(residuals_clean) > 3 * std)) + + stats_dict[target_name] = { + "mean": mean, + "std": std, + "p_ks": float(p_ks), + "ad_stat": ad_stat, + "ad_pvalue": ad_crit_5, + "skewness": skewness, + "kurtosis": kurtosis, + "qq_r2": qq_r2, + "n_outliers": n_outliers, + "n_samples": len(residuals_clean), + } + + return stats_dict + + +def load_cached_residual_normality_stats(model_file): + """Load cached residual normality statistics from a model file if available.""" + _logger.info(f"Loading cached residual normality stats from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + normality_stats = model_cfg.get("residual_normality_stats") + if not isinstance(normality_stats, dict) or not normality_stats: + _logger.warning("No cached residual normality statistics found in model file") + return model_dict, None + + _logger.info("Loaded cached residual normality stats for %d targets", len(normality_stats)) + return model_dict, normality_stats + + +def load_model_and_importance(model_file, target_name=None): + """ + Load model dict and precomputed feature importance. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + dict + Full model dictionary with model metadata. + dict or None + Precomputed feature importances {feature_name: importance_value} for + the selected target. + """ + _logger.info(f"Loading model from {model_file}") + + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + shap_importance = model_cfg.get("shap_importance") + features = model_cfg.get("features") or model_dict.get("features", []) + + if shap_importance is None: + # Backward compatibility for legacy key if present. + shap_importance = model_cfg.get("feature_importances") + + if shap_importance is None or not features: + _logger.warning("No cached feature importances found in model file") + return model_dict, None + + if isinstance(shap_importance, dict): + if not shap_importance: + _logger.warning("Cached SHAP importance dictionary is empty") + return model_dict, None + + selected_target = target_name or next(iter(shap_importance)) + importances = shap_importance.get(selected_target) + if importances is None: + _logger.warning( + "Target %r not found in cached SHAP importance; available targets: %s", + selected_target, + list(shap_importance.keys()), + ) + return model_dict, None + + importance_dict = dict(zip(features, importances, strict=False)) + _logger.info( + "Loaded cached SHAP importances for target %s (%d features)", + selected_target, + len(importance_dict), + ) + return model_dict, importance_dict + + importance_dict = dict(zip(features, shap_importance, strict=False)) + _logger.info("Loaded cached feature importances (%d features)", len(importance_dict)) + return model_dict, importance_dict + + +def get_cached_shap_explainer(model_file): + """ + Retrieve cached SHAP explainer if available. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + shap.TreeExplainer or None + Cached SHAP explainer, or None if not available. + """ + _logger.info(f"Loading SHAP explainer from {model_file}") + _, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + return None + + explainer = model_cfg.get("shap_explainer") + if explainer is not None: + _logger.info("Successfully loaded cached SHAP explainer") + return explainer + + _logger.warning("No cached SHAP explainer found. Will compute on-the-fly.") + return None + + +def importance_dataframe(model_file, top_n=25, target_name=None): + """ + Get feature importance as a sorted pandas DataFrame. + + Parameters + ---------- + model_file : str + Path to joblib model file. + top_n : int, optional + Return only top N features by importance (default 25). + + Returns + ------- + pd.DataFrame + DataFrame with columns ["Feature", "Importance"] sorted by importance. + """ + _, importance_dict = load_model_and_importance(model_file, target_name=target_name) + + if importance_dict is None: + _logger.error("Cannot create importance dataframe: no importances found") + return pd.DataFrame() + + df = pd.DataFrame( + { + "Feature": list(importance_dict.keys()), + "Importance": list(importance_dict.values()), + } + ).sort_values("Importance", ascending=False) + + if top_n: + df = df.head(top_n) + + return df + + +def validate_cached_data(model_file): + """ + Check what data is cached in the model file. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + dict + Summary of cached data availability. + """ + model_dict, model_cfg = _load_model_cfg(model_file) + model_cfg = model_cfg or {} + + shap_importance = model_cfg.get("shap_importance") + has_shap_importance = shap_importance is not None + shap_targets = list(shap_importance.keys()) if isinstance(shap_importance, dict) else [] + + summary = { + "has_model": "model" in model_cfg, + "has_features": "features" in model_cfg, + "has_shap_importance": has_shap_importance, + "has_generalization_metrics": "generalization_metrics" in model_cfg, + "has_residual_normality_stats": "residual_normality_stats" in model_cfg, + "has_feature_importances": "feature_importances" in model_cfg, # legacy key + "has_shap_explainer": "shap_explainer" in model_cfg, + "has_target_mean": "target_mean" in model_dict, + "has_target_std": "target_std" in model_dict, + "n_features": len(model_cfg.get("features", [])), + "n_targets_with_shap": len(shap_targets), + "shap_targets": shap_targets, + "generalization_targets": list(model_cfg.get("generalization_metrics", {}).keys()), + "residual_normality_targets": list(model_cfg.get("residual_normality_stats", {}).keys()), + "n_importances": len(model_cfg.get("feature_importances", [])) + if "feature_importances" in model_cfg + else 0, + } + + if isinstance(shap_importance, dict): + summary["n_importances_per_target"] = { + target: len(values) for target, values in shap_importance.items() + } + else: + summary["n_importances_per_target"] = {} + + return summary diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index dc663d3..9937e36 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -13,7 +13,7 @@ import xgboost as xgb from sklearn.model_selection import train_test_split -from eventdisplay_ml import data_processing, features, utils +from eventdisplay_ml import data_processing, diagnostic_utils, features, utils from eventdisplay_ml.data_processing import ( energy_in_bins, flatten_feature_data, @@ -631,13 +631,20 @@ def train_regression(df, model_configs): y_train_scaled, sample_weight=weights_train, eval_set=eval_set, - verbose=True, + verbose=False, ) _logger.info( f"Training stopped at iteration {model.best_iteration} " f"(best score: {model.best_score:.4f})" ) + y_train_pred_scaled = model.predict(x_train) + y_train_pred = pd.DataFrame( + y_train_pred_scaled * y_std.values + y_mean.values, + columns=model_configs["targets"], + index=y_train.index, + ) + # Predict on scaled targets and inverse transform back to original scale y_pred_scaled = model.predict(x_test) y_pred = pd.DataFrame( @@ -646,11 +653,27 @@ def train_regression(df, model_configs): index=y_test.index, ) + generalization_metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_pred, + model_configs["targets"], + ) + + residual_normality_stats = diagnostic_utils.compute_residual_normality_stats( + y_test, + y_pred, + model_configs["targets"], + ) + shap_importance = evaluate_regression_model( model, x_test, y_pred, y_test, df, x_cols, y_data, name ) cfg["model"] = model cfg["features"] = x_cols # Store feature names for later use + cfg["generalization_metrics"] = generalization_metrics + cfg["residual_normality_stats"] = residual_normality_stats cfg["shap_importance"] = shap_importance # Store per-target SHAP importance from evaluation return model_configs diff --git a/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py b/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py new file mode 100644 index 0000000..04479ee --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py @@ -0,0 +1,195 @@ +r"""Generalization Ratio: Quantify overfitting gap between train and test performance. + +Uses cached train/test RMSE values written during training when available. +For older model files without cached metrics, it falls back to rebuilding the +original train/test split from the stored input metadata. + +Usage: + python diagnostic_generalization_gap.py \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def compute_rmse_and_gaps(model, x_train, y_train, x_test, y_test, model_dict, target_names): + """Compute RMSE for train and test, derive generalization metrics.""" + _logger.info("Computing train and test RMSE...") + + y_train_pred = diagnostic_utils.predict_unscaled_residuals( + model, + x_train, + model_dict, + target_names, + ) + y_test_pred = diagnostic_utils.predict_unscaled_residuals( + model, + x_test, + model_dict, + target_names, + ) + + metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_test_pred, + target_names, + ) + + for target_name in target_names: + gap_pct = metrics[target_name]["gap_pct"] + _logger.info(f"\n{target_name}:") + _logger.info(f" Train RMSE: {metrics[target_name]['rmse_train']:.6f}") + _logger.info(f" Test RMSE: {metrics[target_name]['rmse_test']:.6f}") + _logger.info(f" Gap: {gap_pct:.2f}%") + _logger.info(f" Generalization: {'PASS' if gap_pct < 10 else 'WARN'} (threshold <10%)") + + return metrics + + +def plot_generalization_metrics(metrics, output_dir): + """Create visualization of train/test RMSE and generalization gap.""" + _logger.info("Creating generalization plots...") + + targets = list(metrics.keys()) + + # Plot 1: Train vs Test RMSE + _, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Subplot 1: RMSE comparison + rmse_train = [metrics[target]["rmse_train"] for target in targets] + rmse_test = [metrics[target]["rmse_test"] for target in targets] + + x_pos = np.arange(len(targets)) + width = 0.35 + + axes[0].bar(x_pos - width / 2, rmse_train, width, label="Train RMSE", alpha=0.8) + axes[0].bar(x_pos + width / 2, rmse_test, width, label="Test RMSE", alpha=0.8) + axes[0].set_ylabel("RMSE") + axes[0].set_title("Training vs Test Performance") + axes[0].set_xticks(x_pos) + axes[0].set_xticklabels(targets, rotation=15) + axes[0].legend() + axes[0].grid(axis="y", alpha=0.3) + + # Subplot 2: Generalization gap + gaps = [metrics[target]["gap_pct"] for target in targets] + colors = ["green" if gap < 10 else "orange" if gap < 15 else "red" for gap in gaps] + + axes[1].bar(targets, gaps, color=colors, alpha=0.7) + axes[1].axhline( + y=10, + color="green", + linestyle="--", + linewidth=2, + label="Safe threshold (10%)", + ) + axes[1].axhline( + y=15, + color="orange", + linestyle="--", + linewidth=2, + label="Warning (15%)", + ) + axes[1].set_ylabel("Gap (%)") + axes[1].set_title("Generalization Gap: (Test-Train)/Train x 100%") + axes[1].set_xticks(x_pos) + axes[1].set_xticklabels(targets, rotation=15) + axes[1].legend() + axes[1].grid(axis="y", alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "generalization_gap.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved generalization plot to {output_path}") + plt.close() + + +def diagnose_overfitting(metrics): + """Summary diagnosis of overfitting status.""" + _logger.info("\n%s", "=" * 60) + _logger.info("OVERFITTING DIAGNOSIS") + _logger.info("%s", "=" * 60) + + all_gaps = [metrics[target]["gap_pct"] for target in metrics] + mean_gap = np.mean(all_gaps) + + _logger.info(f"\nMean Generalization Gap: {mean_gap:.2f}%") + + if mean_gap < 5: + status = "EXCELLENT - Model shows minimal overfitting" + elif mean_gap < 10: + status = "GOOD - Model generalization is safe" + elif mean_gap < 15: + status = "ACCEPTABLE - Minor overfitting, monitor carefully" + else: + status = "WARNING - Significant overfitting detected" + + _logger.info(f"Status: {status}") + _logger.info("\nPer-target breakdown:") + for target, data in metrics.items(): + _logger.info(f" {target}: {data['gap_pct']:.2f}% gap") + + +def main(): + """Load cached generalization metrics or fall back to reconstructing the split.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + _, metrics = diagnostic_utils.load_cached_generalization_metrics(args.model_file) + + if metrics is None: + _logger.info( + "Cached generalization metrics are unavailable; rebuilding the train/test split" + ) + model, x_train, y_train, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + args.model_file, + args.input_file_list, + ) + ) + + metrics = compute_rmse_and_gaps( + model, + x_train, + y_train, + x_test, + y_test, + model_dict, + target_names, + ) + else: + _logger.info("Using cached generalization metrics from the model file") + + plot_generalization_metrics(metrics, args.output_dir) + diagnose_overfitting(metrics) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py b/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py new file mode 100644 index 0000000..4b71fd3 --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py @@ -0,0 +1,192 @@ +r"""Partial Dependence Plots (PDP): Prove the model captures physics, not chaos. + +Plots predicted residual output as a function of a single feature while holding +others constant. For stereo reconstruction, proves that the model correctly +reduces corrections for high-multiplicity events and increases them for sparse data. + +Usage: + eventdisplay-ml-diagnostic-partial-dependence \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ \\ + --features DispNImages Xoff_weighted_bdt ErecS +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.inspection import partial_dependence + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def load_data_and_model(model_file, input_file_list=None): + """Load trained model and reconstruct the held-out test split.""" + model, _, _, x_test, _, features, target_names, _ = ( + diagnostic_utils.load_stereo_regression_split( + model_file, + input_file_list, + ) + ) + return model, x_test, features, target_names + + +def compute_partial_dependence(model, x_test, features_to_plot): + """Compute partial dependence for selected features.""" + _logger.info(f"Computing partial dependence for {len(features_to_plot)} features...") + + # Convert all features to float to avoid sklearn warnings about integer dtypes + x_test_float = x_test.astype(np.float64) + + pdp_data = {} + + for feat_name in features_to_plot: + if feat_name not in x_test_float.columns: + _logger.warning(f"Feature {feat_name} not found in data") + continue + + feat_idx = x_test_float.columns.get_loc(feat_name) + + # Compute PDP for each target + pd_result = partial_dependence( + model, + x_test_float, + [feat_idx], + grid_resolution=50, + percentiles=(0.05, 0.95), + ) + + pd_values = pd_result.get("average") + if pd_values is None: + pd_values = pd_result.get("average_predictions") + if pd_values is None: + raise ValueError( + "Could not find partial dependence output in result. " + "Expected key 'average' or 'average_predictions'." + ) + + pdp_data[feat_name] = { + "grid": pd_result["grid_values"][0], + "pd_values": np.asarray(pd_values), # shape: (n_targets, n_grid) + } + + return pdp_data + + +def plot_partial_dependence(pdp_data, output_dir, target_names): + """Create PDP plots for each feature x target combination.""" + _logger.info("Creating partial dependence plots...") + + features = list(pdp_data.keys()) + if not features: + _logger.warning("No valid features available for partial dependence plotting") + return + + # Create a grid of subplots: features x targets + _, axes = plt.subplots(len(features), len(target_names), figsize=(15, 5 * len(features))) + + if len(features) == 1: + axes = axes.reshape(1, -1) + if len(target_names) == 1: + axes = axes.reshape(-1, 1) + + for feat_idx, feat_name in enumerate(features): + for target_idx, target_name in enumerate(target_names): + ax = axes[feat_idx, target_idx] + + grid = pdp_data[feat_name]["grid"] + # pd_values shape: (n_targets, n_grid_points) + pd_vals = pdp_data[feat_name]["pd_values"][target_idx] + + ax.plot(grid, pd_vals, linewidth=2.5, marker="o", markersize=4, color="steelblue") + ax.fill_between(grid, pd_vals * 0.95, pd_vals * 1.05, alpha=0.2) + + ax.set_xlabel(feat_name) + ax.set_ylabel(f"Predicted {target_name}") + ax.set_title(f"{feat_name} -> {target_name}") + ax.grid(alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "partial_dependence.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved PDP plots to {output_path}") + plt.close() + + +def diagnose_physics(pdp_data): + """Check if PDP shows physically sensible behavior.""" + _logger.info("\n%s", "=" * 60) + _logger.info("PHYSICS VALIDATION: Partial Dependence Analysis") + _logger.info("%s", "=" * 60) + + # Check 1: Multiplicity effect (should reduce corrections for high multiplicity) + if "DispNImages" in pdp_data: + grid = pdp_data["DispNImages"]["grid"] + pd_vals = pdp_data["DispNImages"]["pd_values"][0] # first target + + slope = (pd_vals[-1] - pd_vals[0]) / (grid[-1] - grid[0]) + + _logger.info("\nMultiplicity Effect (DispNImages):") + _logger.info(f" Slope of PDP: {slope:.6f}") + if slope < 0: + _logger.info(" CORRECT - More telescopes → smaller corrections needed") + else: + _logger.info(" WARNING - Unexpected behavior (fewer telescopes → larger corr)") + + # Check 2: Baseline stability (should show smooth, monotonic response) + baseline_features = [feat for feat in pdp_data if "weighted_bdt" in feat or "intersect" in feat] + for feat_name in baseline_features: + pd_vals = pdp_data[feat_name]["pd_values"][0] + + # Compute smoothness: ratio of diff magnitudes + diffs = np.abs(np.diff(pd_vals)) + smoothness = np.std(diffs) / (np.mean(diffs) + 1e-6) + + _logger.info(f"\n{feat_name}:") + _logger.info(f" Smoothness index: {smoothness:.4f}") + if smoothness < 0.3: + _logger.info(" GOOD - Smooth, linear relationship (learned physics)") + elif smoothness < 0.6: + _logger.info(" OK - Some noise but generally smooth") + else: + _logger.info(" WARNING - Chaotic relationship (possible overtraining)") + + +def main(): + """Rebuild held-out data from model metadata and run PDP diagnostics.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + parser.add_argument( + "--features", + nargs="+", + default=["DispNImages", "Xoff_weighted_bdt", "Yoff_weighted_bdt", "ErecS"], + help="Features to plot PDP for", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + model, x_test, _, target_names = load_data_and_model(args.model_file, args.input_file_list) + pdp_data = compute_partial_dependence(model, x_test, args.features) + plot_partial_dependence(pdp_data, args.output_dir, target_names) + diagnose_physics(pdp_data) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py b/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py new file mode 100644 index 0000000..38a66ed --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py @@ -0,0 +1,181 @@ +"""Permutation importance for stereo regression models. + +This diagnostic rebuilds the held-out test split from the model metadata and the +original training input files, then shuffles features one-by-one and measures +the degradation in residual RMSE. + +Usage: + python diagnostic_permutation_importance.py \ + --model_file trained_stereo.joblib \ + --output_dir diagnostics/ \ + --top_n 20 +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.metrics import mean_squared_error + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def compute_baseline_rmse(model, x_test, y_test, model_dict, target_names): + """Compute baseline RMSE on unshuffled test set.""" + y_pred = diagnostic_utils.predict_unscaled_residuals(model, x_test, model_dict, target_names) + + baseline_rmse = {} + for target in target_names: + mse = mean_squared_error(y_test[target], y_pred[target]) + baseline_rmse[target] = np.sqrt(mse) + _logger.info(f" Baseline RMSE ({target}): {baseline_rmse[target]:.6f}") + + return baseline_rmse, y_pred + + +def permutation_importance(model, x_test, y_test, baseline_rmse, target_names, model_dict): + """Compute permutation importance for each feature.""" + _logger.info("Computing permutation importance...") + + importance = {target: {} for target in target_names} + + for feat_idx, feat_name in enumerate(x_test.columns): + x_shuffled = x_test.copy() + x_shuffled.iloc[:, feat_idx] = ( + x_shuffled.iloc[:, feat_idx] + .sample( + frac=1, + random_state=model_dict.get("random_state", None), + ) + .to_numpy() + ) + + y_pred_shuffled = diagnostic_utils.predict_unscaled_residuals( + model, + x_shuffled, + model_dict, + target_names, + ) + + for target_name in target_names: + mse_shuffled = mean_squared_error(y_test[target_name], y_pred_shuffled[target_name]) + rmse_shuffled = np.sqrt(mse_shuffled) + + # Importance = relative RMSE increase when feature is shuffled + relative_importance = (rmse_shuffled - baseline_rmse[target_name]) / baseline_rmse[ + target_name + ] + importance[target_name][feat_name] = relative_importance + + return importance + + +def plot_permutation_importance(importance, output_path, top_n=20): + """Plot permutation importance for each target.""" + _logger.info(f"Creating permutation importance plots (top {top_n})...") + + targets = list(importance.keys()) + _, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 8)) + if len(targets) == 1: + axes = [axes] + + for ax, target_name in zip(axes, targets): + imp_df = ( + pd.DataFrame( + list(importance[target_name].items()), + columns=["feature", "importance"], + ) + .sort_values("importance", ascending=True) + .tail(top_n) + ) + colors = ["red" if x < 0 else "green" for x in imp_df["importance"]] + + ax.barh(imp_df["feature"], imp_df["importance"] * 100.0, color=colors, alpha=0.7) + ax.set_xlabel("Relative RMSE Increase (%)") + ax.set_title(f"Permutation Importance\n{target_name}") + ax.axvline(x=0, color="black", linestyle="--", linewidth=1) + ax.grid(axis="x", alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved permutation importance plot to {output_path}") + plt.close() + + +def diagnose_baseline_anchoring(importance): + """Check if model is anchored in conventional baselines.""" + _logger.info("=== Physics Check: Baseline Anchoring ===") + + baseline_features = ["Xoff_weighted_bdt", "Yoff_weighted_bdt", "ErecS"] + + for target, imp_dict in importance.items(): + _logger.info(f"\n{target}:") + + baseline_contrib = sum( + imp_dict.get(feat, 0) for feat in baseline_features if feat in imp_dict + ) + total_contrib = sum(imp for imp in imp_dict.values() if imp > 0) + + anchor_pct = (baseline_contrib / total_contrib * 100) if total_contrib > 0 else 0 + + _logger.info(f" Baseline features contribution: {anchor_pct:.1f}%") + _logger.info(" (Expect >70% for well-anchored model)") + + for feat in baseline_features: + if feat in imp_dict: + _logger.info(f" {feat}: {imp_dict[feat] * 100:.2f}%") + + +def main(): + """Rebuild the held-out test split and compute permutation importance.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + parser.add_argument("--top_n", type=int, default=20, help="Top N features to plot") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + model, _, _, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + args.model_file, + args.input_file_list, + ) + ) + + _logger.info("Computing baseline RMSE...") + baseline_rmse, _ = compute_baseline_rmse(model, x_test, y_test, model_dict, target_names) + + importance = permutation_importance( + model, + x_test, + y_test, + baseline_rmse, + target_names, + model_dict, + ) + + output_path = Path(args.output_dir) / "permutation_importance.png" + plot_permutation_importance(importance, output_path, args.top_n) + + diagnose_baseline_anchoring(importance) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py b/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py new file mode 100644 index 0000000..fd2025f --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py @@ -0,0 +1,253 @@ +r"""Residual Normality & Outlier Check: Validate statistical quality of predictions. + +Tests if residuals are Gaussian and centered at zero. Non-normal residuals indicate +the model is failing on specific event types (e.g., edge-of-camera or low-multiplicity). +Heavy tails indicate outliers; skewness indicates systematic bias. + +Usage: + eventdisplay-ml-diagnostic-residual-normality \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from scipy import stats + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def load_predictions_and_targets(model_file, input_file_list=None): + """Rebuild test split and compute predicted residuals.""" + model, _, _, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + model_file, + input_file_list, + ) + ) + y_pred = diagnostic_utils.predict_unscaled_residuals(model, x_test, model_dict, target_names) + return y_pred, y_test, target_names + + +def compute_residuals(y_pred, y_true, target_names): + """Compute residuals (true - predicted). Standard ML residual convention.""" + residuals = {} + + for target_name in target_names: + if target_name in y_true.columns and target_name in y_pred.columns: + residuals[target_name] = y_true[target_name].values - y_pred[target_name].values + + return residuals + + +def compute_normality_stats(residuals): + """Compute Gaussian fit parameters and normality tests.""" + stats_dict = {} + + for target_name, resid in residuals.items(): + # Remove NaN values + resid_clean = resid[~np.isnan(resid)] + if len(resid_clean) == 0: + _logger.warning(f"Skipping {target_name}: no finite residuals") + continue + + # Gaussian parameters + mean = np.mean(resid_clean) + std = np.std(resid_clean) + + # Normality tests + _, p_ks = stats.kstest(resid_clean, "norm", args=(mean, std)) + ad_result = stats.anderson(resid_clean, dist="norm", method="interpolate") + ad_stat = ad_result.statistic + # With method='interpolate', ad_result has pvalue attribute + ad_pvalue = ad_result.pvalue if hasattr(ad_result, "pvalue") else np.nan + + # Skewness and kurtosis + skewness = stats.skew(resid_clean) + kurtosis = stats.kurtosis(resid_clean) + + # Quantile-Quantile test (visual) + _, (_, _, qq_r) = stats.probplot(resid_clean, dist="norm") + qq_r2 = qq_r**2 + + stats_dict[target_name] = { + "mean": mean, + "std": std, + "p_ks": p_ks, + "ad_stat": ad_stat, + "ad_pvalue": ad_pvalue, + "skewness": skewness, + "kurtosis": kurtosis, + "qq_r2": qq_r2, + "n_outliers": np.sum(np.abs(resid_clean) > 3 * std), + "n_samples": len(resid_clean), + } + + return stats_dict + + +def plot_residual_diagnostics(residuals, stats_dict, output_dir): + """Create comprehensive residual diagnostic plots.""" + _logger.info("Creating residual diagnostic plots...") + + target_names = list(residuals.keys()) + if not target_names: + _logger.warning("No residual targets to plot") + return + + # Create a 2xN grid: histogram + Q-Q plot for each target + _, axes = plt.subplots(2, len(target_names), figsize=(5 * len(target_names), 10)) + + if len(target_names) == 1: + axes = axes.reshape(2, 1) + + for col_idx, target_name in enumerate(target_names): + resid = residuals[target_name][~np.isnan(residuals[target_name])] + stat = stats_dict[target_name] + + # Row 0: Histogram with Gaussian overlay + ax_hist = axes[0, col_idx] + ax_hist.hist(resid, bins=50, density=True, alpha=0.7, color="steelblue", edgecolor="black") + + # Overlay Gaussian + x_range = np.linspace(resid.min(), resid.max(), 100) + gaussian = stats.norm.pdf(x_range, stat["mean"], stat["std"]) + ax_hist.plot(x_range, gaussian, "r-", linewidth=2, label="Normal fit") + + ax_hist.axvline(stat["mean"], color="green", linestyle="--", linewidth=2, label="Mean") + ax_hist.set_xlabel("Residual value") + ax_hist.set_ylabel("Density") + ax_hist.set_title(f"{target_name}\nmu={stat['mean']:.4f}, sigma={stat['std']:.4f}") + ax_hist.legend(fontsize=8) + ax_hist.grid(alpha=0.3) + + # Row 1: Q-Q plot + ax_qq = axes[1, col_idx] + stats.probplot(resid, dist="norm", plot=ax_qq) + ax_qq.set_title(f"Q-Q Plot (R²={stat['qq_r2']:.4f})") + ax_qq.grid(alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "residual_diagnostics.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved residual diagnostics to {output_path}") + plt.close() + + +def diagnose_residual_quality(residuals, stats_dict): + """Provide detailed diagnosis of residual quality.""" + _logger.info("\n%s", "=" * 60) + _logger.info("RESIDUAL NORMALITY & OUTLIER ANALYSIS") + _logger.info("%s", "=" * 60) + + for target_name, stat in stats_dict.items(): + _logger.info(f"\n{target_name}:") + _logger.info(f" Mean: {stat['mean']:.6f} (expect ~0)") + _logger.info(f" Std: {stat['std']:.6f}") + + if np.abs(stat["mean"]) < stat["std"] * 0.1: + _logger.info(" GOOD - Residuals centered at zero") + elif np.abs(stat["mean"]) < stat["std"] * 0.2: + _logger.info(" OK - Small systematic offset") + else: + _logger.info(" WARNING - Significant bias detected") + + _logger.info(f"\n Skewness: {stat['skewness']:.4f}") + if np.abs(stat["skewness"]) < 0.2: + _logger.info(" GOOD - Symmetric distribution") + elif np.abs(stat["skewness"]) < 0.5: + _logger.info(" OK - Mild asymmetry") + else: + _logger.info(" WARNING - Strong skew (model failing on certain events)") + + _logger.info(f"\n Kurtosis: {stat['kurtosis']:.4f}") + if np.abs(stat["kurtosis"]) < 0.5: + _logger.info(" GOOD - Gaussian-like tails") + elif np.abs(stat["kurtosis"]) < 1.0: + _logger.info(" OK - Slightly heavy/light tails") + else: + _logger.info(" WARNING - Heavy tails (outliers present)") + + _logger.info(f"\n Outliers (>3sigma): {stat['n_outliers']} events") + outlier_pct = stat["n_outliers"] / stat["n_samples"] * 100 + if outlier_pct < 0.3: + _logger.info(" GOOD - Minimal outliers") + elif outlier_pct < 1.0: + _logger.info(" OK - Few outliers") + else: + _logger.info(" WARNING - Excessive outliers") + + _logger.info(f"\n Kolmogorov-Smirnov test: p={stat['p_ks']:.4f}") + if stat["p_ks"] > 0.05: + _logger.info(" Gaussian hypothesis NOT rejected (p > 0.05)") + else: + _logger.info(" Distribution deviates from Gaussian (p < 0.05)") + + _logger.info( + f"\n Anderson-Darling normality: stat={stat['ad_stat']:.4f}, " + f"p-value={stat['ad_pvalue']:.4f}" + ) + if stat["ad_pvalue"] > 0.05: + _logger.info(" Anderson-Darling does not reject normality (p > 0.05)") + else: + _logger.info(" Anderson-Darling rejects normality (p < 0.05)") + + _logger.info(f"\n Q-Q R²: {stat['qq_r2']:.4f}") + if stat["qq_r2"] > 0.98: + _logger.info(" Excellent Gaussian fit") + elif stat["qq_r2"] > 0.95: + _logger.info(" Good Gaussian fit") + else: + _logger.info(" Fair fit, consider investigating tails") + + +def main(): + """Load cached residual normality stats or fall back to reconstructing the split.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + _, stats_dict = diagnostic_utils.load_cached_residual_normality_stats(args.model_file) + + residuals = None + if stats_dict is None: + _logger.info("Cached residual normality statistics unavailable; rebuilding from test split") + y_pred, y_true, target_names = load_predictions_and_targets( + args.model_file, + args.input_file_list, + ) + residuals = compute_residuals(y_pred, y_true, target_names) + stats_dict = compute_normality_stats(residuals) + plot_residual_diagnostics(residuals, stats_dict, args.output_dir) + else: + _logger.info("Using cached residual normality statistics from the model file") + _logger.info( + "Note: Diagnostic plots skipped when using cached statistics; " + "rerun without cache to regenerate plots" + ) + + diagnose_residual_quality(residuals or {}, stats_dict) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py index 3729b0a..7eb41e2 100644 --- a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py +++ b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py @@ -55,14 +55,12 @@ def plot_feature_importance(features, importances, target_name, output_dir): output_dir : str Output directory for plot. """ - # Create DataFrame and sort by importance importance_df = ( pd.DataFrame({"Feature": features, "Importance": importances}) .sort_values("Importance", ascending=True) .tail(20) ) - # Create plot _, ax = plt.subplots(figsize=(10, 8)) colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(importance_df))) @@ -96,8 +94,9 @@ def main(): Path(args.output_dir).mkdir(parents=True, exist_ok=True) - # Load model configuration with cached data - model_cfg, _model_dict = load_model_config(args.model_file) + _logger.info("=== SHAP Feature Importance Summary ===") + + model_cfg, _ = load_model_config(args.model_file) shap_importance = model_cfg.get("shap_importance") features = model_cfg.get("features") @@ -119,7 +118,7 @@ def main(): _logger.info(f"\nProcessing {target_name}...") plot_feature_importance(features, importances, target_name, args.output_dir) - _logger.info(f"\n✓ All plots saved to {args.output_dir}") + _logger.info(f"\nPlots saved to {args.output_dir}") if __name__ == "__main__": diff --git a/tests/test_train_regression_standardization.py b/tests/test_train_regression_standardization.py index 1f47646..2a389d3 100644 --- a/tests/test_train_regression_standardization.py +++ b/tests/test_train_regression_standardization.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from eventdisplay_ml import models +from eventdisplay_ml import diagnostic_utils, models @pytest.fixture @@ -247,8 +247,77 @@ def test_train_regression_complete_workflow( assert "models" in result assert "xgboost" in result["models"] assert "model" in result["models"]["xgboost"] + assert "generalization_metrics" in result["models"]["xgboost"] assert "shap_importance" in result["models"]["xgboost"] + def test_generalization_metrics_cached_per_target( + self, regression_training_df, regression_model_config + ): + """Verify train/test RMSE summary is cached in the model config.""" + result = models.train_regression(regression_training_df, regression_model_config) + + metrics = result["models"]["xgboost"]["generalization_metrics"] + assert set(metrics) == set(regression_model_config["targets"]) + + for target in regression_model_config["targets"]: + assert set(metrics[target]) == {"rmse_train", "rmse_test", "gap_pct", "gen_ratio"} + assert np.isfinite(metrics[target]["rmse_train"]) + assert np.isfinite(metrics[target]["rmse_test"]) + + def test_generalization_metrics_match_training_predictions( + self, regression_training_df, regression_model_config + ): + """Verify cached generalization metrics match the model predictions used in training.""" + df = regression_training_df + cfg = regression_model_config + + with patch("xgboost.XGBRegressor") as mock_xgb: + mock_model = MagicMock() + mock_model.best_iteration = 5 + mock_model.best_score = 0.1 + + def _predict(x_values): + return np.zeros((len(x_values), len(cfg["targets"]))) + + mock_model.predict.side_effect = _predict + mock_xgb.return_value = mock_model + + with patch("eventdisplay_ml.models.evaluate_regression_model") as mock_eval: + mock_eval.return_value = {} + result = models.train_regression(df, cfg) + + from sklearn.model_selection import train_test_split + + x_cols = [col for col in df.columns if col not in cfg["targets"]] + _, _, y_train, y_test = train_test_split( + df[x_cols], + df[cfg["targets"]], + train_size=cfg["train_test_fraction"], + random_state=cfg["random_state"], + ) + + target_mean = np.array([result["target_mean"][target] for target in cfg["targets"]]) + y_train_pred = pd.DataFrame( + np.tile(target_mean, (len(y_train), 1)), + columns=cfg["targets"], + index=y_train.index, + ) + y_test_pred = pd.DataFrame( + np.tile(target_mean, (len(y_test), 1)), + columns=cfg["targets"], + index=y_test.index, + ) + + expected_metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_test_pred, + cfg["targets"], + ) + + assert result["models"]["xgboost"]["generalization_metrics"] == expected_metrics + def test_scaled_predictions_unscaled_correctly( self, regression_training_df, regression_model_config ):