diff --git a/.github/workflows/blog.yml b/.github/workflows/blog.yml deleted file mode 100644 index b797520a..00000000 --- a/.github/workflows/blog.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: Build and Deploy Blog - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -permissions: - contents: read - pages: write - id-token: write - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v6 - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: 'latest' - cache: 'npm' - cache-dependency-path: blog/package-lock.json - - - name: Install dependencies - run: cd blog && npm ci - - - name: Build - run: cd ./blog && npm run build - - - name: Upload build artifact - if: github.event_name == 'push' && github.ref == 'refs/heads/main' - uses: actions/upload-artifact@v6 - with: - name: blog-build - path: blog/public - retention-days: 1 - - deploy: - needs: build - if: github.event_name == 'push' && github.ref == 'refs/heads/main' - runs-on: ubuntu-latest - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - concurrency: - group: pages - cancel-in-progress: false - - steps: - - name: Download build artifact - uses: actions/download-artifact@v7 - with: - name: blog-build - path: blog/public - - - name: Setup Pages - uses: actions/configure-pages@v5 - - - name: Upload Pages artifact - uses: actions/upload-pages-artifact@v4 - with: - path: blog/public - retention-days: 1 - - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..a558b951 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,138 @@ +name: docs +permissions: + contents: write + pages: write + pull-requests: write + +on: + push: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/docs.yml + - '**.py' + - '**.ipynb' + - '**.html' + - '**.js' + - '**.md' + - uv.lock + - pyproject.toml + - mkdocs.yml + - '**.png' + - '**.svg' + pull_request: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/docs.yml + - '**.py' + - '**.ipynb' + - '**.js' + - '**.html' + - uv.lock + - pyproject.toml + - '**.md' + - mkdocs.yml + - '**.png' + - '**.svg' + +jobs: + build-docs: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + version: "0.9.11" + enable-cache: true + + - name: Set up Python + uses: actions/setup-python@v6.1.0 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --group docs + + - name: Build docs + run: uv run mkdocs build + + - name: Create .nojekyll file + run: touch site/.nojekyll + + - name: Upload artifact + uses: actions/upload-artifact@v6 + with: + name: docs-site + path: site/ + retention-days: 1 + build-blog: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 'latest' + cache: 'npm' + cache-dependency-path: blog/package-lock.json + + - name: Install dependencies + run: cd blog && npm ci + + - name: Build + run: cd ./blog && npm run build + + - name: Upload build artifact + uses: actions/upload-artifact@v6 + with: + name: blog-build + path: blog/public + retention-days: 1 + deploy: + needs: + - build-docs + - build-blog + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + concurrency: + group: pages + cancel-in-progress: false + steps: + - name: Checkout code + uses: actions/checkout@v6.0.1 + + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + + - name: Download docs artifact + uses: actions/download-artifact@v7 + with: + name: docs-site + path: site + - name: Ensure .nojekyll exists + run: touch site/.nojekyll + - name: Create a folder for blog + run: mkdir -p site/blog + - name: Download blog artifact + uses: actions/download-artifact@v7 + with: + name: blog-build + path: site/blog + - name: Deploy to Github pages + id: deployment + uses: JamesIves/github-pages-deploy-action@v4.8.0 + with: + branch: gh-pages + folder: site \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..3dc40099 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,57 @@ +name: publish package +permissions: + contents: read + pull-requests: write + +on: + push: + tags: + - "v*" + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - name: Install apt dependencies + run: | + sudo apt-get update + sudo apt-get install libcurl4-openssl-dev libssl-dev + - uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + # Install a specific version of uv. + version: "0.9.11" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Build package + run: uv build + + - name: Publish package + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + + release_github: + needs: deploy + runs-on: ubuntu-latest + permissions: + contents: write # To create a github release + steps: + - name: Create GitHub Release + id: create_release + uses: ncipollo/release-action@v1.21.0 + with: + artifacts: "dist/*" + generateReleaseNotes: true + \ No newline at end of file diff --git a/.gitignore b/.gitignore index c28b9fe1..6c08e923 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ *.pyc .venv/ blog/node_modules/ -.DS_Store diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..c8cfe395 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10 diff --git a/README.md b/README.md index 085daf4e..d30eec6f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ CRISP-NAM (Competing Risks Interpretable Survival Prediction with Neural Additiv ## Overview -This repository provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial. +This package provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial. ### Key Features @@ -16,112 +16,14 @@ This repository provides a comprehensive framework for competing risks survival - **Multiple Training Modes**: Standard training, hyperparameter tuning, and nested cross-validation - **Baseline Comparisons**: DeepHit implementation for benchmarking against state-of-the-art methods -### Available Datasets - -The repository includes four well-established survival analysis datasets: - -1. **Framingham Heart Study**: Cardiovascular disease prediction with competing events (CVD vs. death) - - Features: Demographics, clinical measurements, lifestyle factors - - Events: Cardiovascular disease, death from other causes - -2. **PBC (Primary Biliary Cirrhosis)**: Liver disease progression study - - Features: Clinical laboratory values, demographic information - - Events: Death, liver transplantation - -3. **SUPPORT**: Study to understand prognoses and preferences for outcomes - - Features: Comprehensive clinical and demographic variables - - Events: Cancer death, non-cancer death - -4. **Synthetic Dataset**: Controlled simulation for method validation - - Features: Simulated clinical variables with known ground truth - - Events: Multiple competing risks with controllable hazard functions - -All datasets come with preprocessing pipelines that handle missing values, feature encoding, and proper train/test splitting to prevent data leakage. - -### Training Scripts - -The repository provides several specialized training scripts: - -- **`train.py`**: Standard model training with cross-validation and comprehensive evaluation -- **`train_nested_cv.py`**: Robust nested cross-validation for unbiased performance estimation -- **`tune_optuna.py`**: Hyperparameter optimization using Optuna's advanced algorithms -- **`train_deephit.py`**: DeepHit baseline implementation for comparative studies - -Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps. - ## Requirements + Python >=3.10 -## Repository Structure +## Install the package -``` -crisp_nam/ -├── crisp_nam/ # Main package -│ ├── metrics/ -│ │ ├── __init__.py -│ │ ├── calibration.py -│ │ ├── discrimination.py -│ │ └── ipcw.py -│ ├── models/ -│ │ ├── __init__.py -│ │ ├── crisp_nam_model.py -│ │ └── deephit_model.py -│ ├── utils/ -│ │ ├── __init__.py -│ │ ├── loss.py -│ │ ├── plotting.py -│ │ └── risk_cif.py -│ └── __init__.py -├── data_utils/ # Data utilities -│ ├── __init__.py -│ ├── load_datasets.py -│ └── survival_datasets.py -├── datasets/ # Dataset files and loaders -│ ├── metabric/ -│ │ ├── cleaned_features_final.csv -│ │ └── label.csv -│ ├── framingham_dataset.py -│ ├── framingham.csv -│ ├── pbc_dataset.py -│ ├── pbc2.csv -│ ├── support_dataset.py -│ ├── support2.csv -│ ├── SurvivalDataset.py -│ ├── synthetic_comprisk.csv -│ └── synthetic_dataset.py -├── results/ # Results and outputs -│ ├── best_params/ # Best parameters for dataset and model combinations -│ │ ├── best_params_framingham_deephit.yaml -│ │ ├── best_params_framingham.yaml -│ │ ├── best_params_pbc_deephit.yaml -│ │ ├── best_params_pbc.yaml -│ │ ├── best_params_support.yaml -│ │ ├── best_params_support2_deephit.yaml -│ │ ├── best_params_synthetic_deephit.yaml -│ │ └── best_params_synthetic.yaml -│ ├── logs/ # Nested CV results and logs -│ │ ├── nested_cv_best_params_*.yaml -│ │ ├── nested_cv_detailed_metrics_*.csv -│ │ ├── nested_cv_metrics_*.xlsx -│ │ ├── nested_cv_raw_metrics_*.json -│ │ └── nested_cv_summary_metrics_*.csv -│ └── plots/ # Generated plots -│ ├── nested_cv_feature_importance_risk_*_*.png -│ └── nested_cv_shape_functions_risk_*_*.png -├── training_scripts/ # Training scripts -│ ├── config.yaml -│ ├── model_utils.py -│ ├── train_deephit_cuda.py -│ ├── train_deephit.py -│ ├── train_nested_cv.py # Nested cross-validation script -│ ├── train.py -│ ├── tune_optuna_optimized.py -│ └── tune_optuna.py -└── utils/ # Legacy utils (duplicate of crisp_nam/utils) - ├── __init__.py - ├── loss.py - ├── plotting.py - └── risk_cif.py +```bash +pip install crisp-nam ``` ## Install from source @@ -145,135 +47,26 @@ via `uv` cd crisp-nam uv sync ``` +## Research details -## Running training scripts - -1. Modify training parameters in `training_scripts/train.py` - OR - Run either of following commands to see CLI arguments for passing training parameters: - - ```bash - python training_scripts/train.py --help - ``` - - ```bash - uv run training_scripts/train.py --help - ``` - -2. Run the training script - - via `python` - ```bash - source .venv/bin/activate - python training_scripts/train.py --dataset framingham - ``` - - via `uv` - ```bash - uv run training_scripts/train.py --dataset framingham - ``` - -## Running Nested Cross-Validation - -The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations. - -### Basic Usage - -```bash -# Using python -python training_scripts/train_nested_cv.py --dataset framingham - -# Using uv -uv run training_scripts/train_nested_cv.py --dataset framingham -``` +For more details regarding the research work, please refer to `datasets.md` and `training.md` within the project repository. -### Configuration Parameters - -All parameters can be passed via command line or specified in a YAML config file: - -#### Dataset Configuration -- `--dataset` (str): Dataset to use (choices: `framingham`, `support`, `pbc`, `synthetic`, default: `framingham`) -- `--scaling` (str): Data scaling method for continuous features (choices: `minmax`, `standard`, `none`, default: `standard`) - -#### Training Parameters -- `--num_epochs` (int): Number of training epochs (default: `250`) -- `--batch_size` (int): Batch size for training (default: `512`) -- `--patience` (int): Patience for early stopping (default: `10`) - -#### Cross-Validation Configuration -- `--outer_folds` (int): Number of outer CV folds (default: `5`) -- `--inner_folds` (int): Number of inner CV folds for hyperparameter tuning (default: `3`) -- `--n_trials` (int): Number of Optuna trials per inner fold (default: `20`) - -#### Event Weighting -- `--event_weighting` (str): Event weighting strategy (choices: `none`, `balanced`, `custom`, default: `none`) -- `--custom_event_weights` (str): Custom weights for events (comma-separated, default: `None`) - -#### Other Parameters -- `--seed` (int): Random seed for reproducibility (default: `42`) -- `--config` (str): Path to YAML config file (default: looks for `config.yaml`) +## Contributing -### Examples +Contributions are welcome! Please open issues or submit pull requests. -#### Basic nested CV with default parameters: -```bash -python training_scripts/train_nested_cv.py --dataset pbc -``` +## Citation -#### Customized nested CV with specific parameters: -```bash -python training_scripts/train_nested_cv.py \ - --dataset support \ - --outer_folds 10 \ - --inner_folds 5 \ - --n_trials 50 \ - --num_epochs 500 \ - --event_weighting balanced \ - --scaling minmax \ - --seed 123 +If you use our package, kindly acknowledge by citing our research. ``` - -#### Using a config file: -```bash -python training_scripts/train_nested_cv.py --config my_config.yaml +@inproceedings{ramachandram2025crispnam, + title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models}, + author={Ramachandram, Dhanesh and Raval, Ananya}, + booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain}, + year={2025} +} ``` -### Output Files - -The script generates several output files in the current directory: - -#### Performance Metrics -- `nested_cv_summary_metrics_{dataset}.csv`: Summary table with mean ± std metrics -- `nested_cv_detailed_metrics_{dataset}.csv`: Detailed results for each fold -- `nested_cv_metrics_{dataset}.xlsx`: Excel file with multiple sheets (Summary, Detailed, Metadata) -- `nested_cv_raw_metrics_{dataset}.json`: Raw metrics dictionary for reproducibility - -#### Model Configuration -- `nested_cv_best_params_{dataset}.yaml`: Aggregated best hyperparameters across all folds - -#### Visualizations -Results are saved to `results/plots/`: -- `nested_cv_feature_importance_risk_{risk}_{dataset}.png`: Feature importance plots -- `nested_cv_shape_functions_risk_{risk}_{dataset}.png`: Shape function plots for top features - -### Evaluation Metrics - -The script computes the following metrics at different time quantiles (25%, 50%, 75%): - -- **AUC (Area Under the ROC Curve)**: Time-dependent AUC for discrimination - - 0.5 = random, >0.7 = good, >0.8 = excellent -- **TDCI (Time-Dependent Concordance Index)**: Harrell's C-index adapted for competing risks - - 0.5 = random, >0.7 = good, >0.8 = excellent -- **Brier Score**: Calibration metric measuring prediction accuracy - - 0 = perfect, <0.25 = good, >0.25 = poor - -> [!NOTE] -> For `uv` installation, please visit follow instructions in their [official page](https://docs.astral.sh/uv/getting-started/installation/). - -## Contributing - -Contributions are welcome! Please open issues or submit pull requests. - ## License -This project is licensed under the MIT License. \ No newline at end of file +This project is licensed under the MIT License. diff --git a/crisp_nam/__init__.py b/crisp_nam/__init__.py index d67a0b62..3f10443c 100644 --- a/crisp_nam/__init__.py +++ b/crisp_nam/__init__.py @@ -1 +1,3 @@ -__all__ = [ 'utils', 'models', 'data_utils', 'metrics' ] \ No newline at end of file +"""CRISP-NAM paclage.""" + +__all__ = ["utils", "models", "data_utils", "metrics"] diff --git a/crisp_nam/metrics/__init__.py b/crisp_nam/metrics/__init__.py index b9079625..01a720b3 100644 --- a/crisp_nam/metrics/__init__.py +++ b/crisp_nam/metrics/__init__.py @@ -1,3 +1,14 @@ -from .calibration import * -from .discrimination import * -from .ipcw import * +"""Evaluation metrics used within crisp_nam package.""" + +from .calibration import brier_score, integrated_brier_score +from .discrimination import auc_td, cumulative_dynamic_auc, truncated_concordance_td +from .ipcw import estimate_ipcw + +__all__ = [ + "brier_score", + "integrated_brier_score", + "auc_td", + "cumulative_dynamic_auc", + "truncated_concordance_td", + "estimate_ipcw", +] diff --git a/crisp_nam/metrics/calibration.py b/crisp_nam/metrics/calibration.py index 3c035f38..26556604 100644 --- a/crisp_nam/metrics/calibration.py +++ b/crisp_nam/metrics/calibration.py @@ -1,26 +1,47 @@ +"""Calibration metrics for time-to-event models with competing risks. + +This module contains functions to compute the Brier score and +integrated Brier score for competing risks. +""" + import numpy as np + from .ipcw import estimate_ipcw + # A small constant to avoid division by zero epsilon = 1e-4 -def brier_score(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1): - """ - Compute the corrected Brier score for a given competing risk. - +def brier_score( + e_test: np.ndarray, + t_test: np.ndarray, + risk_predicted_test: np.ndarray, + times: np.ndarray, + t: float, + km: object | None = None, + primary_risk: int = 1, +) -> tuple[float, object]: + """Compute the corrected Brier score for a given competing risk. + This implementation is based on the work of Schoop et al. on quantifying the predictive accuracy of time-to-event models in the presence of competing risks. - - Parameters: - e_test (ndarray): Array of event indicators (0 = censored; positive integers for different events). + + Parameters + ---------- + e_test (ndarray): Array of event indicators (0 = censored; + positive integers for different events). t_test (ndarray): Array of event/censoring times. - risk_predicted_test (ndarray): Predicted risk matrix with shape (n_samples, n_times). - times (ndarray): Array of time points corresponding to columns in risk_predicted_test. + risk_predicted_test (ndarray): Predicted risk matrix with + shape (n_samples, n_times). + times (ndarray): Array of time points corresponding to columns + in risk_predicted_test. t (float): Time at which to evaluate the Brier score. - km (object, optional): Kaplan–Meier estimator or data to estimate the censoring distribution. + km (object, optional): Kaplan–Meier estimator or data to estimate + the censoring distribution. primary_risk (int, optional): The event label for which to compute the score. - - Returns: + + Returns + ------- brier (float): The corrected Brier score evaluated at time t. km (object): Updated Kaplan–Meier estimator (if applicable). """ @@ -39,47 +60,67 @@ def brier_score(e_test, t_test, risk_predicted_test, times, t, km=None, primary_ # Initialize weights for IPCW correction. weights = np.zeros_like(e_test, dtype=float) - # For subjects with events (or censoring) before t (excluding those censored exactly at 0 event label), use KM weights. + # For subjects with events (or censoring) before t + # (excluding those censored exactly at 0 event label), use KM weights. mask = (t_test <= t) & (e_test != 0) - weights[mask] = 1. / np.clip(km.survival_function_at_times(t_test[mask]), epsilon, None) - # For subjects still at risk at time t, assign constant weight based on KM at time t. - weights[t_test > t] = 1. / np.clip(km.survival_function_at_times(t), epsilon, None) + weights[mask] = 1.0 / np.clip( + km.survival_function_at_times(t_test[mask]), epsilon, None + ) + # For subjects still at risk at time t, assign constant weight based + # on KM at time t. + weights[t_test > t] = 1.0 / np.clip(km.survival_function_at_times(t), epsilon, None) brier = (weights * (truth - risk_predicted_test[:, index]) ** 2).mean() return brier, km -def integrated_brier_score(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1): +def integrated_brier_score( + e_test: np.ndarray, + t_test: np.ndarray, + risk_predicted_test: np.ndarray, + times: np.ndarray, + t_eval: np.ndarray | None = None, + km: object | None = None, + primary_risk: int = 1, +) -> tuple[float, object]: """ Compute the integrated Brier score for competing risks over a range of time points. - - The integrated Brier score is computed by numerically integrating the Brier score over the evaluation times. - - Parameters: + + The integrated Brier score is computed by numerically integrating the + Brier score over the evaluation times. + + Parameters + ---------- e_test (ndarray): Event indicators. t_test (ndarray): Event/censoring times. - risk_predicted_test (ndarray): Predicted risk matrix with shape (n_samples, n_times). + risk_predicted_test (ndarray): Predicted risk matrix + with shape (n_samples, n_times). times (ndarray): Array of time points corresponding to the predictions. - t_eval (ndarray, optional): Specific time points at which to compute the score. Defaults to using 'times'. + t_eval (ndarray, optional): Specific time points at which to + compute the score. Defaults to using 'times'. km (object, optional): Kaplan–Meier estimator for IPCW. primary_risk (int, optional): The event label for which to compute the score. - - Returns: + + Returns + ------- ibs (float): Integrated Brier score. km (object): Updated Kaplan–Meier estimator. """ km = estimate_ipcw(km) t_eval = times if t_eval is None else t_eval # Compute Brier scores at each time point. - brier_scores = [brier_score(e_test, t_test, risk_predicted_test, times, t_val, km, primary_risk)[0] - for t_val in t_eval] + brier_scores = [ + brier_score( + e_test, t_test, risk_predicted_test, times, t_val, km, primary_risk + )[0] + for t_val in t_eval + ] # Remove NaN values if any. t_eval = t_eval[~np.isnan(brier_scores)] brier_scores = np.array(brier_scores)[~np.isnan(brier_scores)] - + if t_eval.shape[0] < 2: raise ValueError("At least two time points must be provided for integration.") - + ibs = np.trapz(brier_scores, t_eval) / (t_eval[-1] - t_eval[0]) return ibs, km - diff --git a/crisp_nam/metrics/discrimination.py b/crisp_nam/metrics/discrimination.py index 05aa17a0..79bcd7ab 100644 --- a/crisp_nam/metrics/discrimination.py +++ b/crisp_nam/metrics/discrimination.py @@ -1,13 +1,32 @@ +"""Discrimination metrics for time-to-event models with competing risks. + +This module contains functions to compute the cumulative +and single time-dependent AUC and time-dependent C-index +for evaluating competing risks. +""" + import numpy as np -from .ipcw import estimate_ipcw + +from .ipcw import estimate_ipcw + epsilon = 1e-10 -def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1): + +def auc_td( + e_test: np.ndarray, + t_test: np.ndarray, + risk_predicted_test: np.ndarray, + times: np.ndarray, + t: float, + km: object | None = None, + primary_risk: int = 1, +) -> tuple[float, object]: """ Compute the time-dependent AUC for a given competing risk using predicted CIFs. - Parameters: + Parameters + ---------- e_test : ndarray of shape (n_samples,) Event indicator (0=censored, 1=event of interest, 2=competing event, etc.) t_test : ndarray of shape (n_samples,) @@ -23,7 +42,8 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk= primary_risk : int The event label to treat as the event of interest. - Returns: + Returns + ------- auc_value : float AUC estimate at time t (always between 0 and 1) km : Updated Kaplan-Meier estimator @@ -53,8 +73,8 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk= # Compute pairwise AUC: compare each (event, control) pair auc_numerator = 0.0 auc_denominator = 0.0 - for i, (score_i, w_i) in enumerate(zip(event_scores, weights_event)): - for j, (score_j, w_j) in enumerate(zip(control_scores, weights_control)): + for _i, (score_i, w_i) in enumerate(zip(event_scores, weights_event)): + for _j, (score_j, w_j) in enumerate(zip(control_scores, weights_control)): weight = w_i * w_j auc_denominator += weight if score_i > score_j: @@ -66,17 +86,27 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk= auc_value = auc_numerator / auc_denominator if auc_denominator > 0 else np.nan return auc_value, km -def cumulative_dynamic_auc(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1): - """ - Computes the cumulative dynamic AUC by numerically integrating the time-dependent AUC over a range of time points. - - Parameters: +def cumulative_dynamic_auc( + e_test: np.ndarray, + t_test: np.ndarray, + risk_predicted_test: np.ndarray, + times: np.ndarray, + t_eval: np.ndarray | None = None, + km: object | None = None, + primary_risk: int = 1, +) -> tuple[float, object]: + """Compute the cumulative dynamic AUC by numerically + integrating the time-dependent AUC over a range of time points. + + Parameters + ---------- e_test, t_test, risk_predicted_test, times, km, primary_risk: Same as in auc_td. t_eval: ndarray, optional Specific time points to evaluate. If None, uses times. - - Returns: + + Returns + ------- auc_integral: float The cumulative dynamic AUC. km: object @@ -84,18 +114,32 @@ def cumulative_dynamic_auc(e_test, t_test, risk_predicted_test, times, t_eval=No """ km = estimate_ipcw(km) t_eval = times if t_eval is None else t_eval - aucs = [auc_td(e_test, t_test, risk_predicted_test, times, t, km, primary_risk)[0] for t in t_eval] + aucs = [ + auc_td(e_test, t_test, risk_predicted_test, times, t, km, primary_risk)[0] + for t in t_eval + ] t_eval, aucs = t_eval[~np.isnan(aucs)], np.array(aucs)[~np.isnan(aucs)] if t_eval.shape[0] < 2: raise ValueError("At least two time points must be given") auc_integral = np.trapz(aucs, t_eval) / (t_eval[-1] - t_eval[0]) return auc_integral, km -def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1, tied_tol=1e-8): + +def truncated_concordance_td( + e_test: np.ndarray, + t_test: np.ndarray, + risk_predicted_test: np.ndarray, + times: np.ndarray, + t: float, + km: object | None = None, + primary_risk: int = 1, + tied_tol: float = 1e-8, +) -> tuple[float, object]: """ Compute the truncated time-dependent concordance index (C-index). - - Parameters: + + Parameters + ---------- e_test : ndarray Event indicator (0=censored, 1=event of interest, etc.) t_test : ndarray @@ -112,14 +156,15 @@ def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=N Risk of interest tied_tol : float Tolerance to assign 0.5 score for ties - - Returns: + + Returns + ------- c_index : float km : Updated km object """ epsilon = 1e-10 index = np.argmin(np.abs(times - t)) - + # IPCW if km is not None: km = estimate_ipcw(km) @@ -144,7 +189,7 @@ def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=N after_mask = t_test > t_i before_mask = (t_test <= t_i) & (e_test != primary_risk) & (e_test != 0) - weights_after = weights_event[after_mask] / (w_i ** 2) + weights_after = weights_event[after_mask] / (w_i**2) weights_before = weights_event[before_mask] / (w_i * weights_event[before_mask]) risks_after = risk_predicted_test[after_mask, index] diff --git a/crisp_nam/metrics/ipcw.py b/crisp_nam/metrics/ipcw.py index bb794484..71ce0b95 100644 --- a/crisp_nam/metrics/ipcw.py +++ b/crisp_nam/metrics/ipcw.py @@ -1,21 +1,25 @@ +"""IPCW estimation for time-to-event models with competing risks. + +This module provides a function to estimate the inverse probability of censoring weights (IPCW) using a +Kaplan-Meier estimator. +""" + from lifelines import KaplanMeierFitter -def estimate_ipcw(km): - """ - Estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator. - Parameters: - ----------- +def estimate_ipcw(km: tuple | KaplanMeierFitter) -> KaplanMeierFitter: + """Estimate the inverse probability of censoring weights (IPCW) + using a Kaplan-Meier estimator. + + Parameters + ---------- km : tuple or KaplanMeierFitter - If `km` is a tuple, it should contain two elements: - - e_train: array-like, event indicators (1 if the event occurred, 0 if censored). - - t_train: array-like, corresponding event or censoring times. - If `km` is already a fitted KaplanMeierFitter instance, it will be used directly. - Returns: - -------- + Returns + ------- kmf : KaplanMeierFitter - A KaplanMeierFitter instance fitted to the provided data or the input instance if already fitted. + A KaplanMeierFitter instance fitted to the provided data or + the input instance if already fitted. """ if isinstance(km, tuple): kmf = KaplanMeierFitter() @@ -27,5 +31,3 @@ def estimate_ipcw(km): else: kmf = km return kmf - - diff --git a/crisp_nam/models/__init__.py b/crisp_nam/models/__init__.py index cc2ca5d4..92966a24 100644 --- a/crisp_nam/models/__init__.py +++ b/crisp_nam/models/__init__.py @@ -1,2 +1,7 @@ -from .crisp_nam_model import * -from .deephit_model import * +"""Models available in the crisp_nam package.""" + +from .crisp_nam_model import CrispNamModel +from .deephit_model import DeepHit + + +__all__ = ["CrispNamModel", "DeepHit"] diff --git a/crisp_nam/models/crisp_nam_model.py b/crisp_nam/models/crisp_nam_model.py index c8b8200c..0e5601aa 100644 --- a/crisp_nam/models/crisp_nam_model.py +++ b/crisp_nam/models/crisp_nam_model.py @@ -1,60 +1,96 @@ +"""CrispNamModel for competing-risks survival analysis. + +PyTorch implementation of CrispNamModel for competing risks +survival analysis with L2 normalized projection weights. +""" + +from typing import ( + List, + Optional, + Sequence, + Tuple, +) + import torch import numpy as np -import torch.nn as nn import torch.nn.functional as F +from torch import nn + +class _FeatureNet(nn.Module): + """Neural network to model the effect of a single feature on hazard. -class FeatureNet(nn.Module): - """ - Neural network to model the effect of a single feature on hazard. This is the building block for NAM with optional batch normalization. """ - def __init__(self, hidden_sizes=[64, 64], dropout_rate=0.1, feature_dropout=0.0, - batch_norm=False): - super(FeatureNet, self).__init__() + + def __init__( + self, + hidden_sizes: Sequence[int] = (64, 64), + dropout_rate: float = 0.1, + feature_dropout: float = 0.0, + batch_norm: bool = False, + ) -> None: + """Initialize the FeatureNet.""" + super(_FeatureNet, self).__init__() self.batch_norm = batch_norm - layers = [] - - # Input layer + layers: List[nn.Module] = [] + layers.append(nn.Linear(1, hidden_sizes[0])) if batch_norm: layers.append(nn.BatchNorm1d(hidden_sizes[0])) layers.append(nn.Tanh()) layers.append(nn.Dropout(dropout_rate)) - + # Hidden layers for i in range(len(hidden_sizes) - 1): - layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1])) + layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1])) if batch_norm: - layers.append(nn.BatchNorm1d(hidden_sizes[i+1])) + layers.append(nn.BatchNorm1d(hidden_sizes[i + 1])) layers.append(nn.Tanh()) layers.append(nn.Dropout(dropout_rate)) - + # Final representation layer layers.append(nn.Linear(hidden_sizes[-1], hidden_sizes[-1])) if batch_norm: layers.append(nn.BatchNorm1d(hidden_sizes[-1])) layers.append(nn.Tanh()) - + self.network = nn.Sequential(*layers) self.feature_dropout = feature_dropout - - def forward(self, x): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the network. + + Parameters + x: Input tensor of shape + + Returns + ------- + torch.Tensor + """ + # ensure float32 x = x.to(dtype=torch.float32) # Apply feature dropout during training if specified if self.training and self.feature_dropout > 0: mask = torch.rand_like(x) > self.feature_dropout x = x * mask.float() - + # Handle BatchNorm with single sample if self.batch_norm and x.size(0) == 1: return self._forward_singleton(x) - + return self.network(x) - - def _forward_singleton(self, x): - """ - Handle the case of a single sample (batch_size=1) - where BatchNorm1d would fail + + def _forward_singleton(self, x: torch.Tensor) -> torch.Tensor: + """Handle the case of a single sample (batch_size=1) + where BatchNorm1d would fail. + + Parameters + ----------- + x: Input tensor of shape (1, num_features) + + Returns + ------- + torch.Tensor """ was_training = self.training self.eval() @@ -64,12 +100,14 @@ def _forward_singleton(self, x): self.train() return result -class L2NormalizedLinear(nn.Module): - """ - Linear layer with L2 normalized weights (unit norm constraint) - """ - def __init__(self, in_features, out_features, bias=False, eps=1e-8): - super(L2NormalizedLinear, self).__init__() +class _L2NormalizedLinear(nn.Module): + """Linear layer with L2 normalized weights (unit norm constraint).""" + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, eps: float = 1e-8 + ) -> None: + """Initialize the L2NormalizedLinear layer.""" + super(_L2NormalizedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.eps = eps @@ -77,27 +115,48 @@ def __init__(self, in_features, out_features, bias=False, eps=1e-8): if bias: self.bias = nn.Parameter(torch.randn(out_features)) else: - self.register_parameter('bias', None) - - def forward(self, x): + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the linear transformation with L2 normalized weights. + + Parameters + ----------- + x: Input tensor of shape (batch_size, in_features) + + Returns + ------- + Output tensor + """ # L2 normalize weights to unit norm normalized_weight = F.normalize(self.weight, p=2, dim=1, eps=self.eps) return F.linear(x, normalized_weight, self.bias) - - def get_normalized_weights(self): - """Return the L2 normalized weights (useful for inspection)""" + + def get_normalized_weights(self) -> torch.Tensor: + """Return the L2 normalized weights (useful for inspection).""" with torch.no_grad(): return F.normalize(self.weight, p=2, dim=1, eps=self.eps) + class CrispNamModel(nn.Module): - """ - Competing risks CoxNAM with L2 normalized projection weights. + """Competing risks CoxNAM with L2 normalized projection weights. + Each feature contributes to each risk through a separate shape function. All projection weights are constrained to unit L2 norm. """ - def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64], - dropout_rate=0.1, feature_dropout=0.1, batch_norm=False, - normalize_projections=True, eps=1e-8): + + def __init__( + self, + num_features: int, + num_competing_risks: int, + hidden_sizes: Sequence[int] = (64, 64), + dropout_rate: float = 0.1, + feature_dropout: float = 0.1, + batch_norm: bool = False, + normalize_projections: bool = True, + eps: float = 1e-8, + ): + """Initialize the CrispNamModel.""" super(CrispNamModel, self).__init__() self.num_features = num_features self.num_competing_risks = num_competing_risks @@ -105,39 +164,51 @@ def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64], self.feature_dropout = feature_dropout self.normalize_projections = normalize_projections self.eps = eps - + # Create a FeatureNet for each input feature - self.feature_nets = nn.ModuleList([ - FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm) - for _ in range(num_features) - ]) - + self.feature_nets = nn.ModuleList( + [ + _FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm) + for _ in range(num_features) + ] + ) + # For each feature and risk type, create a projection layer if normalize_projections: - self.risk_projections = nn.ModuleList([ - nn.ModuleList([ - L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps) - for _ in range(num_competing_risks) - ]) - for _ in range(num_features) - ]) + self.risk_projections: nn.ModuleList = nn.ModuleList( + [ + nn.ModuleList( + [ + _L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps) + for _ in range(num_competing_risks) + ] + ) + for _ in range(num_features) + ] + ) else: # Fallback to standard linear layers - self.risk_projections = nn.ModuleList([ - nn.ModuleList([ - nn.Linear(hidden_sizes[-1], 1, bias=False) - for _ in range(num_competing_risks) - ]) - for _ in range(num_features) - ]) + self.risk_projections = nn.ModuleList( + [ + nn.ModuleList( + [ + nn.Linear(hidden_sizes[-1], 1, bias=False) + for _ in range(num_competing_risks) + ] + ) + for _ in range(num_features) + ] + ) - def forward(self, x): - """ - Forward pass to compute risk scores for all competing risks - - Args: + def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Forward pass to compute risk scores for all competing risks. + + Parameters + ----------- x: Tensor of shape (batch_size, num_features) - Returns: + + Returns + ------- risk_scores: List of (batch_size, 1) Tensors feature_outputs: List of (batch_size, hidden) Tensors """ @@ -159,337 +230,155 @@ def forward(self, x): # loop features for feat_idx, fnet in enumerate(self.feature_nets): # take one column and get repr - col = x[:, feat_idx].unsqueeze(1) # [batch,1] - repr = fnet(col) # [batch, hidden] + col = x[:, feat_idx].unsqueeze(1) # [batch,1] + repr = fnet(col) # [batch, hidden] feature_outputs.append(repr) + proj: Optional[nn.ModuleList | None] = None # project into each risk channel with L2 normalized weights for risk_idx, proj in enumerate(self.risk_projections[feat_idx]): - # proj automatically applies L2 normalization if normalize_projections=True + # proj automatically applies L2 normalization + # if normalize_projections=True combined[:, risk_idx] += proj(repr).view(-1) # split back into list of [batch,1] risk_scores = [ - combined[:, r].unsqueeze(1) - for r in range(self.num_competing_risks) + combined[:, r].unsqueeze(1) for r in range(self.num_competing_risks) ] return risk_scores, feature_outputs - - def get_shape_functions(self, x_values, feature_idx, risk_idx=None, normalize=True): - """ - Extract shape functions for a specific feature across all risks or a specific risk - - Args: - x_values: Feature values to evaluate (numpy array or tensor) - feature_idx: Index of the feature to get shape functions for - risk_idx: Optional; if provided, only returns the shape function for this risk - normalize: Whether to center the shape functions - - Returns: - Dictionary mapping risk names to shape function values - """ - self.eval() - - if not isinstance(x_values, torch.Tensor): - x_values = torch.FloatTensor(x_values) - - x_vals = x_values.view(-1, 1) - - with torch.no_grad(): - feature_repr = self.feature_nets[feature_idx](x_vals) - - shape_funcs = {} - - # If risk_idx is specified, only compute for that risk - risk_indices = [risk_idx] if risk_idx is not None else range(self.num_competing_risks) - - for j in risk_indices: - # Apply the L2 normalized projection to get shape function values - values = self.risk_projections[feature_idx][j](feature_repr).cpu().numpy().flatten() - - # Normalize if requested - if normalize: - values = values - np.mean(values) - - shape_funcs[f'risk_{j+1}'] = values - - return shape_funcs - - def get_projection_norms(self): - """ - Get the L2 norms of all projection weights (should be ~1.0 if normalized) - - Returns: + + def get_projection_norms(self) -> dict: + """Get the L2 norms of all projection weights (should be ~1.0 if normalized). + + Returns + ------- Dictionary of weight norms by feature and risk """ norms = {} - + for feat_idx in range(self.num_features): for risk_idx in range(self.num_competing_risks): proj = self.risk_projections[feat_idx][risk_idx] - - if hasattr(proj, 'weight'): + + if hasattr(proj, "weight"): weight_norm = proj.weight.norm(p=2, dim=1).item() - norms[f'feature_{feat_idx}_risk_{risk_idx}'] = weight_norm - + norms[f"feature_{feat_idx}_risk_{risk_idx}"] = weight_norm + return norms - - def get_normalized_projection_weights(self): - """ - Get the actual L2 normalized weights used in computation - - Returns: + + def get_normalized_projection_weights(self) -> dict: + """Get the actual L2 normalized weights used in computation. + + Returns + ------- Dictionary of normalized weights """ normalized_weights = {} - + for feat_idx in range(self.num_features): for risk_idx in range(self.num_competing_risks): proj = self.risk_projections[feat_idx][risk_idx] - - if hasattr(proj, 'get_normalized_weights'): + + if hasattr(proj, "get_normalized_weights"): # L2NormalizedLinear layer weights = proj.get_normalized_weights().detach().cpu().numpy() - elif hasattr(proj, 'weight'): + elif hasattr(proj, "weight"): # Standard linear layer - normalize manually - weights = F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy() + weights = ( + F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy() + ) else: weights = None - - normalized_weights[f'feature_{feat_idx}_risk_{risk_idx}'] = weights - + + normalized_weights[f"feature_{feat_idx}_risk_{risk_idx}"] = weights + return normalized_weights - - def calculate_feature_importance(self, x_data, feature_idx=None): - """ - Calculate feature importance based on the magnitude of risk-specific projection outputs - With L2 normalized weights, this gives a fair comparison across features - - Args: + + def calculate_feature_importance( + self, + x_data: Optional[torch.Tensor | np.ndarray], + feature_idx: Optional[int | None] = None, + ) -> dict: + """Calculate feature importance based on the magnitude of + risk-specific projection outputs. + + With L2 normalized weights, this gives a fair + comparison across features. + + Parameters + ----------- x_data: Input data tensor or numpy array - feature_idx: Optional; if provided, only calculate importance for this feature - - Returns: + feature_idx: Optional; if provided, only calculate + importance for this feature + + Returns + ------- Dictionary of feature importances by risk type """ + self.eval() device = next(self.parameters()).device - + # Convert to tensor if needed if not isinstance(x_data, torch.Tensor): x_data = torch.FloatTensor(x_data) x_data = x_data.to(device) - - feature_indices = [feature_idx] if feature_idx is not None else range(self.num_features) - importance = {f'risk_{j+1}': {} for j in range(self.num_competing_risks)} - + + feature_indices = ( + [feature_idx] if feature_idx is not None else range(self.num_features) + ) + importance: dict = { + f"risk_{j + 1}": {} for j in range(self.num_competing_risks) + } + for i in feature_indices: # Get feature values feature_values = x_data[:, i].view(-1, 1) - + with torch.no_grad(): # Get the feature representation feature_repr = self.feature_nets[i](feature_values) - + # Calculate importance for each risk (mean absolute value) - # With L2 normalized weights, this is directly comparable across features + # With L2 normalized weights, this is comparable across features for j in range(self.num_competing_risks): risk_specific_output = self.risk_projections[i][j](feature_repr) abs_values = torch.abs(risk_specific_output).cpu().numpy() - importance[f'risk_{j+1}'][f'feature_{i}'] = float(np.mean(abs_values)) - + importance[f"risk_{j + 1}"][f"feature_{i}"] = float( + np.mean(abs_values) + ) + return importance - - def predict_risk(self, x, baseline_hazards=None): - """ - Predict survival probability or cumulative incidence - - Args: - x: Input tensor of shape (batch_size, num_features) - baseline_hazards: Optional dict of baseline hazards for each risk - - Returns: - Dictionary of predictions for each competing risk - """ - self.eval() - - # Convert to tensor if needed - if not isinstance(x, torch.Tensor): - x = torch.FloatTensor(x) - - with torch.no_grad(): - risk_scores, _ = self(x) - - # Convert scores to hazard ratios - hazard_ratios = [torch.exp(score).cpu().numpy() for score in risk_scores] - - # If baseline hazards are provided, compute absolute risks - if baseline_hazards is not None: - predictions = {} - - for j in range(self.num_competing_risks): - risk_name = f'risk_{j+1}' - - # Baseline survival and hazard - baseline_surv = baseline_hazards.get(risk_name, {}).get('survival', None) - baseline_haz = baseline_hazards.get(risk_name, {}).get('hazard', None) - - if baseline_surv is not None: - # Compute survival probability: S(t|x) = S0(t)^exp(f(x)) - predictions[f'{risk_name}_survival'] = np.power( - baseline_surv.reshape(1, -1), - hazard_ratios[j].reshape(-1, 1) - ) - - if baseline_haz is not None: - # Compute cumulative hazard: H(t|x) = H0(t) * exp(f(x)) - predictions[f'{risk_name}_cumhazard'] = baseline_haz.reshape(1, -1) * hazard_ratios[j].reshape(-1, 1) - - return predictions - else: - # Without baseline hazards, just return hazard ratios - return {f'risk_{j+1}_hazard_ratio': hazard_ratios[j] for j in range(self.num_competing_risks)} - -# Alternative implementation using manual normalization in forward pass -class CrispNamModelManualL2(nn.Module): - """ - Alternative implementation with manual L2 normalization in forward pass - """ - def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64], - dropout_rate=0.1, feature_dropout=0.1, batch_norm=False, eps=1e-8): - super(CrispNamModelManualL2, self).__init__() - self.num_features = num_features - self.num_competing_risks = num_competing_risks - self.batch_norm = batch_norm - self.feature_dropout = feature_dropout - self.eps = eps - - # Create a FeatureNet for each input feature - self.feature_nets = nn.ModuleList([ - FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm) - for _ in range(num_features) - ]) - - # Standard linear layers - weights will be L2 normalized in forward pass - self.risk_projections = nn.ModuleList([ - nn.ModuleList([ - nn.Linear(hidden_sizes[-1], 1, bias=False) - for _ in range(num_competing_risks) - ]) - for _ in range(num_features) - ]) - - def forward(self, x): - """Forward pass with manual L2 weight normalization""" - x = x.to(dtype=torch.float32) - batch_size, _ = x.shape - device = x.device - if self.training and self.feature_dropout > 0: - mask = torch.empty_like(x).bernoulli_(1.0 - self.feature_dropout) - x = x * mask + # Utility functions for model analysis + def analyze_projection_weights(self) -> dict: + """Analyze the L2 norms and statistics of projection weights. - combined = torch.zeros(batch_size, self.num_competing_risks, device=device) - feature_outputs = [] + Parameters + ----------- + None - for feat_idx, fnet in enumerate(self.feature_nets): - col = x[:, feat_idx].unsqueeze(1) - repr = fnet(col) - feature_outputs.append(repr) + Returns + ------- + None + """ + print("Projection Weight Analysis:") + print("=" * 50) - for risk_idx, proj in enumerate(self.risk_projections[feat_idx]): - # L2 normalize weights manually - normalized_weight = F.normalize(proj.weight, p=2, dim=1, eps=self.eps) - # Apply normalized projection - output = F.linear(repr, normalized_weight, proj.bias) - combined[:, risk_idx] += output.view(-1) + # Get weight norms + norms = self.get_projection_norms() + norm_values = list(norms.values()) - risk_scores = [ - combined[:, r].unsqueeze(1) - for r in range(self.num_competing_risks) - ] + print("Weight L2 Norms (should be ~1.0):") + print(f" Mean: {np.mean(norm_values):.6f}") + print(f" Std: {np.std(norm_values):.6f}") + print(f" Min: {np.min(norm_values):.6f}") + print(f" Max: {np.max(norm_values):.6f}") - return risk_scores, feature_outputs + # Show some individual norms + print("\nSample individual norms:") + for _i, (name, norm) in enumerate(list(norms.items())[:6]): + print(f" {name}: {norm:.6f}") -# Utility functions for model analysis -def analyze_projection_weights(model): - """ - Analyze the L2 norms and statistics of projection weights - """ - print("Projection Weight Analysis:") - print("=" * 50) - - # Get weight norms - norms = model.get_projection_norms() - norm_values = list(norms.values()) - - print(f"Weight L2 Norms (should be ~1.0):") - print(f" Mean: {np.mean(norm_values):.6f}") - print(f" Std: {np.std(norm_values):.6f}") - print(f" Min: {np.min(norm_values):.6f}") - print(f" Max: {np.max(norm_values):.6f}") - - # Show some individual norms - print(f"\nSample individual norms:") - for i, (name, norm) in enumerate(list(norms.items())[:6]): - print(f" {name}: {norm:.6f}") - - return norms - -def compare_feature_importance_fairness(model, x_data): - """ - Compare feature importance when weights are L2 normalized vs not normalized - """ - print("\nFeature Importance Comparison:") - print("=" * 50) - - # Calculate importance with current model (L2 normalized) - importance_normalized = model.calculate_feature_importance(x_data) - - # Create equivalent model without normalization for comparison - model_unnorm = CrispNamModel( - model.num_features, - model.num_competing_risks, - normalize_projections=False - ) - - # Copy weights from normalized model - with torch.no_grad(): - for i in range(model.num_features): - for j in range(model.num_competing_risks): - model_unnorm.risk_projections[i][j].weight.copy_( - model.risk_projections[i][j].weight - ) - - importance_unnorm = model_unnorm.calculate_feature_importance(x_data) - - print("Importance comparison (Risk 1):") - for feat in range(min(5, model.num_features)): # Show first 5 features - norm_imp = importance_normalized['risk_1'].get(f'feature_{feat}', 0) - unnorm_imp = importance_unnorm['risk_1'].get(f'feature_{feat}', 0) - print(f" Feature {feat}: Normalized={norm_imp:.4f}, Unnormalized={unnorm_imp:.4f}") - -# Example usage and testing -if __name__ == "__main__": - # Create model with L2 normalized projections - model = CrispNamModel( - num_features=5, - num_competing_risks=3, - hidden_sizes=[32, 32], - normalize_projections=True - ) - - # Generate some test data - torch.manual_seed(42) - test_data = torch.randn(100, 5) - - # Test forward pass - risk_scores, feature_outputs = model(test_data) - print(f"Risk scores shapes: {[score.shape for score in risk_scores]}") - - # Analyze projection weights - analyze_projection_weights(model) - - # Compare feature importance - compare_feature_importance_fairness(model, test_data) \ No newline at end of file + return norms diff --git a/crisp_nam/models/deephit_model.py b/crisp_nam/models/deephit_model.py index cc3394f3..5d1d3a75 100644 --- a/crisp_nam/models/deephit_model.py +++ b/crisp_nam/models/deephit_model.py @@ -1,23 +1,48 @@ +"""PyTorch implementation of DeepHit for competing risks survival analysis.""" + +from typing import Callable, Optional + +import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn + class FCLayer(nn.Module): """Fully connected layer with optional batch norm, dropout, and activation.""" - def __init__(self, in_dim, out_dim, activation=nn.ReLU(), batch_norm=False, dropout_rate=0.0, - init_fn=nn.init.xavier_normal_): + + def __init__( + self, + in_dim: int, + out_dim: int, + activation: Optional[nn.Module] = None, + batch_norm: bool = False, + dropout_rate: float = 0.0, + init_fn: Optional[Callable | None] = nn.init.xavier_normal_, + ) -> None: + """Initialize the fully connected layer.""" super(FCLayer, self).__init__() self.fc = nn.Linear(in_dim, out_dim) self.batch_norm = nn.BatchNorm1d(out_dim) if batch_norm else None self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None - self.activation = activation - + self.activation = activation if activation else nn.ReLU() + # Initialize weights if init_fn: init_fn(self.fc.weight) nn.init.zeros_(self.fc.bias) - - def forward(self, x): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the layer. + + Parameters + ---------- + x: Tensor of shape (batch_size, in_dim) + + Returns + ------- + out: Tensor of shape (batch_size, out_dim) + """ x = self.fc(x) if self.batch_norm: x = self.batch_norm(x) @@ -30,222 +55,383 @@ def forward(self, x): class FCNet(nn.Module): """Multi-layer fully connected network.""" - def __init__(self, in_dim, num_layers, h_dim, activation=nn.ReLU(), - out_dim=None, out_activation=None, batch_norm=False, - dropout_rate=0.0, init_fn=nn.init.xavier_normal_): + + def __init__( + self, + in_dim: int, + num_layers: int, + h_dim: int, + activation: Optional[nn.Module] = None, + out_dim: Optional[int | None] = None, + out_activation: Optional[nn.Module | None] = None, + batch_norm: bool = False, + dropout_rate: float = 0.0, + init_fn: Optional[Callable | None] = nn.init.xavier_normal_, + ) -> None: + """Initialize the fully connected network.""" super(FCNet, self).__init__() - + layers = [] prev_dim = in_dim - + activation = activation if activation else nn.ReLU() + # Hidden layers for i in range(num_layers): curr_dim = out_dim if (i == num_layers - 1 and out_dim) else h_dim - curr_act = out_activation if (i == num_layers - 1 and out_activation) else activation - - layers.append(FCLayer( - prev_dim, curr_dim, activation=curr_act, - batch_norm=batch_norm, dropout_rate=dropout_rate, init_fn=init_fn - )) + curr_act: Optional[nn.Module | None] = ( + out_activation + if (i == num_layers - 1 and out_activation) + else activation + ) + + layers.append( + FCLayer( + prev_dim, + curr_dim, + activation=curr_act, + batch_norm=batch_norm, + dropout_rate=dropout_rate, + init_fn=init_fn, + ) + ) prev_dim = curr_dim - + self.network = nn.Sequential(*layers) - - def forward(self, x): + + def forward(self, x: torch.Tensor) -> Optional[nn.Module]: + """Forward pass through the network. + + Parameters + ----------- + x: Tensor of shape (batch_size, in_dim) + + Returns + ------- + out: Tensor of shape (batch_size, out_dim) + """ return self.network(x) class DeepHit(nn.Module): """PyTorch implementation of DeepHit for competing risks survival analysis.""" - - def __init__(self, input_dims, network_settings): + + def __init__(self, input_dims: dict, network_settings: dict): + """Initialize the DeepHit model.""" super(DeepHit, self).__init__() - + # Input dimensions - self.x_dim = input_dims['x_dim'] - self.num_Event = input_dims['num_Event'] - self.num_Category = input_dims['num_Category'] - + self.x_dim = input_dims["x_dim"] + self.num_Event = input_dims["num_Event"] + self.num_Category = input_dims["num_Category"] + # Network settings - self.h_dim_shared = network_settings['h_dim_shared'] - self.h_dim_CS = network_settings['h_dim_CS'] - self.num_layers_shared = network_settings['num_layers_shared'] - self.num_layers_CS = network_settings['num_layers_CS'] - + self.h_dim_shared = network_settings["h_dim_shared"] + self.h_dim_CS = network_settings["h_dim_CS"] + self.num_layers_shared = network_settings["num_layers_shared"] + self.num_layers_CS = network_settings["num_layers_CS"] + # Activation function - if network_settings['active_fn'] == 'relu': + if network_settings["active_fn"] == "relu": self.active_fn = nn.ReLU() - elif network_settings['active_fn'] == 'elu': + elif network_settings["active_fn"] == "elu": self.active_fn = nn.ELU() - elif network_settings['active_fn'] == 'tanh': + elif network_settings["active_fn"] == "tanh": self.active_fn = nn.Tanh() else: self.active_fn = nn.ReLU() - + # Regularization - self.keep_prob = network_settings.get('keep_prob', 0.5) + self.keep_prob = network_settings.get("keep_prob", 0.5) self.dropout_rate = 1.0 - self.keep_prob - + # Initialize networks self._build_network() - - def _build_network(self): + + def _build_network(self) -> None: + """Build the shared and cause-specific networks. + + Parameters + ---------- + None + + Returns + ------- + None + """ # Shared network self.shared_net = FCNet( in_dim=self.x_dim, num_layers=self.num_layers_shared, h_dim=self.h_dim_shared, activation=self.active_fn, - dropout_rate=self.dropout_rate + dropout_rate=self.dropout_rate, ) - + # Cause-specific networks - self.cs_nets = nn.ModuleList([ - FCNet( - in_dim=self.x_dim + self.h_dim_shared, # Concatenate input and shared output - num_layers=self.num_layers_CS, - h_dim=self.h_dim_CS, - activation=self.active_fn, - dropout_rate=self.dropout_rate - ) for _ in range(self.num_Event) - ]) - + self.cs_nets = nn.ModuleList( + [ + FCNet( + in_dim=self.x_dim + + self.h_dim_shared, # Concatenate input and shared output + num_layers=self.num_layers_CS, + h_dim=self.h_dim_CS, + activation=self.active_fn, + dropout_rate=self.dropout_rate, + ) + for _ in range(self.num_Event) + ] + ) + # Output layer - self.output_layer = nn.Linear(self.num_Event * self.h_dim_CS, self.num_Event * self.num_Category) - - def forward(self, x): + self.output_layer = nn.Linear( + self.num_Event * self.h_dim_CS, self.num_Event * self.num_Category + ) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: + """Forward pass through the network. + + Parameters + ---------- + x: Tensor of shape (batch_size, num_Event, num_Category) + + Returns + ------- + risk_scores: List of (batch_size, 1) Tensors + feature_outputs: None + """ # Shared network shared_out = self.shared_net(x) - + # Concatenate input with shared output h = torch.cat([x, shared_out], dim=1) - + # Cause-specific networks cs_outputs = [] for cs_net in self.cs_nets: cs_out = cs_net(h) cs_outputs.append(cs_out) - + # Stack outputs - stacked_out = torch.stack(cs_outputs, dim=1) # [batch_size, num_Event, h_dim_CS] - reshaped_out = stacked_out.view(-1, self.num_Event * self.h_dim_CS) # [batch_size, num_Event * h_dim_CS] - + stacked_out = torch.stack( + cs_outputs, dim=1 + ) # [batch_size, num_Event, h_dim_CS] + reshaped_out = stacked_out.view( + -1, self.num_Event * self.h_dim_CS + ) # [batch_size, num_Event * h_dim_CS] + # Final output layer - logits = self.output_layer(F.dropout(reshaped_out, self.dropout_rate, self.training)) + logits = self.output_layer( + F.dropout(reshaped_out, self.dropout_rate, self.training) + ) out = F.softmax(logits.view(-1, self.num_Event * self.num_Category), dim=1) - + # Reshape to [batch_size, num_Event, num_Category] out = out.view(-1, self.num_Event, self.num_Category) - - # For compatibility with the training script, return both raw risks and shape functions + + # For compatibility with the training script, return both + # raw risks and shape functions # In this model, we don't have separate shape functions, so just return None return out, None - - def log_likelihood_loss(self, out, t, k, mask1, mask2): - """Log-likelihood loss (including log-likelihood of censored subjects)""" - batch_size = out.size(0) - + + def _log_likelihood_loss( + self, + out: torch.Tensor, + t: Optional[torch.Tensor | np.ndarray], + k: Optional[torch.Tensor | np.ndarray], + mask1: torch.Tensor, + mask2: torch.Tensor, + ) -> torch.Tensor: + """Log-likelihood loss (including log-likelihood of censored subjects). + + Parameters + ---------- + out: Torch.tensor + t: Torch.tensor or numpy array + k: Torch.tensor or numpy array + mask1: Torch.tensor + mask2: Torch.tensor + + Returns + ------- + loss: Torch.tensor + """ # Convert to PyTorch tensors if necessary if not isinstance(k, torch.Tensor): k = torch.tensor(k, device=out.device) if not isinstance(t, torch.Tensor): t = torch.tensor(t, device=out.device) - + # Indicator for uncensored subjects - I_1 = (k > 0).float().view(-1, 1) - + i_1 = (k > 0).float().view(-1, 1) + # For uncensored: log P(T=t, K=k|x) tmp1 = torch.sum(torch.sum(mask1 * out, dim=2), dim=1, keepdim=True) - tmp1 = I_1 * torch.log(tmp1 + 1e-8) - + tmp1 = i_1 * torch.log(tmp1 + 1e-8) + # For censored: log ∑ P(T>t|x) - tmp2 = torch.sum(torch.sum(mask2.unsqueeze(1) * out, dim=2), dim=1, keepdim=True) - tmp2 = (1.0 - I_1) * torch.log(tmp2 + 1e-8) - + tmp2 = torch.sum( + torch.sum(mask2.unsqueeze(1) * out, dim=2), dim=1, keepdim=True + ) + tmp2 = (1.0 - i_1) * torch.log(tmp2 + 1e-8) + return -torch.mean(tmp1 + tmp2) - - def ranking_loss(self, out, t, k, mask2): - """Ranking loss (calculated only for acceptable pairs)""" - batch_size = out.size(0) + + def _ranking_loss( + self, + out: torch.Tensor, + t: Optional[torch.Tensor | np.ndarray], + k: Optional[torch.Tensor | np.ndarray], + mask2: torch.Tensor, + ) -> torch.Tensor: + """Ranking loss (calculated only for acceptable pairs). + + Parameters + ---------- + out: Torch.tensor + t: Torch.tensor or numpy array + k: Torch.tensor or numpy array + mask2: Torch.tensor + + Returns + ------- + loss: Torch.tensor + """ sigma1 = 0.1 eta = [] - + # Convert to PyTorch tensors if necessary if not isinstance(k, torch.Tensor): k = torch.tensor(k, device=out.device) if not isinstance(t, torch.Tensor): t = torch.tensor(t, device=out.device) - + one_vector = torch.ones_like(t) - + for e in range(self.num_Event): - # Indicator for current event type - I_2 = (k == e+1).float() - I_2_diag = torch.diag(I_2.squeeze()) - + i_2 = (k == e + 1).float() + i_2_diag = torch.diag(i_2.squeeze()) + # Extract event-specific probabilities tmp_e = out[:, e, :] # [batch_size, num_Category] - + # Calculate risk scores - R = torch.matmul(tmp_e, mask2.transpose(0, 1)) # [batch_size, batch_size] - diag_R = torch.diag(R).unsqueeze(1) # [batch_size, 1] - R = torch.matmul(one_vector, diag_R.transpose(0, 1)) - R # [batch_size, batch_size] - R = R.transpose(0, 1) # Now R_ij = r_i(T_i) - r_j(T_i) - + r = torch.matmul(tmp_e, mask2.transpose(0, 1)) # [batch_size, batch_size] + diag_r = torch.diag(r).unsqueeze(1) # [batch_size, 1] + r = ( + torch.matmul(one_vector, diag_r.transpose(0, 1)) - r + ) # [batch_size, batch_size] + r = r.transpose(0, 1) # Now R_ij = r_i(T_i) - r_j(T_i) + # Time comparison matrix - T = F.relu(torch.sign(torch.matmul(one_vector, t.transpose(0, 1)) - - torch.matmul(t, one_vector.transpose(0, 1)))) - + time = F.relu( + torch.sign( + torch.matmul(one_vector, t.transpose(0, 1)) + - torch.matmul(t, one_vector.transpose(0, 1)) + ) + ) + # Filter by event occurrence - T = torch.matmul(I_2_diag, T) - + time = torch.matmul(i_2_diag, time) + # Calculate ranking loss for current event - tmp_eta = torch.mean(T * torch.exp(-R / sigma1), dim=1, keepdim=True) + tmp_eta = torch.mean(time * torch.exp(-r / sigma1), dim=1, keepdim=True) eta.append(tmp_eta) - + eta = torch.stack(eta, dim=1) # [batch_size, num_Event] - eta = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True) - - return torch.sum(eta) - - def calibration_loss(self, out, t, k, mask2): - """Calibration loss""" - batch_size = out.size(0) + eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True) + + return torch.sum(eta_mean) + + def _calibration_loss( + self, + out: torch.Tensor, + t: Optional[torch.Tensor | np.ndarray], + k: Optional[torch.Tensor | np.ndarray], + mask2: torch.Tensor, + ) -> torch.Tensor: + """Calibration loss. + + Parameters + ---------- + out: Torch.tensor + t: Torch.tensor or numpy array + k: Torch.tensor or numpy array + mask2: Torch.tensor + + Returns + ------- + loss: Torch.tensor + """ eta = [] - + # Convert to PyTorch tensors if necessary if not isinstance(k, torch.Tensor): k = torch.tensor(k, device=out.device) - + for e in range(self.num_Event): # Indicator for current event type - I_2 = (k == e+1).float() - + i_2 = (k == e + 1).float() + # Extract event-specific probabilities tmp_e = out[:, e, :] # [batch_size, num_Category] - + # Calculate calibration loss r = torch.sum(tmp_e * mask2, dim=1) - tmp_eta = torch.mean((r - I_2) ** 2, dim=0, keepdim=True) + tmp_eta = torch.mean((r - i_2) ** 2, dim=0, keepdim=True) eta.append(tmp_eta) - + eta = torch.stack(eta, dim=1) # [1, num_Event] - eta = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True) - - return torch.sum(eta) - - def compute_loss(self, out, t, k, mask1, mask2, alpha=1.0, beta=1.0, gamma=1.0): - """Compute total loss""" - loss1 = self.log_likelihood_loss(out, t, k, mask1, mask2) - loss2 = self.ranking_loss(out, t, k, mask2) - loss3 = self.calibration_loss(out, t, k, mask2) - + eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True) + + return torch.sum(eta_mean) + + def compute_loss( + self, + out: torch.Tensor, + t: Optional[torch.Tensor | np.ndarray], + k: Optional[torch.Tensor | np.ndarray], + mask1: Optional[torch.Tensor | np.ndarray], + mask2: torch.Tensor, + alpha: float = 1.0, + beta: float = 1.0, + gamma: float = 1.0, + ) -> torch.Tensor: + """Compute total loss. + + Parameters + ---------- + out: Torch.tensor + t: Torch.tensor or numpy array + k: Torch.tensor or numpy array + mask1: Torch.tensor + mask2: Torch.tensor + alpha: float, weight for log-likelihood loss + beta: float, weight for ranking loss + gamma: float, weight for calibration loss + + Returns + ------- + total_loss: Torch.tensor + """ + loss1 = self._log_likelihood_loss(out, t, k, mask1, mask2) + loss2 = self._ranking_loss(out, t, k, mask2) + loss3 = self._calibration_loss(out, t, k, mask2) + # L2 regularization is handled by optimizer (weight_decay) return alpha * loss1 + beta * loss2 + gamma * loss3 - - def predict(self, x): - """Predict risk scores for input x""" + + def predict(self, x: torch.Tensor) -> torch.Tensor: + """Predict risk scores for input x. + + Parameters + ---------- + x: Tensor of shape (batch_size, num_Event, num_Category) + + Returns + ------- + out: Tensor of shape (batch_size, num_Event, num_Category) + """ self.eval() with torch.no_grad(): out, _ = self.forward(x) - return out \ No newline at end of file + return out diff --git a/crisp_nam/utils/__init__.py b/crisp_nam/utils/__init__.py index 08275dbe..42bd9dea 100644 --- a/crisp_nam/utils/__init__.py +++ b/crisp_nam/utils/__init__.py @@ -1,3 +1,5 @@ +"""Utility functions for crisp-nam package.""" + from .loss import * from .plotting import * -from .risk_cif import * \ No newline at end of file +from .risk_cif import * diff --git a/crisp_nam/utils/loss.py b/crisp_nam/utils/loss.py index 9cae44fd..5055387e 100644 --- a/crisp_nam/utils/loss.py +++ b/crisp_nam/utils/loss.py @@ -1,143 +1,175 @@ +"""Loss functions for competing risks. + +This module implements weighted and un-weighted +negative log-likelihood loss, L2 penalty loss functions. +""" + import torch -def weighted_negative_log_likelihood_loss(risk_scores, times, events, - num_competing_risks, event_weights=None, - sample_weights=None, eps=1e-8) -> float: + +def weighted_negative_log_likelihood_loss( + risk_scores, + times, + events, + num_competing_risks, + event_weights=None, + sample_weights=None, + eps=1e-8, +) -> torch.Tensor: """ - Computes the weighted negative log-likelihood loss for competing risks Cox model. - - Args: + Compute the weighted negative log-likelihood loss for competing risks Cox model. + + Parameters + ---------- risk_scores: List of tensors with shape (batch_size, 1) for each competing risk times: Event/censoring times (batch_size,) events: Event indicators (0=censored, 1...K=event types) (batch_size,) num_competing_risks: Number of competing risks - event_weights: Tensor of weights for each competing risk type (size: num_competing_risks) + event_weights: Tensor of weights for each competing risk type + (size: num_competing_risks) sample_weights: Tensor of weights for each sample (size: batch_size) eps: Small constant for numerical stability - - Returns: + + Returns + ------- Weighted negative log partial likelihood loss """ device = times.device batch_size = times.shape[0] - + # Initialize loss loss = torch.tensor(0.0, device=device) - + # Set default weights if not provided if event_weights is None: event_weights = torch.ones(num_competing_risks, device=device) if sample_weights is None: sample_weights = torch.ones(batch_size, device=device) - + # Count number of events n_events = (events > 0).sum().item() if n_events == 0: return loss - + # Process each competing risk separately for k in range(1, num_competing_risks + 1): # Find samples with this event type - event_mask = (events == k) + event_mask = events == k n_events_k = event_mask.sum().item() - + if n_events_k == 0: continue - + # Get risk scores for this competing risk - risk_k = risk_scores[k-1].squeeze() - + risk_k = risk_scores[k - 1].squeeze() + # Get weight for this event type - event_weight = event_weights[k-1] - + event_weight = event_weights[k - 1] + # For each event of type k for i in range(batch_size): if event_mask[i]: # Find samples in risk set (samples with time >= event time) - risk_set = (times >= times[i]) - + risk_set = times >= times[i] + # Calculate log sum of exp of risk scores in risk set risk_set_scores = risk_k[risk_set] log_risk_sum = torch.logsumexp(risk_set_scores, dim=0) - + # Subtract individual risk score from log sum and apply weights individual_loss = log_risk_sum - risk_k[i] - weighted_individual_loss = individual_loss * event_weight * sample_weights[i] + weighted_individual_loss = ( + individual_loss * event_weight * sample_weights[i] + ) loss += weighted_individual_loss - + # Return average loss return loss / max(n_events, 1) -def negative_log_likelihood_loss(risk_scores, times, events, - num_competing_risks, eps=1e-8): + +def negative_log_likelihood_loss( + risk_scores: float, + times: torch.Tensor, + events: torch.Tensor, + num_competing_risks: int, + eps: float = 1e-8, +) -> torch.Tensor: """ - Computes the negative log-likelihood loss for competing risks Cox model. - - Args: + Compute the negative log-likelihood loss for competing risks Cox model. + + Parameters + ---------- risk_scores: List of tensors with shape (batch_size, 1) for each competing risk times: Event/censoring times (batch_size,) events: Event indicators (0=censored, 1...K=event types) (batch_size,) num_competing_risks: Number of competing risks eps: Small constant for numerical stability - - Returns: + + Returns + ------- Negative log partial likelihood loss """ device = times.device batch_size = times.shape[0] - + # Initialize loss loss = torch.tensor(0.0, device=device) - + # Count number of events n_events = (events > 0).sum().item() if n_events == 0: return loss - + # Process each competing risk separately for k in range(1, num_competing_risks + 1): # Find samples with this event type - event_mask = (events == k) + event_mask = events == k n_events_k = event_mask.sum().item() - + if n_events_k == 0: continue - + # Get risk scores for this competing risk - risk_k = risk_scores[k-1].squeeze() - + risk_k = risk_scores[k - 1].squeeze() + # For each event of type k for i in range(batch_size): if event_mask[i]: # Find samples in risk set (samples with time >= event time) - risk_set = (times >= times[i]) - + risk_set = times >= times[i] + # Calculate log sum of exp of risk scores in risk set risk_set_scores = risk_k[risk_set] log_risk_sum = torch.logsumexp(risk_set_scores, dim=0) - + # Subtract individual risk score from log sum loss += log_risk_sum - risk_k[i] - + # Return average loss return loss / max(n_events, 1) -def compute_l2_penalty(model, include_bias=False) -> int: + +def compute_l2_penalty( + model: torch.nn.Module, + include_bias: bool = False + ) -> torch.Tensor: """ - Compute L2 regularization penalty on model parameters - - Args: + Compute L2 regularization penalty on model parameters. + + Parameters + ---------- model: Neural network model include_bias: Whether to include bias terms in regularization - - Returns: + + Returns + ------- L2 penalty term """ l2_reg = 0.0 for name, param in model.named_parameters(): if param.requires_grad: # Skip bias parameters if specified - if not include_bias and 'bias' in name: + if not include_bias and "bias" in name: continue - l2_reg += torch.sum(param ** 2) - return l2_reg \ No newline at end of file + l2_reg += torch.sum(param**2) + return l2_reg diff --git a/crisp_nam/utils/plotting.py b/crisp_nam/utils/plotting.py index e92e1b44..184e7718 100644 --- a/crisp_nam/utils/plotting.py +++ b/crisp_nam/utils/plotting.py @@ -1,32 +1,61 @@ -import matplotlib.pyplot as plt -from typing import Union, List +"""Utility functions for plotting. -import torch +This module provides functions to visualize feature importance +and shape functions for both crisp-nam and deephit models. +""" + +from typing import List, Union + +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import torch -def plot_feature_importance(model: torch.nn.Module, - x_data: Union[np.ndarray, torch.Tensor], - feature_names = None, - n_top:int = 5, - n_bottom:int = 5, - risk_idx:int = 1, - figsize: tuple = (8, 6), - output_file:str = None, - color_positive:str = '#2196F3', - color_negative:str = '#F44336') -> tuple: - """ - Plot feature importance with both top positive and negative influences, + +def plot_feature_importance( + model: torch.nn.Module, + x_data: Union[np.ndarray, torch.Tensor], + feature_names=None, + n_top: int = 5, + n_bottom: int = 5, + risk_idx: int = 1, + figsize: tuple = (8, 6), + output_file: str = "", + color_positive: str = "#2196F3", + color_negative: str = "#F44336", +) -> tuple: + """Plot feature importance with both top positive and negative influences, handling both CPU and CUDA devices automatically. + + Parameters + ---------- + - model: A trained CoxNAM model (torch.nn.Module) + - x_data: Input data (numpy array or torch tensor) to compute contributions + - feature_names: Optional list of feature names (default: generic names) + - n_top: Number of top positive features to display + - n_bottom: Number of top negative features to display + - risk_idx: Index of the competing risk to analyze + - figsize: Size of the plot (width, height) + - output_file: Optional path to save the plot image (e.g., "feature_importance.png") + - color_positive: Color for positive contributions (default: blue) + - color_negative: Color for negative contributions (default: red) + + Returns + ------- + - fig: Matplotlib figure object + - ax: Matplotlib axes object + - top_pos: List of top positive feature names + - top_neg: List of top negative feature names """ + # determine model device device = next(model.parameters()).device model.eval() # prepare feature names - num_features = model.num_features + num_features: torch.Tensor = model.num_features if feature_names is None: - feature_names = [f"Feature {i+1}" for i in range(num_features)] + feature_names = [f"Feature {i + 1}" for i in range(num_features)] # convert x_data to tensor on the model device if not isinstance(x_data, torch.Tensor): @@ -45,77 +74,102 @@ def plot_feature_importance(model: torch.nn.Module, continue # forward through the feature net and projection - rep = model.feature_nets[i](vals) - proj = model.risk_projections[i][risk_idx0](rep) + rep : torch.nn.ModuleList = model.feature_nets[i](vals) + proj : torch.nn.ModuleList = model.risk_projections[i][risk_idx0](rep) # mean contribution as a Python float contrib = proj.mean().item() feature_contribs[feature_names[i]] = contrib # build a DataFrame for sorting - df = pd.DataFrame({ - 'feature': list(feature_contribs.keys()), - 'contribution': list(feature_contribs.values()) - }) - df['abs_contrib'] = df['contribution'].abs() - df = df.sort_values('abs_contrib', ascending=False) - - pos = df[df['contribution'] > 0].head(n_top).sort_values('contribution') - neg = df[df['contribution'] < 0].head(n_bottom).sort_values('contribution', ascending=False) - - top_pos = pos['feature'].tolist() - top_neg = neg['feature'].tolist() + df = pd.DataFrame( + { + "feature": list(feature_contribs.keys()), + "contribution": list(feature_contribs.values()), + } + ) + df["abs_contrib"] = df["contribution"].abs() + df = df.sort_values("abs_contrib", ascending=False) + + pos = df[df["contribution"] > 0].head(n_top).sort_values("contribution") + neg = ( + df[df["contribution"] < 0] + .head(n_bottom) + .sort_values("contribution", ascending=False) + ) + + top_pos = pos["feature"].tolist() + top_neg = neg["feature"].tolist() # plotting fig, ax = plt.subplots(figsize=figsize) - ax.barh(pos['feature'], pos['contribution'], color=color_positive, alpha=0.8) - ax.barh(neg['feature'], neg['contribution'], color=color_negative, alpha=0.8) - ax.axvline(0, color='black', linestyle='-', alpha=0.3) - - ax.set_xlabel('Contribution to Risk Score') - ax.set_title(f'Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}') - ax.grid(axis='x', linestyle='--', alpha=0.5) + ax.barh(pos["feature"], pos["contribution"], color=color_positive, alpha=0.8) + ax.barh(neg["feature"], neg["contribution"], color=color_negative, alpha=0.8) + ax.axvline(0, color="black", linestyle="-", alpha=0.3) + + ax.set_xlabel("Contribution to Risk Score") + ax.set_title( + f"Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}" + ) + ax.grid(axis="x", linestyle="--", alpha=0.5) plt.tight_layout() if output_file: - plt.savefig(output_file, bbox_inches='tight', dpi=300) + plt.savefig(output_file, bbox_inches="tight", dpi=300) return fig, ax, top_pos, top_neg -def plot_coxnam_shape_functions(model:torch.nn.Module, - X: Union[np.ndarray, torch.Tensor], - risk_to_plot:int = 1, - feature_names:List[str] = None, - top_features:List[str] = None, - ncols:int = 3, - figsize:tuple = (12, 8), - output_file:str = None) -> tuple: - """ - Plot shape functions for each feature in a CoxNAM model, +def plot_coxnam_shape_functions( + model: torch.nn.Module, + X: Union[np.ndarray, torch.Tensor], + risk_to_plot: int = 1, + feature_names: np.ndarray | None = None, + top_features: List[str] | None = None, + ncols: int = 3, + figsize: tuple = (12, 8), + output_file: str = "", +) -> tuple: + """Plot shape functions for each feature in a CoxNAM model, automatically handling CPU vs CUDA inputs. + + Parameters + ---------- + - model: A trained CoxNAM model (torch.nn.Module) + - X: Input data (numpy array or torch tensor) to compute shape functions + - risk_to_plot: Index of the competing risk to visualize + - feature_names: Optional list of feature names (default: generic names) + - top_features: Optional list of feature names to plot features) + - ncols: Number of columns in the subplot grid + - figsize: Size of the entire figure (width, height) + - output_file: Optional path to save the plot image (e.g., "shape_functions.png") + + Returns + ------- + - fig: Matplotlib figure object + - axes: List of Matplotlib axes objects for each plotted feature """ device = next(model.parameters()).device model.eval() risk_idx = risk_to_plot - 1 # ensure X is a numpy array - if isinstance(X, torch.Tensor): - X_np = X.cpu().numpy() - else: - X_np = np.array(X, dtype=float) + X_np = X.cpu().numpy() if isinstance(X, torch.Tensor) else np.array(X, dtype=float) # derive feature list num_features = model.num_features + print(f'{plot_coxnam_shape_functions.__name__}: top_features={top_features}') if feature_names is None: - feature_names = [f"Feature {i+1}" for i in range(num_features)] - if top_features: + feature_names = [f"Feature {i + 1}" for i in range(num_features)] + if top_features is not None : # map names back to indices idx_map = {name: i for i, name in enumerate(feature_names)} - selected = [(idx_map.get(name, None), name) for name in top_features] + selected = [(idx_map.get(name), name) for name in top_features] selected = [(i, name) for i, name in selected if i is not None] else: selected = list(zip(range(num_features), feature_names)) + print(f'{plot_coxnam_shape_functions.__name__}: num_features={num_features}, feature_names={feature_names}, top_features={top_features}') + print(f'{plot_coxnam_shape_functions.__name__}: selected={selected}') n_selected = len(selected) nrows = int(np.ceil(n_selected / ncols)) fig, axes = plt.subplots(nrows, ncols, figsize=(figsize)) @@ -125,7 +179,7 @@ def plot_coxnam_shape_functions(model:torch.nn.Module, for ax, (f_idx, fname) in zip(axes, selected): vals = X_np[:, f_idx] if vals.size == 0: - ax.text(0.5, 0.5, "no data", ha='center', va='center') + ax.text(0.5, 0.5, "no data", ha="center", va="center") continue # choose evaluation points @@ -138,26 +192,26 @@ def plot_coxnam_shape_functions(model:torch.nn.Module, t_pts = torch.tensor(pts, dtype=torch.float32, device=device).unsqueeze(1) # compute shape values - rep = model.feature_nets[f_idx](t_pts) - proj = model.risk_projections[f_idx][risk_idx](rep) + rep : torch.nn.ModuleList = model.feature_nets[f_idx](t_pts) + proj : torch.nn.ModuleList = model.risk_projections[f_idx][risk_idx](rep) shp = proj.squeeze(-1).cpu().numpy() # plot ax.plot(pts, shp, linewidth=2) - ax.axhline(0, linestyle='--', alpha=0.5) + ax.axhline(0, linestyle="--", alpha=0.5) ax.set_title(fname) - ax.set_xlabel('Value') - ax.set_ylabel('Contribution') + ax.set_xlabel("Value") + ax.set_ylabel("Contribution") # rug plot - ax.plot(vals, np.zeros_like(vals)-0.1, '|', alpha=0.3) + ax.plot(vals, np.zeros_like(vals) - 0.1, "|", alpha=0.3) # turn off any extra axes for ax in axes[n_selected:]: - ax.axis('off') + ax.axis("off") - fig.suptitle(f'Shape Functions for Risk {risk_to_plot}', fontsize=14) - plt.tight_layout(rect=[0,0,1,0.96]) + fig.suptitle(f"Shape Functions for Risk {risk_to_plot}", fontsize=14) + plt.tight_layout(rect=(0, 0, 1, 0.96)) if output_file: - plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.savefig(output_file, dpi=300, bbox_inches="tight") return fig, axes[:n_selected] diff --git a/crisp_nam/utils/risk_cif.py b/crisp_nam/utils/risk_cif.py index 9e8cd1b2..4c043890 100644 --- a/crisp_nam/utils/risk_cif.py +++ b/crisp_nam/utils/risk_cif.py @@ -1,33 +1,40 @@ -from typing import List, Any +"""Risk functions for evaluation. + +This module provides functions to compute cumulative incidence functions (CIFs) +and risk scores for competing risk models. +""" + +from typing import Any, List -import torch import numpy as np +import torch -def compute_baseline_cif(times:np.ndarray, - events:np.ndarray, - eval_times:List[Any], - event_type:np.ndarray) -> np.ndarray: + +def compute_baseline_cif( + times: np.ndarray, events: np.ndarray, eval_times: List[Any], event_type: np.ndarray +) -> np.ndarray: """ - Compute baseline cumulative incidence function for a specific event type - + Compute baseline cumulative incidence function for a specific event type. + Args: times: Numpy array of event times events: Numpy array of event indicators (0=censored, 1...K=event types) eval_times: Times at which to evaluate the CIF event_type: Event type to compute CIF for (1...K) - - Returns: + + Returns + ------- Numpy array of baseline CIF values at eval_times """ # Sort times and corresponding events sort_idx = np.argsort(times) sorted_times = times[sort_idx] sorted_events = events[sort_idx] - + # Initialize cumulative hazard n_samples = len(times) baseline_cif = np.zeros(len(eval_times)) - + # For each evaluation time for i, t in enumerate(eval_times): cif_t = 0.0 @@ -37,78 +44,31 @@ def compute_baseline_cif(times:np.ndarray, # Simple Aalen-Johansen estimator cif_t = event_count / n_samples baseline_cif[i] = cif_t - - return baseline_cif -def predict_absolute_risk(model: torch.nn.Module, - x:np.ndarray, - baseline_cifs:np.ndarray, - eval_times: np.ndarray, - device:str="cpu") -> np.ndarray: - """ - Predict absolute risk (cumulative incidence) at specified times - - Args: - model: Trained CoxNAM model - x: Feature matrix - baseline_cifs: Dictionary of baseline CIFs for each event type - eval_times: Times at which to evaluate the CIF - device: Device to run computations on - - Returns: - Numpy array of predicted absolute risks with shape (n_samples, n_risks, n_times) - """ - model.eval() - - # Convert to tensor if needed - if not isinstance(x, torch.Tensor): - x = torch.FloatTensor(x).to(device) - - with torch.no_grad(): - # Get risk scores - risk_scores, _ = model(x) - - # Convert to hazard ratios - hazard_ratios = [torch.exp(score).cpu().numpy() for score in risk_scores] - - # Initialize prediction array - n_samples = x.shape[0] - n_risks = model.num_competing_risks - n_times = len(eval_times) - abs_risks = np.zeros((n_samples, n_risks, n_times)) - - # Compute absolute risk for each sample, risk, and time - for k in range(n_risks): - # Skip if baseline CIF is not available - if k+1 not in baseline_cifs: - continue - - baseline_cif = baseline_cifs[k+1] - - for i in range(n_samples): - for j, t in enumerate(eval_times): - # Simple Fine-Gray model for cumulative incidence - abs_risks[i, k, j] = 1 - np.exp(-baseline_cif[j] * hazard_ratios[k][i]) - - return abs_risks + return baseline_cif -def predict_cif(model:torch.nn.Module, - x:np.ndarray, - baseline_cif:np.ndarray, - times:np.ndarray, - event_of_interest:int) -> np.ndarray: +def predict_cif( + model: torch.nn.Module, + x: np.ndarray, + baseline_cif: np.ndarray, + times: np.ndarray, + event_of_interest: int, +) -> np.ndarray: """ Predict cumulative incidence function for a specific competing risk. - Args: + Parameters + ---------- model: Trained model. x: Input tensor of shape (n_samples, n_features). - baseline_cif: Array of shape (len(times),) — estimated CIF for baseline (e.g. from compute_baseline_cif). + baseline_cif: Array of shape (len(times),) — + estimated CIF for baseline (e.g. from compute_baseline_cif). times: Time points at which CIF is evaluated. event_type: Integer, 0-based index of event of interest. - Returns: + Returns + ------- cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample. """ model.eval() @@ -117,56 +77,63 @@ def predict_cif(model:torch.nn.Module, f_j_x = logits[event_of_interest].squeeze(1).cpu().numpy() # (n_samples,) baseline_cif = np.asarray(baseline_cif).reshape(1, -1) # (1, T) - risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1) - - # Fine-Gray style CIF prediction under PH assumption - cif_pred = 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T) - - return cif_pred - -def predict_risk(model:np.ndarray, - x_input:np.ndarray, - device:str = 'cpu'): + risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1) + + # Return Fine-Gray style CIF prediction under PH assumption + return 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T) + + +def predict_risk( + model: torch.nn.Module, x_input: np.ndarray, device: str = "cpu" +) -> np.ndarray: """ Predicts relative risk scores for each competing risk. Args: model : Trained model. - x_input (np.ndarray or torch.Tensor): Input features of shape (n_samples, n_features). + x_input (np.ndarray or torch.Tensor): Input features of + shape (n_samples, n_features). device (str): Device to run the computation on. - Returns: + Returns + ------- np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores. """ model.eval() - + if isinstance(x_input, np.ndarray): x_tensor = torch.from_numpy(x_input).float().to(device) else: x_tensor = x_input.to(device).float() - + with torch.no_grad(): risk_outputs, _ = model(x_tensor) # List of [batch_size, 1] tensors risks = torch.cat(risk_outputs, dim=1) # Shape: [batch_size, num_risks] - return risks.cpu().numpy() + return risks.cpu().numpy() + -def predict_absolute_risk(model:torch.Tensor, - x_input:np.ndarray, - baseline_cifs:List[Any], - times:List[Any], - device:str = 'cpu') -> np.ndarray: +def predict_absolute_risk( + model: torch.nn.Module, + x_input: np.ndarray, + baseline_cifs: List[Any], + times: List[Any], + device: str = "cpu", +) -> np.ndarray: """ Predict absolute risk (CIF) for each competing event by given time points. - Args: + Parameters + ---------- model: Trained model. x_input (np.ndarray or Tensor): Input features, shape (n_samples, n_features). - baseline_cifs (dict): Mapping of event index to baseline CIF array of shape (n_times,). + baseline_cifs (dict): Mapping of event index to baseline CIF + array of shape (n_times,). times (np.ndarray): Time grid used for baseline_cifs. device: CPU or CUDA. - Returns: + Returns + ------- np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs. """ rel_risks = predict_risk(model, x_input, device) # shape (n_samples, num_events) @@ -179,5 +146,5 @@ def predict_absolute_risk(model:torch.Tensor, base_cif = np.clip(baseline_cifs[k], 1e-10, 0.9999) # avoid edge cases for i in range(n_samples): abs_risks[i, k, :] = 1 - np.power(1 - base_cif, np.exp(rel_risks[i, k])) - - return abs_risks \ No newline at end of file + + return abs_risks diff --git a/data_utils/load_datasets.py b/data_utils/load_datasets.py index 2263fa30..6e4837fa 100644 --- a/data_utils/load_datasets.py +++ b/data_utils/load_datasets.py @@ -1,16 +1,18 @@ -from typing import Tuple, List +from typing import List, Tuple -import pandas as pd import numpy as np +import pandas as pd from sklearn.impute import SimpleImputer + def load_framingham(sequential=False): """ Load and preprocess the Framingham dataset for competing risks analysis, with imputation but no scaling. Feature normalization must be done externally after splitting to avoid data leakage. - Returns: + Returns + ------- x (np.ndarray): Feature matrix with one-hot categorical + raw continuous features. t (np.ndarray): Time-to-event (with +1 offset). e (np.ndarray): Event indicator (0=censored, 1=CVD, 2=death). @@ -18,55 +20,62 @@ def load_framingham(sequential=False): n_continuous (int): Number of continuous features at the end of x. feature_ranges (None): Placeholder for backward compatibility. """ - file_path = "datasets/framingham.csv" data = pd.read_csv(file_path) if not sequential: - data = data.groupby("RANDID").first() - cat_cols = [ - 'SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS', - 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP', 'educ' + "SEX", + "CURSMOKE", + "DIABETES", + "BPMEDS", + "PREVCHD", + "PREVAP", + "PREVMI", + "PREVSTRK", + "PREVHYP", + "educ", ] - + # 'HDLC', 'LDLC' - removed to replicate nfg experiments. cont_cols = [ - 'TOTCHOL', 'AGE', - 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI', - 'HEARTRTE', 'GLUCOSE' + "TOTCHOL", + "AGE", + "SYSBP", + "DIABP", + "CIGPDAY", + "BMI", + "HEARTRTE", + "GLUCOSE", ] - - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat = pd.DataFrame( cat_imputer.fit_transform(data[cat_cols]), columns=cat_cols, - index=data.index + index=data.index, ) x_cat = pd.get_dummies(x_cat, drop_first=True) - - cont_imputer = SimpleImputer(strategy='mean') + cont_imputer = SimpleImputer(strategy="mean") x_cont = cont_imputer.fit_transform(data[cont_cols]) - x = np.hstack([x_cat.values, x_cont]) feature_names = np.concatenate([x_cat.columns.values, cont_cols]) n_continuous = len(cont_cols) event = np.zeros(len(data), dtype=int) - time = (data['TIMEDTH'] - data['TIME']).values + time = (data["TIMEDTH"] - data["TIME"]).values # Primary CVD event (risk=1) - cvd_mask = data['CVD'] == 1 + cvd_mask = data["CVD"] == 1 event[cvd_mask] = 1 - time_cvd = (data['TIMECVD'] - data['TIME']).values + time_cvd = (data["TIMECVD"] - data["TIME"]).values time[cvd_mask] = time_cvd[cvd_mask] # Competing death event (risk=2), only if CVD did not occur - death_mask = (data['DEATH'] == 1) & ~cvd_mask + death_mask = (data["DEATH"] == 1) & ~cvd_mask event[death_mask] = 2 # Filter out invalid or zero times @@ -79,8 +88,8 @@ def load_framingham(sequential=False): assert not np.isnan(x).any(), "NaNs found in feature matrix" return x, t, e, feature_names, n_continuous, None - else: - raise NotImplementedError("Sequential mode not yet implemented.") + raise NotImplementedError("Sequential mode not yet implemented.") + def load_pbc2_dataset(): """ @@ -90,7 +99,8 @@ def load_pbc2_dataset(): with missing values imputed. The function also constructs the outcome variable for competing risks analysis and returns time-to-event data. - Returns: + Returns + ------- tuple: A tuple containing the following elements: - x (numpy.ndarray): Combined feature matrix with categorical features one-hot encoded and continuous features imputed. @@ -103,55 +113,56 @@ def load_pbc2_dataset(): - n_continuous (int): Number of continuous features. - feature_ranges (list of tuple): List of (min, max) ranges for each feature. """ - file_path = "datasets/pbc2.csv" data = pd.read_csv(file_path) - data = data.drop(columns=['id', 'sno.', 'year', 'status2'], axis=1) - + data = data.drop(columns=["id", "sno.", "year", "status2"], axis=1) event_type = np.where( - data['status'] == 'dead', 1, - np.where(data['status'] == 'transplanted', 2, 0) + data["status"] == "dead", 1, np.where(data["status"] == "transplanted", 2, 0) ) - cont_cols = [ - 'age', 'serBilir', 'serChol', 'albumin', 'alkaline', - 'SGOT', 'platelets', 'prothrombin', 'histologic' + "age", + "serBilir", + "serChol", + "albumin", + "alkaline", + "SGOT", + "platelets", + "prothrombin", + "histologic", ] - x_cont = data[cont_cols].replace('NA', np.nan).astype(float) + x_cont = data[cont_cols].replace("NA", np.nan).astype(float) - - mean_imputer = SimpleImputer(strategy='mean') + mean_imputer = SimpleImputer(strategy="mean") x_cont_imputed = mean_imputer.fit_transform(x_cont) - cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0), - np.nanmax(x_cont_imputed, axis=0))) + cont_feature_ranges = list( + zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0)) + ) - - cat_cols = ['sex', 'drug', 'ascites', 'hepatomegaly', 'spiders', 'edema'] - x_cat = data[cat_cols].fillna('missing') + cat_cols = ["sex", "drug", "ascites", "hepatomegaly", "spiders", "edema"] + x_cat = data[cat_cols].fillna("missing") - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat_imputed = cat_imputer.fit_transform(x_cat) x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols) x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True) - - x_cat_encoded = x_cat_encoded.loc[:, ~x_cat_encoded.columns.str.contains('_missing')] + x_cat_encoded = x_cat_encoded.loc[ + :, ~x_cat_encoded.columns.str.contains("_missing") + ] cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1] - x = np.hstack([x_cat_encoded.values, x_cont_imputed]) feature_names = np.concatenate([x_cat_encoded.columns, cont_cols]) n_continuous = len(cont_cols) feature_ranges = cat_feature_ranges + cont_feature_ranges - - t = data['years'].astype(float).values * 365.25 + t = data["years"].astype(float).values * 365.25 valid = ~np.isnan(t) return ( @@ -160,17 +171,20 @@ def load_pbc2_dataset(): event_type[valid], feature_names, n_continuous, - feature_ranges + feature_ranges, ) + def load_support_dataset(): """ Load and preprocess the SUPPORT dataset. This function reads the SUPPORT dataset from a CSV file, imputes missing values, encodes categorical features, and constructs the outcome variable for survival analysis. - It returns the processed features, time-to-event data, event types, feature names, + It returns the processed features, time-to-event data, event types, feature names, the number of continuous features, and feature ranges. - Returns: + + Returns + ------- tuple: A tuple containing: - x (numpy.ndarray): Combined array of processed categorical and continuous features. - t (numpy.ndarray): Time-to-event data with a +1 offset. @@ -178,7 +192,9 @@ def load_support_dataset(): - feature_names (numpy.ndarray): Array of feature names. - n_continuous (int): Number of continuous features. - feature_ranges (list): List of tuples representing the range (min, max) for each feature. - Notes: + + Notes + ----- - The dataset file "support2.csv" must be located in the same directory as this script. - Continuous features are imputed using the median strategy if missing values are present. - Categorical features are imputed using the most frequent strategy if missing values are present. @@ -186,77 +202,101 @@ def load_support_dataset(): - The outcome variable is constructed using the 'ca', 'dzgroup', and 'death' columns. - Time-to-event data ('d.time') is offset by +1 to avoid zero follow-up times. """ - file_path = "datasets/support2.csv" data = pd.read_csv(file_path) - - is_cancer = data['ca'].astype(str).str.lower().str.contains("meta") | \ - data['dzgroup'].astype(str).str.lower().str.contains("cancer") - event_type = np.where(data['death'] == 1, np.where(is_cancer, 1, 2), 0) + + is_cancer = data["ca"].astype(str).str.lower().str.contains("meta") | data[ + "dzgroup" + ].astype(str).str.lower().str.contains("cancer") + event_type = np.where(data["death"] == 1, np.where(is_cancer, 1, 2), 0) cont_cols = [ - 'age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 'pafi', 'alb', - 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 'urine', - 'scoma', 'aps', 'sps', 'adls', 'adlsc', 'charges', 'totcst', 'totmcst', 'avtisst' + "age", + "num.co", + "meanbp", + "wblc", + "hrt", + "resp", + "temp", + "pafi", + "alb", + "bili", + "crea", + "sod", + "ph", + "glucose", + "bun", + "urine", + "scoma", + "aps", + "sps", + "adls", + "adlsc", + "charges", + "totcst", + "totmcst", + "avtisst", ] x_cont = data[cont_cols] - + if x_cont.isnull().values.any(): - simp_imputer = SimpleImputer(strategy='median') + simp_imputer = SimpleImputer(strategy="median") x_cont_imputed = simp_imputer.fit_transform(x_cont) else: x_cont_imputed = x_cont.values - cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0), - np.nanmax(x_cont_imputed, axis=0))) + cont_feature_ranges = list( + zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0)) + ) # -- Categorical features -- # Remove leakage fields 'ca', 'dzgroup', 'dzclass' - cat_cols = ['sex', 'income', 'race', 'dnr', 'dementia', 'diabetes'] + cat_cols = ["sex", "income", "race", "dnr", "dementia", "diabetes"] x_cat = data[cat_cols] if x_cat.isnull().values.any(): - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat_imputed = cat_imputer.fit_transform(x_cat) else: x_cat_imputed = x_cat.values x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols) - + x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True) cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1] - x = np.hstack([x_cat_encoded.values, x_cont_imputed]) feature_names = np.concatenate([x_cat_encoded.columns, cont_cols]) n_continuous = len(cont_cols) feature_ranges = cat_feature_ranges + cont_feature_ranges - - t = data['d.time'].values + t = data["d.time"].values valid = ~np.isnan(t) - + print("Completed imputation of missing values.") return ( x[valid], - t[valid] + 1, # Add +1 offset to follow-up times + t[valid] + 1, # Add +1 offset to follow-up times event_type[valid], feature_names, n_continuous, - feature_ranges + feature_ranges, ) -def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]]: +def load_synthetic_dataset() -> Tuple[ + np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple] +]: """ Loads a synthetic competing risks dataset from a CSV file. - + The CSV is expected to have a header with the following columns: - time: observed time - label: event indicator (0 for censored; >0 for event types) - true_time: (optional) true time (unused here) - true_label: (optional) true event label (unused here) - feature1, feature2, ..., featureN: feature values - - Returns: + + Returns + ------- X (np.ndarray): Feature matrix of shape (n_samples, n_features). T_obs (np.ndarray): Observed times of shape (n_samples,). e (np.ndarray): Event indicators of shape (n_samples,). @@ -264,22 +304,20 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s n_continuous (int): Total number of continuous features. feature_ranges (List[tuple]): List of (min, max) tuples for each feature. """ - file_path = "datasets/synthetic_comprisk.csv" df = pd.read_csv(file_path) - - T_obs = df["time"].values .astype(np.float32) + T_obs = df["time"].values.astype(np.float32) e = df["label"].values.astype(np.float32) - - + feature_columns = [col for col in df.columns if col.startswith("feature")] - X = df[feature_columns].values .astype(np.float32) - - + X = df[feature_columns].values.astype(np.float32) + feature_names = feature_columns n_continuous = X.shape[1] - - feature_ranges = [(float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)] - - return X, T_obs, e, feature_names, n_continuous, feature_ranges \ No newline at end of file + + feature_ranges = [ + (float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous) + ] + + return X, T_obs, e, feature_names, n_continuous, feature_ranges diff --git a/data_utils/survival_datasets.py b/data_utils/survival_datasets.py index d2eab1b9..41d8c8f9 100644 --- a/data_utils/survival_datasets.py +++ b/data_utils/survival_datasets.py @@ -1,6 +1,7 @@ import torch from torch.utils.data import Dataset + class SurvivalDataset(Dataset): def __init__(self, x, t, e): self.x = torch.tensor(x, dtype=torch.float32) @@ -13,22 +14,28 @@ def __len__(self): def __getitem__(self, idx): return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking + class SurvivalDatasetDeepHit(Dataset): """Dataset class for DeepHit model""" + def __init__(self, x, t, e, num_Category): self.x = torch.tensor(x, dtype=torch.float32) self.t = torch.tensor(t, dtype=torch.float32).view(-1, 1) self.e = torch.tensor(e, dtype=torch.float32).view(-1, 1) - + # Create discretized time if needed - self.t_discrete = torch.floor(self.t * num_Category / torch.max(self.t)).clamp(0, num_Category-1).long() - + self.t_discrete = ( + torch.floor(self.t * num_Category / torch.max(self.t)) + .clamp(0, num_Category - 1) + .long() + ) + # Create masks for loss calculation self.num_Category = num_Category self.num_Event = int(torch.max(self.e).item()) - + def __len__(self): return len(self.x) - + def __getitem__(self, idx): - return self.x[idx], self.t[idx], self.e[idx], self.t_discrete[idx] \ No newline at end of file + return self.x[idx], self.t[idx], self.e[idx], self.t_discrete[idx] diff --git a/datasets.md b/datasets.md new file mode 100644 index 00000000..47f95d3d --- /dev/null +++ b/datasets.md @@ -0,0 +1,40 @@ +## Datasets overview +The repository includes four well-established survival analysis datasets: + +1. **Framingham Heart Study**: Cardiovascular disease prediction with competing events (CVD vs. death) + - Features: Demographics, clinical measurements, lifestyle factors + - Events: Cardiovascular disease, death from other causes + +2. **PBC (Primary Biliary Cirrhosis)**: Liver disease progression study + - Features: Clinical laboratory values, demographic information + - Events: Death, liver transplantation + +3. **SUPPORT**: Study to understand prognoses and preferences for outcomes + - Features: Comprehensive clinical and demographic variables + - Events: Cancer death, non-cancer death + +4. **Synthetic Dataset**: Controlled simulation for method validation + - Features: Simulated clinical variables with known ground truth + - Events: Multiple competing risks with controllable hazard functions + +CSV files for all datasets are available in the repository within `crisp-nam/datasets` folder. + +## Data loading scripts +The repository contains preprocessing scripts within `datasets` folder that handle missing values, feature encoding, and proper train/test splitting to prevent data leakage for each dataset. + +- **`framingham_dataset.py`**: Preprocess and load Framingham dataset. +- **`pbc_dataset.py`**: Preprocess and load PBC dataset. +- **`support_dataset.py`**: Preprocess and load Support2 dataset. +- **`synthetic_dataset.py`**: Preprocess and load synthetically generated dataset. + +## Return format +Each script returns the following values for use within training scripts: +1. `x`: Feature matrix after preprocessing +2. `t`: Array of time to event values +3. `event_type`: Categorical data depicting event types +4. `feature_names`: Array of feature names +5. `n_continuous`: Number of continuous features +6. `feature_ranges`: List of (min, max) ranges for each feature. +``` +> [!NOTE] +> If introducing a new dataset, the above mentioned return format is needed to run the training scripts. \ No newline at end of file diff --git a/datasets/SurvivalDataset.py b/datasets/SurvivalDataset.py index c9fd9567..b6a2f01d 100644 --- a/datasets/SurvivalDataset.py +++ b/datasets/SurvivalDataset.py @@ -1,5 +1,6 @@ -from torch.utils.data import Dataset import torch +from torch.utils.data import Dataset + class SurvivalDataset(Dataset): def __init__(self, x, t, e): @@ -11,4 +12,4 @@ def __len__(self): return len(self.x) def __getitem__(self, idx): - return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking \ No newline at end of file + return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking diff --git a/datasets/framingham.csv b/datasets/framingham.csv index b33fe583..50a4f09d 100644 --- a/datasets/framingham.csv +++ b/datasets/framingham.csv @@ -11625,4 +11625,4 @@ RANDID,SEX,TOTCHOL,AGE,SYSBP,DIABP,CURSMOKE,CIGPDAY,BMI,DIABETES,BPMEDS,HEARTRTE 9998212,1,153,52,143,89,0,0,25.74,0,0,65,72,3,0,0,0,0,1,4538,3,30,123,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,0 9999312,2,196,39,133,86,1,30,20.91,0,0,85,80,3,0,0,0,0,0,0,1,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201 9999312,2,240,46,138,79,1,20,26.39,0,0,90,83,3,0,0,0,0,0,2390,2,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201 -9999312,2,,50,147,96,1,10,24.19,0,0,94,,3,0,0,0,0,1,4201,3,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201 \ No newline at end of file +9999312,2,,50,147,96,1,10,24.19,0,0,94,,3,0,0,0,0,1,4201,3,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201 diff --git a/datasets/framingham_dataset.py b/datasets/framingham_dataset.py index 9520c089..4cce629e 100644 --- a/datasets/framingham_dataset.py +++ b/datasets/framingham_dataset.py @@ -1,15 +1,18 @@ import os -import pandas as pd + import numpy as np +import pandas as pd from sklearn.impute import SimpleImputer + def load_framingham(competing=True, sequential=False): """ Load and preprocess the Framingham dataset for competing risks analysis, with imputation but no scaling. Feature normalization must be done externally after splitting to avoid data leakage. - Returns: + Returns + ------- x (np.ndarray): Feature matrix with one-hot categorical + raw continuous features. t (np.ndarray): Time-to-event (with +1 offset). e (np.ndarray): Event indicator (0=censored, 1=CVD, 2=death). @@ -17,58 +20,64 @@ def load_framingham(competing=True, sequential=False): n_continuous (int): Number of continuous features at the end of x. feature_ranges (None): Placeholder for backward compatibility. """ - file_path = os.path.join(os.path.dirname(__file__), "framingham.csv") data = pd.read_csv(file_path) if not sequential: - data = data.groupby("RANDID").first() - cat_cols = [ - 'SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS', - 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP', 'educ' + "SEX", + "CURSMOKE", + "DIABETES", + "BPMEDS", + "PREVCHD", + "PREVAP", + "PREVMI", + "PREVSTRK", + "PREVHYP", + "educ", ] - + # 'HDLC', 'LDLC' - removed to replicate nfg experiments. cont_cols = [ - 'TOTCHOL', 'AGE', - 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI', - 'HEARTRTE', 'GLUCOSE' + "TOTCHOL", + "AGE", + "SYSBP", + "DIABP", + "CIGPDAY", + "BMI", + "HEARTRTE", + "GLUCOSE", ] - - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat = pd.DataFrame( cat_imputer.fit_transform(data[cat_cols]), columns=cat_cols, - index=data.index + index=data.index, ) x_cat = pd.get_dummies(x_cat, drop_first=True) - - cont_imputer = SimpleImputer(strategy='mean') + cont_imputer = SimpleImputer(strategy="mean") x_cont = cont_imputer.fit_transform(data[cont_cols]) - x = np.hstack([x_cat.values, x_cont]) feature_names = np.concatenate([x_cat.columns.values, cont_cols]) n_continuous = len(cont_cols) - event = np.zeros(len(data), dtype=int) - - time = (data['TIMEDTH'] - data['TIME']).values + + time = (data["TIMEDTH"] - data["TIME"]).values # Primary CVD event (risk=1) - cvd_mask = data['CVD'] == 1 + cvd_mask = data["CVD"] == 1 event[cvd_mask] = 1 - time_cvd = (data['TIMECVD'] - data['TIME']).values + time_cvd = (data["TIMECVD"] - data["TIME"]).values time[cvd_mask] = time_cvd[cvd_mask] # Competing death event (risk=2), only if CVD did not occur - death_mask = (data['DEATH'] == 1) & ~cvd_mask + death_mask = (data["DEATH"] == 1) & ~cvd_mask event[death_mask] = 2 # Filter out invalid or zero times @@ -81,10 +90,8 @@ def load_framingham(competing=True, sequential=False): assert not np.isnan(x).any(), "NaNs found in feature matrix" return x, t, e, feature_names, n_continuous, None - else: - raise NotImplementedError("Sequential mode not yet implemented.") + raise NotImplementedError("Sequential mode not yet implemented.") + if __name__ == "__main__": x, t, e, feature_names, n_continuous, feature_ranges = load_framingham() - - diff --git a/datasets/metabric/label.csv b/datasets/metabric/label.csv index 1ea535c2..c8147146 100644 --- a/datasets/metabric/label.csv +++ b/datasets/metabric/label.csv @@ -1,1982 +1,1982 @@ -event_time,label -2999 ,0 -1484 ,0 -3053 ,0 -1721 ,0 -1241 ,1 -234 ,1 -2947 ,0 -672 ,1 -2734 ,1 -2083 ,0 -1097 ,1 -1088 ,1 -2314 ,0 -2258 ,0 -2361 ,0 -424 ,1 -1594 ,0 -1784 ,0 -2005 ,0 -2310 ,0 -1900 ,0 -1718 ,0 -1846 ,0 -2482 ,0 -1420 ,0 -1543 ,0 -2673 ,0 -1583 ,0 -855 ,1 -2149 ,0 -2484 ,1 -2063 ,0 -1493 ,1 -2611 ,0 -1844 ,0 -242 ,1 -2713 ,0 -1333 ,0 -1481 ,0 -1866 ,0 -1281 ,1 -1433 ,0 -2554 ,0 -3148 ,0 -1175 ,1 -1384 ,0 -402 ,0 -2002 ,1 -2474 ,0 -247 ,0 -364 ,0 -872 ,1 -1694 ,0 -2569 ,0 -2527 ,0 -2368 ,0 -49 ,0 -1479 ,0 -1918 ,0 -1882 ,0 -1157 ,1 -1870 ,0 -1999 ,1 -2612 ,0 -388 ,1 -2398 ,0 -2271 ,0 -1405 ,0 -2596 ,0 -169 ,0 -372 ,0 -1630 ,1 -4562 ,1 -1841 ,0 -2580 ,0 -1553 ,1 -53 ,0 -1551 ,1 -2392 ,0 -1477 ,1 -1860 ,0 -2567 ,0 -1884 ,0 -1794 ,0 -59 ,0 -453 ,0 -1729 ,0 -2259 ,0 -2384 ,0 -1673 ,0 -1141 ,1 -325 ,0 -1428 ,0 -2282 ,0 -1293 ,1 -2283 ,0 -2215 ,0 -2801 ,1 -163 ,0 -1854 ,0 -130 ,0 -2363 ,0 -2171 ,0 -2363 ,0 -1368 ,1 -2755 ,0 -538 ,1 -1870 ,0 -2215 ,0 -2297 ,0 -1315 ,1 -939 ,1 -272 ,1 -459 ,1 -1877 ,0 -283 ,0 -2107 ,0 -2019 ,0 -2184 ,0 -2018 ,0 -2283 ,0 -2191 ,0 -2330 ,0 -2344 ,0 -2073 ,0 -1763 ,1 -729 ,1 -1919 ,0 -1585 ,0 -3359 ,0 -1995 ,0 -1881 ,0 -1299 ,1 -1927 ,0 -1841 ,0 -2179 ,0 -1195 ,1 -606 ,1 -3128 ,0 -2440 ,1 -1008 ,0 -4398 ,0 -1361 ,0 -2102 ,0 -325 ,1 -1958 ,0 -2653 ,0 -1366 ,0 -1324 ,0 -2173 ,1 -2827 ,1 -4277 ,1 -3861 ,0 -2083 ,0 -2567 ,1 -2194 ,1 -2198 ,0 -4028 ,0 -2302 ,0 -3259 ,0 -2319 ,0 -2130 ,0 -3628 ,0 -2965 ,0 -3948 ,0 -2752 ,1 -1966 ,0 -3821 ,0 -2707 ,1 -3511 ,0 -2476 ,0 -1293 ,1 -1701 ,0 -2720 ,1 -855 ,1 -1441 ,0 -8220 ,0 -3660 ,1 -3874 ,0 -2141 ,0 -175 ,1 -3807 ,0 -3406 ,0 -3819 ,0 -661 ,1 -101 ,0 -1992 ,0 -2842 ,1 -1184 ,1 -2147 ,1 -2767 ,0 -1176 ,1 -1483 ,1 -1917 ,0 -385 ,0 -3037 ,0 -3684 ,0 -1778 ,1 -2526 ,1 -1778 ,0 -3346 ,1 -1905 ,1 -1289 ,1 -3520 ,0 -1885 ,0 -1270 ,1 -1918 ,0 -2125 ,0 -3576 ,0 -3325 ,0 -133 ,1 -3625 ,0 -1828 ,0 -4550 ,1 -1918 ,0 -1221 ,1 -1968 ,0 -1922 ,0 -1516 ,1 -3198 ,0 -2037 ,0 -3598 ,0 -1772 ,0 -3687 ,0 -1048 ,0 -2042 ,1 -797 ,1 -3602 ,0 -1828 ,0 -800 ,1 -2925 ,0 -1848 ,0 -3760 ,1 -2734 ,1 -3317 ,0 -202 ,0 -613 ,1 -6167 ,1 -2338 ,0 -1848 ,0 -1382 ,1 -5860 ,1 -1656 ,1 -1255 ,1 -348 ,1 -4896 ,1 -5359 ,0 -1524 ,0 -179 ,0 -1256 ,0 -452 ,1 -1411 ,1 -2697 ,1 -3385 ,0 -2617 ,1 -1897 ,0 -1930 ,0 -1857 ,1 -1944 ,0 -1854 ,0 -1945 ,0 -1883 ,1 -810 ,1 -43 ,0 -1767 ,0 -260 ,0 -1584 ,0 -255 ,0 -857 ,1 -2010 ,0 -313 ,0 -2163 ,0 -1688 ,1 -1093 ,1 -2050 ,0 -1829 ,0 -335 ,0 -2166 ,1 -2504 ,0 -989 ,1 -625 ,1 -1927 ,0 -870 ,1 -1820 ,0 -1662 ,0 -1296 ,1 -2038 ,0 -674 ,1 -788 ,1 -1840 ,0 -1905 ,0 -1772 ,1 -1652 ,0 -2120 ,0 -2120 ,0 -1890 ,0 -1809 ,0 -1647 ,0 -2081 ,0 -1876 ,0 -188 ,1 -1970 ,0 -511 ,0 -1938 ,0 -1938 ,0 -1014 ,1 -1940 ,0 -295 ,1 -732 ,1 -2166 ,0 -1647 ,0 -1678 ,0 -1910 ,0 -180 ,0 -1655 ,0 -2566 ,0 -1365 ,1 -1959 ,0 -1885 ,1 -1851 ,0 -2178 ,0 -2143 ,0 -2037 ,0 -1903 ,0 -387 ,0 -1664 ,1 -1678 ,0 -1768 ,0 -877 ,1 -2040 ,0 -1502 ,0 -1884 ,0 -1294 ,1 -1405 ,1 -2006 ,0 -2006 ,0 -1068 ,1 -1838 ,0 -1835 ,0 -1859 ,0 -1179 ,1 -858 ,1 -1340 ,0 -1385 ,0 -2021 ,0 -2015 ,0 -2036 ,0 -1967 ,0 -764 ,0 -1121 ,0 -1527 ,0 -1440 ,0 -1561 ,0 -2354 ,0 -1871 ,0 -1092 ,1 -746 ,1 -2226 ,0 -1883 ,0 -1103 ,1 -1651 ,0 -1918 ,0 -2267 ,0 -1955 ,0 -684 ,0 -399 ,0 -1169 ,0 -1099 ,1 -1883 ,1 -475 ,0 -1847 ,0 -2225 ,0 -253 ,0 -1790 ,0 -1871 ,0 -499 ,0 -1930 ,0 -1821 ,0 -1351 ,1 -749 ,0 -1395 ,0 -1876 ,0 -1592 ,0 -1851 ,0 -1876 ,0 -1667 ,0 -402 ,0 -1482 ,0 -1857 ,0 -1637 ,0 -2964 ,0 -299 ,0 -2425 ,0 -2922 ,0 -2556 ,1 -3295 ,1 -2317 ,0 -1456 ,1 -4098 ,0 -3986 ,0 -3746 ,0 -3463 ,0 -3658 ,0 -3512 ,0 -723 ,1 -2565 ,1 -3367 ,0 -2273 ,0 -1458 ,1 -2082 ,0 -2380 ,1 -1235 ,1 -3598 ,0 -3961 ,0 -839 ,0 -394 ,0 -1559 ,1 -2042 ,1 -2572 ,1 -3335 ,1 -2935 ,1 -3602 ,0 -1138 ,1 -2352 ,0 -1785 ,1 -6705 ,0 -2344 ,0 -2382 ,0 -427 ,0 -1482 ,0 -930 ,1 -1514 ,0 -1204 ,0 -1577 ,0 -2218 ,0 -236 ,1 -466 ,1 -56 ,0 -273 ,0 -1525 ,0 -185 ,0 -1995 ,0 -1540 ,0 -1841 ,0 -1766 ,0 -978 ,0 -487 ,0 -1040 ,1 -1327 ,0 -1079 ,0 -1648 ,0 -1466 ,0 -2004 ,0 -1856 ,0 -880 ,1 -1415 ,0 -1704 ,0 -824 ,1 -1577 ,0 -463 ,0 -363 ,0 -1498 ,0 -626 ,0 -1464 ,0 -257 ,0 -875 ,0 -871 ,1 -922 ,0 -1561 ,0 -489 ,0 -501 ,0 -971 ,0 -1092 ,0 -1291 ,0 -1057 ,1 -23 ,0 -1121 ,0 -342 ,0 -1403 ,1 -965 ,0 -536 ,0 -1171 ,0 -358 ,1 -787 ,0 -922 ,0 -1393 ,0 -1112 ,0 -1227 ,0 -1010 ,0 -402 ,1 -1057 ,0 -77 ,1 -1065 ,0 -522 ,0 -867 ,0 -727 ,1 -1283 ,0 -1176 ,0 -564 ,1 -126 ,0 -819 ,0 -8 ,1 -560 ,1 -600 ,1 -760 ,0 -1661 ,0 -318 ,1 -730 ,1 -2083 ,0 -2957 ,1 -2533 ,0 -498 ,1 -2001 ,0 -433 ,1 -37 ,0 -1271 ,1 -2810 ,1 -2098 ,1 -3277 ,0 -3241 ,0 -1293 ,1 -929 ,1 -1655 ,0 -2266 ,0 -2582 ,0 -1916 ,1 -1215 ,0 -1791 ,1 -5222 ,0 -1437 ,1 -5196 ,0 -5490 ,0 -5292 ,0 -4902 ,0 -1948 ,1 -4703 ,0 -2673 ,1 -2673 ,1 -3855 ,0 -5352 ,0 -5056 ,0 -573 ,1 -5202 ,0 -1355 ,1 -1589 ,1 -1343 ,1 -2211 ,1 -5318 ,0 -4911 ,0 -3997 ,1 -4860 ,0 -5224 ,0 -4857 ,0 -4787 ,0 -4861 ,0 -3242 ,1 -4748 ,0 -1542 ,1 -4011 ,0 -5173 ,0 -747 ,1 -5145 ,0 -5069 ,1 -4521 ,0 -2933 ,1 -4362 ,0 -2457 ,1 -5025 ,0 -5213 ,0 -1050 ,1 -3541 ,1 -4408 ,0 -4832 ,0 -5087 ,0 -5221 ,0 -4800 ,0 -5108 ,0 -1345 ,1 -1111 ,1 -3950 ,1 -5173 ,0 -4717 ,0 -1213 ,1 -2048 ,1 -1837 ,0 -5086 ,0 -3323 ,1 -4866 ,0 -5064 ,0 -2631 ,1 -4535 ,0 -4910 ,0 -2576 ,1 -4751 ,0 -5580 ,0 -5059 ,0 -4811 ,0 -3175 ,0 -4899 ,0 -4703 ,0 -4860 ,0 -4489 ,0 -1651 ,1 -4671 ,0 -4430 ,0 -592 ,1 -5132 ,0 -4402 ,0 -4381 ,0 -4491 ,1 -1186 ,1 -4930 ,0 -1296 ,1 -283 ,1 -4770 ,0 -4311 ,0 -4245 ,0 -4936 ,0 -4247 ,1 -3626 ,0 -4944 ,0 -4436 ,0 -4846 ,0 -620 ,1 -4944 ,0 -3768 ,1 -4614 ,0 -4241 ,0 -958 ,1 -4388 ,0 -432 ,1 -423 ,1 -4740 ,0 -4430 ,0 -4850 ,0 -3252 ,0 -4858 ,1 -3038 ,1 -816 ,1 -3359 ,0 -4709 ,0 -4619 ,0 -904 ,1 -159 ,1 -4412 ,0 -4704 ,0 -4803 ,0 -288 ,1 -4857 ,0 -3062 ,0 -4514 ,0 -4628 ,0 -3318 ,1 -4749 ,0 -4733 ,0 -4426 ,0 -4778 ,0 -4506 ,0 -4398 ,0 -4817 ,0 -4388 ,0 -3772 ,0 -4423 ,0 -471 ,1 -3447 ,1 -4118 ,0 -4474 ,0 -4528 ,0 -4482 ,0 -4590 ,0 -4494 ,0 -4273 ,0 -1277 ,1 -1069 ,1 -4153 ,0 -2112 ,1 -4493 ,0 -2964 ,0 -2215 ,1 -4668 ,0 -4647 ,0 -961 ,1 -4422 ,0 -1338 ,1 -4391 ,0 -857 ,1 -4741 ,0 -4481 ,0 -4475 ,0 -2433 ,1 -4721 ,0 -4633 ,0 -4570 ,1 -4356 ,0 -862 ,1 -2261 ,1 -4694 ,0 -3021 ,0 -3940 ,1 -4381 ,0 -4230 ,0 -4468 ,0 -4570 ,0 -4108 ,0 -3901 ,0 -4584 ,0 -4056 ,0 -4125 ,0 -718 ,1 -4136 ,0 -4535 ,0 -2991 ,1 -1675 ,1 -4020 ,0 -1536 ,1 -1964 ,1 -680 ,1 -1246 ,1 -988 ,1 -3915 ,0 -4026 ,0 -3961 ,0 -4135 ,0 -3803 ,0 -4076 ,0 -3985 ,0 -4124 ,0 -1139 ,1 -4305 ,0 -4209 ,0 -4372 ,0 -4390 ,0 -3269 ,0 -1141 ,1 -803 ,1 -2046 ,1 -326 ,1 -4307 ,0 -4079 ,0 -4040 ,0 -2626 ,1 -3415 ,1 -4241 ,0 -4087 ,0 -696 ,1 -1669 ,1 -4080 ,0 -4131 ,0 -4178 ,0 -4243 ,0 -1503 ,1 -3835 ,0 -3781 ,0 -3104 ,0 -4027 ,0 -3964 ,0 -763 ,1 -3263 ,0 -4246 ,0 -3726 ,0 -3596 ,1 -985 ,1 -3716 ,0 -4180 ,0 -4135 ,0 -1164 ,1 -1884 ,0 -3506 ,0 -4037 ,0 -3737 ,0 -824 ,1 -3072 ,1 -3544 ,1 -465 ,1 -3898 ,0 -3721 ,0 -4108 ,0 -1788 ,1 -3460 ,0 -2582 ,1 -3410 ,0 -3389 ,1 -2587 ,1 -4046 ,0 -990 ,1 -3339 ,0 -3880 ,0 -746 ,1 -3775 ,0 -3319 ,0 -3695 ,0 -3296 ,0 -1164 ,1 -3736 ,0 -3289 ,0 -3872 ,0 -2260 ,1 -2782 ,1 -3586 ,0 -3793 ,0 -3794 ,1 -3766 ,0 -5353 ,0 -3324 ,0 -2168 ,0 -3331 ,0 -2976 ,0 -3327 ,0 -3509 ,0 -3005 ,0 -3368 ,0 -3026 ,1 -842 ,1 -1490 ,1 -3197 ,1 -7495 ,0 -4665 ,0 -1342 ,1 -351 ,1 -4978 ,0 -8321 ,1 -339 ,1 -4179 ,1 -3170 ,1 -3723 ,0 -456 ,1 -8941 ,1 -5160 ,1 -8725 ,0 -8441 ,1 -5291 ,0 -2034 ,1 -6513 ,1 -5244 ,1 -7745 ,1 -4962 ,1 -3435 ,1 -4435 ,1 -6529 ,1 -469 ,1 -1891 ,1 -1487 ,1 -5541 ,1 -4810 ,1 -9184 ,0 -9218 ,0 -2535 ,1 -1657 ,1 -1965 ,1 -3468 ,1 -1091 ,1 -7606 ,1 -9193 ,0 -6111 ,1 -3005 ,1 -5352 ,0 -4896 ,1 -1017 ,1 -2961 ,1 -1798 ,1 -3966 ,1 -7326 ,1 -7801 ,1 -1023 ,0 -5245 ,0 -2905 ,1 -8183 ,0 -6575 ,1 -1653 ,1 -5671 ,0 -2671 ,1 -7041 ,1 -7967 ,1 -8805 ,0 -5337 ,0 -514 ,1 -5894 ,1 -3799 ,1 -2617 ,1 -5978 ,1 -4936 ,0 -7067 ,1 -8507 ,0 -6239 ,1 -5208 ,1 -1011 ,1 -5692 ,0 -5958 ,1 -6786 ,1 -932 ,1 -2044 ,1 -3856 ,1 -461 ,1 -3421 ,1 -356 ,1 -4339 ,0 -3584 ,1 -1277 ,1 -8735 ,0 -1844 ,1 -4930 ,1 -4188 ,1 -2501 ,1 -6197 ,1 -7248 ,1 -501 ,1 -8303 ,0 -7943 ,1 -5127 ,1 -4443 ,1 -8555 ,0 -1569 ,1 -3794 ,1 -7954 ,0 -7673 ,0 -6313 ,1 -5920 ,1 -1235 ,1 -4207 ,0 -3581 ,1 -5735 ,0 -1688 ,1 -3684 ,1 -7177 ,0 -3652 ,0 -2704 ,1 -7390 ,0 -4257 ,0 -6852 ,0 -2680 ,0 -4290 ,1 -5599 ,1 -2875 ,1 -5046 ,0 -3288 ,1 -3570 ,1 -6018 ,0 -4457 ,1 -6808 ,0 -1244 ,1 -2700 ,1 -5056 ,1 -505 ,1 -909 ,1 -1444 ,1 -1929 ,1 -2970 ,1 -1196 ,1 -7027 ,0 -7180 ,0 -7110 ,0 -3347 ,1 -4088 ,0 -3584 ,1 -2034 ,1 -1041 ,1 -6628 ,1 -2610 ,1 -2432 ,1 -7099 ,0 -7049 ,0 -7441 ,0 -4186 ,1 -1299 ,1 -1453 ,1 -2223 ,1 -2559 ,1 -3971 ,0 -3218 ,1 -5949 ,0 -639 ,1 -6982 ,0 -3540 ,1 -5120 ,1 -3826 ,0 -3808 ,0 -1920 ,1 -6393 ,1 -1670 ,1 -6636 ,0 -6837 ,1 -4403 ,0 -1730 ,0 -944 ,1 -6860 ,0 -2510 ,1 -1057 ,1 -1287 ,1 -5950 ,1 -2447 ,1 -712 ,1 -2583 ,0 -6359 ,0 -4382 ,0 -2963 ,1 -840 ,1 -3939 ,1 -1346 ,1 -4496 ,0 -592 ,1 -6974 ,0 -3666 ,0 -5282 ,0 -4850 ,1 -1695 ,1 -2494 ,1 -6905 ,0 -4590 ,1 -1280 ,1 -6690 ,0 -1544 ,1 -1132 ,1 -3758 ,1 -822 ,1 -6358 ,0 -3744 ,1 -6828 ,0 -3063 ,0 -6660 ,0 -1414 ,1 -651 ,1 -4562 ,1 -6950 ,0 -6184 ,1 -3047 ,0 -6889 ,0 -4372 ,1 -3650 ,1 -2432 ,1 -2090 ,1 -5215 ,1 -1219 ,1 -2501 ,1 -1365 ,1 -3848 ,0 -2981 ,1 -6567 ,0 -3328 ,1 -3911 ,1 -1317 ,1 -3894 ,1 -6053 ,1 -5958 ,0 -2204 ,1 -2287 ,1 -1719 ,1 -2749 ,1 -2577 ,0 -5698 ,0 -1371 ,1 -1275 ,1 -5288 ,1 -4252 ,1 -1085 ,1 -703 ,1 -6023 ,1 -6990 ,0 -3509 ,0 -764 ,1 -3049 ,1 -1360 ,1 -2708 ,1 -7351 ,0 -1244 ,1 -3880 ,1 -4734 ,1 -6239 ,0 -5619 ,0 -6315 ,0 -6404 ,0 -6261 ,0 -879 ,1 -6118 ,0 -1194 ,1 -4713 ,0 -3530 ,1 -4518 ,1 -674 ,1 -2399 ,1 -6329 ,0 -6050 ,0 -6083 ,0 -6048 ,0 -1445 ,1 -1062 ,1 -6015 ,0 -5301 ,0 -6269 ,0 -1946 ,1 -5923 ,0 -3537 ,1 -5678 ,0 -6185 ,0 -3877 ,1 -5933 ,0 -1437 ,1 -943 ,1 -1500 ,1 -3083 ,1 -6208 ,0 -1200 ,1 -1753 ,0 -911 ,1 -5981 ,0 -2107 ,0 -5826 ,0 -2317 ,1 -2118 ,1 -1759 ,1 -3724 ,1 -5159 ,0 -604 ,1 -5806 ,0 -4518 ,1 -5532 ,0 -2079 ,1 -5930 ,1 -3527 ,1 -6080 ,0 -4743 ,0 -1494 ,0 -1479 ,1 -3921 ,1 -2361 ,0 -5943 ,1 -2664 ,1 -5283 ,1 -5966 ,0 -6077 ,0 -5953 ,0 -5423 ,1 -575 ,1 -2394 ,1 -5603 ,0 -3806 ,0 -4961 ,0 -2257 ,1 -5224 ,1 -2993 ,1 -1967 ,1 -5638 ,0 -2454 ,1 -5782 ,0 -5762 ,0 -2415 ,1 -5637 ,0 -5688 ,0 -5221 ,0 -4348 ,0 -3258 ,1 -1134 ,1 -2545 ,1 -3558 ,1 -5696 ,0 -1582 ,1 -3176 ,0 -1864 ,1 -2375 ,1 -3674 ,0 -3851 ,1 -4183 ,0 -5049 ,1 -2591 ,1 -5920 ,0 -5530 ,0 -1037 ,1 -4518 ,1 -5631 ,0 -1690 ,1 -5362 ,1 -21 ,0 -4503 ,1 -1183 ,0 -5048 ,0 -5715 ,0 -2745 ,0 -5544 ,0 -2841 ,0 -5475 ,0 -5731 ,0 -1507 ,1 -2245 ,0 -5908 ,0 -806 ,1 -2613 ,1 -5736 ,0 -1740 ,1 -5936 ,0 -2527 ,1 -3696 ,0 -5237 ,1 -4273 ,1 -1848 ,0 -5592 ,1 -597 ,1 -1030 ,1 -861 ,1 -3675 ,1 -5865 ,0 -5354 ,0 -2479 ,1 -3752 ,1 -657 ,1 -5269 ,0 -461 ,1 -2322 ,0 -2153 ,1 -5758 ,0 -5615 ,0 -1793 ,0 -834 ,1 -3447 ,1 -3771 ,1 -2709 ,1 -3726 ,1 -819 ,1 -3615 ,0 -5718 ,0 -5910 ,0 -5469 ,0 -1118 ,1 -5362 ,0 -1298 ,0 -1110 ,1 -2699 ,1 -4042 ,1 -3838 ,1 -2502 ,0 -4777 ,1 -608 ,0 -5820 ,0 -3325 ,1 -3478 ,1 -5125 ,0 -5529 ,0 -547 ,1 -5812 ,0 -2632 ,1 -5260 ,0 -4930 ,0 -5553 ,0 -2415 ,0 -2422 ,1 -5637 ,0 -1559 ,1 -5550 ,1 -5528 ,0 -664 ,1 -4536 ,1 -5616 ,0 -3379 ,1 -1957 ,1 -5530 ,0 -5503 ,0 -4028 ,1 -3556 ,1 -2680 ,1 -2324 ,1 -5438 ,0 -582 ,1 -3359 ,1 -1447 ,0 -4265 ,1 -1628 ,1 -5617 ,0 -2821 ,1 -5706 ,0 -5414 ,0 -1569 ,1 -2909 ,1 -5391 ,0 -5559 ,0 -5453 ,0 -4557 ,0 -1476 ,0 -4514 ,0 -5283 ,0 -1302 ,1 -3221 ,1 -3561 ,1 -5598 ,0 -5354 ,0 -3939 ,1 -3355 ,1 -4340 ,1 -2982 ,1 -1298 ,1 -5349 ,0 -2296 ,0 -5422 ,0 -1019 ,1 -5275 ,0 -497 ,1 -5131 ,0 -5203 ,0 -5488 ,0 -5000 ,0 -1324 ,1 -5157 ,0 -3115 ,1 -1153 ,1 -3724 ,0 -3374 ,1 -3062 ,1 -1607 ,0 -3242 ,1 -3032 ,1 -1385 ,1 -5102 ,0 -3494 ,1 -5153 ,0 -2509 ,1 -2132 ,1 -5584 ,0 -5596 ,0 -5025 ,0 -5602 ,0 -1826 ,1 -2279 ,1 -5908 ,0 -2999 ,1 -2063 ,1 -2336 ,1 -2379 ,0 -1960 ,1 -3720 ,1 -3852 ,1 -2610 ,1 -1108 ,1 -4606 ,0 -1453 ,1 -5240 ,0 -3394 ,0 -2268 ,1 -414 ,1 -2547 ,1 -501 ,1 -4970 ,0 -5222 ,0 -5085 ,0 -4483 ,1 -4792 ,0 -3143 ,1 -3524 ,0 -1730 ,1 -1137 ,1 -1396 ,1 -1291 ,1 -1609 ,1 -425 ,1 -3069 ,1 -4732 ,0 -5049 ,0 -5116 ,0 -2886 ,1 -4906 ,0 -426 ,1 -1372 ,1 -3436 ,1 -3724 ,1 -2965 ,1 -2562 ,1 -757 ,1 -1550 ,1 -5147 ,0 -5305 ,0 -5192 ,0 -5092 ,0 -5460 ,0 -5090 ,0 -3784 ,0 -1609 ,1 -5405 ,0 -865 ,1 -4834 ,1 -461 ,1 -3773 ,1 -691 ,1 -5059 ,0 -3004 ,1 -1179 ,1 -3579 ,1 -76 ,0 -3060 ,1 -368 ,1 -5137 ,0 -3487 ,1 -3410 ,1 -2135 ,1 -4464 ,1 -5082 ,0 -1369 ,1 -4408 ,1 -3743 ,1 -928 ,1 -5066 ,0 -4620 ,1 -5040 ,0 -4955 ,0 -4970 ,0 -5114 ,0 -5386 ,0 -3204 ,1 -2105 ,0 -3233 ,1 -2872 ,1 -2242 ,1 -3528 ,0 -2923 ,1 -3011 ,0 -3042 ,0 -5262 ,0 -2787 ,0 -780 ,1 -3850 ,0 -4735 ,1 -4218 ,1 -3212 ,0 -2260 ,0 -3018 ,0 -5245 ,0 -1382 ,1 -2506 ,1 -1484 ,1 -1039 ,1 -4672 ,1 -4514 ,0 -921 ,1 -476 ,1 -2955 ,1 -2421 ,0 -1355 ,1 -3511 ,0 -291 ,1 -5194 ,0 -4588 ,0 -4885 ,1 -4466 ,1 -3708 ,0 -4923 ,0 -485 ,1 -1719 ,1 -1520 ,1 -1854 ,1 -4828 ,1 -4891 ,0 -2724 ,1 -4490 ,0 -1744 ,1 -3496 ,1 -564 ,1 -4895 ,0 -2981 ,0 -1262 ,1 -4869 ,0 -587 ,1 -3777 ,1 -5089 ,0 -76 ,1 -3309 ,0 -1029 ,1 -1391 ,1 -530 ,1 -1273 ,1 -4973 ,0 -954 ,1 -3699 ,1 -5296 ,0 -1536 ,1 -1975 ,1 -701 ,1 -1717 ,1 -4662 ,0 -5187 ,0 -3001 ,0 -3712 ,0 -1799 ,0 -3583 ,0 -3093 ,1 -4437 ,0 -1099 ,1 -4771 ,0 -926 ,1 -4816 ,0 -489 ,1 -3744 ,0 -60 ,0 -4532 ,0 -4312 ,0 -444 ,1 -4035 ,0 -4820 ,0 -2396 ,1 -4741 ,0 -2373 ,1 -3604 ,1 -1453 ,1 -4989 ,0 -1033 ,1 -5036 ,0 -2957 ,0 -760 ,0 -2668 ,1 -771 ,0 -3542 ,0 -1041 ,1 -4635 ,1 -4773 ,0 -4405 ,0 -751 ,0 -565 ,1 -4802 ,0 -3089 ,1 -3530 ,1 -3461 ,0 -4737 ,0 -232 ,0 -4616 ,1 -4697 ,0 -3530 ,1 -4767 ,0 -1662 ,1 -2707 ,1 -2149 ,1 -4056 ,0 -3341 ,0 -2733 ,1 -1617 ,1 -2434 ,1 -4807 ,0 -535 ,0 -4648 ,1 -790 ,1 -2775 ,1 -4912 ,0 -982 ,1 -4521 ,0 -4215 ,0 -4669 ,0 -3663 ,0 -4369 ,0 -3711 ,0 -4267 ,0 -3493 ,1 -4778 ,0 -3846 ,1 -1262 ,1 -1911 ,0 -4134 ,1 -4080 ,0 -969 ,1 -2489 ,0 -1278 ,1 -3114 ,1 -1145 ,1 -1051 ,1 -1887 ,1 -632 ,1 -3249 ,1 -4551 ,0 -4006 ,0 -3625 ,0 -4213 ,0 -667 ,1 -4522 ,0 -924 ,1 -4152 ,0 -630 ,1 -3140 ,1 -4263 ,0 -1915 ,0 -4524 ,0 -3792 ,1 -4059 ,1 -1907 ,1 -2116 ,1 -2073 ,1 -4441 ,0 -6334 ,1 -4333 ,1 -5758 ,1 -4306 ,1 -2127 ,1 -4295 ,1 -406 ,1 -1448 ,1 -4690 ,0 -1562 ,1 -6172 ,0 -6749 ,0 -1492 ,1 -6933 ,0 -4361 ,1 -5683 ,0 -2194 ,1 -5861 ,0 -5214 ,0 -1136 ,1 -5046 ,1 -6201 ,0 -5023 ,0 -5605 ,1 -6032 ,0 -6060 ,0 -642 ,0 -680 ,1 -860 ,0 -1141 ,1 -6112 ,0 -468 ,1 -5656 ,1 -5478 ,1 -2150 ,0 -1722 ,1 -5470 ,1 -5108 ,1 -5294 ,0 -1623 ,1 -3075 ,1 -4921 ,1 -730 ,0 -69 ,1 -4717 ,0 -319 ,1 -302 ,1 -2325 ,1 -1410 ,0 -3391 ,0 -3700 ,1 -2990 ,0 -3737 ,1 -2005 ,0 -3610 ,0 -5215 ,0 -1603 ,0 -3469 ,1 -4617 ,1 -1301 ,0 -1854 ,1 -2358 ,1 -769 ,1 -2366 ,1 -4046 ,0 -3133 ,0 -1030 ,1 -2688 ,1 -1669 ,0 -2248 ,1 -2330 ,1 -2310 ,0 -3653 ,0 -1758 ,1 -476 ,1 -3752 ,0 -2660 ,1 -135 ,0 -3684 ,1 -4361 ,0 -2542 ,1 -2768 ,0 -994 ,1 -1754 ,1 -1648 ,1 -3595 ,0 -2374 ,1 -4794 ,0 -1992 ,0 -1133 ,1 -713 ,0 -2024 ,0 -739 ,1 -1279 ,1 -1574 ,1 -3335 ,0 -902 ,1 -3528 ,0 -1703 ,1 -4614 ,0 -3549 ,1 -775 ,0 -1269 ,1 -2088 ,0 -3409 ,0 -737 ,0 -4488 ,0 -4434 ,1 -3836 ,0 -1064 ,1 -3536 ,0 -1061 ,1 -3779 ,0 -3 ,1 -2244 ,1 -2177 ,1 -2550 ,1 -1704 ,0 -3344 ,0 -3156 ,1 -1608 ,0 -3622 ,0 -2637 ,0 -3723 ,0 -1586 ,1 -2254 ,0 -441 ,1 -2933 ,0 -1827 ,1 -2037 ,0 -2730 ,0 -2276 ,0 -3352 ,0 -672 ,1 -3150 ,0 -836 ,1 -105 ,1 -2772 ,1 -2650 ,1 -1486 ,1 -1056 ,1 -3009 ,1 -2204 ,0 -1915 ,1 -1528 ,0 -2658 ,0 -1761 ,0 -185 ,0 -1996 ,0 -2570 ,0 -877 ,1 -427 ,0 -2401 ,0 -2411 ,1 -548 ,1 -2022 ,0 -1869 ,0 -1889 ,0 -1737 ,0 -516 ,1 -1582 ,0 -4588 ,1 -6160 ,0 -1136 ,1 -2551 ,1 -501 ,1 -1816 ,0 -1940 ,0 -1738 ,0 -1625 ,0 -1514 ,0 -1360 ,0 -1778 ,0 -1758 ,0 -1295 ,1 -1643 ,1 -2221 ,0 -1785 ,0 -595 ,1 -1393 ,0 -2061 ,0 -2345 ,0 -1805 ,0 -2135 ,1 -2394 ,0 -2143 ,1 -2565 ,0 -803 ,1 -2108 ,0 -2407 ,0 -2008 ,0 -1103 ,1 -2694 ,0 -2306 ,0 -940 ,1 -2561 ,0 -2144 ,0 -2253 ,0 -962 ,1 -2132 ,0 -2592 ,0 -2431 ,0 -1476 ,1 -2155 ,0 -2843 ,1 -2655 ,0 -2746 ,0 -2605 ,0 -2847 ,1 -2404 ,0 -2654 ,0 -3061 ,0 -1954 ,1 -2578 ,1 -3184 ,0 -3664 ,0 -913 ,1 -274 ,1 -3510 ,0 -1315 ,1 -3675 ,0 -3580 ,0 -3056 ,0 -2816 ,0 -2876 ,0 -2598 ,0 -3236 ,0 -2693 ,0 -3150 ,0 -2954 ,0 -3438 ,0 -3242 ,0 -3062 ,0 -1570 ,1 -3477 ,0 -3245 ,0 -3349 ,0 -2951 ,0 -3038 ,0 -125 ,1 -2732 ,0 -2737 ,0 -3066 ,0 -3164 ,0 -3496 ,0 -3303 ,0 -3706 ,0 -152 ,1 -3706 ,0 -3567 ,0 -3293 ,0 -3308 ,0 -3672 ,0 -3064 ,0 -3401 ,0 -3469 ,0 -3720 ,0 -1955 ,1 -3597 ,0 -3538 ,0 -3728 ,0 -3123 ,1 -3467 ,0 -1125 ,1 -1708 ,1 -3526 ,0 -905 ,1 -1941 ,1 -3297 ,0 -3698 ,0 -3559 ,0 -1001 ,1 -4217 ,0 -3851 ,0 -2665 ,1 -897 ,1 -4130 ,0 -4188 ,0 -4060 ,0 -1808 ,1 -3089 ,1 -4042 ,0 -3948 ,0 -715 ,1 -802 ,1 -4120 ,0 -3520 ,1 -2724 ,1 -4157 ,0 -4377 ,0 -4134 ,0 -3079 ,1 -4150 ,0 -3884 ,0 -570 ,1 -1339 ,1 -3869 ,0 -4333 ,0 -4136 ,0 -4284 ,0 -3254 ,1 -4493 ,0 -4673 ,0 -1393 ,1 -658 ,1 -4442 ,0 -3646 ,0 -4206 ,0 -1579 ,1 -4374 ,1 -4415 ,0 -4231 ,0 -4515 ,0 -4662 ,0 -4414 ,1 -4479 ,0 -1906 ,1 -4223 ,0 -4554 ,0 -4726 ,0 -1484 ,1 -3751 ,1 -146 ,0 -4403 ,1 -4357 ,0 -4308 ,1 -2624 ,1 -1400 ,1 -2265 ,1 -3213 ,1 -4639 ,0 -502 ,1 -2280 ,1 -2928 ,1 -4903 ,0 -4981 ,0 -1672 ,1 -1378 ,1 -1575 ,1 -4960 ,0 -4608 ,0 -2712 ,1 -4963 ,0 -4956 ,0 -4548 ,0 -1650 ,1 -1601 ,1 -1896 ,1 -5212 ,0 -3348 ,1 -5478 ,0 -4883 ,0 -5328 ,0 -4736 ,0 -5190 ,0 -5124 ,0 -4804 ,0 -718 ,1 -698 ,1 -5133 ,0 -4464 ,1 -5525 ,0 -1760 ,1 -1530 ,1 -647 ,1 -1071 ,1 -1799 ,1 -5766 ,0 -1041 ,1 -452 ,1 -935 ,0 -3121 ,1 -4949 ,1 -1289 ,1 -2045 ,1 -5831 ,0 -5704 ,0 -635 ,1 -5255 ,0 -6246 ,0 -1332 ,1 -5560 ,0 -5611 ,0 -5866 ,0 -639 ,1 -5866 ,0 -1486 ,1 -866 ,1 -5998 ,0 -3270 ,1 -4759 ,1 -2909 ,1 -812 ,1 -3800 ,1 -205 ,1 -2354 ,1 -5977 ,0 -2482 ,1 -5906 ,0 -1342 ,1 -5279 ,1 -2587 ,1 -5893 ,0 +event_time,label +2999 ,0 +1484 ,0 +3053 ,0 +1721 ,0 +1241 ,1 +234 ,1 +2947 ,0 +672 ,1 +2734 ,1 +2083 ,0 +1097 ,1 +1088 ,1 +2314 ,0 +2258 ,0 +2361 ,0 +424 ,1 +1594 ,0 +1784 ,0 +2005 ,0 +2310 ,0 +1900 ,0 +1718 ,0 +1846 ,0 +2482 ,0 +1420 ,0 +1543 ,0 +2673 ,0 +1583 ,0 +855 ,1 +2149 ,0 +2484 ,1 +2063 ,0 +1493 ,1 +2611 ,0 +1844 ,0 +242 ,1 +2713 ,0 +1333 ,0 +1481 ,0 +1866 ,0 +1281 ,1 +1433 ,0 +2554 ,0 +3148 ,0 +1175 ,1 +1384 ,0 +402 ,0 +2002 ,1 +2474 ,0 +247 ,0 +364 ,0 +872 ,1 +1694 ,0 +2569 ,0 +2527 ,0 +2368 ,0 +49 ,0 +1479 ,0 +1918 ,0 +1882 ,0 +1157 ,1 +1870 ,0 +1999 ,1 +2612 ,0 +388 ,1 +2398 ,0 +2271 ,0 +1405 ,0 +2596 ,0 +169 ,0 +372 ,0 +1630 ,1 +4562 ,1 +1841 ,0 +2580 ,0 +1553 ,1 +53 ,0 +1551 ,1 +2392 ,0 +1477 ,1 +1860 ,0 +2567 ,0 +1884 ,0 +1794 ,0 +59 ,0 +453 ,0 +1729 ,0 +2259 ,0 +2384 ,0 +1673 ,0 +1141 ,1 +325 ,0 +1428 ,0 +2282 ,0 +1293 ,1 +2283 ,0 +2215 ,0 +2801 ,1 +163 ,0 +1854 ,0 +130 ,0 +2363 ,0 +2171 ,0 +2363 ,0 +1368 ,1 +2755 ,0 +538 ,1 +1870 ,0 +2215 ,0 +2297 ,0 +1315 ,1 +939 ,1 +272 ,1 +459 ,1 +1877 ,0 +283 ,0 +2107 ,0 +2019 ,0 +2184 ,0 +2018 ,0 +2283 ,0 +2191 ,0 +2330 ,0 +2344 ,0 +2073 ,0 +1763 ,1 +729 ,1 +1919 ,0 +1585 ,0 +3359 ,0 +1995 ,0 +1881 ,0 +1299 ,1 +1927 ,0 +1841 ,0 +2179 ,0 +1195 ,1 +606 ,1 +3128 ,0 +2440 ,1 +1008 ,0 +4398 ,0 +1361 ,0 +2102 ,0 +325 ,1 +1958 ,0 +2653 ,0 +1366 ,0 +1324 ,0 +2173 ,1 +2827 ,1 +4277 ,1 +3861 ,0 +2083 ,0 +2567 ,1 +2194 ,1 +2198 ,0 +4028 ,0 +2302 ,0 +3259 ,0 +2319 ,0 +2130 ,0 +3628 ,0 +2965 ,0 +3948 ,0 +2752 ,1 +1966 ,0 +3821 ,0 +2707 ,1 +3511 ,0 +2476 ,0 +1293 ,1 +1701 ,0 +2720 ,1 +855 ,1 +1441 ,0 +8220 ,0 +3660 ,1 +3874 ,0 +2141 ,0 +175 ,1 +3807 ,0 +3406 ,0 +3819 ,0 +661 ,1 +101 ,0 +1992 ,0 +2842 ,1 +1184 ,1 +2147 ,1 +2767 ,0 +1176 ,1 +1483 ,1 +1917 ,0 +385 ,0 +3037 ,0 +3684 ,0 +1778 ,1 +2526 ,1 +1778 ,0 +3346 ,1 +1905 ,1 +1289 ,1 +3520 ,0 +1885 ,0 +1270 ,1 +1918 ,0 +2125 ,0 +3576 ,0 +3325 ,0 +133 ,1 +3625 ,0 +1828 ,0 +4550 ,1 +1918 ,0 +1221 ,1 +1968 ,0 +1922 ,0 +1516 ,1 +3198 ,0 +2037 ,0 +3598 ,0 +1772 ,0 +3687 ,0 +1048 ,0 +2042 ,1 +797 ,1 +3602 ,0 +1828 ,0 +800 ,1 +2925 ,0 +1848 ,0 +3760 ,1 +2734 ,1 +3317 ,0 +202 ,0 +613 ,1 +6167 ,1 +2338 ,0 +1848 ,0 +1382 ,1 +5860 ,1 +1656 ,1 +1255 ,1 +348 ,1 +4896 ,1 +5359 ,0 +1524 ,0 +179 ,0 +1256 ,0 +452 ,1 +1411 ,1 +2697 ,1 +3385 ,0 +2617 ,1 +1897 ,0 +1930 ,0 +1857 ,1 +1944 ,0 +1854 ,0 +1945 ,0 +1883 ,1 +810 ,1 +43 ,0 +1767 ,0 +260 ,0 +1584 ,0 +255 ,0 +857 ,1 +2010 ,0 +313 ,0 +2163 ,0 +1688 ,1 +1093 ,1 +2050 ,0 +1829 ,0 +335 ,0 +2166 ,1 +2504 ,0 +989 ,1 +625 ,1 +1927 ,0 +870 ,1 +1820 ,0 +1662 ,0 +1296 ,1 +2038 ,0 +674 ,1 +788 ,1 +1840 ,0 +1905 ,0 +1772 ,1 +1652 ,0 +2120 ,0 +2120 ,0 +1890 ,0 +1809 ,0 +1647 ,0 +2081 ,0 +1876 ,0 +188 ,1 +1970 ,0 +511 ,0 +1938 ,0 +1938 ,0 +1014 ,1 +1940 ,0 +295 ,1 +732 ,1 +2166 ,0 +1647 ,0 +1678 ,0 +1910 ,0 +180 ,0 +1655 ,0 +2566 ,0 +1365 ,1 +1959 ,0 +1885 ,1 +1851 ,0 +2178 ,0 +2143 ,0 +2037 ,0 +1903 ,0 +387 ,0 +1664 ,1 +1678 ,0 +1768 ,0 +877 ,1 +2040 ,0 +1502 ,0 +1884 ,0 +1294 ,1 +1405 ,1 +2006 ,0 +2006 ,0 +1068 ,1 +1838 ,0 +1835 ,0 +1859 ,0 +1179 ,1 +858 ,1 +1340 ,0 +1385 ,0 +2021 ,0 +2015 ,0 +2036 ,0 +1967 ,0 +764 ,0 +1121 ,0 +1527 ,0 +1440 ,0 +1561 ,0 +2354 ,0 +1871 ,0 +1092 ,1 +746 ,1 +2226 ,0 +1883 ,0 +1103 ,1 +1651 ,0 +1918 ,0 +2267 ,0 +1955 ,0 +684 ,0 +399 ,0 +1169 ,0 +1099 ,1 +1883 ,1 +475 ,0 +1847 ,0 +2225 ,0 +253 ,0 +1790 ,0 +1871 ,0 +499 ,0 +1930 ,0 +1821 ,0 +1351 ,1 +749 ,0 +1395 ,0 +1876 ,0 +1592 ,0 +1851 ,0 +1876 ,0 +1667 ,0 +402 ,0 +1482 ,0 +1857 ,0 +1637 ,0 +2964 ,0 +299 ,0 +2425 ,0 +2922 ,0 +2556 ,1 +3295 ,1 +2317 ,0 +1456 ,1 +4098 ,0 +3986 ,0 +3746 ,0 +3463 ,0 +3658 ,0 +3512 ,0 +723 ,1 +2565 ,1 +3367 ,0 +2273 ,0 +1458 ,1 +2082 ,0 +2380 ,1 +1235 ,1 +3598 ,0 +3961 ,0 +839 ,0 +394 ,0 +1559 ,1 +2042 ,1 +2572 ,1 +3335 ,1 +2935 ,1 +3602 ,0 +1138 ,1 +2352 ,0 +1785 ,1 +6705 ,0 +2344 ,0 +2382 ,0 +427 ,0 +1482 ,0 +930 ,1 +1514 ,0 +1204 ,0 +1577 ,0 +2218 ,0 +236 ,1 +466 ,1 +56 ,0 +273 ,0 +1525 ,0 +185 ,0 +1995 ,0 +1540 ,0 +1841 ,0 +1766 ,0 +978 ,0 +487 ,0 +1040 ,1 +1327 ,0 +1079 ,0 +1648 ,0 +1466 ,0 +2004 ,0 +1856 ,0 +880 ,1 +1415 ,0 +1704 ,0 +824 ,1 +1577 ,0 +463 ,0 +363 ,0 +1498 ,0 +626 ,0 +1464 ,0 +257 ,0 +875 ,0 +871 ,1 +922 ,0 +1561 ,0 +489 ,0 +501 ,0 +971 ,0 +1092 ,0 +1291 ,0 +1057 ,1 +23 ,0 +1121 ,0 +342 ,0 +1403 ,1 +965 ,0 +536 ,0 +1171 ,0 +358 ,1 +787 ,0 +922 ,0 +1393 ,0 +1112 ,0 +1227 ,0 +1010 ,0 +402 ,1 +1057 ,0 +77 ,1 +1065 ,0 +522 ,0 +867 ,0 +727 ,1 +1283 ,0 +1176 ,0 +564 ,1 +126 ,0 +819 ,0 +8 ,1 +560 ,1 +600 ,1 +760 ,0 +1661 ,0 +318 ,1 +730 ,1 +2083 ,0 +2957 ,1 +2533 ,0 +498 ,1 +2001 ,0 +433 ,1 +37 ,0 +1271 ,1 +2810 ,1 +2098 ,1 +3277 ,0 +3241 ,0 +1293 ,1 +929 ,1 +1655 ,0 +2266 ,0 +2582 ,0 +1916 ,1 +1215 ,0 +1791 ,1 +5222 ,0 +1437 ,1 +5196 ,0 +5490 ,0 +5292 ,0 +4902 ,0 +1948 ,1 +4703 ,0 +2673 ,1 +2673 ,1 +3855 ,0 +5352 ,0 +5056 ,0 +573 ,1 +5202 ,0 +1355 ,1 +1589 ,1 +1343 ,1 +2211 ,1 +5318 ,0 +4911 ,0 +3997 ,1 +4860 ,0 +5224 ,0 +4857 ,0 +4787 ,0 +4861 ,0 +3242 ,1 +4748 ,0 +1542 ,1 +4011 ,0 +5173 ,0 +747 ,1 +5145 ,0 +5069 ,1 +4521 ,0 +2933 ,1 +4362 ,0 +2457 ,1 +5025 ,0 +5213 ,0 +1050 ,1 +3541 ,1 +4408 ,0 +4832 ,0 +5087 ,0 +5221 ,0 +4800 ,0 +5108 ,0 +1345 ,1 +1111 ,1 +3950 ,1 +5173 ,0 +4717 ,0 +1213 ,1 +2048 ,1 +1837 ,0 +5086 ,0 +3323 ,1 +4866 ,0 +5064 ,0 +2631 ,1 +4535 ,0 +4910 ,0 +2576 ,1 +4751 ,0 +5580 ,0 +5059 ,0 +4811 ,0 +3175 ,0 +4899 ,0 +4703 ,0 +4860 ,0 +4489 ,0 +1651 ,1 +4671 ,0 +4430 ,0 +592 ,1 +5132 ,0 +4402 ,0 +4381 ,0 +4491 ,1 +1186 ,1 +4930 ,0 +1296 ,1 +283 ,1 +4770 ,0 +4311 ,0 +4245 ,0 +4936 ,0 +4247 ,1 +3626 ,0 +4944 ,0 +4436 ,0 +4846 ,0 +620 ,1 +4944 ,0 +3768 ,1 +4614 ,0 +4241 ,0 +958 ,1 +4388 ,0 +432 ,1 +423 ,1 +4740 ,0 +4430 ,0 +4850 ,0 +3252 ,0 +4858 ,1 +3038 ,1 +816 ,1 +3359 ,0 +4709 ,0 +4619 ,0 +904 ,1 +159 ,1 +4412 ,0 +4704 ,0 +4803 ,0 +288 ,1 +4857 ,0 +3062 ,0 +4514 ,0 +4628 ,0 +3318 ,1 +4749 ,0 +4733 ,0 +4426 ,0 +4778 ,0 +4506 ,0 +4398 ,0 +4817 ,0 +4388 ,0 +3772 ,0 +4423 ,0 +471 ,1 +3447 ,1 +4118 ,0 +4474 ,0 +4528 ,0 +4482 ,0 +4590 ,0 +4494 ,0 +4273 ,0 +1277 ,1 +1069 ,1 +4153 ,0 +2112 ,1 +4493 ,0 +2964 ,0 +2215 ,1 +4668 ,0 +4647 ,0 +961 ,1 +4422 ,0 +1338 ,1 +4391 ,0 +857 ,1 +4741 ,0 +4481 ,0 +4475 ,0 +2433 ,1 +4721 ,0 +4633 ,0 +4570 ,1 +4356 ,0 +862 ,1 +2261 ,1 +4694 ,0 +3021 ,0 +3940 ,1 +4381 ,0 +4230 ,0 +4468 ,0 +4570 ,0 +4108 ,0 +3901 ,0 +4584 ,0 +4056 ,0 +4125 ,0 +718 ,1 +4136 ,0 +4535 ,0 +2991 ,1 +1675 ,1 +4020 ,0 +1536 ,1 +1964 ,1 +680 ,1 +1246 ,1 +988 ,1 +3915 ,0 +4026 ,0 +3961 ,0 +4135 ,0 +3803 ,0 +4076 ,0 +3985 ,0 +4124 ,0 +1139 ,1 +4305 ,0 +4209 ,0 +4372 ,0 +4390 ,0 +3269 ,0 +1141 ,1 +803 ,1 +2046 ,1 +326 ,1 +4307 ,0 +4079 ,0 +4040 ,0 +2626 ,1 +3415 ,1 +4241 ,0 +4087 ,0 +696 ,1 +1669 ,1 +4080 ,0 +4131 ,0 +4178 ,0 +4243 ,0 +1503 ,1 +3835 ,0 +3781 ,0 +3104 ,0 +4027 ,0 +3964 ,0 +763 ,1 +3263 ,0 +4246 ,0 +3726 ,0 +3596 ,1 +985 ,1 +3716 ,0 +4180 ,0 +4135 ,0 +1164 ,1 +1884 ,0 +3506 ,0 +4037 ,0 +3737 ,0 +824 ,1 +3072 ,1 +3544 ,1 +465 ,1 +3898 ,0 +3721 ,0 +4108 ,0 +1788 ,1 +3460 ,0 +2582 ,1 +3410 ,0 +3389 ,1 +2587 ,1 +4046 ,0 +990 ,1 +3339 ,0 +3880 ,0 +746 ,1 +3775 ,0 +3319 ,0 +3695 ,0 +3296 ,0 +1164 ,1 +3736 ,0 +3289 ,0 +3872 ,0 +2260 ,1 +2782 ,1 +3586 ,0 +3793 ,0 +3794 ,1 +3766 ,0 +5353 ,0 +3324 ,0 +2168 ,0 +3331 ,0 +2976 ,0 +3327 ,0 +3509 ,0 +3005 ,0 +3368 ,0 +3026 ,1 +842 ,1 +1490 ,1 +3197 ,1 +7495 ,0 +4665 ,0 +1342 ,1 +351 ,1 +4978 ,0 +8321 ,1 +339 ,1 +4179 ,1 +3170 ,1 +3723 ,0 +456 ,1 +8941 ,1 +5160 ,1 +8725 ,0 +8441 ,1 +5291 ,0 +2034 ,1 +6513 ,1 +5244 ,1 +7745 ,1 +4962 ,1 +3435 ,1 +4435 ,1 +6529 ,1 +469 ,1 +1891 ,1 +1487 ,1 +5541 ,1 +4810 ,1 +9184 ,0 +9218 ,0 +2535 ,1 +1657 ,1 +1965 ,1 +3468 ,1 +1091 ,1 +7606 ,1 +9193 ,0 +6111 ,1 +3005 ,1 +5352 ,0 +4896 ,1 +1017 ,1 +2961 ,1 +1798 ,1 +3966 ,1 +7326 ,1 +7801 ,1 +1023 ,0 +5245 ,0 +2905 ,1 +8183 ,0 +6575 ,1 +1653 ,1 +5671 ,0 +2671 ,1 +7041 ,1 +7967 ,1 +8805 ,0 +5337 ,0 +514 ,1 +5894 ,1 +3799 ,1 +2617 ,1 +5978 ,1 +4936 ,0 +7067 ,1 +8507 ,0 +6239 ,1 +5208 ,1 +1011 ,1 +5692 ,0 +5958 ,1 +6786 ,1 +932 ,1 +2044 ,1 +3856 ,1 +461 ,1 +3421 ,1 +356 ,1 +4339 ,0 +3584 ,1 +1277 ,1 +8735 ,0 +1844 ,1 +4930 ,1 +4188 ,1 +2501 ,1 +6197 ,1 +7248 ,1 +501 ,1 +8303 ,0 +7943 ,1 +5127 ,1 +4443 ,1 +8555 ,0 +1569 ,1 +3794 ,1 +7954 ,0 +7673 ,0 +6313 ,1 +5920 ,1 +1235 ,1 +4207 ,0 +3581 ,1 +5735 ,0 +1688 ,1 +3684 ,1 +7177 ,0 +3652 ,0 +2704 ,1 +7390 ,0 +4257 ,0 +6852 ,0 +2680 ,0 +4290 ,1 +5599 ,1 +2875 ,1 +5046 ,0 +3288 ,1 +3570 ,1 +6018 ,0 +4457 ,1 +6808 ,0 +1244 ,1 +2700 ,1 +5056 ,1 +505 ,1 +909 ,1 +1444 ,1 +1929 ,1 +2970 ,1 +1196 ,1 +7027 ,0 +7180 ,0 +7110 ,0 +3347 ,1 +4088 ,0 +3584 ,1 +2034 ,1 +1041 ,1 +6628 ,1 +2610 ,1 +2432 ,1 +7099 ,0 +7049 ,0 +7441 ,0 +4186 ,1 +1299 ,1 +1453 ,1 +2223 ,1 +2559 ,1 +3971 ,0 +3218 ,1 +5949 ,0 +639 ,1 +6982 ,0 +3540 ,1 +5120 ,1 +3826 ,0 +3808 ,0 +1920 ,1 +6393 ,1 +1670 ,1 +6636 ,0 +6837 ,1 +4403 ,0 +1730 ,0 +944 ,1 +6860 ,0 +2510 ,1 +1057 ,1 +1287 ,1 +5950 ,1 +2447 ,1 +712 ,1 +2583 ,0 +6359 ,0 +4382 ,0 +2963 ,1 +840 ,1 +3939 ,1 +1346 ,1 +4496 ,0 +592 ,1 +6974 ,0 +3666 ,0 +5282 ,0 +4850 ,1 +1695 ,1 +2494 ,1 +6905 ,0 +4590 ,1 +1280 ,1 +6690 ,0 +1544 ,1 +1132 ,1 +3758 ,1 +822 ,1 +6358 ,0 +3744 ,1 +6828 ,0 +3063 ,0 +6660 ,0 +1414 ,1 +651 ,1 +4562 ,1 +6950 ,0 +6184 ,1 +3047 ,0 +6889 ,0 +4372 ,1 +3650 ,1 +2432 ,1 +2090 ,1 +5215 ,1 +1219 ,1 +2501 ,1 +1365 ,1 +3848 ,0 +2981 ,1 +6567 ,0 +3328 ,1 +3911 ,1 +1317 ,1 +3894 ,1 +6053 ,1 +5958 ,0 +2204 ,1 +2287 ,1 +1719 ,1 +2749 ,1 +2577 ,0 +5698 ,0 +1371 ,1 +1275 ,1 +5288 ,1 +4252 ,1 +1085 ,1 +703 ,1 +6023 ,1 +6990 ,0 +3509 ,0 +764 ,1 +3049 ,1 +1360 ,1 +2708 ,1 +7351 ,0 +1244 ,1 +3880 ,1 +4734 ,1 +6239 ,0 +5619 ,0 +6315 ,0 +6404 ,0 +6261 ,0 +879 ,1 +6118 ,0 +1194 ,1 +4713 ,0 +3530 ,1 +4518 ,1 +674 ,1 +2399 ,1 +6329 ,0 +6050 ,0 +6083 ,0 +6048 ,0 +1445 ,1 +1062 ,1 +6015 ,0 +5301 ,0 +6269 ,0 +1946 ,1 +5923 ,0 +3537 ,1 +5678 ,0 +6185 ,0 +3877 ,1 +5933 ,0 +1437 ,1 +943 ,1 +1500 ,1 +3083 ,1 +6208 ,0 +1200 ,1 +1753 ,0 +911 ,1 +5981 ,0 +2107 ,0 +5826 ,0 +2317 ,1 +2118 ,1 +1759 ,1 +3724 ,1 +5159 ,0 +604 ,1 +5806 ,0 +4518 ,1 +5532 ,0 +2079 ,1 +5930 ,1 +3527 ,1 +6080 ,0 +4743 ,0 +1494 ,0 +1479 ,1 +3921 ,1 +2361 ,0 +5943 ,1 +2664 ,1 +5283 ,1 +5966 ,0 +6077 ,0 +5953 ,0 +5423 ,1 +575 ,1 +2394 ,1 +5603 ,0 +3806 ,0 +4961 ,0 +2257 ,1 +5224 ,1 +2993 ,1 +1967 ,1 +5638 ,0 +2454 ,1 +5782 ,0 +5762 ,0 +2415 ,1 +5637 ,0 +5688 ,0 +5221 ,0 +4348 ,0 +3258 ,1 +1134 ,1 +2545 ,1 +3558 ,1 +5696 ,0 +1582 ,1 +3176 ,0 +1864 ,1 +2375 ,1 +3674 ,0 +3851 ,1 +4183 ,0 +5049 ,1 +2591 ,1 +5920 ,0 +5530 ,0 +1037 ,1 +4518 ,1 +5631 ,0 +1690 ,1 +5362 ,1 +21 ,0 +4503 ,1 +1183 ,0 +5048 ,0 +5715 ,0 +2745 ,0 +5544 ,0 +2841 ,0 +5475 ,0 +5731 ,0 +1507 ,1 +2245 ,0 +5908 ,0 +806 ,1 +2613 ,1 +5736 ,0 +1740 ,1 +5936 ,0 +2527 ,1 +3696 ,0 +5237 ,1 +4273 ,1 +1848 ,0 +5592 ,1 +597 ,1 +1030 ,1 +861 ,1 +3675 ,1 +5865 ,0 +5354 ,0 +2479 ,1 +3752 ,1 +657 ,1 +5269 ,0 +461 ,1 +2322 ,0 +2153 ,1 +5758 ,0 +5615 ,0 +1793 ,0 +834 ,1 +3447 ,1 +3771 ,1 +2709 ,1 +3726 ,1 +819 ,1 +3615 ,0 +5718 ,0 +5910 ,0 +5469 ,0 +1118 ,1 +5362 ,0 +1298 ,0 +1110 ,1 +2699 ,1 +4042 ,1 +3838 ,1 +2502 ,0 +4777 ,1 +608 ,0 +5820 ,0 +3325 ,1 +3478 ,1 +5125 ,0 +5529 ,0 +547 ,1 +5812 ,0 +2632 ,1 +5260 ,0 +4930 ,0 +5553 ,0 +2415 ,0 +2422 ,1 +5637 ,0 +1559 ,1 +5550 ,1 +5528 ,0 +664 ,1 +4536 ,1 +5616 ,0 +3379 ,1 +1957 ,1 +5530 ,0 +5503 ,0 +4028 ,1 +3556 ,1 +2680 ,1 +2324 ,1 +5438 ,0 +582 ,1 +3359 ,1 +1447 ,0 +4265 ,1 +1628 ,1 +5617 ,0 +2821 ,1 +5706 ,0 +5414 ,0 +1569 ,1 +2909 ,1 +5391 ,0 +5559 ,0 +5453 ,0 +4557 ,0 +1476 ,0 +4514 ,0 +5283 ,0 +1302 ,1 +3221 ,1 +3561 ,1 +5598 ,0 +5354 ,0 +3939 ,1 +3355 ,1 +4340 ,1 +2982 ,1 +1298 ,1 +5349 ,0 +2296 ,0 +5422 ,0 +1019 ,1 +5275 ,0 +497 ,1 +5131 ,0 +5203 ,0 +5488 ,0 +5000 ,0 +1324 ,1 +5157 ,0 +3115 ,1 +1153 ,1 +3724 ,0 +3374 ,1 +3062 ,1 +1607 ,0 +3242 ,1 +3032 ,1 +1385 ,1 +5102 ,0 +3494 ,1 +5153 ,0 +2509 ,1 +2132 ,1 +5584 ,0 +5596 ,0 +5025 ,0 +5602 ,0 +1826 ,1 +2279 ,1 +5908 ,0 +2999 ,1 +2063 ,1 +2336 ,1 +2379 ,0 +1960 ,1 +3720 ,1 +3852 ,1 +2610 ,1 +1108 ,1 +4606 ,0 +1453 ,1 +5240 ,0 +3394 ,0 +2268 ,1 +414 ,1 +2547 ,1 +501 ,1 +4970 ,0 +5222 ,0 +5085 ,0 +4483 ,1 +4792 ,0 +3143 ,1 +3524 ,0 +1730 ,1 +1137 ,1 +1396 ,1 +1291 ,1 +1609 ,1 +425 ,1 +3069 ,1 +4732 ,0 +5049 ,0 +5116 ,0 +2886 ,1 +4906 ,0 +426 ,1 +1372 ,1 +3436 ,1 +3724 ,1 +2965 ,1 +2562 ,1 +757 ,1 +1550 ,1 +5147 ,0 +5305 ,0 +5192 ,0 +5092 ,0 +5460 ,0 +5090 ,0 +3784 ,0 +1609 ,1 +5405 ,0 +865 ,1 +4834 ,1 +461 ,1 +3773 ,1 +691 ,1 +5059 ,0 +3004 ,1 +1179 ,1 +3579 ,1 +76 ,0 +3060 ,1 +368 ,1 +5137 ,0 +3487 ,1 +3410 ,1 +2135 ,1 +4464 ,1 +5082 ,0 +1369 ,1 +4408 ,1 +3743 ,1 +928 ,1 +5066 ,0 +4620 ,1 +5040 ,0 +4955 ,0 +4970 ,0 +5114 ,0 +5386 ,0 +3204 ,1 +2105 ,0 +3233 ,1 +2872 ,1 +2242 ,1 +3528 ,0 +2923 ,1 +3011 ,0 +3042 ,0 +5262 ,0 +2787 ,0 +780 ,1 +3850 ,0 +4735 ,1 +4218 ,1 +3212 ,0 +2260 ,0 +3018 ,0 +5245 ,0 +1382 ,1 +2506 ,1 +1484 ,1 +1039 ,1 +4672 ,1 +4514 ,0 +921 ,1 +476 ,1 +2955 ,1 +2421 ,0 +1355 ,1 +3511 ,0 +291 ,1 +5194 ,0 +4588 ,0 +4885 ,1 +4466 ,1 +3708 ,0 +4923 ,0 +485 ,1 +1719 ,1 +1520 ,1 +1854 ,1 +4828 ,1 +4891 ,0 +2724 ,1 +4490 ,0 +1744 ,1 +3496 ,1 +564 ,1 +4895 ,0 +2981 ,0 +1262 ,1 +4869 ,0 +587 ,1 +3777 ,1 +5089 ,0 +76 ,1 +3309 ,0 +1029 ,1 +1391 ,1 +530 ,1 +1273 ,1 +4973 ,0 +954 ,1 +3699 ,1 +5296 ,0 +1536 ,1 +1975 ,1 +701 ,1 +1717 ,1 +4662 ,0 +5187 ,0 +3001 ,0 +3712 ,0 +1799 ,0 +3583 ,0 +3093 ,1 +4437 ,0 +1099 ,1 +4771 ,0 +926 ,1 +4816 ,0 +489 ,1 +3744 ,0 +60 ,0 +4532 ,0 +4312 ,0 +444 ,1 +4035 ,0 +4820 ,0 +2396 ,1 +4741 ,0 +2373 ,1 +3604 ,1 +1453 ,1 +4989 ,0 +1033 ,1 +5036 ,0 +2957 ,0 +760 ,0 +2668 ,1 +771 ,0 +3542 ,0 +1041 ,1 +4635 ,1 +4773 ,0 +4405 ,0 +751 ,0 +565 ,1 +4802 ,0 +3089 ,1 +3530 ,1 +3461 ,0 +4737 ,0 +232 ,0 +4616 ,1 +4697 ,0 +3530 ,1 +4767 ,0 +1662 ,1 +2707 ,1 +2149 ,1 +4056 ,0 +3341 ,0 +2733 ,1 +1617 ,1 +2434 ,1 +4807 ,0 +535 ,0 +4648 ,1 +790 ,1 +2775 ,1 +4912 ,0 +982 ,1 +4521 ,0 +4215 ,0 +4669 ,0 +3663 ,0 +4369 ,0 +3711 ,0 +4267 ,0 +3493 ,1 +4778 ,0 +3846 ,1 +1262 ,1 +1911 ,0 +4134 ,1 +4080 ,0 +969 ,1 +2489 ,0 +1278 ,1 +3114 ,1 +1145 ,1 +1051 ,1 +1887 ,1 +632 ,1 +3249 ,1 +4551 ,0 +4006 ,0 +3625 ,0 +4213 ,0 +667 ,1 +4522 ,0 +924 ,1 +4152 ,0 +630 ,1 +3140 ,1 +4263 ,0 +1915 ,0 +4524 ,0 +3792 ,1 +4059 ,1 +1907 ,1 +2116 ,1 +2073 ,1 +4441 ,0 +6334 ,1 +4333 ,1 +5758 ,1 +4306 ,1 +2127 ,1 +4295 ,1 +406 ,1 +1448 ,1 +4690 ,0 +1562 ,1 +6172 ,0 +6749 ,0 +1492 ,1 +6933 ,0 +4361 ,1 +5683 ,0 +2194 ,1 +5861 ,0 +5214 ,0 +1136 ,1 +5046 ,1 +6201 ,0 +5023 ,0 +5605 ,1 +6032 ,0 +6060 ,0 +642 ,0 +680 ,1 +860 ,0 +1141 ,1 +6112 ,0 +468 ,1 +5656 ,1 +5478 ,1 +2150 ,0 +1722 ,1 +5470 ,1 +5108 ,1 +5294 ,0 +1623 ,1 +3075 ,1 +4921 ,1 +730 ,0 +69 ,1 +4717 ,0 +319 ,1 +302 ,1 +2325 ,1 +1410 ,0 +3391 ,0 +3700 ,1 +2990 ,0 +3737 ,1 +2005 ,0 +3610 ,0 +5215 ,0 +1603 ,0 +3469 ,1 +4617 ,1 +1301 ,0 +1854 ,1 +2358 ,1 +769 ,1 +2366 ,1 +4046 ,0 +3133 ,0 +1030 ,1 +2688 ,1 +1669 ,0 +2248 ,1 +2330 ,1 +2310 ,0 +3653 ,0 +1758 ,1 +476 ,1 +3752 ,0 +2660 ,1 +135 ,0 +3684 ,1 +4361 ,0 +2542 ,1 +2768 ,0 +994 ,1 +1754 ,1 +1648 ,1 +3595 ,0 +2374 ,1 +4794 ,0 +1992 ,0 +1133 ,1 +713 ,0 +2024 ,0 +739 ,1 +1279 ,1 +1574 ,1 +3335 ,0 +902 ,1 +3528 ,0 +1703 ,1 +4614 ,0 +3549 ,1 +775 ,0 +1269 ,1 +2088 ,0 +3409 ,0 +737 ,0 +4488 ,0 +4434 ,1 +3836 ,0 +1064 ,1 +3536 ,0 +1061 ,1 +3779 ,0 +3 ,1 +2244 ,1 +2177 ,1 +2550 ,1 +1704 ,0 +3344 ,0 +3156 ,1 +1608 ,0 +3622 ,0 +2637 ,0 +3723 ,0 +1586 ,1 +2254 ,0 +441 ,1 +2933 ,0 +1827 ,1 +2037 ,0 +2730 ,0 +2276 ,0 +3352 ,0 +672 ,1 +3150 ,0 +836 ,1 +105 ,1 +2772 ,1 +2650 ,1 +1486 ,1 +1056 ,1 +3009 ,1 +2204 ,0 +1915 ,1 +1528 ,0 +2658 ,0 +1761 ,0 +185 ,0 +1996 ,0 +2570 ,0 +877 ,1 +427 ,0 +2401 ,0 +2411 ,1 +548 ,1 +2022 ,0 +1869 ,0 +1889 ,0 +1737 ,0 +516 ,1 +1582 ,0 +4588 ,1 +6160 ,0 +1136 ,1 +2551 ,1 +501 ,1 +1816 ,0 +1940 ,0 +1738 ,0 +1625 ,0 +1514 ,0 +1360 ,0 +1778 ,0 +1758 ,0 +1295 ,1 +1643 ,1 +2221 ,0 +1785 ,0 +595 ,1 +1393 ,0 +2061 ,0 +2345 ,0 +1805 ,0 +2135 ,1 +2394 ,0 +2143 ,1 +2565 ,0 +803 ,1 +2108 ,0 +2407 ,0 +2008 ,0 +1103 ,1 +2694 ,0 +2306 ,0 +940 ,1 +2561 ,0 +2144 ,0 +2253 ,0 +962 ,1 +2132 ,0 +2592 ,0 +2431 ,0 +1476 ,1 +2155 ,0 +2843 ,1 +2655 ,0 +2746 ,0 +2605 ,0 +2847 ,1 +2404 ,0 +2654 ,0 +3061 ,0 +1954 ,1 +2578 ,1 +3184 ,0 +3664 ,0 +913 ,1 +274 ,1 +3510 ,0 +1315 ,1 +3675 ,0 +3580 ,0 +3056 ,0 +2816 ,0 +2876 ,0 +2598 ,0 +3236 ,0 +2693 ,0 +3150 ,0 +2954 ,0 +3438 ,0 +3242 ,0 +3062 ,0 +1570 ,1 +3477 ,0 +3245 ,0 +3349 ,0 +2951 ,0 +3038 ,0 +125 ,1 +2732 ,0 +2737 ,0 +3066 ,0 +3164 ,0 +3496 ,0 +3303 ,0 +3706 ,0 +152 ,1 +3706 ,0 +3567 ,0 +3293 ,0 +3308 ,0 +3672 ,0 +3064 ,0 +3401 ,0 +3469 ,0 +3720 ,0 +1955 ,1 +3597 ,0 +3538 ,0 +3728 ,0 +3123 ,1 +3467 ,0 +1125 ,1 +1708 ,1 +3526 ,0 +905 ,1 +1941 ,1 +3297 ,0 +3698 ,0 +3559 ,0 +1001 ,1 +4217 ,0 +3851 ,0 +2665 ,1 +897 ,1 +4130 ,0 +4188 ,0 +4060 ,0 +1808 ,1 +3089 ,1 +4042 ,0 +3948 ,0 +715 ,1 +802 ,1 +4120 ,0 +3520 ,1 +2724 ,1 +4157 ,0 +4377 ,0 +4134 ,0 +3079 ,1 +4150 ,0 +3884 ,0 +570 ,1 +1339 ,1 +3869 ,0 +4333 ,0 +4136 ,0 +4284 ,0 +3254 ,1 +4493 ,0 +4673 ,0 +1393 ,1 +658 ,1 +4442 ,0 +3646 ,0 +4206 ,0 +1579 ,1 +4374 ,1 +4415 ,0 +4231 ,0 +4515 ,0 +4662 ,0 +4414 ,1 +4479 ,0 +1906 ,1 +4223 ,0 +4554 ,0 +4726 ,0 +1484 ,1 +3751 ,1 +146 ,0 +4403 ,1 +4357 ,0 +4308 ,1 +2624 ,1 +1400 ,1 +2265 ,1 +3213 ,1 +4639 ,0 +502 ,1 +2280 ,1 +2928 ,1 +4903 ,0 +4981 ,0 +1672 ,1 +1378 ,1 +1575 ,1 +4960 ,0 +4608 ,0 +2712 ,1 +4963 ,0 +4956 ,0 +4548 ,0 +1650 ,1 +1601 ,1 +1896 ,1 +5212 ,0 +3348 ,1 +5478 ,0 +4883 ,0 +5328 ,0 +4736 ,0 +5190 ,0 +5124 ,0 +4804 ,0 +718 ,1 +698 ,1 +5133 ,0 +4464 ,1 +5525 ,0 +1760 ,1 +1530 ,1 +647 ,1 +1071 ,1 +1799 ,1 +5766 ,0 +1041 ,1 +452 ,1 +935 ,0 +3121 ,1 +4949 ,1 +1289 ,1 +2045 ,1 +5831 ,0 +5704 ,0 +635 ,1 +5255 ,0 +6246 ,0 +1332 ,1 +5560 ,0 +5611 ,0 +5866 ,0 +639 ,1 +5866 ,0 +1486 ,1 +866 ,1 +5998 ,0 +3270 ,1 +4759 ,1 +2909 ,1 +812 ,1 +3800 ,1 +205 ,1 +2354 ,1 +5977 ,0 +2482 ,1 +5906 ,0 +1342 ,1 +5279 ,1 +2587 ,1 +5893 ,0 diff --git a/datasets/pbc2.csv b/datasets/pbc2.csv index 87bcc36a..11625db8 100644 --- a/datasets/pbc2.csv +++ b/datasets/pbc2.csv @@ -1943,4 +1943,4 @@ "1942","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",0.564012703975468,"No","No","No","No edema",5.5,NA,3.2,1678,124,189,10.9,2,0 "1943","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",1.06779104150695,"No","No","No","No edema",7.4,312,3.56,1767,166,148,11.7,2,0 "1944","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.12189245427664,"No","No","Yes","edema no diuretics",16.3,688,3.34,2460,173,138,13,2,0 -"1945","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.94327017851276,"No","No","Yes","edema no diuretics",23.4,741,3.42,3012,200,128,13.4,3,0 \ No newline at end of file +"1945","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.94327017851276,"No","No","Yes","edema no diuretics",23.4,741,3.42,3012,200,128,13.4,3,0 diff --git a/datasets/pbc_dataset.py b/datasets/pbc_dataset.py index e41a7347..66f5de03 100644 --- a/datasets/pbc_dataset.py +++ b/datasets/pbc_dataset.py @@ -1,6 +1,7 @@ import os -import pandas as pd + import numpy as np +import pandas as pd from sklearn.impute import SimpleImputer @@ -12,7 +13,8 @@ def load_pbc2_dataset(): with missing values imputed. The function also constructs the outcome variable for competing risks analysis and returns time-to-event data. - Returns: + Returns + ------- tuple: A tuple containing the following elements: - x (numpy.ndarray): Combined feature matrix with categorical features one-hot encoded and continuous features imputed. @@ -25,56 +27,56 @@ def load_pbc2_dataset(): - n_continuous (int): Number of continuous features. - feature_ranges (list of tuple): List of (min, max) ranges for each feature. """ - - file_path = os.path.join(os.path.dirname(__file__), "pbc2.csv") data = pd.read_csv(file_path) - data = data.drop(columns=['id', 'sno.', 'year', 'status2'], axis=1) - + data = data.drop(columns=["id", "sno.", "year", "status2"], axis=1) event_type = np.where( - data['status'] == 'dead', 1, - np.where(data['status'] == 'transplanted', 2, 0) + data["status"] == "dead", 1, np.where(data["status"] == "transplanted", 2, 0) ) - cont_cols = [ - 'age', 'serBilir', 'serChol', 'albumin', 'alkaline', - 'SGOT', 'platelets', 'prothrombin', 'histologic' + "age", + "serBilir", + "serChol", + "albumin", + "alkaline", + "SGOT", + "platelets", + "prothrombin", + "histologic", ] - x_cont = data[cont_cols].replace('NA', np.nan).astype(float) + x_cont = data[cont_cols].replace("NA", np.nan).astype(float) - - mean_imputer = SimpleImputer(strategy='mean') + mean_imputer = SimpleImputer(strategy="mean") x_cont_imputed = mean_imputer.fit_transform(x_cont) - cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0), - np.nanmax(x_cont_imputed, axis=0))) + cont_feature_ranges = list( + zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0)) + ) - - cat_cols = ['sex', 'drug', 'ascites', 'hepatomegaly', 'spiders', 'edema'] - x_cat = data[cat_cols].fillna('missing') + cat_cols = ["sex", "drug", "ascites", "hepatomegaly", "spiders", "edema"] + x_cat = data[cat_cols].fillna("missing") - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat_imputed = cat_imputer.fit_transform(x_cat) x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols) x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True) - - x_cat_encoded = x_cat_encoded.loc[:, ~x_cat_encoded.columns.str.contains('_missing')] + x_cat_encoded = x_cat_encoded.loc[ + :, ~x_cat_encoded.columns.str.contains("_missing") + ] cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1] - x = np.hstack([x_cat_encoded.values, x_cont_imputed]) feature_names = np.concatenate([x_cat_encoded.columns, cont_cols]) n_continuous = len(cont_cols) feature_ranges = cat_feature_ranges + cont_feature_ranges - - t = data['years'].astype(float).values * 365.25 + t = data["years"].astype(float).values * 365.25 valid = ~np.isnan(t) return ( @@ -83,10 +85,9 @@ def load_pbc2_dataset(): event_type[valid], feature_names, n_continuous, - feature_ranges + feature_ranges, ) if __name__ == "__main__": x, t, e, feature_names, n_continuous, feature_ranges = load_pbc2_dataset() - \ No newline at end of file diff --git a/datasets/support2.csv b/datasets/support2.csv index b090d1b2..07c9d5c2 100644 --- a/datasets/support2.csv +++ b/datasets/support2.csv @@ -9103,4 +9103,4 @@ sno,age,death,sex,hospdead,slos,d.time,dzgroup,dzclass,num.co,edu,income,scoma,c 9102, 55.15399,0,female,0, 29, 347,Coma,Coma,1,11,, 41, 35377.000, 23558.5000, 22131.04690,18.000000,white,25.7968750, 31,0.5539550780,0.4859619140, 1,0,0,no,0.500000000,5.000000e-01,no dnr, 29, 43.0,, 0.000000, 8,38.59375,218.50000,,, 5.89941406,135,7.289062, 190.000000, 49.000000, 0.000000,, 0,,0.0000000 9103, 70.38196,0,male,0, 8, 346,ARF/MOSF w/Sepsis,ARF/MOSF,1,,, 0, 46564.000, 31409.0156, 31131.25000,23.000000,white,22.6992188, 39,0.7419433590,0.6608886720, 18,0,0,no,0.899999619,7.999997e-01,no dnr, 8,111.0, 8.39843750, 83.000000,24,36.69531,180.00000,, 0.39996338, 2.69970703,139,7.379883, 189.000000, 60.000000,3900.000000,,,,2.5253906 9104, 47.01999,1,male,1, 7, 7,MOSF w/Malig,ARF/MOSF,1,13,, 0, 58439.000,,,35.500000,white,40.1953125, 51,0.1779785160,0.0919952393, 22,0,0,yes,0.089999974,8.999997e-02,dnr after sadm, 5, 99.0, 7.59960938,110.000000,24,36.39844,428.56250, 1.1999512, 0.39996338, 3.50000000,135,7.469727, 246.000000, 55.000000,,, 0,<2 mo. follow-up,0.0000000 -9105, 81.53894,1,female,0, 12, 198,ARF/MOSF w/Sepsis,ARF/MOSF,1, 8,$11-$25k, 0, 15604.000, 10605.7578, 12146.15620,13.500000,white,18.0976562, 7,0.8328857420,0.7769775390, 1,1,0,no,,,no dnr, 12, 75.0, 8.59960938, 69.000000,24,36.19531,230.40625, 4.5000000, 0.59997559, 1.19995117,137,7.289062, 187.000000, 15.000000,, 0,,no(M2 and SIP pres),0.4947510 \ No newline at end of file +9105, 81.53894,1,female,0, 12, 198,ARF/MOSF w/Sepsis,ARF/MOSF,1, 8,$11-$25k, 0, 15604.000, 10605.7578, 12146.15620,13.500000,white,18.0976562, 7,0.8328857420,0.7769775390, 1,1,0,no,,,no dnr, 12, 75.0, 8.59960938, 69.000000,24,36.19531,230.40625, 4.5000000, 0.59997559, 1.19995117,137,7.289062, 187.000000, 15.000000,, 0,,no(M2 and SIP pres),0.4947510 diff --git a/datasets/support_dataset.py b/datasets/support_dataset.py index 2146e92f..ce1aae67 100644 --- a/datasets/support_dataset.py +++ b/datasets/support_dataset.py @@ -1,10 +1,11 @@ import os -import pandas as pd + import numpy as np +import pandas as pd + # Enable and import IterativeImputer (MICE) from scikit-learn: from sklearn.experimental import enable_iterative_imputer # noqa -from sklearn.impute import IterativeImputer, SimpleImputer - +from sklearn.impute import SimpleImputer def load_support_dataset(): @@ -12,9 +13,11 @@ def load_support_dataset(): Load and preprocess the SUPPORT dataset. This function reads the SUPPORT dataset from a CSV file, imputes missing values, encodes categorical features, and constructs the outcome variable for survival analysis. - It returns the processed features, time-to-event data, event types, feature names, + It returns the processed features, time-to-event data, event types, feature names, the number of continuous features, and feature ranges. - Returns: + + Returns + ------- tuple: A tuple containing: - x (numpy.ndarray): Combined array of processed categorical and continuous features. - t (numpy.ndarray): Time-to-event data with a +1 offset. @@ -22,7 +25,9 @@ def load_support_dataset(): - feature_names (numpy.ndarray): Array of feature names. - n_continuous (int): Number of continuous features. - feature_ranges (list): List of tuples representing the range (min, max) for each feature. - Notes: + + Notes + ----- - The dataset file "support2.csv" must be located in the same directory as this script. - Continuous features are imputed using the median strategy if missing values are present. - Categorical features are imputed using the most frequent strategy if missing values are present. @@ -30,72 +35,87 @@ def load_support_dataset(): - The outcome variable is constructed using the 'ca', 'dzgroup', and 'death' columns. - Time-to-event data ('d.time') is offset by +1 to avoid zero follow-up times. """ - - file_path = os.path.join(os.path.dirname(__file__), "support2.csv") data = pd.read_csv(file_path) print(data.columns) - + is_cancer = data["ca"].astype(str).str.lower().str.contains("meta") | data[ + "dzgroup" + ].astype(str).str.lower().str.contains("cancer") + event_type = np.where(data["death"] == 1, np.where(is_cancer, 1, 2), 0) - - is_cancer = data['ca'].astype(str).str.lower().str.contains("meta") | \ - data['dzgroup'].astype(str).str.lower().str.contains("cancer") - event_type = np.where(data['death'] == 1, np.where(is_cancer, 1, 2), 0) - - - - cont_cols = [ - 'age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 'pafi', 'alb', - 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 'urine', - 'scoma', 'aps', 'sps', 'adls', 'adlsc', 'charges', 'totcst', 'totmcst', 'avtisst' + "age", + "num.co", + "meanbp", + "wblc", + "hrt", + "resp", + "temp", + "pafi", + "alb", + "bili", + "crea", + "sod", + "ph", + "glucose", + "bun", + "urine", + "scoma", + "aps", + "sps", + "adls", + "adlsc", + "charges", + "totcst", + "totmcst", + "avtisst", ] x_cont = data[cont_cols] - + if x_cont.isnull().values.any(): - simp_imputer = SimpleImputer(strategy='most_frequent') #originally was median + simp_imputer = SimpleImputer(strategy="most_frequent") # originally was median x_cont_imputed = simp_imputer.fit_transform(x_cont) else: x_cont_imputed = x_cont.values - cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0), - np.nanmax(x_cont_imputed, axis=0))) + cont_feature_ranges = list( + zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0)) + ) # -- Categorical features -- # Remove leakage fields 'ca', 'dzgroup', 'dzclass' - cat_cols = ['sex', 'income', 'race', 'dnr', 'dementia', 'diabetes'] + cat_cols = ["sex", "income", "race", "dnr", "dementia", "diabetes"] x_cat = data[cat_cols] if x_cat.isnull().values.any(): - cat_imputer = SimpleImputer(strategy='most_frequent') + cat_imputer = SimpleImputer(strategy="most_frequent") x_cat_imputed = cat_imputer.fit_transform(x_cat) else: x_cat_imputed = x_cat.values x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols) - + x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True) cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1] - x = np.hstack([x_cat_encoded.values, x_cont_imputed]) feature_names = np.concatenate([x_cat_encoded.columns, cont_cols]) n_continuous = len(cont_cols) feature_ranges = cat_feature_ranges + cont_feature_ranges - - t = data['d.time'].values + t = data["d.time"].values valid = ~np.isnan(t) - + print("Completed imputation of missing values.") return ( x[valid], - t[valid] + 1, # Add +1 offset to follow-up times + t[valid] + 1, # Add +1 offset to follow-up times event_type[valid], feature_names, n_continuous, - feature_ranges + feature_ranges, ) + if __name__ == "__main__": x, t, e, feature_names, n_continuous, feature_ranges = load_support_dataset() print("Feature names:", feature_names) diff --git a/datasets/synthetic_dataset.py b/datasets/synthetic_dataset.py index 84d5cb17..abd62a1b 100644 --- a/datasets/synthetic_dataset.py +++ b/datasets/synthetic_dataset.py @@ -1,21 +1,26 @@ +import os +from typing import List, Tuple + import numpy as np import pandas as pd -from typing import Tuple, List -import os import torch -def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]]: + +def load_synthetic_dataset() -> Tuple[ + np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple] +]: """ Loads a synthetic competing risks dataset from a CSV file. - + The CSV is expected to have a header with the following columns: - time: observed time - label: event indicator (0 for censored; >0 for event types) - true_time: (optional) true time (unused here) - true_label: (optional) true event label (unused here) - feature1, feature2, ..., featureN: feature values - - Returns: + + Returns + ------- X (np.ndarray): Feature matrix of shape (n_samples, n_features). T_obs (np.ndarray): Observed times of shape (n_samples,). e (np.ndarray): Event indicators of shape (n_samples,). @@ -23,30 +28,24 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s n_continuous (int): Total number of continuous features. feature_ranges (List[tuple]): List of (min, max) tuples for each feature. """ - use_gpu_compatible_dtype = torch.cuda.is_available() - - - - + file_path = os.path.join(os.path.dirname(__file__), "synthetic_comprisk.csv") df = pd.read_csv(file_path) - - - T_obs = df["time"].values .astype(np.float32) + T_obs = df["time"].values.astype(np.float32) e = df["label"].values.astype(np.float32) - - + feature_columns = [col for col in df.columns if col.startswith("feature")] - X = df[feature_columns].values .astype(np.float32) - - + X = df[feature_columns].values.astype(np.float32) + feature_names = feature_columns n_continuous = X.shape[1] - - feature_ranges = [(float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)] - + + feature_ranges = [ + (float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous) + ] + return X, T_obs, e, feature_names, n_continuous, feature_ranges @@ -61,10 +60,3 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s # For a quick inspection, print the first 5 rows (features, time, event) print("First 5 rows:") print(np.hstack([X[:5], T_obs[:5, None], e[:5, None]])) - - - - - - - diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..e1360093 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,24 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +# NOTE: This is not generated from `sphinx-quickstart` and manually added +serve: + sphinx-autobuild $(SOURCEDIR) $(BUILDDIR) \ No newline at end of file diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..432868f3 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,47 @@ +# Python API Reference + +This section documents the Python API for CRISP-NAM. + +## Model + +::: crisp_nam.models.crisp_nam_model + options: + show_root_heading: true + members: + - CrispNamModel +::: crisp_nam.models.deephit_model + options: + show_root_heading: true + filters: + - "!^_" + members: + - DeepHit + + ## Metrics + +::: crisp_nam.metrics.calibration + options: + show_root_heading: true + members: true +::: crisp_nam.metrics.discrimination + options: + show_root_heading: true + members: true +::: crisp_nam.metrics.ipcw + options: + show_root_heading: true + members: true + + ## Utilities +::: crisp_nam.utils.plotting + options: + show_root_heading: true + members: true +::: crisp_nam.utils.risk_cif + options: + show_root_heading: true + members: true +::: crisp_nam.utils.loss + options: + show_root_heading: true + members: true \ No newline at end of file diff --git a/docs/assets/favicon-48x48.svg b/docs/assets/favicon-48x48.svg new file mode 100644 index 00000000..3cd92e57 --- /dev/null +++ b/docs/assets/favicon-48x48.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/docs/assets/favicon.ico b/docs/assets/favicon.ico new file mode 100644 index 00000000..30762370 Binary files /dev/null and b/docs/assets/favicon.ico differ diff --git a/docs/assets/launch.png b/docs/assets/launch.png new file mode 100644 index 00000000..c9a3e06f Binary files /dev/null and b/docs/assets/launch.png differ diff --git a/docs/assets/vector-logo.svg b/docs/assets/vector-logo.svg new file mode 100644 index 00000000..8dd76b5b --- /dev/null +++ b/docs/assets/vector-logo.svg @@ -0,0 +1,172 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/blog.md b/docs/blog.md new file mode 100644 index 00000000..8ceecd18 --- /dev/null +++ b/docs/blog.md @@ -0,0 +1,3 @@ +There is a blog explaining the architecture of CRISP-NAM and showcasing the results. + +Head [here](blog/index.html) to read it! \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..76614c1e --- /dev/null +++ b/docs/index.md @@ -0,0 +1,28 @@ +# CRISP-NAM: Competing Risks for Interpretable Survival Analysis using Neural Additive Models + +This repository contains research code for the paper: [CRISP-NAM: Competing Risks Interpretable Survival +Prediction with Neural Additive Models](https://ceur-ws.org/Vol-4059/paper5.pdf). +It includes the Python code for the following: + +- Models: `CRISP-NAM` and `DeepHIT` +- Data loading utilities for 4 datasets: Framingham, PBC, Support2, Synthetic +- Training scripts: Standard training, Hyperparameter optimization via Optuna, Nested cross validation +- Metrics: Loss and risk functions for survival analysis. +- Plotting: Feature importance and Shape functions for interpretability. + +## PyPI package +The core files of research: models, metrics and plotting utilities. + +## Installation +You can install the package via the following pip command: +```bash +pip install crisp_nam +``` + +## Citation +> @inproceedings{ramachandram2025crispnam,
+title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models},
+author={Ramachandram, Dhanesh and Raval, Ananya},
+booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain},
+year={2025}
+}
\ No newline at end of file diff --git a/docs/overrides/partials/copyright.html b/docs/overrides/partials/copyright.html new file mode 100644 index 00000000..776166c1 --- /dev/null +++ b/docs/overrides/partials/copyright.html @@ -0,0 +1,22 @@ + diff --git a/docs/overrides/partials/logo.html b/docs/overrides/partials/logo.html new file mode 100644 index 00000000..2ed1c763 --- /dev/null +++ b/docs/overrides/partials/logo.html @@ -0,0 +1,5 @@ +{% if config.theme.logo %} +logo +{% else %} +{{ config.site_name }} +{% endif %} diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css new file mode 100644 index 00000000..cda9da97 --- /dev/null +++ b/docs/stylesheets/extra.css @@ -0,0 +1,235 @@ +[data-md-color-primary="vector"] { + --md-primary-fg-color: #eb088a; + --md-primary-fg-color--light: #f252a5; + --md-primary-fg-color--dark: #b00068; + --md-primary-bg-color: hsla(0, 0%, 100%, 1); + --md-primary-bg-color--light: hsla(0, 0%, 100%, 0.7); +} + +[data-md-color-primary="black"] { + --md-primary-fg-color: #181818; + --md-primary-fg-color--light: #f252a5; + --md-primary-fg-color--dark: #b00068; + --md-primary-bg-color: #eb088a; +} + +[data-md-color-accent="vector-teal"] { + --md-accent-fg-color: #48c0d9; + --md-accent-fg-color--transparent: #526cfe1a; + --md-accent-bg-color: #fff; + --md-accent-bg-color--light: #ffffffb3; +} + +[data-md-color-scheme="slate"][data-md-color-primary="black"] { + --md-typeset-a-color: #eb088a; +} + +[data-md-color-scheme="default"] { + /* Default light mode styling */ +} + +[data-md-color-scheme="slate"] { + --md-typeset-a-color: #eb088a; + /* Dark mode styling */ +} + +/* Vector logo css styling to match overrides/partial/copyright.html */ +.md-footer-vector { + display: flex; + align-items: center; + padding: 0 0.6rem; +} + +.md-footer-vector img { + height: 24px; /* Reduce height to a fixed value */ + width: auto; /* Maintain aspect ratio */ + transition: opacity 0.25s; + opacity: 0.7; +} + +.md-footer-vector img:hover { + opacity: 1; +} + +/* Make the inner footer grid elements distribute evenly */ +.md-footer-meta__inner { + display: flex; + justify-content: space-between; + align-items: center; +} + +/* To make socials and Vector logo not stack when viewing on mobile */ +@media screen and (max-width: 76.234375em) { + .md-footer-meta__inner.md-grid { + flex-direction: row; + justify-content: space-between; + align-items: center; + } + + .md-copyright, + .md-social { + width: auto; + max-width: 49%; + } + + /* Prevent margin that causes stacking */ + .md-social { + margin: 0; + } +} + +/* Reduce margins for h2 when using grid cards */ +.grid.cards h2 { + margin-top: 0; /* Remove top margin completely in cards */ + margin-bottom: 0.5rem; /* Smaller bottom margin in cards */ +} + +.vector-icon { + color: #eb088a; + opacity: 0.7; + margin-right: 0.2em; +} + +/* Version selector styling - Material theme */ + +/* Version selector container */ +.md-version { + position: relative; + display: inline-block; + margin-left: 0.25rem; +} + +/* Current version button styling */ +.md-version__current { + display: inline-flex; + align-items: center; + font-size: 0.7rem; + font-weight: 600; + color: var(--md-primary-bg-color); + padding: 0.4rem 0.8rem; + margin: 0.4rem 0; + background-color: rgba(255, 255, 255, 0.1); + border-radius: 4px; + border: 1px solid rgba(255, 255, 255, 0.2); + cursor: pointer; + transition: all 0.15s ease-in-out; +} + +/* Hover effect for current version button */ +.md-version__current:hover { + background-color: rgba(255, 255, 255, 0.2); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); +} + +/* Down arrow for version dropdown */ +.md-version__current:after { + display: inline-block; + margin-left: 0.5rem; + content: ""; + vertical-align: middle; + border-top: 0.3em solid; + border-right: 0.3em solid transparent; + border-bottom: 0; + border-left: 0.3em solid transparent; +} + +/* Dropdown menu */ +.md-version__list { + position: absolute; + top: 100%; + left: 0; + z-index: 10; + min-width: 125%; + margin: 0.1rem 0 0; + padding: 0; + background-color: var(--md-primary-fg-color); + border-radius: 4px; + box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2); + opacity: 0; + visibility: hidden; + transform: translateY(-8px); + transition: all 0.2s ease; +} + +/* Show dropdown when parent is hovered */ +.md-version:hover .md-version__list { + opacity: 1; + visibility: visible; + transform: translateY(0); +} + +/* Version list items */ +.md-version__item { + list-style: none; + padding: 0; +} + +/* Version links */ +.md-version__link { + display: block; + padding: 0.5rem 1rem; + font-size: 0.75rem; + color: var(--md-primary-bg-color); + transition: background-color 0.15s; + text-decoration: none; +} + +/* Version link hover */ +.md-version__link:hover { + background-color: var(--md-primary-fg-color--dark); + text-decoration: none; +} + +/* Active version in dropdown */ +.md-version__link--active { + background-color: var(--md-accent-fg-color); + color: var(--md-accent-bg-color); + font-weight: 700; +} + +/* For the Material selector */ +.md-header__option { + display: flex; + align-items: center; +} + +/* Version selector in Material 9.x */ +.md-select { + position: relative; + margin-left: 0.5rem; +} + +.md-select__label { + font-size: 0.7rem; + font-weight: 600; + color: var(--md-primary-bg-color); + cursor: pointer; + padding: 0.4rem 0.8rem; + background-color: rgba(255, 255, 255, 0.1); + border-radius: 4px; + border: 1px solid rgba(255, 255, 255, 0.2); + transition: all 0.15s ease-in-out; +} + +.md-select__label:hover { + background-color: rgba(255, 255, 255, 0.2); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); +} + +/* Version selector in Material 9.2+ */ +.md-header__button.md-select { + display: inline-flex; + align-items: center; + margin: 0 0.8rem; +} + +/* For Material 9.x+ with specific version selector */ +.md-typeset .md-version-warn { + padding: 0.6rem 1rem; + margin: 1.5rem 0; + background-color: rgba(235, 8, 138, 0.1); + border-left: 4px solid #eb088a; + border-radius: 0.2rem; + color: var(--md-default-fg-color); + font-size: 0.8rem; +} diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 00000000..e406fcb6 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,72 @@ +Example code to show how to use two classes in the package. + +## Use of `CrispNamModel` model class + +```python +import torch + +from crisp_nam.models import CrispNamModel + +# Example usage and testing +if __name__ == "__main__": + + # Generate some test data + torch.manual_seed(42) + test_data = torch.randn(100, 5) + + # Create model with L2 normalized projections + model = CrispNamModel( + num_features=5, + num_competing_risks=3, + hidden_sizes=[32, 32], + normalize_projections=True, + ) + + # Forward pass + risk_predictions = model(test_data) + + print("First 5 risk predictions:", risk_predictions[:5]) + + #Calculate feature importance + feature_importance = model.calculate_feature_importance(test_data) + print("First 5 feature importance values:", feature_importance) + + # Analyze projection weights + print('Analyzing projection weights...') + projection_weights = model.analyze_projection_weights() +``` + +## Use of `DeepHIT` model class + +```python +import torch + +from crisp_nam.models.deephit_model import DeepHit + +# Example usage and testing +if __name__ == "__main__": + + # Generate some test data + torch.manual_seed(42) + test_data = torch.randn(100, 5) + + input_dims = { + 'x_dim': test_data.shape[1], + 'num_Event': 2, + 'num_Category': 100 + } + + network_settings = { + 'h_dim_shared': 128, + 'h_dim_CS': 32, + 'num_layers_shared': 1, + 'num_layers_CS': 2, + 'active_fn': 'tanh', + 'keep_prob': 1.0 - 0.3 #1.0 - dropout_rate + } + + model = DeepHit(input_dims, network_settings) + + out = model.predict(test_data) + print("Output shape:", out.shape) +``` diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..193ed034 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,99 @@ +extra_css: + - stylesheets/extra.css +extra: + generator: false + social: + - icon: fontawesome/brands/discord + link: 404.html + - icon: fontawesome/brands/github + link: https://github.com/VectorInstitute/crisp-nam + version: + provider: mike + default: latest +markdown_extensions: + - attr_list + - admonition + - md_in_html + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.details + - pymdownx.snippets + - pymdownx.superfences + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - toc: + permalink: true + - meta + - footnotes +nav: + - Home: index.md + - API Reference: api.md + - Usage: usage.md + - Blog: blog.md +plugins: + - search + - mike: + version_selector: true + css_dir: stylesheets + canonical_version: latest + alias_type: symlink + deploy_prefix: '' + - mkdocstrings: + default_handler: python + handlers: + python: + paths: [../crisp_nam] + options: + docstring_style: numpy + members_order: source + separate_signature: true + show_overloads: true + show_submodules: true + show_root_heading: false + show_root_full_path: true + show_root_toc_entry: false + show_symbol_type_heading: true + show_symbol_type_toc: true +repo_url: https://github.com/VectorInstitute/crisp-nam +repo_name: VectorInstitute/crisp-nam +site_name: CRISP-NAM Documentation +docs_dir: docs +theme: + name: material + custom_dir: docs/overrides + favicon: assets/favicon-48x48.svg + features: + - content.code.annotate + - content.code.copy + - navigation.footer + - navigation.indexes + - navigation.instant + - navigation.tabs + - navigation.tabs.sticky + - navigation.top + - search.suggest + - search.highlight + - toc.follow + icon: + repo: fontawesome/brands/github + logo: assets/vector-logo.svg + logo_footer: assets/vector-logo.svg + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: vector + accent: vector-teal + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + accent: vector-teal + toggle: + icon: material/brightness-4 + name: Switch to light mode \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1308ddbd..64364dbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "crisp-nam" -version = "0.1.0" +version = "0.1.1" description = "CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models" readme = "README.md" requires-python = ">=3.10" @@ -12,6 +12,7 @@ dependencies = [ "configargparse>=1.7", "lifelines>=0.30.0", "matplotlib>=3.10.1", + "mypy>=1.19.1", "openpyxl>=3.1.5", "optuna>=4.3.0", "pandas>=2.2.3", @@ -22,5 +23,104 @@ dependencies = [ "torch>=2.7.0", ] +[dependency-groups] +docs = [ + "mkdocs>=1.5.3", + "mkdocs-material>=9.5.12", + "mkdocstrings>=0.24.1", + "mkdocstrings-python>=1.8.0", + "pymdown-extensions>=10.7.1", + "mike>=2.0.0", + "click<=8.2.1", +] + +[project.urls] +Homepage = "https://github.com/VectorInstitute/crisp-nam" + +[tool.mypy] +follow_imports = "normal" +ignore_missing_imports = false +install_types = true +pretty = true +non_interactive = true +allow_untyped_defs = false +no_implicit_optional = true +check_untyped_defs = true +namespace_packages = true +explicit_package_bases = true +warn_unused_configs = true +allow_subclassing_any = false +allow_untyped_calls = false +allow_incomplete_defs = false +allow_untyped_decorators = false +warn_redundant_casts = true +warn_unused_ignores = true +implicit_reexport = false +strict_equality = true +extra_checks = true +mypy_path = "crisp_nam" + +[tool.ruff] +include = ["crisp_nam/*.py", "pyproject.toml", "*.ipynb"] +extend-exclude = ["*.csv", "*.json", "datasets/", "data_utils/", "training_scripts/"] +line-length = 88 + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "COM", # flake8-commas + "C4", # flake8-comprehensions + "RET", # flake8-return + "SIM", # flake8-simplify + "ICN", # flake8-import-conventions + "Q", # flake8-quotes + "RSE", # flake8-raise + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "W", # pycodestyle + "N", # pep8-naming + "ERA", # eradicate + "PL", # pylint +] + +fixable = ["A", "B", "COM", "C4", "RET", "SIM", "ICN", "Q", "RSE", "D", "E", "F", "I", "W", "N", "ERA", "PL"] +ignore = [ + "B905", # `zip()` without an explicit `strict=` parameter + "E501", # line too long + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "PLR2004", # Replace magic number with named constant + "PLR0913", # Too many arguments + "COM812", # Missing trailing comma + "ERA001", # Found commented-out code (too many false positives with math comments) + "A001", # Ignore variable `input` is shadowing a Python builtin (common for torch) + "A002", # Ignore variable `input` is shadowing a Python builtin in function (common for torch) + "D301", # r-strings for docstrings with backslashes +] + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.pep8-naming] +ignore-names = ["X*", "setUp"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + [tool.hatch.build.targets.sdist] -only-include = ["crisp_nam"] +only-include = ["crisp_nam"] \ No newline at end of file diff --git a/results/best_params/best_params_framingham_deephit.yaml b/results/best_params/best_params_framingham_deephit.yaml index 92e10146..27f8d8e9 100644 --- a/results/best_params/best_params_framingham_deephit.yaml +++ b/results/best_params/best_params_framingham_deephit.yaml @@ -22,4 +22,4 @@ active_fn: tanh # General parameters dropout_rate: 0.3 seed: 42 -n_folds: 5 \ No newline at end of file +n_folds: 5 diff --git a/results/best_params/best_params_pbc_deephit.yaml b/results/best_params/best_params_pbc_deephit.yaml index 7fd8f3b2..a14f4765 100644 --- a/results/best_params/best_params_pbc_deephit.yaml +++ b/results/best_params/best_params_pbc_deephit.yaml @@ -23,4 +23,4 @@ active_fn: tanh # General parameters dropout_rate: 0.3 seed: 42 -n_folds: 5 \ No newline at end of file +n_folds: 5 diff --git a/results/best_params/best_params_support2_deephit.yaml b/results/best_params/best_params_support2_deephit.yaml index 36844915..4d1df68d 100644 --- a/results/best_params/best_params_support2_deephit.yaml +++ b/results/best_params/best_params_support2_deephit.yaml @@ -27,4 +27,4 @@ active_fn: tanh # General parameters dropout_rate: 0.3 seed: 42 -n_folds: 5 \ No newline at end of file +n_folds: 5 diff --git a/results/best_params/best_params_synthetic_deephit.yaml b/results/best_params/best_params_synthetic_deephit.yaml index f794156c..4b19279b 100644 --- a/results/best_params/best_params_synthetic_deephit.yaml +++ b/results/best_params/best_params_synthetic_deephit.yaml @@ -1,19 +1,18 @@ -scaling: standard -num_epochs: 250 -batch_size: 256 -learning_rate: 1.0e-3 -l2_reg: 1.0e-5 -patience: 10 -alpha: 1.0 -beta: 0.0 -gamma: 0.0 -h_dim_shared: 128 -h_dim_CS: 32 -num_layers_shared: 1 -num_layers_CS: 2 -num_categories: 100 -active_fn: tanh -dropout_rate: 0.3 -seed: 42 -n_folds: 5 - +scaling: standard +num_epochs: 250 +batch_size: 256 +learning_rate: 1.0e-3 +l2_reg: 1.0e-5 +patience: 10 +alpha: 1.0 +beta: 0.0 +gamma: 0.0 +h_dim_shared: 128 +h_dim_CS: 32 +num_layers_shared: 1 +num_layers_CS: 2 +num_categories: 100 +active_fn: tanh +dropout_rate: 0.3 +seed: 42 +n_folds: 5 diff --git a/results/logs/nested_cv_raw_metrics_framingham.json b/results/logs/nested_cv_raw_metrics_framingham.json index 5ac18164..330d31e7 100644 --- a/results/logs/nested_cv_raw_metrics_framingham.json +++ b/results/logs/nested_cv_raw_metrics_framingham.json @@ -138,4 +138,4 @@ "scaling": "standard", "seed": 42 } -} \ No newline at end of file +} diff --git a/results/logs/nested_cv_raw_metrics_pbc.json b/results/logs/nested_cv_raw_metrics_pbc.json index 4153a978..3c7dd293 100644 --- a/results/logs/nested_cv_raw_metrics_pbc.json +++ b/results/logs/nested_cv_raw_metrics_pbc.json @@ -138,4 +138,4 @@ "scaling": "standard", "seed": 42 } -} \ No newline at end of file +} diff --git a/results/logs/nested_cv_raw_metrics_support.json b/results/logs/nested_cv_raw_metrics_support.json index 4cbca613..e6bde10b 100644 --- a/results/logs/nested_cv_raw_metrics_support.json +++ b/results/logs/nested_cv_raw_metrics_support.json @@ -138,4 +138,4 @@ "scaling": "standard", "seed": 42 } -} \ No newline at end of file +} diff --git a/results/logs/nested_cv_raw_metrics_synthetic.json b/results/logs/nested_cv_raw_metrics_synthetic.json index d982097a..f8b5e73c 100644 --- a/results/logs/nested_cv_raw_metrics_synthetic.json +++ b/results/logs/nested_cv_raw_metrics_synthetic.json @@ -138,4 +138,4 @@ "scaling": "standard", "seed": 42 } -} \ No newline at end of file +} diff --git a/training.md b/training.md new file mode 100644 index 00000000..2b13b81d --- /dev/null +++ b/training.md @@ -0,0 +1,214 @@ +# Training Instructions + +This file contains details information about training models within the `crisp_nam` repository. + +## Repository Structure +The following is the structure for the repository. Files within the `crisp_nam` folder are available as a python package. + +``` +crisp_nam/ +├── blog/ # Blog +├── crisp_nam/ # Main package +│ ├── metrics/ +│ │ ├── __init__.py +│ │ ├── calibration.py +│ │ ├── discrimination.py +│ │ └── ipcw.py +│ ├── models/ +│ │ ├── __init__.py +│ │ ├── crisp_nam_model.py +│ │ └── deephit_model.py +│ ├── utils/ +│ │ ├── __init__.py +│ │ ├── loss.py +│ │ ├── plotting.py +│ │ └── risk_cif.py +│ └── __init__.py +├── data_utils/ # Data utilities +│ ├── __init__.py +│ ├── load_datasets.py +│ └── survival_datasets.py +├── datasets/ # Dataset files and loaders +│ ├── metabric/ +│ │ ├── cleaned_features_final.csv +│ │ └── label.csv +│ ├── framingham_dataset.py +│ ├── framingham.csv +│ ├── pbc_dataset.py +│ ├── pbc2.csv +│ ├── support_dataset.py +│ ├── support2.csv +│ ├── SurvivalDataset.py +│ ├── synthetic_comprisk.csv +│ └── synthetic_dataset.py +├── docs/ # Documentation of Pypi package. +├── results/ # Results and outputs +│ ├── best_params/ # Best parameters for dataset and model combinations +│ │ ├── best_params_framingham_deephit.yaml +│ │ ├── best_params_framingham.yaml +│ │ ├── best_params_pbc_deephit.yaml +│ │ ├── best_params_pbc.yaml +│ │ ├── best_params_support.yaml +│ │ ├── best_params_support2_deephit.yaml +│ │ ├── best_params_synthetic_deephit.yaml +│ │ └── best_params_synthetic.yaml +│ ├── logs/ # Nested CV results and logs +│ │ ├── nested_cv_best_params_*.yaml +│ │ ├── nested_cv_detailed_metrics_*.csv +│ │ ├── nested_cv_metrics_*.xlsx +│ │ ├── nested_cv_raw_metrics_*.json +│ │ └── nested_cv_summary_metrics_*.csv +│ └── plots/ # Generated plots +│ ├── nested_cv_feature_importance_risk_*_*.png +│ └── nested_cv_shape_functions_risk_*_*.png +├── training_scripts/ # Training scripts +│ ├── config.yaml +│ ├── model_utils.py +│ ├── train_deephit_cuda.py +│ ├── train_deephit.py +│ ├── train_nested_cv.py # Nested cross-validation script +│ ├── train.py +│ ├── tune_optuna_optimized.py +│ └── tune_optuna.py +``` + +### Available Datasets +This repository includes 4 datasets: Framingham Heart Study, PBC, Support2 and Synthetic datasets. Detailed information is available in [datasets.md](./datasets.md). + +## Training Scripts + +The repository provides several specialized training scripts: + +- **`train.py`**: Standard model training with cross-validation and comprehensive evaluation +- **`train_nested_cv.py`**: Robust nested cross-validation for unbiased performance estimation +- **`tune_optuna.py`**: Hyperparameter optimization using Optuna's advanced algorithms +- **`tune_optuna_optimized.py`**: Hyperparameter optimization using Optuna on a GPU. +- **`train_deephit.py`**: DeepHit baseline implementation for comparative studies +- **`train_deephit_cuda.py`**: DeepHit baseline implementation optimized for running on a GPU. + +Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps. + +### Running training scripts + +1. Modify training parameters in `training_scripts/train.py` + OR + Run either of following commands to see CLI arguments for passing training parameters: + + ```bash + python training_scripts/train.py --help + ``` + + ```bash + uv run training_scripts/train.py --help + ``` + +2. Run the training script + + 1. via `python` + ```bash + source .venv/bin/activate + python training_scripts/train.py --dataset framingham + ``` + + 2. via `uv` + ```bash + uv run training_scripts/train.py --dataset framingham + ``` + +### Running Nested Cross-Validation + +The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations. + +via `python` +```bash +python training_scripts/train_nested_cv.py --dataset framingham +``` + +via `uv` +```bash +uv run training_scripts/train_nested_cv.py --dataset framingham +``` + +## Configuration Parameters + +All parameters can be passed via command line or specified in a YAML config file: + +1. Dataset Configuration +- `--dataset` (str): Dataset to use (choices: `framingham`, `support`, `pbc`, `synthetic`, default: `framingham`) +- `--scaling` (str): Data scaling method for continuous features (choices: `minmax`, `standard`, `none`, default: `standard`) + +2. Training Parameters +- `--num_epochs` (int): Number of training epochs (default: `250`) +- `--batch_size` (int): Batch size for training (default: `512`) +- `--patience` (int): Patience for early stopping (default: `10`) + +3. Cross-Validation Configuration +- `--outer_folds` (int): Number of outer CV folds (default: `5`) +- `--inner_folds` (int): Number of inner CV folds for hyperparameter tuning (default: `3`) +- `--n_trials` (int): Number of Optuna trials per inner fold (default: `20`) + +4. Event Weighting +- `--event_weighting` (str): Event weighting strategy (choices: `none`, `balanced`, `custom`, default: `none`) +- `--custom_event_weights` (str): Custom weights for events (comma-separated, default: `None`) + +5. Other Parameters +- `--seed` (int): Random seed for reproducibility (default: `42`) +- `--config` (str): Path to YAML config file (default: looks for `config.yaml`) + +### Examples + +1. **Basic nested CV with default parameters:** +```bash +python training_scripts/train_nested_cv.py --dataset pbc +``` + +2. **Customized nested CV with specific parameters:** +```bash +python training_scripts/train_nested_cv.py \ + --dataset support \ + --outer_folds 10 \ + --inner_folds 5 \ + --n_trials 50 \ + --num_epochs 500 \ + --event_weighting balanced \ + --scaling minmax \ + --seed 123 +``` + +3. **Using a config file:** +```bash +python training_scripts/train_nested_cv.py --config my_config.yaml +``` + +## Output Files + +The script generates several output files in the current directory: + +1. **Performance Metrics** +- `nested_cv_summary_metrics_{dataset}.csv`: Summary table with mean ± std metrics +- `nested_cv_detailed_metrics_{dataset}.csv`: Detailed results for each fold +- `nested_cv_metrics_{dataset}.xlsx`: Excel file with multiple sheets (Summary, Detailed, Metadata) +- `nested_cv_raw_metrics_{dataset}.json`: Raw metrics dictionary for reproducibility + +2. **Model Configuration** +- `nested_cv_best_params_{dataset}.yaml`: Aggregated best hyperparameters across all folds + +3. **Visualizations** +- `nested_cv_feature_importance_risk_{risk}_{dataset}.png`: Feature importance plots +- `nested_cv_shape_functions_risk_{risk}_{dataset}.png`: Shape function plots for top features + +Results are saved to `results/plots/`: + +## Evaluation Metrics + +The script computes the following metrics at different time quantiles (25%, 50%, 75%): + +1. **AUC (Area Under the ROC Curve)**: Time-dependent AUC for discrimination + - 0.5 = random, >0.7 = good, >0.8 = excellent +2. **TDCI (Time-Dependent Concordance Index)**: Harrell's C-index adapted for competing risks + - 0.5 = random, >0.7 = good, >0.8 = excellent +3. **Brier Score**: Calibration metric measuring prediction accuracy + - 0 = perfect, <0.25 = good, >0.25 = poor + +> [!NOTE] +> For `uv` installation, please visit follow instructions in their [official page](https://docs.astral.sh/uv/getting-started/installation/). \ No newline at end of file diff --git a/training_scripts/config.yaml b/training_scripts/config.yaml index d361646f..5fd9e691 100644 --- a/training_scripts/config.yaml +++ b/training_scripts/config.yaml @@ -17,4 +17,4 @@ batch_norm: False # Other parameters seed: 42 -n_folds: 2 \ No newline at end of file +n_folds: 2 diff --git a/training_scripts/model_utils.py b/training_scripts/model_utils.py index 98be4143..812f0e62 100644 --- a/training_scripts/model_utils.py +++ b/training_scripts/model_utils.py @@ -1,7 +1,8 @@ import random -import torch import numpy as np +import torch + def set_seed(seed=42): random.seed(seed) @@ -31,34 +32,35 @@ def step(self, val_loss): self.should_stop = True - # Utility functions to create masks for DeepHit def create_fc_mask1(k, t, num_Event, num_Category, device=None): """Create mask1 for loss calculation - for uncensored loss""" N = len(k) mask = torch.zeros((N, num_Event, num_Category), device=device) - + for i in range(N): if k[i] > 0: # Not censored event_idx = int(k[i] - 1) time_idx = int(t[i]) if time_idx < num_Category: mask[i, event_idx, time_idx] = 1.0 - + return mask + def create_fc_mask2(t, num_Category, device=None): """Create mask2 for loss calculation - for censored loss""" N = len(t) mask = torch.zeros((N, num_Category), device=device) - + for i in range(N): time_idx = int(t[i]) for j in range(time_idx, num_Category): mask[i, j] = 1.0 - + return mask + # Pre-create masks for DeepHit on GPU def create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device): """ @@ -67,15 +69,16 @@ def create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device): """ batch_size = e.size(0) mask1 = torch.zeros(batch_size, num_Event, num_Category, device=device) - + for i in range(batch_size): if e[i] > 0: # if not censored event_idx = int(e[i].item()) - 1 t_idx = int(t_disc[i].item()) mask1[i, event_idx, t_idx] = 1 - + return mask1 + def create_fc_mask2_gpu(t_disc, num_Category, device): """ Create second mask for DeepHit loss computation @@ -83,9 +86,9 @@ def create_fc_mask2_gpu(t_disc, num_Category, device): """ batch_size = t_disc.size(0) mask2 = torch.zeros(batch_size, num_Category, device=device) - + for i in range(batch_size): t_idx = int(t_disc[i].item()) mask2[i, t_idx:] = 1 - + return mask2 diff --git a/training_scripts/train.py b/training_scripts/train.py index 2192a6c3..34d626c6 100644 --- a/training_scripts/train.py +++ b/training_scripts/train.py @@ -1,79 +1,116 @@ -import configargparse from collections import defaultdict -import torch +import configargparse +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch.nn as nn -import torch.optim as optim -import matplotlib.pyplot as plt -from tabulate import tabulate -from torch.utils.data import DataLoader, Subset +import torch +from model_utils import EarlyStopping, set_seed from sklearn.model_selection import StratifiedKFold -from sklearn.preprocessing import StandardScaler, MinMaxScaler -from sksurv.util import Surv +from sklearn.preprocessing import MinMaxScaler, StandardScaler from sksurv.metrics import concordance_index_ipcw +from sksurv.util import Surv +from tabulate import tabulate +from torch import optim +from torch.utils.data import DataLoader +from crisp_nam.metrics import auc_td, brier_score from crisp_nam.models import CrispNamModel from crisp_nam.utils import ( - weighted_negative_log_likelihood_loss, + compute_baseline_cif, + compute_l2_penalty, negative_log_likelihood_loss, - compute_l2_penalty + plot_coxnam_shape_functions, + plot_feature_importance, + predict_absolute_risk, + weighted_negative_log_likelihood_loss, ) from data_utils import * -from model_utils import EarlyStopping, set_seed -from crisp_nam.metrics import brier_score, auc_td -from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif -from crisp_nam.utils import plot_coxnam_shape_functions, plot_feature_importance + def parse_args(): parser = configargparse.ArgumentParser( description="Training script for MultiTaskCoxNAM model", default_config_files=["config.yml"], - config_file_parser_class=configargparse.YAMLConfigFileParser + config_file_parser_class=configargparse.YAMLConfigFileParser, ) - - parser.add_argument("-c", "--config", is_config_file=True, - help="Path to config file") - - parser.add_argument("--dataset", type=str, default="framingham", - help="Dataset to use: (framingham, support, pbc, synthetic)") - - parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"], - help="Data scaling method for continuous features") - - parser.add_argument("--num_epochs", type=int, default=500, - help="Number of training epochs") - parser.add_argument("--batch_size", type=int, default=256, - help="Batch size for training") - parser.add_argument("--learning_rate", type=float, default=1e-3, - help="Learning rate for optimizer") - parser.add_argument("--l2_reg", type=float, default=1e-3, - help="L2 regularization weight") - parser.add_argument("--patience", type=int, default=10, - help="Patience for early stopping") - - parser.add_argument("--dropout_rate", type=float, default=0.5, - help="Dropout rate for model") - parser.add_argument("--feature_dropout", type=float, default=0.1, - help="Feature dropout rate") - parser.add_argument("--hidden_dimensions", type=str, default="64,64", - help="Hidden layer dimensions (comma-separated)") - parser.add_argument("--batch_norm", type=str, default="False", choices=["True", "False"], - help="Whether to use batch normalization") - - parser.add_argument("--seed", type=int, default=42, - help="Random seed for reproducibility") - parser.add_argument("--n_folds", type=int, default=5, - help="Number of folds for cross-validation") - - return parser.parse_args() + parser.add_argument( + "-c", "--config", is_config_file=True, help="Path to config file" + ) + parser.add_argument( + "--dataset", + type=str, + default="framingham", + help="Dataset to use: (framingham, support, pbc, synthetic)", + ) + + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + help="Data scaling method for continuous features", + ) + + parser.add_argument( + "--num_epochs", type=int, default=500, help="Number of training epochs" + ) + parser.add_argument( + "--batch_size", type=int, default=256, help="Batch size for training" + ) + parser.add_argument( + "--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer" + ) + parser.add_argument( + "--l2_reg", type=float, default=1e-3, help="L2 regularization weight" + ) + parser.add_argument( + "--patience", type=int, default=10, help="Patience for early stopping" + ) + + parser.add_argument( + "--dropout_rate", type=float, default=0.5, help="Dropout rate for model" + ) + parser.add_argument( + "--feature_dropout", type=float, default=0.1, help="Feature dropout rate" + ) + parser.add_argument( + "--hidden_dimensions", + type=str, + default="64,64", + help="Hidden layer dimensions (comma-separated)", + ) + parser.add_argument( + "--batch_norm", + type=str, + default="False", + choices=["True", "False"], + help="Whether to use batch normalization", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + parser.add_argument( + "--n_folds", type=int, default=5, help="Number of folds for cross-validation" + ) -def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_rate=1e-3, - l2_reg=0.01, patience=10, event_weights=None, verbose=True): + return parser.parse_args() + + +def train_model( + model, + train_loader, + val_loader=None, + num_epochs=500, + learning_rate=1e-3, + l2_reg=0.01, + patience=10, + event_weights=None, + verbose=True, +): optimizer = optim.AdamW(model.parameters(), lr=learning_rate) early_stopper = EarlyStopping(patience=patience) device = next(model.parameters()).device @@ -87,12 +124,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r # Use weighted loss if event_weights is provided if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg total = loss + reg @@ -110,22 +153,30 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r for x, t, e, _ in val_loader: x, t, e = x.to(device), t.to(device), e.to(device) risk_scores, _ = model(x) - + # Use same loss function as in training if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg val_loss += (loss + reg).item() avg_val_loss = val_loss / len(val_loader) early_stopper.step(avg_val_loss) if verbose: - print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") + print( + f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}" + ) if early_stopper.should_stop: if verbose: @@ -149,10 +200,10 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time abs_risks: Array of shape (n_samples, n_events, n_times) with predicted absolute risks. times: List of time points at which to evaluate. - Returns: + Returns + ------- metrics: dict of evaluation metrics by event and time. """ - survival_train = Surv.from_arrays(e_train != 0, t_train) survival_val = Surv.from_arrays(e_val != 0, t_val) @@ -162,12 +213,11 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time for k in range(n_events): for i, time in enumerate(times): risk_preds = abs_risks[:, k, i] - - + try: risk_preds_2d = np.zeros((len(risk_preds), len(times))) - risk_preds_2d[:, i] = risk_preds - + risk_preds_2d[:, i] = risk_preds + auc_score, _ = auc_td( e_val, t_val, @@ -175,16 +225,14 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, ) - metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score)) + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score)) except Exception as ex: - print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan) - try: - brier_score_val, _ = brier_score( e_val, t_val, @@ -192,30 +240,28 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, + ) + metrics[f"brier_event{k + 1}_t{time:.2f}"].append( + float(brier_score_val) ) - metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val)) except Exception as ex: - print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}") + print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}") print(f"Debug: surv_probs shape={risk_preds_2d.shape}") - metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan) + metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan) - try: tdci_result = concordance_index_ipcw( - survival_train, - survival_val, - estimate=risk_preds, - tau=time + survival_train, survival_val, estimate=risk_preds, tau=time ) if isinstance(tdci_result, tuple): tdci_score = tdci_result[0] else: - tdci_score = tdci_result - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score)) + tdci_score = tdci_result + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score)) except Exception as ex: - print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan) return metrics @@ -224,88 +270,88 @@ def display_metrics_table(metrics_dict, n_folds=5, quantiles=[0.25, 0.5, 0.75]): """ Display evaluation metrics summarized across folds for different time quantiles """ - - time_points_by_metric = {} event_types = set() - metric_types = ['auc', 'tdci', 'brier'] - + metric_types = ["auc", "tdci", "brier"] + for key, values in metrics_dict.items(): if any(metric in key for metric in metric_types): - parts = key.split('_') + parts = key.split("_") metric_type = parts[0] event_info = parts[1] - + # Extract event type - event_type = int(event_info.replace('event', '')) + event_type = int(event_info.replace("event", "")) event_types.add(event_type) - + # Extract time point - if len(parts) > 2 and parts[2].startswith('t'): - time_point = float(parts[2].replace('t', '')) - + if len(parts) > 2 and parts[2].startswith("t"): + time_point = float(parts[2].replace("t", "")) + # Initialize nested dictionaries if needed if (event_type, metric_type) not in time_points_by_metric: time_points_by_metric[(event_type, metric_type)] = {} - + # Store metrics by time point time_points_by_metric[(event_type, metric_type)][time_point] = values - + results = [] - + for event_type in sorted(event_types): - row = {'Risk': f"Type {event_type}"} - + row = {"Risk": f"Type {event_type}"} + for metric_type in metric_types: if (event_type, metric_type) not in time_points_by_metric: # Skip metrics that don't exist for this event type for q in quantiles: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" continue - + # Get all time points for this event/metric time_data = time_points_by_metric[(event_type, metric_type)] sorted_times = sorted(time_data.keys()) - + # For each quantile for q in quantiles: # Calculate the index for this quantile q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q))) - + # Get the corresponding time point for this quantile q_time = sorted_times[q_idx] - + # Get the metrics for this time point q_values = time_data[q_time] - + # Calculate and format statistics if q_values: value_array = np.array(q_values) mean_val = np.nanmean(value_array) std_val = np.nanstd(value_array) - row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}" + row[f"{metric_type.upper()}_q{q:.2f}"] = ( + f"{mean_val:.3f} ± {std_val:.3f}" + ) else: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" - + results.append(row) - + df = pd.DataFrame(results) - + # Define column order - columns = ['Risk'] - for metric in ['AUC', 'TDCI', 'BRIER']: + columns = ["Risk"] + for metric in ["AUC", "TDCI", "BRIER"]: for q in quantiles: columns.append(f"{metric}_q{q:.2f}") - + # Select columns in the right order (only those that exist) df = df[[col for col in columns if col in df.columns]] - + print("\nSummary Performance Metrics:") - print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False)) - + print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False)) + print("\nInterpretation:") print("- AUC: 0.5=random, >0.7=good, >0.8=excellent") - print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") + print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor") @@ -314,64 +360,104 @@ def parse_args(): parser = configargparse.ArgumentParser( description="Training script for MultiTaskCoxNAM model", default_config_files=["config.yaml"], - config_file_parser_class=configargparse.YAMLConfigFileParser + config_file_parser_class=configargparse.YAMLConfigFileParser, ) - + # Config file option - parser.add_argument("-c", "--config", is_config_file=True, - help="Path to config file") - + parser.add_argument( + "-c", "--config", is_config_file=True, help="Path to config file" + ) + # Dataset - parser.add_argument("--dataset", type=str, default="framingham", choices=["framingham", "support", "pbc", "synthetic"], - help="Dataset to use") - + parser.add_argument( + "--dataset", + type=str, + default="framingham", + choices=["framingham", "support", "pbc", "synthetic"], + help="Dataset to use", + ) + # Data scaling - parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"], - help="Data scaling method for continuous features") - + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + help="Data scaling method for continuous features", + ) + # Training parameters - parser.add_argument("--num_epochs", type=int, default=500, - help="Number of training epochs") - parser.add_argument("--batch_size", type=int, default=256, - help="Batch size for training") - parser.add_argument("--learning_rate", type=float, default=1e-3, - help="Learning rate for optimizer") - parser.add_argument("--l2_reg", type=float, default=1e-3, - help="L2 regularization weight") - parser.add_argument("--patience", type=int, default=10, - help="Patience for early stopping") - - - parser.add_argument("--dropout_rate", type=float, default=0.5, - help="Dropout rate for model") - parser.add_argument("--feature_dropout", type=float, default=0.1, - help="Feature dropout rate") - parser.add_argument("--hidden_dimensions", type=str, default="64,64", - help="Hidden layer dimensions (comma-separated)") - parser.add_argument("--batch_norm", type=str, default="False", choices=["True", "False"], - help="Whether to use batch normalization") - - - parser.add_argument("--event_weighting", type=str, default="none", - choices=["none", "balanced", "custom"], - help="Event weighting strategy (none, balanced, custom)") - parser.add_argument("--custom_event_weights", type=str, default=None, - help="Custom weights for events (comma-separated, e.g., '1.0,2.0')") - - # - parser.add_argument("--seed", type=int, default=42, - help="Random seed for reproducibility") - parser.add_argument("--n_folds", type=int, default=5, - help="Number of folds for cross-validation") - + parser.add_argument( + "--num_epochs", type=int, default=500, help="Number of training epochs" + ) + parser.add_argument( + "--batch_size", type=int, default=256, help="Batch size for training" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-3, + help="Learning rate for optimizer", + ) + parser.add_argument( + "--l2_reg", type=float, default=1e-3, help="L2 regularization weight" + ) + parser.add_argument( + "--patience", type=int, default=10, help="Patience for early stopping" + ) + + parser.add_argument( + "--dropout_rate", type=float, default=0.5, help="Dropout rate for model" + ) + parser.add_argument( + "--feature_dropout", type=float, default=0.1, help="Feature dropout rate" + ) + parser.add_argument( + "--hidden_dimensions", + type=str, + default="64,64", + help="Hidden layer dimensions (comma-separated)", + ) + parser.add_argument( + "--batch_norm", + type=str, + default="False", + choices=["True", "False"], + help="Whether to use batch normalization", + ) + + parser.add_argument( + "--event_weighting", + type=str, + default="none", + choices=["none", "balanced", "custom"], + help="Event weighting strategy (none, balanced, custom)", + ) + parser.add_argument( + "--custom_event_weights", + type=str, + default=None, + help="Custom weights for events (comma-separated, e.g., '1.0,2.0')", + ) + + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + parser.add_argument( + "--n_folds", + type=int, + default=5, + help="Number of folds for cross-validation", + ) + return parser.parse_args() - + args = parse_args() print(args) - + # Set random seed for reproducibility set_seed(args.seed) - + # Load the dataset if args.dataset.lower() == "framingham": x, t, e, feature_names, n_cont, _ = load_framingham() @@ -383,62 +469,62 @@ def parse_args(): x, t, e, feature_names, n_cont, _ = load_synthetic_dataset() else: raise ValueError(f"Dataset {args.dataset} not supported") - - + # Note: Scaling will be done inside cross-validation loop to prevent data leakage # Compute weights based on event distribution num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0) - device = ("cuda" if torch.cuda.is_available() else "cpu") - - + device = "cuda" if torch.cuda.is_available() else "cpu" + event_weights = None - + if args.event_weighting != "none": if args.event_weighting == "balanced": # Compute balanced weights (inverse of class frequencies) event_counts = np.zeros(num_competing_risks) for k in range(1, num_competing_risks + 1): - event_counts[k-1] = np.sum(e == k) - + event_counts[k - 1] = np.sum(e == k) event_counts = np.maximum(event_counts, 1) - + # Inverse frequency weighting event_weights = 1.0 / event_counts - + # Normalize weights to sum to num_competing_risks event_weights = event_weights * (num_competing_risks / event_weights.sum()) - + print(f"Computed balanced event weights: {event_weights}") - + elif args.event_weighting == "custom": if args.custom_event_weights is None: - raise ValueError("Custom event weights must be provided when using custom weighting") - + raise ValueError( + "Custom event weights must be provided when using custom weighting" + ) + custom_weights = [float(w) for w in args.custom_event_weights.split(",")] if len(custom_weights) != num_competing_risks: - raise ValueError(f"Expected {num_competing_risks} weights, got {len(custom_weights)}") - + raise ValueError( + f"Expected {num_competing_risks} weights, got {len(custom_weights)}" + ) + event_weights = np.array(custom_weights) print(f"Using custom event weights: {event_weights}") - + event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device) skf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed) all_metrics = defaultdict(list) quantiles = [0.25, 0.5, 0.75] - safe_max = 0.99 * np.max(t) eval_times = np.quantile(t[t <= safe_max], quantiles) - + print(f"Evaluation times: {eval_times}") - + # Parse hidden dimensions string to list of integers hidden_dimensions = [int(dim) for dim in args.hidden_dimensions.split(",")] - + # Convert batch_norm string to boolean batch_norm = args.batch_norm.lower() == "true" @@ -470,25 +556,31 @@ def parse_args(): # Recalculate weights based on training set for this fold fold_event_counts = np.zeros(num_competing_risks) for k in range(1, num_competing_risks + 1): - fold_event_counts[k-1] = np.sum(e_train == k) - + fold_event_counts[k - 1] = np.sum(e_train == k) + # Avoid division by zero fold_event_counts = np.maximum(fold_event_counts, 1) - + # Inverse frequency weighting fold_event_weights = 1.0 / fold_event_counts - + # Normalize weights to sum to num_competing_risks - fold_event_weights = fold_event_weights * (num_competing_risks / fold_event_weights.sum()) - fold_event_weights = torch.tensor(fold_event_weights, dtype=torch.float32, device=device) - - print(f"Fold {fold+1} event weights: {fold_event_weights.cpu().numpy()}") + fold_event_weights = fold_event_weights * ( + num_competing_risks / fold_event_weights.sum() + ) + fold_event_weights = torch.tensor( + fold_event_weights, dtype=torch.float32, device=device + ) + + print(f"Fold {fold + 1} event weights: {fold_event_weights.cpu().numpy()}") # Create datasets with properly normalized data train_dataset = SurvivalDataset(x_train, t_train, e_train) val_dataset = SurvivalDataset(x_val, t_val, e_val) - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) + + train_loader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) val_loader = DataLoader(val_dataset, batch_size=args.batch_size) model = CrispNamModel( @@ -497,67 +589,72 @@ def parse_args(): hidden_sizes=hidden_dimensions, dropout_rate=args.dropout_rate, feature_dropout=args.feature_dropout, - batch_norm=batch_norm + batch_norm=batch_norm, ).to(device) - - train_model(model, train_loader, val_loader, - num_epochs=args.num_epochs, - learning_rate=args.learning_rate, - l2_reg=args.l2_reg, - patience=args.patience, - event_weights=fold_event_weights, - verbose=True) - + train_model( + model, + train_loader, + val_loader, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + l2_reg=args.l2_reg, + patience=args.patience, + event_weights=fold_event_weights, + verbose=True, + ) + # Calculate baseline CIFs using the same eval times for all folds - baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) - for k in range(num_competing_risks)} - - abs_risks = predict_absolute_risk(model, x_val, baseline_cifs, eval_times, device=device) - - - fold_metrics = evaluate_model(model, x_val, t_val, e_val, - t_train, e_train, abs_risks, eval_times) - + baseline_cifs = { + k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) + for k in range(num_competing_risks) + } + + abs_risks = predict_absolute_risk( + model, x_val, baseline_cifs, eval_times, device=device + ) + + fold_metrics = evaluate_model( + model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times + ) + for k, v in fold_metrics.items(): all_metrics[k].extend(v) - display_metrics_table(all_metrics, n_folds=args.n_folds) - - #create figs subdirectory if not present + + # create figs subdirectory if not present import os + if not os.path.exists("figs"): os.makedirs("figs") - for risk in range(1, num_competing_risks + 1): - fig, _, top_positive, top_negative = plot_feature_importance( model=model, x_data=x, feature_names=feature_names, n_top=5, # Show top 5 positive contributors n_bottom=5, # Show top 5 negative contributors - risk_idx=risk, + risk_idx=risk, figsize=(6, 4), - output_file=f"figs/feature_importance_risk_new_{risk}_{args.dataset}.png" - ) - + output_file=f"figs/feature_importance_risk_new_{risk}_{args.dataset}.png", + ) + top_features = top_positive + top_negative - + fig, axes = plot_coxnam_shape_functions( model=model, - X=x, + X=x, risk_to_plot=risk, - feature_names=feature_names, - top_features=top_features, + feature_names=feature_names, + top_features=top_features, ncols=5, - figsize=(12,6), - output_file=f"figs/shape_functions_top_features_risk{risk}_{args.dataset}.png" + figsize=(12, 6), + output_file=f"figs/shape_functions_top_features_risk{risk}_{args.dataset}.png", ) plt.close(fig) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/training_scripts/train_deephit_cuda.py b/training_scripts/train_deephit_cuda.py index d670acb5..b861d7a3 100644 --- a/training_scripts/train_deephit_cuda.py +++ b/training_scripts/train_deephit_cuda.py @@ -1,96 +1,165 @@ import os -import configargparse from collections import defaultdict -import torch +import configargparse import numpy as np -import torch.optim as optim -from sksurv.util import Surv -from torch.utils.data import DataLoader, Subset -from sklearn.model_selection import StratifiedKFold -from sklearn.preprocessing import StandardScaler, MinMaxScaler -from sksurv.metrics import concordance_index_ipcw - -# Import the DeepHit model implementation -from data_utils import * +import torch from model_utils import ( - set_seed, EarlyStopping, create_fc_mask1_gpu, - create_fc_mask2_gpu + create_fc_mask2_gpu, + set_seed, ) +from sklearn.model_selection import StratifiedKFold +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sksurv.metrics import concordance_index_ipcw +from sksurv.util import Surv +from torch import optim +from torch.utils.data import DataLoader, Subset + +from crisp_nam.metrics import auc_td, brier_score from crisp_nam.models import DeepHit -from crisp_nam.metrics import brier_score, auc_td + +# Import the DeepHit model implementation +from data_utils import * + def parse_args(): parser = configargparse.ArgumentParser( description="Training script for DeepHit model", default_config_files=["config.yml"], - config_file_parser_class=configargparse.YAMLConfigFileParser + config_file_parser_class=configargparse.YAMLConfigFileParser, ) - - parser.add_argument("-c", "--config", is_config_file=True, - help="Path to config file") - + + parser.add_argument( + "-c", "--config", is_config_file=True, help="Path to config file" + ) + # Dataset - parser.add_argument("--dataset", type=str, default="framingham", - help="Dataset to use: (framingham, support, pbc, synthetic)") - + parser.add_argument( + "--dataset", + type=str, + default="framingham", + help="Dataset to use: (framingham, support, pbc, synthetic)", + ) + # Data scaling - parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"], - help="Data scaling method for continuous features") - + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + help="Data scaling method for continuous features", + ) + # Training parameters - parser.add_argument("--num_epochs", type=int, default=500, - help="Number of training epochs") - parser.add_argument("--batch_size", type=int, default=512, - help="Batch size for training") - parser.add_argument("--learning_rate", type=float, default=1e-3, - help="Learning rate for optimizer") - parser.add_argument("--l2_reg", type=float, default=1e-4, - help="L2 regularization weight") - parser.add_argument("--patience", type=int, default=10, - help="Patience for early stopping") - + parser.add_argument( + "--num_epochs", type=int, default=500, help="Number of training epochs" + ) + parser.add_argument( + "--batch_size", type=int, default=512, help="Batch size for training" + ) + parser.add_argument( + "--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer" + ) + parser.add_argument( + "--l2_reg", type=float, default=1e-4, help="L2 regularization weight" + ) + parser.add_argument( + "--patience", type=int, default=10, help="Patience for early stopping" + ) + # DeepHit specific parameters - parser.add_argument("--alpha", type=float, default=1.0, - help="Weight for log-likelihood loss") - parser.add_argument("--beta", type=float, default=0.5, - help="Weight for ranking loss") - parser.add_argument("--gamma", type=float, default=0.5, - help="Weight for calibration loss") - parser.add_argument("--h_dim_shared", type=int, default=64, - help="Hidden dimension for shared network") - parser.add_argument("--h_dim_CS", type=int, default=16, - help="Hidden dimension for cause-specific networks") - parser.add_argument("--num_layers_shared", type=int, default=2, - help="Number of layers in shared network") - parser.add_argument("--num_layers_CS", type=int, default=2, - help="Number of layers in cause-specific networks") - parser.add_argument("--num_categories", type=int, default=100, - help="Number of time categories for discretization") - parser.add_argument("--active_fn", type=str, default="relu", choices=["relu", "elu", "tanh"], - help="Activation function") - + parser.add_argument( + "--alpha", type=float, default=1.0, help="Weight for log-likelihood loss" + ) + parser.add_argument( + "--beta", type=float, default=0.5, help="Weight for ranking loss" + ) + parser.add_argument( + "--gamma", type=float, default=0.5, help="Weight for calibration loss" + ) + parser.add_argument( + "--h_dim_shared", + type=int, + default=64, + help="Hidden dimension for shared network", + ) + parser.add_argument( + "--h_dim_CS", + type=int, + default=16, + help="Hidden dimension for cause-specific networks", + ) + parser.add_argument( + "--num_layers_shared", + type=int, + default=2, + help="Number of layers in shared network", + ) + parser.add_argument( + "--num_layers_CS", + type=int, + default=2, + help="Number of layers in cause-specific networks", + ) + parser.add_argument( + "--num_categories", + type=int, + default=100, + help="Number of time categories for discretization", + ) + parser.add_argument( + "--active_fn", + type=str, + default="relu", + choices=["relu", "elu", "tanh"], + help="Activation function", + ) + # General parameters - parser.add_argument("--dropout_rate", type=float, default=0.2, - help="Dropout rate for model") - parser.add_argument("--seed", type=int, default=42, - help="Random seed for reproducibility") - parser.add_argument("--n_folds", type=int, default=5, - help="Number of folds for cross-validation") - parser.add_argument("--num_workers", type=int, default=8, - help="Number of workers for data loading") - parser.add_argument("--eval_freq", type=int, default=10, - help="Evaluate every N epochs during training") - parser.add_argument("--use_amp", action="store_true", - help="Use Automatic Mixed Precision for training") - + parser.add_argument( + "--dropout_rate", type=float, default=0.2, help="Dropout rate for model" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + parser.add_argument( + "--n_folds", type=int, default=5, help="Number of folds for cross-validation" + ) + parser.add_argument( + "--num_workers", type=int, default=8, help="Number of workers for data loading" + ) + parser.add_argument( + "--eval_freq", + type=int, + default=10, + help="Evaluate every N epochs during training", + ) + parser.add_argument( + "--use_amp", + action="store_true", + help="Use Automatic Mixed Precision for training", + ) + return parser.parse_args() -def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.0, gamma=1.0, - num_epochs=500, learning_rate=1e-3, l2_reg=0.01, patience=10, - eval_freq=10, use_amp=False, verbose=True): + +def train_deephit_model( + model, + train_loader, + val_loader=None, + alpha=1.0, + beta=1.0, + gamma=1.0, + num_epochs=500, + learning_rate=1e-3, + l2_reg=0.01, + patience=10, + eval_freq=10, + use_amp=False, + verbose=True, +): """ Train the DeepHit model using the three loss components Optimized with AMP support, prefetching, and CUDA streams @@ -98,35 +167,35 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1. device = next(model.parameters()).device optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_reg) early_stopper = EarlyStopping(patience=patience) - + # Get model dimensions for mask creation num_Event = model.num_Event num_Category = model.num_Category - + # Initialize scaler for mixed precision training scaler = torch.cuda.amp.GradScaler() if use_amp else None - + # Enable asynchronous CUDA execution torch.backends.cudnn.benchmark = True - + # Increase batch size if it's too small (optional) # if train_loader.batch_size < 512 and device == 'cuda': # print("Warning: Small batch size. Consider increasing for better GPU utilization.") - + # Prefetch data using CUDA streams for async data loading - prefetch_stream = torch.cuda.Stream() if device == 'cuda' else None - + prefetch_stream = torch.cuda.Stream() if device == "cuda" else None + # Precomputed masks for common events (optimization) precomputed_masks2 = {} for t_val in range(num_Category): mask = torch.zeros(1, num_Category, device=device) mask[0, t_val:] = 1 precomputed_masks2[t_val] = mask - + for epoch in range(num_epochs): model.train() total_loss = 0.0 - + # Prefetch first batch if prefetch_stream is not None: try: @@ -135,34 +204,54 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1. prefetch_batch = next(batch_iter) # Prefetch to GPU asynchronously with torch.cuda.stream(prefetch_stream): - prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [t.to(device, non_blocking=True) for t in prefetch_batch] + prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [ + t.to(device, non_blocking=True) for t in prefetch_batch + ] # Precompute masks - prefetch_mask1 = create_fc_mask1_gpu(prefetch_e, prefetch_t_disc, num_Event, num_Category, device) - prefetch_mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in prefetch_t_disc], dim=0) + prefetch_mask1 = create_fc_mask1_gpu( + prefetch_e, prefetch_t_disc, num_Event, num_Category, device + ) + prefetch_mask2 = torch.cat( + [ + precomputed_masks2[int(t_val.item())] + for t_val in prefetch_t_disc + ], + dim=0, + ) except StopIteration: prefetch_batch = None else: prefetch_batch = None - + # Main training loop with prefetching batch_iter = iter(train_loader) more_batches = True - + while more_batches: - # Synchronize with prefetched data + # Synchronize with prefetched data if prefetch_stream is not None and prefetch_batch is not None: torch.cuda.current_stream().wait_stream(prefetch_stream) x, t, e, t_disc = prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc mask1, mask2 = prefetch_mask1, prefetch_mask2 - + # Prefetch next batch try: prefetch_batch = next(batch_iter) with torch.cuda.stream(prefetch_stream): - prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [t.to(device, non_blocking=True) for t in prefetch_batch] + prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [ + t.to(device, non_blocking=True) for t in prefetch_batch + ] # Precompute masks - prefetch_mask1 = create_fc_mask1_gpu(prefetch_e, prefetch_t_disc, num_Event, num_Category, device) - prefetch_mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in prefetch_t_disc], dim=0) + prefetch_mask1 = create_fc_mask1_gpu( + prefetch_e, prefetch_t_disc, num_Event, num_Category, device + ) + prefetch_mask2 = torch.cat( + [ + precomputed_masks2[int(t_val.item())] + for t_val in prefetch_t_disc + ], + dim=0, + ) except StopIteration: prefetch_batch = None more_batches = False @@ -172,20 +261,27 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1. batch = next(batch_iter) x, t, e, t_disc = [t.to(device, non_blocking=True) for t in batch] # Create masks for loss computation - mask1 = create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device) - mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in t_disc], dim=0) + mask1 = create_fc_mask1_gpu( + e, t_disc, num_Event, num_Category, device + ) + mask2 = torch.cat( + [precomputed_masks2[int(t_val.item())] for t_val in t_disc], + dim=0, + ) except StopIteration: more_batches = False continue - + # Automatic mixed precision if use_amp: with torch.cuda.amp.autocast(): # Forward pass out, _ = model(x) # Compute loss components - loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma) - + loss = model.compute_loss( + out, t, e, mask1, mask2, alpha, beta, gamma + ) + # Optimization step with gradient scaling optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad() scaler.scale(loss).backward() @@ -196,42 +292,55 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1. out, _ = model(x) # Compute loss components loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma) - + # Regular optimization step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() - + total_loss += loss.item() - + avg_train_loss = total_loss / len(train_loader) - + # Validation if provided, but only every eval_freq epochs to save time if val_loader and (epoch % eval_freq == 0 or epoch == num_epochs - 1): model.eval() val_loss = 0.0 with torch.no_grad(): for x, t, e, t_disc in val_loader: - x, t, e, t_disc = x.to(device, non_blocking=True), t.to(device, non_blocking=True), \ - e.to(device, non_blocking=True), t_disc.to(device, non_blocking=True) - + x, t, e, t_disc = ( + x.to(device, non_blocking=True), + t.to(device, non_blocking=True), + e.to(device, non_blocking=True), + t_disc.to(device, non_blocking=True), + ) + # Create masks for validation (using precomputed masks) - mask1 = create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device) - mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in t_disc], dim=0) - + mask1 = create_fc_mask1_gpu( + e, t_disc, num_Event, num_Category, device + ) + mask2 = torch.cat( + [precomputed_masks2[int(t_val.item())] for t_val in t_disc], + dim=0, + ) + # Forward pass out, _ = model(x) - + # Compute loss - loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma) + loss = model.compute_loss( + out, t, e, mask1, mask2, alpha, beta, gamma + ) val_loss += loss.item() - + avg_val_loss = val_loss / len(val_loader) early_stopper.step(avg_val_loss) - + if verbose and (epoch % eval_freq == 0 or epoch == num_epochs - 1): - print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") - + print( + f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}" + ) + if early_stopper.should_stop: if verbose: print("Early stopping triggered.") @@ -246,100 +355,109 @@ def predict_absolute_risk_deephit(model, x_test, times, t_max=None, device="cpu" Optimized to keep more operations on GPU """ model.eval() - + # Convert x_test to tensor and move to device if not isinstance(x_test, torch.Tensor): x_test_tensor = torch.tensor(x_test, dtype=torch.float32, device=device) else: x_test_tensor = x_test.to(device) - + # Process in batches for large datasets batch_size = 1024 # Can be adjusted based on available GPU memory n_samples = x_test_tensor.shape[0] - + # Get model dimensions n_events = model.num_Event n_categories = model.num_Category n_times = len(times) - + # Set t_max if not provided if t_max is None: t_max = max(times) - + # Initialize the output tensor on GPU abs_risks = torch.zeros(n_samples, n_events, n_times, device=device) - + # Process in batches with torch.no_grad(): for i in range(0, n_samples, batch_size): end_idx = min(i + batch_size, n_samples) batch = x_test_tensor[i:end_idx] - + # Get model predictions preds, _ = model(batch) # shape: (batch_size, num_Event, num_Category) - + # For each evaluation time, calculate CIF for t_idx, t in enumerate(times): # Scale time to [0, num_categories-1] scaled_t = (t / t_max) * (n_categories - 1) - + # Find lower and upper bin indices lower_bin = int(np.floor(scaled_t)) - upper_bin = min(int(np.ceil(scaled_t)), n_categories-1) - + upper_bin = min(int(np.ceil(scaled_t)), n_categories - 1) + # Handle boundary cases if lower_bin == upper_bin or upper_bin >= n_categories: # Exactly at a bin boundary or at/beyond the last bin for k in range(n_events): # Get cumulative probability up to this bin - abs_risks[i:end_idx, k, t_idx] = torch.sum(preds[:, k, :lower_bin+1], dim=1) + abs_risks[i:end_idx, k, t_idx] = torch.sum( + preds[:, k, : lower_bin + 1], dim=1 + ) else: # Need to interpolate between bins weight_upper = scaled_t - lower_bin weight_lower = 1 - weight_upper - + for k in range(n_events): # Cumulative probability up to lower bin - cum_prob_lower = torch.sum(preds[:, k, :lower_bin+1], dim=1) - + cum_prob_lower = torch.sum(preds[:, k, : lower_bin + 1], dim=1) + # Cumulative probability up to upper bin - cum_prob_upper = torch.sum(preds[:, k, :upper_bin+1], dim=1) - + cum_prob_upper = torch.sum(preds[:, k, : upper_bin + 1], dim=1) + # Linear interpolation - abs_risks[i:end_idx, k, t_idx] = weight_lower * cum_prob_lower + weight_upper * cum_prob_upper - + abs_risks[i:end_idx, k, t_idx] = ( + weight_lower * cum_prob_lower + + weight_upper * cum_prob_upper + ) + # Move results back to CPU and convert to numpy only if needed return abs_risks.cpu().numpy() -def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, times, max_time, device="cpu"): +def evaluate_model( + model, x_val, t_val, e_val, t_train, e_train, times, max_time, device="cpu" +): """ Evaluate the DeepHit model using time-dependent AUC, Brier score, and td-C-index """ # Predict absolute risks for validation set - abs_risks_tensor = predict_absolute_risk_deephit(model, x_val, times, t_max=max_time, device=device) + abs_risks_tensor = predict_absolute_risk_deephit( + model, x_val, times, t_max=max_time, device=device + ) abs_risks = abs_risks_tensor # Now returns numpy array directly - + # Format data for scikit-survival functions survival_train = Surv.from_arrays(e_train != 0, t_train) survival_val = Surv.from_arrays(e_val != 0, t_val) - + metrics = defaultdict(list) n_events = abs_risks.shape[1] - + # Calculate metrics in parallel using joblib if available try: from joblib import Parallel, delayed - + def process_metric(k, i, time): result = {} risk_preds = abs_risks[:, k, i] - + try: # Reshape to format needed by custom function (n_samples, n_times) risk_preds_2d = np.zeros((len(risk_preds), len(times))) risk_preds_2d[:, i] = risk_preds - + auc_score, _ = auc_td( e_val, t_val, @@ -347,12 +465,12 @@ def process_metric(k, i, time): times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, ) - result[f"auc_event{k+1}_t{time:.2f}"] = float(auc_score) - except Exception as ex: - result[f"auc_event{k+1}_t{time:.2f}"] = np.nan - + result[f"auc_event{k + 1}_t{time:.2f}"] = float(auc_score) + except Exception: + result[f"auc_event{k + 1}_t{time:.2f}"] = np.nan + try: brier_score_val, _ = brier_score( e_val, @@ -361,51 +479,50 @@ def process_metric(k, i, time): times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, ) - result[f"brier_event{k+1}_t{time:.2f}"] = float(brier_score_val) - except Exception as ex: - result[f"brier_event{k+1}_t{time:.2f}"] = np.nan - + result[f"brier_event{k + 1}_t{time:.2f}"] = float(brier_score_val) + except Exception: + result[f"brier_event{k + 1}_t{time:.2f}"] = np.nan + try: tdci_result = concordance_index_ipcw( - survival_train, - survival_val, - estimate=risk_preds, - tau=time + survival_train, survival_val, estimate=risk_preds, tau=time ) if isinstance(tdci_result, tuple): tdci_score = tdci_result[0] else: tdci_score = tdci_result - result[f"tdci_event{k+1}_t{time:.2f}"] = float(tdci_score) - except Exception as ex: - result[f"tdci_event{k+1}_t{time:.2f}"] = np.nan - + result[f"tdci_event{k + 1}_t{time:.2f}"] = float(tdci_score) + except Exception: + result[f"tdci_event{k + 1}_t{time:.2f}"] = np.nan + return result - + # Collect all tasks tasks = [(k, i, time) for k in range(n_events) for i, time in enumerate(times)] - + # Run in parallel - results = Parallel(n_jobs=-1)(delayed(process_metric)(k, i, time) for k, i, time in tasks) - + results = Parallel(n_jobs=-1)( + delayed(process_metric)(k, i, time) for k, i, time in tasks + ) + # Combine results for result in results: for k, v in result.items(): metrics[k].append(v) - + except ImportError: # Fall back to sequential computation for k in range(n_events): for i, time in enumerate(times): risk_preds = abs_risks[:, k, i] - + try: # Reshape to format needed by custom function (n_samples, n_times) risk_preds_2d = np.zeros((len(risk_preds), len(times))) risk_preds_2d[:, i] = risk_preds - + auc_score, _ = auc_td( e_val, t_val, @@ -413,12 +530,12 @@ def process_metric(k, i, time): times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, ) - metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score)) - except Exception as ex: - metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan) - + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score)) + except Exception: + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan) + try: brier_score_val, _ = brier_score( e_val, @@ -427,27 +544,26 @@ def process_metric(k, i, time): times, time, km=(e_train, t_train), - primary_risk=k+1 + primary_risk=k + 1, ) - metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val)) - except Exception as ex: - metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan) - + metrics[f"brier_event{k + 1}_t{time:.2f}"].append( + float(brier_score_val) + ) + except Exception: + metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan) + try: tdci_result = concordance_index_ipcw( - survival_train, - survival_val, - estimate=risk_preds, - tau=time + survival_train, survival_val, estimate=risk_preds, tau=time ) if isinstance(tdci_result, tuple): tdci_score = tdci_result[0] else: tdci_score = tdci_result - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score)) - except Exception as ex: - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan) - + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score)) + except Exception: + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan) + return metrics @@ -457,97 +573,99 @@ def display_metrics_table(metrics_dict, n_folds=5, quantiles=[0.25, 0.5, 0.75]): """ import pandas as pd from tabulate import tabulate - + time_points_by_metric = {} event_types = set() - metric_types = ['auc', 'tdci', 'brier'] - + metric_types = ["auc", "tdci", "brier"] + for key, values in metrics_dict.items(): if any(metric in key for metric in metric_types): - parts = key.split('_') + parts = key.split("_") metric_type = parts[0] event_info = parts[1] - + # Extract event type - event_type = int(event_info.replace('event', '')) + event_type = int(event_info.replace("event", "")) event_types.add(event_type) - + # Extract time point - if len(parts) > 2 and parts[2].startswith('t'): - time_point = float(parts[2].replace('t', '')) - + if len(parts) > 2 and parts[2].startswith("t"): + time_point = float(parts[2].replace("t", "")) + # Initialize nested dictionaries if needed if (event_type, metric_type) not in time_points_by_metric: time_points_by_metric[(event_type, metric_type)] = {} - + # Store metrics by time point time_points_by_metric[(event_type, metric_type)][time_point] = values - + results = [] - + for event_type in sorted(event_types): - row = {'Risk': f"Type {event_type}"} - + row = {"Risk": f"Type {event_type}"} + for metric_type in metric_types: if (event_type, metric_type) not in time_points_by_metric: # Skip metrics that don't exist for this event type for q in quantiles: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" continue - + # Get all time points for this event/metric time_data = time_points_by_metric[(event_type, metric_type)] sorted_times = sorted(time_data.keys()) - + # For each quantile for q in quantiles: # Calculate the index for this quantile q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q))) - + # Get the corresponding time point for this quantile q_time = sorted_times[q_idx] - + # Get the metrics for this time point q_values = time_data[q_time] - + # Calculate and format statistics if q_values: value_array = np.array(q_values) mean_val = np.nanmean(value_array) std_val = np.nanstd(value_array) - row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}" + row[f"{metric_type.upper()}_q{q:.2f}"] = ( + f"{mean_val:.3f} ± {std_val:.3f}" + ) else: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" - + results.append(row) - + df = pd.DataFrame(results) - + # Define column order - columns = ['Risk'] - for metric in ['AUC', 'TDCI', 'BRIER']: + columns = ["Risk"] + for metric in ["AUC", "TDCI", "BRIER"]: for q in quantiles: columns.append(f"{metric}_q{q:.2f}") - + # Select columns in the right order (only those that exist) df = df[[col for col in columns if col in df.columns]] - + print("\nSummary Performance Metrics:") - print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False)) - + print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False)) + print("\nInterpretation:") print("- AUC: 0.5=random, >0.7=good, >0.8=excellent") - print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") + print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor") def main(): args = parse_args() print(args) - + # Set random seed for reproducibility set_seed(args.seed) - + # Configure CUDA for maximum performance if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True @@ -558,16 +676,22 @@ def main(): # Print CUDA device info device_name = torch.cuda.get_device_name(0) compute_capability = torch.cuda.get_device_capability(0) - print(f"Using CUDA device: {device_name} with compute capability {compute_capability[0]}.{compute_capability[1]}") - + print( + f"Using CUDA device: {device_name} with compute capability {compute_capability[0]}.{compute_capability[1]}" + ) + # Set GPU memory usage strategy torch.cuda.empty_cache() - + # Manual memory management - if hasattr(torch.cuda, 'memory_stats'): - print(f"Initial GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") - print(f"Initial GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") - + if hasattr(torch.cuda, "memory_stats"): + print( + f"Initial GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" + ) + print( + f"Initial GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB" + ) + # Load the dataset if args.dataset.lower() == "framingham": x, t, e, feature_names, n_cont, _ = load_framingham() @@ -579,10 +703,10 @@ def main(): x, t, e, feature_names, n_cont, _ = load_synthetic_dataset() else: raise ValueError(f"Dataset {args.dataset} not supported") - + # Convert data to tensors immediately before scaling - device = ("cuda" if torch.cuda.is_available() else "cpu") - + device = "cuda" if torch.cuda.is_available() else "cpu" + # Scale features if needed (more efficiently) if args.scaling.lower() == "standard": scaler = StandardScaler() @@ -598,13 +722,13 @@ def main(): pass else: raise ValueError(f"Scaling method {args.scaling} not supported") - + # Get number of competing risks num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0) - + # Maximum time in the dataset max_time = np.max(t) - + # Convert data to tensors early and move to device if small enough # For large datasets, keep on CPU and transfer in batches dataset_size = x.nbytes / (1024 * 1024) # Size in MB @@ -616,120 +740,136 @@ def main(): e_tensor = torch.tensor(e, dtype=torch.long, device=device) print(f"Entire dataset moved to {device} (Size: {dataset_size:.2f} MB)") # Create dataset with GPU tensors - dataset = SurvivalDatasetDeepHit(x_tensor.cpu().numpy(), t_tensor.cpu().numpy(), - e_tensor.cpu().numpy(), args.num_categories) + dataset = SurvivalDatasetDeepHit( + x_tensor.cpu().numpy(), + t_tensor.cpu().numpy(), + e_tensor.cpu().numpy(), + args.num_categories, + ) except RuntimeError: # If OOM, fall back to CPU - print(f"Dataset too large for GPU memory, keeping on CPU (Size: {dataset_size:.2f} MB)") + print( + f"Dataset too large for GPU memory, keeping on CPU (Size: {dataset_size:.2f} MB)" + ) x_tensor = torch.tensor(x, dtype=torch.float32) t_tensor = torch.tensor(t, dtype=torch.float32) e_tensor = torch.tensor(e, dtype=torch.long) dataset = SurvivalDatasetDeepHit(x, t, e, args.num_categories) else: # Large dataset, keep on CPU - print(f"Large dataset detected ({dataset_size:.2f} MB), keeping on CPU and using efficient batch loading") + print( + f"Large dataset detected ({dataset_size:.2f} MB), keeping on CPU and using efficient batch loading" + ) dataset = SurvivalDatasetDeepHit(x, t, e, args.num_categories) - + # Setup cross-validation skf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed) - + # Metrics collection and evaluation times all_metrics = defaultdict(list) quantiles = [0.25, 0.5, 0.75] - + # Calculate global evaluation times safe_max = 0.99 * np.max(t) eval_times = np.quantile(t[t <= safe_max], quantiles) - + print(f"Evaluation times: {eval_times}") - + # Create directory for figures if it doesn't exist if not os.path.exists("figs"): os.makedirs("figs") - + # Cross-validation loop for fold, (train_idx, val_idx) in enumerate(skf.split(x, e)): print(f"\n=== Fold {fold + 1}/{args.n_folds} ===") - + # Calculate optimal batch size based on device if device == "cuda": # Dynamically adjust batch size based on available GPU memory - total_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) # in GB + total_gpu_mem = torch.cuda.get_device_properties(0).total_memory / ( + 1024**3 + ) # in GB # Use heuristic - larger batch sizes for larger GPUs - dynamic_batch_size = min(2048, max(512, int(args.batch_size * (total_gpu_mem / 8)))) + dynamic_batch_size = min( + 2048, max(512, int(args.batch_size * (total_gpu_mem / 8))) + ) if dynamic_batch_size != args.batch_size: - print(f"Adjusting batch size from {args.batch_size} to {dynamic_batch_size} based on GPU memory") + print( + f"Adjusting batch size from {args.batch_size} to {dynamic_batch_size} based on GPU memory" + ) batch_size = dynamic_batch_size else: batch_size = args.batch_size else: batch_size = args.batch_size - + # Create data loaders with optimized settings train_subset = Subset(dataset, train_idx) val_subset = Subset(dataset, val_idx) - + # Use different prefetch factors based on data size and complexity prefetch_factor = 2 - + train_loader = DataLoader( - train_subset, - batch_size=batch_size, + train_subset, + batch_size=batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True if device == "cuda" else False, persistent_workers=True if args.num_workers > 0 else False, prefetch_factor=prefetch_factor if args.num_workers > 0 else None, - drop_last=True # Drop last incomplete batch for better performance + drop_last=True, # Drop last incomplete batch for better performance ) - + val_loader = DataLoader( - val_subset, + val_subset, batch_size=batch_size, num_workers=args.num_workers, pin_memory=True if device == "cuda" else False, persistent_workers=True if args.num_workers > 0 else False, - prefetch_factor=prefetch_factor if args.num_workers > 0 else None + prefetch_factor=prefetch_factor if args.num_workers > 0 else None, ) - + # Initialize DeepHit model with optimized size input_dims = { - 'x_dim': x.shape[1], - 'num_Event': num_competing_risks, - 'num_Category': args.num_categories + "x_dim": x.shape[1], + "num_Event": num_competing_risks, + "num_Category": args.num_categories, } - + # Adjust network size based on available GPU memory if device == "cuda": h_dim_shared = args.h_dim_shared h_dim_CS = args.h_dim_CS - + # For GPUs with more memory, can use larger networks if total_gpu_mem > 16: # For high-end GPUs h_dim_shared = max(args.h_dim_shared, 128) h_dim_CS = max(args.h_dim_CS, 32) - print(f"Using larger network dimensions for high-memory GPU: {h_dim_shared}/{h_dim_CS}") + print( + f"Using larger network dimensions for high-memory GPU: {h_dim_shared}/{h_dim_CS}" + ) else: h_dim_shared = args.h_dim_shared h_dim_CS = args.h_dim_CS - + network_settings = { - 'h_dim_shared': h_dim_shared, - 'h_dim_CS': h_dim_CS, - 'num_layers_shared': args.num_layers_shared, - 'num_layers_CS': args.num_layers_CS, - 'active_fn': args.active_fn, - 'keep_prob': 1.0 - args.dropout_rate + "h_dim_shared": h_dim_shared, + "h_dim_CS": h_dim_CS, + "num_layers_shared": args.num_layers_shared, + "num_layers_CS": args.num_layers_CS, + "active_fn": args.active_fn, + "keep_prob": 1.0 - args.dropout_rate, } - + # Create model model = DeepHit(input_dims, network_settings).to(device) - + # Optional: print model summary for debugging if fold == 0: num_params = sum(p.numel() for p in model.parameters()) print(f"Model has {num_params:,} parameters") - + # Print estimated memory usage per batch # 4 bytes per float32 parameter, multiply by 4 for activations, gradients, optimizer state param_memory_mb = num_params * 4 * 4 / (1024 * 1024) @@ -741,40 +881,63 @@ def main(): print("Profiling first batch for performance analysis...") # Get a sample batch for profiling sample_batch = next(iter(train_loader)) - x_sample, t_sample, e_sample, t_disc_sample = [t.to(device, non_blocking=True) for t in sample_batch] - + x_sample, t_sample, e_sample, t_disc_sample = [ + t.to(device, non_blocking=True) for t in sample_batch + ] + # Simple profiling of forward and backward pass with torch.autograd.profiler.profile(use_cuda=True) as prof: # Create masks - mask1 = create_fc_mask1_gpu(e_sample, t_disc_sample, model.num_Event, model.num_Category, device) - mask2 = create_fc_mask2_gpu(t_disc_sample, model.num_Category, device) - + mask1 = create_fc_mask1_gpu( + e_sample, + t_disc_sample, + model.num_Event, + model.num_Category, + device, + ) + mask2 = create_fc_mask2_gpu( + t_disc_sample, model.num_Category, device + ) + # Forward pass out, _ = model(x_sample) - + # Compute loss - loss = model.compute_loss(out, t_sample, e_sample, mask1, mask2, args.alpha, args.beta, args.gamma) - + loss = model.compute_loss( + out, + t_sample, + e_sample, + mask1, + mask2, + args.alpha, + args.beta, + args.gamma, + ) + # Backward pass loss.backward() - + # Print profiling results - profile_sorted = prof.key_averages().table(sort_by="cuda_time_total", row_limit=10) + profile_sorted = prof.key_averages().table( + sort_by="cuda_time_total", row_limit=10 + ) print("Profile of most expensive CUDA operations:") print(profile_sorted) - + # Identify bottlenecks cpu_pct = prof.self_cpu_time_total / prof.total_cuda_time_total * 100 if cpu_pct > 20: - print(f"WARNING: High CPU overhead ({cpu_pct:.1f}%). Consider optimizing CPU-GPU transfers.") - + print( + f"WARNING: High CPU overhead ({cpu_pct:.1f}%). Consider optimizing CPU-GPU transfers." + ) + except Exception as e: print(f"Profiling skipped due to error: {e}") - + # Train the model with optimizations train_deephit_model( - model, - train_loader, + model, + train_loader, val_loader, alpha=args.alpha, beta=args.beta, @@ -785,34 +948,34 @@ def main(): patience=args.patience, eval_freq=args.eval_freq, use_amp=args.use_amp, - verbose=True + verbose=True, ) - + # Evaluate model (less frequently during training) fold_metrics = evaluate_model( - model, - x[val_idx], - t[val_idx], + model, + x[val_idx], + t[val_idx], e[val_idx], - t[train_idx], - e[train_idx], - eval_times, + t[train_idx], + e[train_idx], + eval_times, max_time, - device + device, ) - + # Collect metrics for k, v in fold_metrics.items(): all_metrics[k].extend(v) - + # Free memory after each fold if device == "cuda": del model torch.cuda.empty_cache() - + # Display final metrics display_metrics_table(all_metrics, n_folds=args.n_folds) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/training_scripts/train_nested_cv.py b/training_scripts/train_nested_cv.py index 60ee0330..c277b556 100644 --- a/training_scripts/train_nested_cv.py +++ b/training_scripts/train_nested_cv.py @@ -1,86 +1,130 @@ -import yaml -import configargparse from collections import defaultdict -import optuna -import torch +import configargparse +import matplotlib.pyplot as plt import numpy as np +import optuna import pandas as pd -import torch.optim as optim -import matplotlib.pyplot as plt +import torch +import yaml +from model_utils import EarlyStopping, set_seed +from sklearn.model_selection import StratifiedKFold +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sksurv.metrics import concordance_index_ipcw +from sksurv.util import Surv from tabulate import tabulate +from torch import optim from torch.utils.data import DataLoader -from sklearn.model_selection import StratifiedKFold, train_test_split -from sklearn.preprocessing import StandardScaler, MinMaxScaler -from sksurv.util import Surv -from sksurv.metrics import concordance_index_ipcw +from crisp_nam.metrics import auc_td, brier_score from crisp_nam.models import CrispNamModel from crisp_nam.utils import ( - weighted_negative_log_likelihood_loss, + compute_baseline_cif, + compute_l2_penalty, negative_log_likelihood_loss, - compute_l2_penalty + plot_coxnam_shape_functions, + plot_feature_importance, + predict_absolute_risk, + weighted_negative_log_likelihood_loss, ) from data_utils import * -from model_utils import EarlyStopping, set_seed -from crisp_nam.metrics import brier_score, auc_td -from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif -from crisp_nam.utils import plot_coxnam_shape_functions, plot_feature_importance def parse_args(): parser = configargparse.ArgumentParser( description="Nested Cross-Validation for CrispNAM model", default_config_files=["config.yaml"], - config_file_parser_class=configargparse.YAMLConfigFileParser + config_file_parser_class=configargparse.YAMLConfigFileParser, ) - - parser.add_argument("-c", "--config", is_config_file=True, - help="Path to config file") - parser.add_argument("--dataset", type=str, default="framingham", - choices=["framingham", "support", "pbc", "synthetic"], - help="Dataset to use") + parser.add_argument( + "-c", "--config", is_config_file=True, help="Path to config file" + ) - parser.add_argument("--scaling", type=str, default="standard", - choices=["minmax", "standard", "none"], - help="Data scaling method for continuous features") + parser.add_argument( + "--dataset", + type=str, + default="framingham", + choices=["framingham", "support", "pbc", "synthetic"], + help="Dataset to use", + ) - parser.add_argument("--num_epochs", type=int, default=250, - help="Number of training epochs (reduced for nested CV)") - parser.add_argument("--batch_size", type=int, default=512, - help="Batch size for training") - parser.add_argument("--patience", type=int, default=10, - help="Patience for early stopping") + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + help="Data scaling method for continuous features", + ) + + parser.add_argument( + "--num_epochs", + type=int, + default=250, + help="Number of training epochs (reduced for nested CV)", + ) + parser.add_argument( + "--batch_size", type=int, default=512, help="Batch size for training" + ) + parser.add_argument( + "--patience", type=int, default=10, help="Patience for early stopping" + ) # Nested CV parameters - parser.add_argument("--outer_folds", type=int, default=5, - help="Number of outer CV folds") - parser.add_argument("--inner_folds", type=int, default=3, - help="Number of inner CV folds for hyperparameter tuning") - parser.add_argument("--n_trials", type=int, default=20, - help="Number of Optuna trials per inner fold") + parser.add_argument( + "--outer_folds", type=int, default=2, help="Number of outer CV folds" + ) + parser.add_argument( + "--inner_folds", + type=int, + default=2, + help="Number of inner CV folds for hyperparameter tuning", + ) + parser.add_argument( + "--n_trials", + type=int, + default=10, + help="Number of Optuna trials per inner fold", + ) # Event weighting - parser.add_argument("--event_weighting", type=str, default="none", - choices=["none", "balanced", "custom"], - help="Event weighting strategy") - parser.add_argument("--custom_event_weights", type=str, default=None, - help="Custom weights for events (comma-separated)") - - parser.add_argument("--seed", type=int, default=42, - help="Random seed for reproducibility") - + parser.add_argument( + "--event_weighting", + type=str, + default="none", + choices=["none", "balanced", "custom"], + help="Event weighting strategy", + ) + parser.add_argument( + "--custom_event_weights", + type=str, + default=None, + help="Custom weights for events (comma-separated)", + ) + + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + return parser.parse_args() -def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_rate=1e-3, - l2_reg=0.01, patience=10, event_weights=None, verbose=False): +def train_model( + model, + train_loader, + val_loader=None, + num_epochs=100, + learning_rate=1e-3, + l2_reg=0.01, + patience=10, + event_weights=None, + verbose=False, +): """Train model with early stopping""" optimizer = optim.AdamW(model.parameters(), lr=learning_rate) early_stopper = EarlyStopping(patience=patience) device = next(model.parameters()).device - best_val_loss = float('inf') + best_val_loss = float("inf") for epoch in range(num_epochs): model.train() @@ -90,12 +134,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r risk_scores, _ = model(x) if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg total = loss + reg @@ -113,25 +163,32 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r for x, t, e, _ in val_loader: x, t, e = x.to(device), t.to(device), e.to(device) risk_scores, _ = model(x) - + if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg val_loss += (loss + reg).item() avg_val_loss = val_loss / len(val_loader) - - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss - + + best_val_loss = min(best_val_loss, avg_val_loss) + early_stopper.step(avg_val_loss) if verbose and epoch % 10 == 0: - print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") + print( + f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}" + ) if early_stopper.should_stop: if verbose: @@ -139,37 +196,51 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r break elif verbose and epoch % 10 == 0: print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f}") - + return best_val_loss -def hyperparameter_optimization(x_train_inner, t_train_inner, e_train_inner, - x_val_inner, t_val_inner, e_val_inner, - num_competing_risks, device, args, event_weights, n_cont): +def hyperparameter_optimization( + x_train_inner, + t_train_inner, + e_train_inner, + x_val_inner, + t_val_inner, + e_val_inner, + num_competing_risks, + device, + args, + event_weights, + n_cont, +): """Run Optuna hyperparameter optimization on inner training data""" - + def objective(trial): # Define hyperparameters to optimize - learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True) - l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True) - dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.8) - feature_dropout = trial.suggest_float('feature_dropout', 0.0, 0.5) + learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True) + l2_reg = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True) + dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.8) + feature_dropout = trial.suggest_float("feature_dropout", 0.0, 0.5) # Hidden dimensions - n_layers = trial.suggest_int('n_layers', 1, 3) + n_layers = trial.suggest_int("n_layers", 1, 3) hidden_dimensions = [] for i in range(n_layers): - hidden_dimensions.append(trial.suggest_categorical(f'hidden_dim_{i}', [8,16, 32, 64, 128])) - - batch_norm = trial.suggest_categorical('batch_norm', [True, False]) - + hidden_dimensions.append( + trial.suggest_categorical(f"hidden_dim_{i}", [8, 16, 32, 64, 128]) + ) + + batch_norm = trial.suggest_categorical("batch_norm", [True, False]) + # Create data loaders train_dataset = SurvivalDataset(x_train_inner, t_train_inner, e_train_inner) val_dataset = SurvivalDataset(x_val_inner, t_val_inner, e_val_inner) - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) + + train_loader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) val_loader = DataLoader(val_dataset, batch_size=args.batch_size) - + # Initialize model model = CrispNamModel( num_features=x_train_inner.shape[1], @@ -177,26 +248,28 @@ def objective(trial): hidden_sizes=hidden_dimensions, dropout_rate=dropout_rate, feature_dropout=feature_dropout, - batch_norm=batch_norm + batch_norm=batch_norm, ).to(device) - + # Train model best_val_loss = train_model( - model, train_loader, val_loader, - num_epochs=args.num_epochs, + model, + train_loader, + val_loader, + num_epochs=args.num_epochs, learning_rate=learning_rate, - l2_reg=l2_reg, + l2_reg=l2_reg, patience=args.patience, event_weights=event_weights, - verbose=False + verbose=False, ) - + return best_val_loss - + # Create study and optimize study = optuna.create_study(direction="minimize") study.optimize(objective, n_trials=args.n_trials, show_progress_bar=False) - + return study.best_params @@ -211,29 +284,41 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time for k in range(n_events): for i, time in enumerate(times): risk_preds = abs_risks[:, k, i] - + try: risk_preds_2d = np.zeros((len(risk_preds), len(times))) - risk_preds_2d[:, i] = risk_preds - + risk_preds_2d[:, i] = risk_preds + auc_score, _ = auc_td( - e_val, t_val, risk_preds_2d, times, time, - km=(e_train, t_train), primary_risk=k+1 + e_val, + t_val, + risk_preds_2d, + times, + time, + km=(e_train, t_train), + primary_risk=k + 1, ) - metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score)) + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score)) except Exception as ex: - print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan) try: brier_score_val, _ = brier_score( - e_val, t_val, risk_preds_2d, times, time, - km=(e_train, t_train), primary_risk=k+1 + e_val, + t_val, + risk_preds_2d, + times, + time, + km=(e_train, t_train), + primary_risk=k + 1, + ) + metrics[f"brier_event{k + 1}_t{time:.2f}"].append( + float(brier_score_val) ) - metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val)) except Exception as ex: - print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan) try: tdci_result = concordance_index_ipcw( @@ -242,112 +327,120 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time if isinstance(tdci_result, tuple): tdci_score = tdci_result[0] else: - tdci_score = tdci_result - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score)) + tdci_score = tdci_result + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score)) except Exception as ex: - print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan) return metrics -def display_metrics_table(metrics_dict, quantiles=[0.25, 0.5, 0.75], dataset_name="dataset"): +def display_metrics_table( + metrics_dict, quantiles=[0.25, 0.5, 0.75], dataset_name="dataset" +): """Display evaluation metrics table and return processed data""" time_points_by_metric = {} event_types = set() - metric_types = ['auc', 'tdci', 'brier'] - + metric_types = ["auc", "tdci", "brier"] + for key, values in metrics_dict.items(): if any(metric in key for metric in metric_types): - parts = key.split('_') + parts = key.split("_") metric_type = parts[0] event_info = parts[1] - - event_type = int(event_info.replace('event', '')) + + event_type = int(event_info.replace("event", "")) event_types.add(event_type) - - if len(parts) > 2 and parts[2].startswith('t'): - time_point = float(parts[2].replace('t', '')) - + + if len(parts) > 2 and parts[2].startswith("t"): + time_point = float(parts[2].replace("t", "")) + if (event_type, metric_type) not in time_points_by_metric: time_points_by_metric[(event_type, metric_type)] = {} - + time_points_by_metric[(event_type, metric_type)][time_point] = values - + # Create summary table summary_results = [] detailed_results = [] - + for event_type in sorted(event_types): - row = {'Risk': f"Type {event_type}"} - + row = {"Risk": f"Type {event_type}"} + for metric_type in metric_types: if (event_type, metric_type) not in time_points_by_metric: for q in quantiles: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" continue - + time_data = time_points_by_metric[(event_type, metric_type)] sorted_times = sorted(time_data.keys()) - + for q in quantiles: q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q))) q_time = sorted_times[q_idx] q_values = time_data[q_time] - + if q_values: value_array = np.array(q_values) mean_val = np.nanmean(value_array) std_val = np.nanstd(value_array) - row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}" - + row[f"{metric_type.upper()}_q{q:.2f}"] = ( + f"{mean_val:.3f} ± {std_val:.3f}" + ) + # Add to detailed results for separate CSV - detailed_results.append({ - 'Dataset': dataset_name, - 'Risk_Type': event_type, - 'Metric': metric_type.upper(), - 'Time_Quantile': f"q{q:.2f}", - 'Time_Value': q_time, - 'Mean': mean_val, - 'Std': std_val, - 'N_Folds': len(q_values), - 'Raw_Values': ';'.join(map(str, q_values)) - }) + detailed_results.append( + { + "Dataset": dataset_name, + "Risk_Type": event_type, + "Metric": metric_type.upper(), + "Time_Quantile": f"q{q:.2f}", + "Time_Value": q_time, + "Mean": mean_val, + "Std": std_val, + "N_Folds": len(q_values), + "Raw_Values": ";".join(map(str, q_values)), + } + ) else: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" - + summary_results.append(row) - + # Create DataFrames summary_df = pd.DataFrame(summary_results) detailed_df = pd.DataFrame(detailed_results) - - columns = ['Risk'] - for metric in ['AUC', 'TDCI', 'BRIER']: + + columns = ["Risk"] + for metric in ["AUC", "TDCI", "BRIER"]: for q in quantiles: columns.append(f"{metric}_q{q:.2f}") - + summary_df = summary_df[[col for col in columns if col in summary_df.columns]] - + print("\nNested CV Performance Metrics:") - print(tabulate(summary_df, headers='keys', tablefmt='pretty', showindex=False)) - + print(tabulate(summary_df, headers="keys", tablefmt="pretty", showindex=False)) + print("\nInterpretation:") print("- AUC: 0.5=random, >0.7=good, >0.8=excellent") - print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") + print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor") - + return summary_df, detailed_df def main(): args = parse_args() - print(f"Running Nested Cross-Validation with {args.outer_folds} outer folds and {args.inner_folds} inner folds") + print( + f"Running Nested Cross-Validation with {args.outer_folds} outer folds and {args.inner_folds} inner folds" + ) print(args) - + # Set random seed set_seed(args.seed) - + # Load dataset if args.dataset.lower() == "framingham": x, t, e, feature_names, n_cont, _ = load_framingham() @@ -359,142 +452,193 @@ def main(): x, t, e, feature_names, n_cont, _ = load_synthetic_dataset() else: raise ValueError(f"Dataset {args.dataset} not supported") - + num_competing_risks = len(np.unique(e)) - 1 - device = ("cuda" if torch.cuda.is_available() else "cpu") - + device = "cuda" if torch.cuda.is_available() else "cpu" + # Outer cross-validation loop - outer_cv = StratifiedKFold(n_splits=args.outer_folds, shuffle=True, random_state=args.seed) + outer_cv = StratifiedKFold( + n_splits=args.outer_folds, shuffle=True, random_state=args.seed + ) all_metrics = defaultdict(list) all_best_params = [] - + quantiles = [0.25, 0.5, 0.75] safe_max = 0.99 * np.max(t) eval_times = np.quantile(t[t <= safe_max], quantiles) - + print(f"Evaluation times: {eval_times}") - + for outer_fold, (train_idx, test_idx) in enumerate(outer_cv.split(x, e)): print(f"\n=== Outer Fold {outer_fold + 1}/{args.outer_folds} ===") - + # Split data for outer fold x_train_outer, x_test_outer = x[train_idx].copy(), x[test_idx].copy() t_train_outer, t_test_outer = t[train_idx], t[test_idx] e_train_outer, e_test_outer = e[train_idx], e[test_idx] - + # Apply scaling within outer fold if args.scaling.lower() == "standard": scaler = StandardScaler() - x_train_outer[:, -n_cont:] = scaler.fit_transform(x_train_outer[:, -n_cont:]) + x_train_outer[:, -n_cont:] = scaler.fit_transform( + x_train_outer[:, -n_cont:] + ) x_test_outer[:, -n_cont:] = scaler.transform(x_test_outer[:, -n_cont:]) elif args.scaling.lower() == "minmax": scaler = MinMaxScaler() - x_train_outer[:, -n_cont:] = scaler.fit_transform(x_train_outer[:, -n_cont:]) + x_train_outer[:, -n_cont:] = scaler.fit_transform( + x_train_outer[:, -n_cont:] + ) x_test_outer[:, -n_cont:] = scaler.transform(x_test_outer[:, -n_cont:]) - + # Inner cross-validation for hyperparameter tuning - inner_cv = StratifiedKFold(n_splits=args.inner_folds, shuffle=True, random_state=args.seed) + inner_cv = StratifiedKFold( + n_splits=args.inner_folds, shuffle=True, random_state=args.seed + ) inner_scores = [] inner_params = [] - - for inner_fold, (train_inner_idx, val_inner_idx) in enumerate(inner_cv.split(x_train_outer, e_train_outer)): + + for inner_fold, (train_inner_idx, val_inner_idx) in enumerate( + inner_cv.split(x_train_outer, e_train_outer) + ): print(f" Inner Fold {inner_fold + 1}/{args.inner_folds}") - + # Split inner training data x_train_inner = x_train_outer[train_inner_idx].copy() x_val_inner = x_train_outer[val_inner_idx].copy() - t_train_inner, t_val_inner = t_train_outer[train_inner_idx], t_train_outer[val_inner_idx] - e_train_inner, e_val_inner = e_train_outer[train_inner_idx], e_train_outer[val_inner_idx] - + t_train_inner, t_val_inner = ( + t_train_outer[train_inner_idx], + t_train_outer[val_inner_idx], + ) + e_train_inner, e_val_inner = ( + e_train_outer[train_inner_idx], + e_train_outer[val_inner_idx], + ) + # Calculate event weights for this inner fold event_weights = None if args.event_weighting == "balanced": event_counts = np.zeros(num_competing_risks) for k in range(1, num_competing_risks + 1): - event_counts[k-1] = np.sum(e_train_inner == k) + event_counts[k - 1] = np.sum(e_train_inner == k) event_counts = np.maximum(event_counts, 1) event_weights = 1.0 / event_counts - event_weights = event_weights * (num_competing_risks / event_weights.sum()) - event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device) - + event_weights = event_weights * ( + num_competing_risks / event_weights.sum() + ) + event_weights = torch.tensor( + event_weights, dtype=torch.float32, device=device + ) + # Hyperparameter optimization best_params = hyperparameter_optimization( - x_train_inner, t_train_inner, e_train_inner, - x_val_inner, t_val_inner, e_val_inner, - num_competing_risks, device, args, event_weights, n_cont + x_train_inner, + t_train_inner, + e_train_inner, + x_val_inner, + t_val_inner, + e_val_inner, + num_competing_risks, + device, + args, + event_weights, + n_cont, ) - + inner_params.append(best_params) - + # Select best hyperparameters (could use ensemble or average, here we take most common) # For simplicity, we'll use the first inner fold's best params best_params = inner_params[0] all_best_params.append(best_params) - - print(f" Selected hyperparameters for outer fold {outer_fold + 1}: {best_params}") - + + print( + f" Selected hyperparameters for outer fold {outer_fold + 1}: {best_params}" + ) + # Train final model on all outer training data with best hyperparameters - n_layers = best_params['n_layers'] - hidden_dimensions = [best_params[f'hidden_dim_{i}'] for i in range(n_layers)] - + n_layers = best_params["n_layers"] + hidden_dimensions = [best_params[f"hidden_dim_{i}"] for i in range(n_layers)] + # Calculate event weights for outer training data event_weights = None if args.event_weighting == "balanced": event_counts = np.zeros(num_competing_risks) for k in range(1, num_competing_risks + 1): - event_counts[k-1] = np.sum(e_train_outer == k) + event_counts[k - 1] = np.sum(e_train_outer == k) event_counts = np.maximum(event_counts, 1) event_weights = 1.0 / event_counts event_weights = event_weights * (num_competing_risks / event_weights.sum()) - event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device) - + event_weights = torch.tensor( + event_weights, dtype=torch.float32, device=device + ) + # Create final model final_model = CrispNamModel( num_features=x.shape[1], num_competing_risks=num_competing_risks, hidden_sizes=hidden_dimensions, - dropout_rate=best_params['dropout_rate'], - feature_dropout=best_params['feature_dropout'], - batch_norm=best_params['batch_norm'] + dropout_rate=best_params["dropout_rate"], + feature_dropout=best_params["feature_dropout"], + batch_norm=best_params["batch_norm"], ).to(device) - + # Train final model train_dataset = SurvivalDataset(x_train_outer, t_train_outer, e_train_outer) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) - + train_loader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True + ) + train_model( - final_model, train_loader, None, - num_epochs=args.num_epochs, - learning_rate=best_params['learning_rate'], - l2_reg=best_params['l2_reg'], + final_model, + train_loader, + None, + num_epochs=args.num_epochs, + learning_rate=best_params["learning_rate"], + l2_reg=best_params["l2_reg"], patience=args.patience, event_weights=event_weights, - verbose=True + verbose=True, ) - + # Evaluate on test data - baseline_cifs = {k: compute_baseline_cif(t_train_outer, e_train_outer, eval_times, k + 1) - for k in range(num_competing_risks)} - - abs_risks = predict_absolute_risk(final_model, x_test_outer, baseline_cifs, eval_times, device=device) - - fold_metrics = evaluate_model(final_model, x_test_outer, t_test_outer, e_test_outer, - t_train_outer, e_train_outer, abs_risks, eval_times) - + baseline_cifs = { + k: compute_baseline_cif(t_train_outer, e_train_outer, eval_times, k + 1) + for k in range(num_competing_risks) + } + + abs_risks = predict_absolute_risk( + final_model, x_test_outer, baseline_cifs, eval_times, device=device + ) + + fold_metrics = evaluate_model( + final_model, + x_test_outer, + t_test_outer, + e_test_outer, + t_train_outer, + e_train_outer, + abs_risks, + eval_times, + ) + for k, v in fold_metrics.items(): all_metrics[k].extend(v) - + # Display final results and get DataFrames for saving - summary_df, detailed_df = display_metrics_table(all_metrics, dataset_name=args.dataset) - + summary_df, detailed_df = display_metrics_table( + all_metrics, dataset_name=args.dataset + ) + # Generate shape plots using the last trained model (from the last outer fold) print("\n=== Generating Shape Function Plots ===") - + # Create figs subdirectory if not present import os + if not os.path.exists("figs"): os.makedirs("figs") - + # Use the entire dataset for plotting (with proper scaling) x_plot = x.copy() if args.scaling.lower() == "standard": @@ -503,11 +647,11 @@ def main(): elif args.scaling.lower() == "minmax": scaler = MinMaxScaler() x_plot[:, -n_cont:] = scaler.fit_transform(x_plot[:, -n_cont:]) - + # Generate plots for each risk type for risk in range(1, num_competing_risks + 1): print(f"Generating plots for risk type {risk}") - + # Generate feature importance plot fig, _, top_positive, top_negative = plot_feature_importance( model=final_model, @@ -515,131 +659,150 @@ def main(): feature_names=feature_names, n_top=5, # Show top 5 positive contributors n_bottom=5, # Show top 5 negative contributors - risk_idx=risk, + risk_idx=risk, figsize=(6, 4), - output_file=f"results/plots/nested_cv_feature_importance_risk_{risk}_{args.dataset}.png" + output_file=f"results/plots/nested_cv_feature_importance_risk_{risk}_{args.dataset}.png", ) - + # Get top features for shape function plots top_features = top_positive + top_negative - + # Generate shape function plots for top features + print("Ananya: Generating shape function plots for top features...") + print(f'{feature_names}') + print(f'{top_features}') fig, _ = plot_coxnam_shape_functions( model=final_model, - X=x_plot, + X=x_plot, risk_to_plot=risk, - feature_names=feature_names, - top_features=top_features, + feature_names=feature_names, + top_features=top_features, ncols=5, figsize=(12, 6), - output_file=f"results/plots/nested_cv_shape_functions_risk_{risk}_{args.dataset}.png" + output_file=f"results/plots/nested_cv_shape_functions_risk_{risk}_{args.dataset}.png", ) plt.close(fig) - - print(f"Shape function plots saved to results/plots/ directory") - + + print("Shape function plots saved to results/plots/ directory") + # Save metrics to files print("\n=== Saving Metrics to Files ===") - + # Save summary metrics (formatted table) summary_filename = f"nested_cv_summary_metrics_{args.dataset}.csv" summary_df.to_csv(summary_filename, index=False) print(f"Summary metrics saved to: {summary_filename}") - + # Save detailed metrics (all individual fold results) detailed_filename = f"nested_cv_detailed_metrics_{args.dataset}.csv" detailed_df.to_csv(detailed_filename, index=False) print(f"Detailed metrics saved to: {detailed_filename}") - + # Save to Excel with multiple sheets excel_filename = f"nested_cv_metrics_{args.dataset}.xlsx" - with pd.ExcelWriter(excel_filename, engine='openpyxl') as writer: - summary_df.to_excel(writer, sheet_name='Summary', index=False) - detailed_df.to_excel(writer, sheet_name='Detailed', index=False) - + with pd.ExcelWriter(excel_filename, engine="openpyxl") as writer: + summary_df.to_excel(writer, sheet_name="Summary", index=False) + detailed_df.to_excel(writer, sheet_name="Detailed", index=False) + # Create a metadata sheet - metadata_df = pd.DataFrame([ - {'Parameter': 'Dataset', 'Value': args.dataset}, - {'Parameter': 'Outer Folds', 'Value': args.outer_folds}, - {'Parameter': 'Inner Folds', 'Value': args.inner_folds}, - {'Parameter': 'Number of Trials', 'Value': args.n_trials}, - {'Parameter': 'Number of Epochs', 'Value': args.num_epochs}, - {'Parameter': 'Batch Size', 'Value': args.batch_size}, - {'Parameter': 'Event Weighting', 'Value': args.event_weighting}, - {'Parameter': 'Scaling', 'Value': args.scaling}, - {'Parameter': 'Random Seed', 'Value': args.seed} - ]) - metadata_df.to_excel(writer, sheet_name='Metadata', index=False) - + metadata_df = pd.DataFrame( + [ + {"Parameter": "Dataset", "Value": args.dataset}, + {"Parameter": "Outer Folds", "Value": args.outer_folds}, + {"Parameter": "Inner Folds", "Value": args.inner_folds}, + {"Parameter": "Number of Trials", "Value": args.n_trials}, + {"Parameter": "Number of Epochs", "Value": args.num_epochs}, + {"Parameter": "Batch Size", "Value": args.batch_size}, + {"Parameter": "Event Weighting", "Value": args.event_weighting}, + {"Parameter": "Scaling", "Value": args.scaling}, + {"Parameter": "Random Seed", "Value": args.seed}, + ] + ) + metadata_df.to_excel(writer, sheet_name="Metadata", index=False) + print(f"Excel file with multiple sheets saved to: {excel_filename}") - + # Save raw metrics dictionary as JSON for complete reproducibility import json - + # Convert numpy arrays to lists for JSON serialization serializable_metrics = {} for key, values in all_metrics.items(): - serializable_metrics[key] = [float(v) if not np.isnan(v) else None for v in values] - + serializable_metrics[key] = [ + float(v) if not np.isnan(v) else None for v in values + ] + json_filename = f"nested_cv_raw_metrics_{args.dataset}.json" - with open(json_filename, 'w') as f: - json.dump({ - 'metrics': serializable_metrics, - 'metadata': { - 'dataset': args.dataset, - 'outer_folds': args.outer_folds, - 'inner_folds': args.inner_folds, - 'n_trials': args.n_trials, - 'num_epochs': args.num_epochs, - 'batch_size': args.batch_size, - 'event_weighting': args.event_weighting, - 'scaling': args.scaling, - 'seed': args.seed - } - }, f, indent=2) - + with open(json_filename, "w") as f: + json.dump( + { + "metrics": serializable_metrics, + "metadata": { + "dataset": args.dataset, + "outer_folds": args.outer_folds, + "inner_folds": args.inner_folds, + "n_trials": args.n_trials, + "num_epochs": args.num_epochs, + "batch_size": args.batch_size, + "event_weighting": args.event_weighting, + "scaling": args.scaling, + "seed": args.seed, + }, + }, + f, + indent=2, + ) + print(f"Raw metrics (JSON) saved to: {json_filename}") - + # Save aggregated best parameters param_summary = {} for param_name in all_best_params[0].keys(): param_values = [params[param_name] for params in all_best_params] - + # Check if all values are numeric if all(isinstance(val, (int, float)) for val in param_values): param_summary[param_name] = np.mean(param_values) else: # For categorical parameters, take the most common param_summary[param_name] = max(set(param_values), key=param_values.count) - - print(f"\nAggregated best hyperparameters across all outer folds:") + + print("\nAggregated best hyperparameters across all outer folds:") for param, value in param_summary.items(): print(f"{param}: {value}") - + # Save to YAML config_dict = { "dataset": args.dataset, "scaling": args.scaling, "num_epochs": args.num_epochs, "batch_size": args.batch_size, - "learning_rate": param_summary['learning_rate'], - "l2_reg": param_summary['l2_reg'], + "learning_rate": param_summary["learning_rate"], + "l2_reg": param_summary["l2_reg"], "patience": args.patience, - "dropout_rate": param_summary['dropout_rate'], - "feature_dropout": param_summary['feature_dropout'], - "hidden_dimensions": ",".join(map(str, [param_summary[f'hidden_dim_{i}'] for i in range(int(param_summary['n_layers']))])), - "batch_norm": str(param_summary['batch_norm']), + "dropout_rate": param_summary["dropout_rate"], + "feature_dropout": param_summary["feature_dropout"], + "hidden_dimensions": ",".join( + map( + str, + [ + param_summary[f"hidden_dim_{i}"] + for i in range(int(param_summary["n_layers"])) + ], + ) + ), + "batch_norm": str(param_summary["batch_norm"]), "event_weighting": args.event_weighting, "seed": args.seed, - "n_folds": args.outer_folds + "n_folds": args.outer_folds, } - + output_file = f"nested_cv_best_params_{args.dataset}.yaml" - with open(output_file, 'w') as file: + with open(output_file, "w") as file: yaml.dump(config_dict, file, default_flow_style=False) - + print(f"\nBest hyperparameters saved to {output_file}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/training_scripts/tune_optuna.py b/training_scripts/tune_optuna.py index 3711de49..b6494f4d 100644 --- a/training_scripts/tune_optuna.py +++ b/training_scripts/tune_optuna.py @@ -1,74 +1,108 @@ -import yaml -import configargparse from collections import defaultdict +import configargparse +import numpy as np import optuna import torch -import numpy as np -import torch.optim as optim -from torch.utils.data import DataLoader +import yaml +from model_utils import EarlyStopping, set_seed from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler, MinMaxScaler +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sksurv.metrics import brier_score, concordance_index_ipcw, cumulative_dynamic_auc from sksurv.util import Surv -from sksurv.metrics import concordance_index_ipcw, cumulative_dynamic_auc, brier_score +from torch import optim +from torch.utils.data import DataLoader from crisp_nam.models import CrispNamModel from crisp_nam.utils import ( - weighted_negative_log_likelihood_loss, + compute_baseline_cif, + compute_l2_penalty, negative_log_likelihood_loss, - compute_l2_penalty + predict_absolute_risk, + weighted_negative_log_likelihood_loss, ) from data_utils import * -from model_utils import EarlyStopping, set_seed -from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif + def parse_args(): parser = configargparse.ArgumentParser( description="Training script for MultiTaskCoxNAM model with Optuna", default_config_files=["config.yaml"], - config_file_parser_class=configargparse.YAMLConfigFileParser + config_file_parser_class=configargparse.YAMLConfigFileParser, ) - - + # Dataset - parser.add_argument("--dataset", type=str, default="framingham", choices=["framingham", "support", "pbc", "synthetic"], - help="Dataset to use") - + parser.add_argument( + "--dataset", + type=str, + default="framingham", + choices=["framingham", "support", "pbc", "synthetic"], + help="Dataset to use", + ) + # Data scaling - parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"], - help="Data scaling method for continuous features") - + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + help="Data scaling method for continuous features", + ) + # Training parameters - parser.add_argument("--num_epochs", type=int, default=500, - help="Number of training epochs") - parser.add_argument("--batch_size", type=int, default=256, - help="Batch size for training") - parser.add_argument("--patience", type=int, default=10, - help="Patience for early stopping") - + parser.add_argument( + "--num_epochs", type=int, default=500, help="Number of training epochs" + ) + parser.add_argument( + "--batch_size", type=int, default=256, help="Batch size for training" + ) + parser.add_argument( + "--patience", type=int, default=10, help="Patience for early stopping" + ) + # Optuna parameters - parser.add_argument("--n_trials", type=int, default=50, - help="Number of Optuna trials") - + parser.add_argument( + "--n_trials", type=int, default=50, help="Number of Optuna trials" + ) + # Weight parameters - parser.add_argument("--event_weighting", type=str, default="none", - choices=["none", "balanced", "custom"], - help="Event weighting strategy (none, balanced, custom)") - parser.add_argument("--custom_event_weights", type=str, default=None, - help="Custom weights for events (comma-separated, e.g., '1.0,2.0')") - + parser.add_argument( + "--event_weighting", + type=str, + default="none", + choices=["none", "balanced", "custom"], + help="Event weighting strategy (none, balanced, custom)", + ) + parser.add_argument( + "--custom_event_weights", + type=str, + default=None, + help="Custom weights for events (comma-separated, e.g., '1.0,2.0')", + ) + # Other parameters - parser.add_argument("--seed", type=int, default=42, - help="Random seed for reproducibility") - + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + return parser.parse_args() -def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_rate=1e-3, - l2_reg=0.01, patience=10, event_weights=None, verbose=True): + +def train_model( + model, + train_loader, + val_loader=None, + num_epochs=500, + learning_rate=1e-3, + l2_reg=0.01, + patience=10, + event_weights=None, + verbose=True, +): optimizer = optim.AdamW(model.parameters(), lr=learning_rate) early_stopper = EarlyStopping(patience=patience) device = next(model.parameters()).device - best_val_loss = float('inf') + best_val_loss = float("inf") for epoch in range(num_epochs): model.train() @@ -79,12 +113,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r # Use weighted loss if event_weights is provided if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg total = loss + reg @@ -102,25 +142,32 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r for x, t, e, _ in val_loader: x, t, e = x.to(device), t.to(device), e.to(device) risk_scores, _ = model(x) - + # Use same loss function as in training if event_weights is not None: - loss = weighted_negative_log_likelihood_loss(risk_scores, t, e, - model.num_competing_risks, - event_weights=event_weights) + loss = weighted_negative_log_likelihood_loss( + risk_scores, + t, + e, + model.num_competing_risks, + event_weights=event_weights, + ) else: - loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks) - + loss = negative_log_likelihood_loss( + risk_scores, t, e, model.num_competing_risks + ) + reg = compute_l2_penalty(model) * l2_reg val_loss += (loss + reg).item() avg_val_loss = val_loss / len(val_loader) current_best = early_stopper.step(avg_val_loss) - - if avg_val_loss < best_val_loss: - best_val_loss = avg_val_loss + + best_val_loss = min(best_val_loss, avg_val_loss) if verbose: - print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}") + print( + f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}" + ) if early_stopper.should_stop: if verbose: @@ -128,7 +175,7 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r break elif verbose: print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f}") - + return best_val_loss @@ -151,17 +198,14 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time # ---- AUC ---- try: auc_score, _ = cumulative_dynamic_auc( - survival_train, - survival_val, - risk_preds, - times=[time] + survival_train, survival_val, risk_preds, times=[time] ) - metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score[0])) + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score[0])) avg_auc += float(auc_score[0]) count += 1 except Exception as ex: - print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan) # ---- Brier Score ---- try: @@ -169,153 +213,176 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time surv_probs = surv_probs.reshape(-1, 1) # Shape: (n_samples, n_times) _, brier_scores = brier_score( - survival_train, - survival_val, - surv_probs, - times=np.array([time]) + survival_train, survival_val, surv_probs, times=np.array([time]) + ) + metrics[f"brier_event{k + 1}_t{time:.2f}"].append( + float(brier_scores[0]) ) - metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_scores[0])) except Exception as ex: - print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan) + print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan) # ---- Time-dependent Concordance Index ---- try: tdci_result = concordance_index_ipcw( - survival_train, - survival_val, - estimate=risk_preds, - tau=time + survival_train, survival_val, estimate=risk_preds, tau=time ) if isinstance(tdci_result, tuple): tdci_score = tdci_result[0] else: tdci_score = tdci_result # fallback - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score)) + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score)) except Exception as ex: - print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}") - metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan) - + print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}") + metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan) + # Calculate average AUC if we have valid measurements if count > 0: avg_auc = avg_auc / count else: avg_auc = 0.0 - + return metrics, avg_auc def display_metrics_table(metrics_dict, quantiles=[0.25, 0.5, 0.75]): import pandas as pd from tabulate import tabulate - + time_points_by_metric = {} event_types = set() - metric_types = ['auc', 'tdci', 'brier'] - + metric_types = ["auc", "tdci", "brier"] + for key, values in metrics_dict.items(): if any(metric in key for metric in metric_types): - parts = key.split('_') + parts = key.split("_") metric_type = parts[0] event_info = parts[1] - + # Extract event type - event_type = int(event_info.replace('event', '')) + event_type = int(event_info.replace("event", "")) event_types.add(event_type) - + # Extract time point - if len(parts) > 2 and parts[2].startswith('t'): - time_point = float(parts[2].replace('t', '')) - + if len(parts) > 2 and parts[2].startswith("t"): + time_point = float(parts[2].replace("t", "")) + # Initialize nested dictionaries if needed if (event_type, metric_type) not in time_points_by_metric: time_points_by_metric[(event_type, metric_type)] = {} - + # Store metrics by time point time_points_by_metric[(event_type, metric_type)][time_point] = values - + results = [] - + for event_type in sorted(event_types): - row = {'Risk': f"Type {event_type}"} - + row = {"Risk": f"Type {event_type}"} + for metric_type in metric_types: if (event_type, metric_type) not in time_points_by_metric: # Skip metrics that don't exist for this event type for q in quantiles: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" continue - + # Get all time points for this event/metric time_data = time_points_by_metric[(event_type, metric_type)] sorted_times = sorted(time_data.keys()) - + # For each quantile for q in quantiles: # Calculate the index for this quantile q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q))) - + # Get the corresponding time point for this quantile q_time = sorted_times[q_idx] - + # Get the metrics for this time point q_values = time_data[q_time] - + # Calculate and format statistics if q_values: value_array = np.array(q_values) mean_val = np.nanmean(value_array) std_val = np.nanstd(value_array) - row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}" + row[f"{metric_type.upper()}_q{q:.2f}"] = ( + f"{mean_val:.3f} ± {std_val:.3f}" + ) else: row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A" - + results.append(row) - + df = pd.DataFrame(results) - + # Define column order - columns = ['Risk'] - for metric in ['AUC', 'TDCI', 'BRIER']: + columns = ["Risk"] + for metric in ["AUC", "TDCI", "BRIER"]: for q in quantiles: columns.append(f"{metric}_q{q:.2f}") - + # Select columns in the right order (only those that exist) df = df[[col for col in columns if col in df.columns]] - + print("\nSummary Performance Metrics:") - print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False)) - + print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False)) + print("\nInterpretation:") print("- AUC: 0.5=random, >0.7=good, >0.8=excellent") - print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") + print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent") print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor") -def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, feature_names, n_cont, - num_competing_risks, device, args, event_weights=None): - +def objective( + trial, + x, + t, + e, + x_train, + t_train, + e_train, + x_val, + t_val, + e_val, + feature_names, + n_cont, + num_competing_risks, + device, + args, + event_weights=None, +): # Define hyperparameters to optimize - learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True) - l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True) - dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.8) - feature_dropout = trial.suggest_float('feature_dropout', 0.0, 0.5) - + learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True) + l2_reg = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True) + dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.8) + feature_dropout = trial.suggest_float("feature_dropout", 0.0, 0.5) + # For hidden_dimensions - n_layers = trial.suggest_int('n_layers', 1, 3) + n_layers = trial.suggest_int("n_layers", 1, 3) hidden_dimensions = [] for i in range(n_layers): - hidden_dimensions.append(trial.suggest_categorical(f'hidden_dim_{i}', [32, 64, 128, 256])) - - batch_norm = trial.suggest_categorical('batch_norm', [True, False]) - + hidden_dimensions.append( + trial.suggest_categorical(f"hidden_dim_{i}", [32, 64, 128, 256]) + ) + + batch_norm = trial.suggest_categorical("batch_norm", [True, False]) + # Create data loaders train_dataset = SurvivalDataset(x_train, t_train, e_train) val_dataset = SurvivalDataset(x_val, t_val, e_val) - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=args.batch_size,num_workers=4, pin_memory=True) - + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=4, + pin_memory=True, + shuffle=True, + ) + val_loader = DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True + ) + # Initialize model with the hyperparameters model = CrispNamModel( num_features=x.shape[1], @@ -323,40 +390,40 @@ def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, fe hidden_sizes=hidden_dimensions, dropout_rate=dropout_rate, feature_dropout=feature_dropout, - batch_norm=batch_norm + batch_norm=batch_norm, ).to(device) - + # Train model best_val_loss = train_model( - model, - train_loader, - val_loader, - num_epochs=args.num_epochs, + model, + train_loader, + val_loader, + num_epochs=args.num_epochs, learning_rate=learning_rate, - l2_reg=l2_reg, + l2_reg=l2_reg, patience=args.patience, event_weights=event_weights, - verbose=False # Set to False to reduce output during Optuna trials + verbose=False, # Set to False to reduce output during Optuna trials ) - + # We can also evaluate with metrics like AUC # If needed, uncomment this code to optimize for AUC instead of validation loss """ # Calculate evaluation times safe_max = 0.99 * np.max(t_train) eval_times = np.quantile(t_train[t_train <= safe_max], [0.25, 0.5, 0.75]) - + # Calculate baseline CIFs - baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) + baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) for k in range(num_competing_risks)} - + # Predict and evaluate abs_risks = predict_absolute_risk(model, x_val, baseline_cifs, eval_times, device=device) _, avg_auc = evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times) - + return avg_auc # Maximize AUC """ - + # For now, we'll optimize for validation loss (minimize) return best_val_loss # Minimize validation loss @@ -364,10 +431,10 @@ def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, fe def main(): args = parse_args() print(args) - + # Set random seed for reproducibility set_seed(args.seed) - + # Load the dataset if args.dataset.lower() == "framingham": x, t, e, feature_names, n_cont, _ = load_framingham() @@ -379,21 +446,18 @@ def main(): x, t, e, feature_names, n_cont, _ = load_synthetic_dataset() else: raise ValueError(f"Dataset {args.dataset} not supported") - + # Note: Scaling will be done after train/validation split to prevent data leakage # Compute number of competing risks num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0) - device = ("cuda" if torch.cuda.is_available() else "cpu") - - + device = "cuda" if torch.cuda.is_available() else "cpu" - # Split data into train and validation sets x_train, x_val, t_train, t_val, e_train, e_val = train_test_split( x, t, e, test_size=0.2, random_state=args.seed, stratify=e ) - + # Apply scaling after split to prevent data leakage if args.scaling.lower() == "standard": scaler = StandardScaler() @@ -407,148 +471,183 @@ def main(): pass else: raise ValueError(f"Scaling method {args.scaling} not supported") - + # Initialize event weights event_weights = None - + if args.event_weighting != "none": if args.event_weighting == "balanced": # Compute balanced weights (inverse of class frequencies) event_counts = np.zeros(num_competing_risks) for k in range(1, num_competing_risks + 1): - event_counts[k-1] = np.sum(e_train == k) - + event_counts[k - 1] = np.sum(e_train == k) + # Avoid division by zero event_counts = np.maximum(event_counts, 1) - + # Inverse frequency weighting event_weights = 1.0 / event_counts - + # Normalize weights to sum to num_competing_risks event_weights = event_weights * (num_competing_risks / event_weights.sum()) - + print(f"Computed balanced event weights: {event_weights}") - + elif args.event_weighting == "custom": if args.custom_event_weights is None: - raise ValueError("Custom event weights must be provided when using custom weighting") - + raise ValueError( + "Custom event weights must be provided when using custom weighting" + ) + custom_weights = [float(w) for w in args.custom_event_weights.split(",")] if len(custom_weights) != num_competing_risks: - raise ValueError(f"Expected {num_competing_risks} weights, got {len(custom_weights)}") - + raise ValueError( + f"Expected {num_competing_risks} weights, got {len(custom_weights)}" + ) + event_weights = np.array(custom_weights) print(f"Using custom event weights: {event_weights}") - + # Convert to torch tensor event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device) - + # Create Optuna study study = optuna.create_study(direction="minimize") # Minimize validation loss - + # Run optimization print("\n=== Starting Optuna hyperparameter optimization ===") - study.optimize(lambda trial: objective( - trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, - feature_names, n_cont, num_competing_risks, device, args, event_weights - ), n_trials=args.n_trials) - + study.optimize( + lambda trial: objective( + trial, + x, + t, + e, + x_train, + t_train, + e_train, + x_val, + t_val, + e_val, + feature_names, + n_cont, + num_competing_risks, + device, + args, + event_weights, + ), + n_trials=args.n_trials, + ) + # Get best parameters best_params = study.best_params print("\n=== Best Hyperparameters ===") for param, value in best_params.items(): print(f"{param}: {value}") - # Extract hidden dimensions from best params - n_layers = best_params.pop('n_layers') + n_layers = best_params.pop("n_layers") hidden_dimensions = [] for i in range(n_layers): - hidden_dimensions.append(best_params.pop(f'hidden_dim_{i}')) - + hidden_dimensions.append(best_params.pop(f"hidden_dim_{i}")) + # Save best hyperparameters to YAML file config_dict = { "dataset": args.dataset, "scaling": args.scaling, "num_epochs": args.num_epochs, "batch_size": args.batch_size, - "learning_rate": best_params['learning_rate'], - "l2_reg": best_params['l2_reg'], + "learning_rate": best_params["learning_rate"], + "l2_reg": best_params["l2_reg"], "patience": args.patience, - "dropout_rate": best_params['dropout_rate'], - "feature_dropout": best_params['feature_dropout'], - "hidden_dimensions": ",".join(map(str, hidden_dimensions)), # Format as comma-separated string - "batch_norm": str(best_params['batch_norm']), # Convert to string "True" or "False" + "dropout_rate": best_params["dropout_rate"], + "feature_dropout": best_params["feature_dropout"], + "hidden_dimensions": ",".join( + map(str, hidden_dimensions) + ), # Format as comma-separated string + "batch_norm": str( + best_params["batch_norm"] + ), # Convert to string "True" or "False" "event_weighting": args.event_weighting, "seed": args.seed, - "n_folds": 5 # Default for k-fold script + "n_folds": 5, # Default for k-fold script } - + # Add custom event weights if applicable if args.event_weighting == "custom" and args.custom_event_weights is not None: config_dict["custom_event_weights"] = args.custom_event_weights - + # Save to YAML file output_file = f"best_params_{args.dataset}.yaml" - with open(output_file, 'w') as file: + with open(output_file, "w") as file: yaml.dump(config_dict, file, default_flow_style=False) - + print(f"\nBest hyperparameters saved to {output_file}") - + # Train final model with best parameters print("\n=== Training final model with best parameters ===") - + # Create data loaders with all training data train_dataset = SurvivalDataset(x_train, t_train, e_train) val_dataset = SurvivalDataset(x_val, t_val, e_val) - - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=args.batch_size,num_workers=4, pin_memory=True) - - + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=4, + pin_memory=True, + shuffle=True, + ) + val_loader = DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True + ) + # Initialize final model with best parameters final_model = CrispNamModel( num_features=x.shape[1], num_competing_risks=num_competing_risks, hidden_sizes=hidden_dimensions, - dropout_rate=best_params['dropout_rate'], - feature_dropout=best_params['feature_dropout'], - batch_norm=best_params['batch_norm'] + dropout_rate=best_params["dropout_rate"], + feature_dropout=best_params["feature_dropout"], + batch_norm=best_params["batch_norm"], ).to(device) - + # Train final model train_model( - final_model, - train_loader, - val_loader, - num_epochs=args.num_epochs, - learning_rate=best_params['learning_rate'], - l2_reg=best_params['l2_reg'], + final_model, + train_loader, + val_loader, + num_epochs=args.num_epochs, + learning_rate=best_params["learning_rate"], + l2_reg=best_params["l2_reg"], patience=args.patience, event_weights=event_weights, - verbose=True + verbose=True, ) - + # Evaluate final model # Calculate evaluation times safe_max = 0.99 * np.max(t_train) eval_times = np.quantile(t_train[t_train <= safe_max], [0.25, 0.5, 0.75]) - + print(f"Evaluation times: {eval_times}") - + # Calculate baseline CIFs - baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) - for k in range(num_competing_risks)} - + baseline_cifs = { + k: compute_baseline_cif(t_train, e_train, eval_times, k + 1) + for k in range(num_competing_risks) + } + # Predict and evaluate - abs_risks = predict_absolute_risk(final_model, x_val, baseline_cifs, eval_times, device=device) - final_metrics, _ = evaluate_model(final_model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times) - + abs_risks = predict_absolute_risk( + final_model, x_val, baseline_cifs, eval_times, device=device + ) + final_metrics, _ = evaluate_model( + final_model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times + ) + # Display metrics display_metrics_table(final_metrics) - - + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/training_scripts/tune_optuna_optimized.py b/training_scripts/tune_optuna_optimized.py index eb226813..7daabf4c 100644 --- a/training_scripts/tune_optuna_optimized.py +++ b/training_scripts/tune_optuna_optimized.py @@ -1,35 +1,29 @@ -import yaml -import configargparse from collections import defaultdict -import torch -import optuna +import configargparse import numpy as np -import torch.nn as nn +import optuna +import torch +import yaml +from model_utils import EarlyStopping, set_seed +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sksurv.metrics import brier_score, concordance_index_ipcw, cumulative_dynamic_auc from sksurv.util import Surv -import torch.optim as optim from tabulate import tabulate +from torch import nn, optim from torch.utils.data import DataLoader, TensorDataset -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler, MinMaxScaler -from sksurv.metrics import ( - cumulative_dynamic_auc, - brier_score, - concordance_index_ipcw -) -from data_utils import * -from model_utils import EarlyStopping, set_seed from crisp_nam.models import CrispNamModel from crisp_nam.utils import ( - weighted_negative_log_likelihood_loss, - negative_log_likelihood_loss, + compute_baseline_cif, compute_l2_penalty, -) -from crisp_nam.utils import ( + negative_log_likelihood_loss, predict_absolute_risk, - compute_baseline_cif + weighted_negative_log_likelihood_loss, ) +from data_utils import * + def train_model( model: nn.Module, @@ -45,7 +39,7 @@ def train_model( device = next(model.parameters()).device optimizer = optim.AdamW(model.parameters(), lr=lr) early_stop = EarlyStopping(patience) - scaler = torch.amp.GradScaler('cuda') + scaler = torch.amp.GradScaler("cuda") for epoch in range(1, num_epochs + 1): model.train() @@ -56,14 +50,16 @@ def train_model( eb = eb.to(device, non_blocking=True) optimizer.zero_grad() - with torch.amp.autocast('cuda'): + with torch.amp.autocast("cuda"): scores, _ = model(xb) if event_weights is not None: loss = weighted_negative_log_likelihood_loss( - scores, tb, eb, model.num_competing_risks, event_weights) + scores, tb, eb, model.num_competing_risks, event_weights + ) else: loss = negative_log_likelihood_loss( - scores, tb, eb, model.num_competing_risks) + scores, tb, eb, model.num_competing_risks + ) loss = loss + compute_l2_penalty(model) * l2_reg scaler.scale(loss).backward() @@ -81,26 +77,30 @@ def train_model( xb = xb.to(device, non_blocking=True) tb = tb.to(device, non_blocking=True) eb = eb.to(device, non_blocking=True) - with torch.amp.autocast('cuda'): + with torch.amp.autocast("cuda"): scores, _ = model(xb) if event_weights is not None: loss = weighted_negative_log_likelihood_loss( - scores, tb, eb, model.num_competing_risks, event_weights) + scores, tb, eb, model.num_competing_risks, event_weights + ) else: loss = negative_log_likelihood_loss( - scores, tb, eb, model.num_competing_risks) + scores, tb, eb, model.num_competing_risks + ) loss = loss + compute_l2_penalty(model) * l2_reg val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) if verbose: - print(f"Epoch {epoch} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}") + print( + f"Epoch {epoch} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}" + ) if early_stop.step(avg_val_loss): - if verbose: print("Early stopping.") + if verbose: + print("Early stopping.") return avg_val_loss - else: - if verbose: - print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}") + elif verbose: + print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}") return avg_val_loss if val_loader is not None else avg_train_loss @@ -115,15 +115,16 @@ def objective( args, event_weights, ): - - lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True) - l2 = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True) - dropout = trial.suggest_float('dropout_rate', 0.0, 0.8) - feat_drop = trial.suggest_float('feature_dropout', 0.0, 0.5) - n_layers = trial.suggest_int('n_layers', 1, 3) - hidden_dims = [trial.suggest_categorical(f'hidden_dim_{i}', [32,64,128,256]) - for i in range(n_layers)] - batch_norm = trial.suggest_categorical('batch_norm', [True, False]) + lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True) + l2 = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True) + dropout = trial.suggest_float("dropout_rate", 0.0, 0.8) + feat_drop = trial.suggest_float("feature_dropout", 0.0, 0.5) + n_layers = trial.suggest_int("n_layers", 1, 3) + hidden_dims = [ + trial.suggest_categorical(f"hidden_dim_{i}", [32, 64, 128, 256]) + for i in range(n_layers) + ] + batch_norm = trial.suggest_categorical("batch_norm", [True, False]) loader_kwargs = dict( batch_size=args.batch_size, @@ -134,7 +135,7 @@ def objective( ) train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs) - val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs) + val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs) model = CrispNamModel( num_features=num_features, @@ -171,12 +172,12 @@ def evaluate_model( device: torch.device, ): surv_train = Surv.from_arrays(e_train != 0, t_train) - surv_val = Surv.from_arrays(e_val != 0, t_val) + surv_val = Surv.from_arrays(e_val != 0, t_val) abs_risk = predict_absolute_risk(model, X, baseline_cifs, eval_times, device=device) metrics = defaultdict(list) for k_idx, t0 in enumerate(eval_times, start=1): - preds = abs_risk[:, k_idx-1, :] + preds = abs_risk[:, k_idx - 1, :] for ti_idx, ti in enumerate(eval_times): r = preds[:, ti_idx] try: @@ -185,7 +186,9 @@ def evaluate_model( except: metrics[f"auc_event{k_idx}_t{ti:.2f}"].append(np.nan) try: - _, bsc = brier_score(surv_train, surv_val, 1 - r.reshape(-1,1), times=np.array([ti])) + _, bsc = brier_score( + surv_train, surv_val, 1 - r.reshape(-1, 1), times=np.array([ti]) + ) metrics[f"brier_event{k_idx}_t{ti:.2f}"].append(float(bsc[0])) except: metrics[f"brier_event{k_idx}_t{ti:.2f}"].append(np.nan) @@ -204,21 +207,34 @@ def parse_args(): default_config_files=["config.yml"], config_file_parser_class=configargparse.YAMLConfigFileParser, ) - parser.add_argument("--dataset", type=str, default="framingham", - choices=["framingham","support","pbc","synthetic"] ) - parser.add_argument("--scaling", type=str, default="standard", - choices=["minmax","standard","none"] ) + parser.add_argument( + "--dataset", + type=str, + default="framingham", + choices=["framingham", "support", "pbc", "synthetic"], + ) + parser.add_argument( + "--scaling", + type=str, + default="standard", + choices=["minmax", "standard", "none"], + ) parser.add_argument("--num_epochs", type=int, default=500) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--patience", type=int, default=10) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--n_trials", type=int, default=50) - parser.add_argument("--event_weighting", type=str, default="none", - choices=["none","balanced","custom"] ) + parser.add_argument( + "--event_weighting", + type=str, + default="none", + choices=["none", "balanced", "custom"], + ) parser.add_argument("--custom_event_weights", type=str, default=None) parser.add_argument("--seed", type=int, default=42) return parser.parse_args() + # ----------------------- Main ----------------------- # def main(): args = parse_args() @@ -226,74 +242,81 @@ def main(): set_seed(args.seed) # Load dataset - if args.dataset == 'framingham': + if args.dataset == "framingham": x, t, e, feature_names, n_cont, _ = load_framingham() - elif args.dataset == 'support': + elif args.dataset == "support": x, t, e, feature_names, n_cont, _ = load_support_dataset() - elif args.dataset == 'pbc': + elif args.dataset == "pbc": x, t, e, feature_names, n_cont, _ = load_pbc2_dataset() else: x, t, e, feature_names, n_cont, _ = load_synthetic_dataset() # Scale continuous features - if args.scaling == 'standard': + if args.scaling == "standard": x[:, -n_cont:] = StandardScaler().fit_transform(x[:, -n_cont:]) - elif args.scaling == 'minmax': + elif args.scaling == "minmax": x[:, -n_cont:] = MinMaxScaler().fit_transform(x[:, -n_cont:]) # Bulk convert to torch tensors on CPU - X = torch.from_numpy(x.astype('float32')) - T = torch.from_numpy(t.astype('float32')) - E = torch.from_numpy(e.astype('int64')) + X = torch.from_numpy(x.astype("float32")) + T = torch.from_numpy(t.astype("float32")) + E = torch.from_numpy(e.astype("int64")) # Train/validation split (fixed) idx_train, idx_val = train_test_split( np.arange(len(e)), test_size=0.2, random_state=args.seed, stratify=e ) train_ds = TensorDataset(X[idx_train], T[idx_train], E[idx_train]) - val_ds = TensorDataset(X[idx_val], T[idx_val], E[idx_val]) + val_ds = TensorDataset(X[idx_val], T[idx_val], E[idx_val]) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_competing_risks = len(np.unique(e)) - 1 # Event weighting event_weights = None - if args.event_weighting == 'balanced': - counts = np.array([(e[idx_train]==k).sum() for k in range(1, num_competing_risks+1)]) + if args.event_weighting == "balanced": + counts = np.array( + [(e[idx_train] == k).sum() for k in range(1, num_competing_risks + 1)] + ) counts = np.maximum(counts, 1) w = 1.0 / counts w *= num_competing_risks / w.sum() - event_weights = torch.from_numpy(w.astype('float32')).to(device) - elif args.event_weighting == 'custom': - w = np.array(list(map(float, args.custom_event_weights.split(',')))) - event_weights = torch.from_numpy(w.astype('float32')).to(device) + event_weights = torch.from_numpy(w.astype("float32")).to(device) + elif args.event_weighting == "custom": + w = np.array(list(map(float, args.custom_event_weights.split(",")))) + event_weights = torch.from_numpy(w.astype("float32")).to(device) # Optuna study - study = optuna.create_study(direction='minimize') + study = optuna.create_study(direction="minimize") study.optimize( lambda trial: objective( - trial, train_ds, val_ds, - x.shape[1], num_competing_risks, - device, args, event_weights + trial, + train_ds, + val_ds, + x.shape[1], + num_competing_risks, + device, + args, + event_weights, ), - n_trials=args.n_trials + n_trials=args.n_trials, ) best = study.best_params print("Best hyperparameters:", best) - with open(f"best_params_{args.dataset}.yaml", 'w') as f: + with open(f"best_params_{args.dataset}.yaml", "w") as f: yaml.dump(best, f) # Final model training with best params - n_layers = best['n_layers'] - hidden_dims = [best[f'hidden_dim_{i}'] for i in range(n_layers)] + n_layers = best["n_layers"] + hidden_dims = [best[f"hidden_dim_{i}"] for i in range(n_layers)] final_model = CrispNamModel( num_features=x.shape[1], num_competing_risks=num_competing_risks, hidden_sizes=hidden_dims, - dropout_rate=best['dropout_rate'], - feature_dropout=best['feature_dropout'], - batch_norm=best['batch_norm'], + dropout_rate=best["dropout_rate"], + feature_dropout=best["feature_dropout"], + batch_norm=best["batch_norm"], ).to(device) loader_kwargs = dict( @@ -304,15 +327,15 @@ def main(): prefetch_factor=2, ) final_train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs) - final_val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs) + final_val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs) train_model( final_model, final_train_loader, val_loader=final_val_loader, num_epochs=args.num_epochs, - lr=best['learning_rate'], - l2_reg=best['l2_reg'], + lr=best["learning_rate"], + l2_reg=best["l2_reg"], patience=args.patience, event_weights=event_weights, verbose=True, @@ -322,29 +345,36 @@ def main(): safe_max = 0.99 * np.max(t[idx_train]) eval_times = np.quantile(t[idx_train][t[idx_train] <= safe_max], [0.25, 0.5, 0.75]) baseline_cifs = { - k: compute_baseline_cif( - t[idx_train], e[idx_train], eval_times, k+1 - ) for k in range(num_competing_risks) + k: compute_baseline_cif(t[idx_train], e[idx_train], eval_times, k + 1) + for k in range(num_competing_risks) } metrics = evaluate_model( final_model, - x[idx_val], t[idx_val], e[idx_val], - t[idx_train], e[idx_train], - baseline_cifs, eval_times, device + x[idx_val], + t[idx_val], + e[idx_val], + t[idx_train], + e[idx_train], + baseline_cifs, + eval_times, + device, ) # Reporting rows = [] - for k in range(1, num_competing_risks+1): - row = {'Risk': f'Type {k}'} - for metric in ['auc', 'tdci', 'brier']: + for k in range(1, num_competing_risks + 1): + row = {"Risk": f"Type {k}"} + for metric in ["auc", "tdci", "brier"]: for q, ti in zip([0.25, 0.5, 0.75], eval_times): key = f"{metric}_event{k}_t{ti:.2f}" vals = np.array(metrics[key], dtype=float) - row[f"{metric.upper()}_q{q:.2f}"] = f"{np.nanmean(vals):.3f} ± {np.nanstd(vals):.3f}" + row[f"{metric.upper()}_q{q:.2f}"] = ( + f"{np.nanmean(vals):.3f} ± {np.nanstd(vals):.3f}" + ) rows.append(row) print("\n=== Performance ===") - print(tabulate(rows, headers='keys', tablefmt='pretty', showindex=False)) + print(tabulate(rows, headers="keys", tablefmt="pretty", showindex=False)) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/utils/loss.py b/utils/loss.py deleted file mode 100644 index 9cae44fd..00000000 --- a/utils/loss.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch - -def weighted_negative_log_likelihood_loss(risk_scores, times, events, - num_competing_risks, event_weights=None, - sample_weights=None, eps=1e-8) -> float: - """ - Computes the weighted negative log-likelihood loss for competing risks Cox model. - - Args: - risk_scores: List of tensors with shape (batch_size, 1) for each competing risk - times: Event/censoring times (batch_size,) - events: Event indicators (0=censored, 1...K=event types) (batch_size,) - num_competing_risks: Number of competing risks - event_weights: Tensor of weights for each competing risk type (size: num_competing_risks) - sample_weights: Tensor of weights for each sample (size: batch_size) - eps: Small constant for numerical stability - - Returns: - Weighted negative log partial likelihood loss - """ - device = times.device - batch_size = times.shape[0] - - # Initialize loss - loss = torch.tensor(0.0, device=device) - - # Set default weights if not provided - if event_weights is None: - event_weights = torch.ones(num_competing_risks, device=device) - if sample_weights is None: - sample_weights = torch.ones(batch_size, device=device) - - # Count number of events - n_events = (events > 0).sum().item() - if n_events == 0: - return loss - - # Process each competing risk separately - for k in range(1, num_competing_risks + 1): - # Find samples with this event type - event_mask = (events == k) - n_events_k = event_mask.sum().item() - - if n_events_k == 0: - continue - - # Get risk scores for this competing risk - risk_k = risk_scores[k-1].squeeze() - - # Get weight for this event type - event_weight = event_weights[k-1] - - # For each event of type k - for i in range(batch_size): - if event_mask[i]: - # Find samples in risk set (samples with time >= event time) - risk_set = (times >= times[i]) - - # Calculate log sum of exp of risk scores in risk set - risk_set_scores = risk_k[risk_set] - log_risk_sum = torch.logsumexp(risk_set_scores, dim=0) - - # Subtract individual risk score from log sum and apply weights - individual_loss = log_risk_sum - risk_k[i] - weighted_individual_loss = individual_loss * event_weight * sample_weights[i] - loss += weighted_individual_loss - - # Return average loss - return loss / max(n_events, 1) - -def negative_log_likelihood_loss(risk_scores, times, events, - num_competing_risks, eps=1e-8): - """ - Computes the negative log-likelihood loss for competing risks Cox model. - - Args: - risk_scores: List of tensors with shape (batch_size, 1) for each competing risk - times: Event/censoring times (batch_size,) - events: Event indicators (0=censored, 1...K=event types) (batch_size,) - num_competing_risks: Number of competing risks - eps: Small constant for numerical stability - - Returns: - Negative log partial likelihood loss - """ - device = times.device - batch_size = times.shape[0] - - # Initialize loss - loss = torch.tensor(0.0, device=device) - - # Count number of events - n_events = (events > 0).sum().item() - if n_events == 0: - return loss - - # Process each competing risk separately - for k in range(1, num_competing_risks + 1): - # Find samples with this event type - event_mask = (events == k) - n_events_k = event_mask.sum().item() - - if n_events_k == 0: - continue - - # Get risk scores for this competing risk - risk_k = risk_scores[k-1].squeeze() - - # For each event of type k - for i in range(batch_size): - if event_mask[i]: - # Find samples in risk set (samples with time >= event time) - risk_set = (times >= times[i]) - - # Calculate log sum of exp of risk scores in risk set - risk_set_scores = risk_k[risk_set] - log_risk_sum = torch.logsumexp(risk_set_scores, dim=0) - - # Subtract individual risk score from log sum - loss += log_risk_sum - risk_k[i] - - # Return average loss - return loss / max(n_events, 1) - -def compute_l2_penalty(model, include_bias=False) -> int: - """ - Compute L2 regularization penalty on model parameters - - Args: - model: Neural network model - include_bias: Whether to include bias terms in regularization - - Returns: - L2 penalty term - """ - l2_reg = 0.0 - for name, param in model.named_parameters(): - if param.requires_grad: - # Skip bias parameters if specified - if not include_bias and 'bias' in name: - continue - l2_reg += torch.sum(param ** 2) - return l2_reg \ No newline at end of file diff --git a/utils/plotting.py b/utils/plotting.py deleted file mode 100644 index 1702632e..00000000 --- a/utils/plotting.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -from typing import Union, List - -def plot_feature_importance(model: torch.nn.Module, - x_data: Union[np.ndarray, torch.Tensor], - feature_names = None, - n_top:int = 5, - n_bottom:int = 5, - risk_idx:int = 1, - figsize: tuple = (8, 6), - output_file:str = None, - color_positive:str = '#2196F3', - color_negative:str = '#F44336') -> tuple: - """ - Plot feature importance with both top positive and negative influences, - handling both CPU and CUDA devices automatically. - """ - # determine model device - device = next(model.parameters()).device - model.eval() - - # prepare feature names - num_features = model.num_features - if feature_names is None: - feature_names = [f"Feature {i+1}" for i in range(num_features)] - - # convert x_data to tensor on the model device - if not isinstance(x_data, torch.Tensor): - x = torch.tensor(x_data, dtype=torch.float32, device=device) - else: - x = x_data.to(device) - - feature_contribs = {} - risk_idx0 = risk_idx - 1 - - with torch.no_grad(): - for i in range(num_features): - vals = x[:, i].unsqueeze(1) # shape (N,1) - if torch.var(vals) <= 1e-8: - feature_contribs[feature_names[i]] = 0.0 - continue - - # forward through the feature net and projection - rep = model.feature_nets[i](vals) - proj = model.risk_projections[i][risk_idx0](rep) - # mean contribution as a Python float - contrib = proj.mean().item() - feature_contribs[feature_names[i]] = contrib - - # build a DataFrame for sorting - df = pd.DataFrame({ - 'feature': list(feature_contribs.keys()), - 'contribution': list(feature_contribs.values()) - }) - df['abs_contrib'] = df['contribution'].abs() - df = df.sort_values('abs_contrib', ascending=False) - - pos = df[df['contribution'] > 0].head(n_top).sort_values('contribution') - neg = df[df['contribution'] < 0].head(n_bottom).sort_values('contribution', ascending=False) - - top_pos = pos['feature'].tolist() - top_neg = neg['feature'].tolist() - - # plotting - fig, ax = plt.subplots(figsize=figsize) - ax.barh(pos['feature'], pos['contribution'], color=color_positive, alpha=0.8) - ax.barh(neg['feature'], neg['contribution'], color=color_negative, alpha=0.8) - ax.axvline(0, color='black', linestyle='-', alpha=0.3) - - ax.set_xlabel('Contribution to Risk Score') - ax.set_title(f'Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}') - ax.grid(axis='x', linestyle='--', alpha=0.5) - plt.tight_layout() - - if output_file: - plt.savefig(output_file, bbox_inches='tight', dpi=300) - - return fig, ax, top_pos, top_neg - - -def plot_coxnam_shape_functions(model:torch.nn.Module, - X: Union[np.ndarray, torch.Tensor], - risk_to_plot:int = 1, - feature_names:List[str] = None, - top_features:List[str] = None, - ncols:int = 3, - figsize:tuple = (12, 8), - output_file:str = None) -> tuple: - """ - Plot shape functions for each feature in a CoxNAM model, - automatically handling CPU vs CUDA inputs. - """ - device = next(model.parameters()).device - model.eval() - risk_idx = risk_to_plot - 1 - - # ensure X is a numpy array - if isinstance(X, torch.Tensor): - X_np = X.cpu().numpy() - else: - X_np = np.array(X, dtype=float) - - # derive feature list - num_features = model.num_features - if feature_names is None: - feature_names = [f"Feature {i+1}" for i in range(num_features)] - if top_features: - # map names back to indices - idx_map = {name: i for i, name in enumerate(feature_names)} - selected = [(idx_map.get(name, None), name) for name in top_features] - selected = [(i, name) for i, name in selected if i is not None] - else: - selected = list(zip(range(num_features), feature_names)) - - n_selected = len(selected) - nrows = int(np.ceil(n_selected / ncols)) - fig, axes = plt.subplots(nrows, ncols, figsize=(figsize)) - axes = np.array(axes).reshape(-1) - - with torch.no_grad(): - for ax, (f_idx, fname) in zip(axes, selected): - vals = X_np[:, f_idx] - if vals.size == 0: - ax.text(0.5, 0.5, "no data", ha='center', va='center') - continue - - # choose evaluation points - if np.issubdtype(vals.dtype, np.integer) or len(np.unique(vals)) <= 10: - pts = np.unique(vals) - else: - pts = np.linspace(vals.min(), vals.max(), 100) - - # convert to tensor on correct device - t_pts = torch.tensor(pts, dtype=torch.float32, device=device).unsqueeze(1) - - # compute shape values - rep = model.feature_nets[f_idx](t_pts) - proj = model.risk_projections[f_idx][risk_idx](rep) - shp = proj.squeeze(-1).cpu().numpy() - - # plot - ax.plot(pts, shp, linewidth=2) - ax.axhline(0, linestyle='--', alpha=0.5) - ax.set_title(fname) - ax.set_xlabel('Value') - ax.set_ylabel('Contribution') - # rug plot - ax.plot(vals, np.zeros_like(vals)-0.1, '|', alpha=0.3) - - # turn off any extra axes - for ax in axes[n_selected:]: - ax.axis('off') - - fig.suptitle(f'Shape Functions for Risk {risk_to_plot}', fontsize=14) - plt.tight_layout(rect=[0,0,1,0.96]) - if output_file: - plt.savefig(output_file, dpi=300, bbox_inches='tight') - - return fig, axes[:n_selected] diff --git a/utils/risk_cif.py b/utils/risk_cif.py deleted file mode 100644 index 0cd38d54..00000000 --- a/utils/risk_cif.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch -import numpy as np -from typing import List, Any - -def compute_baseline_cif(times:np.ndarray, - events:np.ndarray, - eval_times:List[Any], - event_type:np.ndarray) -> np.ndarray: - """ - Compute baseline cumulative incidence function for a specific event type - - Args: - times: Numpy array of event times - events: Numpy array of event indicators (0=censored, 1...K=event types) - eval_times: Times at which to evaluate the CIF - event_type: Event type to compute CIF for (1...K) - - Returns: - Numpy array of baseline CIF values at eval_times - """ - # Sort times and corresponding events - sort_idx = np.argsort(times) - sorted_times = times[sort_idx] - sorted_events = events[sort_idx] - - # Initialize cumulative hazard - n_samples = len(times) - baseline_cif = np.zeros(len(eval_times)) - - # For each evaluation time - for i, t in enumerate(eval_times): - cif_t = 0.0 - # Count number of events of the specified type before time t - event_count = np.sum((sorted_events == event_type) & (sorted_times <= t)) - if event_count > 0: - # Simple Aalen-Johansen estimator - cif_t = event_count / n_samples - baseline_cif[i] = cif_t - - return baseline_cif - - -def predict_cif(model:torch.Module, - x:np.ndarray, - baseline_cif:np.ndarray, - times:np.ndarray, - event_of_interest:int) -> np.ndarray: - """ - Predict cumulative incidence function for a specific competing risk. - - Args: - model: Trained model. - x: Input tensor of shape (n_samples, n_features). - baseline_cif: Array of shape (len(times),) — estimated CIF for baseline (e.g. from compute_baseline_cif). - times: Time points at which CIF is evaluated. - event_type: Integer, 0-based index of event of interest. - - Returns: - cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample. - """ - model.eval() - with torch.no_grad(): - logits, _ = model(x) # list of length num_risks - f_j_x = logits[event_of_interest].squeeze(1).cpu().numpy() # (n_samples,) - - baseline_cif = np.asarray(baseline_cif).reshape(1, -1) # (1, T) - risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1) - - # Fine-Gray style CIF prediction under PH assumption - cif_pred = 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T) - - return cif_pred - -def predict_risk(model:np.ndarray, - x_input:np.ndarray, - device:str = 'cpu'): - """ - Predicts relative risk scores for each competing risk. - - Args: - model : Trained model. - x_input (np.ndarray or torch.Tensor): Input features of shape (n_samples, n_features). - device (str): Device to run the computation on. - - Returns: - np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores. - """ - model.eval() - - if isinstance(x_input, np.ndarray): - x_tensor = torch.from_numpy(x_input).float().to(device) - else: - x_tensor = x_input.to(device).float() - - with torch.no_grad(): - risk_outputs, _ = model(x_tensor) # List of [batch_size, 1] tensors - risks = torch.cat(risk_outputs, dim=1) # Shape: [batch_size, num_risks] - - return risks.cpu().numpy() - -def predict_absolute_risk(model:torch.Tensor, - x_input:np.ndarray, - baseline_cifs:List[Any], - times:List[Any], - device:str = 'cpu') -> np.ndarray: - """ - Predict absolute risk (CIF) for each competing event by given time points. - - Args: - model: Trained model. - x_input (np.ndarray or Tensor): Input features, shape (n_samples, n_features). - baseline_cifs (dict): Mapping of event index to baseline CIF array of shape (n_times,). - times (np.ndarray): Time grid used for baseline_cifs. - device: CPU or CUDA. - - Returns: - np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs. - """ - rel_risks = predict_risk(model, x_input, device) # shape (n_samples, num_events) - n_samples, num_events = rel_risks.shape - n_times = len(times) - - abs_risks = np.zeros((n_samples, num_events, n_times)) - - for k in range(num_events): - base_cif = np.clip(baseline_cifs[k], 1e-10, 0.9999) # avoid edge cases - for i in range(n_samples): - abs_risks[i, k, :] = 1 - np.power(1 - base_cif, np.exp(rel_risks[i, k])) - - return abs_risks diff --git a/uv.lock b/uv.lock index 1927b861..a84729ed 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -45,6 +46,139 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/85/ae/7f2031ea76140444b2453fa139041e5afd4a09fc5300cfefeb1103291f80/autograd-gamma-0.5.0.tar.gz", hash = "sha256:f27abb7b8bb9cffc8badcbf59f3fe44a9db39e124ecacf1992b6d952934ac9c4", size = 3952 } +[[package]] +name = "babel" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/51899539b6ceeeb420d40ed3cd4b7a40519404f9baf3d4ac99dc413a834b/babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d", size = 9959554 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845 }, +] + +[[package]] +name = "backrefs" +version = "6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/86/e3/bb3a439d5cb255c4774724810ad8073830fac9c9dee123555820c1bcc806/backrefs-6.1.tar.gz", hash = "sha256:3bba1749aafe1db9b915f00e0dd166cba613b6f788ffd63060ac3485dc9be231", size = 7011962 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ee/c216d52f58ea75b5e1841022bbae24438b19834a29b163cb32aa3a2a7c6e/backrefs-6.1-py310-none-any.whl", hash = "sha256:2a2ccb96302337ce61ee4717ceacfbf26ba4efb1d55af86564b8bbaeda39cac1", size = 381059 }, + { url = "https://files.pythonhosted.org/packages/e6/9a/8da246d988ded941da96c7ed945d63e94a445637eaad985a0ed88787cb89/backrefs-6.1-py311-none-any.whl", hash = "sha256:e82bba3875ee4430f4de4b6db19429a27275d95a5f3773c57e9e18abc23fd2b7", size = 392854 }, + { url = "https://files.pythonhosted.org/packages/37/c9/fd117a6f9300c62bbc33bc337fd2b3c6bfe28b6e9701de336b52d7a797ad/backrefs-6.1-py312-none-any.whl", hash = "sha256:c64698c8d2269343d88947c0735cb4b78745bd3ba590e10313fbf3f78c34da5a", size = 398770 }, + { url = "https://files.pythonhosted.org/packages/eb/95/7118e935b0b0bd3f94dfec2d852fd4e4f4f9757bdb49850519acd245cd3a/backrefs-6.1-py313-none-any.whl", hash = "sha256:4c9d3dc1e2e558965202c012304f33d4e0e477e1c103663fd2c3cc9bb18b0d05", size = 400726 }, + { url = "https://files.pythonhosted.org/packages/1d/72/6296bad135bfafd3254ae3648cd152980a424bd6fed64a101af00cc7ba31/backrefs-6.1-py314-none-any.whl", hash = "sha256:13eafbc9ccd5222e9c1f0bec563e6d2a6d21514962f11e7fc79872fd56cbc853", size = 412584 }, + { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058 }, +] + +[[package]] +name = "certifi" +version = "2026.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900 }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709 }, + { url = "https://files.pythonhosted.org/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814 }, + { url = "https://files.pythonhosted.org/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467 }, + { url = "https://files.pythonhosted.org/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280 }, + { url = "https://files.pythonhosted.org/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454 }, + { url = "https://files.pythonhosted.org/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609 }, + { url = "https://files.pythonhosted.org/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849 }, + { url = "https://files.pythonhosted.org/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586 }, + { url = "https://files.pythonhosted.org/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290 }, + { url = "https://files.pythonhosted.org/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663 }, + { url = "https://files.pythonhosted.org/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964 }, + { url = "https://files.pythonhosted.org/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064 }, + { url = "https://files.pythonhosted.org/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015 }, + { url = "https://files.pythonhosted.org/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792 }, + { url = "https://files.pythonhosted.org/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198 }, + { url = "https://files.pythonhosted.org/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262 }, + { url = "https://files.pythonhosted.org/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988 }, + { url = "https://files.pythonhosted.org/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324 }, + { url = "https://files.pythonhosted.org/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742 }, + { url = "https://files.pythonhosted.org/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863 }, + { url = "https://files.pythonhosted.org/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837 }, + { url = "https://files.pythonhosted.org/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550 }, + { url = "https://files.pythonhosted.org/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162 }, + { url = "https://files.pythonhosted.org/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019 }, + { url = "https://files.pythonhosted.org/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310 }, + { url = "https://files.pythonhosted.org/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022 }, + { url = "https://files.pythonhosted.org/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383 }, + { url = "https://files.pythonhosted.org/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098 }, + { url = "https://files.pythonhosted.org/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991 }, + { url = "https://files.pythonhosted.org/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456 }, + { url = "https://files.pythonhosted.org/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978 }, + { url = "https://files.pythonhosted.org/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969 }, + { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425 }, + { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162 }, + { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558 }, + { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497 }, + { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240 }, + { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471 }, + { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864 }, + { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647 }, + { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110 }, + { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839 }, + { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667 }, + { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535 }, + { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816 }, + { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694 }, + { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131 }, + { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390 }, + { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091 }, + { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936 }, + { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180 }, + { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346 }, + { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874 }, + { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076 }, + { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601 }, + { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376 }, + { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825 }, + { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583 }, + { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366 }, + { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300 }, + { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465 }, + { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404 }, + { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092 }, + { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408 }, + { url = "https://files.pythonhosted.org/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746 }, + { url = "https://files.pythonhosted.org/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889 }, + { url = "https://files.pythonhosted.org/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641 }, + { url = "https://files.pythonhosted.org/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779 }, + { url = "https://files.pythonhosted.org/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035 }, + { url = "https://files.pythonhosted.org/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542 }, + { url = "https://files.pythonhosted.org/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524 }, + { url = "https://files.pythonhosted.org/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395 }, + { url = "https://files.pythonhosted.org/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680 }, + { url = "https://files.pythonhosted.org/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045 }, + { url = "https://files.pythonhosted.org/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687 }, + { url = "https://files.pythonhosted.org/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014 }, + { url = "https://files.pythonhosted.org/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044 }, + { url = "https://files.pythonhosted.org/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940 }, + { url = "https://files.pythonhosted.org/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104 }, + { url = "https://files.pythonhosted.org/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743 }, + { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402 }, +] + +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -151,6 +285,7 @@ dependencies = [ { name = "configargparse" }, { name = "lifelines" }, { name = "matplotlib" }, + { name = "mypy" }, { name = "openpyxl" }, { name = "optuna" }, { name = "pandas" }, @@ -161,11 +296,23 @@ dependencies = [ { name = "torch" }, ] +[package.dev-dependencies] +docs = [ + { name = "click" }, + { name = "mike" }, + { name = "mkdocs" }, + { name = "mkdocs-material" }, + { name = "mkdocstrings" }, + { name = "mkdocstrings-python" }, + { name = "pymdown-extensions" }, +] + [package.metadata] requires-dist = [ { name = "configargparse", specifier = ">=1.7" }, { name = "lifelines", specifier = ">=0.30.0" }, { name = "matplotlib", specifier = ">=3.10.1" }, + { name = "mypy", specifier = ">=1.19.1" }, { name = "openpyxl", specifier = ">=3.1.5" }, { name = "optuna", specifier = ">=4.3.0" }, { name = "pandas", specifier = ">=2.2.3" }, @@ -176,6 +323,17 @@ requires-dist = [ { name = "torch", specifier = ">=2.7.0" }, ] +[package.metadata.requires-dev] +docs = [ + { name = "click", specifier = "<=8.2.1" }, + { name = "mike", specifier = ">=2.0.0" }, + { name = "mkdocs", specifier = ">=1.5.3" }, + { name = "mkdocs-material", specifier = ">=9.5.12" }, + { name = "mkdocstrings", specifier = ">=0.24.1" }, + { name = "mkdocstrings-python", specifier = ">=1.8.0" }, + { name = "pymdown-extensions", specifier = ">=10.7.1" }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -293,6 +451,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bb/61/78c7b3851add1481b048b5fdc29067397a1784e2910592bc81bb3f608635/fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462", size = 199052 }, ] +[[package]] +name = "ghp-import" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/29/d40217cbe2f6b1359e00c6c307bb3fc876ba74068cbab3dde77f03ca0dc4/ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343", size = 10943 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 }, +] + [[package]] name = "greenlet" version = "3.2.3" @@ -344,6 +514,68 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/4f/aab73ecaa6b3086a4c89863d94cf26fa84cbff63f52ce9bc4342b3087a06/greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a", size = 301236 }, ] +[[package]] +name = "griffe" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "griffecli" }, + { name = "griffelib" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214 }, +] + +[[package]] +name = "griffecli" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "griffelib" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345 }, +] + +[[package]] +name = "griffelib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004 }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008 }, +] + +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865 }, +] + +[[package]] +name = "importlib-resources" +version = "6.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 }, +] + [[package]] name = "interface-meta" version = "1.3.0" @@ -461,6 +693,79 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 }, ] +[[package]] +name = "librt" +version = "0.7.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/24/5f3646ff414285e0f7708fa4e946b9bf538345a41d1c375c439467721a5e/librt-0.7.8.tar.gz", hash = "sha256:1a4ede613941d9c3470b0368be851df6bb78ab218635512d0370b27a277a0862", size = 148323 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/13/57b06758a13550c5f09563893b004f98e9537ee6ec67b7df85c3571c8832/librt-0.7.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b45306a1fc5f53c9330fbee134d8b3227fe5da2ab09813b892790400aa49352d", size = 56521 }, + { url = "https://files.pythonhosted.org/packages/c2/24/bbea34d1452a10612fb45ac8356f95351ba40c2517e429602160a49d1fd0/librt-0.7.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:864c4b7083eeee250ed55135d2127b260d7eb4b5e953a9e5df09c852e327961b", size = 58456 }, + { url = "https://files.pythonhosted.org/packages/04/72/a168808f92253ec3a810beb1eceebc465701197dbc7e865a1c9ceb3c22c7/librt-0.7.8-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6938cc2de153bc927ed8d71c7d2f2ae01b4e96359126c602721340eb7ce1a92d", size = 164392 }, + { url = "https://files.pythonhosted.org/packages/14/5c/4c0d406f1b02735c2e7af8ff1ff03a6577b1369b91aa934a9fa2cc42c7ce/librt-0.7.8-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66daa6ac5de4288a5bbfbe55b4caa7bf0cd26b3269c7a476ffe8ce45f837f87d", size = 172959 }, + { url = "https://files.pythonhosted.org/packages/82/5f/3e85351c523f73ad8d938989e9a58c7f59fb9c17f761b9981b43f0025ce7/librt-0.7.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4864045f49dc9c974dadb942ac56a74cd0479a2aafa51ce272c490a82322ea3c", size = 186717 }, + { url = "https://files.pythonhosted.org/packages/08/f8/18bfe092e402d00fe00d33aa1e01dda1bd583ca100b393b4373847eade6d/librt-0.7.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a36515b1328dc5b3ffce79fe204985ca8572525452eacabee2166f44bb387b2c", size = 184585 }, + { url = "https://files.pythonhosted.org/packages/4e/fc/f43972ff56fd790a9fa55028a52ccea1875100edbb856b705bd393b601e3/librt-0.7.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b7e7f140c5169798f90b80d6e607ed2ba5059784968a004107c88ad61fb3641d", size = 180497 }, + { url = "https://files.pythonhosted.org/packages/e1/3a/25e36030315a410d3ad0b7d0f19f5f188e88d1613d7d3fd8150523ea1093/librt-0.7.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ff71447cb778a4f772ddc4ce360e6ba9c95527ed84a52096bd1bbf9fee2ec7c0", size = 200052 }, + { url = "https://files.pythonhosted.org/packages/fc/b8/f3a5a1931ae2a6ad92bf6893b9ef44325b88641d58723529e2c2935e8abe/librt-0.7.8-cp310-cp310-win32.whl", hash = "sha256:047164e5f68b7a8ebdf9fae91a3c2161d3192418aadd61ddd3a86a56cbe3dc85", size = 43477 }, + { url = "https://files.pythonhosted.org/packages/fe/91/c4202779366bc19f871b4ad25db10fcfa1e313c7893feb942f32668e8597/librt-0.7.8-cp310-cp310-win_amd64.whl", hash = "sha256:d6f254d096d84156a46a84861183c183d30734e52383602443292644d895047c", size = 49806 }, + { url = "https://files.pythonhosted.org/packages/1b/a3/87ea9c1049f2c781177496ebee29430e4631f439b8553a4969c88747d5d8/librt-0.7.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ff3e9c11aa260c31493d4b3197d1e28dd07768594a4f92bec4506849d736248f", size = 56507 }, + { url = "https://files.pythonhosted.org/packages/5e/4a/23bcef149f37f771ad30203d561fcfd45b02bc54947b91f7a9ac34815747/librt-0.7.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddb52499d0b3ed4aa88746aaf6f36a08314677d5c346234c3987ddc506404eac", size = 58455 }, + { url = "https://files.pythonhosted.org/packages/22/6e/46eb9b85c1b9761e0f42b6e6311e1cc544843ac897457062b9d5d0b21df4/librt-0.7.8-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e9c0afebbe6ce177ae8edba0c7c4d626f2a0fc12c33bb993d163817c41a7a05c", size = 164956 }, + { url = "https://files.pythonhosted.org/packages/7a/3f/aa7c7f6829fb83989feb7ba9aa11c662b34b4bd4bd5b262f2876ba3db58d/librt-0.7.8-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:631599598e2c76ded400c0a8722dec09217c89ff64dc54b060f598ed68e7d2a8", size = 174364 }, + { url = "https://files.pythonhosted.org/packages/3f/2d/d57d154b40b11f2cb851c4df0d4c4456bacd9b1ccc4ecb593ddec56c1a8b/librt-0.7.8-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c1ba843ae20db09b9d5c80475376168feb2640ce91cd9906414f23cc267a1ff", size = 188034 }, + { url = "https://files.pythonhosted.org/packages/59/f9/36c4dad00925c16cd69d744b87f7001792691857d3b79187e7a673e812fb/librt-0.7.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b5b007bb22ea4b255d3ee39dfd06d12534de2fcc3438567d9f48cdaf67ae1ae3", size = 186295 }, + { url = "https://files.pythonhosted.org/packages/23/9b/8a9889d3df5efb67695a67785028ccd58e661c3018237b73ad081691d0cb/librt-0.7.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:dbd79caaf77a3f590cbe32dc2447f718772d6eea59656a7dcb9311161b10fa75", size = 181470 }, + { url = "https://files.pythonhosted.org/packages/43/64/54d6ef11afca01fef8af78c230726a9394759f2addfbf7afc5e3cc032a45/librt-0.7.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:87808a8d1e0bd62a01cafc41f0fd6818b5a5d0ca0d8a55326a81643cdda8f873", size = 201713 }, + { url = "https://files.pythonhosted.org/packages/2d/29/73e7ed2991330b28919387656f54109139b49e19cd72902f466bd44415fd/librt-0.7.8-cp311-cp311-win32.whl", hash = "sha256:31724b93baa91512bd0a376e7cf0b59d8b631ee17923b1218a65456fa9bda2e7", size = 43803 }, + { url = "https://files.pythonhosted.org/packages/3f/de/66766ff48ed02b4d78deea30392ae200bcbd99ae61ba2418b49fd50a4831/librt-0.7.8-cp311-cp311-win_amd64.whl", hash = "sha256:978e8b5f13e52cf23a9e80f3286d7546baa70bc4ef35b51d97a709d0b28e537c", size = 50080 }, + { url = "https://files.pythonhosted.org/packages/6f/e3/33450438ff3a8c581d4ed7f798a70b07c3206d298cf0b87d3806e72e3ed8/librt-0.7.8-cp311-cp311-win_arm64.whl", hash = "sha256:20e3946863d872f7cabf7f77c6c9d370b8b3d74333d3a32471c50d3a86c0a232", size = 43383 }, + { url = "https://files.pythonhosted.org/packages/56/04/79d8fcb43cae376c7adbab7b2b9f65e48432c9eced62ac96703bcc16e09b/librt-0.7.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9b6943885b2d49c48d0cff23b16be830ba46b0152d98f62de49e735c6e655a63", size = 57472 }, + { url = "https://files.pythonhosted.org/packages/b4/ba/60b96e93043d3d659da91752689023a73981336446ae82078cddf706249e/librt-0.7.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46ef1f4b9b6cc364b11eea0ecc0897314447a66029ee1e55859acb3dd8757c93", size = 58986 }, + { url = "https://files.pythonhosted.org/packages/7c/26/5215e4cdcc26e7be7eee21955a7e13cbf1f6d7d7311461a6014544596fac/librt-0.7.8-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:907ad09cfab21e3c86e8f1f87858f7049d1097f77196959c033612f532b4e592", size = 168422 }, + { url = "https://files.pythonhosted.org/packages/0f/84/e8d1bc86fa0159bfc24f3d798d92cafd3897e84c7fea7fe61b3220915d76/librt-0.7.8-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2991b6c3775383752b3ca0204842743256f3ad3deeb1d0adc227d56b78a9a850", size = 177478 }, + { url = "https://files.pythonhosted.org/packages/57/11/d0268c4b94717a18aa91df1100e767b010f87b7ae444dafaa5a2d80f33a6/librt-0.7.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03679b9856932b8c8f674e87aa3c55ea11c9274301f76ae8dc4d281bda55cf62", size = 192439 }, + { url = "https://files.pythonhosted.org/packages/8d/56/1e8e833b95fe684f80f8894ae4d8b7d36acc9203e60478fcae599120a975/librt-0.7.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3968762fec1b2ad34ce57458b6de25dbb4142713e9ca6279a0d352fa4e9f452b", size = 191483 }, + { url = "https://files.pythonhosted.org/packages/17/48/f11cf28a2cb6c31f282009e2208312aa84a5ee2732859f7856ee306176d5/librt-0.7.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bb7a7807523a31f03061288cc4ffc065d684c39db7644c676b47d89553c0d714", size = 185376 }, + { url = "https://files.pythonhosted.org/packages/b8/6a/d7c116c6da561b9155b184354a60a3d5cdbf08fc7f3678d09c95679d13d9/librt-0.7.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad64a14b1e56e702e19b24aae108f18ad1bf7777f3af5fcd39f87d0c5a814449", size = 206234 }, + { url = "https://files.pythonhosted.org/packages/61/de/1975200bb0285fc921c5981d9978ce6ce11ae6d797df815add94a5a848a3/librt-0.7.8-cp312-cp312-win32.whl", hash = "sha256:0241a6ed65e6666236ea78203a73d800dbed896cf12ae25d026d75dc1fcd1dac", size = 44057 }, + { url = "https://files.pythonhosted.org/packages/8e/cd/724f2d0b3461426730d4877754b65d39f06a41ac9d0a92d5c6840f72b9ae/librt-0.7.8-cp312-cp312-win_amd64.whl", hash = "sha256:6db5faf064b5bab9675c32a873436b31e01d66ca6984c6f7f92621656033a708", size = 50293 }, + { url = "https://files.pythonhosted.org/packages/bd/cf/7e899acd9ee5727ad8160fdcc9994954e79fab371c66535c60e13b968ffc/librt-0.7.8-cp312-cp312-win_arm64.whl", hash = "sha256:57175aa93f804d2c08d2edb7213e09276bd49097611aefc37e3fa38d1fb99ad0", size = 43574 }, + { url = "https://files.pythonhosted.org/packages/a1/fe/b1f9de2829cf7fc7649c1dcd202cfd873837c5cc2fc9e526b0e7f716c3d2/librt-0.7.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4c3995abbbb60b3c129490fa985dfe6cac11d88fc3c36eeb4fb1449efbbb04fc", size = 57500 }, + { url = "https://files.pythonhosted.org/packages/eb/d4/4a60fbe2e53b825f5d9a77325071d61cd8af8506255067bf0c8527530745/librt-0.7.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:44e0c2cbc9bebd074cf2cdbe472ca185e824be4e74b1c63a8e934cea674bebf2", size = 59019 }, + { url = "https://files.pythonhosted.org/packages/6a/37/61ff80341ba5159afa524445f2d984c30e2821f31f7c73cf166dcafa5564/librt-0.7.8-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:4d2f1e492cae964b3463a03dc77a7fe8742f7855d7258c7643f0ee32b6651dd3", size = 169015 }, + { url = "https://files.pythonhosted.org/packages/1c/86/13d4f2d6a93f181ebf2fc953868826653ede494559da8268023fe567fca3/librt-0.7.8-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:451e7ffcef8f785831fdb791bd69211f47e95dc4c6ddff68e589058806f044c6", size = 178161 }, + { url = "https://files.pythonhosted.org/packages/88/26/e24ef01305954fc4d771f1f09f3dd682f9eb610e1bec188ffb719374d26e/librt-0.7.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3469e1af9f1380e093ae06bedcbdd11e407ac0b303a56bbe9afb1d6824d4982d", size = 193015 }, + { url = "https://files.pythonhosted.org/packages/88/a0/92b6bd060e720d7a31ed474d046a69bd55334ec05e9c446d228c4b806ae3/librt-0.7.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f11b300027ce19a34f6d24ebb0a25fd0e24a9d53353225a5c1e6cadbf2916b2e", size = 192038 }, + { url = "https://files.pythonhosted.org/packages/06/bb/6f4c650253704279c3a214dad188101d1b5ea23be0606628bc6739456624/librt-0.7.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4adc73614f0d3c97874f02f2c7fd2a27854e7e24ad532ea6b965459c5b757eca", size = 186006 }, + { url = "https://files.pythonhosted.org/packages/dc/00/1c409618248d43240cadf45f3efb866837fa77e9a12a71481912135eb481/librt-0.7.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60c299e555f87e4c01b2eca085dfccda1dde87f5a604bb45c2906b8305819a93", size = 206888 }, + { url = "https://files.pythonhosted.org/packages/d9/83/b2cfe8e76ff5c1c77f8a53da3d5de62d04b5ebf7cf913e37f8bca43b5d07/librt-0.7.8-cp313-cp313-win32.whl", hash = "sha256:b09c52ed43a461994716082ee7d87618096851319bf695d57ec123f2ab708951", size = 44126 }, + { url = "https://files.pythonhosted.org/packages/a9/0b/c59d45de56a51bd2d3a401fc63449c0ac163e4ef7f523ea8b0c0dee86ec5/librt-0.7.8-cp313-cp313-win_amd64.whl", hash = "sha256:f8f4a901a3fa28969d6e4519deceab56c55a09d691ea7b12ca830e2fa3461e34", size = 50262 }, + { url = "https://files.pythonhosted.org/packages/fc/b9/973455cec0a1ec592395250c474164c4a58ebf3e0651ee920fef1a2623f1/librt-0.7.8-cp313-cp313-win_arm64.whl", hash = "sha256:43d4e71b50763fcdcf64725ac680d8cfa1706c928b844794a7aa0fa9ac8e5f09", size = 43600 }, + { url = "https://files.pythonhosted.org/packages/1a/73/fa8814c6ce2d49c3827829cadaa1589b0bf4391660bd4510899393a23ebc/librt-0.7.8-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:be927c3c94c74b05128089a955fba86501c3b544d1d300282cc1b4bd370cb418", size = 57049 }, + { url = "https://files.pythonhosted.org/packages/53/fe/f6c70956da23ea235fd2e3cc16f4f0b4ebdfd72252b02d1164dd58b4e6c3/librt-0.7.8-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7b0803e9008c62a7ef79058233db7ff6f37a9933b8f2573c05b07ddafa226611", size = 58689 }, + { url = "https://files.pythonhosted.org/packages/1f/4d/7a2481444ac5fba63050d9abe823e6bc16896f575bfc9c1e5068d516cdce/librt-0.7.8-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:79feb4d00b2a4e0e05c9c56df707934f41fcb5fe53fd9efb7549068d0495b758", size = 166808 }, + { url = "https://files.pythonhosted.org/packages/ac/3c/10901d9e18639f8953f57c8986796cfbf4c1c514844a41c9197cf87cb707/librt-0.7.8-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9122094e3f24aa759c38f46bd8863433820654927370250f460ae75488b66ea", size = 175614 }, + { url = "https://files.pythonhosted.org/packages/db/01/5cbdde0951a5090a80e5ba44e6357d375048123c572a23eecfb9326993a7/librt-0.7.8-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7e03bea66af33c95ce3addf87a9bf1fcad8d33e757bc479957ddbc0e4f7207ac", size = 189955 }, + { url = "https://files.pythonhosted.org/packages/6a/b4/e80528d2f4b7eaf1d437fcbd6fc6ba4cbeb3e2a0cb9ed5a79f47c7318706/librt-0.7.8-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f1ade7f31675db00b514b98f9ab9a7698c7282dad4be7492589109471852d398", size = 189370 }, + { url = "https://files.pythonhosted.org/packages/c1/ab/938368f8ce31a9787ecd4becb1e795954782e4312095daf8fd22420227c8/librt-0.7.8-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a14229ac62adcf1b90a15992f1ab9c69ae8b99ffb23cb64a90878a6e8a2f5b81", size = 183224 }, + { url = "https://files.pythonhosted.org/packages/3c/10/559c310e7a6e4014ac44867d359ef8238465fb499e7eb31b6bfe3e3f86f5/librt-0.7.8-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5bcaaf624fd24e6a0cb14beac37677f90793a96864c67c064a91458611446e83", size = 203541 }, + { url = "https://files.pythonhosted.org/packages/f8/db/a0db7acdb6290c215f343835c6efda5b491bb05c3ddc675af558f50fdba3/librt-0.7.8-cp314-cp314-win32.whl", hash = "sha256:7aa7d5457b6c542ecaed79cec4ad98534373c9757383973e638ccced0f11f46d", size = 40657 }, + { url = "https://files.pythonhosted.org/packages/72/e0/4f9bdc2a98a798511e81edcd6b54fe82767a715e05d1921115ac70717f6f/librt-0.7.8-cp314-cp314-win_amd64.whl", hash = "sha256:3d1322800771bee4a91f3b4bd4e49abc7d35e65166821086e5afd1e6c0d9be44", size = 46835 }, + { url = "https://files.pythonhosted.org/packages/f9/3d/59c6402e3dec2719655a41ad027a7371f8e2334aa794ed11533ad5f34969/librt-0.7.8-cp314-cp314-win_arm64.whl", hash = "sha256:5363427bc6a8c3b1719f8f3845ea53553d301382928a86e8fab7984426949bce", size = 39885 }, + { url = "https://files.pythonhosted.org/packages/4e/9c/2481d80950b83085fb14ba3c595db56330d21bbc7d88a19f20165f3538db/librt-0.7.8-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:ca916919793a77e4a98d4a1701e345d337ce53be4a16620f063191f7322ac80f", size = 59161 }, + { url = "https://files.pythonhosted.org/packages/96/79/108df2cfc4e672336765d54e3ff887294c1cc36ea4335c73588875775527/librt-0.7.8-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:54feb7b4f2f6706bb82325e836a01be805770443e2400f706e824e91f6441dde", size = 61008 }, + { url = "https://files.pythonhosted.org/packages/46/f2/30179898f9994a5637459d6e169b6abdc982012c0a4b2d4c26f50c06f911/librt-0.7.8-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:39a4c76fee41007070f872b648cc2f711f9abf9a13d0c7162478043377b52c8e", size = 187199 }, + { url = "https://files.pythonhosted.org/packages/b4/da/f7563db55cebdc884f518ba3791ad033becc25ff68eb70902b1747dc0d70/librt-0.7.8-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac9c8a458245c7de80bc1b9765b177055efff5803f08e548dd4bb9ab9a8d789b", size = 198317 }, + { url = "https://files.pythonhosted.org/packages/b3/6c/4289acf076ad371471fa86718c30ae353e690d3de6167f7db36f429272f1/librt-0.7.8-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95b67aa7eff150f075fda09d11f6bfb26edffd300f6ab1666759547581e8f666", size = 210334 }, + { url = "https://files.pythonhosted.org/packages/4a/7f/377521ac25b78ac0a5ff44127a0360ee6d5ddd3ce7327949876a30533daa/librt-0.7.8-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:535929b6eff670c593c34ff435d5440c3096f20fa72d63444608a5aef64dd581", size = 211031 }, + { url = "https://files.pythonhosted.org/packages/c5/b1/e1e96c3e20b23d00cf90f4aad48f0deb4cdfec2f0ed8380d0d85acf98bbf/librt-0.7.8-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:63937bd0f4d1cb56653dc7ae900d6c52c41f0015e25aaf9902481ee79943b33a", size = 204581 }, + { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731 }, + { url = "https://files.pythonhosted.org/packages/22/f0/07fb6ab5c39a4ca9af3e37554f9d42f25c464829254d72e4ebbd81da351c/librt-0.7.8-cp314-cp314t-win32.whl", hash = "sha256:171ca3a0a06c643bd0a2f62a8944e1902c94aa8e5da4db1ea9a8daf872685365", size = 41173 }, + { url = "https://files.pythonhosted.org/packages/24/d4/7e4be20993dc6a782639625bd2f97f3c66125c7aa80c82426956811cfccf/librt-0.7.8-cp314-cp314t-win_amd64.whl", hash = "sha256:445b7304145e24c60288a2f172b5ce2ca35c0f81605f5299f3fa567e189d2e32", size = 47668 }, + { url = "https://files.pythonhosted.org/packages/fc/85/69f92b2a7b3c0f88ffe107c86b952b397004b5b8ea5a81da3d9c04c04422/librt-0.7.8-cp314-cp314t-win_arm64.whl", hash = "sha256:8766ece9de08527deabcd7cb1b4f1a967a385d26e33e536d6d8913db6ef74f06", size = 40550 }, +] + [[package]] name = "lifelines" version = "0.30.0" @@ -492,6 +797,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509 }, ] +[[package]] +name = "markdown" +version = "3.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180 }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -603,6 +917,149 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/b9/59e120d24a2ec5fc2d30646adb2efb4621aab3c6d83d66fb2a7a182db032/matplotlib-3.10.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb73d8aa75a237457988f9765e4dfe1c0d2453c5ca4eabc897d4309672c8e014", size = 8594298 }, ] +[[package]] +name = "mergedeep" +version = "1.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/41/580bb4006e3ed0361b8151a01d324fb03f420815446c7def45d02f74c270/mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8", size = 4661 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354 }, +] + +[[package]] +name = "mike" +version = "2.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "importlib-resources" }, + { name = "jinja2" }, + { name = "mkdocs" }, + { name = "pyparsing" }, + { name = "pyyaml" }, + { name = "pyyaml-env-tag" }, + { name = "verspec" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/f7/2933f1a1fb0e0f077d5d6a92c6c7f8a54e6128241f116dff4df8b6050bbf/mike-2.1.3.tar.gz", hash = "sha256:abd79b8ea483fb0275b7972825d3082e5ae67a41820f8d8a0dc7a3f49944e810", size = 38119 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/1a/31b7cd6e4e7a02df4e076162e9783620777592bea9e4bb036389389af99d/mike-2.1.3-py3-none-any.whl", hash = "sha256:d90c64077e84f06272437b464735130d380703a76a5738b152932884c60c062a", size = 33754 }, +] + +[[package]] +name = "mkdocs" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "ghp-import" }, + { name = "jinja2" }, + { name = "markdown" }, + { name = "markupsafe" }, + { name = "mergedeep" }, + { name = "mkdocs-get-deps" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "pyyaml" }, + { name = "pyyaml-env-tag" }, + { name = "watchdog" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bc/c6/bbd4f061bd16b378247f12953ffcb04786a618ce5e904b8c5a01a0309061/mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2", size = 3889159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/5b/dbc6a8cddc9cfa9c4971d59fb12bb8d42e161b7e7f8cc89e49137c5b279c/mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e", size = 3864451 }, +] + +[[package]] +name = "mkdocs-autorefs" +version = "1.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown" }, + { name = "markupsafe" }, + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/c0/f641843de3f612a6b48253f39244165acff36657a91cc903633d456ae1ac/mkdocs_autorefs-1.4.4.tar.gz", hash = "sha256:d54a284f27a7346b9c38f1f852177940c222da508e66edc816a0fa55fc6da197", size = 56588 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/de/a3e710469772c6a89595fc52816da05c1e164b4c866a89e3cb82fb1b67c5/mkdocs_autorefs-1.4.4-py3-none-any.whl", hash = "sha256:834ef5408d827071ad1bc69e0f39704fa34c7fc05bc8e1c72b227dfdc5c76089", size = 25530 }, +] + +[[package]] +name = "mkdocs-get-deps" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mergedeep" }, + { name = "platformdirs" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/f5/ed29cd50067784976f25ed0ed6fcd3c2ce9eb90650aa3b2796ddf7b6870b/mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c", size = 10239 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/d4/029f984e8d3f3b6b726bd33cafc473b75e9e44c0f7e80a5b29abc466bdea/mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134", size = 9521 }, +] + +[[package]] +name = "mkdocs-material" +version = "9.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "babel" }, + { name = "backrefs" }, + { name = "colorama" }, + { name = "jinja2" }, + { name = "markdown" }, + { name = "mkdocs" }, + { name = "mkdocs-material-extensions" }, + { name = "paginate" }, + { name = "pygments" }, + { name = "pymdown-extensions" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166 }, +] + +[[package]] +name = "mkdocs-material-extensions" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/9b/9b4c96d6593b2a541e1cb8b34899a6d021d208bb357042823d4d2cabdbe7/mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443", size = 11847 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728 }, +] + +[[package]] +name = "mkdocstrings" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "markdown" }, + { name = "markupsafe" }, + { name = "mkdocs" }, + { name = "mkdocs-autorefs" }, + { name = "pymdown-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/62/0dfc5719514115bf1781f44b1d7f2a0923fcc01e9c5d7990e48a05c9ae5d/mkdocstrings-1.0.3.tar.gz", hash = "sha256:ab670f55040722b49bb45865b2e93b824450fb4aef638b00d7acb493a9020434", size = 100946 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/41/1cf02e3df279d2dd846a1bf235a928254eba9006dd22b4a14caa71aed0f7/mkdocstrings-1.0.3-py3-none-any.whl", hash = "sha256:0d66d18430c2201dc7fe85134277382baaa15e6b30979f3f3bdbabd6dbdb6046", size = 35523 }, +] + +[[package]] +name = "mkdocstrings-python" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "griffe" }, + { name = "mkdocs-autorefs" }, + { name = "mkdocstrings" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/84/78243847ad9d5c21d30a2842720425b17e880d99dfe824dee11d6b2149b4/mkdocstrings_python-2.0.2.tar.gz", hash = "sha256:4a32ccfc4b8d29639864698e81cfeb04137bce76bb9f3c251040f55d4b6e1ad8", size = 199124 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/31/7ee938abbde2322e553a2cb5f604cdd1e4728e08bba39c7ee6fae9af840b/mkdocstrings_python-2.0.2-py3-none-any.whl", hash = "sha256:31241c0f43d85a69306d704d5725786015510ea3f3c4bdfdb5a5731d83cdc2b0", size = 104900 }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -612,6 +1069,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, ] +[[package]] +name = "mypy" +version = "1.19.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, + { name = "mypy-extensions" }, + { name = "pathspec" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/63/e499890d8e39b1ff2df4c0c6ce5d371b6844ee22b8250687a99fd2f657a8/mypy-1.19.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f05aa3d375b385734388e844bc01733bd33c644ab48e9684faa54e5389775ec", size = 13101333 }, + { url = "https://files.pythonhosted.org/packages/72/4b/095626fc136fba96effc4fd4a82b41d688ab92124f8c4f7564bffe5cf1b0/mypy-1.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:022ea7279374af1a5d78dfcab853fe6a536eebfda4b59deab53cd21f6cd9f00b", size = 12164102 }, + { url = "https://files.pythonhosted.org/packages/0c/5b/952928dd081bf88a83a5ccd49aaecfcd18fd0d2710c7ff07b8fb6f7032b9/mypy-1.19.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee4c11e460685c3e0c64a4c5de82ae143622410950d6be863303a1c4ba0e36d6", size = 12765799 }, + { url = "https://files.pythonhosted.org/packages/2a/0d/93c2e4a287f74ef11a66fb6d49c7a9f05e47b0a4399040e6719b57f500d2/mypy-1.19.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de759aafbae8763283b2ee5869c7255391fbc4de3ff171f8f030b5ec48381b74", size = 13522149 }, + { url = "https://files.pythonhosted.org/packages/7b/0e/33a294b56aaad2b338d203e3a1d8b453637ac36cb278b45005e0901cf148/mypy-1.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ab43590f9cd5108f41aacf9fca31841142c786827a74ab7cc8a2eacb634e09a1", size = 13810105 }, + { url = "https://files.pythonhosted.org/packages/0e/fd/3e82603a0cb66b67c5e7abababce6bf1a929ddf67bf445e652684af5c5a0/mypy-1.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:2899753e2f61e571b3971747e302d5f420c3fd09650e1951e99f823bc3089dac", size = 10057200 }, + { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539 }, + { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163 }, + { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629 }, + { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933 }, + { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754 }, + { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772 }, + { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053 }, + { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134 }, + { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616 }, + { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847 }, + { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976 }, + { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104 }, + { url = "https://files.pythonhosted.org/packages/de/9f/a6abae693f7a0c697dbb435aac52e958dc8da44e92e08ba88d2e42326176/mypy-1.19.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3157c7594ff2ef1634ee058aafc56a82db665c9438fd41b390f3bde1ab12250", size = 13201927 }, + { url = "https://files.pythonhosted.org/packages/9a/a4/45c35ccf6e1c65afc23a069f50e2c66f46bd3798cbe0d680c12d12935caa/mypy-1.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdb12f69bcc02700c2b47e070238f42cb87f18c0bc1fc4cdb4fb2bc5fd7a3b8b", size = 12206730 }, + { url = "https://files.pythonhosted.org/packages/05/bb/cdcf89678e26b187650512620eec8368fded4cfd99cfcb431e4cdfd19dec/mypy-1.19.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f859fb09d9583a985be9a493d5cfc5515b56b08f7447759a0c5deaf68d80506e", size = 12724581 }, + { url = "https://files.pythonhosted.org/packages/d1/32/dd260d52babf67bad8e6770f8e1102021877ce0edea106e72df5626bb0ec/mypy-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9a6538e0415310aad77cb94004ca6482330fece18036b5f360b62c45814c4ef", size = 13616252 }, + { url = "https://files.pythonhosted.org/packages/71/d0/5e60a9d2e3bd48432ae2b454b7ef2b62a960ab51292b1eda2a95edd78198/mypy-1.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:da4869fc5e7f62a88f3fe0b5c919d1d9f7ea3cef92d3689de2823fd27e40aa75", size = 13840848 }, + { url = "https://files.pythonhosted.org/packages/98/76/d32051fa65ecf6cc8c6610956473abdc9b4c43301107476ac03559507843/mypy-1.19.1-cp313-cp313-win_amd64.whl", hash = "sha256:016f2246209095e8eda7538944daa1d60e1e8134d98983b9fc1e92c1fc0cb8dd", size = 10135510 }, + { url = "https://files.pythonhosted.org/packages/de/eb/b83e75f4c820c4247a58580ef86fcd35165028f191e7e1ba57128c52782d/mypy-1.19.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:06e6170bd5836770e8104c8fdd58e5e725cfeb309f0a6c681a811f557e97eac1", size = 13199744 }, + { url = "https://files.pythonhosted.org/packages/94/28/52785ab7bfa165f87fcbb61547a93f98bb20e7f82f90f165a1f69bce7b3d/mypy-1.19.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:804bd67b8054a85447c8954215a906d6eff9cabeabe493fb6334b24f4bfff718", size = 12215815 }, + { url = "https://files.pythonhosted.org/packages/0a/c6/bdd60774a0dbfb05122e3e925f2e9e846c009e479dcec4821dad881f5b52/mypy-1.19.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:21761006a7f497cb0d4de3d8ef4ca70532256688b0523eee02baf9eec895e27b", size = 12740047 }, + { url = "https://files.pythonhosted.org/packages/32/2a/66ba933fe6c76bd40d1fe916a83f04fed253152f451a877520b3c4a5e41e/mypy-1.19.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28902ee51f12e0f19e1e16fbe2f8f06b6637f482c459dd393efddd0ec7f82045", size = 13601998 }, + { url = "https://files.pythonhosted.org/packages/e3/da/5055c63e377c5c2418760411fd6a63ee2b96cf95397259038756c042574f/mypy-1.19.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:481daf36a4c443332e2ae9c137dfee878fcea781a2e3f895d54bd3002a900957", size = 13807476 }, + { url = "https://files.pythonhosted.org/packages/cd/09/4ebd873390a063176f06b0dbf1f7783dd87bd120eae7727fa4ae4179b685/mypy-1.19.1-cp314-cp314-win_amd64.whl", hash = "sha256:8bb5c6f6d043655e055be9b542aa5f3bdd30e4f3589163e85f93f3640060509f", size = 10281872 }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239 }, +] + +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963 }, +] + [[package]] name = "networkx" version = "3.4.2" @@ -1015,6 +1527,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 }, ] +[[package]] +name = "paginate" +version = "0.5.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/46/68dde5b6bc00c1296ec6466ab27dddede6aec9af1b99090e1107091b3b84/paginate-0.5.7.tar.gz", hash = "sha256:22bd083ab41e1a8b4f3690544afb2c60c25e5c9a63a30fa2f483f6c60c8e5945", size = 19252 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746 }, +] + [[package]] name = "pandas" version = "2.3.0" @@ -1064,6 +1585,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/c2/646d2e93e0af70f4e5359d870a63584dacbc324b54d73e6b3267920ff117/pandas-2.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bb3be958022198531eb7ec2008cfc78c5b1eed51af8600c6c5d9160d89d8d249", size = 13231847 }, ] +[[package]] +name = "pathspec" +version = "1.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206 }, +] + [[package]] name = "pillow" version = "11.2.1" @@ -1141,6 +1671,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/21/2c/5e05f58658cf49b6667762cca03d6e7d85cededde2caf2ab37b81f80e574/pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044", size = 2674751 }, ] +[[package]] +name = "platformdirs" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/86/0248f086a84f01b37aaec0fa567b397df1a119f73c16f6c7a9aac73ea309/platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda", size = 21715 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731 }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217 }, +] + +[[package]] +name = "pymdown-extensions" +version = "10.20.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1e/6c/9e370934bfa30e889d12e61d0dae009991294f40055c238980066a7fbd83/pymdown_extensions-10.20.1.tar.gz", hash = "sha256:e7e39c865727338d434b55f1dd8da51febcffcaebd6e1a0b9c836243f660740a", size = 852860 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/6d/b6ee155462a0156b94312bdd82d2b92ea56e909740045a87ccb98bf52405/pymdown_extensions-10.20.1-py3-none-any.whl", hash = "sha256:24af7feacbca56504b313b7b418c4f5e1317bb5fea60f03d57be7fcc40912aa0", size = 268768 }, +] + [[package]] name = "pyparsing" version = "3.2.3" @@ -1215,6 +1776,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, ] +[[package]] +name = "pyyaml-env-tag" +version = "1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/2e/79c822141bfd05a853236b504869ebc6b70159afc570e1d5a20641782eaa/pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff", size = 5737 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722 }, +] + [[package]] name = "qdldl" version = "0.1.7.post5" @@ -1248,6 +1821,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/f7/abac03a09f6848cee6d5dd7a7a8bd1dfed68766ee77f9cbf3e9de596ad68/qdldl-0.1.7.post5-cp313-cp313-win_amd64.whl", hash = "sha256:cc9be378e7bec67d4c62b7fa27cafb4f77d3e5e059d753c3dce0a5ae1ef5fea0", size = 90735 }, ] +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738 }, +] + [[package]] name = "scikit-learn" version = "1.6.1" @@ -1608,6 +2196,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 }, ] +[[package]] +name = "urllib3" +version = "2.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584 }, +] + +[[package]] +name = "verspec" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/44/8126f9f0c44319b2efc65feaad589cadef4d77ece200ae3c9133d58464d0/verspec-0.1.0.tar.gz", hash = "sha256:c4504ca697b2056cdb4bfa7121461f5a0e81809255b41c03dda4ba823637c01e", size = 27123 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/ce/3b6fee91c85626eaf769d617f1be9d2e15c1cca027bbdeb2e0d751469355/verspec-0.1.0-py3-none-any.whl", hash = "sha256:741877d5633cc9464c45a469ae2a31e801e6dbbaa85b9675d481cda100f11c31", size = 19640 }, +] + +[[package]] +name = "watchdog" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/56/90994d789c61df619bfc5ce2ecdabd5eeff564e1eb47512bd01b5e019569/watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26", size = 96390 }, + { url = "https://files.pythonhosted.org/packages/55/46/9a67ee697342ddf3c6daa97e3a587a56d6c4052f881ed926a849fcf7371c/watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112", size = 88389 }, + { url = "https://files.pythonhosted.org/packages/44/65/91b0985747c52064d8701e1075eb96f8c40a79df889e59a399453adfb882/watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3", size = 89020 }, + { url = "https://files.pythonhosted.org/packages/e0/24/d9be5cd6642a6aa68352ded4b4b10fb0d7889cb7f45814fb92cecd35f101/watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c", size = 96393 }, + { url = "https://files.pythonhosted.org/packages/63/7a/6013b0d8dbc56adca7fdd4f0beed381c59f6752341b12fa0886fa7afc78b/watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2", size = 88392 }, + { url = "https://files.pythonhosted.org/packages/d1/40/b75381494851556de56281e053700e46bff5b37bf4c7267e858640af5a7f/watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c", size = 89019 }, + { url = "https://files.pythonhosted.org/packages/39/ea/3930d07dafc9e286ed356a679aa02d777c06e9bfd1164fa7c19c288a5483/watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948", size = 96471 }, + { url = "https://files.pythonhosted.org/packages/12/87/48361531f70b1f87928b045df868a9fd4e253d9ae087fa4cf3f7113be363/watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860", size = 88449 }, + { url = "https://files.pythonhosted.org/packages/5b/7e/8f322f5e600812e6f9a31b75d242631068ca8f4ef0582dd3ae6e72daecc8/watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0", size = 89054 }, + { url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480 }, + { url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451 }, + { url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057 }, + { url = "https://files.pythonhosted.org/packages/30/ad/d17b5d42e28a8b91f8ed01cb949da092827afb9995d4559fd448d0472763/watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881", size = 87902 }, + { url = "https://files.pythonhosted.org/packages/5c/ca/c3649991d140ff6ab67bfc85ab42b165ead119c9e12211e08089d763ece5/watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11", size = 88380 }, + { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079 }, + { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078 }, + { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076 }, + { url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077 }, + { url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078 }, + { url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077 }, + { url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078 }, + { url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065 }, + { url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070 }, + { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067 }, +] + [[package]] name = "wrapt" version = "1.17.2" @@ -1671,3 +2309,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/5e/1655cf481e079c1f22d0cabdd4e51733679932718dc23bf2db175f329b76/wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c", size = 40750 }, { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594 }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276 }, +]