diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 37f98d8..f3dc37b 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -203,9 +203,11 @@ eventdisplay-ml-plot-classification-gamma-efficiency --help **Model Architecture**: - **Stereo reconstruction**: Multi-output regression (XGBoost) - - Targets: `[MCxoff, MCyoff, MCe0]` (x offset, y offset, log energy) + - Targets: `[Xoff_residual, Yoff_residual, E_residual]` (residuals relative to DispBDT) + - Residuals computed as: MC truth - DispBDT prediction + - During inference: final prediction = DispBDT baseline + predicted residual - Single model handles all telescope multiplicities (2-4+ telescopes) - - Features: Telescope-level arrays + event-level parameters + - Features: Telescope-level arrays + event-level parameters (including DispBDT results) - **Classification**: Binary classification (XGBoost) - Target: Gamma vs hadron (implicit in training data split) diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..6d5349e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +--- +# Set update schedule for GitHub Actions + +version: 2 +updates: + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every month + interval: "monthly" diff --git a/README.md b/README.md index 09180e2..26346d2 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,59 @@ Stereo analysis methods implemented in Eventdisplay provide direction / energies Output is a single ROOT tree called `StereoAnalysis` with the same number of events as the input tree. +### Training Stereo Reconstruction Models + +The stereo regression training pipeline uses multi-target XGBoost to predict residuals (deviations from baseline reconstructions): + +**Targets:** `[Xoff_residual, Yoff_residual, E_residual]` (residuals on direction and energy as reconstruction by the BDT stereo reconstruction method) + +**Key techniques:** + +- **Target standardization:** Targets are mean-centered and scaled to unit variance during training +- **Energy-bin weighting:** Events are weighted inversely by energy bin density; bins with fewer than 10 events are excluded from training to prevent overfitting on low-statistics regions +- **Multiplicity weighting:** Higher-multiplicity events (more telescopes) receive higher sample weights to prioritize high-confidence reconstructions +- **Per-target SHAP importance:** Feature importance values computed during training for each target and cached for later analysis + +**Training command:** + +```bash +eventdisplay-ml-train-xgb-stereo \ + --input_file_list train_files.txt \ + --model_prefix models/stereo_model \ + --max_events 100000 \ + --train_test_fraction 0.5 \ + --max_cores 8 +``` + +**Output:** Joblib model file containing: + +- XGBoost trained model object +- Target standardization scalers (mean/std) +- Feature list and SHAP importance rankings +- Training metadata (random state, hyperparameters) + +### Applying Stereo Reconstruction Models + +The apply pipeline loads trained models and makes predictions: + +**Key safeguards:** + +- Invalid energy values (≤0 or NaN) produce NaN outputs but preserve all input event rows +- Missing standardization parameters raise ValueError (prevents silent data corruption) +- Output row count always equals input row count + +**Apply command:** + +```bash +eventdisplay-ml-apply-xgb-stereo \ + --input_file_list apply_files.txt \ + --output_file_list output_files.txt \ + --model_prefix models/stereo_model +``` + + +**Output:** ROOT files with `StereoAnalysis` tree containing reconstructed Xoff, Yoff, and log10(E). + ## Gamma/hadron separation using XGBoost Gamma/hadron separation is performed using XGB Boost classification trees. Features are image parameters and stereo reconstruction parameters provided by Eventdisplay. @@ -27,6 +80,223 @@ The zenith angle dependence is accounted for by including the zenith angle as a Output is a single ROOT tree called `Classification` with the same number of events as the input tree. It contains the classification prediction (`Gamma_Prediction`) and boolean flags (e.g. `Is_Gamma_75` for 75% signal efficiency cut). +## Diagnostic Tools + +The committed regression diagnostics in this branch are: + +### SHAP feature-importance summary + + Tests: Feature importance + +- Load per-target SHAP importances cached in the trained model file +- Create one top-20 feature plot per residual target (`Xoff_residual`, `Yoff_residual`, `E_residual`) + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated PNGs + +Run: + +```bash + eventdisplay-ml-diagnostic-shap-summary \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Outputs: + +- `diagnostics/shap_importance_Xoff_residual.png` +- `diagnostics/shap_importance_Yoff_residual.png` +- `diagnostics/shap_importance_E_residual.png` + +### Permutation importance + +- Rebuild the held-out test split from the model metadata and original input files +- Shuffle one feature at a time and measure the relative RMSE increase per residual target +- Validate predictive dependence on features rather than cached model attribution + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots +- `--top_n`: number of top features to include in the plot (optional) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-permutation-importance \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ \ + --top_n 20 +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-permutation-importance \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --output_dir diagnostics/ +``` + +Output: + +- `diagnostics/permutation_importance.png` + +Notes: + +- This diagnostic is slower than the SHAP summary because it rebuilds the processed test split. +- It is the better choice when you want to measure actual performance sensitivity to each feature. + +### Generalization gap + +- Read the cached train/test RMSE summary written during training +- Compare final train and test RMSE for each residual target +- Quantify the overfitting gap after training is complete + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-generalization-gap \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-generalization-gap \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --output_dir diagnostics/ +``` + +Output: + +- `diagnostics/generalization_gap.png` + +Notes: + +- This diagnostic measures final overfitting by comparing train and test residual RMSE. +- Older model files without cached metrics fall back to rebuilding the original train/test split. +- Unlike `plot_training_evaluation.py`, it summarizes final RMSE, not the per-iteration XGBoost training history. + +### Partial Dependence Plots + +- Visualize how each feature influences model predictions +- Prove the model captures physics by checking that multiplicity reduces corrections and baselines show smooth relationships + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots (optional; default: `diagnostics`) +- `--features`: space-separated list of features to plot (optional; default: `DispNImages Xoff_weighted_bdt Yoff_weighted_bdt ErecS`) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-partial-dependence \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ \ + --features DispNImages Xoff_weighted_bdt ErecS +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-partial-dependence \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt \ + --features Xoff_weighted_bdt Yoff_weighted_bdt +``` + +Output: + +- `diagnostics/partial_dependence.png` (grid of feature × target subplots) + +Notes: + +- PDP displays predicted residual output as a function of a single feature while holding others constant +- Multiplicity effect: high-multiplicity events should show smaller corrections (negative slope) +- Baseline stability: baseline features (e.g., `weighted_bdt`) should show smooth, linear relationships +- This diagnostic rebuilds the held-out test split and is slower than SHAP summary + +### Residual Normality Diagnostics + +- Validate that model residuals follow a normal distribution +- Detect outlier events and check for systematic biases in reconstruction errors + +Required inputs: + +- `--model_file`: trained stereo model `.joblib` +- `--output_dir`: directory for generated plots (optional; default: `diagnostics`) +- `--input_file_list`: optional override if the path stored in the model metadata is no longer valid + +Run: + +```bash +eventdisplay-ml-diagnostic-residual-normality \ + --model_file models/stereo_model.joblib \ + --output_dir diagnostics/ +``` + +Optional override: + +```bash +eventdisplay-ml-diagnostic-residual-normality \ + --model_file models/stereo_model.joblib \ + --input_file_list files.txt +``` + +Output: + +- Residual normality statistics printed to console: + - Mean and standard deviation per target + - Kolmogorov-Smirnov test p-value (normality test) + - Anderson-Darling test statistic and critical value + - Skewness and kurtosis + - Q-Q plot R² value + - Number of outliers (>3σ) per target +- `diagnostics/residual_diagnostics.png` (single 2xN grid; generated on cache miss when reconstruction is required) + +Notes: + +- Residual normality stats are cached during training and loaded from the model file for fast retrieval +- Diagnostic plots (histograms, Q-Q plots) are only generated when the split must be reconstructed +- Invalid KS test or Anderson-Darling results (NaN/inf) are reported as special values +- Outlier counts help identify events with unusually large reconstruction errors + +### Training-evaluation curves + +- Plot XGBoost training vs validation metric curves +- Useful for checking convergence and overfitting behavior + +Required inputs: + +- `--model_file`: trained model `.joblib` containing an XGBoost model +- `--output_file`: output image path (optional; if omitted, plot is shown interactively) + +Run: + +```bash +eventdisplay-ml-plot-training-evaluation \ + --model_file models/stereo_model.joblib \ + --output_file diagnostics/training_curves.png +``` + +Output: + +- Figure with one panel per tracked metric (for example `rmse`), showing training and test curves. + ## Generative AI disclosure Generative AI tools (including Claude, ChatGPT, and Gemini) were used to assist with code development, debugging, and documentation drafting. All AI-assisted outputs were reviewed, validated, and, where necessary, modified by the authors to ensure accuracy and reliability. diff --git a/docs/changes/53.feature.md b/docs/changes/53.feature.md new file mode 100644 index 0000000..cc91512 --- /dev/null +++ b/docs/changes/53.feature.md @@ -0,0 +1,37 @@ +## Stereo Regression: Training on Residuals with Standardization and Energy Weighting + +### Architectural Change + +- **Training targets changed from absolute to residual values**: Models now predict residuals (deviations from baseline reconstructions) rather than absolute directions/energies. This allows XGBoost to learn corrections to existing Eventdisplay reconstructions (DispBDT, intersection method) and leverage their baseline accuracy as a starting point. + +### Critical Bug Fixes + +- **Fixed double log10 application**: Energy residuals computed in linear space; log10 applied explicitly during evaluation +- **Fixed standardization inversion**: Apply pipeline now loads and validates target_mean/target_std scalers (prevents KeyError) +- **Fixed energy-bin weighting**: Bins with <10 events get zero weight; correct inverse weighting for balanced training +- **Fixed ErecS validation**: Safe log10 computation during apply; all input rows preserved in output +- **Fixed evaluation metrics**: Energy resolution compared in log10 space with proper baseline alignment +- **Fixed FutureWarning**: Series positional indexing converted to numpy arrays for pandas compatibility + +### New Features + +- **Target standardization in training**: Residuals standardized to mean=0, std=1 during training to enable multi-target learning with balanced learning signals (direction and energy equally weighted) +- **Energy-bin weighted training**: Events weighted inversely by energy bin density; bins with <10 events excluded to prevent overfitting on low-statistics regions +- **Per-target SHAP importance caching**: Feature importances computed once during training for each target (Xoff_residual, Yoff_residual, E_residual), cached for diagnostic tools +- **Diagnostic scripts**: + - `diagnostic_shap_summary.py`: Top-20 feature importance plots per residual target + - `plot_training_evaluation.py`: Energy resolution and residual distribution visualization +- **Comprehensive test suites**: 20 new tests covering residual computation, standardization, energy weighting, apply inference +- **Robust error handling**: Clear messages for missing scalers; guaranteed row-count preservation in apply pipeline + +### Enhanced Diagnostic Pipeline + +- **Generalization-gap metrics cached during training**: Train/test RMSE, gap %, and generalization ratio computed and cached in the model artifact, enabling fast overfitting assessment without recomputation +- **Residual normality statistics cached during training**: Normality tests (Kolmogorov-Smirnov, Anderson-Darling), distribution shape metrics (skewness, kurtosis, Q-Q R²), and outlier counts computed once during training and cached for fast retrieval +- **Diagnostic reconstruction from model metadata**: All regression diagnostics (generalization-gap, partial-dependence, residual-normality) now reconstruct the held-out test split from stored model metadata + input file list, enabling reproducibility and offline analysis without CSV exports +- **Cache-first diagnostic workflows**: Diagnostic scripts load cached metrics first (fast) with graceful fallback to reconstruction if cache unavailable (backward compatible with older models) +- **CLI entry points for all diagnostics**: + - `eventdisplay-ml-diagnostic-generalization-gap`: Quantify overfitting via train/test RMSE comparison + - `eventdisplay-ml-diagnostic-partial-dependence`: Validate model captures physics via partial dependence curves + - `eventdisplay-ml-diagnostic-residual-normality`: Validate residual normality and detect outliers +- **Fixed sklearn FutureWarning**: Partial dependence plots convert feature data to float64 to avoid integer dtype warnings in newer scikit-learn versions diff --git a/pyproject.toml b/pyproject.toml index f593ab2..b4f7445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,14 @@ urls."documentation" = "https://github.com/Eventdisplay/Eventdisplay-ML" urls."repository" = "https://github.com/Eventdisplay/Eventdisplay-ML" scripts.eventdisplay-ml-apply-xgb-classify = "eventdisplay_ml.scripts.apply_xgb_classify:main" scripts.eventdisplay-ml-apply-xgb-stereo = "eventdisplay_ml.scripts.apply_xgb_stereo:main" +scripts.eventdisplay-ml-diagnostic-generalization-gap = "eventdisplay_ml.scripts.diagnostic_generalization_gap:main" +scripts.eventdisplay-ml-diagnostic-partial-dependence = "eventdisplay_ml.scripts.diagnostic_partial_dependence:main" +scripts.eventdisplay-ml-diagnostic-permutation-importance = "eventdisplay_ml.scripts.diagnostic_permutation_importance:main" +scripts.eventdisplay-ml-diagnostic-residual-normality = "eventdisplay_ml.scripts.diagnostic_residual_normality:main" +scripts.eventdisplay-ml-diagnostic-shap-summary = "eventdisplay_ml.scripts.diagnostic_shap_summary:main" scripts.eventdisplay-ml-plot-classification-performance-metrics = "eventdisplay_ml.scripts.plot_classification_performance_metrics:main" scripts.eventdisplay-ml-plot-classification-gamma-efficiency = "eventdisplay_ml.scripts.plot_classification_gamma_efficiency:main" +scripts.eventdisplay-ml-plot-training-evaluation = "eventdisplay_ml.scripts.plot_training_evaluation:main" scripts.eventdisplay-ml-train-xgb-classify = "eventdisplay_ml.scripts.train_xgb_classify:main" scripts.eventdisplay-ml-train-xgb-stereo = "eventdisplay_ml.scripts.train_xgb_stereo:main" diff --git a/src/eventdisplay_ml/config.py b/src/eventdisplay_ml/config.py index 30be4f3..8e6fd86 100644 --- a/src/eventdisplay_ml/config.py +++ b/src/eventdisplay_ml/config.py @@ -239,5 +239,8 @@ def configure_apply(analysis_type): ) model_configs["energy_bins_log10_tev"] = par.get("energy_bins_log10_tev", []) model_configs["zenith_bins_deg"] = par.get("zenith_bins_deg", []) + if analysis_type == "stereo_analysis": + model_configs["target_mean"] = par.get("target_mean") + model_configs["target_std"] = par.get("target_std") return model_configs diff --git a/src/eventdisplay_ml/data_processing.py b/src/eventdisplay_ml/data_processing.py index 41d0516..c2fa8b5 100644 --- a/src/eventdisplay_ml/data_processing.py +++ b/src/eventdisplay_ml/data_processing.py @@ -871,11 +871,36 @@ def load_training_data(model_configs, file_list, analysis_type): max_tel_per_type=model_configs.get("max_tel_per_type", None), preview_rows=model_configs.get("preview_rows", 20), ) + + # Filter out events with invalid energy reconstruction for stereo training + if analysis_type == "stereo_analysis": + n_before_erec_filter = len(df_flat) + valid_erec_mask = (df_flat["ErecS"] > 0) & np.isfinite(df_flat["ErecS"]) + df_flat = df_flat[valid_erec_mask] + n_removed_erec = n_before_erec_filter - len(df_flat) + if n_removed_erec > 0: + _logger.info( + f"Removed {n_removed_erec} events with ErecS <= 0 or NaN " + f"(fraction removed: {n_removed_erec / n_before_erec_filter:.4f})" + ) + if analysis_type == "stereo_analysis": + mc_xoff = _to_numpy_1d(df["MCxoff"], np.float32)[valid_erec_mask] + mc_yoff = _to_numpy_1d(df["MCyoff"], np.float32)[valid_erec_mask] + mc_e0 = _to_numpy_1d(df["MCe0"], np.float32)[valid_erec_mask] + + disp_xoff = df_flat["Xoff_weighted_bdt"].values + disp_yoff = df_flat["Yoff_weighted_bdt"].values + disp_erec = df_flat["ErecS"].values + + # Compute log energies (ErecS already filtered > 0) + mc_e0_log = np.where(mc_e0 > 0, np.log10(mc_e0), np.nan) + disp_erec_log = np.log10(disp_erec) # Safe since already filtered > 0 + new_cols = { - "MCxoff": _to_numpy_1d(df["MCxoff"], np.float32), - "MCyoff": _to_numpy_1d(df["MCyoff"], np.float32), - "MCe0": np.log10(_to_numpy_1d(df["MCe0"], np.float32)), + "Xoff_residual": mc_xoff - disp_xoff, + "Yoff_residual": mc_yoff - disp_yoff, + "E_residual": mc_e0_log - disp_erec_log, } elif analysis_type == "classification": new_cols = { @@ -887,6 +912,22 @@ def load_training_data(model_configs, file_list, analysis_type): for col_name, values in new_cols.items(): df_flat[col_name] = values + # Filter out events with NaN in residuals (can't train on these) + if analysis_type == "stereo_analysis": + n_before_nan_filter = len(df_flat) + valid_mask = ( + np.isfinite(df_flat["Xoff_residual"]) + & np.isfinite(df_flat["Yoff_residual"]) + & np.isfinite(df_flat["E_residual"]) + ) + df_flat = df_flat[valid_mask] + n_removed = n_before_nan_filter - len(df_flat) + if n_removed > 0: + _logger.info( + f"Removed {n_removed} events with NaN residuals " + f"(fraction removed: {n_removed / n_before_nan_filter:.4f})" + ) + dfs.append(df_flat) del df @@ -1159,14 +1200,24 @@ def extra_columns(df, analysis_type, training, index, tel_config=None, observato data["ze_bin"] = _to_numpy_1d(df["ze_bin"], np.float32) df_extra = pd.DataFrame(data, index=index) - apply_clip_intervals( - df_extra, - apply_log10=[ + # For stereo_analysis, Erec/ErecS must remain in linear space for residual computation + # (log10 is applied explicitly when computing E_residual = log10(MC) - log10(ErecS)) + # For classification, Erec/ErecS can be log10'd as features + if analysis_type == "stereo_analysis": + apply_log10_list = [ + "EChi2S", + "EmissionHeightChi2", + ] + else: + apply_log10_list = [ "EChi2S", "EmissionHeightChi2", "Erec", "ErecS", - ], + ] + apply_clip_intervals( + df_extra, + apply_log10=apply_log10_list, ) return df_extra diff --git a/src/eventdisplay_ml/diagnostic_utils.py b/src/eventdisplay_ml/diagnostic_utils.py new file mode 100644 index 0000000..aae271e --- /dev/null +++ b/src/eventdisplay_ml/diagnostic_utils.py @@ -0,0 +1,401 @@ +"""Utilities for inspecting cached diagnostic data in model joblib files. + +The current stereo regression cache stores per-target SHAP importances under +``models[]["shap_importance"]``. This module supports that layout while +remaining compatible with older cache keys where possible. +""" + +import logging + +import joblib +import numpy as np +import pandas as pd +from scipy import stats +from sklearn.metrics import mean_squared_error +from sklearn.model_selection import train_test_split + +from eventdisplay_ml.data_processing import load_training_data + +_logger = logging.getLogger(__name__) + + +def _load_model_cfg(model_file): + """Load full model dictionary and the first model configuration entry.""" + model_dict = joblib.load(model_file) + models = model_dict.get("models", {}) + model_cfg = next(iter(models.values())) if models else None + return model_dict, model_cfg + + +def load_stereo_regression_split(model_file, input_file_list=None): + """Load a stereo regression model and reconstruct its train/test split. + + Parameters + ---------- + model_file : str + Path to trained model joblib file. + input_file_list : str or None, optional + Optional override for the input file list stored in the model metadata. + + Returns + ------- + tuple + Trained model, reconstructed x_train, y_train, x_test, y_test, + feature names, target names, and full model metadata. + """ + _logger.info(f"Loading model from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + raise ValueError(f"No models found in model file: {model_file}") + + model = model_cfg.get("model") + if model is None: + raise ValueError(f"No trained model object found in model file: {model_file}") + + file_list = input_file_list or model_dict.get("input_file_list") + if not file_list: + raise ValueError( + "No input file list available. Provide --input_file_list or retrain with " + "input_file_list stored in the model metadata." + ) + + _logger.info(f"Rebuilding training data from input file list: {file_list}") + df = load_training_data(model_dict, file_list, "stereo_analysis") + + features = model_cfg.get("features") or model_dict.get("features", []) + targets = model_dict.get("targets", ["Xoff_residual", "Yoff_residual", "E_residual"]) + if not features: + raise ValueError(f"No feature list found in model file: {model_file}") + + x_data = df[features] + y_data = df[targets] + + _logger.info("Reconstructing train/test split from model metadata") + x_train, x_test, y_train, y_test = train_test_split( + x_data, + y_data, + train_size=model_dict.get("train_test_fraction", 0.5), + random_state=model_dict.get("random_state", None), + ) + + _logger.info( + "Reconstructed split with %d training and %d test events", + len(x_train), + len(x_test), + ) + return model, x_train, y_train, x_test, y_test, features, targets, model_dict + + +def predict_unscaled_residuals(model, x_data, model_dict, target_names): + """Predict residual targets and inverse-standardize them to original scale.""" + preds_scaled = model.predict(x_data) + + target_mean_cfg = model_dict.get("target_mean") + target_std_cfg = model_dict.get("target_std") + if not target_mean_cfg or not target_std_cfg: + raise ValueError( + "Missing target standardization parameters (target_mean/target_std) in model " + "file. This diagnostic requires a residual-trained stereo model." + ) + + target_mean = np.array([target_mean_cfg[target] for target in target_names], dtype=np.float64) + target_std = np.array([target_std_cfg[target] for target in target_names], dtype=np.float64) + + preds = preds_scaled * target_std + target_mean + return pd.DataFrame(preds, columns=target_names, index=x_data.index) + + +def compute_generalization_metrics(y_train, y_train_pred, y_test, y_test_pred, target_names): + """Compute train/test RMSE and relative generalization gap per target.""" + metrics = {} + + for target_name in target_names: + rmse_train = np.sqrt(mean_squared_error(y_train[target_name], y_train_pred[target_name])) + rmse_test = np.sqrt(mean_squared_error(y_test[target_name], y_test_pred[target_name])) + + if rmse_train == 0: + gap_pct = 0.0 if rmse_test == 0 else np.inf + else: + gap_pct = (rmse_test - rmse_train) / rmse_train * 100 + + # Unitless generalization ratio: >1 means worse on test than train. + if rmse_train == 0: + gen_ratio = 1.0 if rmse_test == 0 else np.inf + else: + gen_ratio = rmse_test / rmse_train + + metrics[target_name] = { + "rmse_train": float(rmse_train), + "rmse_test": float(rmse_test), + "gap_pct": float(gap_pct), + "gen_ratio": float(gen_ratio), + } + + return metrics + + +def load_cached_generalization_metrics(model_file): + """Load cached train/test RMSE summary from a model file if available.""" + _logger.info(f"Loading cached generalization metrics from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + metrics = model_cfg.get("generalization_metrics") + if not isinstance(metrics, dict) or not metrics: + _logger.warning("No cached generalization metrics found in model file") + return model_dict, None + + _logger.info("Loaded cached generalization metrics for %d targets", len(metrics)) + return model_dict, metrics + + +def compute_residual_normality_stats(y_test, y_test_pred, target_names): + """Compute Gaussian fit parameters and normality tests for residuals.""" + stats_dict = {} + + for target_name in target_names: + residuals = y_test[target_name].values - y_test_pred[target_name].values + residuals_clean = residuals[~np.isnan(residuals)] + + if len(residuals_clean) == 0: + _logger.warning(f"Skipping {target_name}: no finite residuals") + continue + + # Gaussian parameters + mean = float(np.mean(residuals_clean)) + std = float(np.std(residuals_clean)) + + # Normality tests + _, p_ks = stats.kstest(residuals_clean, "norm", args=(mean, std)) + ad_result = stats.anderson(residuals_clean, dist="norm", method="interpolate") + ad_stat = float(ad_result.statistic) + # With method='interpolate', ad_result is SignificanceResult with pvalue + # (not AndersonResult, which has critical_values) + ad_crit_5 = float(ad_result.pvalue if hasattr(ad_result, "pvalue") else np.nan) + + # Skewness and kurtosis + skewness = float(stats.skew(residuals_clean)) + kurtosis = float(stats.kurtosis(residuals_clean)) + + # Quantile-Quantile test (visual) + _, (_, _, qq_r) = stats.probplot(residuals_clean, dist="norm") + qq_r2 = float(qq_r**2) + + # Outlier count + n_outliers = int(np.sum(np.abs(residuals_clean) > 3 * std)) + + stats_dict[target_name] = { + "mean": mean, + "std": std, + "p_ks": float(p_ks), + "ad_stat": ad_stat, + "ad_pvalue": ad_crit_5, + "skewness": skewness, + "kurtosis": kurtosis, + "qq_r2": qq_r2, + "n_outliers": n_outliers, + "n_samples": len(residuals_clean), + } + + return stats_dict + + +def load_cached_residual_normality_stats(model_file): + """Load cached residual normality statistics from a model file if available.""" + _logger.info(f"Loading cached residual normality stats from {model_file}") + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + normality_stats = model_cfg.get("residual_normality_stats") + if not isinstance(normality_stats, dict) or not normality_stats: + _logger.warning("No cached residual normality statistics found in model file") + return model_dict, None + + _logger.info("Loaded cached residual normality stats for %d targets", len(normality_stats)) + return model_dict, normality_stats + + +def load_model_and_importance(model_file, target_name=None): + """ + Load model dict and precomputed feature importance. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + dict + Full model dictionary with model metadata. + dict or None + Precomputed feature importances {feature_name: importance_value} for + the selected target. + """ + _logger.info(f"Loading model from {model_file}") + + model_dict, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + _logger.warning("No models found in model file") + return model_dict, None + + shap_importance = model_cfg.get("shap_importance") + features = model_cfg.get("features") or model_dict.get("features", []) + + if shap_importance is None: + # Backward compatibility for legacy key if present. + shap_importance = model_cfg.get("feature_importances") + + if shap_importance is None or not features: + _logger.warning("No cached feature importances found in model file") + return model_dict, None + + if isinstance(shap_importance, dict): + if not shap_importance: + _logger.warning("Cached SHAP importance dictionary is empty") + return model_dict, None + + selected_target = target_name or next(iter(shap_importance)) + importances = shap_importance.get(selected_target) + if importances is None: + _logger.warning( + "Target %r not found in cached SHAP importance; available targets: %s", + selected_target, + list(shap_importance.keys()), + ) + return model_dict, None + + importance_dict = dict(zip(features, importances, strict=False)) + _logger.info( + "Loaded cached SHAP importances for target %s (%d features)", + selected_target, + len(importance_dict), + ) + return model_dict, importance_dict + + importance_dict = dict(zip(features, shap_importance, strict=False)) + _logger.info("Loaded cached feature importances (%d features)", len(importance_dict)) + return model_dict, importance_dict + + +def get_cached_shap_explainer(model_file): + """ + Retrieve cached SHAP explainer if available. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + shap.TreeExplainer or None + Cached SHAP explainer, or None if not available. + """ + _logger.info(f"Loading SHAP explainer from {model_file}") + _, model_cfg = _load_model_cfg(model_file) + + if model_cfg is None: + return None + + explainer = model_cfg.get("shap_explainer") + if explainer is not None: + _logger.info("Successfully loaded cached SHAP explainer") + return explainer + + _logger.warning("No cached SHAP explainer found. Will compute on-the-fly.") + return None + + +def importance_dataframe(model_file, top_n=25, target_name=None): + """ + Get feature importance as a sorted pandas DataFrame. + + Parameters + ---------- + model_file : str + Path to joblib model file. + top_n : int, optional + Return only top N features by importance (default 25). + + Returns + ------- + pd.DataFrame + DataFrame with columns ["Feature", "Importance"] sorted by importance. + """ + _, importance_dict = load_model_and_importance(model_file, target_name=target_name) + + if importance_dict is None: + _logger.error("Cannot create importance dataframe: no importances found") + return pd.DataFrame() + + df = pd.DataFrame( + { + "Feature": list(importance_dict.keys()), + "Importance": list(importance_dict.values()), + } + ).sort_values("Importance", ascending=False) + + if top_n: + df = df.head(top_n) + + return df + + +def validate_cached_data(model_file): + """ + Check what data is cached in the model file. + + Parameters + ---------- + model_file : str + Path to joblib model file. + + Returns + ------- + dict + Summary of cached data availability. + """ + model_dict, model_cfg = _load_model_cfg(model_file) + model_cfg = model_cfg or {} + + shap_importance = model_cfg.get("shap_importance") + has_shap_importance = shap_importance is not None + shap_targets = list(shap_importance.keys()) if isinstance(shap_importance, dict) else [] + + summary = { + "has_model": "model" in model_cfg, + "has_features": "features" in model_cfg, + "has_shap_importance": has_shap_importance, + "has_generalization_metrics": "generalization_metrics" in model_cfg, + "has_residual_normality_stats": "residual_normality_stats" in model_cfg, + "has_feature_importances": "feature_importances" in model_cfg, # legacy key + "has_shap_explainer": "shap_explainer" in model_cfg, + "has_target_mean": "target_mean" in model_dict, + "has_target_std": "target_std" in model_dict, + "n_features": len(model_cfg.get("features", [])), + "n_targets_with_shap": len(shap_targets), + "shap_targets": shap_targets, + "generalization_targets": list(model_cfg.get("generalization_metrics", {}).keys()), + "residual_normality_targets": list(model_cfg.get("residual_normality_stats", {}).keys()), + "n_importances": len(model_cfg.get("feature_importances", [])) + if "feature_importances" in model_cfg + else 0, + } + + if isinstance(shap_importance, dict): + summary["n_importances_per_target"] = { + target: len(values) for target, values in shap_importance.items() + } + else: + summary["n_importances_per_target"] = {} + + return summary diff --git a/src/eventdisplay_ml/evaluate.py b/src/eventdisplay_ml/evaluate.py index cd49336..39585e5 100644 --- a/src/eventdisplay_ml/evaluate.py +++ b/src/eventdisplay_ml/evaluate.py @@ -12,8 +12,6 @@ mean_squared_error, ) -from eventdisplay_ml.features import target_features - _logger = logging.getLogger(__name__) @@ -69,27 +67,48 @@ def evaluate_classification_model(model, x_test, y_test, df, x_cols, name): def evaluate_regression_model( - model, x_test, y_test, df, x_cols, y_data, name, shap_per_energy=False + model, x_test, y_pred, y_test, df, x_cols, y_data, name, shap_per_energy=False ): - """Evaluate the trained model on the test set and log performance metrics.""" - score = model.score(x_test, y_test) - _logger.info(f"XGBoost Multi-Target R^2 Score (Testing Set): {score:.4f}") - y_pred = model.predict(x_test) + """Evaluate the trained model on the test set and log performance metrics. + + Parameters + ---------- + model : XGBRegressor + Trained model. + x_test : pd.DataFrame + Test features. + y_pred : pd.DataFrame + Predicted targets (already inverse-transformed to original scale). + y_test : pd.DataFrame + True targets (in original scale). + df : pd.DataFrame + Full dataset for accessing baseline values. + x_cols : list + Feature column names. + y_data : pd.DataFrame + All target data. + name : str + Model name. + shap_per_energy : bool, optional + Whether to compute SHAP values per energy bin. + """ + # Compute metrics on original-scale predictions mse = mean_squared_error(y_test, y_pred) - _logger.info(f"{name} Mean Squared Error (All targets): {mse:.4f}") + _logger.info(f"{name} Mean Squared Error (All targets, unscaled): {mse:.4f}") mae = mean_absolute_error(y_test, y_pred) - _logger.info(f"{name} Mean Absolute Error (All targets): {mae:.4f}") + _logger.info(f"{name} Mean Absolute Error (All targets, unscaled): {mae:.4f}") target_variance(y_test, y_pred, y_data.columns) feature_importance(model, x_cols, y_data.columns, name) + + shap_importance_dict = {} if name == "xgboost": - shap_feature_importance(model, x_test, y_data.columns) + shap_importance_dict = shap_feature_importance(model, x_test, y_data.columns) if shap_per_energy: shap_feature_importance_by_energy(model, x_test, df, y_test, y_data.columns) - df_pred = pd.DataFrame(y_pred, columns=target_features("stereo_analysis")) calculate_resolution( - df_pred, + y_pred, y_test, df, percentiles=[68, 90, 95], @@ -99,13 +118,16 @@ def evaluate_regression_model( name=name, ) + return shap_importance_dict + def target_variance(y_test, y_pred, targets): """Calculate and log variance explained per target.""" y_test_np = y_test.to_numpy() if hasattr(y_test, "to_numpy") else y_test - mse_values = np.mean((y_test_np - y_pred) ** 2, axis=0) - variance_values = np.var(y_test_np, axis=0) + # Force numpy arrays so integer indexing is positional and future-proof. + mse_values = np.asarray(np.mean((y_test_np - y_pred) ** 2, axis=0)) + variance_values = np.asarray(np.var(y_test_np, axis=0)) _logger.info("--- Performance Per Target ---") for i, name in enumerate(targets): @@ -127,14 +149,40 @@ def target_variance(y_test, y_pred, targets): def calculate_resolution(y_pred, y_test, df, percentiles, log_e_min, log_e_max, n_bins, name): """Compute angular and energy resolution based on predictions.""" + # Model predicts residuals, so reconstruct full predictions and MC truth + # from residuals and DispBDT baseline + _logger.debug( + f"Evaluation: y_test indices min={y_test.index.min()}, max={y_test.index.max()}, len={len(y_test)}" + ) + _logger.debug( + f"Evaluation: df shape={df.shape}, index min={df.index.min()}, max={df.index.max()}" + ) + + disp_xoff = df.loc[y_test.index, "Xoff_weighted_bdt"].values + disp_yoff = df.loc[y_test.index, "Yoff_weighted_bdt"].values + + # Handle ErecS with proper checks for valid values + erec_s = df.loc[y_test.index, "ErecS"].values + disp_erec_log = np.where(erec_s > 0, np.log10(erec_s), np.nan) + + # Reconstruct MC truth from residuals in y_test (residual = MC_true - DispBDT) + mc_xoff_true = y_test["Xoff_residual"].values + disp_xoff + mc_yoff_true = y_test["Yoff_residual"].values + disp_yoff + mc_e0_true = y_test["E_residual"].values + disp_erec_log + + # Reconstruct predictions from residual predictions + mc_xoff_pred = y_pred["Xoff_residual"].values + disp_xoff + mc_yoff_pred = y_pred["Yoff_residual"].values + disp_yoff + mc_e0_pred = y_pred["E_residual"].values + disp_erec_log + results_df = pd.DataFrame( { - "MCxoff_true": y_test["MCxoff"].values, - "MCyoff_true": y_test["MCyoff"].values, - "MCxoff_pred": y_pred["MCxoff"].values, - "MCyoff_pred": y_pred["MCyoff"].values, - "MCe0_pred": y_pred["MCe0"].values, - "MCe0": df.loc[y_test.index, "MCe0"].values, + "MCxoff_true": mc_xoff_true, + "MCyoff_true": mc_yoff_true, + "MCxoff_pred": mc_xoff_pred, + "MCyoff_pred": mc_yoff_pred, + "MCe0_pred": mc_e0_pred, + "MCe0": mc_e0_true, } ) @@ -143,6 +191,12 @@ def calculate_resolution(y_pred, y_test, df, percentiles, log_e_min, log_e_max, if col in df.columns: results_df[col] = df.loc[y_test.index, col].values + # Convert ErecS to log10 space for energy resolution comparison + # (ErecS is stored in linear space in the DataFrame, but needs log10 for rel_error calc) + if "ErecS" in results_df.columns: + erec_s_val = results_df["ErecS"].values + results_df["ErecS"] = np.where(erec_s_val > 0, np.log10(erec_s_val), np.nan) + # Calculate angular resolution for BDT prediction results_df["DeltaTheta"] = np.hypot( results_df["MCxoff_true"] - results_df["MCxoff_pred"], @@ -238,7 +292,13 @@ def _log_importance_table(target_label, values, x_cols, name): def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top=25): - """Feature importance using SHAP values for native multi-target XGBoost.""" + """Feature importance using SHAP values for native multi-target XGBoost. + + Returns + ------- + dict + Dictionary mapping target names to importance arrays (per-target SHAP values). + """ x_sample = x_data.sample(n=min(len(x_data), max_points), random_state=None) n_features = len(x_data.columns) n_targets = len(target_names) @@ -247,10 +307,15 @@ def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top= shap_vals = model.get_booster().predict(dmatrix, pred_contribs=True) shap_vals = shap_vals.reshape(len(x_sample), n_targets, n_features + 1) + # Store per-target importance values + importance_dict = {} + for i, target in enumerate(target_names): target_shap = shap_vals[:, i, :-1] imp = np.abs(target_shap).mean(axis=0) + importance_dict[target] = imp # Store the importance array + idx = np.argsort(imp)[::-1] _logger.info(f"=== SHAP Importance for {target} ===") @@ -258,6 +323,8 @@ def shap_feature_importance(model, x_data, target_names, max_points=1000, n_top= if j < n_features: _logger.info(f"{x_data.columns[j]:25s} {imp[j]:.6e}") + return importance_dict + def shap_feature_importance_by_energy( model, diff --git a/src/eventdisplay_ml/features.py b/src/eventdisplay_ml/features.py index d217fb7..b4818d7 100644 --- a/src/eventdisplay_ml/features.py +++ b/src/eventdisplay_ml/features.py @@ -16,7 +16,7 @@ def target_features(analysis_type): List of target feature names. """ if analysis_type == "stereo_analysis": - return ["MCxoff", "MCyoff", "MCe0"] # sequence matters + return ["Xoff_residual", "Yoff_residual", "E_residual"] # sequence matters if "classification" in analysis_type: return [] raise ValueError(f"Unknown analysis type: {analysis_type}") @@ -40,6 +40,7 @@ def excluded_features(analysis_type, ntel): """ if analysis_type == "stereo_analysis": return { + # Pointing corrections applied during preprocessing *[f"fpointing_dx_{i}" for i in range(ntel)], *[f"fpointing_dy_{i}" for i in range(ntel)], } @@ -130,7 +131,8 @@ def _regression_features(training): "Ycore", ] if training: - return [*target_features("stereo_analysis"), *var] + # Load MC truth values (residuals will be computed from these) + return ["MCxoff", "MCyoff", "MCe0", *var] return var diff --git a/src/eventdisplay_ml/hyper_parameters.py b/src/eventdisplay_ml/hyper_parameters.py index 364d36f..0a24979 100644 --- a/src/eventdisplay_ml/hyper_parameters.py +++ b/src/eventdisplay_ml/hyper_parameters.py @@ -10,10 +10,12 @@ "xgboost": { "model": None, "hyper_parameters": { - "n_estimators": 1000, - "learning_rate": 0.1, # Shrinkage - "max_depth": 10, - "min_child_weight": 5.0, # Equivalent to MinNodeSize=1.0% for XGBoost + "n_estimators": 10000, + "early_stopping_rounds": 50, + "eval_metric": ["rmse"], + "learning_rate": 0.02, # Shrinkage + "max_depth": 7, + "min_child_weight": 10.0, # Equivalent to MinNodeSize=1.0% for XGBoost "objective": "reg:squarederror", "n_jobs": 8, "random_state": None, diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index c6ed8f6..9937e36 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -13,7 +13,7 @@ import xgboost as xgb from sklearn.model_selection import train_test_split -from eventdisplay_ml import data_processing, features, utils +from eventdisplay_ml import data_processing, diagnostic_utils, features, utils from eventdisplay_ml.data_processing import ( energy_in_bins, flatten_feature_data, @@ -34,7 +34,10 @@ def save_models(model_configs): - """Save trained models to files.""" + """Save trained models to files. + + Models already have per-target SHAP importance values cached during evaluation. + """ joblib.dump( model_configs, utils.output_file_name( @@ -213,8 +216,15 @@ def load_regression_models(model_prefix, model_name): "features": model_data.get("features", []), } } + par = {} + for key in ("target_mean", "target_std"): + if key in model_data: + par[key] = model_data[key] + else: + _logger.warning("Missing '%s' in regression model file: %s", key, model_path) + _logger.info("Loaded regression model.") - return models, {} + return models, par def apply_regression_models(df, model_configs): @@ -261,9 +271,59 @@ def apply_regression_models(df, model_configs): data_processing.print_variable_statistics(flatten_data) model = model_data["model"] - preds = model.predict(flatten_data) + preds_scaled = model.predict(flatten_data) + + # Inverse transform predictions from standardized space back to original scale + # Model was trained on standardized targets (mean=0, std=1) + target_mean_cfg = model_configs.get("target_mean") + target_std_cfg = model_configs.get("target_std") + if not target_mean_cfg or not target_std_cfg: + raise ValueError( + "Missing target standardization parameters (target_mean/target_std). " + "Regenerate the regression model or load a model file that includes them." + ) + + target_mean = np.array( + [ + target_mean_cfg["Xoff_residual"], + target_mean_cfg["Yoff_residual"], + target_mean_cfg["E_residual"], + ] + ) + target_std = np.array( + [ + target_std_cfg["Xoff_residual"], + target_std_cfg["Yoff_residual"], + target_std_cfg["E_residual"], + ] + ) + + # Inverse standardization: y = y_scaled * std + mean + preds = preds_scaled * target_std + target_mean + + # Model predicts residuals, so add them to DispBDT baseline + # Extract DispBDT predictions from the flattened data + disp_xoff = flatten_data["Xoff_weighted_bdt"].values + disp_yoff = flatten_data["Yoff_weighted_bdt"].values + erec_s = flatten_data["ErecS"].values + valid_erec_mask = (erec_s > 0) & np.isfinite(erec_s) + if not np.all(valid_erec_mask): + n_invalid = np.count_nonzero(~valid_erec_mask) + _logger.warning( + "Found %d events with ErecS <= 0 or non-finite during apply; " + "keeping entries but setting log10(ErecS) to NaN.", + n_invalid, + ) + # Compute log10 only for valid values to avoid RuntimeWarning + disp_erec_log = np.full_like(erec_s, np.nan, dtype=np.float64) + disp_erec_log[valid_erec_mask] = np.log10(erec_s[valid_erec_mask]) + + # Add residual predictions to baseline + pred_xoff = preds[:, 0] + disp_xoff + pred_yoff = preds[:, 1] + disp_yoff + pred_erec_log = preds[:, 2] + disp_erec_log - return preds[:, 0], preds[:, 1], preds[:, 2] + return pred_xoff, pred_yoff, pred_erec_log def apply_classification_models(df, model_configs, threshold_keys): @@ -509,31 +569,50 @@ def train_regression(df, model_configs): _logger.warning("Skipping training due to empty data.") return None - x_cols = df.columns.difference(model_configs["targets"]) + # Exclude target residuals from features + excluded_cols = set(model_configs["targets"]) + x_cols = [col for col in df.columns if col not in excluded_cols] _logger.info(f"Features ({len(x_cols)}): {', '.join(list(x_cols))}") model_configs["features"] = list(x_cols) x_data, y_data = df[x_cols], df[model_configs["targets"]] - # Calculate energy bin weights for balancing - bin_result = _log_energy_bin_counts(df) - sample_weights = bin_result[2] if bin_result else None - - if sample_weights is not None: - x_train, x_test, y_train, y_test, weights_train, _ = train_test_split( - x_data, - y_data, - sample_weights, - train_size=model_configs.get("train_test_fraction", 0.5), - random_state=model_configs.get("random_state", None), - ) - else: - x_train, x_test, y_train, y_test = train_test_split( - x_data, - y_data, - train_size=model_configs.get("train_test_fraction", 0.5), - random_state=model_configs.get("random_state", None), - ) - weights_train = None + # Split data first to avoid data leakage in weight computation + x_train, x_test, y_train, y_test = train_test_split( + x_data, + y_data, + train_size=model_configs.get("train_test_fraction", 0.5), + random_state=model_configs.get("random_state", None), + ) + + # Verify indices are preserved correctly + _logger.info( + f"Train indices: min={y_train.index.min()}, max={y_train.index.max()}, len={len(y_train)}" + ) + _logger.info( + f"Test indices: min={y_test.index.min()}, max={y_test.index.max()}, len={len(y_test)}" + ) + + # Calculate energy bin weights for balancing ONLY on training data + # This avoids data leakage from test set distribution + df_train = df.loc[y_train.index] + bin_result = _log_energy_bin_counts(df_train) + weights_train = bin_result[2] if bin_result else None + + # Standardize targets to prevent energy from dominating direction in multi-target learning + # Compute mean and std from training data only + y_mean = y_train.mean() + y_std = y_train.std() + + _logger.info("Target standardization (training set):") + for target in model_configs["targets"]: + _logger.info(f" {target}: mean={y_mean[target]:.6f}, std={y_std[target]:.6f}") + + y_train_scaled = (y_train - y_mean) / y_std + y_test_scaled = (y_test - y_mean) / y_std + + # Store scalers for later use during inference + model_configs["target_mean"] = y_mean.to_dict() + model_configs["target_std"] = y_std.to_dict() _logger.info(f"Training events: {len(x_train)}, Testing events: {len(x_test)}") if weights_train is not None: @@ -542,12 +621,60 @@ def train_regression(df, model_configs): f"std={weights_train.std():.3f})" ) + eval_set = [(x_train, y_train_scaled), (x_test, y_test_scaled)] + for name, cfg in model_configs.get("models", {}).items(): _logger.info(f"Training {name}") model = xgb.XGBRegressor(**cfg.get("hyper_parameters", {})) - model.fit(x_train, y_train, sample_weight=weights_train) - evaluate_regression_model(model, x_test, y_test, df, x_cols, y_data, name) + model.fit( + x_train, + y_train_scaled, + sample_weight=weights_train, + eval_set=eval_set, + verbose=False, + ) + _logger.info( + f"Training stopped at iteration {model.best_iteration} " + f"(best score: {model.best_score:.4f})" + ) + + y_train_pred_scaled = model.predict(x_train) + y_train_pred = pd.DataFrame( + y_train_pred_scaled * y_std.values + y_mean.values, + columns=model_configs["targets"], + index=y_train.index, + ) + + # Predict on scaled targets and inverse transform back to original scale + y_pred_scaled = model.predict(x_test) + y_pred = pd.DataFrame( + y_pred_scaled * y_std.values + y_mean.values, + columns=model_configs["targets"], + index=y_test.index, + ) + + generalization_metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_pred, + model_configs["targets"], + ) + + residual_normality_stats = diagnostic_utils.compute_residual_normality_stats( + y_test, + y_pred, + model_configs["targets"], + ) + + shap_importance = evaluate_regression_model( + model, x_test, y_pred, y_test, df, x_cols, y_data, name + ) cfg["model"] = model + cfg["features"] = x_cols # Store feature names for later use + cfg["generalization_metrics"] = generalization_metrics + cfg["residual_normality_stats"] = residual_normality_stats + cfg["shap_importance"] = shap_importance # Store per-target SHAP importance from evaluation return model_configs @@ -584,11 +711,12 @@ def train_classification(df, model_configs): ) _logger.info(f"Training events: {len(x_train)}, Testing events: {len(x_test)}") + eval_set = [(x_train, y_train), (x_test, y_test)] for name, cfg in model_configs.get("models", {}).items(): _logger.info(f"Training {name}") model = xgb.XGBClassifier(**cfg.get("hyper_parameters", {})) - model.fit(x_train, y_train, eval_set=[(x_test, y_test)], verbose=True) + model.fit(x_train, y_train, eval_set=eval_set, verbose=True) evaluate_classification_model(model, x_test, y_test, full_df, x_data.columns.tolist(), name) cfg["model"] = model cfg["efficiency"] = evaluation_efficiency(name, model, x_test, y_test) @@ -607,24 +735,37 @@ def _log_energy_bin_counts(df): - counts_dict: dict mapping intervals to event counts - weights_array: np.ndarray of inverse-count weights for each event (normalized for both energy and multiplicity) - Returns None if MCe0 not found. + Returns None if E_residual not found. """ - if "MCe0" not in df: - _logger.warning("MCe0 not found; skipping energy-bin availability printout.") + # Reconstruct MC truth energy from residual + DispBDT baseline + if "E_residual" not in df or "ErecS" not in df: + _logger.warning("E_residual or ErecS not found; skipping energy-bin availability printout.") return None + # Handle ErecS with proper checks for valid values (> 0) + erec_s = df["ErecS"].values + disp_erec_log = np.where(erec_s > 0, np.log10(erec_s), np.nan) + mc_e0 = df["E_residual"].values + disp_erec_log + bins = np.linspace(_EVAL_LOG_E_MIN, _EVAL_LOG_E_MAX, _EVAL_LOG_E_BINS + 1) - categories = pd.cut(df["MCe0"], bins=bins, include_lowest=True) - counts = categories.value_counts(sort=False) + categories = pd.cut(mc_e0, bins=bins, include_lowest=True) + counts = pd.Series(categories).value_counts(sort=False).sort_index() _logger.info("Training events per energy bin (log10 E true):") for interval, count in counts.items(): _logger.info(f" {interval.left:.2f} to {interval.right:.2f} : {int(count)}") # Calculate inverse-count weights for balancing (events in low-count bins get higher weight) - bin_indices = pd.cut(df["MCe0"], bins=bins, include_lowest=True, labels=False) + # Bins with fewer than 10 events get zero weight (excluded from training) + bin_indices = pd.cut(mc_e0, bins=bins, include_lowest=True, labels=False) count_per_bin = counts.values - inverse_counts = 1.0 / np.maximum(count_per_bin, 1) - inverse_counts = inverse_counts / inverse_counts.mean() + # Only invert counts >= 10 to avoid divide-by-zero warning + inverse_counts = np.zeros_like(count_per_bin, dtype=np.float64) + mask = count_per_bin >= 10 + inverse_counts[mask] = 1.0 / count_per_bin[mask] + # Normalize by mean of non-zero weights only + valid_weights = inverse_counts[inverse_counts > 0] + if len(valid_weights) > 0: + inverse_counts = inverse_counts / valid_weights.mean() # Assign weight to each event based on its energy bin w_energy = np.ones(len(df), dtype=np.float32) diff --git a/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py b/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py new file mode 100644 index 0000000..04479ee --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_generalization_gap.py @@ -0,0 +1,195 @@ +r"""Generalization Ratio: Quantify overfitting gap between train and test performance. + +Uses cached train/test RMSE values written during training when available. +For older model files without cached metrics, it falls back to rebuilding the +original train/test split from the stored input metadata. + +Usage: + python diagnostic_generalization_gap.py \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def compute_rmse_and_gaps(model, x_train, y_train, x_test, y_test, model_dict, target_names): + """Compute RMSE for train and test, derive generalization metrics.""" + _logger.info("Computing train and test RMSE...") + + y_train_pred = diagnostic_utils.predict_unscaled_residuals( + model, + x_train, + model_dict, + target_names, + ) + y_test_pred = diagnostic_utils.predict_unscaled_residuals( + model, + x_test, + model_dict, + target_names, + ) + + metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_test_pred, + target_names, + ) + + for target_name in target_names: + gap_pct = metrics[target_name]["gap_pct"] + _logger.info(f"\n{target_name}:") + _logger.info(f" Train RMSE: {metrics[target_name]['rmse_train']:.6f}") + _logger.info(f" Test RMSE: {metrics[target_name]['rmse_test']:.6f}") + _logger.info(f" Gap: {gap_pct:.2f}%") + _logger.info(f" Generalization: {'PASS' if gap_pct < 10 else 'WARN'} (threshold <10%)") + + return metrics + + +def plot_generalization_metrics(metrics, output_dir): + """Create visualization of train/test RMSE and generalization gap.""" + _logger.info("Creating generalization plots...") + + targets = list(metrics.keys()) + + # Plot 1: Train vs Test RMSE + _, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Subplot 1: RMSE comparison + rmse_train = [metrics[target]["rmse_train"] for target in targets] + rmse_test = [metrics[target]["rmse_test"] for target in targets] + + x_pos = np.arange(len(targets)) + width = 0.35 + + axes[0].bar(x_pos - width / 2, rmse_train, width, label="Train RMSE", alpha=0.8) + axes[0].bar(x_pos + width / 2, rmse_test, width, label="Test RMSE", alpha=0.8) + axes[0].set_ylabel("RMSE") + axes[0].set_title("Training vs Test Performance") + axes[0].set_xticks(x_pos) + axes[0].set_xticklabels(targets, rotation=15) + axes[0].legend() + axes[0].grid(axis="y", alpha=0.3) + + # Subplot 2: Generalization gap + gaps = [metrics[target]["gap_pct"] for target in targets] + colors = ["green" if gap < 10 else "orange" if gap < 15 else "red" for gap in gaps] + + axes[1].bar(targets, gaps, color=colors, alpha=0.7) + axes[1].axhline( + y=10, + color="green", + linestyle="--", + linewidth=2, + label="Safe threshold (10%)", + ) + axes[1].axhline( + y=15, + color="orange", + linestyle="--", + linewidth=2, + label="Warning (15%)", + ) + axes[1].set_ylabel("Gap (%)") + axes[1].set_title("Generalization Gap: (Test-Train)/Train x 100%") + axes[1].set_xticks(x_pos) + axes[1].set_xticklabels(targets, rotation=15) + axes[1].legend() + axes[1].grid(axis="y", alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "generalization_gap.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved generalization plot to {output_path}") + plt.close() + + +def diagnose_overfitting(metrics): + """Summary diagnosis of overfitting status.""" + _logger.info("\n%s", "=" * 60) + _logger.info("OVERFITTING DIAGNOSIS") + _logger.info("%s", "=" * 60) + + all_gaps = [metrics[target]["gap_pct"] for target in metrics] + mean_gap = np.mean(all_gaps) + + _logger.info(f"\nMean Generalization Gap: {mean_gap:.2f}%") + + if mean_gap < 5: + status = "EXCELLENT - Model shows minimal overfitting" + elif mean_gap < 10: + status = "GOOD - Model generalization is safe" + elif mean_gap < 15: + status = "ACCEPTABLE - Minor overfitting, monitor carefully" + else: + status = "WARNING - Significant overfitting detected" + + _logger.info(f"Status: {status}") + _logger.info("\nPer-target breakdown:") + for target, data in metrics.items(): + _logger.info(f" {target}: {data['gap_pct']:.2f}% gap") + + +def main(): + """Load cached generalization metrics or fall back to reconstructing the split.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + _, metrics = diagnostic_utils.load_cached_generalization_metrics(args.model_file) + + if metrics is None: + _logger.info( + "Cached generalization metrics are unavailable; rebuilding the train/test split" + ) + model, x_train, y_train, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + args.model_file, + args.input_file_list, + ) + ) + + metrics = compute_rmse_and_gaps( + model, + x_train, + y_train, + x_test, + y_test, + model_dict, + target_names, + ) + else: + _logger.info("Using cached generalization metrics from the model file") + + plot_generalization_metrics(metrics, args.output_dir) + diagnose_overfitting(metrics) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py b/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py new file mode 100644 index 0000000..4b71fd3 --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_partial_dependence.py @@ -0,0 +1,192 @@ +r"""Partial Dependence Plots (PDP): Prove the model captures physics, not chaos. + +Plots predicted residual output as a function of a single feature while holding +others constant. For stereo reconstruction, proves that the model correctly +reduces corrections for high-multiplicity events and increases them for sparse data. + +Usage: + eventdisplay-ml-diagnostic-partial-dependence \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ \\ + --features DispNImages Xoff_weighted_bdt ErecS +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.inspection import partial_dependence + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def load_data_and_model(model_file, input_file_list=None): + """Load trained model and reconstruct the held-out test split.""" + model, _, _, x_test, _, features, target_names, _ = ( + diagnostic_utils.load_stereo_regression_split( + model_file, + input_file_list, + ) + ) + return model, x_test, features, target_names + + +def compute_partial_dependence(model, x_test, features_to_plot): + """Compute partial dependence for selected features.""" + _logger.info(f"Computing partial dependence for {len(features_to_plot)} features...") + + # Convert all features to float to avoid sklearn warnings about integer dtypes + x_test_float = x_test.astype(np.float64) + + pdp_data = {} + + for feat_name in features_to_plot: + if feat_name not in x_test_float.columns: + _logger.warning(f"Feature {feat_name} not found in data") + continue + + feat_idx = x_test_float.columns.get_loc(feat_name) + + # Compute PDP for each target + pd_result = partial_dependence( + model, + x_test_float, + [feat_idx], + grid_resolution=50, + percentiles=(0.05, 0.95), + ) + + pd_values = pd_result.get("average") + if pd_values is None: + pd_values = pd_result.get("average_predictions") + if pd_values is None: + raise ValueError( + "Could not find partial dependence output in result. " + "Expected key 'average' or 'average_predictions'." + ) + + pdp_data[feat_name] = { + "grid": pd_result["grid_values"][0], + "pd_values": np.asarray(pd_values), # shape: (n_targets, n_grid) + } + + return pdp_data + + +def plot_partial_dependence(pdp_data, output_dir, target_names): + """Create PDP plots for each feature x target combination.""" + _logger.info("Creating partial dependence plots...") + + features = list(pdp_data.keys()) + if not features: + _logger.warning("No valid features available for partial dependence plotting") + return + + # Create a grid of subplots: features x targets + _, axes = plt.subplots(len(features), len(target_names), figsize=(15, 5 * len(features))) + + if len(features) == 1: + axes = axes.reshape(1, -1) + if len(target_names) == 1: + axes = axes.reshape(-1, 1) + + for feat_idx, feat_name in enumerate(features): + for target_idx, target_name in enumerate(target_names): + ax = axes[feat_idx, target_idx] + + grid = pdp_data[feat_name]["grid"] + # pd_values shape: (n_targets, n_grid_points) + pd_vals = pdp_data[feat_name]["pd_values"][target_idx] + + ax.plot(grid, pd_vals, linewidth=2.5, marker="o", markersize=4, color="steelblue") + ax.fill_between(grid, pd_vals * 0.95, pd_vals * 1.05, alpha=0.2) + + ax.set_xlabel(feat_name) + ax.set_ylabel(f"Predicted {target_name}") + ax.set_title(f"{feat_name} -> {target_name}") + ax.grid(alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "partial_dependence.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved PDP plots to {output_path}") + plt.close() + + +def diagnose_physics(pdp_data): + """Check if PDP shows physically sensible behavior.""" + _logger.info("\n%s", "=" * 60) + _logger.info("PHYSICS VALIDATION: Partial Dependence Analysis") + _logger.info("%s", "=" * 60) + + # Check 1: Multiplicity effect (should reduce corrections for high multiplicity) + if "DispNImages" in pdp_data: + grid = pdp_data["DispNImages"]["grid"] + pd_vals = pdp_data["DispNImages"]["pd_values"][0] # first target + + slope = (pd_vals[-1] - pd_vals[0]) / (grid[-1] - grid[0]) + + _logger.info("\nMultiplicity Effect (DispNImages):") + _logger.info(f" Slope of PDP: {slope:.6f}") + if slope < 0: + _logger.info(" CORRECT - More telescopes → smaller corrections needed") + else: + _logger.info(" WARNING - Unexpected behavior (fewer telescopes → larger corr)") + + # Check 2: Baseline stability (should show smooth, monotonic response) + baseline_features = [feat for feat in pdp_data if "weighted_bdt" in feat or "intersect" in feat] + for feat_name in baseline_features: + pd_vals = pdp_data[feat_name]["pd_values"][0] + + # Compute smoothness: ratio of diff magnitudes + diffs = np.abs(np.diff(pd_vals)) + smoothness = np.std(diffs) / (np.mean(diffs) + 1e-6) + + _logger.info(f"\n{feat_name}:") + _logger.info(f" Smoothness index: {smoothness:.4f}") + if smoothness < 0.3: + _logger.info(" GOOD - Smooth, linear relationship (learned physics)") + elif smoothness < 0.6: + _logger.info(" OK - Some noise but generally smooth") + else: + _logger.info(" WARNING - Chaotic relationship (possible overtraining)") + + +def main(): + """Rebuild held-out data from model metadata and run PDP diagnostics.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + parser.add_argument( + "--features", + nargs="+", + default=["DispNImages", "Xoff_weighted_bdt", "Yoff_weighted_bdt", "ErecS"], + help="Features to plot PDP for", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + model, x_test, _, target_names = load_data_and_model(args.model_file, args.input_file_list) + pdp_data = compute_partial_dependence(model, x_test, args.features) + plot_partial_dependence(pdp_data, args.output_dir, target_names) + diagnose_physics(pdp_data) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py b/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py new file mode 100644 index 0000000..38a66ed --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_permutation_importance.py @@ -0,0 +1,181 @@ +"""Permutation importance for stereo regression models. + +This diagnostic rebuilds the held-out test split from the model metadata and the +original training input files, then shuffles features one-by-one and measures +the degradation in residual RMSE. + +Usage: + python diagnostic_permutation_importance.py \ + --model_file trained_stereo.joblib \ + --output_dir diagnostics/ \ + --top_n 20 +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.metrics import mean_squared_error + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def compute_baseline_rmse(model, x_test, y_test, model_dict, target_names): + """Compute baseline RMSE on unshuffled test set.""" + y_pred = diagnostic_utils.predict_unscaled_residuals(model, x_test, model_dict, target_names) + + baseline_rmse = {} + for target in target_names: + mse = mean_squared_error(y_test[target], y_pred[target]) + baseline_rmse[target] = np.sqrt(mse) + _logger.info(f" Baseline RMSE ({target}): {baseline_rmse[target]:.6f}") + + return baseline_rmse, y_pred + + +def permutation_importance(model, x_test, y_test, baseline_rmse, target_names, model_dict): + """Compute permutation importance for each feature.""" + _logger.info("Computing permutation importance...") + + importance = {target: {} for target in target_names} + + for feat_idx, feat_name in enumerate(x_test.columns): + x_shuffled = x_test.copy() + x_shuffled.iloc[:, feat_idx] = ( + x_shuffled.iloc[:, feat_idx] + .sample( + frac=1, + random_state=model_dict.get("random_state", None), + ) + .to_numpy() + ) + + y_pred_shuffled = diagnostic_utils.predict_unscaled_residuals( + model, + x_shuffled, + model_dict, + target_names, + ) + + for target_name in target_names: + mse_shuffled = mean_squared_error(y_test[target_name], y_pred_shuffled[target_name]) + rmse_shuffled = np.sqrt(mse_shuffled) + + # Importance = relative RMSE increase when feature is shuffled + relative_importance = (rmse_shuffled - baseline_rmse[target_name]) / baseline_rmse[ + target_name + ] + importance[target_name][feat_name] = relative_importance + + return importance + + +def plot_permutation_importance(importance, output_path, top_n=20): + """Plot permutation importance for each target.""" + _logger.info(f"Creating permutation importance plots (top {top_n})...") + + targets = list(importance.keys()) + _, axes = plt.subplots(1, len(targets), figsize=(5 * len(targets), 8)) + if len(targets) == 1: + axes = [axes] + + for ax, target_name in zip(axes, targets): + imp_df = ( + pd.DataFrame( + list(importance[target_name].items()), + columns=["feature", "importance"], + ) + .sort_values("importance", ascending=True) + .tail(top_n) + ) + colors = ["red" if x < 0 else "green" for x in imp_df["importance"]] + + ax.barh(imp_df["feature"], imp_df["importance"] * 100.0, color=colors, alpha=0.7) + ax.set_xlabel("Relative RMSE Increase (%)") + ax.set_title(f"Permutation Importance\n{target_name}") + ax.axvline(x=0, color="black", linestyle="--", linewidth=1) + ax.grid(axis="x", alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved permutation importance plot to {output_path}") + plt.close() + + +def diagnose_baseline_anchoring(importance): + """Check if model is anchored in conventional baselines.""" + _logger.info("=== Physics Check: Baseline Anchoring ===") + + baseline_features = ["Xoff_weighted_bdt", "Yoff_weighted_bdt", "ErecS"] + + for target, imp_dict in importance.items(): + _logger.info(f"\n{target}:") + + baseline_contrib = sum( + imp_dict.get(feat, 0) for feat in baseline_features if feat in imp_dict + ) + total_contrib = sum(imp for imp in imp_dict.values() if imp > 0) + + anchor_pct = (baseline_contrib / total_contrib * 100) if total_contrib > 0 else 0 + + _logger.info(f" Baseline features contribution: {anchor_pct:.1f}%") + _logger.info(" (Expect >70% for well-anchored model)") + + for feat in baseline_features: + if feat in imp_dict: + _logger.info(f" {feat}: {imp_dict[feat] * 100:.2f}%") + + +def main(): + """Rebuild the held-out test split and compute permutation importance.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + parser.add_argument("--top_n", type=int, default=20, help="Top N features to plot") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + model, _, _, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + args.model_file, + args.input_file_list, + ) + ) + + _logger.info("Computing baseline RMSE...") + baseline_rmse, _ = compute_baseline_rmse(model, x_test, y_test, model_dict, target_names) + + importance = permutation_importance( + model, + x_test, + y_test, + baseline_rmse, + target_names, + model_dict, + ) + + output_path = Path(args.output_dir) / "permutation_importance.png" + plot_permutation_importance(importance, output_path, args.top_n) + + diagnose_baseline_anchoring(importance) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py b/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py new file mode 100644 index 0000000..fd2025f --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_residual_normality.py @@ -0,0 +1,253 @@ +r"""Residual Normality & Outlier Check: Validate statistical quality of predictions. + +Tests if residuals are Gaussian and centered at zero. Non-normal residuals indicate +the model is failing on specific event types (e.g., edge-of-camera or low-multiplicity). +Heavy tails indicate outliers; skewness indicates systematic bias. + +Usage: + eventdisplay-ml-diagnostic-residual-normality \\ + --model_file trained_stereo.joblib \\ + --output_dir diagnostics/ +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +from scipy import stats + +from eventdisplay_ml import diagnostic_utils + +_logger = logging.getLogger(__name__) + + +def load_predictions_and_targets(model_file, input_file_list=None): + """Rebuild test split and compute predicted residuals.""" + model, _, _, x_test, y_test, _, target_names, model_dict = ( + diagnostic_utils.load_stereo_regression_split( + model_file, + input_file_list, + ) + ) + y_pred = diagnostic_utils.predict_unscaled_residuals(model, x_test, model_dict, target_names) + return y_pred, y_test, target_names + + +def compute_residuals(y_pred, y_true, target_names): + """Compute residuals (true - predicted). Standard ML residual convention.""" + residuals = {} + + for target_name in target_names: + if target_name in y_true.columns and target_name in y_pred.columns: + residuals[target_name] = y_true[target_name].values - y_pred[target_name].values + + return residuals + + +def compute_normality_stats(residuals): + """Compute Gaussian fit parameters and normality tests.""" + stats_dict = {} + + for target_name, resid in residuals.items(): + # Remove NaN values + resid_clean = resid[~np.isnan(resid)] + if len(resid_clean) == 0: + _logger.warning(f"Skipping {target_name}: no finite residuals") + continue + + # Gaussian parameters + mean = np.mean(resid_clean) + std = np.std(resid_clean) + + # Normality tests + _, p_ks = stats.kstest(resid_clean, "norm", args=(mean, std)) + ad_result = stats.anderson(resid_clean, dist="norm", method="interpolate") + ad_stat = ad_result.statistic + # With method='interpolate', ad_result has pvalue attribute + ad_pvalue = ad_result.pvalue if hasattr(ad_result, "pvalue") else np.nan + + # Skewness and kurtosis + skewness = stats.skew(resid_clean) + kurtosis = stats.kurtosis(resid_clean) + + # Quantile-Quantile test (visual) + _, (_, _, qq_r) = stats.probplot(resid_clean, dist="norm") + qq_r2 = qq_r**2 + + stats_dict[target_name] = { + "mean": mean, + "std": std, + "p_ks": p_ks, + "ad_stat": ad_stat, + "ad_pvalue": ad_pvalue, + "skewness": skewness, + "kurtosis": kurtosis, + "qq_r2": qq_r2, + "n_outliers": np.sum(np.abs(resid_clean) > 3 * std), + "n_samples": len(resid_clean), + } + + return stats_dict + + +def plot_residual_diagnostics(residuals, stats_dict, output_dir): + """Create comprehensive residual diagnostic plots.""" + _logger.info("Creating residual diagnostic plots...") + + target_names = list(residuals.keys()) + if not target_names: + _logger.warning("No residual targets to plot") + return + + # Create a 2xN grid: histogram + Q-Q plot for each target + _, axes = plt.subplots(2, len(target_names), figsize=(5 * len(target_names), 10)) + + if len(target_names) == 1: + axes = axes.reshape(2, 1) + + for col_idx, target_name in enumerate(target_names): + resid = residuals[target_name][~np.isnan(residuals[target_name])] + stat = stats_dict[target_name] + + # Row 0: Histogram with Gaussian overlay + ax_hist = axes[0, col_idx] + ax_hist.hist(resid, bins=50, density=True, alpha=0.7, color="steelblue", edgecolor="black") + + # Overlay Gaussian + x_range = np.linspace(resid.min(), resid.max(), 100) + gaussian = stats.norm.pdf(x_range, stat["mean"], stat["std"]) + ax_hist.plot(x_range, gaussian, "r-", linewidth=2, label="Normal fit") + + ax_hist.axvline(stat["mean"], color="green", linestyle="--", linewidth=2, label="Mean") + ax_hist.set_xlabel("Residual value") + ax_hist.set_ylabel("Density") + ax_hist.set_title(f"{target_name}\nmu={stat['mean']:.4f}, sigma={stat['std']:.4f}") + ax_hist.legend(fontsize=8) + ax_hist.grid(alpha=0.3) + + # Row 1: Q-Q plot + ax_qq = axes[1, col_idx] + stats.probplot(resid, dist="norm", plot=ax_qq) + ax_qq.set_title(f"Q-Q Plot (R²={stat['qq_r2']:.4f})") + ax_qq.grid(alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / "residual_diagnostics.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved residual diagnostics to {output_path}") + plt.close() + + +def diagnose_residual_quality(residuals, stats_dict): + """Provide detailed diagnosis of residual quality.""" + _logger.info("\n%s", "=" * 60) + _logger.info("RESIDUAL NORMALITY & OUTLIER ANALYSIS") + _logger.info("%s", "=" * 60) + + for target_name, stat in stats_dict.items(): + _logger.info(f"\n{target_name}:") + _logger.info(f" Mean: {stat['mean']:.6f} (expect ~0)") + _logger.info(f" Std: {stat['std']:.6f}") + + if np.abs(stat["mean"]) < stat["std"] * 0.1: + _logger.info(" GOOD - Residuals centered at zero") + elif np.abs(stat["mean"]) < stat["std"] * 0.2: + _logger.info(" OK - Small systematic offset") + else: + _logger.info(" WARNING - Significant bias detected") + + _logger.info(f"\n Skewness: {stat['skewness']:.4f}") + if np.abs(stat["skewness"]) < 0.2: + _logger.info(" GOOD - Symmetric distribution") + elif np.abs(stat["skewness"]) < 0.5: + _logger.info(" OK - Mild asymmetry") + else: + _logger.info(" WARNING - Strong skew (model failing on certain events)") + + _logger.info(f"\n Kurtosis: {stat['kurtosis']:.4f}") + if np.abs(stat["kurtosis"]) < 0.5: + _logger.info(" GOOD - Gaussian-like tails") + elif np.abs(stat["kurtosis"]) < 1.0: + _logger.info(" OK - Slightly heavy/light tails") + else: + _logger.info(" WARNING - Heavy tails (outliers present)") + + _logger.info(f"\n Outliers (>3sigma): {stat['n_outliers']} events") + outlier_pct = stat["n_outliers"] / stat["n_samples"] * 100 + if outlier_pct < 0.3: + _logger.info(" GOOD - Minimal outliers") + elif outlier_pct < 1.0: + _logger.info(" OK - Few outliers") + else: + _logger.info(" WARNING - Excessive outliers") + + _logger.info(f"\n Kolmogorov-Smirnov test: p={stat['p_ks']:.4f}") + if stat["p_ks"] > 0.05: + _logger.info(" Gaussian hypothesis NOT rejected (p > 0.05)") + else: + _logger.info(" Distribution deviates from Gaussian (p < 0.05)") + + _logger.info( + f"\n Anderson-Darling normality: stat={stat['ad_stat']:.4f}, " + f"p-value={stat['ad_pvalue']:.4f}" + ) + if stat["ad_pvalue"] > 0.05: + _logger.info(" Anderson-Darling does not reject normality (p > 0.05)") + else: + _logger.info(" Anderson-Darling rejects normality (p < 0.05)") + + _logger.info(f"\n Q-Q R²: {stat['qq_r2']:.4f}") + if stat["qq_r2"] > 0.98: + _logger.info(" Excellent Gaussian fit") + elif stat["qq_r2"] > 0.95: + _logger.info(" Good Gaussian fit") + else: + _logger.info(" Fair fit, consider investigating tails") + + +def main(): + """Load cached residual normality stats or fall back to reconstructing the split.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument( + "--input_file_list", + default=None, + help=( + "Optional override for the training input file list. If omitted, the path stored " + "in the model file is used." + ), + ) + parser.add_argument("--output_dir", default="diagnostics", help="Output directory") + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + _, stats_dict = diagnostic_utils.load_cached_residual_normality_stats(args.model_file) + + residuals = None + if stats_dict is None: + _logger.info("Cached residual normality statistics unavailable; rebuilding from test split") + y_pred, y_true, target_names = load_predictions_and_targets( + args.model_file, + args.input_file_list, + ) + residuals = compute_residuals(y_pred, y_true, target_names) + stats_dict = compute_normality_stats(residuals) + plot_residual_diagnostics(residuals, stats_dict, args.output_dir) + else: + _logger.info("Using cached residual normality statistics from the model file") + _logger.info( + "Note: Diagnostic plots skipped when using cached statistics; " + "rerun without cache to regenerate plots" + ) + + diagnose_residual_quality(residuals or {}, stats_dict) + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py new file mode 100644 index 0000000..7eb41e2 --- /dev/null +++ b/src/eventdisplay_ml/scripts/diagnostic_shap_summary.py @@ -0,0 +1,125 @@ +r"""SHAP Feature Importance: Show cached feature importances from training. + +Displays the top 20 features for each reconstruction target (Xoff, Yoff, Energy) +using XGBoost native feature importances cached during training. + +This script requires no test data - it reads directly from the cached importance +values stored in the model file during training. + +Usage: + python diagnostic_shap_summary.py \\ + --model_file stereo_model_Xoff_residual.joblib \\ + --output_dir diagnostics/ +""" + +import argparse +import logging +from pathlib import Path + +import joblib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +_logger = logging.getLogger(__name__) + + +def load_model_config(model_file): + """Load model configuration with cached feature importances.""" + _logger.info(f"Loading model from {model_file}") + model_dict = joblib.load(model_file) + + models = model_dict.get("models") + if not isinstance(models, dict) or not models: + raise ValueError( + "Invalid model file structure: expected a non-empty 'models' mapping in " + f"{model_file!r}. The file should contain a top-level dictionary with a " + "'models' key mapping target names to model configurations." + ) + model_cfg = next(iter(models.values())) + + return model_cfg, model_dict + + +def plot_feature_importance(features, importances, target_name, output_dir): + """Create feature importance bar plot for a single target. + + Parameters + ---------- + features : list + Feature names. + importances : array + Importance values. + target_name : str + Name of the target (e.g., "Xoff_residual"). + output_dir : str + Output directory for plot. + """ + importance_df = ( + pd.DataFrame({"Feature": features, "Importance": importances}) + .sort_values("Importance", ascending=True) + .tail(20) + ) + + _, ax = plt.subplots(figsize=(10, 8)) + + colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(importance_df))) + ax.barh(importance_df["Feature"], importance_df["Importance"], color=colors) + ax.set_xlabel("XGBoost Importance Score", fontsize=12, fontweight="bold") + ax.set_title(f"Top 20 Feature Importances: {target_name}", fontsize=14, fontweight="bold") + ax.grid(axis="x", alpha=0.3) + + plt.tight_layout() + output_path = Path(output_dir) / f"shap_importance_{target_name}.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Saved {target_name} importance plot to {output_path}") + plt.close() + + # Log top features + _logger.info(f"\n=== Top 10 Features for {target_name} ===") + for feat, imp in zip( + importance_df["Feature"].tail(10)[::-1], importance_df["Importance"].tail(10)[::-1] + ): + _logger.info(f" {feat:35s} {imp:.6f}") + + +def main(): + """Load cached SHAP importance from model file and create plots for each target.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model_file", required=True, help="Path to trained model joblib file") + parser.add_argument("--output_dir", default="diagnostics", help="Output directory for plots") + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, format="%(message)s") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + _logger.info("=== SHAP Feature Importance Summary ===") + + model_cfg, _ = load_model_config(args.model_file) + + shap_importance = model_cfg.get("shap_importance") + features = model_cfg.get("features") + + if shap_importance is None: + _logger.error("ERROR: No cached SHAP importance found in model file!") + _logger.error("Make sure the model was trained with the updated code.") + return + + if features is None: + _logger.error("ERROR: No feature list found in model file!") + return + + _logger.info(f"Loaded {len(features)} features from cache") + _logger.info(f"Found per-target SHAP importance for: {list(shap_importance.keys())}") + + # Create plots for each target using cached SHAP importance + for target_name, importances in shap_importance.items(): + _logger.info(f"\nProcessing {target_name}...") + plot_feature_importance(features, importances, target_name, args.output_dir) + + _logger.info(f"\nPlots saved to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/plot_training_evaluation.py b/src/eventdisplay_ml/scripts/plot_training_evaluation.py new file mode 100644 index 0000000..b9c9c1a --- /dev/null +++ b/src/eventdisplay_ml/scripts/plot_training_evaluation.py @@ -0,0 +1,176 @@ +""" +Plot XGBoost training evaluation metrics for stereo and classification models. + +This script loads a trained model (stereo or classification) from a joblib file +and plots the evaluation results stored during training. It visualizes the +training vs validation metrics curves to assess model convergence and potential +overfitting. + +Example usage: + # Stereo model + python plot_training_evaluation.py \ + --model_file tmp_cta_testing/stereo_south/dispdir_bdt.joblib \ + --output_file training_curves.png + + # Classification model + python plot_training_evaluation.py \ + --model_file tmp_testing/cl/classify_bdt_ebin2.joblib \ + --output_file classification_curves.png +""" + +import argparse +import logging +from pathlib import Path + +import joblib +import matplotlib.pyplot as plt +import numpy as np + +logging.basicConfig(level=logging.INFO) +_logger = logging.getLogger(__name__) + + +def plot_training_curves(evals_result, output_file=None): + """ + Plot training and validation curves from XGBoost evaluation results. + + Parameters + ---------- + evals_result : dict + Dictionary containing evaluation results from XGBoost model. + Expected format: {'validation_0': {'rmse': [...]}, 'validation_1': {'rmse': [...]}} + output_file : str or Path, optional + Path to save the output figure. If None, display interactively. + """ + if not evals_result: + _logger.warning("No evaluation results found in model.") + return + + # Determine how many datasets were tracked (typically training and test) + n_datasets = len(evals_result) + dataset_names = list(evals_result.keys()) + + _logger.info(f"Found {n_datasets} evaluation datasets: {dataset_names}") + + # Get all metrics tracked for the first dataset + metrics = list(evals_result[dataset_names[0]].keys()) + n_metrics = len(metrics) + + _logger.info(f"Metrics tracked: {metrics}") + + # Create subplots for each metric + _, axes = plt.subplots(n_metrics, 1, figsize=(10, 6 * n_metrics), squeeze=False) + axes = axes.flatten() + + colors = ["blue", "red", "green", "orange", "purple", "brown"] + labels = { + "validation_0": "Training", + "validation_1": "Test", + "train": "Training", + "test": "Test", + } + + for metric_idx, metric in enumerate(metrics): + ax = axes[metric_idx] + + for dataset_idx, dataset_name in enumerate(dataset_names): + if metric not in evals_result[dataset_name]: + continue + + values = evals_result[dataset_name][metric] + epochs = np.arange(1, len(values) + 1) + + label = labels.get(dataset_name, dataset_name) + color = colors[dataset_idx % len(colors)] + + ax.plot(epochs, values, label=label, color=color, linewidth=2, alpha=0.8) + + ax.set_xlabel("Boosting Round (Iteration)", fontsize=12) + ax.set_ylabel(metric.upper(), fontsize=12) + ax.set_title(f"Training Progress: {metric.upper()}", fontsize=14, fontweight="bold") + ax.legend(loc="best", fontsize=10) + ax.grid(True, alpha=0.3) + + # Log scale for y-axis if values span multiple orders of magnitude + if len(values) > 0: + value_range = np.max(values) / (np.min(values) + 1e-10) + if value_range > 100: + ax.set_yscale("log") + _logger.info(f"Using log scale for {metric} (value range: {value_range:.1f})") + + plt.tight_layout() + + if output_file: + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + _logger.info(f"Figure saved to: {output_path}") + else: + plt.show() + + plt.close() + + +def main(): + """Plot XGBoost training evaluation.""" + parser = argparse.ArgumentParser( + description=( + "Plot XGBoost training evaluation metrics from trained model " + "(stereo or classification)." + ) + ) + parser.add_argument( + "--model_file", + required=True, + type=str, + help="Path to the trained model joblib file (e.g., dispdir_bdt.joblib).", + ) + parser.add_argument( + "--output_file", + type=str, + default=None, + help="Path to save the output plot (PNG/PDF). If not provided, display interactively.", + ) + + args = parser.parse_args() + + model_path = Path(args.model_file) + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + _logger.info(f"Loading model from: {model_path}") + model_configs = joblib.load(model_path) + + # Extract the XGBoost model and its evaluation results + if "models" not in model_configs: + raise ValueError("Model file does not contain 'models' key.") + + if "xgboost" not in model_configs["models"]: + raise ValueError("Model file does not contain 'xgboost' model.") + + xgb_model = model_configs["models"]["xgboost"]["model"] + + if not hasattr(xgb_model, "evals_result"): + raise AttributeError( + "XGBoost model does not have 'evals_result' method. " + "Model may not have been trained with eval_set parameter." + ) + + evals_result = xgb_model.evals_result() + + _logger.info(f"Model type: {type(xgb_model).__name__}") + _logger.info(f"Number of boosting rounds: {xgb_model.get_booster().num_boosted_rounds()}") + + # Additional model info + if "features" in model_configs: + _logger.info(f"Number of features: {len(model_configs['features'])}") + if "targets" in model_configs: + _logger.info(f"Target variables: {model_configs['targets']}") + + plot_training_curves(evals_result, args.output_file) + + _logger.info("Plotting completed successfully.") + + +if __name__ == "__main__": + main() diff --git a/src/eventdisplay_ml/scripts/train_xgb_stereo.py b/src/eventdisplay_ml/scripts/train_xgb_stereo.py index fab9e02..02bb62b 100644 --- a/src/eventdisplay_ml/scripts/train_xgb_stereo.py +++ b/src/eventdisplay_ml/scripts/train_xgb_stereo.py @@ -1,10 +1,14 @@ """ Train XGBoost BDTs for stereo reconstruction (direction, energy). -Uses x,y offsets calculated from intersection and dispBDT methods plus -image parameters to train multi-target regression BDTs to predict x,y offsets. - -Uses energy related values to estimate event energy. +Uses residuals relative to DispBDT predictions as training targets. The model learns +to correct the DispBDT baseline by predicting residuals: + - Xoff_residual = MCxoff - Xoff_DispBDT + - Yoff_residual = MCyoff - Yoff_DispBDT + - E_residual = log10(MCe0) - log10(ErecS_DispBDT) + +During inference, the predicted residuals are added back to the DispBDT baseline +to produce the final direction and energy estimates. Trains a single BDT on all telescope multiplicity events. """ diff --git a/tests/test_regression_apply.py b/tests/test_regression_apply.py new file mode 100644 index 0000000..3d73743 --- /dev/null +++ b/tests/test_regression_apply.py @@ -0,0 +1,322 @@ +"""Comprehensive tests for regression model apply, residual computation, and standardization.""" + +import numpy as np +import pandas as pd +import pytest + +from eventdisplay_ml import models + + +class DummyXGBRegressor: + """Simple XGBoost-like model that returns fixed scaled residuals.""" + + def __init__(self, preds_scaled): + self._preds_scaled = np.asarray(preds_scaled, dtype=np.float64) + + def predict(self, _x): + """Return fixed predictions regardless of input.""" + return self._preds_scaled + + +class TestApplyRegressionStandardizationInversion: + """Test standardization inversion in apply_regression_models.""" + + def test_apply_regression_inverts_standardization(self, monkeypatch): + """Verify that predicted residuals are correctly inverted from standardized space.""" + df_flat = pd.DataFrame( + { + "Xoff_weighted_bdt": [10.0, 20.0], + "Yoff_weighted_bdt": [30.0, 40.0], + "ErecS": [100.0, 1000.0], + } + ) + + # Model returns fixed scaled predictions (in standardized space: mean=0, std=1) + preds_scaled = np.array( + [ + [ + 0.0, + 1.0, + -1.0, + ], # Event 0: Xoff_residual_scaled=0, Yoff_residual_scaled=1, E_residual_scaled=-1 + [ + 1.0, + 0.0, + 2.0, + ], # Event 1: Xoff_residual_scaled=1, Yoff_residual_scaled=0, E_residual_scaled=2 + ] + ) + + model_configs = { + "models": { + "xgboost": {"model": DummyXGBRegressor(preds_scaled), "features": df_flat.columns} + }, + "target_mean": { + "Xoff_residual": 1.0, + "Yoff_residual": 2.0, + "E_residual": 0.5, + }, + "target_std": { + "Xoff_residual": 2.0, + "Yoff_residual": 3.0, + "E_residual": 0.1, + }, + } + + def _mock_flatten(*_args, **_kwargs): + return df_flat + + monkeypatch.setattr(models, "flatten_feature_data", _mock_flatten) + monkeypatch.setattr(models.data_processing, "print_variable_statistics", lambda *_: None) + + pred_xoff, pred_yoff, pred_erec_log = models.apply_regression_models( + pd.DataFrame({"dummy": [0, 1]}), model_configs + ) + + # Inverse transform: y = y_scaled * std + mean + # Event 0: Xoff = 0.0 * 2.0 + 1.0 = 1.0, then add baseline: 1.0 + 10.0 = 11.0 + # Event 1: Xoff = 1.0 * 2.0 + 1.0 = 3.0, then add baseline: 3.0 + 20.0 = 23.0 + expected_xoff = np.array([11.0, 23.0]) + + # Event 0: Yoff = 1.0 * 3.0 + 2.0 = 5.0, then add baseline: 5.0 + 30.0 = 35.0 + # Event 1: Yoff = 0.0 * 3.0 + 2.0 = 2.0, then add baseline: 2.0 + 40.0 = 42.0 + expected_yoff = np.array([35.0, 42.0]) + + # disp_erec_log = log10([100, 1000]) = [2, 3] + # Event 0: E_residual = -1.0 * 0.1 + 0.5 = 0.4, then add baseline: 0.4 + 2.0 = 2.4 + # Event 1: E_residual = 2.0 * 0.1 + 0.5 = 0.7, then add baseline: 0.7 + 3.0 = 3.7 + expected_erec_log = np.array([2.4, 3.7]) + + np.testing.assert_allclose(pred_xoff, expected_xoff, rtol=0, atol=1e-8) + np.testing.assert_allclose(pred_yoff, expected_yoff, rtol=0, atol=1e-8) + np.testing.assert_allclose(pred_erec_log, expected_erec_log, rtol=0, atol=1e-8) + + def test_apply_regression_missing_standardization_params(self, monkeypatch): + """Verify that missing target_mean/target_std raises clear error.""" + df_flat = pd.DataFrame( + { + "Xoff_weighted_bdt": [10.0], + "Yoff_weighted_bdt": [30.0], + "ErecS": [100.0], + } + ) + + model_configs = { + "models": { + "xgboost": { + "model": DummyXGBRegressor(np.array([[0.0, 1.0, -1.0]])), + "features": df_flat.columns, + } + }, + # Missing target_mean and target_std + } + + def _mock_flatten(*_args, **_kwargs): + return df_flat + + monkeypatch.setattr(models, "flatten_feature_data", _mock_flatten) + monkeypatch.setattr(models.data_processing, "print_variable_statistics", lambda *_: None) + + with pytest.raises(ValueError, match="Missing target standardization parameters"): + models.apply_regression_models(pd.DataFrame({"dummy": [0]}), model_configs) + + +class TestApplyRegressionErecSHandling: + """Test ErecS validation and log10 computation in apply.""" + + def test_apply_regression_handles_invalid_erecs(self, monkeypatch): + """Verify that invalid ErecS values (<=0 or NaN) are handled gracefully.""" + df_flat = pd.DataFrame( + { + "Xoff_weighted_bdt": [10.0, 20.0, 30.0], + "Yoff_weighted_bdt": [30.0, 40.0, 50.0], + "ErecS": [100.0, -5.0, np.nan], # 2nd event: negative, 3rd event: NaN + } + ) + + preds_scaled = np.array( + [ + [0.0, 1.0, -1.0], + [1.0, 0.0, 2.0], + [0.5, 0.5, 0.0], + ] + ) + + model_configs = { + "models": { + "xgboost": {"model": DummyXGBRegressor(preds_scaled), "features": df_flat.columns} + }, + "target_mean": { + "Xoff_residual": 1.0, + "Yoff_residual": 2.0, + "E_residual": 0.5, + }, + "target_std": { + "Xoff_residual": 2.0, + "Yoff_residual": 3.0, + "E_residual": 0.1, + }, + } + + def _mock_flatten(*_args, **_kwargs): + return df_flat + + monkeypatch.setattr(models, "flatten_feature_data", _mock_flatten) + monkeypatch.setattr(models.data_processing, "print_variable_statistics", lambda *_: None) + + pred_xoff, pred_yoff, pred_erec_log = models.apply_regression_models( + pd.DataFrame({"dummy": [0, 1, 2]}), model_configs + ) + + # Event 0: ErecS valid, log10(100) = 2, plus residual transform => 2.4 + assert np.isfinite(pred_erec_log[0]) + assert np.isclose(pred_erec_log[0], 2.4) + + # Events 1 and 2: ErecS invalid, should be NaN + assert np.isnan(pred_erec_log[1]), "Event with negative ErecS should produce NaN" + assert np.isnan(pred_erec_log[2]), "Event with NaN ErecS should produce NaN" + + # Xoff and Yoff should still be valid for all events + assert len(pred_xoff) == 3 + assert len(pred_yoff) == 3 + assert all(np.isfinite(pred_xoff)) + assert all(np.isfinite(pred_yoff)) + + def test_apply_regression_output_length_matches_input(self, monkeypatch): + """Verify that output arrays have same length as input, even with invalid ErecS.""" + n_events = 100 + rng = np.random.default_rng(42) + df_flat = pd.DataFrame( + { + "Xoff_weighted_bdt": rng.uniform(0, 10, n_events), + "Yoff_weighted_bdt": rng.uniform(0, 10, n_events), + "ErecS": np.where( + rng.uniform(0, 1, n_events) > 0.2, + rng.uniform(10, 110, n_events), # 80% valid + np.nan, # 20% NaN + ), + } + ) + + preds_scaled = rng.standard_normal((n_events, 3)) + + model_configs = { + "models": { + "xgboost": {"model": DummyXGBRegressor(preds_scaled), "features": df_flat.columns} + }, + "target_mean": { + "Xoff_residual": 0.0, + "Yoff_residual": 0.0, + "E_residual": 0.0, + }, + "target_std": { + "Xoff_residual": 1.0, + "Yoff_residual": 1.0, + "E_residual": 1.0, + }, + } + + def _mock_flatten(*_args, **_kwargs): + return df_flat + + monkeypatch.setattr(models, "flatten_feature_data", _mock_flatten) + monkeypatch.setattr(models.data_processing, "print_variable_statistics", lambda *_: None) + + pred_xoff, pred_yoff, pred_erec_log = models.apply_regression_models( + pd.DataFrame({"dummy": np.arange(n_events)}), model_configs + ) + + assert len(pred_xoff) == n_events, "Output length should match input length" + assert len(pred_yoff) == n_events, "Output length should match input length" + assert len(pred_erec_log) == n_events, "Output length should match input length" + + +class TestResidualComputation: + """Test residual computation during training (the basis for apply predictions).""" + + def test_residual_computation_from_mc_and_baseline(self): + """Verify residuals are computed correctly as MC_true - baseline.""" + mc_xoff = np.array([1.0, 2.0, 3.0]) + mc_yoff = np.array([4.0, 5.0, 6.0]) + mc_e0_log = np.array([0.0, 1.0, 2.0]) + + baseline_xoff = np.array([0.5, 1.5, 2.5]) + baseline_yoff = np.array([3.5, 4.5, 5.5]) + baseline_erec_log = np.array([-0.5, 0.5, 1.5]) + + # Residuals should be MC - baseline + xoff_residual = mc_xoff - baseline_xoff + yoff_residual = mc_yoff - baseline_yoff + e_residual = mc_e0_log - baseline_erec_log + + expected_xoff_residual = np.array([0.5, 0.5, 0.5]) + expected_yoff_residual = np.array([0.5, 0.5, 0.5]) + expected_e_residual = np.array([0.5, 0.5, 0.5]) + + np.testing.assert_allclose(xoff_residual, expected_xoff_residual) + np.testing.assert_allclose(yoff_residual, expected_yoff_residual) + np.testing.assert_allclose(e_residual, expected_e_residual) + + def test_residual_standardization(self): + """Verify residuals standardize correctly (mean=0, std=1).""" + residuals = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + mean = residuals.mean() + std = residuals.std() + + residuals_scaled = (residuals - mean) / std + + assert np.isclose(residuals_scaled.mean(), 0.0, atol=1e-10) + assert np.isclose(residuals_scaled.std(), 1.0, atol=1e-10) + + def test_residual_reconstruction_after_standardization(self): + """Verify that residuals can be reconstructed from standardized predictions.""" + original_residuals = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + mean = original_residuals.mean() + std = original_residuals.std() + + # Training standardize + scaled_residuals = (original_residuals - mean) / std + + # Apply model predicts scaled residuals, inverse transform + predicted_scaled = scaled_residuals # Assume perfect prediction + reconstructed_residuals = predicted_scaled * std + mean + + np.testing.assert_allclose(reconstructed_residuals, original_residuals, rtol=1e-10) + + +class TestFinalPredictionReconstruction: + """Test that final predictions correctly reconstruct from residuals + baselines.""" + + def test_final_direction_reconstruction(self): + """Verify direction predictions = baseline + residual.""" + baseline_xoff = np.array([1.0, 2.0, 3.0]) + baseline_yoff = np.array([4.0, 5.0, 6.0]) + pred_xoff_residual = np.array([0.5, 0.3, 0.2]) + pred_yoff_residual = np.array([0.1, 0.2, 0.3]) + + # Final prediction should be baseline + residual + final_xoff = baseline_xoff + pred_xoff_residual + final_yoff = baseline_yoff + pred_yoff_residual + + expected_xoff = np.array([1.5, 2.3, 3.2]) + expected_yoff = np.array([4.1, 5.2, 6.3]) + + np.testing.assert_allclose(final_xoff, expected_xoff) + np.testing.assert_allclose(final_yoff, expected_yoff) + + def test_final_energy_reconstruction(self): + """Verify energy predictions = baseline + residual (in log10 space).""" + baseline_erec_log = np.array([2.0, 3.0, 4.0]) # log10(100), log10(1000), log10(10000) + pred_erec_log_residual = np.array([0.1, -0.2, 0.3]) + + # Final log10 energy + final_erec_log = baseline_erec_log + pred_erec_log_residual + + # Convert to linear energy + final_erec = np.power(10.0, final_erec_log) + + expected_erec_log = np.array([2.1, 2.8, 4.3]) + expected_erec = np.array([10**2.1, 10**2.8, 10**4.3]) + + np.testing.assert_allclose(final_erec_log, expected_erec_log) + np.testing.assert_allclose(final_erec, expected_erec) diff --git a/tests/test_train_regression_standardization.py b/tests/test_train_regression_standardization.py new file mode 100644 index 0000000..2a389d3 --- /dev/null +++ b/tests/test_train_regression_standardization.py @@ -0,0 +1,372 @@ +"""Tests for target standardization and energy-bin weighting in train_regression().""" + +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +from eventdisplay_ml import diagnostic_utils, models + + +@pytest.fixture +def regression_training_df(): + """Create a training DataFrame with required columns for regression.""" + rng = np.random.default_rng(42) + n_rows = 100 + + return pd.DataFrame( + { + "Xoff_residual": rng.normal(0.5, 0.3, n_rows), + "Yoff_residual": rng.normal(1.0, 0.5, n_rows), + "E_residual": rng.normal(-0.2, 0.1, n_rows), + "ErecS": np.logspace(1, 2, n_rows), + "DispNImages": rng.choice([2, 3, 4], n_rows), + "Xoff_weighted_bdt": rng.normal(0, 0.5, n_rows), + "Yoff_weighted_bdt": rng.normal(0, 0.5, n_rows), + "mscw": rng.uniform(0, 1, n_rows), + "mscl": rng.uniform(0, 1, n_rows), + } + ) + + +@pytest.fixture +def regression_model_config(): + """Create a model configuration for regression training.""" + return { + "targets": ["Xoff_residual", "Yoff_residual", "E_residual"], + "train_test_fraction": 0.5, + "random_state": 42, + "models": { + "xgboost": { + "hyper_parameters": { + "n_estimators": 10, + "max_depth": 3, + "random_state": 42, + "early_stopping_rounds": 2, + "eval_metric": "rmse", + } + } + }, + } + + +class TestTargetStandardization: + """Tests for target standardization (mean and std) storage.""" + + def test_target_mean_std_computed_from_training_set( + self, regression_training_df, regression_model_config + ): + """Verify target_mean and target_std are computed from training data only.""" + df = regression_training_df + cfg = regression_model_config + + # Train the model + result = models.train_regression(df, cfg) + + # Check that target_mean and target_std are stored in config + assert "target_mean" in result, "target_mean not stored in model config" + assert "target_std" in result, "target_std not stored in model config" + + # Verify they are dictionaries with all target keys + assert isinstance(result["target_mean"], dict) + assert isinstance(result["target_std"], dict) + assert set(result["target_mean"].keys()) == set(cfg["targets"]) + assert set(result["target_std"].keys()) == set(cfg["targets"]) + + def test_target_mean_std_values_reasonable( + self, regression_training_df, regression_model_config + ): + """Verify target_mean and target_std have reasonable values.""" + df = regression_training_df.copy() + cfg = regression_model_config + + # Manually compute expected values from training set (50%) + # train_test_split with train_size=0.5 and random_state=42 + from sklearn.model_selection import train_test_split + + x_cols = [col for col in df.columns if col not in cfg["targets"]] + _, _, y_data_train, _ = train_test_split( + df[x_cols], + df[cfg["targets"]], + train_size=cfg["train_test_fraction"], + random_state=cfg["random_state"], + ) + + expected_mean = y_data_train.mean() + expected_std = y_data_train.std() + + result = models.train_regression(df, cfg) + + # Verify computed values match expected + for target in cfg["targets"]: + assert np.isclose(result["target_mean"][target], expected_mean[target], rtol=1e-5), ( + f"{target} mean mismatch" + ) + assert np.isclose(result["target_std"][target], expected_std[target], rtol=1e-5), ( + f"{target} std mismatch" + ) + + def test_target_std_never_zero(self, regression_training_df, regression_model_config): + """Verify target_std values are not zero (to avoid division by zero).""" + df = regression_training_df + cfg = regression_model_config + + result = models.train_regression(df, cfg) + + for target in cfg["targets"]: + assert result["target_std"][target] > 0, f"{target} std should not be zero" + + +class TestEnergyBinWeighting: + """Tests for energy-bin weighting (especially zeroing low-count bins).""" + + def test_log_energy_bin_counts_returns_correct_structure(self, regression_training_df): + """Verify _log_energy_bin_counts() returns expected tuple structure.""" + df = regression_training_df + result = models._log_energy_bin_counts(df) + + assert result is not None, "Should return a tuple, not None" + bins, counts_dict, weights = result + + # Check tuple structure + assert isinstance(bins, np.ndarray), "bins should be ndarray" + assert isinstance(counts_dict, dict), "counts_dict should be dict" + assert isinstance(weights, np.ndarray), "weights should be ndarray" + + # Verify weight array has same length as input + assert len(weights) == len(df), "weights array length should match dataframe rows" + + def test_log_energy_bin_counts_zeroes_low_count_bins(self): + """Verify bins with < 10 events get zero weight.""" + # Create minimal dataframe with specific energy distribution + rng = np.random.default_rng(42) + n_rows = 30 + df = pd.DataFrame( + { + "ErecS": np.concatenate( + [ + np.full(15, 100.0), + np.full(10, 10.0), + np.full(5, 1000.0), + ] + ), + "E_residual": np.zeros(n_rows), + "DispNImages": rng.choice([2, 3], n_rows), + } + ) + + result = models._log_energy_bin_counts(df) + assert result is not None + + _, counts_dict, weights = result + + # Find which bins have < 10 events + low_count_bins = {interval: count for interval, count in counts_dict.items() if count < 10} + + # Events in low-count bins should have zero energy weight + # (multiplicity weight might still apply, but energy weight should be 0) + if low_count_bins: + zero_weights = weights[weights == 0] + if len(zero_weights) > 0: + assert len(zero_weights) > 0, "Expected some zero weights for low-count bins" + + def test_log_energy_bin_counts_weight_normalization(self, regression_training_df): + """Verify combined weights are normalized to mean ~1.0.""" + df = regression_training_df + result = models._log_energy_bin_counts(df) + + _, _, weights = result + + # Check that weight array is normalized + # (mean should be ~1.0 after normalization) + weight_mean = np.mean(weights) + assert np.isclose(weight_mean, 1.0, rtol=0.01), ( + f"Weight mean should be ~1.0, got {weight_mean}" + ) + + def test_log_energy_bin_counts_handles_missing_columns(self): + """Verify graceful handling when E_residual/ErecS missing.""" + df = pd.DataFrame( + { + "DispNImages": [2, 3, 4], + "some_other_col": [1.0, 2.0, 3.0], + } + ) + + result = models._log_energy_bin_counts(df) + assert result is None, "Should return None when E_residual/ErecS missing" + + def test_energy_bin_weighting_in_training( + self, regression_training_df, regression_model_config + ): + """Verify energy-bin weights are applied during model training.""" + df = regression_training_df + cfg = regression_model_config + + # Mock the XGBRegressor to capture the sample_weight argument + with patch("xgboost.XGBRegressor") as mock_xgb: + mock_model = MagicMock() + mock_model.best_iteration = 5 + mock_model.best_score = 0.1 + mock_model.predict.return_value = np.zeros((len(df) // 2, 3)) + mock_xgb.return_value = mock_model + + # Mock evaluate_regression_model to return empty dict + with patch("eventdisplay_ml.models.evaluate_regression_model") as mock_eval: + mock_eval.return_value = {} + + models.train_regression(df, cfg) + + # Verify fit() was called with sample_weight + mock_model.fit.assert_called_once() + call_args = mock_model.fit.call_args + + # Check that sample_weight is not None + sample_weight = call_args.kwargs.get("sample_weight") + assert sample_weight is not None, "sample_weight should be passed to fit()" + assert len(sample_weight) == len(df) // 2 # Training set size + + +class TestTrainRegressionIntegration: + """Integration tests for train_regression() with standardization and weighting.""" + + def test_train_regression_complete_workflow( + self, regression_training_df, regression_model_config + ): + """Verify complete training workflow with standardization and weighting.""" + df = regression_training_df + cfg = regression_model_config + + result = models.train_regression(df, cfg) + + # Check critical outputs + assert result is not None + assert "target_mean" in result + assert "target_std" in result + assert "models" in result + assert "xgboost" in result["models"] + assert "model" in result["models"]["xgboost"] + assert "generalization_metrics" in result["models"]["xgboost"] + assert "shap_importance" in result["models"]["xgboost"] + + def test_generalization_metrics_cached_per_target( + self, regression_training_df, regression_model_config + ): + """Verify train/test RMSE summary is cached in the model config.""" + result = models.train_regression(regression_training_df, regression_model_config) + + metrics = result["models"]["xgboost"]["generalization_metrics"] + assert set(metrics) == set(regression_model_config["targets"]) + + for target in regression_model_config["targets"]: + assert set(metrics[target]) == {"rmse_train", "rmse_test", "gap_pct", "gen_ratio"} + assert np.isfinite(metrics[target]["rmse_train"]) + assert np.isfinite(metrics[target]["rmse_test"]) + + def test_generalization_metrics_match_training_predictions( + self, regression_training_df, regression_model_config + ): + """Verify cached generalization metrics match the model predictions used in training.""" + df = regression_training_df + cfg = regression_model_config + + with patch("xgboost.XGBRegressor") as mock_xgb: + mock_model = MagicMock() + mock_model.best_iteration = 5 + mock_model.best_score = 0.1 + + def _predict(x_values): + return np.zeros((len(x_values), len(cfg["targets"]))) + + mock_model.predict.side_effect = _predict + mock_xgb.return_value = mock_model + + with patch("eventdisplay_ml.models.evaluate_regression_model") as mock_eval: + mock_eval.return_value = {} + result = models.train_regression(df, cfg) + + from sklearn.model_selection import train_test_split + + x_cols = [col for col in df.columns if col not in cfg["targets"]] + _, _, y_train, y_test = train_test_split( + df[x_cols], + df[cfg["targets"]], + train_size=cfg["train_test_fraction"], + random_state=cfg["random_state"], + ) + + target_mean = np.array([result["target_mean"][target] for target in cfg["targets"]]) + y_train_pred = pd.DataFrame( + np.tile(target_mean, (len(y_train), 1)), + columns=cfg["targets"], + index=y_train.index, + ) + y_test_pred = pd.DataFrame( + np.tile(target_mean, (len(y_test), 1)), + columns=cfg["targets"], + index=y_test.index, + ) + + expected_metrics = diagnostic_utils.compute_generalization_metrics( + y_train, + y_train_pred, + y_test, + y_test_pred, + cfg["targets"], + ) + + assert result["models"]["xgboost"]["generalization_metrics"] == expected_metrics + + def test_scaled_predictions_unscaled_correctly( + self, regression_training_df, regression_model_config + ): + """Verify predictions are correctly unscaled using stored mean/std.""" + # This test verifies the inverse transformation logic + df = regression_training_df.copy() + cfg = regression_model_config + + result = models.train_regression(df, cfg) + + # Get the stored scalers + target_mean = result["target_mean"] + target_std = result["target_std"] + + # Simulate scaled prediction: y_scaled = 1.0 for all targets + y_pred_scaled = np.array([[1.0, 1.0, 1.0]]) + + # Manually unscale + y_pred_unscaled = y_pred_scaled * np.array( + [target_std[target] for target in cfg["targets"]] + ) + np.array([target_mean[target] for target in cfg["targets"]]) + + # Verify unscaling produces reasonable values + for i, target in enumerate(cfg["targets"]): + assert np.isfinite(y_pred_unscaled[0, i]), ( + f"{target} unscaled prediction should be finite" + ) + + def test_train_test_split_preserved_correctly( + self, regression_training_df, regression_model_config + ): + """Verify train/test split doesn't leak into weight computation.""" + df = regression_training_df + cfg = regression_model_config + + # Train with fixed random state multiple times + cfg1 = cfg.copy() + config1 = models.train_regression(df, cfg1) + + cfg2 = cfg.copy() + config2 = models.train_regression(df, cfg2) + + # With same random state, should get identical mean/std + for target in cfg["targets"]: + assert np.isclose( + config1["target_mean"][target], + config2["target_mean"][target], + ), "target mean should be identical with same random_state" + assert np.isclose( + config1["target_std"][target], + config2["target_std"][target], + ), "target std should be identical with same random_state"