diff --git a/.github/workflows/blog.yml b/.github/workflows/blog.yml
deleted file mode 100644
index b797520a..00000000
--- a/.github/workflows/blog.yml
+++ /dev/null
@@ -1,70 +0,0 @@
-name: Build and Deploy Blog
-
-on:
- push:
- branches: [ "main" ]
- pull_request:
- branches: [ "main" ]
-
-permissions:
- contents: read
- pages: write
- id-token: write
-
-jobs:
- build:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v6
-
- - name: Setup Node.js
- uses: actions/setup-node@v6
- with:
- node-version: 'latest'
- cache: 'npm'
- cache-dependency-path: blog/package-lock.json
-
- - name: Install dependencies
- run: cd blog && npm ci
-
- - name: Build
- run: cd ./blog && npm run build
-
- - name: Upload build artifact
- if: github.event_name == 'push' && github.ref == 'refs/heads/main'
- uses: actions/upload-artifact@v6
- with:
- name: blog-build
- path: blog/public
- retention-days: 1
-
- deploy:
- needs: build
- if: github.event_name == 'push' && github.ref == 'refs/heads/main'
- runs-on: ubuntu-latest
- environment:
- name: github-pages
- url: ${{ steps.deployment.outputs.page_url }}
- concurrency:
- group: pages
- cancel-in-progress: false
-
- steps:
- - name: Download build artifact
- uses: actions/download-artifact@v7
- with:
- name: blog-build
- path: blog/public
-
- - name: Setup Pages
- uses: actions/configure-pages@v5
-
- - name: Upload Pages artifact
- uses: actions/upload-pages-artifact@v4
- with:
- path: blog/public
- retention-days: 1
-
- - name: Deploy to GitHub Pages
- id: deployment
- uses: actions/deploy-pages@v4
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
new file mode 100644
index 00000000..a558b951
--- /dev/null
+++ b/.github/workflows/docs.yml
@@ -0,0 +1,138 @@
+name: docs
+permissions:
+ contents: write
+ pages: write
+ pull-requests: write
+
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - .pre-commit-config.yaml
+ - .github/workflows/docs.yml
+ - '**.py'
+ - '**.ipynb'
+ - '**.html'
+ - '**.js'
+ - '**.md'
+ - uv.lock
+ - pyproject.toml
+ - mkdocs.yml
+ - '**.png'
+ - '**.svg'
+ pull_request:
+ branches:
+ - main
+ paths:
+ - .pre-commit-config.yaml
+ - .github/workflows/docs.yml
+ - '**.py'
+ - '**.ipynb'
+ - '**.js'
+ - '**.html'
+ - uv.lock
+ - pyproject.toml
+ - '**.md'
+ - mkdocs.yml
+ - '**.png'
+ - '**.svg'
+
+jobs:
+ build-docs:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6.0.1
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7.2.0
+ with:
+ version: "0.9.11"
+ enable-cache: true
+
+ - name: Set up Python
+ uses: actions/setup-python@v6.1.0
+ with:
+ python-version-file: ".python-version"
+
+ - name: Install the project
+ run: uv sync --all-extras --group docs
+
+ - name: Build docs
+ run: uv run mkdocs build
+
+ - name: Create .nojekyll file
+ run: touch site/.nojekyll
+
+ - name: Upload artifact
+ uses: actions/upload-artifact@v6
+ with:
+ name: docs-site
+ path: site/
+ retention-days: 1
+ build-blog:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: 'latest'
+ cache: 'npm'
+ cache-dependency-path: blog/package-lock.json
+
+ - name: Install dependencies
+ run: cd blog && npm ci
+
+ - name: Build
+ run: cd ./blog && npm run build
+
+ - name: Upload build artifact
+ uses: actions/upload-artifact@v6
+ with:
+ name: blog-build
+ path: blog/public
+ retention-days: 1
+ deploy:
+ needs:
+ - build-docs
+ - build-blog
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ environment:
+ name: github-pages
+ url: ${{ steps.deployment.outputs.page_url }}
+ concurrency:
+ group: pages
+ cancel-in-progress: false
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v6.0.1
+
+ - name: Configure Git Credentials
+ run: |
+ git config user.name github-actions[bot]
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
+
+ - name: Download docs artifact
+ uses: actions/download-artifact@v7
+ with:
+ name: docs-site
+ path: site
+ - name: Ensure .nojekyll exists
+ run: touch site/.nojekyll
+ - name: Create a folder for blog
+ run: mkdir -p site/blog
+ - name: Download blog artifact
+ uses: actions/download-artifact@v7
+ with:
+ name: blog-build
+ path: site/blog
+ - name: Deploy to Github pages
+ id: deployment
+ uses: JamesIves/github-pages-deploy-action@v4.8.0
+ with:
+ branch: gh-pages
+ folder: site
\ No newline at end of file
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 00000000..3dc40099
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,57 @@
+name: publish package
+permissions:
+ contents: read
+ pull-requests: write
+
+on:
+ push:
+ tags:
+ - "v*"
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install libcurl4-openssl-dev libssl-dev
+ - uses: actions/checkout@v6.0.1
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7.2.0
+ with:
+ # Install a specific version of uv.
+ version: "0.9.11"
+ enable-cache: true
+
+ - name: "Set up Python"
+ uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548
+ with:
+ python-version-file: ".python-version"
+
+ - name: Install the project
+ run: uv sync --all-extras --dev
+
+ - name: Build package
+ run: uv build
+
+ - name: Publish package
+ uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e
+ with:
+ user: __token__
+ password: ${{ secrets.PYPI_API_TOKEN }}
+
+ release_github:
+ needs: deploy
+ runs-on: ubuntu-latest
+ permissions:
+ contents: write # To create a github release
+ steps:
+ - name: Create GitHub Release
+ id: create_release
+ uses: ncipollo/release-action@v1.21.0
+ with:
+ artifacts: "dist/*"
+ generateReleaseNotes: true
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index c28b9fe1..6c08e923 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,3 @@
*.pyc
.venv/
blog/node_modules/
-.DS_Store
diff --git a/.python-version b/.python-version
new file mode 100644
index 00000000..c8cfe395
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.10
diff --git a/README.md b/README.md
index 085daf4e..d30eec6f 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@ CRISP-NAM (Competing Risks Interpretable Survival Prediction with Neural Additiv
## Overview
-This repository provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial.
+This package provides a comprehensive framework for competing risks survival analysis with interpretable neural additive models. CRISP-NAM combines the predictive power of deep learning with interpretability through feature-level shape functions, making it suitable for clinical and biomedical applications where understanding feature contributions is crucial.
### Key Features
@@ -16,112 +16,14 @@ This repository provides a comprehensive framework for competing risks survival
- **Multiple Training Modes**: Standard training, hyperparameter tuning, and nested cross-validation
- **Baseline Comparisons**: DeepHit implementation for benchmarking against state-of-the-art methods
-### Available Datasets
-
-The repository includes four well-established survival analysis datasets:
-
-1. **Framingham Heart Study**: Cardiovascular disease prediction with competing events (CVD vs. death)
- - Features: Demographics, clinical measurements, lifestyle factors
- - Events: Cardiovascular disease, death from other causes
-
-2. **PBC (Primary Biliary Cirrhosis)**: Liver disease progression study
- - Features: Clinical laboratory values, demographic information
- - Events: Death, liver transplantation
-
-3. **SUPPORT**: Study to understand prognoses and preferences for outcomes
- - Features: Comprehensive clinical and demographic variables
- - Events: Cancer death, non-cancer death
-
-4. **Synthetic Dataset**: Controlled simulation for method validation
- - Features: Simulated clinical variables with known ground truth
- - Events: Multiple competing risks with controllable hazard functions
-
-All datasets come with preprocessing pipelines that handle missing values, feature encoding, and proper train/test splitting to prevent data leakage.
-
-### Training Scripts
-
-The repository provides several specialized training scripts:
-
-- **`train.py`**: Standard model training with cross-validation and comprehensive evaluation
-- **`train_nested_cv.py`**: Robust nested cross-validation for unbiased performance estimation
-- **`tune_optuna.py`**: Hyperparameter optimization using Optuna's advanced algorithms
-- **`train_deephit.py`**: DeepHit baseline implementation for comparative studies
-
-Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps.
-
## Requirements
+
Python >=3.10
-## Repository Structure
+## Install the package
-```
-crisp_nam/
-├── crisp_nam/ # Main package
-│ ├── metrics/
-│ │ ├── __init__.py
-│ │ ├── calibration.py
-│ │ ├── discrimination.py
-│ │ └── ipcw.py
-│ ├── models/
-│ │ ├── __init__.py
-│ │ ├── crisp_nam_model.py
-│ │ └── deephit_model.py
-│ ├── utils/
-│ │ ├── __init__.py
-│ │ ├── loss.py
-│ │ ├── plotting.py
-│ │ └── risk_cif.py
-│ └── __init__.py
-├── data_utils/ # Data utilities
-│ ├── __init__.py
-│ ├── load_datasets.py
-│ └── survival_datasets.py
-├── datasets/ # Dataset files and loaders
-│ ├── metabric/
-│ │ ├── cleaned_features_final.csv
-│ │ └── label.csv
-│ ├── framingham_dataset.py
-│ ├── framingham.csv
-│ ├── pbc_dataset.py
-│ ├── pbc2.csv
-│ ├── support_dataset.py
-│ ├── support2.csv
-│ ├── SurvivalDataset.py
-│ ├── synthetic_comprisk.csv
-│ └── synthetic_dataset.py
-├── results/ # Results and outputs
-│ ├── best_params/ # Best parameters for dataset and model combinations
-│ │ ├── best_params_framingham_deephit.yaml
-│ │ ├── best_params_framingham.yaml
-│ │ ├── best_params_pbc_deephit.yaml
-│ │ ├── best_params_pbc.yaml
-│ │ ├── best_params_support.yaml
-│ │ ├── best_params_support2_deephit.yaml
-│ │ ├── best_params_synthetic_deephit.yaml
-│ │ └── best_params_synthetic.yaml
-│ ├── logs/ # Nested CV results and logs
-│ │ ├── nested_cv_best_params_*.yaml
-│ │ ├── nested_cv_detailed_metrics_*.csv
-│ │ ├── nested_cv_metrics_*.xlsx
-│ │ ├── nested_cv_raw_metrics_*.json
-│ │ └── nested_cv_summary_metrics_*.csv
-│ └── plots/ # Generated plots
-│ ├── nested_cv_feature_importance_risk_*_*.png
-│ └── nested_cv_shape_functions_risk_*_*.png
-├── training_scripts/ # Training scripts
-│ ├── config.yaml
-│ ├── model_utils.py
-│ ├── train_deephit_cuda.py
-│ ├── train_deephit.py
-│ ├── train_nested_cv.py # Nested cross-validation script
-│ ├── train.py
-│ ├── tune_optuna_optimized.py
-│ └── tune_optuna.py
-└── utils/ # Legacy utils (duplicate of crisp_nam/utils)
- ├── __init__.py
- ├── loss.py
- ├── plotting.py
- └── risk_cif.py
+```bash
+pip install crisp-nam
```
## Install from source
@@ -145,135 +47,26 @@ via `uv`
cd crisp-nam
uv sync
```
+## Research details
-## Running training scripts
-
-1. Modify training parameters in `training_scripts/train.py`
- OR
- Run either of following commands to see CLI arguments for passing training parameters:
-
- ```bash
- python training_scripts/train.py --help
- ```
-
- ```bash
- uv run training_scripts/train.py --help
- ```
-
-2. Run the training script
-
- via `python`
- ```bash
- source .venv/bin/activate
- python training_scripts/train.py --dataset framingham
- ```
-
- via `uv`
- ```bash
- uv run training_scripts/train.py --dataset framingham
- ```
-
-## Running Nested Cross-Validation
-
-The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations.
-
-### Basic Usage
-
-```bash
-# Using python
-python training_scripts/train_nested_cv.py --dataset framingham
-
-# Using uv
-uv run training_scripts/train_nested_cv.py --dataset framingham
-```
+For more details regarding the research work, please refer to `datasets.md` and `training.md` within the project repository.
-### Configuration Parameters
-
-All parameters can be passed via command line or specified in a YAML config file:
-
-#### Dataset Configuration
-- `--dataset` (str): Dataset to use (choices: `framingham`, `support`, `pbc`, `synthetic`, default: `framingham`)
-- `--scaling` (str): Data scaling method for continuous features (choices: `minmax`, `standard`, `none`, default: `standard`)
-
-#### Training Parameters
-- `--num_epochs` (int): Number of training epochs (default: `250`)
-- `--batch_size` (int): Batch size for training (default: `512`)
-- `--patience` (int): Patience for early stopping (default: `10`)
-
-#### Cross-Validation Configuration
-- `--outer_folds` (int): Number of outer CV folds (default: `5`)
-- `--inner_folds` (int): Number of inner CV folds for hyperparameter tuning (default: `3`)
-- `--n_trials` (int): Number of Optuna trials per inner fold (default: `20`)
-
-#### Event Weighting
-- `--event_weighting` (str): Event weighting strategy (choices: `none`, `balanced`, `custom`, default: `none`)
-- `--custom_event_weights` (str): Custom weights for events (comma-separated, default: `None`)
-
-#### Other Parameters
-- `--seed` (int): Random seed for reproducibility (default: `42`)
-- `--config` (str): Path to YAML config file (default: looks for `config.yaml`)
+## Contributing
-### Examples
+Contributions are welcome! Please open issues or submit pull requests.
-#### Basic nested CV with default parameters:
-```bash
-python training_scripts/train_nested_cv.py --dataset pbc
-```
+## Citation
-#### Customized nested CV with specific parameters:
-```bash
-python training_scripts/train_nested_cv.py \
- --dataset support \
- --outer_folds 10 \
- --inner_folds 5 \
- --n_trials 50 \
- --num_epochs 500 \
- --event_weighting balanced \
- --scaling minmax \
- --seed 123
+If you use our package, kindly acknowledge by citing our research.
```
-
-#### Using a config file:
-```bash
-python training_scripts/train_nested_cv.py --config my_config.yaml
+@inproceedings{ramachandram2025crispnam,
+ title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models},
+ author={Ramachandram, Dhanesh and Raval, Ananya},
+ booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain},
+ year={2025}
+}
```
-### Output Files
-
-The script generates several output files in the current directory:
-
-#### Performance Metrics
-- `nested_cv_summary_metrics_{dataset}.csv`: Summary table with mean ± std metrics
-- `nested_cv_detailed_metrics_{dataset}.csv`: Detailed results for each fold
-- `nested_cv_metrics_{dataset}.xlsx`: Excel file with multiple sheets (Summary, Detailed, Metadata)
-- `nested_cv_raw_metrics_{dataset}.json`: Raw metrics dictionary for reproducibility
-
-#### Model Configuration
-- `nested_cv_best_params_{dataset}.yaml`: Aggregated best hyperparameters across all folds
-
-#### Visualizations
-Results are saved to `results/plots/`:
-- `nested_cv_feature_importance_risk_{risk}_{dataset}.png`: Feature importance plots
-- `nested_cv_shape_functions_risk_{risk}_{dataset}.png`: Shape function plots for top features
-
-### Evaluation Metrics
-
-The script computes the following metrics at different time quantiles (25%, 50%, 75%):
-
-- **AUC (Area Under the ROC Curve)**: Time-dependent AUC for discrimination
- - 0.5 = random, >0.7 = good, >0.8 = excellent
-- **TDCI (Time-Dependent Concordance Index)**: Harrell's C-index adapted for competing risks
- - 0.5 = random, >0.7 = good, >0.8 = excellent
-- **Brier Score**: Calibration metric measuring prediction accuracy
- - 0 = perfect, <0.25 = good, >0.25 = poor
-
-> [!NOTE]
-> For `uv` installation, please visit follow instructions in their [official page](https://docs.astral.sh/uv/getting-started/installation/).
-
-## Contributing
-
-Contributions are welcome! Please open issues or submit pull requests.
-
## License
-This project is licensed under the MIT License.
\ No newline at end of file
+This project is licensed under the MIT License.
diff --git a/crisp_nam/__init__.py b/crisp_nam/__init__.py
index d67a0b62..3f10443c 100644
--- a/crisp_nam/__init__.py
+++ b/crisp_nam/__init__.py
@@ -1 +1,3 @@
-__all__ = [ 'utils', 'models', 'data_utils', 'metrics' ]
\ No newline at end of file
+"""CRISP-NAM paclage."""
+
+__all__ = ["utils", "models", "data_utils", "metrics"]
diff --git a/crisp_nam/metrics/__init__.py b/crisp_nam/metrics/__init__.py
index b9079625..01a720b3 100644
--- a/crisp_nam/metrics/__init__.py
+++ b/crisp_nam/metrics/__init__.py
@@ -1,3 +1,14 @@
-from .calibration import *
-from .discrimination import *
-from .ipcw import *
+"""Evaluation metrics used within crisp_nam package."""
+
+from .calibration import brier_score, integrated_brier_score
+from .discrimination import auc_td, cumulative_dynamic_auc, truncated_concordance_td
+from .ipcw import estimate_ipcw
+
+__all__ = [
+ "brier_score",
+ "integrated_brier_score",
+ "auc_td",
+ "cumulative_dynamic_auc",
+ "truncated_concordance_td",
+ "estimate_ipcw",
+]
diff --git a/crisp_nam/metrics/calibration.py b/crisp_nam/metrics/calibration.py
index 3c035f38..26556604 100644
--- a/crisp_nam/metrics/calibration.py
+++ b/crisp_nam/metrics/calibration.py
@@ -1,26 +1,47 @@
+"""Calibration metrics for time-to-event models with competing risks.
+
+This module contains functions to compute the Brier score and
+integrated Brier score for competing risks.
+"""
+
import numpy as np
+
from .ipcw import estimate_ipcw
+
# A small constant to avoid division by zero
epsilon = 1e-4
-def brier_score(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1):
- """
- Compute the corrected Brier score for a given competing risk.
-
+def brier_score(
+ e_test: np.ndarray,
+ t_test: np.ndarray,
+ risk_predicted_test: np.ndarray,
+ times: np.ndarray,
+ t: float,
+ km: object | None = None,
+ primary_risk: int = 1,
+) -> tuple[float, object]:
+ """Compute the corrected Brier score for a given competing risk.
+
This implementation is based on the work of Schoop et al. on quantifying the
predictive accuracy of time-to-event models in the presence of competing risks.
-
- Parameters:
- e_test (ndarray): Array of event indicators (0 = censored; positive integers for different events).
+
+ Parameters
+ ----------
+ e_test (ndarray): Array of event indicators (0 = censored;
+ positive integers for different events).
t_test (ndarray): Array of event/censoring times.
- risk_predicted_test (ndarray): Predicted risk matrix with shape (n_samples, n_times).
- times (ndarray): Array of time points corresponding to columns in risk_predicted_test.
+ risk_predicted_test (ndarray): Predicted risk matrix with
+ shape (n_samples, n_times).
+ times (ndarray): Array of time points corresponding to columns
+ in risk_predicted_test.
t (float): Time at which to evaluate the Brier score.
- km (object, optional): Kaplan–Meier estimator or data to estimate the censoring distribution.
+ km (object, optional): Kaplan–Meier estimator or data to estimate
+ the censoring distribution.
primary_risk (int, optional): The event label for which to compute the score.
-
- Returns:
+
+ Returns
+ -------
brier (float): The corrected Brier score evaluated at time t.
km (object): Updated Kaplan–Meier estimator (if applicable).
"""
@@ -39,47 +60,67 @@ def brier_score(e_test, t_test, risk_predicted_test, times, t, km=None, primary_
# Initialize weights for IPCW correction.
weights = np.zeros_like(e_test, dtype=float)
- # For subjects with events (or censoring) before t (excluding those censored exactly at 0 event label), use KM weights.
+ # For subjects with events (or censoring) before t
+ # (excluding those censored exactly at 0 event label), use KM weights.
mask = (t_test <= t) & (e_test != 0)
- weights[mask] = 1. / np.clip(km.survival_function_at_times(t_test[mask]), epsilon, None)
- # For subjects still at risk at time t, assign constant weight based on KM at time t.
- weights[t_test > t] = 1. / np.clip(km.survival_function_at_times(t), epsilon, None)
+ weights[mask] = 1.0 / np.clip(
+ km.survival_function_at_times(t_test[mask]), epsilon, None
+ )
+ # For subjects still at risk at time t, assign constant weight based
+ # on KM at time t.
+ weights[t_test > t] = 1.0 / np.clip(km.survival_function_at_times(t), epsilon, None)
brier = (weights * (truth - risk_predicted_test[:, index]) ** 2).mean()
return brier, km
-def integrated_brier_score(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1):
+def integrated_brier_score(
+ e_test: np.ndarray,
+ t_test: np.ndarray,
+ risk_predicted_test: np.ndarray,
+ times: np.ndarray,
+ t_eval: np.ndarray | None = None,
+ km: object | None = None,
+ primary_risk: int = 1,
+) -> tuple[float, object]:
"""
Compute the integrated Brier score for competing risks over a range of time points.
-
- The integrated Brier score is computed by numerically integrating the Brier score over the evaluation times.
-
- Parameters:
+
+ The integrated Brier score is computed by numerically integrating the
+ Brier score over the evaluation times.
+
+ Parameters
+ ----------
e_test (ndarray): Event indicators.
t_test (ndarray): Event/censoring times.
- risk_predicted_test (ndarray): Predicted risk matrix with shape (n_samples, n_times).
+ risk_predicted_test (ndarray): Predicted risk matrix
+ with shape (n_samples, n_times).
times (ndarray): Array of time points corresponding to the predictions.
- t_eval (ndarray, optional): Specific time points at which to compute the score. Defaults to using 'times'.
+ t_eval (ndarray, optional): Specific time points at which to
+ compute the score. Defaults to using 'times'.
km (object, optional): Kaplan–Meier estimator for IPCW.
primary_risk (int, optional): The event label for which to compute the score.
-
- Returns:
+
+ Returns
+ -------
ibs (float): Integrated Brier score.
km (object): Updated Kaplan–Meier estimator.
"""
km = estimate_ipcw(km)
t_eval = times if t_eval is None else t_eval
# Compute Brier scores at each time point.
- brier_scores = [brier_score(e_test, t_test, risk_predicted_test, times, t_val, km, primary_risk)[0]
- for t_val in t_eval]
+ brier_scores = [
+ brier_score(
+ e_test, t_test, risk_predicted_test, times, t_val, km, primary_risk
+ )[0]
+ for t_val in t_eval
+ ]
# Remove NaN values if any.
t_eval = t_eval[~np.isnan(brier_scores)]
brier_scores = np.array(brier_scores)[~np.isnan(brier_scores)]
-
+
if t_eval.shape[0] < 2:
raise ValueError("At least two time points must be provided for integration.")
-
+
ibs = np.trapz(brier_scores, t_eval) / (t_eval[-1] - t_eval[0])
return ibs, km
-
diff --git a/crisp_nam/metrics/discrimination.py b/crisp_nam/metrics/discrimination.py
index 05aa17a0..79bcd7ab 100644
--- a/crisp_nam/metrics/discrimination.py
+++ b/crisp_nam/metrics/discrimination.py
@@ -1,13 +1,32 @@
+"""Discrimination metrics for time-to-event models with competing risks.
+
+This module contains functions to compute the cumulative
+and single time-dependent AUC and time-dependent C-index
+for evaluating competing risks.
+"""
+
import numpy as np
-from .ipcw import estimate_ipcw
+
+from .ipcw import estimate_ipcw
+
epsilon = 1e-10
-def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1):
+
+def auc_td(
+ e_test: np.ndarray,
+ t_test: np.ndarray,
+ risk_predicted_test: np.ndarray,
+ times: np.ndarray,
+ t: float,
+ km: object | None = None,
+ primary_risk: int = 1,
+) -> tuple[float, object]:
"""
Compute the time-dependent AUC for a given competing risk using predicted CIFs.
- Parameters:
+ Parameters
+ ----------
e_test : ndarray of shape (n_samples,)
Event indicator (0=censored, 1=event of interest, 2=competing event, etc.)
t_test : ndarray of shape (n_samples,)
@@ -23,7 +42,8 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=
primary_risk : int
The event label to treat as the event of interest.
- Returns:
+ Returns
+ -------
auc_value : float
AUC estimate at time t (always between 0 and 1)
km : Updated Kaplan-Meier estimator
@@ -53,8 +73,8 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=
# Compute pairwise AUC: compare each (event, control) pair
auc_numerator = 0.0
auc_denominator = 0.0
- for i, (score_i, w_i) in enumerate(zip(event_scores, weights_event)):
- for j, (score_j, w_j) in enumerate(zip(control_scores, weights_control)):
+ for _i, (score_i, w_i) in enumerate(zip(event_scores, weights_event)):
+ for _j, (score_j, w_j) in enumerate(zip(control_scores, weights_control)):
weight = w_i * w_j
auc_denominator += weight
if score_i > score_j:
@@ -66,17 +86,27 @@ def auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=
auc_value = auc_numerator / auc_denominator if auc_denominator > 0 else np.nan
return auc_value, km
-def cumulative_dynamic_auc(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1):
- """
- Computes the cumulative dynamic AUC by numerically integrating the time-dependent AUC over a range of time points.
-
- Parameters:
+def cumulative_dynamic_auc(
+ e_test: np.ndarray,
+ t_test: np.ndarray,
+ risk_predicted_test: np.ndarray,
+ times: np.ndarray,
+ t_eval: np.ndarray | None = None,
+ km: object | None = None,
+ primary_risk: int = 1,
+) -> tuple[float, object]:
+ """Compute the cumulative dynamic AUC by numerically
+ integrating the time-dependent AUC over a range of time points.
+
+ Parameters
+ ----------
e_test, t_test, risk_predicted_test, times, km, primary_risk:
Same as in auc_td.
t_eval: ndarray, optional
Specific time points to evaluate. If None, uses times.
-
- Returns:
+
+ Returns
+ -------
auc_integral: float
The cumulative dynamic AUC.
km: object
@@ -84,18 +114,32 @@ def cumulative_dynamic_auc(e_test, t_test, risk_predicted_test, times, t_eval=No
"""
km = estimate_ipcw(km)
t_eval = times if t_eval is None else t_eval
- aucs = [auc_td(e_test, t_test, risk_predicted_test, times, t, km, primary_risk)[0] for t in t_eval]
+ aucs = [
+ auc_td(e_test, t_test, risk_predicted_test, times, t, km, primary_risk)[0]
+ for t in t_eval
+ ]
t_eval, aucs = t_eval[~np.isnan(aucs)], np.array(aucs)[~np.isnan(aucs)]
if t_eval.shape[0] < 2:
raise ValueError("At least two time points must be given")
auc_integral = np.trapz(aucs, t_eval) / (t_eval[-1] - t_eval[0])
return auc_integral, km
-def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1, tied_tol=1e-8):
+
+def truncated_concordance_td(
+ e_test: np.ndarray,
+ t_test: np.ndarray,
+ risk_predicted_test: np.ndarray,
+ times: np.ndarray,
+ t: float,
+ km: object | None = None,
+ primary_risk: int = 1,
+ tied_tol: float = 1e-8,
+) -> tuple[float, object]:
"""
Compute the truncated time-dependent concordance index (C-index).
-
- Parameters:
+
+ Parameters
+ ----------
e_test : ndarray
Event indicator (0=censored, 1=event of interest, etc.)
t_test : ndarray
@@ -112,14 +156,15 @@ def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=N
Risk of interest
tied_tol : float
Tolerance to assign 0.5 score for ties
-
- Returns:
+
+ Returns
+ -------
c_index : float
km : Updated km object
"""
epsilon = 1e-10
index = np.argmin(np.abs(times - t))
-
+
# IPCW
if km is not None:
km = estimate_ipcw(km)
@@ -144,7 +189,7 @@ def truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=N
after_mask = t_test > t_i
before_mask = (t_test <= t_i) & (e_test != primary_risk) & (e_test != 0)
- weights_after = weights_event[after_mask] / (w_i ** 2)
+ weights_after = weights_event[after_mask] / (w_i**2)
weights_before = weights_event[before_mask] / (w_i * weights_event[before_mask])
risks_after = risk_predicted_test[after_mask, index]
diff --git a/crisp_nam/metrics/ipcw.py b/crisp_nam/metrics/ipcw.py
index bb794484..71ce0b95 100644
--- a/crisp_nam/metrics/ipcw.py
+++ b/crisp_nam/metrics/ipcw.py
@@ -1,21 +1,25 @@
+"""IPCW estimation for time-to-event models with competing risks.
+
+This module provides a function to estimate the inverse probability of censoring weights (IPCW) using a
+Kaplan-Meier estimator.
+"""
+
from lifelines import KaplanMeierFitter
-def estimate_ipcw(km):
- """
- Estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator.
- Parameters:
- -----------
+def estimate_ipcw(km: tuple | KaplanMeierFitter) -> KaplanMeierFitter:
+ """Estimate the inverse probability of censoring weights (IPCW)
+ using a Kaplan-Meier estimator.
+
+ Parameters
+ ----------
km : tuple or KaplanMeierFitter
- If `km` is a tuple, it should contain two elements:
- - e_train: array-like, event indicators (1 if the event occurred, 0 if censored).
- - t_train: array-like, corresponding event or censoring times.
- If `km` is already a fitted KaplanMeierFitter instance, it will be used directly.
- Returns:
- --------
+ Returns
+ -------
kmf : KaplanMeierFitter
- A KaplanMeierFitter instance fitted to the provided data or the input instance if already fitted.
+ A KaplanMeierFitter instance fitted to the provided data or
+ the input instance if already fitted.
"""
if isinstance(km, tuple):
kmf = KaplanMeierFitter()
@@ -27,5 +31,3 @@ def estimate_ipcw(km):
else:
kmf = km
return kmf
-
-
diff --git a/crisp_nam/models/__init__.py b/crisp_nam/models/__init__.py
index cc2ca5d4..92966a24 100644
--- a/crisp_nam/models/__init__.py
+++ b/crisp_nam/models/__init__.py
@@ -1,2 +1,7 @@
-from .crisp_nam_model import *
-from .deephit_model import *
+"""Models available in the crisp_nam package."""
+
+from .crisp_nam_model import CrispNamModel
+from .deephit_model import DeepHit
+
+
+__all__ = ["CrispNamModel", "DeepHit"]
diff --git a/crisp_nam/models/crisp_nam_model.py b/crisp_nam/models/crisp_nam_model.py
index c8b8200c..0e5601aa 100644
--- a/crisp_nam/models/crisp_nam_model.py
+++ b/crisp_nam/models/crisp_nam_model.py
@@ -1,60 +1,96 @@
+"""CrispNamModel for competing-risks survival analysis.
+
+PyTorch implementation of CrispNamModel for competing risks
+survival analysis with L2 normalized projection weights.
+"""
+
+from typing import (
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+)
+
import torch
import numpy as np
-import torch.nn as nn
import torch.nn.functional as F
+from torch import nn
+
+class _FeatureNet(nn.Module):
+ """Neural network to model the effect of a single feature on hazard.
-class FeatureNet(nn.Module):
- """
- Neural network to model the effect of a single feature on hazard.
This is the building block for NAM with optional batch normalization.
"""
- def __init__(self, hidden_sizes=[64, 64], dropout_rate=0.1, feature_dropout=0.0,
- batch_norm=False):
- super(FeatureNet, self).__init__()
+
+ def __init__(
+ self,
+ hidden_sizes: Sequence[int] = (64, 64),
+ dropout_rate: float = 0.1,
+ feature_dropout: float = 0.0,
+ batch_norm: bool = False,
+ ) -> None:
+ """Initialize the FeatureNet."""
+ super(_FeatureNet, self).__init__()
self.batch_norm = batch_norm
- layers = []
-
- # Input layer
+ layers: List[nn.Module] = []
+
layers.append(nn.Linear(1, hidden_sizes[0]))
if batch_norm:
layers.append(nn.BatchNorm1d(hidden_sizes[0]))
layers.append(nn.Tanh())
layers.append(nn.Dropout(dropout_rate))
-
+
# Hidden layers
for i in range(len(hidden_sizes) - 1):
- layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
+ layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
if batch_norm:
- layers.append(nn.BatchNorm1d(hidden_sizes[i+1]))
+ layers.append(nn.BatchNorm1d(hidden_sizes[i + 1]))
layers.append(nn.Tanh())
layers.append(nn.Dropout(dropout_rate))
-
+
# Final representation layer
layers.append(nn.Linear(hidden_sizes[-1], hidden_sizes[-1]))
if batch_norm:
layers.append(nn.BatchNorm1d(hidden_sizes[-1]))
layers.append(nn.Tanh())
-
+
self.network = nn.Sequential(*layers)
self.feature_dropout = feature_dropout
-
- def forward(self, x):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the network.
+
+ Parameters
+ x: Input tensor of shape
+
+ Returns
+ -------
+ torch.Tensor
+ """
+ # ensure float32
x = x.to(dtype=torch.float32)
# Apply feature dropout during training if specified
if self.training and self.feature_dropout > 0:
mask = torch.rand_like(x) > self.feature_dropout
x = x * mask.float()
-
+
# Handle BatchNorm with single sample
if self.batch_norm and x.size(0) == 1:
return self._forward_singleton(x)
-
+
return self.network(x)
-
- def _forward_singleton(self, x):
- """
- Handle the case of a single sample (batch_size=1)
- where BatchNorm1d would fail
+
+ def _forward_singleton(self, x: torch.Tensor) -> torch.Tensor:
+ """Handle the case of a single sample (batch_size=1)
+ where BatchNorm1d would fail.
+
+ Parameters
+ -----------
+ x: Input tensor of shape (1, num_features)
+
+ Returns
+ -------
+ torch.Tensor
"""
was_training = self.training
self.eval()
@@ -64,12 +100,14 @@ def _forward_singleton(self, x):
self.train()
return result
-class L2NormalizedLinear(nn.Module):
- """
- Linear layer with L2 normalized weights (unit norm constraint)
- """
- def __init__(self, in_features, out_features, bias=False, eps=1e-8):
- super(L2NormalizedLinear, self).__init__()
+class _L2NormalizedLinear(nn.Module):
+ """Linear layer with L2 normalized weights (unit norm constraint)."""
+
+ def __init__(
+ self, in_features: int, out_features: int, bias: bool = False, eps: float = 1e-8
+ ) -> None:
+ """Initialize the L2NormalizedLinear layer."""
+ super(_L2NormalizedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.eps = eps
@@ -77,27 +115,48 @@ def __init__(self, in_features, out_features, bias=False, eps=1e-8):
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
- self.register_parameter('bias', None)
-
- def forward(self, x):
+ self.register_parameter("bias", None)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply the linear transformation with L2 normalized weights.
+
+ Parameters
+ -----------
+ x: Input tensor of shape (batch_size, in_features)
+
+ Returns
+ -------
+ Output tensor
+ """
# L2 normalize weights to unit norm
normalized_weight = F.normalize(self.weight, p=2, dim=1, eps=self.eps)
return F.linear(x, normalized_weight, self.bias)
-
- def get_normalized_weights(self):
- """Return the L2 normalized weights (useful for inspection)"""
+
+ def get_normalized_weights(self) -> torch.Tensor:
+ """Return the L2 normalized weights (useful for inspection)."""
with torch.no_grad():
return F.normalize(self.weight, p=2, dim=1, eps=self.eps)
+
class CrispNamModel(nn.Module):
- """
- Competing risks CoxNAM with L2 normalized projection weights.
+ """Competing risks CoxNAM with L2 normalized projection weights.
+
Each feature contributes to each risk through a separate shape function.
All projection weights are constrained to unit L2 norm.
"""
- def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64],
- dropout_rate=0.1, feature_dropout=0.1, batch_norm=False,
- normalize_projections=True, eps=1e-8):
+
+ def __init__(
+ self,
+ num_features: int,
+ num_competing_risks: int,
+ hidden_sizes: Sequence[int] = (64, 64),
+ dropout_rate: float = 0.1,
+ feature_dropout: float = 0.1,
+ batch_norm: bool = False,
+ normalize_projections: bool = True,
+ eps: float = 1e-8,
+ ):
+ """Initialize the CrispNamModel."""
super(CrispNamModel, self).__init__()
self.num_features = num_features
self.num_competing_risks = num_competing_risks
@@ -105,39 +164,51 @@ def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64],
self.feature_dropout = feature_dropout
self.normalize_projections = normalize_projections
self.eps = eps
-
+
# Create a FeatureNet for each input feature
- self.feature_nets = nn.ModuleList([
- FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm)
- for _ in range(num_features)
- ])
-
+ self.feature_nets = nn.ModuleList(
+ [
+ _FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm)
+ for _ in range(num_features)
+ ]
+ )
+
# For each feature and risk type, create a projection layer
if normalize_projections:
- self.risk_projections = nn.ModuleList([
- nn.ModuleList([
- L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps)
- for _ in range(num_competing_risks)
- ])
- for _ in range(num_features)
- ])
+ self.risk_projections: nn.ModuleList = nn.ModuleList(
+ [
+ nn.ModuleList(
+ [
+ _L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps)
+ for _ in range(num_competing_risks)
+ ]
+ )
+ for _ in range(num_features)
+ ]
+ )
else:
# Fallback to standard linear layers
- self.risk_projections = nn.ModuleList([
- nn.ModuleList([
- nn.Linear(hidden_sizes[-1], 1, bias=False)
- for _ in range(num_competing_risks)
- ])
- for _ in range(num_features)
- ])
+ self.risk_projections = nn.ModuleList(
+ [
+ nn.ModuleList(
+ [
+ nn.Linear(hidden_sizes[-1], 1, bias=False)
+ for _ in range(num_competing_risks)
+ ]
+ )
+ for _ in range(num_features)
+ ]
+ )
- def forward(self, x):
- """
- Forward pass to compute risk scores for all competing risks
-
- Args:
+ def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ """Forward pass to compute risk scores for all competing risks.
+
+ Parameters
+ -----------
x: Tensor of shape (batch_size, num_features)
- Returns:
+
+ Returns
+ -------
risk_scores: List of (batch_size, 1) Tensors
feature_outputs: List of (batch_size, hidden) Tensors
"""
@@ -159,337 +230,155 @@ def forward(self, x):
# loop features
for feat_idx, fnet in enumerate(self.feature_nets):
# take one column and get repr
- col = x[:, feat_idx].unsqueeze(1) # [batch,1]
- repr = fnet(col) # [batch, hidden]
+ col = x[:, feat_idx].unsqueeze(1) # [batch,1]
+ repr = fnet(col) # [batch, hidden]
feature_outputs.append(repr)
+ proj: Optional[nn.ModuleList | None] = None
# project into each risk channel with L2 normalized weights
for risk_idx, proj in enumerate(self.risk_projections[feat_idx]):
- # proj automatically applies L2 normalization if normalize_projections=True
+ # proj automatically applies L2 normalization
+ # if normalize_projections=True
combined[:, risk_idx] += proj(repr).view(-1)
# split back into list of [batch,1]
risk_scores = [
- combined[:, r].unsqueeze(1)
- for r in range(self.num_competing_risks)
+ combined[:, r].unsqueeze(1) for r in range(self.num_competing_risks)
]
return risk_scores, feature_outputs
-
- def get_shape_functions(self, x_values, feature_idx, risk_idx=None, normalize=True):
- """
- Extract shape functions for a specific feature across all risks or a specific risk
-
- Args:
- x_values: Feature values to evaluate (numpy array or tensor)
- feature_idx: Index of the feature to get shape functions for
- risk_idx: Optional; if provided, only returns the shape function for this risk
- normalize: Whether to center the shape functions
-
- Returns:
- Dictionary mapping risk names to shape function values
- """
- self.eval()
-
- if not isinstance(x_values, torch.Tensor):
- x_values = torch.FloatTensor(x_values)
-
- x_vals = x_values.view(-1, 1)
-
- with torch.no_grad():
- feature_repr = self.feature_nets[feature_idx](x_vals)
-
- shape_funcs = {}
-
- # If risk_idx is specified, only compute for that risk
- risk_indices = [risk_idx] if risk_idx is not None else range(self.num_competing_risks)
-
- for j in risk_indices:
- # Apply the L2 normalized projection to get shape function values
- values = self.risk_projections[feature_idx][j](feature_repr).cpu().numpy().flatten()
-
- # Normalize if requested
- if normalize:
- values = values - np.mean(values)
-
- shape_funcs[f'risk_{j+1}'] = values
-
- return shape_funcs
-
- def get_projection_norms(self):
- """
- Get the L2 norms of all projection weights (should be ~1.0 if normalized)
-
- Returns:
+
+ def get_projection_norms(self) -> dict:
+ """Get the L2 norms of all projection weights (should be ~1.0 if normalized).
+
+ Returns
+ -------
Dictionary of weight norms by feature and risk
"""
norms = {}
-
+
for feat_idx in range(self.num_features):
for risk_idx in range(self.num_competing_risks):
proj = self.risk_projections[feat_idx][risk_idx]
-
- if hasattr(proj, 'weight'):
+
+ if hasattr(proj, "weight"):
weight_norm = proj.weight.norm(p=2, dim=1).item()
- norms[f'feature_{feat_idx}_risk_{risk_idx}'] = weight_norm
-
+ norms[f"feature_{feat_idx}_risk_{risk_idx}"] = weight_norm
+
return norms
-
- def get_normalized_projection_weights(self):
- """
- Get the actual L2 normalized weights used in computation
-
- Returns:
+
+ def get_normalized_projection_weights(self) -> dict:
+ """Get the actual L2 normalized weights used in computation.
+
+ Returns
+ -------
Dictionary of normalized weights
"""
normalized_weights = {}
-
+
for feat_idx in range(self.num_features):
for risk_idx in range(self.num_competing_risks):
proj = self.risk_projections[feat_idx][risk_idx]
-
- if hasattr(proj, 'get_normalized_weights'):
+
+ if hasattr(proj, "get_normalized_weights"):
# L2NormalizedLinear layer
weights = proj.get_normalized_weights().detach().cpu().numpy()
- elif hasattr(proj, 'weight'):
+ elif hasattr(proj, "weight"):
# Standard linear layer - normalize manually
- weights = F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy()
+ weights = (
+ F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy()
+ )
else:
weights = None
-
- normalized_weights[f'feature_{feat_idx}_risk_{risk_idx}'] = weights
-
+
+ normalized_weights[f"feature_{feat_idx}_risk_{risk_idx}"] = weights
+
return normalized_weights
-
- def calculate_feature_importance(self, x_data, feature_idx=None):
- """
- Calculate feature importance based on the magnitude of risk-specific projection outputs
- With L2 normalized weights, this gives a fair comparison across features
-
- Args:
+
+ def calculate_feature_importance(
+ self,
+ x_data: Optional[torch.Tensor | np.ndarray],
+ feature_idx: Optional[int | None] = None,
+ ) -> dict:
+ """Calculate feature importance based on the magnitude of
+ risk-specific projection outputs.
+
+ With L2 normalized weights, this gives a fair
+ comparison across features.
+
+ Parameters
+ -----------
x_data: Input data tensor or numpy array
- feature_idx: Optional; if provided, only calculate importance for this feature
-
- Returns:
+ feature_idx: Optional; if provided, only calculate
+ importance for this feature
+
+ Returns
+ -------
Dictionary of feature importances by risk type
"""
+
self.eval()
device = next(self.parameters()).device
-
+
# Convert to tensor if needed
if not isinstance(x_data, torch.Tensor):
x_data = torch.FloatTensor(x_data)
x_data = x_data.to(device)
-
- feature_indices = [feature_idx] if feature_idx is not None else range(self.num_features)
- importance = {f'risk_{j+1}': {} for j in range(self.num_competing_risks)}
-
+
+ feature_indices = (
+ [feature_idx] if feature_idx is not None else range(self.num_features)
+ )
+ importance: dict = {
+ f"risk_{j + 1}": {} for j in range(self.num_competing_risks)
+ }
+
for i in feature_indices:
# Get feature values
feature_values = x_data[:, i].view(-1, 1)
-
+
with torch.no_grad():
# Get the feature representation
feature_repr = self.feature_nets[i](feature_values)
-
+
# Calculate importance for each risk (mean absolute value)
- # With L2 normalized weights, this is directly comparable across features
+ # With L2 normalized weights, this is comparable across features
for j in range(self.num_competing_risks):
risk_specific_output = self.risk_projections[i][j](feature_repr)
abs_values = torch.abs(risk_specific_output).cpu().numpy()
- importance[f'risk_{j+1}'][f'feature_{i}'] = float(np.mean(abs_values))
-
+ importance[f"risk_{j + 1}"][f"feature_{i}"] = float(
+ np.mean(abs_values)
+ )
+
return importance
-
- def predict_risk(self, x, baseline_hazards=None):
- """
- Predict survival probability or cumulative incidence
-
- Args:
- x: Input tensor of shape (batch_size, num_features)
- baseline_hazards: Optional dict of baseline hazards for each risk
-
- Returns:
- Dictionary of predictions for each competing risk
- """
- self.eval()
-
- # Convert to tensor if needed
- if not isinstance(x, torch.Tensor):
- x = torch.FloatTensor(x)
-
- with torch.no_grad():
- risk_scores, _ = self(x)
-
- # Convert scores to hazard ratios
- hazard_ratios = [torch.exp(score).cpu().numpy() for score in risk_scores]
-
- # If baseline hazards are provided, compute absolute risks
- if baseline_hazards is not None:
- predictions = {}
-
- for j in range(self.num_competing_risks):
- risk_name = f'risk_{j+1}'
-
- # Baseline survival and hazard
- baseline_surv = baseline_hazards.get(risk_name, {}).get('survival', None)
- baseline_haz = baseline_hazards.get(risk_name, {}).get('hazard', None)
-
- if baseline_surv is not None:
- # Compute survival probability: S(t|x) = S0(t)^exp(f(x))
- predictions[f'{risk_name}_survival'] = np.power(
- baseline_surv.reshape(1, -1),
- hazard_ratios[j].reshape(-1, 1)
- )
-
- if baseline_haz is not None:
- # Compute cumulative hazard: H(t|x) = H0(t) * exp(f(x))
- predictions[f'{risk_name}_cumhazard'] = baseline_haz.reshape(1, -1) * hazard_ratios[j].reshape(-1, 1)
-
- return predictions
- else:
- # Without baseline hazards, just return hazard ratios
- return {f'risk_{j+1}_hazard_ratio': hazard_ratios[j] for j in range(self.num_competing_risks)}
-
-# Alternative implementation using manual normalization in forward pass
-class CrispNamModelManualL2(nn.Module):
- """
- Alternative implementation with manual L2 normalization in forward pass
- """
- def __init__(self, num_features, num_competing_risks, hidden_sizes=[64, 64],
- dropout_rate=0.1, feature_dropout=0.1, batch_norm=False, eps=1e-8):
- super(CrispNamModelManualL2, self).__init__()
- self.num_features = num_features
- self.num_competing_risks = num_competing_risks
- self.batch_norm = batch_norm
- self.feature_dropout = feature_dropout
- self.eps = eps
-
- # Create a FeatureNet for each input feature
- self.feature_nets = nn.ModuleList([
- FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm)
- for _ in range(num_features)
- ])
-
- # Standard linear layers - weights will be L2 normalized in forward pass
- self.risk_projections = nn.ModuleList([
- nn.ModuleList([
- nn.Linear(hidden_sizes[-1], 1, bias=False)
- for _ in range(num_competing_risks)
- ])
- for _ in range(num_features)
- ])
-
- def forward(self, x):
- """Forward pass with manual L2 weight normalization"""
- x = x.to(dtype=torch.float32)
- batch_size, _ = x.shape
- device = x.device
- if self.training and self.feature_dropout > 0:
- mask = torch.empty_like(x).bernoulli_(1.0 - self.feature_dropout)
- x = x * mask
+ # Utility functions for model analysis
+ def analyze_projection_weights(self) -> dict:
+ """Analyze the L2 norms and statistics of projection weights.
- combined = torch.zeros(batch_size, self.num_competing_risks, device=device)
- feature_outputs = []
+ Parameters
+ -----------
+ None
- for feat_idx, fnet in enumerate(self.feature_nets):
- col = x[:, feat_idx].unsqueeze(1)
- repr = fnet(col)
- feature_outputs.append(repr)
+ Returns
+ -------
+ None
+ """
+ print("Projection Weight Analysis:")
+ print("=" * 50)
- for risk_idx, proj in enumerate(self.risk_projections[feat_idx]):
- # L2 normalize weights manually
- normalized_weight = F.normalize(proj.weight, p=2, dim=1, eps=self.eps)
- # Apply normalized projection
- output = F.linear(repr, normalized_weight, proj.bias)
- combined[:, risk_idx] += output.view(-1)
+ # Get weight norms
+ norms = self.get_projection_norms()
+ norm_values = list(norms.values())
- risk_scores = [
- combined[:, r].unsqueeze(1)
- for r in range(self.num_competing_risks)
- ]
+ print("Weight L2 Norms (should be ~1.0):")
+ print(f" Mean: {np.mean(norm_values):.6f}")
+ print(f" Std: {np.std(norm_values):.6f}")
+ print(f" Min: {np.min(norm_values):.6f}")
+ print(f" Max: {np.max(norm_values):.6f}")
- return risk_scores, feature_outputs
+ # Show some individual norms
+ print("\nSample individual norms:")
+ for _i, (name, norm) in enumerate(list(norms.items())[:6]):
+ print(f" {name}: {norm:.6f}")
-# Utility functions for model analysis
-def analyze_projection_weights(model):
- """
- Analyze the L2 norms and statistics of projection weights
- """
- print("Projection Weight Analysis:")
- print("=" * 50)
-
- # Get weight norms
- norms = model.get_projection_norms()
- norm_values = list(norms.values())
-
- print(f"Weight L2 Norms (should be ~1.0):")
- print(f" Mean: {np.mean(norm_values):.6f}")
- print(f" Std: {np.std(norm_values):.6f}")
- print(f" Min: {np.min(norm_values):.6f}")
- print(f" Max: {np.max(norm_values):.6f}")
-
- # Show some individual norms
- print(f"\nSample individual norms:")
- for i, (name, norm) in enumerate(list(norms.items())[:6]):
- print(f" {name}: {norm:.6f}")
-
- return norms
-
-def compare_feature_importance_fairness(model, x_data):
- """
- Compare feature importance when weights are L2 normalized vs not normalized
- """
- print("\nFeature Importance Comparison:")
- print("=" * 50)
-
- # Calculate importance with current model (L2 normalized)
- importance_normalized = model.calculate_feature_importance(x_data)
-
- # Create equivalent model without normalization for comparison
- model_unnorm = CrispNamModel(
- model.num_features,
- model.num_competing_risks,
- normalize_projections=False
- )
-
- # Copy weights from normalized model
- with torch.no_grad():
- for i in range(model.num_features):
- for j in range(model.num_competing_risks):
- model_unnorm.risk_projections[i][j].weight.copy_(
- model.risk_projections[i][j].weight
- )
-
- importance_unnorm = model_unnorm.calculate_feature_importance(x_data)
-
- print("Importance comparison (Risk 1):")
- for feat in range(min(5, model.num_features)): # Show first 5 features
- norm_imp = importance_normalized['risk_1'].get(f'feature_{feat}', 0)
- unnorm_imp = importance_unnorm['risk_1'].get(f'feature_{feat}', 0)
- print(f" Feature {feat}: Normalized={norm_imp:.4f}, Unnormalized={unnorm_imp:.4f}")
-
-# Example usage and testing
-if __name__ == "__main__":
- # Create model with L2 normalized projections
- model = CrispNamModel(
- num_features=5,
- num_competing_risks=3,
- hidden_sizes=[32, 32],
- normalize_projections=True
- )
-
- # Generate some test data
- torch.manual_seed(42)
- test_data = torch.randn(100, 5)
-
- # Test forward pass
- risk_scores, feature_outputs = model(test_data)
- print(f"Risk scores shapes: {[score.shape for score in risk_scores]}")
-
- # Analyze projection weights
- analyze_projection_weights(model)
-
- # Compare feature importance
- compare_feature_importance_fairness(model, test_data)
\ No newline at end of file
+ return norms
diff --git a/crisp_nam/models/deephit_model.py b/crisp_nam/models/deephit_model.py
index cc3394f3..5d1d3a75 100644
--- a/crisp_nam/models/deephit_model.py
+++ b/crisp_nam/models/deephit_model.py
@@ -1,23 +1,48 @@
+"""PyTorch implementation of DeepHit for competing risks survival analysis."""
+
+from typing import Callable, Optional
+
+import numpy as np
import torch
-import torch.nn as nn
import torch.nn.functional as F
+from torch import nn
+
class FCLayer(nn.Module):
"""Fully connected layer with optional batch norm, dropout, and activation."""
- def __init__(self, in_dim, out_dim, activation=nn.ReLU(), batch_norm=False, dropout_rate=0.0,
- init_fn=nn.init.xavier_normal_):
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ activation: Optional[nn.Module] = None,
+ batch_norm: bool = False,
+ dropout_rate: float = 0.0,
+ init_fn: Optional[Callable | None] = nn.init.xavier_normal_,
+ ) -> None:
+ """Initialize the fully connected layer."""
super(FCLayer, self).__init__()
self.fc = nn.Linear(in_dim, out_dim)
self.batch_norm = nn.BatchNorm1d(out_dim) if batch_norm else None
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
- self.activation = activation
-
+ self.activation = activation if activation else nn.ReLU()
+
# Initialize weights
if init_fn:
init_fn(self.fc.weight)
nn.init.zeros_(self.fc.bias)
-
- def forward(self, x):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the layer.
+
+ Parameters
+ ----------
+ x: Tensor of shape (batch_size, in_dim)
+
+ Returns
+ -------
+ out: Tensor of shape (batch_size, out_dim)
+ """
x = self.fc(x)
if self.batch_norm:
x = self.batch_norm(x)
@@ -30,222 +55,383 @@ def forward(self, x):
class FCNet(nn.Module):
"""Multi-layer fully connected network."""
- def __init__(self, in_dim, num_layers, h_dim, activation=nn.ReLU(),
- out_dim=None, out_activation=None, batch_norm=False,
- dropout_rate=0.0, init_fn=nn.init.xavier_normal_):
+
+ def __init__(
+ self,
+ in_dim: int,
+ num_layers: int,
+ h_dim: int,
+ activation: Optional[nn.Module] = None,
+ out_dim: Optional[int | None] = None,
+ out_activation: Optional[nn.Module | None] = None,
+ batch_norm: bool = False,
+ dropout_rate: float = 0.0,
+ init_fn: Optional[Callable | None] = nn.init.xavier_normal_,
+ ) -> None:
+ """Initialize the fully connected network."""
super(FCNet, self).__init__()
-
+
layers = []
prev_dim = in_dim
-
+ activation = activation if activation else nn.ReLU()
+
# Hidden layers
for i in range(num_layers):
curr_dim = out_dim if (i == num_layers - 1 and out_dim) else h_dim
- curr_act = out_activation if (i == num_layers - 1 and out_activation) else activation
-
- layers.append(FCLayer(
- prev_dim, curr_dim, activation=curr_act,
- batch_norm=batch_norm, dropout_rate=dropout_rate, init_fn=init_fn
- ))
+ curr_act: Optional[nn.Module | None] = (
+ out_activation
+ if (i == num_layers - 1 and out_activation)
+ else activation
+ )
+
+ layers.append(
+ FCLayer(
+ prev_dim,
+ curr_dim,
+ activation=curr_act,
+ batch_norm=batch_norm,
+ dropout_rate=dropout_rate,
+ init_fn=init_fn,
+ )
+ )
prev_dim = curr_dim
-
+
self.network = nn.Sequential(*layers)
-
- def forward(self, x):
+
+ def forward(self, x: torch.Tensor) -> Optional[nn.Module]:
+ """Forward pass through the network.
+
+ Parameters
+ -----------
+ x: Tensor of shape (batch_size, in_dim)
+
+ Returns
+ -------
+ out: Tensor of shape (batch_size, out_dim)
+ """
return self.network(x)
class DeepHit(nn.Module):
"""PyTorch implementation of DeepHit for competing risks survival analysis."""
-
- def __init__(self, input_dims, network_settings):
+
+ def __init__(self, input_dims: dict, network_settings: dict):
+ """Initialize the DeepHit model."""
super(DeepHit, self).__init__()
-
+
# Input dimensions
- self.x_dim = input_dims['x_dim']
- self.num_Event = input_dims['num_Event']
- self.num_Category = input_dims['num_Category']
-
+ self.x_dim = input_dims["x_dim"]
+ self.num_Event = input_dims["num_Event"]
+ self.num_Category = input_dims["num_Category"]
+
# Network settings
- self.h_dim_shared = network_settings['h_dim_shared']
- self.h_dim_CS = network_settings['h_dim_CS']
- self.num_layers_shared = network_settings['num_layers_shared']
- self.num_layers_CS = network_settings['num_layers_CS']
-
+ self.h_dim_shared = network_settings["h_dim_shared"]
+ self.h_dim_CS = network_settings["h_dim_CS"]
+ self.num_layers_shared = network_settings["num_layers_shared"]
+ self.num_layers_CS = network_settings["num_layers_CS"]
+
# Activation function
- if network_settings['active_fn'] == 'relu':
+ if network_settings["active_fn"] == "relu":
self.active_fn = nn.ReLU()
- elif network_settings['active_fn'] == 'elu':
+ elif network_settings["active_fn"] == "elu":
self.active_fn = nn.ELU()
- elif network_settings['active_fn'] == 'tanh':
+ elif network_settings["active_fn"] == "tanh":
self.active_fn = nn.Tanh()
else:
self.active_fn = nn.ReLU()
-
+
# Regularization
- self.keep_prob = network_settings.get('keep_prob', 0.5)
+ self.keep_prob = network_settings.get("keep_prob", 0.5)
self.dropout_rate = 1.0 - self.keep_prob
-
+
# Initialize networks
self._build_network()
-
- def _build_network(self):
+
+ def _build_network(self) -> None:
+ """Build the shared and cause-specific networks.
+
+ Parameters
+ ----------
+ None
+
+ Returns
+ -------
+ None
+ """
# Shared network
self.shared_net = FCNet(
in_dim=self.x_dim,
num_layers=self.num_layers_shared,
h_dim=self.h_dim_shared,
activation=self.active_fn,
- dropout_rate=self.dropout_rate
+ dropout_rate=self.dropout_rate,
)
-
+
# Cause-specific networks
- self.cs_nets = nn.ModuleList([
- FCNet(
- in_dim=self.x_dim + self.h_dim_shared, # Concatenate input and shared output
- num_layers=self.num_layers_CS,
- h_dim=self.h_dim_CS,
- activation=self.active_fn,
- dropout_rate=self.dropout_rate
- ) for _ in range(self.num_Event)
- ])
-
+ self.cs_nets = nn.ModuleList(
+ [
+ FCNet(
+ in_dim=self.x_dim
+ + self.h_dim_shared, # Concatenate input and shared output
+ num_layers=self.num_layers_CS,
+ h_dim=self.h_dim_CS,
+ activation=self.active_fn,
+ dropout_rate=self.dropout_rate,
+ )
+ for _ in range(self.num_Event)
+ ]
+ )
+
# Output layer
- self.output_layer = nn.Linear(self.num_Event * self.h_dim_CS, self.num_Event * self.num_Category)
-
- def forward(self, x):
+ self.output_layer = nn.Linear(
+ self.num_Event * self.h_dim_CS, self.num_Event * self.num_Category
+ )
+
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
+ """Forward pass through the network.
+
+ Parameters
+ ----------
+ x: Tensor of shape (batch_size, num_Event, num_Category)
+
+ Returns
+ -------
+ risk_scores: List of (batch_size, 1) Tensors
+ feature_outputs: None
+ """
# Shared network
shared_out = self.shared_net(x)
-
+
# Concatenate input with shared output
h = torch.cat([x, shared_out], dim=1)
-
+
# Cause-specific networks
cs_outputs = []
for cs_net in self.cs_nets:
cs_out = cs_net(h)
cs_outputs.append(cs_out)
-
+
# Stack outputs
- stacked_out = torch.stack(cs_outputs, dim=1) # [batch_size, num_Event, h_dim_CS]
- reshaped_out = stacked_out.view(-1, self.num_Event * self.h_dim_CS) # [batch_size, num_Event * h_dim_CS]
-
+ stacked_out = torch.stack(
+ cs_outputs, dim=1
+ ) # [batch_size, num_Event, h_dim_CS]
+ reshaped_out = stacked_out.view(
+ -1, self.num_Event * self.h_dim_CS
+ ) # [batch_size, num_Event * h_dim_CS]
+
# Final output layer
- logits = self.output_layer(F.dropout(reshaped_out, self.dropout_rate, self.training))
+ logits = self.output_layer(
+ F.dropout(reshaped_out, self.dropout_rate, self.training)
+ )
out = F.softmax(logits.view(-1, self.num_Event * self.num_Category), dim=1)
-
+
# Reshape to [batch_size, num_Event, num_Category]
out = out.view(-1, self.num_Event, self.num_Category)
-
- # For compatibility with the training script, return both raw risks and shape functions
+
+ # For compatibility with the training script, return both
+ # raw risks and shape functions
# In this model, we don't have separate shape functions, so just return None
return out, None
-
- def log_likelihood_loss(self, out, t, k, mask1, mask2):
- """Log-likelihood loss (including log-likelihood of censored subjects)"""
- batch_size = out.size(0)
-
+
+ def _log_likelihood_loss(
+ self,
+ out: torch.Tensor,
+ t: Optional[torch.Tensor | np.ndarray],
+ k: Optional[torch.Tensor | np.ndarray],
+ mask1: torch.Tensor,
+ mask2: torch.Tensor,
+ ) -> torch.Tensor:
+ """Log-likelihood loss (including log-likelihood of censored subjects).
+
+ Parameters
+ ----------
+ out: Torch.tensor
+ t: Torch.tensor or numpy array
+ k: Torch.tensor or numpy array
+ mask1: Torch.tensor
+ mask2: Torch.tensor
+
+ Returns
+ -------
+ loss: Torch.tensor
+ """
# Convert to PyTorch tensors if necessary
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=out.device)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=out.device)
-
+
# Indicator for uncensored subjects
- I_1 = (k > 0).float().view(-1, 1)
-
+ i_1 = (k > 0).float().view(-1, 1)
+
# For uncensored: log P(T=t, K=k|x)
tmp1 = torch.sum(torch.sum(mask1 * out, dim=2), dim=1, keepdim=True)
- tmp1 = I_1 * torch.log(tmp1 + 1e-8)
-
+ tmp1 = i_1 * torch.log(tmp1 + 1e-8)
+
# For censored: log ∑ P(T>t|x)
- tmp2 = torch.sum(torch.sum(mask2.unsqueeze(1) * out, dim=2), dim=1, keepdim=True)
- tmp2 = (1.0 - I_1) * torch.log(tmp2 + 1e-8)
-
+ tmp2 = torch.sum(
+ torch.sum(mask2.unsqueeze(1) * out, dim=2), dim=1, keepdim=True
+ )
+ tmp2 = (1.0 - i_1) * torch.log(tmp2 + 1e-8)
+
return -torch.mean(tmp1 + tmp2)
-
- def ranking_loss(self, out, t, k, mask2):
- """Ranking loss (calculated only for acceptable pairs)"""
- batch_size = out.size(0)
+
+ def _ranking_loss(
+ self,
+ out: torch.Tensor,
+ t: Optional[torch.Tensor | np.ndarray],
+ k: Optional[torch.Tensor | np.ndarray],
+ mask2: torch.Tensor,
+ ) -> torch.Tensor:
+ """Ranking loss (calculated only for acceptable pairs).
+
+ Parameters
+ ----------
+ out: Torch.tensor
+ t: Torch.tensor or numpy array
+ k: Torch.tensor or numpy array
+ mask2: Torch.tensor
+
+ Returns
+ -------
+ loss: Torch.tensor
+ """
sigma1 = 0.1
eta = []
-
+
# Convert to PyTorch tensors if necessary
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=out.device)
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, device=out.device)
-
+
one_vector = torch.ones_like(t)
-
+
for e in range(self.num_Event):
- # Indicator for current event type
- I_2 = (k == e+1).float()
- I_2_diag = torch.diag(I_2.squeeze())
-
+ i_2 = (k == e + 1).float()
+ i_2_diag = torch.diag(i_2.squeeze())
+
# Extract event-specific probabilities
tmp_e = out[:, e, :] # [batch_size, num_Category]
-
+
# Calculate risk scores
- R = torch.matmul(tmp_e, mask2.transpose(0, 1)) # [batch_size, batch_size]
- diag_R = torch.diag(R).unsqueeze(1) # [batch_size, 1]
- R = torch.matmul(one_vector, diag_R.transpose(0, 1)) - R # [batch_size, batch_size]
- R = R.transpose(0, 1) # Now R_ij = r_i(T_i) - r_j(T_i)
-
+ r = torch.matmul(tmp_e, mask2.transpose(0, 1)) # [batch_size, batch_size]
+ diag_r = torch.diag(r).unsqueeze(1) # [batch_size, 1]
+ r = (
+ torch.matmul(one_vector, diag_r.transpose(0, 1)) - r
+ ) # [batch_size, batch_size]
+ r = r.transpose(0, 1) # Now R_ij = r_i(T_i) - r_j(T_i)
+
# Time comparison matrix
- T = F.relu(torch.sign(torch.matmul(one_vector, t.transpose(0, 1)) -
- torch.matmul(t, one_vector.transpose(0, 1))))
-
+ time = F.relu(
+ torch.sign(
+ torch.matmul(one_vector, t.transpose(0, 1))
+ - torch.matmul(t, one_vector.transpose(0, 1))
+ )
+ )
+
# Filter by event occurrence
- T = torch.matmul(I_2_diag, T)
-
+ time = torch.matmul(i_2_diag, time)
+
# Calculate ranking loss for current event
- tmp_eta = torch.mean(T * torch.exp(-R / sigma1), dim=1, keepdim=True)
+ tmp_eta = torch.mean(time * torch.exp(-r / sigma1), dim=1, keepdim=True)
eta.append(tmp_eta)
-
+
eta = torch.stack(eta, dim=1) # [batch_size, num_Event]
- eta = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)
-
- return torch.sum(eta)
-
- def calibration_loss(self, out, t, k, mask2):
- """Calibration loss"""
- batch_size = out.size(0)
+ eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)
+
+ return torch.sum(eta_mean)
+
+ def _calibration_loss(
+ self,
+ out: torch.Tensor,
+ t: Optional[torch.Tensor | np.ndarray],
+ k: Optional[torch.Tensor | np.ndarray],
+ mask2: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calibration loss.
+
+ Parameters
+ ----------
+ out: Torch.tensor
+ t: Torch.tensor or numpy array
+ k: Torch.tensor or numpy array
+ mask2: Torch.tensor
+
+ Returns
+ -------
+ loss: Torch.tensor
+ """
eta = []
-
+
# Convert to PyTorch tensors if necessary
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=out.device)
-
+
for e in range(self.num_Event):
# Indicator for current event type
- I_2 = (k == e+1).float()
-
+ i_2 = (k == e + 1).float()
+
# Extract event-specific probabilities
tmp_e = out[:, e, :] # [batch_size, num_Category]
-
+
# Calculate calibration loss
r = torch.sum(tmp_e * mask2, dim=1)
- tmp_eta = torch.mean((r - I_2) ** 2, dim=0, keepdim=True)
+ tmp_eta = torch.mean((r - i_2) ** 2, dim=0, keepdim=True)
eta.append(tmp_eta)
-
+
eta = torch.stack(eta, dim=1) # [1, num_Event]
- eta = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)
-
- return torch.sum(eta)
-
- def compute_loss(self, out, t, k, mask1, mask2, alpha=1.0, beta=1.0, gamma=1.0):
- """Compute total loss"""
- loss1 = self.log_likelihood_loss(out, t, k, mask1, mask2)
- loss2 = self.ranking_loss(out, t, k, mask2)
- loss3 = self.calibration_loss(out, t, k, mask2)
-
+ eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)
+
+ return torch.sum(eta_mean)
+
+ def compute_loss(
+ self,
+ out: torch.Tensor,
+ t: Optional[torch.Tensor | np.ndarray],
+ k: Optional[torch.Tensor | np.ndarray],
+ mask1: Optional[torch.Tensor | np.ndarray],
+ mask2: torch.Tensor,
+ alpha: float = 1.0,
+ beta: float = 1.0,
+ gamma: float = 1.0,
+ ) -> torch.Tensor:
+ """Compute total loss.
+
+ Parameters
+ ----------
+ out: Torch.tensor
+ t: Torch.tensor or numpy array
+ k: Torch.tensor or numpy array
+ mask1: Torch.tensor
+ mask2: Torch.tensor
+ alpha: float, weight for log-likelihood loss
+ beta: float, weight for ranking loss
+ gamma: float, weight for calibration loss
+
+ Returns
+ -------
+ total_loss: Torch.tensor
+ """
+ loss1 = self._log_likelihood_loss(out, t, k, mask1, mask2)
+ loss2 = self._ranking_loss(out, t, k, mask2)
+ loss3 = self._calibration_loss(out, t, k, mask2)
+
# L2 regularization is handled by optimizer (weight_decay)
return alpha * loss1 + beta * loss2 + gamma * loss3
-
- def predict(self, x):
- """Predict risk scores for input x"""
+
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
+ """Predict risk scores for input x.
+
+ Parameters
+ ----------
+ x: Tensor of shape (batch_size, num_Event, num_Category)
+
+ Returns
+ -------
+ out: Tensor of shape (batch_size, num_Event, num_Category)
+ """
self.eval()
with torch.no_grad():
out, _ = self.forward(x)
- return out
\ No newline at end of file
+ return out
diff --git a/crisp_nam/utils/__init__.py b/crisp_nam/utils/__init__.py
index 08275dbe..42bd9dea 100644
--- a/crisp_nam/utils/__init__.py
+++ b/crisp_nam/utils/__init__.py
@@ -1,3 +1,5 @@
+"""Utility functions for crisp-nam package."""
+
from .loss import *
from .plotting import *
-from .risk_cif import *
\ No newline at end of file
+from .risk_cif import *
diff --git a/crisp_nam/utils/loss.py b/crisp_nam/utils/loss.py
index 9cae44fd..5055387e 100644
--- a/crisp_nam/utils/loss.py
+++ b/crisp_nam/utils/loss.py
@@ -1,143 +1,175 @@
+"""Loss functions for competing risks.
+
+This module implements weighted and un-weighted
+negative log-likelihood loss, L2 penalty loss functions.
+"""
+
import torch
-def weighted_negative_log_likelihood_loss(risk_scores, times, events,
- num_competing_risks, event_weights=None,
- sample_weights=None, eps=1e-8) -> float:
+
+def weighted_negative_log_likelihood_loss(
+ risk_scores,
+ times,
+ events,
+ num_competing_risks,
+ event_weights=None,
+ sample_weights=None,
+ eps=1e-8,
+) -> torch.Tensor:
"""
- Computes the weighted negative log-likelihood loss for competing risks Cox model.
-
- Args:
+ Compute the weighted negative log-likelihood loss for competing risks Cox model.
+
+ Parameters
+ ----------
risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
times: Event/censoring times (batch_size,)
events: Event indicators (0=censored, 1...K=event types) (batch_size,)
num_competing_risks: Number of competing risks
- event_weights: Tensor of weights for each competing risk type (size: num_competing_risks)
+ event_weights: Tensor of weights for each competing risk type
+ (size: num_competing_risks)
sample_weights: Tensor of weights for each sample (size: batch_size)
eps: Small constant for numerical stability
-
- Returns:
+
+ Returns
+ -------
Weighted negative log partial likelihood loss
"""
device = times.device
batch_size = times.shape[0]
-
+
# Initialize loss
loss = torch.tensor(0.0, device=device)
-
+
# Set default weights if not provided
if event_weights is None:
event_weights = torch.ones(num_competing_risks, device=device)
if sample_weights is None:
sample_weights = torch.ones(batch_size, device=device)
-
+
# Count number of events
n_events = (events > 0).sum().item()
if n_events == 0:
return loss
-
+
# Process each competing risk separately
for k in range(1, num_competing_risks + 1):
# Find samples with this event type
- event_mask = (events == k)
+ event_mask = events == k
n_events_k = event_mask.sum().item()
-
+
if n_events_k == 0:
continue
-
+
# Get risk scores for this competing risk
- risk_k = risk_scores[k-1].squeeze()
-
+ risk_k = risk_scores[k - 1].squeeze()
+
# Get weight for this event type
- event_weight = event_weights[k-1]
-
+ event_weight = event_weights[k - 1]
+
# For each event of type k
for i in range(batch_size):
if event_mask[i]:
# Find samples in risk set (samples with time >= event time)
- risk_set = (times >= times[i])
-
+ risk_set = times >= times[i]
+
# Calculate log sum of exp of risk scores in risk set
risk_set_scores = risk_k[risk_set]
log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)
-
+
# Subtract individual risk score from log sum and apply weights
individual_loss = log_risk_sum - risk_k[i]
- weighted_individual_loss = individual_loss * event_weight * sample_weights[i]
+ weighted_individual_loss = (
+ individual_loss * event_weight * sample_weights[i]
+ )
loss += weighted_individual_loss
-
+
# Return average loss
return loss / max(n_events, 1)
-def negative_log_likelihood_loss(risk_scores, times, events,
- num_competing_risks, eps=1e-8):
+
+def negative_log_likelihood_loss(
+ risk_scores: float,
+ times: torch.Tensor,
+ events: torch.Tensor,
+ num_competing_risks: int,
+ eps: float = 1e-8,
+) -> torch.Tensor:
"""
- Computes the negative log-likelihood loss for competing risks Cox model.
-
- Args:
+ Compute the negative log-likelihood loss for competing risks Cox model.
+
+ Parameters
+ ----------
risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
times: Event/censoring times (batch_size,)
events: Event indicators (0=censored, 1...K=event types) (batch_size,)
num_competing_risks: Number of competing risks
eps: Small constant for numerical stability
-
- Returns:
+
+ Returns
+ -------
Negative log partial likelihood loss
"""
device = times.device
batch_size = times.shape[0]
-
+
# Initialize loss
loss = torch.tensor(0.0, device=device)
-
+
# Count number of events
n_events = (events > 0).sum().item()
if n_events == 0:
return loss
-
+
# Process each competing risk separately
for k in range(1, num_competing_risks + 1):
# Find samples with this event type
- event_mask = (events == k)
+ event_mask = events == k
n_events_k = event_mask.sum().item()
-
+
if n_events_k == 0:
continue
-
+
# Get risk scores for this competing risk
- risk_k = risk_scores[k-1].squeeze()
-
+ risk_k = risk_scores[k - 1].squeeze()
+
# For each event of type k
for i in range(batch_size):
if event_mask[i]:
# Find samples in risk set (samples with time >= event time)
- risk_set = (times >= times[i])
-
+ risk_set = times >= times[i]
+
# Calculate log sum of exp of risk scores in risk set
risk_set_scores = risk_k[risk_set]
log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)
-
+
# Subtract individual risk score from log sum
loss += log_risk_sum - risk_k[i]
-
+
# Return average loss
return loss / max(n_events, 1)
-def compute_l2_penalty(model, include_bias=False) -> int:
+
+def compute_l2_penalty(
+ model: torch.nn.Module,
+ include_bias: bool = False
+ ) -> torch.Tensor:
"""
- Compute L2 regularization penalty on model parameters
-
- Args:
+ Compute L2 regularization penalty on model parameters.
+
+ Parameters
+ ----------
model: Neural network model
include_bias: Whether to include bias terms in regularization
-
- Returns:
+
+ Returns
+ -------
L2 penalty term
"""
l2_reg = 0.0
for name, param in model.named_parameters():
if param.requires_grad:
# Skip bias parameters if specified
- if not include_bias and 'bias' in name:
+ if not include_bias and "bias" in name:
continue
- l2_reg += torch.sum(param ** 2)
- return l2_reg
\ No newline at end of file
+ l2_reg += torch.sum(param**2)
+ return l2_reg
diff --git a/crisp_nam/utils/plotting.py b/crisp_nam/utils/plotting.py
index e92e1b44..184e7718 100644
--- a/crisp_nam/utils/plotting.py
+++ b/crisp_nam/utils/plotting.py
@@ -1,32 +1,61 @@
-import matplotlib.pyplot as plt
-from typing import Union, List
+"""Utility functions for plotting.
-import torch
+This module provides functions to visualize feature importance
+and shape functions for both crisp-nam and deephit models.
+"""
+
+from typing import List, Union
+
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
+import torch
-def plot_feature_importance(model: torch.nn.Module,
- x_data: Union[np.ndarray, torch.Tensor],
- feature_names = None,
- n_top:int = 5,
- n_bottom:int = 5,
- risk_idx:int = 1,
- figsize: tuple = (8, 6),
- output_file:str = None,
- color_positive:str = '#2196F3',
- color_negative:str = '#F44336') -> tuple:
- """
- Plot feature importance with both top positive and negative influences,
+
+def plot_feature_importance(
+ model: torch.nn.Module,
+ x_data: Union[np.ndarray, torch.Tensor],
+ feature_names=None,
+ n_top: int = 5,
+ n_bottom: int = 5,
+ risk_idx: int = 1,
+ figsize: tuple = (8, 6),
+ output_file: str = "",
+ color_positive: str = "#2196F3",
+ color_negative: str = "#F44336",
+) -> tuple:
+ """Plot feature importance with both top positive and negative influences,
handling both CPU and CUDA devices automatically.
+
+ Parameters
+ ----------
+ - model: A trained CoxNAM model (torch.nn.Module)
+ - x_data: Input data (numpy array or torch tensor) to compute contributions
+ - feature_names: Optional list of feature names (default: generic names)
+ - n_top: Number of top positive features to display
+ - n_bottom: Number of top negative features to display
+ - risk_idx: Index of the competing risk to analyze
+ - figsize: Size of the plot (width, height)
+ - output_file: Optional path to save the plot image (e.g., "feature_importance.png")
+ - color_positive: Color for positive contributions (default: blue)
+ - color_negative: Color for negative contributions (default: red)
+
+ Returns
+ -------
+ - fig: Matplotlib figure object
+ - ax: Matplotlib axes object
+ - top_pos: List of top positive feature names
+ - top_neg: List of top negative feature names
"""
+
# determine model device
device = next(model.parameters()).device
model.eval()
# prepare feature names
- num_features = model.num_features
+ num_features: torch.Tensor = model.num_features
if feature_names is None:
- feature_names = [f"Feature {i+1}" for i in range(num_features)]
+ feature_names = [f"Feature {i + 1}" for i in range(num_features)]
# convert x_data to tensor on the model device
if not isinstance(x_data, torch.Tensor):
@@ -45,77 +74,102 @@ def plot_feature_importance(model: torch.nn.Module,
continue
# forward through the feature net and projection
- rep = model.feature_nets[i](vals)
- proj = model.risk_projections[i][risk_idx0](rep)
+ rep : torch.nn.ModuleList = model.feature_nets[i](vals)
+ proj : torch.nn.ModuleList = model.risk_projections[i][risk_idx0](rep)
# mean contribution as a Python float
contrib = proj.mean().item()
feature_contribs[feature_names[i]] = contrib
# build a DataFrame for sorting
- df = pd.DataFrame({
- 'feature': list(feature_contribs.keys()),
- 'contribution': list(feature_contribs.values())
- })
- df['abs_contrib'] = df['contribution'].abs()
- df = df.sort_values('abs_contrib', ascending=False)
-
- pos = df[df['contribution'] > 0].head(n_top).sort_values('contribution')
- neg = df[df['contribution'] < 0].head(n_bottom).sort_values('contribution', ascending=False)
-
- top_pos = pos['feature'].tolist()
- top_neg = neg['feature'].tolist()
+ df = pd.DataFrame(
+ {
+ "feature": list(feature_contribs.keys()),
+ "contribution": list(feature_contribs.values()),
+ }
+ )
+ df["abs_contrib"] = df["contribution"].abs()
+ df = df.sort_values("abs_contrib", ascending=False)
+
+ pos = df[df["contribution"] > 0].head(n_top).sort_values("contribution")
+ neg = (
+ df[df["contribution"] < 0]
+ .head(n_bottom)
+ .sort_values("contribution", ascending=False)
+ )
+
+ top_pos = pos["feature"].tolist()
+ top_neg = neg["feature"].tolist()
# plotting
fig, ax = plt.subplots(figsize=figsize)
- ax.barh(pos['feature'], pos['contribution'], color=color_positive, alpha=0.8)
- ax.barh(neg['feature'], neg['contribution'], color=color_negative, alpha=0.8)
- ax.axvline(0, color='black', linestyle='-', alpha=0.3)
-
- ax.set_xlabel('Contribution to Risk Score')
- ax.set_title(f'Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}')
- ax.grid(axis='x', linestyle='--', alpha=0.5)
+ ax.barh(pos["feature"], pos["contribution"], color=color_positive, alpha=0.8)
+ ax.barh(neg["feature"], neg["contribution"], color=color_negative, alpha=0.8)
+ ax.axvline(0, color="black", linestyle="-", alpha=0.3)
+
+ ax.set_xlabel("Contribution to Risk Score")
+ ax.set_title(
+ f"Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}"
+ )
+ ax.grid(axis="x", linestyle="--", alpha=0.5)
plt.tight_layout()
if output_file:
- plt.savefig(output_file, bbox_inches='tight', dpi=300)
+ plt.savefig(output_file, bbox_inches="tight", dpi=300)
return fig, ax, top_pos, top_neg
-def plot_coxnam_shape_functions(model:torch.nn.Module,
- X: Union[np.ndarray, torch.Tensor],
- risk_to_plot:int = 1,
- feature_names:List[str] = None,
- top_features:List[str] = None,
- ncols:int = 3,
- figsize:tuple = (12, 8),
- output_file:str = None) -> tuple:
- """
- Plot shape functions for each feature in a CoxNAM model,
+def plot_coxnam_shape_functions(
+ model: torch.nn.Module,
+ X: Union[np.ndarray, torch.Tensor],
+ risk_to_plot: int = 1,
+ feature_names: np.ndarray | None = None,
+ top_features: List[str] | None = None,
+ ncols: int = 3,
+ figsize: tuple = (12, 8),
+ output_file: str = "",
+) -> tuple:
+ """Plot shape functions for each feature in a CoxNAM model,
automatically handling CPU vs CUDA inputs.
+
+ Parameters
+ ----------
+ - model: A trained CoxNAM model (torch.nn.Module)
+ - X: Input data (numpy array or torch tensor) to compute shape functions
+ - risk_to_plot: Index of the competing risk to visualize
+ - feature_names: Optional list of feature names (default: generic names)
+ - top_features: Optional list of feature names to plot features)
+ - ncols: Number of columns in the subplot grid
+ - figsize: Size of the entire figure (width, height)
+ - output_file: Optional path to save the plot image (e.g., "shape_functions.png")
+
+ Returns
+ -------
+ - fig: Matplotlib figure object
+ - axes: List of Matplotlib axes objects for each plotted feature
"""
device = next(model.parameters()).device
model.eval()
risk_idx = risk_to_plot - 1
# ensure X is a numpy array
- if isinstance(X, torch.Tensor):
- X_np = X.cpu().numpy()
- else:
- X_np = np.array(X, dtype=float)
+ X_np = X.cpu().numpy() if isinstance(X, torch.Tensor) else np.array(X, dtype=float)
# derive feature list
num_features = model.num_features
+ print(f'{plot_coxnam_shape_functions.__name__}: top_features={top_features}')
if feature_names is None:
- feature_names = [f"Feature {i+1}" for i in range(num_features)]
- if top_features:
+ feature_names = [f"Feature {i + 1}" for i in range(num_features)]
+ if top_features is not None :
# map names back to indices
idx_map = {name: i for i, name in enumerate(feature_names)}
- selected = [(idx_map.get(name, None), name) for name in top_features]
+ selected = [(idx_map.get(name), name) for name in top_features]
selected = [(i, name) for i, name in selected if i is not None]
else:
selected = list(zip(range(num_features), feature_names))
+ print(f'{plot_coxnam_shape_functions.__name__}: num_features={num_features}, feature_names={feature_names}, top_features={top_features}')
+ print(f'{plot_coxnam_shape_functions.__name__}: selected={selected}')
n_selected = len(selected)
nrows = int(np.ceil(n_selected / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(figsize))
@@ -125,7 +179,7 @@ def plot_coxnam_shape_functions(model:torch.nn.Module,
for ax, (f_idx, fname) in zip(axes, selected):
vals = X_np[:, f_idx]
if vals.size == 0:
- ax.text(0.5, 0.5, "no data", ha='center', va='center')
+ ax.text(0.5, 0.5, "no data", ha="center", va="center")
continue
# choose evaluation points
@@ -138,26 +192,26 @@ def plot_coxnam_shape_functions(model:torch.nn.Module,
t_pts = torch.tensor(pts, dtype=torch.float32, device=device).unsqueeze(1)
# compute shape values
- rep = model.feature_nets[f_idx](t_pts)
- proj = model.risk_projections[f_idx][risk_idx](rep)
+ rep : torch.nn.ModuleList = model.feature_nets[f_idx](t_pts)
+ proj : torch.nn.ModuleList = model.risk_projections[f_idx][risk_idx](rep)
shp = proj.squeeze(-1).cpu().numpy()
# plot
ax.plot(pts, shp, linewidth=2)
- ax.axhline(0, linestyle='--', alpha=0.5)
+ ax.axhline(0, linestyle="--", alpha=0.5)
ax.set_title(fname)
- ax.set_xlabel('Value')
- ax.set_ylabel('Contribution')
+ ax.set_xlabel("Value")
+ ax.set_ylabel("Contribution")
# rug plot
- ax.plot(vals, np.zeros_like(vals)-0.1, '|', alpha=0.3)
+ ax.plot(vals, np.zeros_like(vals) - 0.1, "|", alpha=0.3)
# turn off any extra axes
for ax in axes[n_selected:]:
- ax.axis('off')
+ ax.axis("off")
- fig.suptitle(f'Shape Functions for Risk {risk_to_plot}', fontsize=14)
- plt.tight_layout(rect=[0,0,1,0.96])
+ fig.suptitle(f"Shape Functions for Risk {risk_to_plot}", fontsize=14)
+ plt.tight_layout(rect=(0, 0, 1, 0.96))
if output_file:
- plt.savefig(output_file, dpi=300, bbox_inches='tight')
+ plt.savefig(output_file, dpi=300, bbox_inches="tight")
return fig, axes[:n_selected]
diff --git a/crisp_nam/utils/risk_cif.py b/crisp_nam/utils/risk_cif.py
index 9e8cd1b2..4c043890 100644
--- a/crisp_nam/utils/risk_cif.py
+++ b/crisp_nam/utils/risk_cif.py
@@ -1,33 +1,40 @@
-from typing import List, Any
+"""Risk functions for evaluation.
+
+This module provides functions to compute cumulative incidence functions (CIFs)
+and risk scores for competing risk models.
+"""
+
+from typing import Any, List
-import torch
import numpy as np
+import torch
-def compute_baseline_cif(times:np.ndarray,
- events:np.ndarray,
- eval_times:List[Any],
- event_type:np.ndarray) -> np.ndarray:
+
+def compute_baseline_cif(
+ times: np.ndarray, events: np.ndarray, eval_times: List[Any], event_type: np.ndarray
+) -> np.ndarray:
"""
- Compute baseline cumulative incidence function for a specific event type
-
+ Compute baseline cumulative incidence function for a specific event type.
+
Args:
times: Numpy array of event times
events: Numpy array of event indicators (0=censored, 1...K=event types)
eval_times: Times at which to evaluate the CIF
event_type: Event type to compute CIF for (1...K)
-
- Returns:
+
+ Returns
+ -------
Numpy array of baseline CIF values at eval_times
"""
# Sort times and corresponding events
sort_idx = np.argsort(times)
sorted_times = times[sort_idx]
sorted_events = events[sort_idx]
-
+
# Initialize cumulative hazard
n_samples = len(times)
baseline_cif = np.zeros(len(eval_times))
-
+
# For each evaluation time
for i, t in enumerate(eval_times):
cif_t = 0.0
@@ -37,78 +44,31 @@ def compute_baseline_cif(times:np.ndarray,
# Simple Aalen-Johansen estimator
cif_t = event_count / n_samples
baseline_cif[i] = cif_t
-
- return baseline_cif
-def predict_absolute_risk(model: torch.nn.Module,
- x:np.ndarray,
- baseline_cifs:np.ndarray,
- eval_times: np.ndarray,
- device:str="cpu") -> np.ndarray:
- """
- Predict absolute risk (cumulative incidence) at specified times
-
- Args:
- model: Trained CoxNAM model
- x: Feature matrix
- baseline_cifs: Dictionary of baseline CIFs for each event type
- eval_times: Times at which to evaluate the CIF
- device: Device to run computations on
-
- Returns:
- Numpy array of predicted absolute risks with shape (n_samples, n_risks, n_times)
- """
- model.eval()
-
- # Convert to tensor if needed
- if not isinstance(x, torch.Tensor):
- x = torch.FloatTensor(x).to(device)
-
- with torch.no_grad():
- # Get risk scores
- risk_scores, _ = model(x)
-
- # Convert to hazard ratios
- hazard_ratios = [torch.exp(score).cpu().numpy() for score in risk_scores]
-
- # Initialize prediction array
- n_samples = x.shape[0]
- n_risks = model.num_competing_risks
- n_times = len(eval_times)
- abs_risks = np.zeros((n_samples, n_risks, n_times))
-
- # Compute absolute risk for each sample, risk, and time
- for k in range(n_risks):
- # Skip if baseline CIF is not available
- if k+1 not in baseline_cifs:
- continue
-
- baseline_cif = baseline_cifs[k+1]
-
- for i in range(n_samples):
- for j, t in enumerate(eval_times):
- # Simple Fine-Gray model for cumulative incidence
- abs_risks[i, k, j] = 1 - np.exp(-baseline_cif[j] * hazard_ratios[k][i])
-
- return abs_risks
+ return baseline_cif
-def predict_cif(model:torch.nn.Module,
- x:np.ndarray,
- baseline_cif:np.ndarray,
- times:np.ndarray,
- event_of_interest:int) -> np.ndarray:
+def predict_cif(
+ model: torch.nn.Module,
+ x: np.ndarray,
+ baseline_cif: np.ndarray,
+ times: np.ndarray,
+ event_of_interest: int,
+) -> np.ndarray:
"""
Predict cumulative incidence function for a specific competing risk.
- Args:
+ Parameters
+ ----------
model: Trained model.
x: Input tensor of shape (n_samples, n_features).
- baseline_cif: Array of shape (len(times),) — estimated CIF for baseline (e.g. from compute_baseline_cif).
+ baseline_cif: Array of shape (len(times),) —
+ estimated CIF for baseline (e.g. from compute_baseline_cif).
times: Time points at which CIF is evaluated.
event_type: Integer, 0-based index of event of interest.
- Returns:
+ Returns
+ -------
cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample.
"""
model.eval()
@@ -117,56 +77,63 @@ def predict_cif(model:torch.nn.Module,
f_j_x = logits[event_of_interest].squeeze(1).cpu().numpy() # (n_samples,)
baseline_cif = np.asarray(baseline_cif).reshape(1, -1) # (1, T)
- risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1)
-
- # Fine-Gray style CIF prediction under PH assumption
- cif_pred = 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T)
-
- return cif_pred
-
-def predict_risk(model:np.ndarray,
- x_input:np.ndarray,
- device:str = 'cpu'):
+ risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1)
+
+ # Return Fine-Gray style CIF prediction under PH assumption
+ return 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T)
+
+
+def predict_risk(
+ model: torch.nn.Module, x_input: np.ndarray, device: str = "cpu"
+) -> np.ndarray:
"""
Predicts relative risk scores for each competing risk.
Args:
model : Trained model.
- x_input (np.ndarray or torch.Tensor): Input features of shape (n_samples, n_features).
+ x_input (np.ndarray or torch.Tensor): Input features of
+ shape (n_samples, n_features).
device (str): Device to run the computation on.
- Returns:
+ Returns
+ -------
np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores.
"""
model.eval()
-
+
if isinstance(x_input, np.ndarray):
x_tensor = torch.from_numpy(x_input).float().to(device)
else:
x_tensor = x_input.to(device).float()
-
+
with torch.no_grad():
risk_outputs, _ = model(x_tensor) # List of [batch_size, 1] tensors
risks = torch.cat(risk_outputs, dim=1) # Shape: [batch_size, num_risks]
- return risks.cpu().numpy()
+ return risks.cpu().numpy()
+
-def predict_absolute_risk(model:torch.Tensor,
- x_input:np.ndarray,
- baseline_cifs:List[Any],
- times:List[Any],
- device:str = 'cpu') -> np.ndarray:
+def predict_absolute_risk(
+ model: torch.nn.Module,
+ x_input: np.ndarray,
+ baseline_cifs: List[Any],
+ times: List[Any],
+ device: str = "cpu",
+) -> np.ndarray:
"""
Predict absolute risk (CIF) for each competing event by given time points.
- Args:
+ Parameters
+ ----------
model: Trained model.
x_input (np.ndarray or Tensor): Input features, shape (n_samples, n_features).
- baseline_cifs (dict): Mapping of event index to baseline CIF array of shape (n_times,).
+ baseline_cifs (dict): Mapping of event index to baseline CIF
+ array of shape (n_times,).
times (np.ndarray): Time grid used for baseline_cifs.
device: CPU or CUDA.
- Returns:
+ Returns
+ -------
np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs.
"""
rel_risks = predict_risk(model, x_input, device) # shape (n_samples, num_events)
@@ -179,5 +146,5 @@ def predict_absolute_risk(model:torch.Tensor,
base_cif = np.clip(baseline_cifs[k], 1e-10, 0.9999) # avoid edge cases
for i in range(n_samples):
abs_risks[i, k, :] = 1 - np.power(1 - base_cif, np.exp(rel_risks[i, k]))
-
- return abs_risks
\ No newline at end of file
+
+ return abs_risks
diff --git a/data_utils/load_datasets.py b/data_utils/load_datasets.py
index 2263fa30..6e4837fa 100644
--- a/data_utils/load_datasets.py
+++ b/data_utils/load_datasets.py
@@ -1,16 +1,18 @@
-from typing import Tuple, List
+from typing import List, Tuple
-import pandas as pd
import numpy as np
+import pandas as pd
from sklearn.impute import SimpleImputer
+
def load_framingham(sequential=False):
"""
Load and preprocess the Framingham dataset for competing risks analysis,
with imputation but no scaling. Feature normalization must be done externally
after splitting to avoid data leakage.
- Returns:
+ Returns
+ -------
x (np.ndarray): Feature matrix with one-hot categorical + raw continuous features.
t (np.ndarray): Time-to-event (with +1 offset).
e (np.ndarray): Event indicator (0=censored, 1=CVD, 2=death).
@@ -18,55 +20,62 @@ def load_framingham(sequential=False):
n_continuous (int): Number of continuous features at the end of x.
feature_ranges (None): Placeholder for backward compatibility.
"""
-
file_path = "datasets/framingham.csv"
data = pd.read_csv(file_path)
if not sequential:
-
data = data.groupby("RANDID").first()
-
cat_cols = [
- 'SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS',
- 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP', 'educ'
+ "SEX",
+ "CURSMOKE",
+ "DIABETES",
+ "BPMEDS",
+ "PREVCHD",
+ "PREVAP",
+ "PREVMI",
+ "PREVSTRK",
+ "PREVHYP",
+ "educ",
]
-
+
# 'HDLC', 'LDLC' - removed to replicate nfg experiments.
cont_cols = [
- 'TOTCHOL', 'AGE',
- 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI',
- 'HEARTRTE', 'GLUCOSE'
+ "TOTCHOL",
+ "AGE",
+ "SYSBP",
+ "DIABP",
+ "CIGPDAY",
+ "BMI",
+ "HEARTRTE",
+ "GLUCOSE",
]
-
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat = pd.DataFrame(
cat_imputer.fit_transform(data[cat_cols]),
columns=cat_cols,
- index=data.index
+ index=data.index,
)
x_cat = pd.get_dummies(x_cat, drop_first=True)
-
- cont_imputer = SimpleImputer(strategy='mean')
+ cont_imputer = SimpleImputer(strategy="mean")
x_cont = cont_imputer.fit_transform(data[cont_cols])
-
x = np.hstack([x_cat.values, x_cont])
feature_names = np.concatenate([x_cat.columns.values, cont_cols])
n_continuous = len(cont_cols)
event = np.zeros(len(data), dtype=int)
- time = (data['TIMEDTH'] - data['TIME']).values
+ time = (data["TIMEDTH"] - data["TIME"]).values
# Primary CVD event (risk=1)
- cvd_mask = data['CVD'] == 1
+ cvd_mask = data["CVD"] == 1
event[cvd_mask] = 1
- time_cvd = (data['TIMECVD'] - data['TIME']).values
+ time_cvd = (data["TIMECVD"] - data["TIME"]).values
time[cvd_mask] = time_cvd[cvd_mask]
# Competing death event (risk=2), only if CVD did not occur
- death_mask = (data['DEATH'] == 1) & ~cvd_mask
+ death_mask = (data["DEATH"] == 1) & ~cvd_mask
event[death_mask] = 2
# Filter out invalid or zero times
@@ -79,8 +88,8 @@ def load_framingham(sequential=False):
assert not np.isnan(x).any(), "NaNs found in feature matrix"
return x, t, e, feature_names, n_continuous, None
- else:
- raise NotImplementedError("Sequential mode not yet implemented.")
+ raise NotImplementedError("Sequential mode not yet implemented.")
+
def load_pbc2_dataset():
"""
@@ -90,7 +99,8 @@ def load_pbc2_dataset():
with missing values imputed. The function also constructs the outcome variable
for competing risks analysis and returns time-to-event data.
- Returns:
+ Returns
+ -------
tuple: A tuple containing the following elements:
- x (numpy.ndarray): Combined feature matrix with categorical features
one-hot encoded and continuous features imputed.
@@ -103,55 +113,56 @@ def load_pbc2_dataset():
- n_continuous (int): Number of continuous features.
- feature_ranges (list of tuple): List of (min, max) ranges for each feature.
"""
-
file_path = "datasets/pbc2.csv"
data = pd.read_csv(file_path)
- data = data.drop(columns=['id', 'sno.', 'year', 'status2'], axis=1)
-
+ data = data.drop(columns=["id", "sno.", "year", "status2"], axis=1)
event_type = np.where(
- data['status'] == 'dead', 1,
- np.where(data['status'] == 'transplanted', 2, 0)
+ data["status"] == "dead", 1, np.where(data["status"] == "transplanted", 2, 0)
)
-
cont_cols = [
- 'age', 'serBilir', 'serChol', 'albumin', 'alkaline',
- 'SGOT', 'platelets', 'prothrombin', 'histologic'
+ "age",
+ "serBilir",
+ "serChol",
+ "albumin",
+ "alkaline",
+ "SGOT",
+ "platelets",
+ "prothrombin",
+ "histologic",
]
- x_cont = data[cont_cols].replace('NA', np.nan).astype(float)
+ x_cont = data[cont_cols].replace("NA", np.nan).astype(float)
-
- mean_imputer = SimpleImputer(strategy='mean')
+ mean_imputer = SimpleImputer(strategy="mean")
x_cont_imputed = mean_imputer.fit_transform(x_cont)
- cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0),
- np.nanmax(x_cont_imputed, axis=0)))
+ cont_feature_ranges = list(
+ zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0))
+ )
-
- cat_cols = ['sex', 'drug', 'ascites', 'hepatomegaly', 'spiders', 'edema']
- x_cat = data[cat_cols].fillna('missing')
+ cat_cols = ["sex", "drug", "ascites", "hepatomegaly", "spiders", "edema"]
+ x_cat = data[cat_cols].fillna("missing")
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat_imputed = cat_imputer.fit_transform(x_cat)
x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols)
x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True)
-
- x_cat_encoded = x_cat_encoded.loc[:, ~x_cat_encoded.columns.str.contains('_missing')]
+ x_cat_encoded = x_cat_encoded.loc[
+ :, ~x_cat_encoded.columns.str.contains("_missing")
+ ]
cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1]
-
x = np.hstack([x_cat_encoded.values, x_cont_imputed])
feature_names = np.concatenate([x_cat_encoded.columns, cont_cols])
n_continuous = len(cont_cols)
feature_ranges = cat_feature_ranges + cont_feature_ranges
-
- t = data['years'].astype(float).values * 365.25
+ t = data["years"].astype(float).values * 365.25
valid = ~np.isnan(t)
return (
@@ -160,17 +171,20 @@ def load_pbc2_dataset():
event_type[valid],
feature_names,
n_continuous,
- feature_ranges
+ feature_ranges,
)
+
def load_support_dataset():
"""
Load and preprocess the SUPPORT dataset.
This function reads the SUPPORT dataset from a CSV file, imputes missing values,
encodes categorical features, and constructs the outcome variable for survival analysis.
- It returns the processed features, time-to-event data, event types, feature names,
+ It returns the processed features, time-to-event data, event types, feature names,
the number of continuous features, and feature ranges.
- Returns:
+
+ Returns
+ -------
tuple: A tuple containing:
- x (numpy.ndarray): Combined array of processed categorical and continuous features.
- t (numpy.ndarray): Time-to-event data with a +1 offset.
@@ -178,7 +192,9 @@ def load_support_dataset():
- feature_names (numpy.ndarray): Array of feature names.
- n_continuous (int): Number of continuous features.
- feature_ranges (list): List of tuples representing the range (min, max) for each feature.
- Notes:
+
+ Notes
+ -----
- The dataset file "support2.csv" must be located in the same directory as this script.
- Continuous features are imputed using the median strategy if missing values are present.
- Categorical features are imputed using the most frequent strategy if missing values are present.
@@ -186,77 +202,101 @@ def load_support_dataset():
- The outcome variable is constructed using the 'ca', 'dzgroup', and 'death' columns.
- Time-to-event data ('d.time') is offset by +1 to avoid zero follow-up times.
"""
-
file_path = "datasets/support2.csv"
data = pd.read_csv(file_path)
-
- is_cancer = data['ca'].astype(str).str.lower().str.contains("meta") | \
- data['dzgroup'].astype(str).str.lower().str.contains("cancer")
- event_type = np.where(data['death'] == 1, np.where(is_cancer, 1, 2), 0)
+
+ is_cancer = data["ca"].astype(str).str.lower().str.contains("meta") | data[
+ "dzgroup"
+ ].astype(str).str.lower().str.contains("cancer")
+ event_type = np.where(data["death"] == 1, np.where(is_cancer, 1, 2), 0)
cont_cols = [
- 'age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 'pafi', 'alb',
- 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 'urine',
- 'scoma', 'aps', 'sps', 'adls', 'adlsc', 'charges', 'totcst', 'totmcst', 'avtisst'
+ "age",
+ "num.co",
+ "meanbp",
+ "wblc",
+ "hrt",
+ "resp",
+ "temp",
+ "pafi",
+ "alb",
+ "bili",
+ "crea",
+ "sod",
+ "ph",
+ "glucose",
+ "bun",
+ "urine",
+ "scoma",
+ "aps",
+ "sps",
+ "adls",
+ "adlsc",
+ "charges",
+ "totcst",
+ "totmcst",
+ "avtisst",
]
x_cont = data[cont_cols]
-
+
if x_cont.isnull().values.any():
- simp_imputer = SimpleImputer(strategy='median')
+ simp_imputer = SimpleImputer(strategy="median")
x_cont_imputed = simp_imputer.fit_transform(x_cont)
else:
x_cont_imputed = x_cont.values
- cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0),
- np.nanmax(x_cont_imputed, axis=0)))
+ cont_feature_ranges = list(
+ zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0))
+ )
# -- Categorical features --
# Remove leakage fields 'ca', 'dzgroup', 'dzclass'
- cat_cols = ['sex', 'income', 'race', 'dnr', 'dementia', 'diabetes']
+ cat_cols = ["sex", "income", "race", "dnr", "dementia", "diabetes"]
x_cat = data[cat_cols]
if x_cat.isnull().values.any():
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat_imputed = cat_imputer.fit_transform(x_cat)
else:
x_cat_imputed = x_cat.values
x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols)
-
+
x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True)
cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1]
-
x = np.hstack([x_cat_encoded.values, x_cont_imputed])
feature_names = np.concatenate([x_cat_encoded.columns, cont_cols])
n_continuous = len(cont_cols)
feature_ranges = cat_feature_ranges + cont_feature_ranges
-
- t = data['d.time'].values
+ t = data["d.time"].values
valid = ~np.isnan(t)
-
+
print("Completed imputation of missing values.")
return (
x[valid],
- t[valid] + 1, # Add +1 offset to follow-up times
+ t[valid] + 1, # Add +1 offset to follow-up times
event_type[valid],
feature_names,
n_continuous,
- feature_ranges
+ feature_ranges,
)
-def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]]:
+def load_synthetic_dataset() -> Tuple[
+ np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]
+]:
"""
Loads a synthetic competing risks dataset from a CSV file.
-
+
The CSV is expected to have a header with the following columns:
- time: observed time
- label: event indicator (0 for censored; >0 for event types)
- true_time: (optional) true time (unused here)
- true_label: (optional) true event label (unused here)
- feature1, feature2, ..., featureN: feature values
-
- Returns:
+
+ Returns
+ -------
X (np.ndarray): Feature matrix of shape (n_samples, n_features).
T_obs (np.ndarray): Observed times of shape (n_samples,).
e (np.ndarray): Event indicators of shape (n_samples,).
@@ -264,22 +304,20 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s
n_continuous (int): Total number of continuous features.
feature_ranges (List[tuple]): List of (min, max) tuples for each feature.
"""
-
file_path = "datasets/synthetic_comprisk.csv"
df = pd.read_csv(file_path)
-
- T_obs = df["time"].values .astype(np.float32)
+ T_obs = df["time"].values.astype(np.float32)
e = df["label"].values.astype(np.float32)
-
-
+
feature_columns = [col for col in df.columns if col.startswith("feature")]
- X = df[feature_columns].values .astype(np.float32)
-
-
+ X = df[feature_columns].values.astype(np.float32)
+
feature_names = feature_columns
n_continuous = X.shape[1]
-
- feature_ranges = [(float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)]
-
- return X, T_obs, e, feature_names, n_continuous, feature_ranges
\ No newline at end of file
+
+ feature_ranges = [
+ (float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)
+ ]
+
+ return X, T_obs, e, feature_names, n_continuous, feature_ranges
diff --git a/data_utils/survival_datasets.py b/data_utils/survival_datasets.py
index d2eab1b9..41d8c8f9 100644
--- a/data_utils/survival_datasets.py
+++ b/data_utils/survival_datasets.py
@@ -1,6 +1,7 @@
import torch
from torch.utils.data import Dataset
+
class SurvivalDataset(Dataset):
def __init__(self, x, t, e):
self.x = torch.tensor(x, dtype=torch.float32)
@@ -13,22 +14,28 @@ def __len__(self):
def __getitem__(self, idx):
return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking
+
class SurvivalDatasetDeepHit(Dataset):
"""Dataset class for DeepHit model"""
+
def __init__(self, x, t, e, num_Category):
self.x = torch.tensor(x, dtype=torch.float32)
self.t = torch.tensor(t, dtype=torch.float32).view(-1, 1)
self.e = torch.tensor(e, dtype=torch.float32).view(-1, 1)
-
+
# Create discretized time if needed
- self.t_discrete = torch.floor(self.t * num_Category / torch.max(self.t)).clamp(0, num_Category-1).long()
-
+ self.t_discrete = (
+ torch.floor(self.t * num_Category / torch.max(self.t))
+ .clamp(0, num_Category - 1)
+ .long()
+ )
+
# Create masks for loss calculation
self.num_Category = num_Category
self.num_Event = int(torch.max(self.e).item())
-
+
def __len__(self):
return len(self.x)
-
+
def __getitem__(self, idx):
- return self.x[idx], self.t[idx], self.e[idx], self.t_discrete[idx]
\ No newline at end of file
+ return self.x[idx], self.t[idx], self.e[idx], self.t_discrete[idx]
diff --git a/datasets.md b/datasets.md
new file mode 100644
index 00000000..47f95d3d
--- /dev/null
+++ b/datasets.md
@@ -0,0 +1,40 @@
+## Datasets overview
+The repository includes four well-established survival analysis datasets:
+
+1. **Framingham Heart Study**: Cardiovascular disease prediction with competing events (CVD vs. death)
+ - Features: Demographics, clinical measurements, lifestyle factors
+ - Events: Cardiovascular disease, death from other causes
+
+2. **PBC (Primary Biliary Cirrhosis)**: Liver disease progression study
+ - Features: Clinical laboratory values, demographic information
+ - Events: Death, liver transplantation
+
+3. **SUPPORT**: Study to understand prognoses and preferences for outcomes
+ - Features: Comprehensive clinical and demographic variables
+ - Events: Cancer death, non-cancer death
+
+4. **Synthetic Dataset**: Controlled simulation for method validation
+ - Features: Simulated clinical variables with known ground truth
+ - Events: Multiple competing risks with controllable hazard functions
+
+CSV files for all datasets are available in the repository within `crisp-nam/datasets` folder.
+
+## Data loading scripts
+The repository contains preprocessing scripts within `datasets` folder that handle missing values, feature encoding, and proper train/test splitting to prevent data leakage for each dataset.
+
+- **`framingham_dataset.py`**: Preprocess and load Framingham dataset.
+- **`pbc_dataset.py`**: Preprocess and load PBC dataset.
+- **`support_dataset.py`**: Preprocess and load Support2 dataset.
+- **`synthetic_dataset.py`**: Preprocess and load synthetically generated dataset.
+
+## Return format
+Each script returns the following values for use within training scripts:
+1. `x`: Feature matrix after preprocessing
+2. `t`: Array of time to event values
+3. `event_type`: Categorical data depicting event types
+4. `feature_names`: Array of feature names
+5. `n_continuous`: Number of continuous features
+6. `feature_ranges`: List of (min, max) ranges for each feature.
+```
+> [!NOTE]
+> If introducing a new dataset, the above mentioned return format is needed to run the training scripts.
\ No newline at end of file
diff --git a/datasets/SurvivalDataset.py b/datasets/SurvivalDataset.py
index c9fd9567..b6a2f01d 100644
--- a/datasets/SurvivalDataset.py
+++ b/datasets/SurvivalDataset.py
@@ -1,5 +1,6 @@
-from torch.utils.data import Dataset
import torch
+from torch.utils.data import Dataset
+
class SurvivalDataset(Dataset):
def __init__(self, x, t, e):
@@ -11,4 +12,4 @@ def __len__(self):
return len(self.x)
def __getitem__(self, idx):
- return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking
\ No newline at end of file
+ return self.x[idx], self.t[idx], self.e[idx], idx # idx for tracking
diff --git a/datasets/framingham.csv b/datasets/framingham.csv
index b33fe583..50a4f09d 100644
--- a/datasets/framingham.csv
+++ b/datasets/framingham.csv
@@ -11625,4 +11625,4 @@ RANDID,SEX,TOTCHOL,AGE,SYSBP,DIABP,CURSMOKE,CIGPDAY,BMI,DIABETES,BPMEDS,HEARTRTE
9998212,1,153,52,143,89,0,0,25.74,0,0,65,72,3,0,0,0,0,1,4538,3,30,123,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,0
9999312,2,196,39,133,86,1,30,20.91,0,0,85,80,3,0,0,0,0,0,0,1,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201
9999312,2,240,46,138,79,1,20,26.39,0,0,90,83,3,0,0,0,0,0,2390,2,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201
-9999312,2,,50,147,96,1,10,24.19,0,0,94,,3,0,0,0,0,1,4201,3,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201
\ No newline at end of file
+9999312,2,,50,147,96,1,10,24.19,0,0,94,,3,0,0,0,0,1,4201,3,,,0,0,0,0,0,0,0,1,8766,8766,8766,8766,8766,8766,8766,4201
diff --git a/datasets/framingham_dataset.py b/datasets/framingham_dataset.py
index 9520c089..4cce629e 100644
--- a/datasets/framingham_dataset.py
+++ b/datasets/framingham_dataset.py
@@ -1,15 +1,18 @@
import os
-import pandas as pd
+
import numpy as np
+import pandas as pd
from sklearn.impute import SimpleImputer
+
def load_framingham(competing=True, sequential=False):
"""
Load and preprocess the Framingham dataset for competing risks analysis,
with imputation but no scaling. Feature normalization must be done externally
after splitting to avoid data leakage.
- Returns:
+ Returns
+ -------
x (np.ndarray): Feature matrix with one-hot categorical + raw continuous features.
t (np.ndarray): Time-to-event (with +1 offset).
e (np.ndarray): Event indicator (0=censored, 1=CVD, 2=death).
@@ -17,58 +20,64 @@ def load_framingham(competing=True, sequential=False):
n_continuous (int): Number of continuous features at the end of x.
feature_ranges (None): Placeholder for backward compatibility.
"""
-
file_path = os.path.join(os.path.dirname(__file__), "framingham.csv")
data = pd.read_csv(file_path)
if not sequential:
-
data = data.groupby("RANDID").first()
-
cat_cols = [
- 'SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS',
- 'PREVCHD', 'PREVAP', 'PREVMI', 'PREVSTRK', 'PREVHYP', 'educ'
+ "SEX",
+ "CURSMOKE",
+ "DIABETES",
+ "BPMEDS",
+ "PREVCHD",
+ "PREVAP",
+ "PREVMI",
+ "PREVSTRK",
+ "PREVHYP",
+ "educ",
]
-
+
# 'HDLC', 'LDLC' - removed to replicate nfg experiments.
cont_cols = [
- 'TOTCHOL', 'AGE',
- 'SYSBP', 'DIABP', 'CIGPDAY', 'BMI',
- 'HEARTRTE', 'GLUCOSE'
+ "TOTCHOL",
+ "AGE",
+ "SYSBP",
+ "DIABP",
+ "CIGPDAY",
+ "BMI",
+ "HEARTRTE",
+ "GLUCOSE",
]
-
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat = pd.DataFrame(
cat_imputer.fit_transform(data[cat_cols]),
columns=cat_cols,
- index=data.index
+ index=data.index,
)
x_cat = pd.get_dummies(x_cat, drop_first=True)
-
- cont_imputer = SimpleImputer(strategy='mean')
+ cont_imputer = SimpleImputer(strategy="mean")
x_cont = cont_imputer.fit_transform(data[cont_cols])
-
x = np.hstack([x_cat.values, x_cont])
feature_names = np.concatenate([x_cat.columns.values, cont_cols])
n_continuous = len(cont_cols)
-
event = np.zeros(len(data), dtype=int)
-
- time = (data['TIMEDTH'] - data['TIME']).values
+
+ time = (data["TIMEDTH"] - data["TIME"]).values
# Primary CVD event (risk=1)
- cvd_mask = data['CVD'] == 1
+ cvd_mask = data["CVD"] == 1
event[cvd_mask] = 1
- time_cvd = (data['TIMECVD'] - data['TIME']).values
+ time_cvd = (data["TIMECVD"] - data["TIME"]).values
time[cvd_mask] = time_cvd[cvd_mask]
# Competing death event (risk=2), only if CVD did not occur
- death_mask = (data['DEATH'] == 1) & ~cvd_mask
+ death_mask = (data["DEATH"] == 1) & ~cvd_mask
event[death_mask] = 2
# Filter out invalid or zero times
@@ -81,10 +90,8 @@ def load_framingham(competing=True, sequential=False):
assert not np.isnan(x).any(), "NaNs found in feature matrix"
return x, t, e, feature_names, n_continuous, None
- else:
- raise NotImplementedError("Sequential mode not yet implemented.")
+ raise NotImplementedError("Sequential mode not yet implemented.")
+
if __name__ == "__main__":
x, t, e, feature_names, n_continuous, feature_ranges = load_framingham()
-
-
diff --git a/datasets/metabric/label.csv b/datasets/metabric/label.csv
index 1ea535c2..c8147146 100644
--- a/datasets/metabric/label.csv
+++ b/datasets/metabric/label.csv
@@ -1,1982 +1,1982 @@
-event_time,label
-2999 ,0
-1484 ,0
-3053 ,0
-1721 ,0
-1241 ,1
-234 ,1
-2947 ,0
-672 ,1
-2734 ,1
-2083 ,0
-1097 ,1
-1088 ,1
-2314 ,0
-2258 ,0
-2361 ,0
-424 ,1
-1594 ,0
-1784 ,0
-2005 ,0
-2310 ,0
-1900 ,0
-1718 ,0
-1846 ,0
-2482 ,0
-1420 ,0
-1543 ,0
-2673 ,0
-1583 ,0
-855 ,1
-2149 ,0
-2484 ,1
-2063 ,0
-1493 ,1
-2611 ,0
-1844 ,0
-242 ,1
-2713 ,0
-1333 ,0
-1481 ,0
-1866 ,0
-1281 ,1
-1433 ,0
-2554 ,0
-3148 ,0
-1175 ,1
-1384 ,0
-402 ,0
-2002 ,1
-2474 ,0
-247 ,0
-364 ,0
-872 ,1
-1694 ,0
-2569 ,0
-2527 ,0
-2368 ,0
-49 ,0
-1479 ,0
-1918 ,0
-1882 ,0
-1157 ,1
-1870 ,0
-1999 ,1
-2612 ,0
-388 ,1
-2398 ,0
-2271 ,0
-1405 ,0
-2596 ,0
-169 ,0
-372 ,0
-1630 ,1
-4562 ,1
-1841 ,0
-2580 ,0
-1553 ,1
-53 ,0
-1551 ,1
-2392 ,0
-1477 ,1
-1860 ,0
-2567 ,0
-1884 ,0
-1794 ,0
-59 ,0
-453 ,0
-1729 ,0
-2259 ,0
-2384 ,0
-1673 ,0
-1141 ,1
-325 ,0
-1428 ,0
-2282 ,0
-1293 ,1
-2283 ,0
-2215 ,0
-2801 ,1
-163 ,0
-1854 ,0
-130 ,0
-2363 ,0
-2171 ,0
-2363 ,0
-1368 ,1
-2755 ,0
-538 ,1
-1870 ,0
-2215 ,0
-2297 ,0
-1315 ,1
-939 ,1
-272 ,1
-459 ,1
-1877 ,0
-283 ,0
-2107 ,0
-2019 ,0
-2184 ,0
-2018 ,0
-2283 ,0
-2191 ,0
-2330 ,0
-2344 ,0
-2073 ,0
-1763 ,1
-729 ,1
-1919 ,0
-1585 ,0
-3359 ,0
-1995 ,0
-1881 ,0
-1299 ,1
-1927 ,0
-1841 ,0
-2179 ,0
-1195 ,1
-606 ,1
-3128 ,0
-2440 ,1
-1008 ,0
-4398 ,0
-1361 ,0
-2102 ,0
-325 ,1
-1958 ,0
-2653 ,0
-1366 ,0
-1324 ,0
-2173 ,1
-2827 ,1
-4277 ,1
-3861 ,0
-2083 ,0
-2567 ,1
-2194 ,1
-2198 ,0
-4028 ,0
-2302 ,0
-3259 ,0
-2319 ,0
-2130 ,0
-3628 ,0
-2965 ,0
-3948 ,0
-2752 ,1
-1966 ,0
-3821 ,0
-2707 ,1
-3511 ,0
-2476 ,0
-1293 ,1
-1701 ,0
-2720 ,1
-855 ,1
-1441 ,0
-8220 ,0
-3660 ,1
-3874 ,0
-2141 ,0
-175 ,1
-3807 ,0
-3406 ,0
-3819 ,0
-661 ,1
-101 ,0
-1992 ,0
-2842 ,1
-1184 ,1
-2147 ,1
-2767 ,0
-1176 ,1
-1483 ,1
-1917 ,0
-385 ,0
-3037 ,0
-3684 ,0
-1778 ,1
-2526 ,1
-1778 ,0
-3346 ,1
-1905 ,1
-1289 ,1
-3520 ,0
-1885 ,0
-1270 ,1
-1918 ,0
-2125 ,0
-3576 ,0
-3325 ,0
-133 ,1
-3625 ,0
-1828 ,0
-4550 ,1
-1918 ,0
-1221 ,1
-1968 ,0
-1922 ,0
-1516 ,1
-3198 ,0
-2037 ,0
-3598 ,0
-1772 ,0
-3687 ,0
-1048 ,0
-2042 ,1
-797 ,1
-3602 ,0
-1828 ,0
-800 ,1
-2925 ,0
-1848 ,0
-3760 ,1
-2734 ,1
-3317 ,0
-202 ,0
-613 ,1
-6167 ,1
-2338 ,0
-1848 ,0
-1382 ,1
-5860 ,1
-1656 ,1
-1255 ,1
-348 ,1
-4896 ,1
-5359 ,0
-1524 ,0
-179 ,0
-1256 ,0
-452 ,1
-1411 ,1
-2697 ,1
-3385 ,0
-2617 ,1
-1897 ,0
-1930 ,0
-1857 ,1
-1944 ,0
-1854 ,0
-1945 ,0
-1883 ,1
-810 ,1
-43 ,0
-1767 ,0
-260 ,0
-1584 ,0
-255 ,0
-857 ,1
-2010 ,0
-313 ,0
-2163 ,0
-1688 ,1
-1093 ,1
-2050 ,0
-1829 ,0
-335 ,0
-2166 ,1
-2504 ,0
-989 ,1
-625 ,1
-1927 ,0
-870 ,1
-1820 ,0
-1662 ,0
-1296 ,1
-2038 ,0
-674 ,1
-788 ,1
-1840 ,0
-1905 ,0
-1772 ,1
-1652 ,0
-2120 ,0
-2120 ,0
-1890 ,0
-1809 ,0
-1647 ,0
-2081 ,0
-1876 ,0
-188 ,1
-1970 ,0
-511 ,0
-1938 ,0
-1938 ,0
-1014 ,1
-1940 ,0
-295 ,1
-732 ,1
-2166 ,0
-1647 ,0
-1678 ,0
-1910 ,0
-180 ,0
-1655 ,0
-2566 ,0
-1365 ,1
-1959 ,0
-1885 ,1
-1851 ,0
-2178 ,0
-2143 ,0
-2037 ,0
-1903 ,0
-387 ,0
-1664 ,1
-1678 ,0
-1768 ,0
-877 ,1
-2040 ,0
-1502 ,0
-1884 ,0
-1294 ,1
-1405 ,1
-2006 ,0
-2006 ,0
-1068 ,1
-1838 ,0
-1835 ,0
-1859 ,0
-1179 ,1
-858 ,1
-1340 ,0
-1385 ,0
-2021 ,0
-2015 ,0
-2036 ,0
-1967 ,0
-764 ,0
-1121 ,0
-1527 ,0
-1440 ,0
-1561 ,0
-2354 ,0
-1871 ,0
-1092 ,1
-746 ,1
-2226 ,0
-1883 ,0
-1103 ,1
-1651 ,0
-1918 ,0
-2267 ,0
-1955 ,0
-684 ,0
-399 ,0
-1169 ,0
-1099 ,1
-1883 ,1
-475 ,0
-1847 ,0
-2225 ,0
-253 ,0
-1790 ,0
-1871 ,0
-499 ,0
-1930 ,0
-1821 ,0
-1351 ,1
-749 ,0
-1395 ,0
-1876 ,0
-1592 ,0
-1851 ,0
-1876 ,0
-1667 ,0
-402 ,0
-1482 ,0
-1857 ,0
-1637 ,0
-2964 ,0
-299 ,0
-2425 ,0
-2922 ,0
-2556 ,1
-3295 ,1
-2317 ,0
-1456 ,1
-4098 ,0
-3986 ,0
-3746 ,0
-3463 ,0
-3658 ,0
-3512 ,0
-723 ,1
-2565 ,1
-3367 ,0
-2273 ,0
-1458 ,1
-2082 ,0
-2380 ,1
-1235 ,1
-3598 ,0
-3961 ,0
-839 ,0
-394 ,0
-1559 ,1
-2042 ,1
-2572 ,1
-3335 ,1
-2935 ,1
-3602 ,0
-1138 ,1
-2352 ,0
-1785 ,1
-6705 ,0
-2344 ,0
-2382 ,0
-427 ,0
-1482 ,0
-930 ,1
-1514 ,0
-1204 ,0
-1577 ,0
-2218 ,0
-236 ,1
-466 ,1
-56 ,0
-273 ,0
-1525 ,0
-185 ,0
-1995 ,0
-1540 ,0
-1841 ,0
-1766 ,0
-978 ,0
-487 ,0
-1040 ,1
-1327 ,0
-1079 ,0
-1648 ,0
-1466 ,0
-2004 ,0
-1856 ,0
-880 ,1
-1415 ,0
-1704 ,0
-824 ,1
-1577 ,0
-463 ,0
-363 ,0
-1498 ,0
-626 ,0
-1464 ,0
-257 ,0
-875 ,0
-871 ,1
-922 ,0
-1561 ,0
-489 ,0
-501 ,0
-971 ,0
-1092 ,0
-1291 ,0
-1057 ,1
-23 ,0
-1121 ,0
-342 ,0
-1403 ,1
-965 ,0
-536 ,0
-1171 ,0
-358 ,1
-787 ,0
-922 ,0
-1393 ,0
-1112 ,0
-1227 ,0
-1010 ,0
-402 ,1
-1057 ,0
-77 ,1
-1065 ,0
-522 ,0
-867 ,0
-727 ,1
-1283 ,0
-1176 ,0
-564 ,1
-126 ,0
-819 ,0
-8 ,1
-560 ,1
-600 ,1
-760 ,0
-1661 ,0
-318 ,1
-730 ,1
-2083 ,0
-2957 ,1
-2533 ,0
-498 ,1
-2001 ,0
-433 ,1
-37 ,0
-1271 ,1
-2810 ,1
-2098 ,1
-3277 ,0
-3241 ,0
-1293 ,1
-929 ,1
-1655 ,0
-2266 ,0
-2582 ,0
-1916 ,1
-1215 ,0
-1791 ,1
-5222 ,0
-1437 ,1
-5196 ,0
-5490 ,0
-5292 ,0
-4902 ,0
-1948 ,1
-4703 ,0
-2673 ,1
-2673 ,1
-3855 ,0
-5352 ,0
-5056 ,0
-573 ,1
-5202 ,0
-1355 ,1
-1589 ,1
-1343 ,1
-2211 ,1
-5318 ,0
-4911 ,0
-3997 ,1
-4860 ,0
-5224 ,0
-4857 ,0
-4787 ,0
-4861 ,0
-3242 ,1
-4748 ,0
-1542 ,1
-4011 ,0
-5173 ,0
-747 ,1
-5145 ,0
-5069 ,1
-4521 ,0
-2933 ,1
-4362 ,0
-2457 ,1
-5025 ,0
-5213 ,0
-1050 ,1
-3541 ,1
-4408 ,0
-4832 ,0
-5087 ,0
-5221 ,0
-4800 ,0
-5108 ,0
-1345 ,1
-1111 ,1
-3950 ,1
-5173 ,0
-4717 ,0
-1213 ,1
-2048 ,1
-1837 ,0
-5086 ,0
-3323 ,1
-4866 ,0
-5064 ,0
-2631 ,1
-4535 ,0
-4910 ,0
-2576 ,1
-4751 ,0
-5580 ,0
-5059 ,0
-4811 ,0
-3175 ,0
-4899 ,0
-4703 ,0
-4860 ,0
-4489 ,0
-1651 ,1
-4671 ,0
-4430 ,0
-592 ,1
-5132 ,0
-4402 ,0
-4381 ,0
-4491 ,1
-1186 ,1
-4930 ,0
-1296 ,1
-283 ,1
-4770 ,0
-4311 ,0
-4245 ,0
-4936 ,0
-4247 ,1
-3626 ,0
-4944 ,0
-4436 ,0
-4846 ,0
-620 ,1
-4944 ,0
-3768 ,1
-4614 ,0
-4241 ,0
-958 ,1
-4388 ,0
-432 ,1
-423 ,1
-4740 ,0
-4430 ,0
-4850 ,0
-3252 ,0
-4858 ,1
-3038 ,1
-816 ,1
-3359 ,0
-4709 ,0
-4619 ,0
-904 ,1
-159 ,1
-4412 ,0
-4704 ,0
-4803 ,0
-288 ,1
-4857 ,0
-3062 ,0
-4514 ,0
-4628 ,0
-3318 ,1
-4749 ,0
-4733 ,0
-4426 ,0
-4778 ,0
-4506 ,0
-4398 ,0
-4817 ,0
-4388 ,0
-3772 ,0
-4423 ,0
-471 ,1
-3447 ,1
-4118 ,0
-4474 ,0
-4528 ,0
-4482 ,0
-4590 ,0
-4494 ,0
-4273 ,0
-1277 ,1
-1069 ,1
-4153 ,0
-2112 ,1
-4493 ,0
-2964 ,0
-2215 ,1
-4668 ,0
-4647 ,0
-961 ,1
-4422 ,0
-1338 ,1
-4391 ,0
-857 ,1
-4741 ,0
-4481 ,0
-4475 ,0
-2433 ,1
-4721 ,0
-4633 ,0
-4570 ,1
-4356 ,0
-862 ,1
-2261 ,1
-4694 ,0
-3021 ,0
-3940 ,1
-4381 ,0
-4230 ,0
-4468 ,0
-4570 ,0
-4108 ,0
-3901 ,0
-4584 ,0
-4056 ,0
-4125 ,0
-718 ,1
-4136 ,0
-4535 ,0
-2991 ,1
-1675 ,1
-4020 ,0
-1536 ,1
-1964 ,1
-680 ,1
-1246 ,1
-988 ,1
-3915 ,0
-4026 ,0
-3961 ,0
-4135 ,0
-3803 ,0
-4076 ,0
-3985 ,0
-4124 ,0
-1139 ,1
-4305 ,0
-4209 ,0
-4372 ,0
-4390 ,0
-3269 ,0
-1141 ,1
-803 ,1
-2046 ,1
-326 ,1
-4307 ,0
-4079 ,0
-4040 ,0
-2626 ,1
-3415 ,1
-4241 ,0
-4087 ,0
-696 ,1
-1669 ,1
-4080 ,0
-4131 ,0
-4178 ,0
-4243 ,0
-1503 ,1
-3835 ,0
-3781 ,0
-3104 ,0
-4027 ,0
-3964 ,0
-763 ,1
-3263 ,0
-4246 ,0
-3726 ,0
-3596 ,1
-985 ,1
-3716 ,0
-4180 ,0
-4135 ,0
-1164 ,1
-1884 ,0
-3506 ,0
-4037 ,0
-3737 ,0
-824 ,1
-3072 ,1
-3544 ,1
-465 ,1
-3898 ,0
-3721 ,0
-4108 ,0
-1788 ,1
-3460 ,0
-2582 ,1
-3410 ,0
-3389 ,1
-2587 ,1
-4046 ,0
-990 ,1
-3339 ,0
-3880 ,0
-746 ,1
-3775 ,0
-3319 ,0
-3695 ,0
-3296 ,0
-1164 ,1
-3736 ,0
-3289 ,0
-3872 ,0
-2260 ,1
-2782 ,1
-3586 ,0
-3793 ,0
-3794 ,1
-3766 ,0
-5353 ,0
-3324 ,0
-2168 ,0
-3331 ,0
-2976 ,0
-3327 ,0
-3509 ,0
-3005 ,0
-3368 ,0
-3026 ,1
-842 ,1
-1490 ,1
-3197 ,1
-7495 ,0
-4665 ,0
-1342 ,1
-351 ,1
-4978 ,0
-8321 ,1
-339 ,1
-4179 ,1
-3170 ,1
-3723 ,0
-456 ,1
-8941 ,1
-5160 ,1
-8725 ,0
-8441 ,1
-5291 ,0
-2034 ,1
-6513 ,1
-5244 ,1
-7745 ,1
-4962 ,1
-3435 ,1
-4435 ,1
-6529 ,1
-469 ,1
-1891 ,1
-1487 ,1
-5541 ,1
-4810 ,1
-9184 ,0
-9218 ,0
-2535 ,1
-1657 ,1
-1965 ,1
-3468 ,1
-1091 ,1
-7606 ,1
-9193 ,0
-6111 ,1
-3005 ,1
-5352 ,0
-4896 ,1
-1017 ,1
-2961 ,1
-1798 ,1
-3966 ,1
-7326 ,1
-7801 ,1
-1023 ,0
-5245 ,0
-2905 ,1
-8183 ,0
-6575 ,1
-1653 ,1
-5671 ,0
-2671 ,1
-7041 ,1
-7967 ,1
-8805 ,0
-5337 ,0
-514 ,1
-5894 ,1
-3799 ,1
-2617 ,1
-5978 ,1
-4936 ,0
-7067 ,1
-8507 ,0
-6239 ,1
-5208 ,1
-1011 ,1
-5692 ,0
-5958 ,1
-6786 ,1
-932 ,1
-2044 ,1
-3856 ,1
-461 ,1
-3421 ,1
-356 ,1
-4339 ,0
-3584 ,1
-1277 ,1
-8735 ,0
-1844 ,1
-4930 ,1
-4188 ,1
-2501 ,1
-6197 ,1
-7248 ,1
-501 ,1
-8303 ,0
-7943 ,1
-5127 ,1
-4443 ,1
-8555 ,0
-1569 ,1
-3794 ,1
-7954 ,0
-7673 ,0
-6313 ,1
-5920 ,1
-1235 ,1
-4207 ,0
-3581 ,1
-5735 ,0
-1688 ,1
-3684 ,1
-7177 ,0
-3652 ,0
-2704 ,1
-7390 ,0
-4257 ,0
-6852 ,0
-2680 ,0
-4290 ,1
-5599 ,1
-2875 ,1
-5046 ,0
-3288 ,1
-3570 ,1
-6018 ,0
-4457 ,1
-6808 ,0
-1244 ,1
-2700 ,1
-5056 ,1
-505 ,1
-909 ,1
-1444 ,1
-1929 ,1
-2970 ,1
-1196 ,1
-7027 ,0
-7180 ,0
-7110 ,0
-3347 ,1
-4088 ,0
-3584 ,1
-2034 ,1
-1041 ,1
-6628 ,1
-2610 ,1
-2432 ,1
-7099 ,0
-7049 ,0
-7441 ,0
-4186 ,1
-1299 ,1
-1453 ,1
-2223 ,1
-2559 ,1
-3971 ,0
-3218 ,1
-5949 ,0
-639 ,1
-6982 ,0
-3540 ,1
-5120 ,1
-3826 ,0
-3808 ,0
-1920 ,1
-6393 ,1
-1670 ,1
-6636 ,0
-6837 ,1
-4403 ,0
-1730 ,0
-944 ,1
-6860 ,0
-2510 ,1
-1057 ,1
-1287 ,1
-5950 ,1
-2447 ,1
-712 ,1
-2583 ,0
-6359 ,0
-4382 ,0
-2963 ,1
-840 ,1
-3939 ,1
-1346 ,1
-4496 ,0
-592 ,1
-6974 ,0
-3666 ,0
-5282 ,0
-4850 ,1
-1695 ,1
-2494 ,1
-6905 ,0
-4590 ,1
-1280 ,1
-6690 ,0
-1544 ,1
-1132 ,1
-3758 ,1
-822 ,1
-6358 ,0
-3744 ,1
-6828 ,0
-3063 ,0
-6660 ,0
-1414 ,1
-651 ,1
-4562 ,1
-6950 ,0
-6184 ,1
-3047 ,0
-6889 ,0
-4372 ,1
-3650 ,1
-2432 ,1
-2090 ,1
-5215 ,1
-1219 ,1
-2501 ,1
-1365 ,1
-3848 ,0
-2981 ,1
-6567 ,0
-3328 ,1
-3911 ,1
-1317 ,1
-3894 ,1
-6053 ,1
-5958 ,0
-2204 ,1
-2287 ,1
-1719 ,1
-2749 ,1
-2577 ,0
-5698 ,0
-1371 ,1
-1275 ,1
-5288 ,1
-4252 ,1
-1085 ,1
-703 ,1
-6023 ,1
-6990 ,0
-3509 ,0
-764 ,1
-3049 ,1
-1360 ,1
-2708 ,1
-7351 ,0
-1244 ,1
-3880 ,1
-4734 ,1
-6239 ,0
-5619 ,0
-6315 ,0
-6404 ,0
-6261 ,0
-879 ,1
-6118 ,0
-1194 ,1
-4713 ,0
-3530 ,1
-4518 ,1
-674 ,1
-2399 ,1
-6329 ,0
-6050 ,0
-6083 ,0
-6048 ,0
-1445 ,1
-1062 ,1
-6015 ,0
-5301 ,0
-6269 ,0
-1946 ,1
-5923 ,0
-3537 ,1
-5678 ,0
-6185 ,0
-3877 ,1
-5933 ,0
-1437 ,1
-943 ,1
-1500 ,1
-3083 ,1
-6208 ,0
-1200 ,1
-1753 ,0
-911 ,1
-5981 ,0
-2107 ,0
-5826 ,0
-2317 ,1
-2118 ,1
-1759 ,1
-3724 ,1
-5159 ,0
-604 ,1
-5806 ,0
-4518 ,1
-5532 ,0
-2079 ,1
-5930 ,1
-3527 ,1
-6080 ,0
-4743 ,0
-1494 ,0
-1479 ,1
-3921 ,1
-2361 ,0
-5943 ,1
-2664 ,1
-5283 ,1
-5966 ,0
-6077 ,0
-5953 ,0
-5423 ,1
-575 ,1
-2394 ,1
-5603 ,0
-3806 ,0
-4961 ,0
-2257 ,1
-5224 ,1
-2993 ,1
-1967 ,1
-5638 ,0
-2454 ,1
-5782 ,0
-5762 ,0
-2415 ,1
-5637 ,0
-5688 ,0
-5221 ,0
-4348 ,0
-3258 ,1
-1134 ,1
-2545 ,1
-3558 ,1
-5696 ,0
-1582 ,1
-3176 ,0
-1864 ,1
-2375 ,1
-3674 ,0
-3851 ,1
-4183 ,0
-5049 ,1
-2591 ,1
-5920 ,0
-5530 ,0
-1037 ,1
-4518 ,1
-5631 ,0
-1690 ,1
-5362 ,1
-21 ,0
-4503 ,1
-1183 ,0
-5048 ,0
-5715 ,0
-2745 ,0
-5544 ,0
-2841 ,0
-5475 ,0
-5731 ,0
-1507 ,1
-2245 ,0
-5908 ,0
-806 ,1
-2613 ,1
-5736 ,0
-1740 ,1
-5936 ,0
-2527 ,1
-3696 ,0
-5237 ,1
-4273 ,1
-1848 ,0
-5592 ,1
-597 ,1
-1030 ,1
-861 ,1
-3675 ,1
-5865 ,0
-5354 ,0
-2479 ,1
-3752 ,1
-657 ,1
-5269 ,0
-461 ,1
-2322 ,0
-2153 ,1
-5758 ,0
-5615 ,0
-1793 ,0
-834 ,1
-3447 ,1
-3771 ,1
-2709 ,1
-3726 ,1
-819 ,1
-3615 ,0
-5718 ,0
-5910 ,0
-5469 ,0
-1118 ,1
-5362 ,0
-1298 ,0
-1110 ,1
-2699 ,1
-4042 ,1
-3838 ,1
-2502 ,0
-4777 ,1
-608 ,0
-5820 ,0
-3325 ,1
-3478 ,1
-5125 ,0
-5529 ,0
-547 ,1
-5812 ,0
-2632 ,1
-5260 ,0
-4930 ,0
-5553 ,0
-2415 ,0
-2422 ,1
-5637 ,0
-1559 ,1
-5550 ,1
-5528 ,0
-664 ,1
-4536 ,1
-5616 ,0
-3379 ,1
-1957 ,1
-5530 ,0
-5503 ,0
-4028 ,1
-3556 ,1
-2680 ,1
-2324 ,1
-5438 ,0
-582 ,1
-3359 ,1
-1447 ,0
-4265 ,1
-1628 ,1
-5617 ,0
-2821 ,1
-5706 ,0
-5414 ,0
-1569 ,1
-2909 ,1
-5391 ,0
-5559 ,0
-5453 ,0
-4557 ,0
-1476 ,0
-4514 ,0
-5283 ,0
-1302 ,1
-3221 ,1
-3561 ,1
-5598 ,0
-5354 ,0
-3939 ,1
-3355 ,1
-4340 ,1
-2982 ,1
-1298 ,1
-5349 ,0
-2296 ,0
-5422 ,0
-1019 ,1
-5275 ,0
-497 ,1
-5131 ,0
-5203 ,0
-5488 ,0
-5000 ,0
-1324 ,1
-5157 ,0
-3115 ,1
-1153 ,1
-3724 ,0
-3374 ,1
-3062 ,1
-1607 ,0
-3242 ,1
-3032 ,1
-1385 ,1
-5102 ,0
-3494 ,1
-5153 ,0
-2509 ,1
-2132 ,1
-5584 ,0
-5596 ,0
-5025 ,0
-5602 ,0
-1826 ,1
-2279 ,1
-5908 ,0
-2999 ,1
-2063 ,1
-2336 ,1
-2379 ,0
-1960 ,1
-3720 ,1
-3852 ,1
-2610 ,1
-1108 ,1
-4606 ,0
-1453 ,1
-5240 ,0
-3394 ,0
-2268 ,1
-414 ,1
-2547 ,1
-501 ,1
-4970 ,0
-5222 ,0
-5085 ,0
-4483 ,1
-4792 ,0
-3143 ,1
-3524 ,0
-1730 ,1
-1137 ,1
-1396 ,1
-1291 ,1
-1609 ,1
-425 ,1
-3069 ,1
-4732 ,0
-5049 ,0
-5116 ,0
-2886 ,1
-4906 ,0
-426 ,1
-1372 ,1
-3436 ,1
-3724 ,1
-2965 ,1
-2562 ,1
-757 ,1
-1550 ,1
-5147 ,0
-5305 ,0
-5192 ,0
-5092 ,0
-5460 ,0
-5090 ,0
-3784 ,0
-1609 ,1
-5405 ,0
-865 ,1
-4834 ,1
-461 ,1
-3773 ,1
-691 ,1
-5059 ,0
-3004 ,1
-1179 ,1
-3579 ,1
-76 ,0
-3060 ,1
-368 ,1
-5137 ,0
-3487 ,1
-3410 ,1
-2135 ,1
-4464 ,1
-5082 ,0
-1369 ,1
-4408 ,1
-3743 ,1
-928 ,1
-5066 ,0
-4620 ,1
-5040 ,0
-4955 ,0
-4970 ,0
-5114 ,0
-5386 ,0
-3204 ,1
-2105 ,0
-3233 ,1
-2872 ,1
-2242 ,1
-3528 ,0
-2923 ,1
-3011 ,0
-3042 ,0
-5262 ,0
-2787 ,0
-780 ,1
-3850 ,0
-4735 ,1
-4218 ,1
-3212 ,0
-2260 ,0
-3018 ,0
-5245 ,0
-1382 ,1
-2506 ,1
-1484 ,1
-1039 ,1
-4672 ,1
-4514 ,0
-921 ,1
-476 ,1
-2955 ,1
-2421 ,0
-1355 ,1
-3511 ,0
-291 ,1
-5194 ,0
-4588 ,0
-4885 ,1
-4466 ,1
-3708 ,0
-4923 ,0
-485 ,1
-1719 ,1
-1520 ,1
-1854 ,1
-4828 ,1
-4891 ,0
-2724 ,1
-4490 ,0
-1744 ,1
-3496 ,1
-564 ,1
-4895 ,0
-2981 ,0
-1262 ,1
-4869 ,0
-587 ,1
-3777 ,1
-5089 ,0
-76 ,1
-3309 ,0
-1029 ,1
-1391 ,1
-530 ,1
-1273 ,1
-4973 ,0
-954 ,1
-3699 ,1
-5296 ,0
-1536 ,1
-1975 ,1
-701 ,1
-1717 ,1
-4662 ,0
-5187 ,0
-3001 ,0
-3712 ,0
-1799 ,0
-3583 ,0
-3093 ,1
-4437 ,0
-1099 ,1
-4771 ,0
-926 ,1
-4816 ,0
-489 ,1
-3744 ,0
-60 ,0
-4532 ,0
-4312 ,0
-444 ,1
-4035 ,0
-4820 ,0
-2396 ,1
-4741 ,0
-2373 ,1
-3604 ,1
-1453 ,1
-4989 ,0
-1033 ,1
-5036 ,0
-2957 ,0
-760 ,0
-2668 ,1
-771 ,0
-3542 ,0
-1041 ,1
-4635 ,1
-4773 ,0
-4405 ,0
-751 ,0
-565 ,1
-4802 ,0
-3089 ,1
-3530 ,1
-3461 ,0
-4737 ,0
-232 ,0
-4616 ,1
-4697 ,0
-3530 ,1
-4767 ,0
-1662 ,1
-2707 ,1
-2149 ,1
-4056 ,0
-3341 ,0
-2733 ,1
-1617 ,1
-2434 ,1
-4807 ,0
-535 ,0
-4648 ,1
-790 ,1
-2775 ,1
-4912 ,0
-982 ,1
-4521 ,0
-4215 ,0
-4669 ,0
-3663 ,0
-4369 ,0
-3711 ,0
-4267 ,0
-3493 ,1
-4778 ,0
-3846 ,1
-1262 ,1
-1911 ,0
-4134 ,1
-4080 ,0
-969 ,1
-2489 ,0
-1278 ,1
-3114 ,1
-1145 ,1
-1051 ,1
-1887 ,1
-632 ,1
-3249 ,1
-4551 ,0
-4006 ,0
-3625 ,0
-4213 ,0
-667 ,1
-4522 ,0
-924 ,1
-4152 ,0
-630 ,1
-3140 ,1
-4263 ,0
-1915 ,0
-4524 ,0
-3792 ,1
-4059 ,1
-1907 ,1
-2116 ,1
-2073 ,1
-4441 ,0
-6334 ,1
-4333 ,1
-5758 ,1
-4306 ,1
-2127 ,1
-4295 ,1
-406 ,1
-1448 ,1
-4690 ,0
-1562 ,1
-6172 ,0
-6749 ,0
-1492 ,1
-6933 ,0
-4361 ,1
-5683 ,0
-2194 ,1
-5861 ,0
-5214 ,0
-1136 ,1
-5046 ,1
-6201 ,0
-5023 ,0
-5605 ,1
-6032 ,0
-6060 ,0
-642 ,0
-680 ,1
-860 ,0
-1141 ,1
-6112 ,0
-468 ,1
-5656 ,1
-5478 ,1
-2150 ,0
-1722 ,1
-5470 ,1
-5108 ,1
-5294 ,0
-1623 ,1
-3075 ,1
-4921 ,1
-730 ,0
-69 ,1
-4717 ,0
-319 ,1
-302 ,1
-2325 ,1
-1410 ,0
-3391 ,0
-3700 ,1
-2990 ,0
-3737 ,1
-2005 ,0
-3610 ,0
-5215 ,0
-1603 ,0
-3469 ,1
-4617 ,1
-1301 ,0
-1854 ,1
-2358 ,1
-769 ,1
-2366 ,1
-4046 ,0
-3133 ,0
-1030 ,1
-2688 ,1
-1669 ,0
-2248 ,1
-2330 ,1
-2310 ,0
-3653 ,0
-1758 ,1
-476 ,1
-3752 ,0
-2660 ,1
-135 ,0
-3684 ,1
-4361 ,0
-2542 ,1
-2768 ,0
-994 ,1
-1754 ,1
-1648 ,1
-3595 ,0
-2374 ,1
-4794 ,0
-1992 ,0
-1133 ,1
-713 ,0
-2024 ,0
-739 ,1
-1279 ,1
-1574 ,1
-3335 ,0
-902 ,1
-3528 ,0
-1703 ,1
-4614 ,0
-3549 ,1
-775 ,0
-1269 ,1
-2088 ,0
-3409 ,0
-737 ,0
-4488 ,0
-4434 ,1
-3836 ,0
-1064 ,1
-3536 ,0
-1061 ,1
-3779 ,0
-3 ,1
-2244 ,1
-2177 ,1
-2550 ,1
-1704 ,0
-3344 ,0
-3156 ,1
-1608 ,0
-3622 ,0
-2637 ,0
-3723 ,0
-1586 ,1
-2254 ,0
-441 ,1
-2933 ,0
-1827 ,1
-2037 ,0
-2730 ,0
-2276 ,0
-3352 ,0
-672 ,1
-3150 ,0
-836 ,1
-105 ,1
-2772 ,1
-2650 ,1
-1486 ,1
-1056 ,1
-3009 ,1
-2204 ,0
-1915 ,1
-1528 ,0
-2658 ,0
-1761 ,0
-185 ,0
-1996 ,0
-2570 ,0
-877 ,1
-427 ,0
-2401 ,0
-2411 ,1
-548 ,1
-2022 ,0
-1869 ,0
-1889 ,0
-1737 ,0
-516 ,1
-1582 ,0
-4588 ,1
-6160 ,0
-1136 ,1
-2551 ,1
-501 ,1
-1816 ,0
-1940 ,0
-1738 ,0
-1625 ,0
-1514 ,0
-1360 ,0
-1778 ,0
-1758 ,0
-1295 ,1
-1643 ,1
-2221 ,0
-1785 ,0
-595 ,1
-1393 ,0
-2061 ,0
-2345 ,0
-1805 ,0
-2135 ,1
-2394 ,0
-2143 ,1
-2565 ,0
-803 ,1
-2108 ,0
-2407 ,0
-2008 ,0
-1103 ,1
-2694 ,0
-2306 ,0
-940 ,1
-2561 ,0
-2144 ,0
-2253 ,0
-962 ,1
-2132 ,0
-2592 ,0
-2431 ,0
-1476 ,1
-2155 ,0
-2843 ,1
-2655 ,0
-2746 ,0
-2605 ,0
-2847 ,1
-2404 ,0
-2654 ,0
-3061 ,0
-1954 ,1
-2578 ,1
-3184 ,0
-3664 ,0
-913 ,1
-274 ,1
-3510 ,0
-1315 ,1
-3675 ,0
-3580 ,0
-3056 ,0
-2816 ,0
-2876 ,0
-2598 ,0
-3236 ,0
-2693 ,0
-3150 ,0
-2954 ,0
-3438 ,0
-3242 ,0
-3062 ,0
-1570 ,1
-3477 ,0
-3245 ,0
-3349 ,0
-2951 ,0
-3038 ,0
-125 ,1
-2732 ,0
-2737 ,0
-3066 ,0
-3164 ,0
-3496 ,0
-3303 ,0
-3706 ,0
-152 ,1
-3706 ,0
-3567 ,0
-3293 ,0
-3308 ,0
-3672 ,0
-3064 ,0
-3401 ,0
-3469 ,0
-3720 ,0
-1955 ,1
-3597 ,0
-3538 ,0
-3728 ,0
-3123 ,1
-3467 ,0
-1125 ,1
-1708 ,1
-3526 ,0
-905 ,1
-1941 ,1
-3297 ,0
-3698 ,0
-3559 ,0
-1001 ,1
-4217 ,0
-3851 ,0
-2665 ,1
-897 ,1
-4130 ,0
-4188 ,0
-4060 ,0
-1808 ,1
-3089 ,1
-4042 ,0
-3948 ,0
-715 ,1
-802 ,1
-4120 ,0
-3520 ,1
-2724 ,1
-4157 ,0
-4377 ,0
-4134 ,0
-3079 ,1
-4150 ,0
-3884 ,0
-570 ,1
-1339 ,1
-3869 ,0
-4333 ,0
-4136 ,0
-4284 ,0
-3254 ,1
-4493 ,0
-4673 ,0
-1393 ,1
-658 ,1
-4442 ,0
-3646 ,0
-4206 ,0
-1579 ,1
-4374 ,1
-4415 ,0
-4231 ,0
-4515 ,0
-4662 ,0
-4414 ,1
-4479 ,0
-1906 ,1
-4223 ,0
-4554 ,0
-4726 ,0
-1484 ,1
-3751 ,1
-146 ,0
-4403 ,1
-4357 ,0
-4308 ,1
-2624 ,1
-1400 ,1
-2265 ,1
-3213 ,1
-4639 ,0
-502 ,1
-2280 ,1
-2928 ,1
-4903 ,0
-4981 ,0
-1672 ,1
-1378 ,1
-1575 ,1
-4960 ,0
-4608 ,0
-2712 ,1
-4963 ,0
-4956 ,0
-4548 ,0
-1650 ,1
-1601 ,1
-1896 ,1
-5212 ,0
-3348 ,1
-5478 ,0
-4883 ,0
-5328 ,0
-4736 ,0
-5190 ,0
-5124 ,0
-4804 ,0
-718 ,1
-698 ,1
-5133 ,0
-4464 ,1
-5525 ,0
-1760 ,1
-1530 ,1
-647 ,1
-1071 ,1
-1799 ,1
-5766 ,0
-1041 ,1
-452 ,1
-935 ,0
-3121 ,1
-4949 ,1
-1289 ,1
-2045 ,1
-5831 ,0
-5704 ,0
-635 ,1
-5255 ,0
-6246 ,0
-1332 ,1
-5560 ,0
-5611 ,0
-5866 ,0
-639 ,1
-5866 ,0
-1486 ,1
-866 ,1
-5998 ,0
-3270 ,1
-4759 ,1
-2909 ,1
-812 ,1
-3800 ,1
-205 ,1
-2354 ,1
-5977 ,0
-2482 ,1
-5906 ,0
-1342 ,1
-5279 ,1
-2587 ,1
-5893 ,0
+event_time,label
+2999 ,0
+1484 ,0
+3053 ,0
+1721 ,0
+1241 ,1
+234 ,1
+2947 ,0
+672 ,1
+2734 ,1
+2083 ,0
+1097 ,1
+1088 ,1
+2314 ,0
+2258 ,0
+2361 ,0
+424 ,1
+1594 ,0
+1784 ,0
+2005 ,0
+2310 ,0
+1900 ,0
+1718 ,0
+1846 ,0
+2482 ,0
+1420 ,0
+1543 ,0
+2673 ,0
+1583 ,0
+855 ,1
+2149 ,0
+2484 ,1
+2063 ,0
+1493 ,1
+2611 ,0
+1844 ,0
+242 ,1
+2713 ,0
+1333 ,0
+1481 ,0
+1866 ,0
+1281 ,1
+1433 ,0
+2554 ,0
+3148 ,0
+1175 ,1
+1384 ,0
+402 ,0
+2002 ,1
+2474 ,0
+247 ,0
+364 ,0
+872 ,1
+1694 ,0
+2569 ,0
+2527 ,0
+2368 ,0
+49 ,0
+1479 ,0
+1918 ,0
+1882 ,0
+1157 ,1
+1870 ,0
+1999 ,1
+2612 ,0
+388 ,1
+2398 ,0
+2271 ,0
+1405 ,0
+2596 ,0
+169 ,0
+372 ,0
+1630 ,1
+4562 ,1
+1841 ,0
+2580 ,0
+1553 ,1
+53 ,0
+1551 ,1
+2392 ,0
+1477 ,1
+1860 ,0
+2567 ,0
+1884 ,0
+1794 ,0
+59 ,0
+453 ,0
+1729 ,0
+2259 ,0
+2384 ,0
+1673 ,0
+1141 ,1
+325 ,0
+1428 ,0
+2282 ,0
+1293 ,1
+2283 ,0
+2215 ,0
+2801 ,1
+163 ,0
+1854 ,0
+130 ,0
+2363 ,0
+2171 ,0
+2363 ,0
+1368 ,1
+2755 ,0
+538 ,1
+1870 ,0
+2215 ,0
+2297 ,0
+1315 ,1
+939 ,1
+272 ,1
+459 ,1
+1877 ,0
+283 ,0
+2107 ,0
+2019 ,0
+2184 ,0
+2018 ,0
+2283 ,0
+2191 ,0
+2330 ,0
+2344 ,0
+2073 ,0
+1763 ,1
+729 ,1
+1919 ,0
+1585 ,0
+3359 ,0
+1995 ,0
+1881 ,0
+1299 ,1
+1927 ,0
+1841 ,0
+2179 ,0
+1195 ,1
+606 ,1
+3128 ,0
+2440 ,1
+1008 ,0
+4398 ,0
+1361 ,0
+2102 ,0
+325 ,1
+1958 ,0
+2653 ,0
+1366 ,0
+1324 ,0
+2173 ,1
+2827 ,1
+4277 ,1
+3861 ,0
+2083 ,0
+2567 ,1
+2194 ,1
+2198 ,0
+4028 ,0
+2302 ,0
+3259 ,0
+2319 ,0
+2130 ,0
+3628 ,0
+2965 ,0
+3948 ,0
+2752 ,1
+1966 ,0
+3821 ,0
+2707 ,1
+3511 ,0
+2476 ,0
+1293 ,1
+1701 ,0
+2720 ,1
+855 ,1
+1441 ,0
+8220 ,0
+3660 ,1
+3874 ,0
+2141 ,0
+175 ,1
+3807 ,0
+3406 ,0
+3819 ,0
+661 ,1
+101 ,0
+1992 ,0
+2842 ,1
+1184 ,1
+2147 ,1
+2767 ,0
+1176 ,1
+1483 ,1
+1917 ,0
+385 ,0
+3037 ,0
+3684 ,0
+1778 ,1
+2526 ,1
+1778 ,0
+3346 ,1
+1905 ,1
+1289 ,1
+3520 ,0
+1885 ,0
+1270 ,1
+1918 ,0
+2125 ,0
+3576 ,0
+3325 ,0
+133 ,1
+3625 ,0
+1828 ,0
+4550 ,1
+1918 ,0
+1221 ,1
+1968 ,0
+1922 ,0
+1516 ,1
+3198 ,0
+2037 ,0
+3598 ,0
+1772 ,0
+3687 ,0
+1048 ,0
+2042 ,1
+797 ,1
+3602 ,0
+1828 ,0
+800 ,1
+2925 ,0
+1848 ,0
+3760 ,1
+2734 ,1
+3317 ,0
+202 ,0
+613 ,1
+6167 ,1
+2338 ,0
+1848 ,0
+1382 ,1
+5860 ,1
+1656 ,1
+1255 ,1
+348 ,1
+4896 ,1
+5359 ,0
+1524 ,0
+179 ,0
+1256 ,0
+452 ,1
+1411 ,1
+2697 ,1
+3385 ,0
+2617 ,1
+1897 ,0
+1930 ,0
+1857 ,1
+1944 ,0
+1854 ,0
+1945 ,0
+1883 ,1
+810 ,1
+43 ,0
+1767 ,0
+260 ,0
+1584 ,0
+255 ,0
+857 ,1
+2010 ,0
+313 ,0
+2163 ,0
+1688 ,1
+1093 ,1
+2050 ,0
+1829 ,0
+335 ,0
+2166 ,1
+2504 ,0
+989 ,1
+625 ,1
+1927 ,0
+870 ,1
+1820 ,0
+1662 ,0
+1296 ,1
+2038 ,0
+674 ,1
+788 ,1
+1840 ,0
+1905 ,0
+1772 ,1
+1652 ,0
+2120 ,0
+2120 ,0
+1890 ,0
+1809 ,0
+1647 ,0
+2081 ,0
+1876 ,0
+188 ,1
+1970 ,0
+511 ,0
+1938 ,0
+1938 ,0
+1014 ,1
+1940 ,0
+295 ,1
+732 ,1
+2166 ,0
+1647 ,0
+1678 ,0
+1910 ,0
+180 ,0
+1655 ,0
+2566 ,0
+1365 ,1
+1959 ,0
+1885 ,1
+1851 ,0
+2178 ,0
+2143 ,0
+2037 ,0
+1903 ,0
+387 ,0
+1664 ,1
+1678 ,0
+1768 ,0
+877 ,1
+2040 ,0
+1502 ,0
+1884 ,0
+1294 ,1
+1405 ,1
+2006 ,0
+2006 ,0
+1068 ,1
+1838 ,0
+1835 ,0
+1859 ,0
+1179 ,1
+858 ,1
+1340 ,0
+1385 ,0
+2021 ,0
+2015 ,0
+2036 ,0
+1967 ,0
+764 ,0
+1121 ,0
+1527 ,0
+1440 ,0
+1561 ,0
+2354 ,0
+1871 ,0
+1092 ,1
+746 ,1
+2226 ,0
+1883 ,0
+1103 ,1
+1651 ,0
+1918 ,0
+2267 ,0
+1955 ,0
+684 ,0
+399 ,0
+1169 ,0
+1099 ,1
+1883 ,1
+475 ,0
+1847 ,0
+2225 ,0
+253 ,0
+1790 ,0
+1871 ,0
+499 ,0
+1930 ,0
+1821 ,0
+1351 ,1
+749 ,0
+1395 ,0
+1876 ,0
+1592 ,0
+1851 ,0
+1876 ,0
+1667 ,0
+402 ,0
+1482 ,0
+1857 ,0
+1637 ,0
+2964 ,0
+299 ,0
+2425 ,0
+2922 ,0
+2556 ,1
+3295 ,1
+2317 ,0
+1456 ,1
+4098 ,0
+3986 ,0
+3746 ,0
+3463 ,0
+3658 ,0
+3512 ,0
+723 ,1
+2565 ,1
+3367 ,0
+2273 ,0
+1458 ,1
+2082 ,0
+2380 ,1
+1235 ,1
+3598 ,0
+3961 ,0
+839 ,0
+394 ,0
+1559 ,1
+2042 ,1
+2572 ,1
+3335 ,1
+2935 ,1
+3602 ,0
+1138 ,1
+2352 ,0
+1785 ,1
+6705 ,0
+2344 ,0
+2382 ,0
+427 ,0
+1482 ,0
+930 ,1
+1514 ,0
+1204 ,0
+1577 ,0
+2218 ,0
+236 ,1
+466 ,1
+56 ,0
+273 ,0
+1525 ,0
+185 ,0
+1995 ,0
+1540 ,0
+1841 ,0
+1766 ,0
+978 ,0
+487 ,0
+1040 ,1
+1327 ,0
+1079 ,0
+1648 ,0
+1466 ,0
+2004 ,0
+1856 ,0
+880 ,1
+1415 ,0
+1704 ,0
+824 ,1
+1577 ,0
+463 ,0
+363 ,0
+1498 ,0
+626 ,0
+1464 ,0
+257 ,0
+875 ,0
+871 ,1
+922 ,0
+1561 ,0
+489 ,0
+501 ,0
+971 ,0
+1092 ,0
+1291 ,0
+1057 ,1
+23 ,0
+1121 ,0
+342 ,0
+1403 ,1
+965 ,0
+536 ,0
+1171 ,0
+358 ,1
+787 ,0
+922 ,0
+1393 ,0
+1112 ,0
+1227 ,0
+1010 ,0
+402 ,1
+1057 ,0
+77 ,1
+1065 ,0
+522 ,0
+867 ,0
+727 ,1
+1283 ,0
+1176 ,0
+564 ,1
+126 ,0
+819 ,0
+8 ,1
+560 ,1
+600 ,1
+760 ,0
+1661 ,0
+318 ,1
+730 ,1
+2083 ,0
+2957 ,1
+2533 ,0
+498 ,1
+2001 ,0
+433 ,1
+37 ,0
+1271 ,1
+2810 ,1
+2098 ,1
+3277 ,0
+3241 ,0
+1293 ,1
+929 ,1
+1655 ,0
+2266 ,0
+2582 ,0
+1916 ,1
+1215 ,0
+1791 ,1
+5222 ,0
+1437 ,1
+5196 ,0
+5490 ,0
+5292 ,0
+4902 ,0
+1948 ,1
+4703 ,0
+2673 ,1
+2673 ,1
+3855 ,0
+5352 ,0
+5056 ,0
+573 ,1
+5202 ,0
+1355 ,1
+1589 ,1
+1343 ,1
+2211 ,1
+5318 ,0
+4911 ,0
+3997 ,1
+4860 ,0
+5224 ,0
+4857 ,0
+4787 ,0
+4861 ,0
+3242 ,1
+4748 ,0
+1542 ,1
+4011 ,0
+5173 ,0
+747 ,1
+5145 ,0
+5069 ,1
+4521 ,0
+2933 ,1
+4362 ,0
+2457 ,1
+5025 ,0
+5213 ,0
+1050 ,1
+3541 ,1
+4408 ,0
+4832 ,0
+5087 ,0
+5221 ,0
+4800 ,0
+5108 ,0
+1345 ,1
+1111 ,1
+3950 ,1
+5173 ,0
+4717 ,0
+1213 ,1
+2048 ,1
+1837 ,0
+5086 ,0
+3323 ,1
+4866 ,0
+5064 ,0
+2631 ,1
+4535 ,0
+4910 ,0
+2576 ,1
+4751 ,0
+5580 ,0
+5059 ,0
+4811 ,0
+3175 ,0
+4899 ,0
+4703 ,0
+4860 ,0
+4489 ,0
+1651 ,1
+4671 ,0
+4430 ,0
+592 ,1
+5132 ,0
+4402 ,0
+4381 ,0
+4491 ,1
+1186 ,1
+4930 ,0
+1296 ,1
+283 ,1
+4770 ,0
+4311 ,0
+4245 ,0
+4936 ,0
+4247 ,1
+3626 ,0
+4944 ,0
+4436 ,0
+4846 ,0
+620 ,1
+4944 ,0
+3768 ,1
+4614 ,0
+4241 ,0
+958 ,1
+4388 ,0
+432 ,1
+423 ,1
+4740 ,0
+4430 ,0
+4850 ,0
+3252 ,0
+4858 ,1
+3038 ,1
+816 ,1
+3359 ,0
+4709 ,0
+4619 ,0
+904 ,1
+159 ,1
+4412 ,0
+4704 ,0
+4803 ,0
+288 ,1
+4857 ,0
+3062 ,0
+4514 ,0
+4628 ,0
+3318 ,1
+4749 ,0
+4733 ,0
+4426 ,0
+4778 ,0
+4506 ,0
+4398 ,0
+4817 ,0
+4388 ,0
+3772 ,0
+4423 ,0
+471 ,1
+3447 ,1
+4118 ,0
+4474 ,0
+4528 ,0
+4482 ,0
+4590 ,0
+4494 ,0
+4273 ,0
+1277 ,1
+1069 ,1
+4153 ,0
+2112 ,1
+4493 ,0
+2964 ,0
+2215 ,1
+4668 ,0
+4647 ,0
+961 ,1
+4422 ,0
+1338 ,1
+4391 ,0
+857 ,1
+4741 ,0
+4481 ,0
+4475 ,0
+2433 ,1
+4721 ,0
+4633 ,0
+4570 ,1
+4356 ,0
+862 ,1
+2261 ,1
+4694 ,0
+3021 ,0
+3940 ,1
+4381 ,0
+4230 ,0
+4468 ,0
+4570 ,0
+4108 ,0
+3901 ,0
+4584 ,0
+4056 ,0
+4125 ,0
+718 ,1
+4136 ,0
+4535 ,0
+2991 ,1
+1675 ,1
+4020 ,0
+1536 ,1
+1964 ,1
+680 ,1
+1246 ,1
+988 ,1
+3915 ,0
+4026 ,0
+3961 ,0
+4135 ,0
+3803 ,0
+4076 ,0
+3985 ,0
+4124 ,0
+1139 ,1
+4305 ,0
+4209 ,0
+4372 ,0
+4390 ,0
+3269 ,0
+1141 ,1
+803 ,1
+2046 ,1
+326 ,1
+4307 ,0
+4079 ,0
+4040 ,0
+2626 ,1
+3415 ,1
+4241 ,0
+4087 ,0
+696 ,1
+1669 ,1
+4080 ,0
+4131 ,0
+4178 ,0
+4243 ,0
+1503 ,1
+3835 ,0
+3781 ,0
+3104 ,0
+4027 ,0
+3964 ,0
+763 ,1
+3263 ,0
+4246 ,0
+3726 ,0
+3596 ,1
+985 ,1
+3716 ,0
+4180 ,0
+4135 ,0
+1164 ,1
+1884 ,0
+3506 ,0
+4037 ,0
+3737 ,0
+824 ,1
+3072 ,1
+3544 ,1
+465 ,1
+3898 ,0
+3721 ,0
+4108 ,0
+1788 ,1
+3460 ,0
+2582 ,1
+3410 ,0
+3389 ,1
+2587 ,1
+4046 ,0
+990 ,1
+3339 ,0
+3880 ,0
+746 ,1
+3775 ,0
+3319 ,0
+3695 ,0
+3296 ,0
+1164 ,1
+3736 ,0
+3289 ,0
+3872 ,0
+2260 ,1
+2782 ,1
+3586 ,0
+3793 ,0
+3794 ,1
+3766 ,0
+5353 ,0
+3324 ,0
+2168 ,0
+3331 ,0
+2976 ,0
+3327 ,0
+3509 ,0
+3005 ,0
+3368 ,0
+3026 ,1
+842 ,1
+1490 ,1
+3197 ,1
+7495 ,0
+4665 ,0
+1342 ,1
+351 ,1
+4978 ,0
+8321 ,1
+339 ,1
+4179 ,1
+3170 ,1
+3723 ,0
+456 ,1
+8941 ,1
+5160 ,1
+8725 ,0
+8441 ,1
+5291 ,0
+2034 ,1
+6513 ,1
+5244 ,1
+7745 ,1
+4962 ,1
+3435 ,1
+4435 ,1
+6529 ,1
+469 ,1
+1891 ,1
+1487 ,1
+5541 ,1
+4810 ,1
+9184 ,0
+9218 ,0
+2535 ,1
+1657 ,1
+1965 ,1
+3468 ,1
+1091 ,1
+7606 ,1
+9193 ,0
+6111 ,1
+3005 ,1
+5352 ,0
+4896 ,1
+1017 ,1
+2961 ,1
+1798 ,1
+3966 ,1
+7326 ,1
+7801 ,1
+1023 ,0
+5245 ,0
+2905 ,1
+8183 ,0
+6575 ,1
+1653 ,1
+5671 ,0
+2671 ,1
+7041 ,1
+7967 ,1
+8805 ,0
+5337 ,0
+514 ,1
+5894 ,1
+3799 ,1
+2617 ,1
+5978 ,1
+4936 ,0
+7067 ,1
+8507 ,0
+6239 ,1
+5208 ,1
+1011 ,1
+5692 ,0
+5958 ,1
+6786 ,1
+932 ,1
+2044 ,1
+3856 ,1
+461 ,1
+3421 ,1
+356 ,1
+4339 ,0
+3584 ,1
+1277 ,1
+8735 ,0
+1844 ,1
+4930 ,1
+4188 ,1
+2501 ,1
+6197 ,1
+7248 ,1
+501 ,1
+8303 ,0
+7943 ,1
+5127 ,1
+4443 ,1
+8555 ,0
+1569 ,1
+3794 ,1
+7954 ,0
+7673 ,0
+6313 ,1
+5920 ,1
+1235 ,1
+4207 ,0
+3581 ,1
+5735 ,0
+1688 ,1
+3684 ,1
+7177 ,0
+3652 ,0
+2704 ,1
+7390 ,0
+4257 ,0
+6852 ,0
+2680 ,0
+4290 ,1
+5599 ,1
+2875 ,1
+5046 ,0
+3288 ,1
+3570 ,1
+6018 ,0
+4457 ,1
+6808 ,0
+1244 ,1
+2700 ,1
+5056 ,1
+505 ,1
+909 ,1
+1444 ,1
+1929 ,1
+2970 ,1
+1196 ,1
+7027 ,0
+7180 ,0
+7110 ,0
+3347 ,1
+4088 ,0
+3584 ,1
+2034 ,1
+1041 ,1
+6628 ,1
+2610 ,1
+2432 ,1
+7099 ,0
+7049 ,0
+7441 ,0
+4186 ,1
+1299 ,1
+1453 ,1
+2223 ,1
+2559 ,1
+3971 ,0
+3218 ,1
+5949 ,0
+639 ,1
+6982 ,0
+3540 ,1
+5120 ,1
+3826 ,0
+3808 ,0
+1920 ,1
+6393 ,1
+1670 ,1
+6636 ,0
+6837 ,1
+4403 ,0
+1730 ,0
+944 ,1
+6860 ,0
+2510 ,1
+1057 ,1
+1287 ,1
+5950 ,1
+2447 ,1
+712 ,1
+2583 ,0
+6359 ,0
+4382 ,0
+2963 ,1
+840 ,1
+3939 ,1
+1346 ,1
+4496 ,0
+592 ,1
+6974 ,0
+3666 ,0
+5282 ,0
+4850 ,1
+1695 ,1
+2494 ,1
+6905 ,0
+4590 ,1
+1280 ,1
+6690 ,0
+1544 ,1
+1132 ,1
+3758 ,1
+822 ,1
+6358 ,0
+3744 ,1
+6828 ,0
+3063 ,0
+6660 ,0
+1414 ,1
+651 ,1
+4562 ,1
+6950 ,0
+6184 ,1
+3047 ,0
+6889 ,0
+4372 ,1
+3650 ,1
+2432 ,1
+2090 ,1
+5215 ,1
+1219 ,1
+2501 ,1
+1365 ,1
+3848 ,0
+2981 ,1
+6567 ,0
+3328 ,1
+3911 ,1
+1317 ,1
+3894 ,1
+6053 ,1
+5958 ,0
+2204 ,1
+2287 ,1
+1719 ,1
+2749 ,1
+2577 ,0
+5698 ,0
+1371 ,1
+1275 ,1
+5288 ,1
+4252 ,1
+1085 ,1
+703 ,1
+6023 ,1
+6990 ,0
+3509 ,0
+764 ,1
+3049 ,1
+1360 ,1
+2708 ,1
+7351 ,0
+1244 ,1
+3880 ,1
+4734 ,1
+6239 ,0
+5619 ,0
+6315 ,0
+6404 ,0
+6261 ,0
+879 ,1
+6118 ,0
+1194 ,1
+4713 ,0
+3530 ,1
+4518 ,1
+674 ,1
+2399 ,1
+6329 ,0
+6050 ,0
+6083 ,0
+6048 ,0
+1445 ,1
+1062 ,1
+6015 ,0
+5301 ,0
+6269 ,0
+1946 ,1
+5923 ,0
+3537 ,1
+5678 ,0
+6185 ,0
+3877 ,1
+5933 ,0
+1437 ,1
+943 ,1
+1500 ,1
+3083 ,1
+6208 ,0
+1200 ,1
+1753 ,0
+911 ,1
+5981 ,0
+2107 ,0
+5826 ,0
+2317 ,1
+2118 ,1
+1759 ,1
+3724 ,1
+5159 ,0
+604 ,1
+5806 ,0
+4518 ,1
+5532 ,0
+2079 ,1
+5930 ,1
+3527 ,1
+6080 ,0
+4743 ,0
+1494 ,0
+1479 ,1
+3921 ,1
+2361 ,0
+5943 ,1
+2664 ,1
+5283 ,1
+5966 ,0
+6077 ,0
+5953 ,0
+5423 ,1
+575 ,1
+2394 ,1
+5603 ,0
+3806 ,0
+4961 ,0
+2257 ,1
+5224 ,1
+2993 ,1
+1967 ,1
+5638 ,0
+2454 ,1
+5782 ,0
+5762 ,0
+2415 ,1
+5637 ,0
+5688 ,0
+5221 ,0
+4348 ,0
+3258 ,1
+1134 ,1
+2545 ,1
+3558 ,1
+5696 ,0
+1582 ,1
+3176 ,0
+1864 ,1
+2375 ,1
+3674 ,0
+3851 ,1
+4183 ,0
+5049 ,1
+2591 ,1
+5920 ,0
+5530 ,0
+1037 ,1
+4518 ,1
+5631 ,0
+1690 ,1
+5362 ,1
+21 ,0
+4503 ,1
+1183 ,0
+5048 ,0
+5715 ,0
+2745 ,0
+5544 ,0
+2841 ,0
+5475 ,0
+5731 ,0
+1507 ,1
+2245 ,0
+5908 ,0
+806 ,1
+2613 ,1
+5736 ,0
+1740 ,1
+5936 ,0
+2527 ,1
+3696 ,0
+5237 ,1
+4273 ,1
+1848 ,0
+5592 ,1
+597 ,1
+1030 ,1
+861 ,1
+3675 ,1
+5865 ,0
+5354 ,0
+2479 ,1
+3752 ,1
+657 ,1
+5269 ,0
+461 ,1
+2322 ,0
+2153 ,1
+5758 ,0
+5615 ,0
+1793 ,0
+834 ,1
+3447 ,1
+3771 ,1
+2709 ,1
+3726 ,1
+819 ,1
+3615 ,0
+5718 ,0
+5910 ,0
+5469 ,0
+1118 ,1
+5362 ,0
+1298 ,0
+1110 ,1
+2699 ,1
+4042 ,1
+3838 ,1
+2502 ,0
+4777 ,1
+608 ,0
+5820 ,0
+3325 ,1
+3478 ,1
+5125 ,0
+5529 ,0
+547 ,1
+5812 ,0
+2632 ,1
+5260 ,0
+4930 ,0
+5553 ,0
+2415 ,0
+2422 ,1
+5637 ,0
+1559 ,1
+5550 ,1
+5528 ,0
+664 ,1
+4536 ,1
+5616 ,0
+3379 ,1
+1957 ,1
+5530 ,0
+5503 ,0
+4028 ,1
+3556 ,1
+2680 ,1
+2324 ,1
+5438 ,0
+582 ,1
+3359 ,1
+1447 ,0
+4265 ,1
+1628 ,1
+5617 ,0
+2821 ,1
+5706 ,0
+5414 ,0
+1569 ,1
+2909 ,1
+5391 ,0
+5559 ,0
+5453 ,0
+4557 ,0
+1476 ,0
+4514 ,0
+5283 ,0
+1302 ,1
+3221 ,1
+3561 ,1
+5598 ,0
+5354 ,0
+3939 ,1
+3355 ,1
+4340 ,1
+2982 ,1
+1298 ,1
+5349 ,0
+2296 ,0
+5422 ,0
+1019 ,1
+5275 ,0
+497 ,1
+5131 ,0
+5203 ,0
+5488 ,0
+5000 ,0
+1324 ,1
+5157 ,0
+3115 ,1
+1153 ,1
+3724 ,0
+3374 ,1
+3062 ,1
+1607 ,0
+3242 ,1
+3032 ,1
+1385 ,1
+5102 ,0
+3494 ,1
+5153 ,0
+2509 ,1
+2132 ,1
+5584 ,0
+5596 ,0
+5025 ,0
+5602 ,0
+1826 ,1
+2279 ,1
+5908 ,0
+2999 ,1
+2063 ,1
+2336 ,1
+2379 ,0
+1960 ,1
+3720 ,1
+3852 ,1
+2610 ,1
+1108 ,1
+4606 ,0
+1453 ,1
+5240 ,0
+3394 ,0
+2268 ,1
+414 ,1
+2547 ,1
+501 ,1
+4970 ,0
+5222 ,0
+5085 ,0
+4483 ,1
+4792 ,0
+3143 ,1
+3524 ,0
+1730 ,1
+1137 ,1
+1396 ,1
+1291 ,1
+1609 ,1
+425 ,1
+3069 ,1
+4732 ,0
+5049 ,0
+5116 ,0
+2886 ,1
+4906 ,0
+426 ,1
+1372 ,1
+3436 ,1
+3724 ,1
+2965 ,1
+2562 ,1
+757 ,1
+1550 ,1
+5147 ,0
+5305 ,0
+5192 ,0
+5092 ,0
+5460 ,0
+5090 ,0
+3784 ,0
+1609 ,1
+5405 ,0
+865 ,1
+4834 ,1
+461 ,1
+3773 ,1
+691 ,1
+5059 ,0
+3004 ,1
+1179 ,1
+3579 ,1
+76 ,0
+3060 ,1
+368 ,1
+5137 ,0
+3487 ,1
+3410 ,1
+2135 ,1
+4464 ,1
+5082 ,0
+1369 ,1
+4408 ,1
+3743 ,1
+928 ,1
+5066 ,0
+4620 ,1
+5040 ,0
+4955 ,0
+4970 ,0
+5114 ,0
+5386 ,0
+3204 ,1
+2105 ,0
+3233 ,1
+2872 ,1
+2242 ,1
+3528 ,0
+2923 ,1
+3011 ,0
+3042 ,0
+5262 ,0
+2787 ,0
+780 ,1
+3850 ,0
+4735 ,1
+4218 ,1
+3212 ,0
+2260 ,0
+3018 ,0
+5245 ,0
+1382 ,1
+2506 ,1
+1484 ,1
+1039 ,1
+4672 ,1
+4514 ,0
+921 ,1
+476 ,1
+2955 ,1
+2421 ,0
+1355 ,1
+3511 ,0
+291 ,1
+5194 ,0
+4588 ,0
+4885 ,1
+4466 ,1
+3708 ,0
+4923 ,0
+485 ,1
+1719 ,1
+1520 ,1
+1854 ,1
+4828 ,1
+4891 ,0
+2724 ,1
+4490 ,0
+1744 ,1
+3496 ,1
+564 ,1
+4895 ,0
+2981 ,0
+1262 ,1
+4869 ,0
+587 ,1
+3777 ,1
+5089 ,0
+76 ,1
+3309 ,0
+1029 ,1
+1391 ,1
+530 ,1
+1273 ,1
+4973 ,0
+954 ,1
+3699 ,1
+5296 ,0
+1536 ,1
+1975 ,1
+701 ,1
+1717 ,1
+4662 ,0
+5187 ,0
+3001 ,0
+3712 ,0
+1799 ,0
+3583 ,0
+3093 ,1
+4437 ,0
+1099 ,1
+4771 ,0
+926 ,1
+4816 ,0
+489 ,1
+3744 ,0
+60 ,0
+4532 ,0
+4312 ,0
+444 ,1
+4035 ,0
+4820 ,0
+2396 ,1
+4741 ,0
+2373 ,1
+3604 ,1
+1453 ,1
+4989 ,0
+1033 ,1
+5036 ,0
+2957 ,0
+760 ,0
+2668 ,1
+771 ,0
+3542 ,0
+1041 ,1
+4635 ,1
+4773 ,0
+4405 ,0
+751 ,0
+565 ,1
+4802 ,0
+3089 ,1
+3530 ,1
+3461 ,0
+4737 ,0
+232 ,0
+4616 ,1
+4697 ,0
+3530 ,1
+4767 ,0
+1662 ,1
+2707 ,1
+2149 ,1
+4056 ,0
+3341 ,0
+2733 ,1
+1617 ,1
+2434 ,1
+4807 ,0
+535 ,0
+4648 ,1
+790 ,1
+2775 ,1
+4912 ,0
+982 ,1
+4521 ,0
+4215 ,0
+4669 ,0
+3663 ,0
+4369 ,0
+3711 ,0
+4267 ,0
+3493 ,1
+4778 ,0
+3846 ,1
+1262 ,1
+1911 ,0
+4134 ,1
+4080 ,0
+969 ,1
+2489 ,0
+1278 ,1
+3114 ,1
+1145 ,1
+1051 ,1
+1887 ,1
+632 ,1
+3249 ,1
+4551 ,0
+4006 ,0
+3625 ,0
+4213 ,0
+667 ,1
+4522 ,0
+924 ,1
+4152 ,0
+630 ,1
+3140 ,1
+4263 ,0
+1915 ,0
+4524 ,0
+3792 ,1
+4059 ,1
+1907 ,1
+2116 ,1
+2073 ,1
+4441 ,0
+6334 ,1
+4333 ,1
+5758 ,1
+4306 ,1
+2127 ,1
+4295 ,1
+406 ,1
+1448 ,1
+4690 ,0
+1562 ,1
+6172 ,0
+6749 ,0
+1492 ,1
+6933 ,0
+4361 ,1
+5683 ,0
+2194 ,1
+5861 ,0
+5214 ,0
+1136 ,1
+5046 ,1
+6201 ,0
+5023 ,0
+5605 ,1
+6032 ,0
+6060 ,0
+642 ,0
+680 ,1
+860 ,0
+1141 ,1
+6112 ,0
+468 ,1
+5656 ,1
+5478 ,1
+2150 ,0
+1722 ,1
+5470 ,1
+5108 ,1
+5294 ,0
+1623 ,1
+3075 ,1
+4921 ,1
+730 ,0
+69 ,1
+4717 ,0
+319 ,1
+302 ,1
+2325 ,1
+1410 ,0
+3391 ,0
+3700 ,1
+2990 ,0
+3737 ,1
+2005 ,0
+3610 ,0
+5215 ,0
+1603 ,0
+3469 ,1
+4617 ,1
+1301 ,0
+1854 ,1
+2358 ,1
+769 ,1
+2366 ,1
+4046 ,0
+3133 ,0
+1030 ,1
+2688 ,1
+1669 ,0
+2248 ,1
+2330 ,1
+2310 ,0
+3653 ,0
+1758 ,1
+476 ,1
+3752 ,0
+2660 ,1
+135 ,0
+3684 ,1
+4361 ,0
+2542 ,1
+2768 ,0
+994 ,1
+1754 ,1
+1648 ,1
+3595 ,0
+2374 ,1
+4794 ,0
+1992 ,0
+1133 ,1
+713 ,0
+2024 ,0
+739 ,1
+1279 ,1
+1574 ,1
+3335 ,0
+902 ,1
+3528 ,0
+1703 ,1
+4614 ,0
+3549 ,1
+775 ,0
+1269 ,1
+2088 ,0
+3409 ,0
+737 ,0
+4488 ,0
+4434 ,1
+3836 ,0
+1064 ,1
+3536 ,0
+1061 ,1
+3779 ,0
+3 ,1
+2244 ,1
+2177 ,1
+2550 ,1
+1704 ,0
+3344 ,0
+3156 ,1
+1608 ,0
+3622 ,0
+2637 ,0
+3723 ,0
+1586 ,1
+2254 ,0
+441 ,1
+2933 ,0
+1827 ,1
+2037 ,0
+2730 ,0
+2276 ,0
+3352 ,0
+672 ,1
+3150 ,0
+836 ,1
+105 ,1
+2772 ,1
+2650 ,1
+1486 ,1
+1056 ,1
+3009 ,1
+2204 ,0
+1915 ,1
+1528 ,0
+2658 ,0
+1761 ,0
+185 ,0
+1996 ,0
+2570 ,0
+877 ,1
+427 ,0
+2401 ,0
+2411 ,1
+548 ,1
+2022 ,0
+1869 ,0
+1889 ,0
+1737 ,0
+516 ,1
+1582 ,0
+4588 ,1
+6160 ,0
+1136 ,1
+2551 ,1
+501 ,1
+1816 ,0
+1940 ,0
+1738 ,0
+1625 ,0
+1514 ,0
+1360 ,0
+1778 ,0
+1758 ,0
+1295 ,1
+1643 ,1
+2221 ,0
+1785 ,0
+595 ,1
+1393 ,0
+2061 ,0
+2345 ,0
+1805 ,0
+2135 ,1
+2394 ,0
+2143 ,1
+2565 ,0
+803 ,1
+2108 ,0
+2407 ,0
+2008 ,0
+1103 ,1
+2694 ,0
+2306 ,0
+940 ,1
+2561 ,0
+2144 ,0
+2253 ,0
+962 ,1
+2132 ,0
+2592 ,0
+2431 ,0
+1476 ,1
+2155 ,0
+2843 ,1
+2655 ,0
+2746 ,0
+2605 ,0
+2847 ,1
+2404 ,0
+2654 ,0
+3061 ,0
+1954 ,1
+2578 ,1
+3184 ,0
+3664 ,0
+913 ,1
+274 ,1
+3510 ,0
+1315 ,1
+3675 ,0
+3580 ,0
+3056 ,0
+2816 ,0
+2876 ,0
+2598 ,0
+3236 ,0
+2693 ,0
+3150 ,0
+2954 ,0
+3438 ,0
+3242 ,0
+3062 ,0
+1570 ,1
+3477 ,0
+3245 ,0
+3349 ,0
+2951 ,0
+3038 ,0
+125 ,1
+2732 ,0
+2737 ,0
+3066 ,0
+3164 ,0
+3496 ,0
+3303 ,0
+3706 ,0
+152 ,1
+3706 ,0
+3567 ,0
+3293 ,0
+3308 ,0
+3672 ,0
+3064 ,0
+3401 ,0
+3469 ,0
+3720 ,0
+1955 ,1
+3597 ,0
+3538 ,0
+3728 ,0
+3123 ,1
+3467 ,0
+1125 ,1
+1708 ,1
+3526 ,0
+905 ,1
+1941 ,1
+3297 ,0
+3698 ,0
+3559 ,0
+1001 ,1
+4217 ,0
+3851 ,0
+2665 ,1
+897 ,1
+4130 ,0
+4188 ,0
+4060 ,0
+1808 ,1
+3089 ,1
+4042 ,0
+3948 ,0
+715 ,1
+802 ,1
+4120 ,0
+3520 ,1
+2724 ,1
+4157 ,0
+4377 ,0
+4134 ,0
+3079 ,1
+4150 ,0
+3884 ,0
+570 ,1
+1339 ,1
+3869 ,0
+4333 ,0
+4136 ,0
+4284 ,0
+3254 ,1
+4493 ,0
+4673 ,0
+1393 ,1
+658 ,1
+4442 ,0
+3646 ,0
+4206 ,0
+1579 ,1
+4374 ,1
+4415 ,0
+4231 ,0
+4515 ,0
+4662 ,0
+4414 ,1
+4479 ,0
+1906 ,1
+4223 ,0
+4554 ,0
+4726 ,0
+1484 ,1
+3751 ,1
+146 ,0
+4403 ,1
+4357 ,0
+4308 ,1
+2624 ,1
+1400 ,1
+2265 ,1
+3213 ,1
+4639 ,0
+502 ,1
+2280 ,1
+2928 ,1
+4903 ,0
+4981 ,0
+1672 ,1
+1378 ,1
+1575 ,1
+4960 ,0
+4608 ,0
+2712 ,1
+4963 ,0
+4956 ,0
+4548 ,0
+1650 ,1
+1601 ,1
+1896 ,1
+5212 ,0
+3348 ,1
+5478 ,0
+4883 ,0
+5328 ,0
+4736 ,0
+5190 ,0
+5124 ,0
+4804 ,0
+718 ,1
+698 ,1
+5133 ,0
+4464 ,1
+5525 ,0
+1760 ,1
+1530 ,1
+647 ,1
+1071 ,1
+1799 ,1
+5766 ,0
+1041 ,1
+452 ,1
+935 ,0
+3121 ,1
+4949 ,1
+1289 ,1
+2045 ,1
+5831 ,0
+5704 ,0
+635 ,1
+5255 ,0
+6246 ,0
+1332 ,1
+5560 ,0
+5611 ,0
+5866 ,0
+639 ,1
+5866 ,0
+1486 ,1
+866 ,1
+5998 ,0
+3270 ,1
+4759 ,1
+2909 ,1
+812 ,1
+3800 ,1
+205 ,1
+2354 ,1
+5977 ,0
+2482 ,1
+5906 ,0
+1342 ,1
+5279 ,1
+2587 ,1
+5893 ,0
diff --git a/datasets/pbc2.csv b/datasets/pbc2.csv
index 87bcc36a..11625db8 100644
--- a/datasets/pbc2.csv
+++ b/datasets/pbc2.csv
@@ -1943,4 +1943,4 @@
"1942","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",0.564012703975468,"No","No","No","No edema",5.5,NA,3.2,1678,124,189,10.9,2,0
"1943","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",1.06779104150695,"No","No","No","No edema",7.4,312,3.56,1767,166,148,11.7,2,0
"1944","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.12189245427664,"No","No","Yes","edema no diuretics",16.3,688,3.34,2460,173,138,13,2,0
-"1945","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.94327017851276,"No","No","Yes","edema no diuretics",23.4,741,3.42,3012,200,128,13.4,3,0
\ No newline at end of file
+"1945","312",3.98915781404008,"alive","placebo",33.1535428759172,"female",2.94327017851276,"No","No","Yes","edema no diuretics",23.4,741,3.42,3012,200,128,13.4,3,0
diff --git a/datasets/pbc_dataset.py b/datasets/pbc_dataset.py
index e41a7347..66f5de03 100644
--- a/datasets/pbc_dataset.py
+++ b/datasets/pbc_dataset.py
@@ -1,6 +1,7 @@
import os
-import pandas as pd
+
import numpy as np
+import pandas as pd
from sklearn.impute import SimpleImputer
@@ -12,7 +13,8 @@ def load_pbc2_dataset():
with missing values imputed. The function also constructs the outcome variable
for competing risks analysis and returns time-to-event data.
- Returns:
+ Returns
+ -------
tuple: A tuple containing the following elements:
- x (numpy.ndarray): Combined feature matrix with categorical features
one-hot encoded and continuous features imputed.
@@ -25,56 +27,56 @@ def load_pbc2_dataset():
- n_continuous (int): Number of continuous features.
- feature_ranges (list of tuple): List of (min, max) ranges for each feature.
"""
-
-
file_path = os.path.join(os.path.dirname(__file__), "pbc2.csv")
data = pd.read_csv(file_path)
- data = data.drop(columns=['id', 'sno.', 'year', 'status2'], axis=1)
-
+ data = data.drop(columns=["id", "sno.", "year", "status2"], axis=1)
event_type = np.where(
- data['status'] == 'dead', 1,
- np.where(data['status'] == 'transplanted', 2, 0)
+ data["status"] == "dead", 1, np.where(data["status"] == "transplanted", 2, 0)
)
-
cont_cols = [
- 'age', 'serBilir', 'serChol', 'albumin', 'alkaline',
- 'SGOT', 'platelets', 'prothrombin', 'histologic'
+ "age",
+ "serBilir",
+ "serChol",
+ "albumin",
+ "alkaline",
+ "SGOT",
+ "platelets",
+ "prothrombin",
+ "histologic",
]
- x_cont = data[cont_cols].replace('NA', np.nan).astype(float)
+ x_cont = data[cont_cols].replace("NA", np.nan).astype(float)
-
- mean_imputer = SimpleImputer(strategy='mean')
+ mean_imputer = SimpleImputer(strategy="mean")
x_cont_imputed = mean_imputer.fit_transform(x_cont)
- cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0),
- np.nanmax(x_cont_imputed, axis=0)))
+ cont_feature_ranges = list(
+ zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0))
+ )
-
- cat_cols = ['sex', 'drug', 'ascites', 'hepatomegaly', 'spiders', 'edema']
- x_cat = data[cat_cols].fillna('missing')
+ cat_cols = ["sex", "drug", "ascites", "hepatomegaly", "spiders", "edema"]
+ x_cat = data[cat_cols].fillna("missing")
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat_imputed = cat_imputer.fit_transform(x_cat)
x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols)
x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True)
-
- x_cat_encoded = x_cat_encoded.loc[:, ~x_cat_encoded.columns.str.contains('_missing')]
+ x_cat_encoded = x_cat_encoded.loc[
+ :, ~x_cat_encoded.columns.str.contains("_missing")
+ ]
cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1]
-
x = np.hstack([x_cat_encoded.values, x_cont_imputed])
feature_names = np.concatenate([x_cat_encoded.columns, cont_cols])
n_continuous = len(cont_cols)
feature_ranges = cat_feature_ranges + cont_feature_ranges
-
- t = data['years'].astype(float).values * 365.25
+ t = data["years"].astype(float).values * 365.25
valid = ~np.isnan(t)
return (
@@ -83,10 +85,9 @@ def load_pbc2_dataset():
event_type[valid],
feature_names,
n_continuous,
- feature_ranges
+ feature_ranges,
)
if __name__ == "__main__":
x, t, e, feature_names, n_continuous, feature_ranges = load_pbc2_dataset()
-
\ No newline at end of file
diff --git a/datasets/support2.csv b/datasets/support2.csv
index b090d1b2..07c9d5c2 100644
--- a/datasets/support2.csv
+++ b/datasets/support2.csv
@@ -9103,4 +9103,4 @@ sno,age,death,sex,hospdead,slos,d.time,dzgroup,dzclass,num.co,edu,income,scoma,c
9102, 55.15399,0,female,0, 29, 347,Coma,Coma,1,11,, 41, 35377.000, 23558.5000, 22131.04690,18.000000,white,25.7968750, 31,0.5539550780,0.4859619140, 1,0,0,no,0.500000000,5.000000e-01,no dnr, 29, 43.0,, 0.000000, 8,38.59375,218.50000,,, 5.89941406,135,7.289062, 190.000000, 49.000000, 0.000000,, 0,,0.0000000
9103, 70.38196,0,male,0, 8, 346,ARF/MOSF w/Sepsis,ARF/MOSF,1,,, 0, 46564.000, 31409.0156, 31131.25000,23.000000,white,22.6992188, 39,0.7419433590,0.6608886720, 18,0,0,no,0.899999619,7.999997e-01,no dnr, 8,111.0, 8.39843750, 83.000000,24,36.69531,180.00000,, 0.39996338, 2.69970703,139,7.379883, 189.000000, 60.000000,3900.000000,,,,2.5253906
9104, 47.01999,1,male,1, 7, 7,MOSF w/Malig,ARF/MOSF,1,13,, 0, 58439.000,,,35.500000,white,40.1953125, 51,0.1779785160,0.0919952393, 22,0,0,yes,0.089999974,8.999997e-02,dnr after sadm, 5, 99.0, 7.59960938,110.000000,24,36.39844,428.56250, 1.1999512, 0.39996338, 3.50000000,135,7.469727, 246.000000, 55.000000,,, 0,<2 mo. follow-up,0.0000000
-9105, 81.53894,1,female,0, 12, 198,ARF/MOSF w/Sepsis,ARF/MOSF,1, 8,$11-$25k, 0, 15604.000, 10605.7578, 12146.15620,13.500000,white,18.0976562, 7,0.8328857420,0.7769775390, 1,1,0,no,,,no dnr, 12, 75.0, 8.59960938, 69.000000,24,36.19531,230.40625, 4.5000000, 0.59997559, 1.19995117,137,7.289062, 187.000000, 15.000000,, 0,,no(M2 and SIP pres),0.4947510
\ No newline at end of file
+9105, 81.53894,1,female,0, 12, 198,ARF/MOSF w/Sepsis,ARF/MOSF,1, 8,$11-$25k, 0, 15604.000, 10605.7578, 12146.15620,13.500000,white,18.0976562, 7,0.8328857420,0.7769775390, 1,1,0,no,,,no dnr, 12, 75.0, 8.59960938, 69.000000,24,36.19531,230.40625, 4.5000000, 0.59997559, 1.19995117,137,7.289062, 187.000000, 15.000000,, 0,,no(M2 and SIP pres),0.4947510
diff --git a/datasets/support_dataset.py b/datasets/support_dataset.py
index 2146e92f..ce1aae67 100644
--- a/datasets/support_dataset.py
+++ b/datasets/support_dataset.py
@@ -1,10 +1,11 @@
import os
-import pandas as pd
+
import numpy as np
+import pandas as pd
+
# Enable and import IterativeImputer (MICE) from scikit-learn:
from sklearn.experimental import enable_iterative_imputer # noqa
-from sklearn.impute import IterativeImputer, SimpleImputer
-
+from sklearn.impute import SimpleImputer
def load_support_dataset():
@@ -12,9 +13,11 @@ def load_support_dataset():
Load and preprocess the SUPPORT dataset.
This function reads the SUPPORT dataset from a CSV file, imputes missing values,
encodes categorical features, and constructs the outcome variable for survival analysis.
- It returns the processed features, time-to-event data, event types, feature names,
+ It returns the processed features, time-to-event data, event types, feature names,
the number of continuous features, and feature ranges.
- Returns:
+
+ Returns
+ -------
tuple: A tuple containing:
- x (numpy.ndarray): Combined array of processed categorical and continuous features.
- t (numpy.ndarray): Time-to-event data with a +1 offset.
@@ -22,7 +25,9 @@ def load_support_dataset():
- feature_names (numpy.ndarray): Array of feature names.
- n_continuous (int): Number of continuous features.
- feature_ranges (list): List of tuples representing the range (min, max) for each feature.
- Notes:
+
+ Notes
+ -----
- The dataset file "support2.csv" must be located in the same directory as this script.
- Continuous features are imputed using the median strategy if missing values are present.
- Categorical features are imputed using the most frequent strategy if missing values are present.
@@ -30,72 +35,87 @@ def load_support_dataset():
- The outcome variable is constructed using the 'ca', 'dzgroup', and 'death' columns.
- Time-to-event data ('d.time') is offset by +1 to avoid zero follow-up times.
"""
-
-
file_path = os.path.join(os.path.dirname(__file__), "support2.csv")
data = pd.read_csv(file_path)
print(data.columns)
-
+ is_cancer = data["ca"].astype(str).str.lower().str.contains("meta") | data[
+ "dzgroup"
+ ].astype(str).str.lower().str.contains("cancer")
+ event_type = np.where(data["death"] == 1, np.where(is_cancer, 1, 2), 0)
-
- is_cancer = data['ca'].astype(str).str.lower().str.contains("meta") | \
- data['dzgroup'].astype(str).str.lower().str.contains("cancer")
- event_type = np.where(data['death'] == 1, np.where(is_cancer, 1, 2), 0)
-
-
-
-
cont_cols = [
- 'age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 'pafi', 'alb',
- 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 'urine',
- 'scoma', 'aps', 'sps', 'adls', 'adlsc', 'charges', 'totcst', 'totmcst', 'avtisst'
+ "age",
+ "num.co",
+ "meanbp",
+ "wblc",
+ "hrt",
+ "resp",
+ "temp",
+ "pafi",
+ "alb",
+ "bili",
+ "crea",
+ "sod",
+ "ph",
+ "glucose",
+ "bun",
+ "urine",
+ "scoma",
+ "aps",
+ "sps",
+ "adls",
+ "adlsc",
+ "charges",
+ "totcst",
+ "totmcst",
+ "avtisst",
]
x_cont = data[cont_cols]
-
+
if x_cont.isnull().values.any():
- simp_imputer = SimpleImputer(strategy='most_frequent') #originally was median
+ simp_imputer = SimpleImputer(strategy="most_frequent") # originally was median
x_cont_imputed = simp_imputer.fit_transform(x_cont)
else:
x_cont_imputed = x_cont.values
- cont_feature_ranges = list(zip(np.nanmin(x_cont_imputed, axis=0),
- np.nanmax(x_cont_imputed, axis=0)))
+ cont_feature_ranges = list(
+ zip(np.nanmin(x_cont_imputed, axis=0), np.nanmax(x_cont_imputed, axis=0))
+ )
# -- Categorical features --
# Remove leakage fields 'ca', 'dzgroup', 'dzclass'
- cat_cols = ['sex', 'income', 'race', 'dnr', 'dementia', 'diabetes']
+ cat_cols = ["sex", "income", "race", "dnr", "dementia", "diabetes"]
x_cat = data[cat_cols]
if x_cat.isnull().values.any():
- cat_imputer = SimpleImputer(strategy='most_frequent')
+ cat_imputer = SimpleImputer(strategy="most_frequent")
x_cat_imputed = cat_imputer.fit_transform(x_cat)
else:
x_cat_imputed = x_cat.values
x_cat_df = pd.DataFrame(x_cat_imputed, columns=cat_cols)
-
+
x_cat_encoded = pd.get_dummies(x_cat_df, drop_first=True)
cat_feature_ranges = [(0.0, 1.0)] * x_cat_encoded.shape[1]
-
x = np.hstack([x_cat_encoded.values, x_cont_imputed])
feature_names = np.concatenate([x_cat_encoded.columns, cont_cols])
n_continuous = len(cont_cols)
feature_ranges = cat_feature_ranges + cont_feature_ranges
-
- t = data['d.time'].values
+ t = data["d.time"].values
valid = ~np.isnan(t)
-
+
print("Completed imputation of missing values.")
return (
x[valid],
- t[valid] + 1, # Add +1 offset to follow-up times
+ t[valid] + 1, # Add +1 offset to follow-up times
event_type[valid],
feature_names,
n_continuous,
- feature_ranges
+ feature_ranges,
)
+
if __name__ == "__main__":
x, t, e, feature_names, n_continuous, feature_ranges = load_support_dataset()
print("Feature names:", feature_names)
diff --git a/datasets/synthetic_dataset.py b/datasets/synthetic_dataset.py
index 84d5cb17..abd62a1b 100644
--- a/datasets/synthetic_dataset.py
+++ b/datasets/synthetic_dataset.py
@@ -1,21 +1,26 @@
+import os
+from typing import List, Tuple
+
import numpy as np
import pandas as pd
-from typing import Tuple, List
-import os
import torch
-def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]]:
+
+def load_synthetic_dataset() -> Tuple[
+ np.ndarray, np.ndarray, np.ndarray, List[str], int, List[tuple]
+]:
"""
Loads a synthetic competing risks dataset from a CSV file.
-
+
The CSV is expected to have a header with the following columns:
- time: observed time
- label: event indicator (0 for censored; >0 for event types)
- true_time: (optional) true time (unused here)
- true_label: (optional) true event label (unused here)
- feature1, feature2, ..., featureN: feature values
-
- Returns:
+
+ Returns
+ -------
X (np.ndarray): Feature matrix of shape (n_samples, n_features).
T_obs (np.ndarray): Observed times of shape (n_samples,).
e (np.ndarray): Event indicators of shape (n_samples,).
@@ -23,30 +28,24 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s
n_continuous (int): Total number of continuous features.
feature_ranges (List[tuple]): List of (min, max) tuples for each feature.
"""
-
use_gpu_compatible_dtype = torch.cuda.is_available()
-
-
-
-
+
file_path = os.path.join(os.path.dirname(__file__), "synthetic_comprisk.csv")
df = pd.read_csv(file_path)
-
-
- T_obs = df["time"].values .astype(np.float32)
+ T_obs = df["time"].values.astype(np.float32)
e = df["label"].values.astype(np.float32)
-
-
+
feature_columns = [col for col in df.columns if col.startswith("feature")]
- X = df[feature_columns].values .astype(np.float32)
-
-
+ X = df[feature_columns].values.astype(np.float32)
+
feature_names = feature_columns
n_continuous = X.shape[1]
-
- feature_ranges = [(float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)]
-
+
+ feature_ranges = [
+ (float(X[:, i].min()), float(X[:, i].max())) for i in range(n_continuous)
+ ]
+
return X, T_obs, e, feature_names, n_continuous, feature_ranges
@@ -61,10 +60,3 @@ def load_synthetic_dataset() -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[s
# For a quick inspection, print the first 5 rows (features, time, event)
print("First 5 rows:")
print(np.hstack([X[:5], T_obs[:5, None], e[:5, None]]))
-
-
-
-
-
-
-
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..e1360093
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,24 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+# NOTE: This is not generated from `sphinx-quickstart` and manually added
+serve:
+ sphinx-autobuild $(SOURCEDIR) $(BUILDDIR)
\ No newline at end of file
diff --git a/docs/api.md b/docs/api.md
new file mode 100644
index 00000000..432868f3
--- /dev/null
+++ b/docs/api.md
@@ -0,0 +1,47 @@
+# Python API Reference
+
+This section documents the Python API for CRISP-NAM.
+
+## Model
+
+::: crisp_nam.models.crisp_nam_model
+ options:
+ show_root_heading: true
+ members:
+ - CrispNamModel
+::: crisp_nam.models.deephit_model
+ options:
+ show_root_heading: true
+ filters:
+ - "!^_"
+ members:
+ - DeepHit
+
+ ## Metrics
+
+::: crisp_nam.metrics.calibration
+ options:
+ show_root_heading: true
+ members: true
+::: crisp_nam.metrics.discrimination
+ options:
+ show_root_heading: true
+ members: true
+::: crisp_nam.metrics.ipcw
+ options:
+ show_root_heading: true
+ members: true
+
+ ## Utilities
+::: crisp_nam.utils.plotting
+ options:
+ show_root_heading: true
+ members: true
+::: crisp_nam.utils.risk_cif
+ options:
+ show_root_heading: true
+ members: true
+::: crisp_nam.utils.loss
+ options:
+ show_root_heading: true
+ members: true
\ No newline at end of file
diff --git a/docs/assets/favicon-48x48.svg b/docs/assets/favicon-48x48.svg
new file mode 100644
index 00000000..3cd92e57
--- /dev/null
+++ b/docs/assets/favicon-48x48.svg
@@ -0,0 +1,9 @@
+
+
diff --git a/docs/assets/favicon.ico b/docs/assets/favicon.ico
new file mode 100644
index 00000000..30762370
Binary files /dev/null and b/docs/assets/favicon.ico differ
diff --git a/docs/assets/launch.png b/docs/assets/launch.png
new file mode 100644
index 00000000..c9a3e06f
Binary files /dev/null and b/docs/assets/launch.png differ
diff --git a/docs/assets/vector-logo.svg b/docs/assets/vector-logo.svg
new file mode 100644
index 00000000..8dd76b5b
--- /dev/null
+++ b/docs/assets/vector-logo.svg
@@ -0,0 +1,172 @@
+
+
+
+
diff --git a/docs/blog.md b/docs/blog.md
new file mode 100644
index 00000000..8ceecd18
--- /dev/null
+++ b/docs/blog.md
@@ -0,0 +1,3 @@
+There is a blog explaining the architecture of CRISP-NAM and showcasing the results.
+
+Head [here](blog/index.html) to read it!
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 00000000..76614c1e
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,28 @@
+# CRISP-NAM: Competing Risks for Interpretable Survival Analysis using Neural Additive Models
+
+This repository contains research code for the paper: [CRISP-NAM: Competing Risks Interpretable Survival
+Prediction with Neural Additive Models](https://ceur-ws.org/Vol-4059/paper5.pdf).
+It includes the Python code for the following:
+
+- Models: `CRISP-NAM` and `DeepHIT`
+- Data loading utilities for 4 datasets: Framingham, PBC, Support2, Synthetic
+- Training scripts: Standard training, Hyperparameter optimization via Optuna, Nested cross validation
+- Metrics: Loss and risk functions for survival analysis.
+- Plotting: Feature importance and Shape functions for interpretability.
+
+## PyPI package
+The core files of research: models, metrics and plotting utilities.
+
+## Installation
+You can install the package via the following pip command:
+```bash
+pip install crisp_nam
+```
+
+## Citation
+> @inproceedings{ramachandram2025crispnam,
+title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models},
+author={Ramachandram, Dhanesh and Raval, Ananya},
+booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain},
+year={2025}
+}
\ No newline at end of file
diff --git a/docs/overrides/partials/copyright.html b/docs/overrides/partials/copyright.html
new file mode 100644
index 00000000..776166c1
--- /dev/null
+++ b/docs/overrides/partials/copyright.html
@@ -0,0 +1,22 @@
+
+
+ {% if config.copyright %}
+
{{ config.copyright }}
+ {% endif %} {% if not config.extra.generator == false %} Made with
+
+ Material for MkDocs
+
+ {% endif %}
+
diff --git a/docs/overrides/partials/logo.html b/docs/overrides/partials/logo.html
new file mode 100644
index 00000000..2ed1c763
--- /dev/null
+++ b/docs/overrides/partials/logo.html
@@ -0,0 +1,5 @@
+{% if config.theme.logo %}
+
+{% else %}
+
+{% endif %}
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
new file mode 100644
index 00000000..cda9da97
--- /dev/null
+++ b/docs/stylesheets/extra.css
@@ -0,0 +1,235 @@
+[data-md-color-primary="vector"] {
+ --md-primary-fg-color: #eb088a;
+ --md-primary-fg-color--light: #f252a5;
+ --md-primary-fg-color--dark: #b00068;
+ --md-primary-bg-color: hsla(0, 0%, 100%, 1);
+ --md-primary-bg-color--light: hsla(0, 0%, 100%, 0.7);
+}
+
+[data-md-color-primary="black"] {
+ --md-primary-fg-color: #181818;
+ --md-primary-fg-color--light: #f252a5;
+ --md-primary-fg-color--dark: #b00068;
+ --md-primary-bg-color: #eb088a;
+}
+
+[data-md-color-accent="vector-teal"] {
+ --md-accent-fg-color: #48c0d9;
+ --md-accent-fg-color--transparent: #526cfe1a;
+ --md-accent-bg-color: #fff;
+ --md-accent-bg-color--light: #ffffffb3;
+}
+
+[data-md-color-scheme="slate"][data-md-color-primary="black"] {
+ --md-typeset-a-color: #eb088a;
+}
+
+[data-md-color-scheme="default"] {
+ /* Default light mode styling */
+}
+
+[data-md-color-scheme="slate"] {
+ --md-typeset-a-color: #eb088a;
+ /* Dark mode styling */
+}
+
+/* Vector logo css styling to match overrides/partial/copyright.html */
+.md-footer-vector {
+ display: flex;
+ align-items: center;
+ padding: 0 0.6rem;
+}
+
+.md-footer-vector img {
+ height: 24px; /* Reduce height to a fixed value */
+ width: auto; /* Maintain aspect ratio */
+ transition: opacity 0.25s;
+ opacity: 0.7;
+}
+
+.md-footer-vector img:hover {
+ opacity: 1;
+}
+
+/* Make the inner footer grid elements distribute evenly */
+.md-footer-meta__inner {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+
+/* To make socials and Vector logo not stack when viewing on mobile */
+@media screen and (max-width: 76.234375em) {
+ .md-footer-meta__inner.md-grid {
+ flex-direction: row;
+ justify-content: space-between;
+ align-items: center;
+ }
+
+ .md-copyright,
+ .md-social {
+ width: auto;
+ max-width: 49%;
+ }
+
+ /* Prevent margin that causes stacking */
+ .md-social {
+ margin: 0;
+ }
+}
+
+/* Reduce margins for h2 when using grid cards */
+.grid.cards h2 {
+ margin-top: 0; /* Remove top margin completely in cards */
+ margin-bottom: 0.5rem; /* Smaller bottom margin in cards */
+}
+
+.vector-icon {
+ color: #eb088a;
+ opacity: 0.7;
+ margin-right: 0.2em;
+}
+
+/* Version selector styling - Material theme */
+
+/* Version selector container */
+.md-version {
+ position: relative;
+ display: inline-block;
+ margin-left: 0.25rem;
+}
+
+/* Current version button styling */
+.md-version__current {
+ display: inline-flex;
+ align-items: center;
+ font-size: 0.7rem;
+ font-weight: 600;
+ color: var(--md-primary-bg-color);
+ padding: 0.4rem 0.8rem;
+ margin: 0.4rem 0;
+ background-color: rgba(255, 255, 255, 0.1);
+ border-radius: 4px;
+ border: 1px solid rgba(255, 255, 255, 0.2);
+ cursor: pointer;
+ transition: all 0.15s ease-in-out;
+}
+
+/* Hover effect for current version button */
+.md-version__current:hover {
+ background-color: rgba(255, 255, 255, 0.2);
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
+}
+
+/* Down arrow for version dropdown */
+.md-version__current:after {
+ display: inline-block;
+ margin-left: 0.5rem;
+ content: "";
+ vertical-align: middle;
+ border-top: 0.3em solid;
+ border-right: 0.3em solid transparent;
+ border-bottom: 0;
+ border-left: 0.3em solid transparent;
+}
+
+/* Dropdown menu */
+.md-version__list {
+ position: absolute;
+ top: 100%;
+ left: 0;
+ z-index: 10;
+ min-width: 125%;
+ margin: 0.1rem 0 0;
+ padding: 0;
+ background-color: var(--md-primary-fg-color);
+ border-radius: 4px;
+ box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2);
+ opacity: 0;
+ visibility: hidden;
+ transform: translateY(-8px);
+ transition: all 0.2s ease;
+}
+
+/* Show dropdown when parent is hovered */
+.md-version:hover .md-version__list {
+ opacity: 1;
+ visibility: visible;
+ transform: translateY(0);
+}
+
+/* Version list items */
+.md-version__item {
+ list-style: none;
+ padding: 0;
+}
+
+/* Version links */
+.md-version__link {
+ display: block;
+ padding: 0.5rem 1rem;
+ font-size: 0.75rem;
+ color: var(--md-primary-bg-color);
+ transition: background-color 0.15s;
+ text-decoration: none;
+}
+
+/* Version link hover */
+.md-version__link:hover {
+ background-color: var(--md-primary-fg-color--dark);
+ text-decoration: none;
+}
+
+/* Active version in dropdown */
+.md-version__link--active {
+ background-color: var(--md-accent-fg-color);
+ color: var(--md-accent-bg-color);
+ font-weight: 700;
+}
+
+/* For the Material selector */
+.md-header__option {
+ display: flex;
+ align-items: center;
+}
+
+/* Version selector in Material 9.x */
+.md-select {
+ position: relative;
+ margin-left: 0.5rem;
+}
+
+.md-select__label {
+ font-size: 0.7rem;
+ font-weight: 600;
+ color: var(--md-primary-bg-color);
+ cursor: pointer;
+ padding: 0.4rem 0.8rem;
+ background-color: rgba(255, 255, 255, 0.1);
+ border-radius: 4px;
+ border: 1px solid rgba(255, 255, 255, 0.2);
+ transition: all 0.15s ease-in-out;
+}
+
+.md-select__label:hover {
+ background-color: rgba(255, 255, 255, 0.2);
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
+}
+
+/* Version selector in Material 9.2+ */
+.md-header__button.md-select {
+ display: inline-flex;
+ align-items: center;
+ margin: 0 0.8rem;
+}
+
+/* For Material 9.x+ with specific version selector */
+.md-typeset .md-version-warn {
+ padding: 0.6rem 1rem;
+ margin: 1.5rem 0;
+ background-color: rgba(235, 8, 138, 0.1);
+ border-left: 4px solid #eb088a;
+ border-radius: 0.2rem;
+ color: var(--md-default-fg-color);
+ font-size: 0.8rem;
+}
diff --git a/docs/usage.md b/docs/usage.md
new file mode 100644
index 00000000..e406fcb6
--- /dev/null
+++ b/docs/usage.md
@@ -0,0 +1,72 @@
+Example code to show how to use two classes in the package.
+
+## Use of `CrispNamModel` model class
+
+```python
+import torch
+
+from crisp_nam.models import CrispNamModel
+
+# Example usage and testing
+if __name__ == "__main__":
+
+ # Generate some test data
+ torch.manual_seed(42)
+ test_data = torch.randn(100, 5)
+
+ # Create model with L2 normalized projections
+ model = CrispNamModel(
+ num_features=5,
+ num_competing_risks=3,
+ hidden_sizes=[32, 32],
+ normalize_projections=True,
+ )
+
+ # Forward pass
+ risk_predictions = model(test_data)
+
+ print("First 5 risk predictions:", risk_predictions[:5])
+
+ #Calculate feature importance
+ feature_importance = model.calculate_feature_importance(test_data)
+ print("First 5 feature importance values:", feature_importance)
+
+ # Analyze projection weights
+ print('Analyzing projection weights...')
+ projection_weights = model.analyze_projection_weights()
+```
+
+## Use of `DeepHIT` model class
+
+```python
+import torch
+
+from crisp_nam.models.deephit_model import DeepHit
+
+# Example usage and testing
+if __name__ == "__main__":
+
+ # Generate some test data
+ torch.manual_seed(42)
+ test_data = torch.randn(100, 5)
+
+ input_dims = {
+ 'x_dim': test_data.shape[1],
+ 'num_Event': 2,
+ 'num_Category': 100
+ }
+
+ network_settings = {
+ 'h_dim_shared': 128,
+ 'h_dim_CS': 32,
+ 'num_layers_shared': 1,
+ 'num_layers_CS': 2,
+ 'active_fn': 'tanh',
+ 'keep_prob': 1.0 - 0.3 #1.0 - dropout_rate
+ }
+
+ model = DeepHit(input_dims, network_settings)
+
+ out = model.predict(test_data)
+ print("Output shape:", out.shape)
+```
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 00000000..193ed034
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,99 @@
+extra_css:
+ - stylesheets/extra.css
+extra:
+ generator: false
+ social:
+ - icon: fontawesome/brands/discord
+ link: 404.html
+ - icon: fontawesome/brands/github
+ link: https://github.com/VectorInstitute/crisp-nam
+ version:
+ provider: mike
+ default: latest
+markdown_extensions:
+ - attr_list
+ - admonition
+ - md_in_html
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ - pymdownx.inlinehilite
+ - pymdownx.details
+ - pymdownx.snippets
+ - pymdownx.superfences
+ - pymdownx.emoji:
+ emoji_index: !!python/name:material.extensions.emoji.twemoji
+ emoji_generator: !!python/name:material.extensions.emoji.to_svg
+ - toc:
+ permalink: true
+ - meta
+ - footnotes
+nav:
+ - Home: index.md
+ - API Reference: api.md
+ - Usage: usage.md
+ - Blog: blog.md
+plugins:
+ - search
+ - mike:
+ version_selector: true
+ css_dir: stylesheets
+ canonical_version: latest
+ alias_type: symlink
+ deploy_prefix: ''
+ - mkdocstrings:
+ default_handler: python
+ handlers:
+ python:
+ paths: [../crisp_nam]
+ options:
+ docstring_style: numpy
+ members_order: source
+ separate_signature: true
+ show_overloads: true
+ show_submodules: true
+ show_root_heading: false
+ show_root_full_path: true
+ show_root_toc_entry: false
+ show_symbol_type_heading: true
+ show_symbol_type_toc: true
+repo_url: https://github.com/VectorInstitute/crisp-nam
+repo_name: VectorInstitute/crisp-nam
+site_name: CRISP-NAM Documentation
+docs_dir: docs
+theme:
+ name: material
+ custom_dir: docs/overrides
+ favicon: assets/favicon-48x48.svg
+ features:
+ - content.code.annotate
+ - content.code.copy
+ - navigation.footer
+ - navigation.indexes
+ - navigation.instant
+ - navigation.tabs
+ - navigation.tabs.sticky
+ - navigation.top
+ - search.suggest
+ - search.highlight
+ - toc.follow
+ icon:
+ repo: fontawesome/brands/github
+ logo: assets/vector-logo.svg
+ logo_footer: assets/vector-logo.svg
+ palette:
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ primary: vector
+ accent: vector-teal
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ primary: black
+ accent: vector-teal
+ toggle:
+ icon: material/brightness-4
+ name: Switch to light mode
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 1308ddbd..64364dbe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "crisp-nam"
-version = "0.1.0"
+version = "0.1.1"
description = "CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models"
readme = "README.md"
requires-python = ">=3.10"
@@ -12,6 +12,7 @@ dependencies = [
"configargparse>=1.7",
"lifelines>=0.30.0",
"matplotlib>=3.10.1",
+ "mypy>=1.19.1",
"openpyxl>=3.1.5",
"optuna>=4.3.0",
"pandas>=2.2.3",
@@ -22,5 +23,104 @@ dependencies = [
"torch>=2.7.0",
]
+[dependency-groups]
+docs = [
+ "mkdocs>=1.5.3",
+ "mkdocs-material>=9.5.12",
+ "mkdocstrings>=0.24.1",
+ "mkdocstrings-python>=1.8.0",
+ "pymdown-extensions>=10.7.1",
+ "mike>=2.0.0",
+ "click<=8.2.1",
+]
+
+[project.urls]
+Homepage = "https://github.com/VectorInstitute/crisp-nam"
+
+[tool.mypy]
+follow_imports = "normal"
+ignore_missing_imports = false
+install_types = true
+pretty = true
+non_interactive = true
+allow_untyped_defs = false
+no_implicit_optional = true
+check_untyped_defs = true
+namespace_packages = true
+explicit_package_bases = true
+warn_unused_configs = true
+allow_subclassing_any = false
+allow_untyped_calls = false
+allow_incomplete_defs = false
+allow_untyped_decorators = false
+warn_redundant_casts = true
+warn_unused_ignores = true
+implicit_reexport = false
+strict_equality = true
+extra_checks = true
+mypy_path = "crisp_nam"
+
+[tool.ruff]
+include = ["crisp_nam/*.py", "pyproject.toml", "*.ipynb"]
+extend-exclude = ["*.csv", "*.json", "datasets/", "data_utils/", "training_scripts/"]
+line-length = 88
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+docstring-code-format = true
+
+[tool.ruff.lint]
+select = [
+ "A", # flake8-builtins
+ "B", # flake8-bugbear
+ "COM", # flake8-commas
+ "C4", # flake8-comprehensions
+ "RET", # flake8-return
+ "SIM", # flake8-simplify
+ "ICN", # flake8-import-conventions
+ "Q", # flake8-quotes
+ "RSE", # flake8-raise
+ "D", # pydocstyle
+ "E", # pycodestyle
+ "F", # pyflakes
+ "I", # isort
+ "W", # pycodestyle
+ "N", # pep8-naming
+ "ERA", # eradicate
+ "PL", # pylint
+]
+
+fixable = ["A", "B", "COM", "C4", "RET", "SIM", "ICN", "Q", "RSE", "D", "E", "F", "I", "W", "N", "ERA", "PL"]
+ignore = [
+ "B905", # `zip()` without an explicit `strict=` parameter
+ "E501", # line too long
+ "D203", # 1 blank line required before class docstring
+ "D213", # Multi-line docstring summary should start at the second line
+ "PLR2004", # Replace magic number with named constant
+ "PLR0913", # Too many arguments
+ "COM812", # Missing trailing comma
+ "ERA001", # Found commented-out code (too many false positives with math comments)
+ "A001", # Ignore variable `input` is shadowing a Python builtin (common for torch)
+ "A002", # Ignore variable `input` is shadowing a Python builtin in function (common for torch)
+ "D301", # r-strings for docstrings with backslashes
+]
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.pep8-naming]
+ignore-names = ["X*", "setUp"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+
+[tool.ruff.lint.pydocstyle]
+convention = "numpy"
+
+[tool.ruff.lint.pycodestyle]
+max-doc-length = 88
+
[tool.hatch.build.targets.sdist]
-only-include = ["crisp_nam"]
+only-include = ["crisp_nam"]
\ No newline at end of file
diff --git a/results/best_params/best_params_framingham_deephit.yaml b/results/best_params/best_params_framingham_deephit.yaml
index 92e10146..27f8d8e9 100644
--- a/results/best_params/best_params_framingham_deephit.yaml
+++ b/results/best_params/best_params_framingham_deephit.yaml
@@ -22,4 +22,4 @@ active_fn: tanh
# General parameters
dropout_rate: 0.3
seed: 42
-n_folds: 5
\ No newline at end of file
+n_folds: 5
diff --git a/results/best_params/best_params_pbc_deephit.yaml b/results/best_params/best_params_pbc_deephit.yaml
index 7fd8f3b2..a14f4765 100644
--- a/results/best_params/best_params_pbc_deephit.yaml
+++ b/results/best_params/best_params_pbc_deephit.yaml
@@ -23,4 +23,4 @@ active_fn: tanh
# General parameters
dropout_rate: 0.3
seed: 42
-n_folds: 5
\ No newline at end of file
+n_folds: 5
diff --git a/results/best_params/best_params_support2_deephit.yaml b/results/best_params/best_params_support2_deephit.yaml
index 36844915..4d1df68d 100644
--- a/results/best_params/best_params_support2_deephit.yaml
+++ b/results/best_params/best_params_support2_deephit.yaml
@@ -27,4 +27,4 @@ active_fn: tanh
# General parameters
dropout_rate: 0.3
seed: 42
-n_folds: 5
\ No newline at end of file
+n_folds: 5
diff --git a/results/best_params/best_params_synthetic_deephit.yaml b/results/best_params/best_params_synthetic_deephit.yaml
index f794156c..4b19279b 100644
--- a/results/best_params/best_params_synthetic_deephit.yaml
+++ b/results/best_params/best_params_synthetic_deephit.yaml
@@ -1,19 +1,18 @@
-scaling: standard
-num_epochs: 250
-batch_size: 256
-learning_rate: 1.0e-3
-l2_reg: 1.0e-5
-patience: 10
-alpha: 1.0
-beta: 0.0
-gamma: 0.0
-h_dim_shared: 128
-h_dim_CS: 32
-num_layers_shared: 1
-num_layers_CS: 2
-num_categories: 100
-active_fn: tanh
-dropout_rate: 0.3
-seed: 42
-n_folds: 5
-
+scaling: standard
+num_epochs: 250
+batch_size: 256
+learning_rate: 1.0e-3
+l2_reg: 1.0e-5
+patience: 10
+alpha: 1.0
+beta: 0.0
+gamma: 0.0
+h_dim_shared: 128
+h_dim_CS: 32
+num_layers_shared: 1
+num_layers_CS: 2
+num_categories: 100
+active_fn: tanh
+dropout_rate: 0.3
+seed: 42
+n_folds: 5
diff --git a/results/logs/nested_cv_raw_metrics_framingham.json b/results/logs/nested_cv_raw_metrics_framingham.json
index 5ac18164..330d31e7 100644
--- a/results/logs/nested_cv_raw_metrics_framingham.json
+++ b/results/logs/nested_cv_raw_metrics_framingham.json
@@ -138,4 +138,4 @@
"scaling": "standard",
"seed": 42
}
-}
\ No newline at end of file
+}
diff --git a/results/logs/nested_cv_raw_metrics_pbc.json b/results/logs/nested_cv_raw_metrics_pbc.json
index 4153a978..3c7dd293 100644
--- a/results/logs/nested_cv_raw_metrics_pbc.json
+++ b/results/logs/nested_cv_raw_metrics_pbc.json
@@ -138,4 +138,4 @@
"scaling": "standard",
"seed": 42
}
-}
\ No newline at end of file
+}
diff --git a/results/logs/nested_cv_raw_metrics_support.json b/results/logs/nested_cv_raw_metrics_support.json
index 4cbca613..e6bde10b 100644
--- a/results/logs/nested_cv_raw_metrics_support.json
+++ b/results/logs/nested_cv_raw_metrics_support.json
@@ -138,4 +138,4 @@
"scaling": "standard",
"seed": 42
}
-}
\ No newline at end of file
+}
diff --git a/results/logs/nested_cv_raw_metrics_synthetic.json b/results/logs/nested_cv_raw_metrics_synthetic.json
index d982097a..f8b5e73c 100644
--- a/results/logs/nested_cv_raw_metrics_synthetic.json
+++ b/results/logs/nested_cv_raw_metrics_synthetic.json
@@ -138,4 +138,4 @@
"scaling": "standard",
"seed": 42
}
-}
\ No newline at end of file
+}
diff --git a/training.md b/training.md
new file mode 100644
index 00000000..2b13b81d
--- /dev/null
+++ b/training.md
@@ -0,0 +1,214 @@
+# Training Instructions
+
+This file contains details information about training models within the `crisp_nam` repository.
+
+## Repository Structure
+The following is the structure for the repository. Files within the `crisp_nam` folder are available as a python package.
+
+```
+crisp_nam/
+├── blog/ # Blog
+├── crisp_nam/ # Main package
+│ ├── metrics/
+│ │ ├── __init__.py
+│ │ ├── calibration.py
+│ │ ├── discrimination.py
+│ │ └── ipcw.py
+│ ├── models/
+│ │ ├── __init__.py
+│ │ ├── crisp_nam_model.py
+│ │ └── deephit_model.py
+│ ├── utils/
+│ │ ├── __init__.py
+│ │ ├── loss.py
+│ │ ├── plotting.py
+│ │ └── risk_cif.py
+│ └── __init__.py
+├── data_utils/ # Data utilities
+│ ├── __init__.py
+│ ├── load_datasets.py
+│ └── survival_datasets.py
+├── datasets/ # Dataset files and loaders
+│ ├── metabric/
+│ │ ├── cleaned_features_final.csv
+│ │ └── label.csv
+│ ├── framingham_dataset.py
+│ ├── framingham.csv
+│ ├── pbc_dataset.py
+│ ├── pbc2.csv
+│ ├── support_dataset.py
+│ ├── support2.csv
+│ ├── SurvivalDataset.py
+│ ├── synthetic_comprisk.csv
+│ └── synthetic_dataset.py
+├── docs/ # Documentation of Pypi package.
+├── results/ # Results and outputs
+│ ├── best_params/ # Best parameters for dataset and model combinations
+│ │ ├── best_params_framingham_deephit.yaml
+│ │ ├── best_params_framingham.yaml
+│ │ ├── best_params_pbc_deephit.yaml
+│ │ ├── best_params_pbc.yaml
+│ │ ├── best_params_support.yaml
+│ │ ├── best_params_support2_deephit.yaml
+│ │ ├── best_params_synthetic_deephit.yaml
+│ │ └── best_params_synthetic.yaml
+│ ├── logs/ # Nested CV results and logs
+│ │ ├── nested_cv_best_params_*.yaml
+│ │ ├── nested_cv_detailed_metrics_*.csv
+│ │ ├── nested_cv_metrics_*.xlsx
+│ │ ├── nested_cv_raw_metrics_*.json
+│ │ └── nested_cv_summary_metrics_*.csv
+│ └── plots/ # Generated plots
+│ ├── nested_cv_feature_importance_risk_*_*.png
+│ └── nested_cv_shape_functions_risk_*_*.png
+├── training_scripts/ # Training scripts
+│ ├── config.yaml
+│ ├── model_utils.py
+│ ├── train_deephit_cuda.py
+│ ├── train_deephit.py
+│ ├── train_nested_cv.py # Nested cross-validation script
+│ ├── train.py
+│ ├── tune_optuna_optimized.py
+│ └── tune_optuna.py
+```
+
+### Available Datasets
+This repository includes 4 datasets: Framingham Heart Study, PBC, Support2 and Synthetic datasets. Detailed information is available in [datasets.md](./datasets.md).
+
+## Training Scripts
+
+The repository provides several specialized training scripts:
+
+- **`train.py`**: Standard model training with cross-validation and comprehensive evaluation
+- **`train_nested_cv.py`**: Robust nested cross-validation for unbiased performance estimation
+- **`tune_optuna.py`**: Hyperparameter optimization using Optuna's advanced algorithms
+- **`tune_optuna_optimized.py`**: Hyperparameter optimization using Optuna on a GPU.
+- **`train_deephit.py`**: DeepHit baseline implementation for comparative studies
+- **`train_deephit_cuda.py`**: DeepHit baseline implementation optimized for running on a GPU.
+
+Each script supports extensive configuration through command-line arguments and YAML config files, enabling reproducible experiments and easy parameter sweeps.
+
+### Running training scripts
+
+1. Modify training parameters in `training_scripts/train.py`
+ OR
+ Run either of following commands to see CLI arguments for passing training parameters:
+
+ ```bash
+ python training_scripts/train.py --help
+ ```
+
+ ```bash
+ uv run training_scripts/train.py --help
+ ```
+
+2. Run the training script
+
+ 1. via `python`
+ ```bash
+ source .venv/bin/activate
+ python training_scripts/train.py --dataset framingham
+ ```
+
+ 2. via `uv`
+ ```bash
+ uv run training_scripts/train.py --dataset framingham
+ ```
+
+### Running Nested Cross-Validation
+
+The nested cross-validation script performs robust model evaluation with hyperparameter optimization using inner and outer cross-validation loops. It automatically generates performance metrics, feature importance plots, and shape function visualizations.
+
+via `python`
+```bash
+python training_scripts/train_nested_cv.py --dataset framingham
+```
+
+via `uv`
+```bash
+uv run training_scripts/train_nested_cv.py --dataset framingham
+```
+
+## Configuration Parameters
+
+All parameters can be passed via command line or specified in a YAML config file:
+
+1. Dataset Configuration
+- `--dataset` (str): Dataset to use (choices: `framingham`, `support`, `pbc`, `synthetic`, default: `framingham`)
+- `--scaling` (str): Data scaling method for continuous features (choices: `minmax`, `standard`, `none`, default: `standard`)
+
+2. Training Parameters
+- `--num_epochs` (int): Number of training epochs (default: `250`)
+- `--batch_size` (int): Batch size for training (default: `512`)
+- `--patience` (int): Patience for early stopping (default: `10`)
+
+3. Cross-Validation Configuration
+- `--outer_folds` (int): Number of outer CV folds (default: `5`)
+- `--inner_folds` (int): Number of inner CV folds for hyperparameter tuning (default: `3`)
+- `--n_trials` (int): Number of Optuna trials per inner fold (default: `20`)
+
+4. Event Weighting
+- `--event_weighting` (str): Event weighting strategy (choices: `none`, `balanced`, `custom`, default: `none`)
+- `--custom_event_weights` (str): Custom weights for events (comma-separated, default: `None`)
+
+5. Other Parameters
+- `--seed` (int): Random seed for reproducibility (default: `42`)
+- `--config` (str): Path to YAML config file (default: looks for `config.yaml`)
+
+### Examples
+
+1. **Basic nested CV with default parameters:**
+```bash
+python training_scripts/train_nested_cv.py --dataset pbc
+```
+
+2. **Customized nested CV with specific parameters:**
+```bash
+python training_scripts/train_nested_cv.py \
+ --dataset support \
+ --outer_folds 10 \
+ --inner_folds 5 \
+ --n_trials 50 \
+ --num_epochs 500 \
+ --event_weighting balanced \
+ --scaling minmax \
+ --seed 123
+```
+
+3. **Using a config file:**
+```bash
+python training_scripts/train_nested_cv.py --config my_config.yaml
+```
+
+## Output Files
+
+The script generates several output files in the current directory:
+
+1. **Performance Metrics**
+- `nested_cv_summary_metrics_{dataset}.csv`: Summary table with mean ± std metrics
+- `nested_cv_detailed_metrics_{dataset}.csv`: Detailed results for each fold
+- `nested_cv_metrics_{dataset}.xlsx`: Excel file with multiple sheets (Summary, Detailed, Metadata)
+- `nested_cv_raw_metrics_{dataset}.json`: Raw metrics dictionary for reproducibility
+
+2. **Model Configuration**
+- `nested_cv_best_params_{dataset}.yaml`: Aggregated best hyperparameters across all folds
+
+3. **Visualizations**
+- `nested_cv_feature_importance_risk_{risk}_{dataset}.png`: Feature importance plots
+- `nested_cv_shape_functions_risk_{risk}_{dataset}.png`: Shape function plots for top features
+
+Results are saved to `results/plots/`:
+
+## Evaluation Metrics
+
+The script computes the following metrics at different time quantiles (25%, 50%, 75%):
+
+1. **AUC (Area Under the ROC Curve)**: Time-dependent AUC for discrimination
+ - 0.5 = random, >0.7 = good, >0.8 = excellent
+2. **TDCI (Time-Dependent Concordance Index)**: Harrell's C-index adapted for competing risks
+ - 0.5 = random, >0.7 = good, >0.8 = excellent
+3. **Brier Score**: Calibration metric measuring prediction accuracy
+ - 0 = perfect, <0.25 = good, >0.25 = poor
+
+> [!NOTE]
+> For `uv` installation, please visit follow instructions in their [official page](https://docs.astral.sh/uv/getting-started/installation/).
\ No newline at end of file
diff --git a/training_scripts/config.yaml b/training_scripts/config.yaml
index d361646f..5fd9e691 100644
--- a/training_scripts/config.yaml
+++ b/training_scripts/config.yaml
@@ -17,4 +17,4 @@ batch_norm: False
# Other parameters
seed: 42
-n_folds: 2
\ No newline at end of file
+n_folds: 2
diff --git a/training_scripts/model_utils.py b/training_scripts/model_utils.py
index 98be4143..812f0e62 100644
--- a/training_scripts/model_utils.py
+++ b/training_scripts/model_utils.py
@@ -1,7 +1,8 @@
import random
-import torch
import numpy as np
+import torch
+
def set_seed(seed=42):
random.seed(seed)
@@ -31,34 +32,35 @@ def step(self, val_loss):
self.should_stop = True
-
# Utility functions to create masks for DeepHit
def create_fc_mask1(k, t, num_Event, num_Category, device=None):
"""Create mask1 for loss calculation - for uncensored loss"""
N = len(k)
mask = torch.zeros((N, num_Event, num_Category), device=device)
-
+
for i in range(N):
if k[i] > 0: # Not censored
event_idx = int(k[i] - 1)
time_idx = int(t[i])
if time_idx < num_Category:
mask[i, event_idx, time_idx] = 1.0
-
+
return mask
+
def create_fc_mask2(t, num_Category, device=None):
"""Create mask2 for loss calculation - for censored loss"""
N = len(t)
mask = torch.zeros((N, num_Category), device=device)
-
+
for i in range(N):
time_idx = int(t[i])
for j in range(time_idx, num_Category):
mask[i, j] = 1.0
-
+
return mask
+
# Pre-create masks for DeepHit on GPU
def create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device):
"""
@@ -67,15 +69,16 @@ def create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device):
"""
batch_size = e.size(0)
mask1 = torch.zeros(batch_size, num_Event, num_Category, device=device)
-
+
for i in range(batch_size):
if e[i] > 0: # if not censored
event_idx = int(e[i].item()) - 1
t_idx = int(t_disc[i].item())
mask1[i, event_idx, t_idx] = 1
-
+
return mask1
+
def create_fc_mask2_gpu(t_disc, num_Category, device):
"""
Create second mask for DeepHit loss computation
@@ -83,9 +86,9 @@ def create_fc_mask2_gpu(t_disc, num_Category, device):
"""
batch_size = t_disc.size(0)
mask2 = torch.zeros(batch_size, num_Category, device=device)
-
+
for i in range(batch_size):
t_idx = int(t_disc[i].item())
mask2[i, t_idx:] = 1
-
+
return mask2
diff --git a/training_scripts/train.py b/training_scripts/train.py
index 2192a6c3..34d626c6 100644
--- a/training_scripts/train.py
+++ b/training_scripts/train.py
@@ -1,79 +1,116 @@
-import configargparse
from collections import defaultdict
-import torch
+import configargparse
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
-import torch.nn as nn
-import torch.optim as optim
-import matplotlib.pyplot as plt
-from tabulate import tabulate
-from torch.utils.data import DataLoader, Subset
+import torch
+from model_utils import EarlyStopping, set_seed
from sklearn.model_selection import StratifiedKFold
-from sklearn.preprocessing import StandardScaler, MinMaxScaler
-from sksurv.util import Surv
+from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sksurv.metrics import concordance_index_ipcw
+from sksurv.util import Surv
+from tabulate import tabulate
+from torch import optim
+from torch.utils.data import DataLoader
+from crisp_nam.metrics import auc_td, brier_score
from crisp_nam.models import CrispNamModel
from crisp_nam.utils import (
- weighted_negative_log_likelihood_loss,
+ compute_baseline_cif,
+ compute_l2_penalty,
negative_log_likelihood_loss,
- compute_l2_penalty
+ plot_coxnam_shape_functions,
+ plot_feature_importance,
+ predict_absolute_risk,
+ weighted_negative_log_likelihood_loss,
)
from data_utils import *
-from model_utils import EarlyStopping, set_seed
-from crisp_nam.metrics import brier_score, auc_td
-from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif
-from crisp_nam.utils import plot_coxnam_shape_functions, plot_feature_importance
+
def parse_args():
parser = configargparse.ArgumentParser(
description="Training script for MultiTaskCoxNAM model",
default_config_files=["config.yml"],
- config_file_parser_class=configargparse.YAMLConfigFileParser
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
)
-
- parser.add_argument("-c", "--config", is_config_file=True,
- help="Path to config file")
-
- parser.add_argument("--dataset", type=str, default="framingham",
- help="Dataset to use: (framingham, support, pbc, synthetic)")
-
- parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"],
- help="Data scaling method for continuous features")
-
- parser.add_argument("--num_epochs", type=int, default=500,
- help="Number of training epochs")
- parser.add_argument("--batch_size", type=int, default=256,
- help="Batch size for training")
- parser.add_argument("--learning_rate", type=float, default=1e-3,
- help="Learning rate for optimizer")
- parser.add_argument("--l2_reg", type=float, default=1e-3,
- help="L2 regularization weight")
- parser.add_argument("--patience", type=int, default=10,
- help="Patience for early stopping")
-
- parser.add_argument("--dropout_rate", type=float, default=0.5,
- help="Dropout rate for model")
- parser.add_argument("--feature_dropout", type=float, default=0.1,
- help="Feature dropout rate")
- parser.add_argument("--hidden_dimensions", type=str, default="64,64",
- help="Hidden layer dimensions (comma-separated)")
- parser.add_argument("--batch_norm", type=str, default="False", choices=["True", "False"],
- help="Whether to use batch normalization")
-
- parser.add_argument("--seed", type=int, default=42,
- help="Random seed for reproducibility")
- parser.add_argument("--n_folds", type=int, default=5,
- help="Number of folds for cross-validation")
-
- return parser.parse_args()
+ parser.add_argument(
+ "-c", "--config", is_config_file=True, help="Path to config file"
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ help="Dataset to use: (framingham, support, pbc, synthetic)",
+ )
+
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ help="Data scaling method for continuous features",
+ )
+
+ parser.add_argument(
+ "--num_epochs", type=int, default=500, help="Number of training epochs"
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=256, help="Batch size for training"
+ )
+ parser.add_argument(
+ "--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer"
+ )
+ parser.add_argument(
+ "--l2_reg", type=float, default=1e-3, help="L2 regularization weight"
+ )
+ parser.add_argument(
+ "--patience", type=int, default=10, help="Patience for early stopping"
+ )
+
+ parser.add_argument(
+ "--dropout_rate", type=float, default=0.5, help="Dropout rate for model"
+ )
+ parser.add_argument(
+ "--feature_dropout", type=float, default=0.1, help="Feature dropout rate"
+ )
+ parser.add_argument(
+ "--hidden_dimensions",
+ type=str,
+ default="64,64",
+ help="Hidden layer dimensions (comma-separated)",
+ )
+ parser.add_argument(
+ "--batch_norm",
+ type=str,
+ default="False",
+ choices=["True", "False"],
+ help="Whether to use batch normalization",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
+ )
+ parser.add_argument(
+ "--n_folds", type=int, default=5, help="Number of folds for cross-validation"
+ )
-def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_rate=1e-3,
- l2_reg=0.01, patience=10, event_weights=None, verbose=True):
+ return parser.parse_args()
+
+
+def train_model(
+ model,
+ train_loader,
+ val_loader=None,
+ num_epochs=500,
+ learning_rate=1e-3,
+ l2_reg=0.01,
+ patience=10,
+ event_weights=None,
+ verbose=True,
+):
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
early_stopper = EarlyStopping(patience=patience)
device = next(model.parameters()).device
@@ -87,12 +124,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r
# Use weighted loss if event_weights is provided
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
total = loss + reg
@@ -110,22 +153,30 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r
for x, t, e, _ in val_loader:
x, t, e = x.to(device), t.to(device), e.to(device)
risk_scores, _ = model(x)
-
+
# Use same loss function as in training
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
val_loss += (loss + reg).item()
avg_val_loss = val_loss / len(val_loader)
early_stopper.step(avg_val_loss)
if verbose:
- print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
+ print(
+ f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
+ )
if early_stopper.should_stop:
if verbose:
@@ -149,10 +200,10 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
abs_risks: Array of shape (n_samples, n_events, n_times) with predicted absolute risks.
times: List of time points at which to evaluate.
- Returns:
+ Returns
+ -------
metrics: dict of evaluation metrics by event and time.
"""
-
survival_train = Surv.from_arrays(e_train != 0, t_train)
survival_val = Surv.from_arrays(e_val != 0, t_val)
@@ -162,12 +213,11 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
for k in range(n_events):
for i, time in enumerate(times):
risk_preds = abs_risks[:, k, i]
-
-
+
try:
risk_preds_2d = np.zeros((len(risk_preds), len(times)))
- risk_preds_2d[:, i] = risk_preds
-
+ risk_preds_2d[:, i] = risk_preds
+
auc_score, _ = auc_td(
e_val,
t_val,
@@ -175,16 +225,14 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
)
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score))
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score))
except Exception as ex:
- print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan)
-
try:
-
brier_score_val, _ = brier_score(
e_val,
t_val,
@@ -192,30 +240,28 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
+ )
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(
+ float(brier_score_val)
)
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val))
except Exception as ex:
- print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}")
+ print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}")
print(f"Debug: surv_probs shape={risk_preds_2d.shape}")
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan)
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan)
-
try:
tdci_result = concordance_index_ipcw(
- survival_train,
- survival_val,
- estimate=risk_preds,
- tau=time
+ survival_train, survival_val, estimate=risk_preds, tau=time
)
if isinstance(tdci_result, tuple):
tdci_score = tdci_result[0]
else:
- tdci_score = tdci_result
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score))
+ tdci_score = tdci_result
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score))
except Exception as ex:
- print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan)
return metrics
@@ -224,88 +270,88 @@ def display_metrics_table(metrics_dict, n_folds=5, quantiles=[0.25, 0.5, 0.75]):
"""
Display evaluation metrics summarized across folds for different time quantiles
"""
-
-
time_points_by_metric = {}
event_types = set()
- metric_types = ['auc', 'tdci', 'brier']
-
+ metric_types = ["auc", "tdci", "brier"]
+
for key, values in metrics_dict.items():
if any(metric in key for metric in metric_types):
- parts = key.split('_')
+ parts = key.split("_")
metric_type = parts[0]
event_info = parts[1]
-
+
# Extract event type
- event_type = int(event_info.replace('event', ''))
+ event_type = int(event_info.replace("event", ""))
event_types.add(event_type)
-
+
# Extract time point
- if len(parts) > 2 and parts[2].startswith('t'):
- time_point = float(parts[2].replace('t', ''))
-
+ if len(parts) > 2 and parts[2].startswith("t"):
+ time_point = float(parts[2].replace("t", ""))
+
# Initialize nested dictionaries if needed
if (event_type, metric_type) not in time_points_by_metric:
time_points_by_metric[(event_type, metric_type)] = {}
-
+
# Store metrics by time point
time_points_by_metric[(event_type, metric_type)][time_point] = values
-
+
results = []
-
+
for event_type in sorted(event_types):
- row = {'Risk': f"Type {event_type}"}
-
+ row = {"Risk": f"Type {event_type}"}
+
for metric_type in metric_types:
if (event_type, metric_type) not in time_points_by_metric:
# Skip metrics that don't exist for this event type
for q in quantiles:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
continue
-
+
# Get all time points for this event/metric
time_data = time_points_by_metric[(event_type, metric_type)]
sorted_times = sorted(time_data.keys())
-
+
# For each quantile
for q in quantiles:
# Calculate the index for this quantile
q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q)))
-
+
# Get the corresponding time point for this quantile
q_time = sorted_times[q_idx]
-
+
# Get the metrics for this time point
q_values = time_data[q_time]
-
+
# Calculate and format statistics
if q_values:
value_array = np.array(q_values)
mean_val = np.nanmean(value_array)
std_val = np.nanstd(value_array)
- row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}"
+ row[f"{metric_type.upper()}_q{q:.2f}"] = (
+ f"{mean_val:.3f} ± {std_val:.3f}"
+ )
else:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
-
+
results.append(row)
-
+
df = pd.DataFrame(results)
-
+
# Define column order
- columns = ['Risk']
- for metric in ['AUC', 'TDCI', 'BRIER']:
+ columns = ["Risk"]
+ for metric in ["AUC", "TDCI", "BRIER"]:
for q in quantiles:
columns.append(f"{metric}_q{q:.2f}")
-
+
# Select columns in the right order (only those that exist)
df = df[[col for col in columns if col in df.columns]]
-
+
print("\nSummary Performance Metrics:")
- print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False))
-
+ print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False))
+
print("\nInterpretation:")
print("- AUC: 0.5=random, >0.7=good, >0.8=excellent")
- print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
+ print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor")
@@ -314,64 +360,104 @@ def parse_args():
parser = configargparse.ArgumentParser(
description="Training script for MultiTaskCoxNAM model",
default_config_files=["config.yaml"],
- config_file_parser_class=configargparse.YAMLConfigFileParser
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
)
-
+
# Config file option
- parser.add_argument("-c", "--config", is_config_file=True,
- help="Path to config file")
-
+ parser.add_argument(
+ "-c", "--config", is_config_file=True, help="Path to config file"
+ )
+
# Dataset
- parser.add_argument("--dataset", type=str, default="framingham", choices=["framingham", "support", "pbc", "synthetic"],
- help="Dataset to use")
-
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ choices=["framingham", "support", "pbc", "synthetic"],
+ help="Dataset to use",
+ )
+
# Data scaling
- parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"],
- help="Data scaling method for continuous features")
-
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ help="Data scaling method for continuous features",
+ )
+
# Training parameters
- parser.add_argument("--num_epochs", type=int, default=500,
- help="Number of training epochs")
- parser.add_argument("--batch_size", type=int, default=256,
- help="Batch size for training")
- parser.add_argument("--learning_rate", type=float, default=1e-3,
- help="Learning rate for optimizer")
- parser.add_argument("--l2_reg", type=float, default=1e-3,
- help="L2 regularization weight")
- parser.add_argument("--patience", type=int, default=10,
- help="Patience for early stopping")
-
-
- parser.add_argument("--dropout_rate", type=float, default=0.5,
- help="Dropout rate for model")
- parser.add_argument("--feature_dropout", type=float, default=0.1,
- help="Feature dropout rate")
- parser.add_argument("--hidden_dimensions", type=str, default="64,64",
- help="Hidden layer dimensions (comma-separated)")
- parser.add_argument("--batch_norm", type=str, default="False", choices=["True", "False"],
- help="Whether to use batch normalization")
-
-
- parser.add_argument("--event_weighting", type=str, default="none",
- choices=["none", "balanced", "custom"],
- help="Event weighting strategy (none, balanced, custom)")
- parser.add_argument("--custom_event_weights", type=str, default=None,
- help="Custom weights for events (comma-separated, e.g., '1.0,2.0')")
-
- #
- parser.add_argument("--seed", type=int, default=42,
- help="Random seed for reproducibility")
- parser.add_argument("--n_folds", type=int, default=5,
- help="Number of folds for cross-validation")
-
+ parser.add_argument(
+ "--num_epochs", type=int, default=500, help="Number of training epochs"
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=256, help="Batch size for training"
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-3,
+ help="Learning rate for optimizer",
+ )
+ parser.add_argument(
+ "--l2_reg", type=float, default=1e-3, help="L2 regularization weight"
+ )
+ parser.add_argument(
+ "--patience", type=int, default=10, help="Patience for early stopping"
+ )
+
+ parser.add_argument(
+ "--dropout_rate", type=float, default=0.5, help="Dropout rate for model"
+ )
+ parser.add_argument(
+ "--feature_dropout", type=float, default=0.1, help="Feature dropout rate"
+ )
+ parser.add_argument(
+ "--hidden_dimensions",
+ type=str,
+ default="64,64",
+ help="Hidden layer dimensions (comma-separated)",
+ )
+ parser.add_argument(
+ "--batch_norm",
+ type=str,
+ default="False",
+ choices=["True", "False"],
+ help="Whether to use batch normalization",
+ )
+
+ parser.add_argument(
+ "--event_weighting",
+ type=str,
+ default="none",
+ choices=["none", "balanced", "custom"],
+ help="Event weighting strategy (none, balanced, custom)",
+ )
+ parser.add_argument(
+ "--custom_event_weights",
+ type=str,
+ default=None,
+ help="Custom weights for events (comma-separated, e.g., '1.0,2.0')",
+ )
+
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
+ )
+ parser.add_argument(
+ "--n_folds",
+ type=int,
+ default=5,
+ help="Number of folds for cross-validation",
+ )
+
return parser.parse_args()
-
+
args = parse_args()
print(args)
-
+
# Set random seed for reproducibility
set_seed(args.seed)
-
+
# Load the dataset
if args.dataset.lower() == "framingham":
x, t, e, feature_names, n_cont, _ = load_framingham()
@@ -383,62 +469,62 @@ def parse_args():
x, t, e, feature_names, n_cont, _ = load_synthetic_dataset()
else:
raise ValueError(f"Dataset {args.dataset} not supported")
-
-
+
# Note: Scaling will be done inside cross-validation loop to prevent data leakage
# Compute weights based on event distribution
num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0)
- device = ("cuda" if torch.cuda.is_available() else "cpu")
-
-
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
event_weights = None
-
+
if args.event_weighting != "none":
if args.event_weighting == "balanced":
# Compute balanced weights (inverse of class frequencies)
event_counts = np.zeros(num_competing_risks)
for k in range(1, num_competing_risks + 1):
- event_counts[k-1] = np.sum(e == k)
-
+ event_counts[k - 1] = np.sum(e == k)
event_counts = np.maximum(event_counts, 1)
-
+
# Inverse frequency weighting
event_weights = 1.0 / event_counts
-
+
# Normalize weights to sum to num_competing_risks
event_weights = event_weights * (num_competing_risks / event_weights.sum())
-
+
print(f"Computed balanced event weights: {event_weights}")
-
+
elif args.event_weighting == "custom":
if args.custom_event_weights is None:
- raise ValueError("Custom event weights must be provided when using custom weighting")
-
+ raise ValueError(
+ "Custom event weights must be provided when using custom weighting"
+ )
+
custom_weights = [float(w) for w in args.custom_event_weights.split(",")]
if len(custom_weights) != num_competing_risks:
- raise ValueError(f"Expected {num_competing_risks} weights, got {len(custom_weights)}")
-
+ raise ValueError(
+ f"Expected {num_competing_risks} weights, got {len(custom_weights)}"
+ )
+
event_weights = np.array(custom_weights)
print(f"Using custom event weights: {event_weights}")
-
+
event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device)
skf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed)
all_metrics = defaultdict(list)
quantiles = [0.25, 0.5, 0.75]
-
safe_max = 0.99 * np.max(t)
eval_times = np.quantile(t[t <= safe_max], quantiles)
-
+
print(f"Evaluation times: {eval_times}")
-
+
# Parse hidden dimensions string to list of integers
hidden_dimensions = [int(dim) for dim in args.hidden_dimensions.split(",")]
-
+
# Convert batch_norm string to boolean
batch_norm = args.batch_norm.lower() == "true"
@@ -470,25 +556,31 @@ def parse_args():
# Recalculate weights based on training set for this fold
fold_event_counts = np.zeros(num_competing_risks)
for k in range(1, num_competing_risks + 1):
- fold_event_counts[k-1] = np.sum(e_train == k)
-
+ fold_event_counts[k - 1] = np.sum(e_train == k)
+
# Avoid division by zero
fold_event_counts = np.maximum(fold_event_counts, 1)
-
+
# Inverse frequency weighting
fold_event_weights = 1.0 / fold_event_counts
-
+
# Normalize weights to sum to num_competing_risks
- fold_event_weights = fold_event_weights * (num_competing_risks / fold_event_weights.sum())
- fold_event_weights = torch.tensor(fold_event_weights, dtype=torch.float32, device=device)
-
- print(f"Fold {fold+1} event weights: {fold_event_weights.cpu().numpy()}")
+ fold_event_weights = fold_event_weights * (
+ num_competing_risks / fold_event_weights.sum()
+ )
+ fold_event_weights = torch.tensor(
+ fold_event_weights, dtype=torch.float32, device=device
+ )
+
+ print(f"Fold {fold + 1} event weights: {fold_event_weights.cpu().numpy()}")
# Create datasets with properly normalized data
train_dataset = SurvivalDataset(x_train, t_train, e_train)
val_dataset = SurvivalDataset(x_val, t_val, e_val)
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True
+ )
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
model = CrispNamModel(
@@ -497,67 +589,72 @@ def parse_args():
hidden_sizes=hidden_dimensions,
dropout_rate=args.dropout_rate,
feature_dropout=args.feature_dropout,
- batch_norm=batch_norm
+ batch_norm=batch_norm,
).to(device)
-
- train_model(model, train_loader, val_loader,
- num_epochs=args.num_epochs,
- learning_rate=args.learning_rate,
- l2_reg=args.l2_reg,
- patience=args.patience,
- event_weights=fold_event_weights,
- verbose=True)
-
+ train_model(
+ model,
+ train_loader,
+ val_loader,
+ num_epochs=args.num_epochs,
+ learning_rate=args.learning_rate,
+ l2_reg=args.l2_reg,
+ patience=args.patience,
+ event_weights=fold_event_weights,
+ verbose=True,
+ )
+
# Calculate baseline CIFs using the same eval times for all folds
- baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
- for k in range(num_competing_risks)}
-
- abs_risks = predict_absolute_risk(model, x_val, baseline_cifs, eval_times, device=device)
-
-
- fold_metrics = evaluate_model(model, x_val, t_val, e_val,
- t_train, e_train, abs_risks, eval_times)
-
+ baseline_cifs = {
+ k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
+ for k in range(num_competing_risks)
+ }
+
+ abs_risks = predict_absolute_risk(
+ model, x_val, baseline_cifs, eval_times, device=device
+ )
+
+ fold_metrics = evaluate_model(
+ model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times
+ )
+
for k, v in fold_metrics.items():
all_metrics[k].extend(v)
-
display_metrics_table(all_metrics, n_folds=args.n_folds)
-
- #create figs subdirectory if not present
+
+ # create figs subdirectory if not present
import os
+
if not os.path.exists("figs"):
os.makedirs("figs")
-
for risk in range(1, num_competing_risks + 1):
-
fig, _, top_positive, top_negative = plot_feature_importance(
model=model,
x_data=x,
feature_names=feature_names,
n_top=5, # Show top 5 positive contributors
n_bottom=5, # Show top 5 negative contributors
- risk_idx=risk,
+ risk_idx=risk,
figsize=(6, 4),
- output_file=f"figs/feature_importance_risk_new_{risk}_{args.dataset}.png"
- )
-
+ output_file=f"figs/feature_importance_risk_new_{risk}_{args.dataset}.png",
+ )
+
top_features = top_positive + top_negative
-
+
fig, axes = plot_coxnam_shape_functions(
model=model,
- X=x,
+ X=x,
risk_to_plot=risk,
- feature_names=feature_names,
- top_features=top_features,
+ feature_names=feature_names,
+ top_features=top_features,
ncols=5,
- figsize=(12,6),
- output_file=f"figs/shape_functions_top_features_risk{risk}_{args.dataset}.png"
+ figsize=(12, 6),
+ output_file=f"figs/shape_functions_top_features_risk{risk}_{args.dataset}.png",
)
plt.close(fig)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/training_scripts/train_deephit_cuda.py b/training_scripts/train_deephit_cuda.py
index d670acb5..b861d7a3 100644
--- a/training_scripts/train_deephit_cuda.py
+++ b/training_scripts/train_deephit_cuda.py
@@ -1,96 +1,165 @@
import os
-import configargparse
from collections import defaultdict
-import torch
+import configargparse
import numpy as np
-import torch.optim as optim
-from sksurv.util import Surv
-from torch.utils.data import DataLoader, Subset
-from sklearn.model_selection import StratifiedKFold
-from sklearn.preprocessing import StandardScaler, MinMaxScaler
-from sksurv.metrics import concordance_index_ipcw
-
-# Import the DeepHit model implementation
-from data_utils import *
+import torch
from model_utils import (
- set_seed,
EarlyStopping,
create_fc_mask1_gpu,
- create_fc_mask2_gpu
+ create_fc_mask2_gpu,
+ set_seed,
)
+from sklearn.model_selection import StratifiedKFold
+from sklearn.preprocessing import MinMaxScaler, StandardScaler
+from sksurv.metrics import concordance_index_ipcw
+from sksurv.util import Surv
+from torch import optim
+from torch.utils.data import DataLoader, Subset
+
+from crisp_nam.metrics import auc_td, brier_score
from crisp_nam.models import DeepHit
-from crisp_nam.metrics import brier_score, auc_td
+
+# Import the DeepHit model implementation
+from data_utils import *
+
def parse_args():
parser = configargparse.ArgumentParser(
description="Training script for DeepHit model",
default_config_files=["config.yml"],
- config_file_parser_class=configargparse.YAMLConfigFileParser
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
)
-
- parser.add_argument("-c", "--config", is_config_file=True,
- help="Path to config file")
-
+
+ parser.add_argument(
+ "-c", "--config", is_config_file=True, help="Path to config file"
+ )
+
# Dataset
- parser.add_argument("--dataset", type=str, default="framingham",
- help="Dataset to use: (framingham, support, pbc, synthetic)")
-
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ help="Dataset to use: (framingham, support, pbc, synthetic)",
+ )
+
# Data scaling
- parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"],
- help="Data scaling method for continuous features")
-
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ help="Data scaling method for continuous features",
+ )
+
# Training parameters
- parser.add_argument("--num_epochs", type=int, default=500,
- help="Number of training epochs")
- parser.add_argument("--batch_size", type=int, default=512,
- help="Batch size for training")
- parser.add_argument("--learning_rate", type=float, default=1e-3,
- help="Learning rate for optimizer")
- parser.add_argument("--l2_reg", type=float, default=1e-4,
- help="L2 regularization weight")
- parser.add_argument("--patience", type=int, default=10,
- help="Patience for early stopping")
-
+ parser.add_argument(
+ "--num_epochs", type=int, default=500, help="Number of training epochs"
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=512, help="Batch size for training"
+ )
+ parser.add_argument(
+ "--learning_rate", type=float, default=1e-3, help="Learning rate for optimizer"
+ )
+ parser.add_argument(
+ "--l2_reg", type=float, default=1e-4, help="L2 regularization weight"
+ )
+ parser.add_argument(
+ "--patience", type=int, default=10, help="Patience for early stopping"
+ )
+
# DeepHit specific parameters
- parser.add_argument("--alpha", type=float, default=1.0,
- help="Weight for log-likelihood loss")
- parser.add_argument("--beta", type=float, default=0.5,
- help="Weight for ranking loss")
- parser.add_argument("--gamma", type=float, default=0.5,
- help="Weight for calibration loss")
- parser.add_argument("--h_dim_shared", type=int, default=64,
- help="Hidden dimension for shared network")
- parser.add_argument("--h_dim_CS", type=int, default=16,
- help="Hidden dimension for cause-specific networks")
- parser.add_argument("--num_layers_shared", type=int, default=2,
- help="Number of layers in shared network")
- parser.add_argument("--num_layers_CS", type=int, default=2,
- help="Number of layers in cause-specific networks")
- parser.add_argument("--num_categories", type=int, default=100,
- help="Number of time categories for discretization")
- parser.add_argument("--active_fn", type=str, default="relu", choices=["relu", "elu", "tanh"],
- help="Activation function")
-
+ parser.add_argument(
+ "--alpha", type=float, default=1.0, help="Weight for log-likelihood loss"
+ )
+ parser.add_argument(
+ "--beta", type=float, default=0.5, help="Weight for ranking loss"
+ )
+ parser.add_argument(
+ "--gamma", type=float, default=0.5, help="Weight for calibration loss"
+ )
+ parser.add_argument(
+ "--h_dim_shared",
+ type=int,
+ default=64,
+ help="Hidden dimension for shared network",
+ )
+ parser.add_argument(
+ "--h_dim_CS",
+ type=int,
+ default=16,
+ help="Hidden dimension for cause-specific networks",
+ )
+ parser.add_argument(
+ "--num_layers_shared",
+ type=int,
+ default=2,
+ help="Number of layers in shared network",
+ )
+ parser.add_argument(
+ "--num_layers_CS",
+ type=int,
+ default=2,
+ help="Number of layers in cause-specific networks",
+ )
+ parser.add_argument(
+ "--num_categories",
+ type=int,
+ default=100,
+ help="Number of time categories for discretization",
+ )
+ parser.add_argument(
+ "--active_fn",
+ type=str,
+ default="relu",
+ choices=["relu", "elu", "tanh"],
+ help="Activation function",
+ )
+
# General parameters
- parser.add_argument("--dropout_rate", type=float, default=0.2,
- help="Dropout rate for model")
- parser.add_argument("--seed", type=int, default=42,
- help="Random seed for reproducibility")
- parser.add_argument("--n_folds", type=int, default=5,
- help="Number of folds for cross-validation")
- parser.add_argument("--num_workers", type=int, default=8,
- help="Number of workers for data loading")
- parser.add_argument("--eval_freq", type=int, default=10,
- help="Evaluate every N epochs during training")
- parser.add_argument("--use_amp", action="store_true",
- help="Use Automatic Mixed Precision for training")
-
+ parser.add_argument(
+ "--dropout_rate", type=float, default=0.2, help="Dropout rate for model"
+ )
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
+ )
+ parser.add_argument(
+ "--n_folds", type=int, default=5, help="Number of folds for cross-validation"
+ )
+ parser.add_argument(
+ "--num_workers", type=int, default=8, help="Number of workers for data loading"
+ )
+ parser.add_argument(
+ "--eval_freq",
+ type=int,
+ default=10,
+ help="Evaluate every N epochs during training",
+ )
+ parser.add_argument(
+ "--use_amp",
+ action="store_true",
+ help="Use Automatic Mixed Precision for training",
+ )
+
return parser.parse_args()
-def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.0, gamma=1.0,
- num_epochs=500, learning_rate=1e-3, l2_reg=0.01, patience=10,
- eval_freq=10, use_amp=False, verbose=True):
+
+def train_deephit_model(
+ model,
+ train_loader,
+ val_loader=None,
+ alpha=1.0,
+ beta=1.0,
+ gamma=1.0,
+ num_epochs=500,
+ learning_rate=1e-3,
+ l2_reg=0.01,
+ patience=10,
+ eval_freq=10,
+ use_amp=False,
+ verbose=True,
+):
"""
Train the DeepHit model using the three loss components
Optimized with AMP support, prefetching, and CUDA streams
@@ -98,35 +167,35 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.
device = next(model.parameters()).device
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=l2_reg)
early_stopper = EarlyStopping(patience=patience)
-
+
# Get model dimensions for mask creation
num_Event = model.num_Event
num_Category = model.num_Category
-
+
# Initialize scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler() if use_amp else None
-
+
# Enable asynchronous CUDA execution
torch.backends.cudnn.benchmark = True
-
+
# Increase batch size if it's too small (optional)
# if train_loader.batch_size < 512 and device == 'cuda':
# print("Warning: Small batch size. Consider increasing for better GPU utilization.")
-
+
# Prefetch data using CUDA streams for async data loading
- prefetch_stream = torch.cuda.Stream() if device == 'cuda' else None
-
+ prefetch_stream = torch.cuda.Stream() if device == "cuda" else None
+
# Precomputed masks for common events (optimization)
precomputed_masks2 = {}
for t_val in range(num_Category):
mask = torch.zeros(1, num_Category, device=device)
mask[0, t_val:] = 1
precomputed_masks2[t_val] = mask
-
+
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
-
+
# Prefetch first batch
if prefetch_stream is not None:
try:
@@ -135,34 +204,54 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.
prefetch_batch = next(batch_iter)
# Prefetch to GPU asynchronously
with torch.cuda.stream(prefetch_stream):
- prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [t.to(device, non_blocking=True) for t in prefetch_batch]
+ prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [
+ t.to(device, non_blocking=True) for t in prefetch_batch
+ ]
# Precompute masks
- prefetch_mask1 = create_fc_mask1_gpu(prefetch_e, prefetch_t_disc, num_Event, num_Category, device)
- prefetch_mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in prefetch_t_disc], dim=0)
+ prefetch_mask1 = create_fc_mask1_gpu(
+ prefetch_e, prefetch_t_disc, num_Event, num_Category, device
+ )
+ prefetch_mask2 = torch.cat(
+ [
+ precomputed_masks2[int(t_val.item())]
+ for t_val in prefetch_t_disc
+ ],
+ dim=0,
+ )
except StopIteration:
prefetch_batch = None
else:
prefetch_batch = None
-
+
# Main training loop with prefetching
batch_iter = iter(train_loader)
more_batches = True
-
+
while more_batches:
- # Synchronize with prefetched data
+ # Synchronize with prefetched data
if prefetch_stream is not None and prefetch_batch is not None:
torch.cuda.current_stream().wait_stream(prefetch_stream)
x, t, e, t_disc = prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc
mask1, mask2 = prefetch_mask1, prefetch_mask2
-
+
# Prefetch next batch
try:
prefetch_batch = next(batch_iter)
with torch.cuda.stream(prefetch_stream):
- prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [t.to(device, non_blocking=True) for t in prefetch_batch]
+ prefetch_x, prefetch_t, prefetch_e, prefetch_t_disc = [
+ t.to(device, non_blocking=True) for t in prefetch_batch
+ ]
# Precompute masks
- prefetch_mask1 = create_fc_mask1_gpu(prefetch_e, prefetch_t_disc, num_Event, num_Category, device)
- prefetch_mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in prefetch_t_disc], dim=0)
+ prefetch_mask1 = create_fc_mask1_gpu(
+ prefetch_e, prefetch_t_disc, num_Event, num_Category, device
+ )
+ prefetch_mask2 = torch.cat(
+ [
+ precomputed_masks2[int(t_val.item())]
+ for t_val in prefetch_t_disc
+ ],
+ dim=0,
+ )
except StopIteration:
prefetch_batch = None
more_batches = False
@@ -172,20 +261,27 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.
batch = next(batch_iter)
x, t, e, t_disc = [t.to(device, non_blocking=True) for t in batch]
# Create masks for loss computation
- mask1 = create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device)
- mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in t_disc], dim=0)
+ mask1 = create_fc_mask1_gpu(
+ e, t_disc, num_Event, num_Category, device
+ )
+ mask2 = torch.cat(
+ [precomputed_masks2[int(t_val.item())] for t_val in t_disc],
+ dim=0,
+ )
except StopIteration:
more_batches = False
continue
-
+
# Automatic mixed precision
if use_amp:
with torch.cuda.amp.autocast():
# Forward pass
out, _ = model(x)
# Compute loss components
- loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma)
-
+ loss = model.compute_loss(
+ out, t, e, mask1, mask2, alpha, beta, gamma
+ )
+
# Optimization step with gradient scaling
optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad()
scaler.scale(loss).backward()
@@ -196,42 +292,55 @@ def train_deephit_model(model, train_loader, val_loader=None, alpha=1.0, beta=1.
out, _ = model(x)
# Compute loss components
loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma)
-
+
# Regular optimization step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
-
+
total_loss += loss.item()
-
+
avg_train_loss = total_loss / len(train_loader)
-
+
# Validation if provided, but only every eval_freq epochs to save time
if val_loader and (epoch % eval_freq == 0 or epoch == num_epochs - 1):
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, t, e, t_disc in val_loader:
- x, t, e, t_disc = x.to(device, non_blocking=True), t.to(device, non_blocking=True), \
- e.to(device, non_blocking=True), t_disc.to(device, non_blocking=True)
-
+ x, t, e, t_disc = (
+ x.to(device, non_blocking=True),
+ t.to(device, non_blocking=True),
+ e.to(device, non_blocking=True),
+ t_disc.to(device, non_blocking=True),
+ )
+
# Create masks for validation (using precomputed masks)
- mask1 = create_fc_mask1_gpu(e, t_disc, num_Event, num_Category, device)
- mask2 = torch.cat([precomputed_masks2[int(t_val.item())] for t_val in t_disc], dim=0)
-
+ mask1 = create_fc_mask1_gpu(
+ e, t_disc, num_Event, num_Category, device
+ )
+ mask2 = torch.cat(
+ [precomputed_masks2[int(t_val.item())] for t_val in t_disc],
+ dim=0,
+ )
+
# Forward pass
out, _ = model(x)
-
+
# Compute loss
- loss = model.compute_loss(out, t, e, mask1, mask2, alpha, beta, gamma)
+ loss = model.compute_loss(
+ out, t, e, mask1, mask2, alpha, beta, gamma
+ )
val_loss += loss.item()
-
+
avg_val_loss = val_loss / len(val_loader)
early_stopper.step(avg_val_loss)
-
+
if verbose and (epoch % eval_freq == 0 or epoch == num_epochs - 1):
- print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
-
+ print(
+ f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
+ )
+
if early_stopper.should_stop:
if verbose:
print("Early stopping triggered.")
@@ -246,100 +355,109 @@ def predict_absolute_risk_deephit(model, x_test, times, t_max=None, device="cpu"
Optimized to keep more operations on GPU
"""
model.eval()
-
+
# Convert x_test to tensor and move to device
if not isinstance(x_test, torch.Tensor):
x_test_tensor = torch.tensor(x_test, dtype=torch.float32, device=device)
else:
x_test_tensor = x_test.to(device)
-
+
# Process in batches for large datasets
batch_size = 1024 # Can be adjusted based on available GPU memory
n_samples = x_test_tensor.shape[0]
-
+
# Get model dimensions
n_events = model.num_Event
n_categories = model.num_Category
n_times = len(times)
-
+
# Set t_max if not provided
if t_max is None:
t_max = max(times)
-
+
# Initialize the output tensor on GPU
abs_risks = torch.zeros(n_samples, n_events, n_times, device=device)
-
+
# Process in batches
with torch.no_grad():
for i in range(0, n_samples, batch_size):
end_idx = min(i + batch_size, n_samples)
batch = x_test_tensor[i:end_idx]
-
+
# Get model predictions
preds, _ = model(batch) # shape: (batch_size, num_Event, num_Category)
-
+
# For each evaluation time, calculate CIF
for t_idx, t in enumerate(times):
# Scale time to [0, num_categories-1]
scaled_t = (t / t_max) * (n_categories - 1)
-
+
# Find lower and upper bin indices
lower_bin = int(np.floor(scaled_t))
- upper_bin = min(int(np.ceil(scaled_t)), n_categories-1)
-
+ upper_bin = min(int(np.ceil(scaled_t)), n_categories - 1)
+
# Handle boundary cases
if lower_bin == upper_bin or upper_bin >= n_categories:
# Exactly at a bin boundary or at/beyond the last bin
for k in range(n_events):
# Get cumulative probability up to this bin
- abs_risks[i:end_idx, k, t_idx] = torch.sum(preds[:, k, :lower_bin+1], dim=1)
+ abs_risks[i:end_idx, k, t_idx] = torch.sum(
+ preds[:, k, : lower_bin + 1], dim=1
+ )
else:
# Need to interpolate between bins
weight_upper = scaled_t - lower_bin
weight_lower = 1 - weight_upper
-
+
for k in range(n_events):
# Cumulative probability up to lower bin
- cum_prob_lower = torch.sum(preds[:, k, :lower_bin+1], dim=1)
-
+ cum_prob_lower = torch.sum(preds[:, k, : lower_bin + 1], dim=1)
+
# Cumulative probability up to upper bin
- cum_prob_upper = torch.sum(preds[:, k, :upper_bin+1], dim=1)
-
+ cum_prob_upper = torch.sum(preds[:, k, : upper_bin + 1], dim=1)
+
# Linear interpolation
- abs_risks[i:end_idx, k, t_idx] = weight_lower * cum_prob_lower + weight_upper * cum_prob_upper
-
+ abs_risks[i:end_idx, k, t_idx] = (
+ weight_lower * cum_prob_lower
+ + weight_upper * cum_prob_upper
+ )
+
# Move results back to CPU and convert to numpy only if needed
return abs_risks.cpu().numpy()
-def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, times, max_time, device="cpu"):
+def evaluate_model(
+ model, x_val, t_val, e_val, t_train, e_train, times, max_time, device="cpu"
+):
"""
Evaluate the DeepHit model using time-dependent AUC, Brier score, and td-C-index
"""
# Predict absolute risks for validation set
- abs_risks_tensor = predict_absolute_risk_deephit(model, x_val, times, t_max=max_time, device=device)
+ abs_risks_tensor = predict_absolute_risk_deephit(
+ model, x_val, times, t_max=max_time, device=device
+ )
abs_risks = abs_risks_tensor # Now returns numpy array directly
-
+
# Format data for scikit-survival functions
survival_train = Surv.from_arrays(e_train != 0, t_train)
survival_val = Surv.from_arrays(e_val != 0, t_val)
-
+
metrics = defaultdict(list)
n_events = abs_risks.shape[1]
-
+
# Calculate metrics in parallel using joblib if available
try:
from joblib import Parallel, delayed
-
+
def process_metric(k, i, time):
result = {}
risk_preds = abs_risks[:, k, i]
-
+
try:
# Reshape to format needed by custom function (n_samples, n_times)
risk_preds_2d = np.zeros((len(risk_preds), len(times)))
risk_preds_2d[:, i] = risk_preds
-
+
auc_score, _ = auc_td(
e_val,
t_val,
@@ -347,12 +465,12 @@ def process_metric(k, i, time):
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
)
- result[f"auc_event{k+1}_t{time:.2f}"] = float(auc_score)
- except Exception as ex:
- result[f"auc_event{k+1}_t{time:.2f}"] = np.nan
-
+ result[f"auc_event{k + 1}_t{time:.2f}"] = float(auc_score)
+ except Exception:
+ result[f"auc_event{k + 1}_t{time:.2f}"] = np.nan
+
try:
brier_score_val, _ = brier_score(
e_val,
@@ -361,51 +479,50 @@ def process_metric(k, i, time):
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
)
- result[f"brier_event{k+1}_t{time:.2f}"] = float(brier_score_val)
- except Exception as ex:
- result[f"brier_event{k+1}_t{time:.2f}"] = np.nan
-
+ result[f"brier_event{k + 1}_t{time:.2f}"] = float(brier_score_val)
+ except Exception:
+ result[f"brier_event{k + 1}_t{time:.2f}"] = np.nan
+
try:
tdci_result = concordance_index_ipcw(
- survival_train,
- survival_val,
- estimate=risk_preds,
- tau=time
+ survival_train, survival_val, estimate=risk_preds, tau=time
)
if isinstance(tdci_result, tuple):
tdci_score = tdci_result[0]
else:
tdci_score = tdci_result
- result[f"tdci_event{k+1}_t{time:.2f}"] = float(tdci_score)
- except Exception as ex:
- result[f"tdci_event{k+1}_t{time:.2f}"] = np.nan
-
+ result[f"tdci_event{k + 1}_t{time:.2f}"] = float(tdci_score)
+ except Exception:
+ result[f"tdci_event{k + 1}_t{time:.2f}"] = np.nan
+
return result
-
+
# Collect all tasks
tasks = [(k, i, time) for k in range(n_events) for i, time in enumerate(times)]
-
+
# Run in parallel
- results = Parallel(n_jobs=-1)(delayed(process_metric)(k, i, time) for k, i, time in tasks)
-
+ results = Parallel(n_jobs=-1)(
+ delayed(process_metric)(k, i, time) for k, i, time in tasks
+ )
+
# Combine results
for result in results:
for k, v in result.items():
metrics[k].append(v)
-
+
except ImportError:
# Fall back to sequential computation
for k in range(n_events):
for i, time in enumerate(times):
risk_preds = abs_risks[:, k, i]
-
+
try:
# Reshape to format needed by custom function (n_samples, n_times)
risk_preds_2d = np.zeros((len(risk_preds), len(times)))
risk_preds_2d[:, i] = risk_preds
-
+
auc_score, _ = auc_td(
e_val,
t_val,
@@ -413,12 +530,12 @@ def process_metric(k, i, time):
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
)
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score))
- except Exception as ex:
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan)
-
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score))
+ except Exception:
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan)
+
try:
brier_score_val, _ = brier_score(
e_val,
@@ -427,27 +544,26 @@ def process_metric(k, i, time):
times,
time,
km=(e_train, t_train),
- primary_risk=k+1
+ primary_risk=k + 1,
)
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val))
- except Exception as ex:
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan)
-
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(
+ float(brier_score_val)
+ )
+ except Exception:
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan)
+
try:
tdci_result = concordance_index_ipcw(
- survival_train,
- survival_val,
- estimate=risk_preds,
- tau=time
+ survival_train, survival_val, estimate=risk_preds, tau=time
)
if isinstance(tdci_result, tuple):
tdci_score = tdci_result[0]
else:
tdci_score = tdci_result
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score))
- except Exception as ex:
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan)
-
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score))
+ except Exception:
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan)
+
return metrics
@@ -457,97 +573,99 @@ def display_metrics_table(metrics_dict, n_folds=5, quantiles=[0.25, 0.5, 0.75]):
"""
import pandas as pd
from tabulate import tabulate
-
+
time_points_by_metric = {}
event_types = set()
- metric_types = ['auc', 'tdci', 'brier']
-
+ metric_types = ["auc", "tdci", "brier"]
+
for key, values in metrics_dict.items():
if any(metric in key for metric in metric_types):
- parts = key.split('_')
+ parts = key.split("_")
metric_type = parts[0]
event_info = parts[1]
-
+
# Extract event type
- event_type = int(event_info.replace('event', ''))
+ event_type = int(event_info.replace("event", ""))
event_types.add(event_type)
-
+
# Extract time point
- if len(parts) > 2 and parts[2].startswith('t'):
- time_point = float(parts[2].replace('t', ''))
-
+ if len(parts) > 2 and parts[2].startswith("t"):
+ time_point = float(parts[2].replace("t", ""))
+
# Initialize nested dictionaries if needed
if (event_type, metric_type) not in time_points_by_metric:
time_points_by_metric[(event_type, metric_type)] = {}
-
+
# Store metrics by time point
time_points_by_metric[(event_type, metric_type)][time_point] = values
-
+
results = []
-
+
for event_type in sorted(event_types):
- row = {'Risk': f"Type {event_type}"}
-
+ row = {"Risk": f"Type {event_type}"}
+
for metric_type in metric_types:
if (event_type, metric_type) not in time_points_by_metric:
# Skip metrics that don't exist for this event type
for q in quantiles:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
continue
-
+
# Get all time points for this event/metric
time_data = time_points_by_metric[(event_type, metric_type)]
sorted_times = sorted(time_data.keys())
-
+
# For each quantile
for q in quantiles:
# Calculate the index for this quantile
q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q)))
-
+
# Get the corresponding time point for this quantile
q_time = sorted_times[q_idx]
-
+
# Get the metrics for this time point
q_values = time_data[q_time]
-
+
# Calculate and format statistics
if q_values:
value_array = np.array(q_values)
mean_val = np.nanmean(value_array)
std_val = np.nanstd(value_array)
- row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}"
+ row[f"{metric_type.upper()}_q{q:.2f}"] = (
+ f"{mean_val:.3f} ± {std_val:.3f}"
+ )
else:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
-
+
results.append(row)
-
+
df = pd.DataFrame(results)
-
+
# Define column order
- columns = ['Risk']
- for metric in ['AUC', 'TDCI', 'BRIER']:
+ columns = ["Risk"]
+ for metric in ["AUC", "TDCI", "BRIER"]:
for q in quantiles:
columns.append(f"{metric}_q{q:.2f}")
-
+
# Select columns in the right order (only those that exist)
df = df[[col for col in columns if col in df.columns]]
-
+
print("\nSummary Performance Metrics:")
- print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False))
-
+ print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False))
+
print("\nInterpretation:")
print("- AUC: 0.5=random, >0.7=good, >0.8=excellent")
- print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
+ print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor")
def main():
args = parse_args()
print(args)
-
+
# Set random seed for reproducibility
set_seed(args.seed)
-
+
# Configure CUDA for maximum performance
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
@@ -558,16 +676,22 @@ def main():
# Print CUDA device info
device_name = torch.cuda.get_device_name(0)
compute_capability = torch.cuda.get_device_capability(0)
- print(f"Using CUDA device: {device_name} with compute capability {compute_capability[0]}.{compute_capability[1]}")
-
+ print(
+ f"Using CUDA device: {device_name} with compute capability {compute_capability[0]}.{compute_capability[1]}"
+ )
+
# Set GPU memory usage strategy
torch.cuda.empty_cache()
-
+
# Manual memory management
- if hasattr(torch.cuda, 'memory_stats'):
- print(f"Initial GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
- print(f"Initial GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
-
+ if hasattr(torch.cuda, "memory_stats"):
+ print(
+ f"Initial GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
+ )
+ print(
+ f"Initial GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB"
+ )
+
# Load the dataset
if args.dataset.lower() == "framingham":
x, t, e, feature_names, n_cont, _ = load_framingham()
@@ -579,10 +703,10 @@ def main():
x, t, e, feature_names, n_cont, _ = load_synthetic_dataset()
else:
raise ValueError(f"Dataset {args.dataset} not supported")
-
+
# Convert data to tensors immediately before scaling
- device = ("cuda" if torch.cuda.is_available() else "cpu")
-
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
# Scale features if needed (more efficiently)
if args.scaling.lower() == "standard":
scaler = StandardScaler()
@@ -598,13 +722,13 @@ def main():
pass
else:
raise ValueError(f"Scaling method {args.scaling} not supported")
-
+
# Get number of competing risks
num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0)
-
+
# Maximum time in the dataset
max_time = np.max(t)
-
+
# Convert data to tensors early and move to device if small enough
# For large datasets, keep on CPU and transfer in batches
dataset_size = x.nbytes / (1024 * 1024) # Size in MB
@@ -616,120 +740,136 @@ def main():
e_tensor = torch.tensor(e, dtype=torch.long, device=device)
print(f"Entire dataset moved to {device} (Size: {dataset_size:.2f} MB)")
# Create dataset with GPU tensors
- dataset = SurvivalDatasetDeepHit(x_tensor.cpu().numpy(), t_tensor.cpu().numpy(),
- e_tensor.cpu().numpy(), args.num_categories)
+ dataset = SurvivalDatasetDeepHit(
+ x_tensor.cpu().numpy(),
+ t_tensor.cpu().numpy(),
+ e_tensor.cpu().numpy(),
+ args.num_categories,
+ )
except RuntimeError:
# If OOM, fall back to CPU
- print(f"Dataset too large for GPU memory, keeping on CPU (Size: {dataset_size:.2f} MB)")
+ print(
+ f"Dataset too large for GPU memory, keeping on CPU (Size: {dataset_size:.2f} MB)"
+ )
x_tensor = torch.tensor(x, dtype=torch.float32)
t_tensor = torch.tensor(t, dtype=torch.float32)
e_tensor = torch.tensor(e, dtype=torch.long)
dataset = SurvivalDatasetDeepHit(x, t, e, args.num_categories)
else:
# Large dataset, keep on CPU
- print(f"Large dataset detected ({dataset_size:.2f} MB), keeping on CPU and using efficient batch loading")
+ print(
+ f"Large dataset detected ({dataset_size:.2f} MB), keeping on CPU and using efficient batch loading"
+ )
dataset = SurvivalDatasetDeepHit(x, t, e, args.num_categories)
-
+
# Setup cross-validation
skf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed)
-
+
# Metrics collection and evaluation times
all_metrics = defaultdict(list)
quantiles = [0.25, 0.5, 0.75]
-
+
# Calculate global evaluation times
safe_max = 0.99 * np.max(t)
eval_times = np.quantile(t[t <= safe_max], quantiles)
-
+
print(f"Evaluation times: {eval_times}")
-
+
# Create directory for figures if it doesn't exist
if not os.path.exists("figs"):
os.makedirs("figs")
-
+
# Cross-validation loop
for fold, (train_idx, val_idx) in enumerate(skf.split(x, e)):
print(f"\n=== Fold {fold + 1}/{args.n_folds} ===")
-
+
# Calculate optimal batch size based on device
if device == "cuda":
# Dynamically adjust batch size based on available GPU memory
- total_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) # in GB
+ total_gpu_mem = torch.cuda.get_device_properties(0).total_memory / (
+ 1024**3
+ ) # in GB
# Use heuristic - larger batch sizes for larger GPUs
- dynamic_batch_size = min(2048, max(512, int(args.batch_size * (total_gpu_mem / 8))))
+ dynamic_batch_size = min(
+ 2048, max(512, int(args.batch_size * (total_gpu_mem / 8)))
+ )
if dynamic_batch_size != args.batch_size:
- print(f"Adjusting batch size from {args.batch_size} to {dynamic_batch_size} based on GPU memory")
+ print(
+ f"Adjusting batch size from {args.batch_size} to {dynamic_batch_size} based on GPU memory"
+ )
batch_size = dynamic_batch_size
else:
batch_size = args.batch_size
else:
batch_size = args.batch_size
-
+
# Create data loaders with optimized settings
train_subset = Subset(dataset, train_idx)
val_subset = Subset(dataset, val_idx)
-
+
# Use different prefetch factors based on data size and complexity
prefetch_factor = 2
-
+
train_loader = DataLoader(
- train_subset,
- batch_size=batch_size,
+ train_subset,
+ batch_size=batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True if device == "cuda" else False,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=prefetch_factor if args.num_workers > 0 else None,
- drop_last=True # Drop last incomplete batch for better performance
+ drop_last=True, # Drop last incomplete batch for better performance
)
-
+
val_loader = DataLoader(
- val_subset,
+ val_subset,
batch_size=batch_size,
num_workers=args.num_workers,
pin_memory=True if device == "cuda" else False,
persistent_workers=True if args.num_workers > 0 else False,
- prefetch_factor=prefetch_factor if args.num_workers > 0 else None
+ prefetch_factor=prefetch_factor if args.num_workers > 0 else None,
)
-
+
# Initialize DeepHit model with optimized size
input_dims = {
- 'x_dim': x.shape[1],
- 'num_Event': num_competing_risks,
- 'num_Category': args.num_categories
+ "x_dim": x.shape[1],
+ "num_Event": num_competing_risks,
+ "num_Category": args.num_categories,
}
-
+
# Adjust network size based on available GPU memory
if device == "cuda":
h_dim_shared = args.h_dim_shared
h_dim_CS = args.h_dim_CS
-
+
# For GPUs with more memory, can use larger networks
if total_gpu_mem > 16: # For high-end GPUs
h_dim_shared = max(args.h_dim_shared, 128)
h_dim_CS = max(args.h_dim_CS, 32)
- print(f"Using larger network dimensions for high-memory GPU: {h_dim_shared}/{h_dim_CS}")
+ print(
+ f"Using larger network dimensions for high-memory GPU: {h_dim_shared}/{h_dim_CS}"
+ )
else:
h_dim_shared = args.h_dim_shared
h_dim_CS = args.h_dim_CS
-
+
network_settings = {
- 'h_dim_shared': h_dim_shared,
- 'h_dim_CS': h_dim_CS,
- 'num_layers_shared': args.num_layers_shared,
- 'num_layers_CS': args.num_layers_CS,
- 'active_fn': args.active_fn,
- 'keep_prob': 1.0 - args.dropout_rate
+ "h_dim_shared": h_dim_shared,
+ "h_dim_CS": h_dim_CS,
+ "num_layers_shared": args.num_layers_shared,
+ "num_layers_CS": args.num_layers_CS,
+ "active_fn": args.active_fn,
+ "keep_prob": 1.0 - args.dropout_rate,
}
-
+
# Create model
model = DeepHit(input_dims, network_settings).to(device)
-
+
# Optional: print model summary for debugging
if fold == 0:
num_params = sum(p.numel() for p in model.parameters())
print(f"Model has {num_params:,} parameters")
-
+
# Print estimated memory usage per batch
# 4 bytes per float32 parameter, multiply by 4 for activations, gradients, optimizer state
param_memory_mb = num_params * 4 * 4 / (1024 * 1024)
@@ -741,40 +881,63 @@ def main():
print("Profiling first batch for performance analysis...")
# Get a sample batch for profiling
sample_batch = next(iter(train_loader))
- x_sample, t_sample, e_sample, t_disc_sample = [t.to(device, non_blocking=True) for t in sample_batch]
-
+ x_sample, t_sample, e_sample, t_disc_sample = [
+ t.to(device, non_blocking=True) for t in sample_batch
+ ]
+
# Simple profiling of forward and backward pass
with torch.autograd.profiler.profile(use_cuda=True) as prof:
# Create masks
- mask1 = create_fc_mask1_gpu(e_sample, t_disc_sample, model.num_Event, model.num_Category, device)
- mask2 = create_fc_mask2_gpu(t_disc_sample, model.num_Category, device)
-
+ mask1 = create_fc_mask1_gpu(
+ e_sample,
+ t_disc_sample,
+ model.num_Event,
+ model.num_Category,
+ device,
+ )
+ mask2 = create_fc_mask2_gpu(
+ t_disc_sample, model.num_Category, device
+ )
+
# Forward pass
out, _ = model(x_sample)
-
+
# Compute loss
- loss = model.compute_loss(out, t_sample, e_sample, mask1, mask2, args.alpha, args.beta, args.gamma)
-
+ loss = model.compute_loss(
+ out,
+ t_sample,
+ e_sample,
+ mask1,
+ mask2,
+ args.alpha,
+ args.beta,
+ args.gamma,
+ )
+
# Backward pass
loss.backward()
-
+
# Print profiling results
- profile_sorted = prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)
+ profile_sorted = prof.key_averages().table(
+ sort_by="cuda_time_total", row_limit=10
+ )
print("Profile of most expensive CUDA operations:")
print(profile_sorted)
-
+
# Identify bottlenecks
cpu_pct = prof.self_cpu_time_total / prof.total_cuda_time_total * 100
if cpu_pct > 20:
- print(f"WARNING: High CPU overhead ({cpu_pct:.1f}%). Consider optimizing CPU-GPU transfers.")
-
+ print(
+ f"WARNING: High CPU overhead ({cpu_pct:.1f}%). Consider optimizing CPU-GPU transfers."
+ )
+
except Exception as e:
print(f"Profiling skipped due to error: {e}")
-
+
# Train the model with optimizations
train_deephit_model(
- model,
- train_loader,
+ model,
+ train_loader,
val_loader,
alpha=args.alpha,
beta=args.beta,
@@ -785,34 +948,34 @@ def main():
patience=args.patience,
eval_freq=args.eval_freq,
use_amp=args.use_amp,
- verbose=True
+ verbose=True,
)
-
+
# Evaluate model (less frequently during training)
fold_metrics = evaluate_model(
- model,
- x[val_idx],
- t[val_idx],
+ model,
+ x[val_idx],
+ t[val_idx],
e[val_idx],
- t[train_idx],
- e[train_idx],
- eval_times,
+ t[train_idx],
+ e[train_idx],
+ eval_times,
max_time,
- device
+ device,
)
-
+
# Collect metrics
for k, v in fold_metrics.items():
all_metrics[k].extend(v)
-
+
# Free memory after each fold
if device == "cuda":
del model
torch.cuda.empty_cache()
-
+
# Display final metrics
display_metrics_table(all_metrics, n_folds=args.n_folds)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/training_scripts/train_nested_cv.py b/training_scripts/train_nested_cv.py
index 60ee0330..c277b556 100644
--- a/training_scripts/train_nested_cv.py
+++ b/training_scripts/train_nested_cv.py
@@ -1,86 +1,130 @@
-import yaml
-import configargparse
from collections import defaultdict
-import optuna
-import torch
+import configargparse
+import matplotlib.pyplot as plt
import numpy as np
+import optuna
import pandas as pd
-import torch.optim as optim
-import matplotlib.pyplot as plt
+import torch
+import yaml
+from model_utils import EarlyStopping, set_seed
+from sklearn.model_selection import StratifiedKFold
+from sklearn.preprocessing import MinMaxScaler, StandardScaler
+from sksurv.metrics import concordance_index_ipcw
+from sksurv.util import Surv
from tabulate import tabulate
+from torch import optim
from torch.utils.data import DataLoader
-from sklearn.model_selection import StratifiedKFold, train_test_split
-from sklearn.preprocessing import StandardScaler, MinMaxScaler
-from sksurv.util import Surv
-from sksurv.metrics import concordance_index_ipcw
+from crisp_nam.metrics import auc_td, brier_score
from crisp_nam.models import CrispNamModel
from crisp_nam.utils import (
- weighted_negative_log_likelihood_loss,
+ compute_baseline_cif,
+ compute_l2_penalty,
negative_log_likelihood_loss,
- compute_l2_penalty
+ plot_coxnam_shape_functions,
+ plot_feature_importance,
+ predict_absolute_risk,
+ weighted_negative_log_likelihood_loss,
)
from data_utils import *
-from model_utils import EarlyStopping, set_seed
-from crisp_nam.metrics import brier_score, auc_td
-from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif
-from crisp_nam.utils import plot_coxnam_shape_functions, plot_feature_importance
def parse_args():
parser = configargparse.ArgumentParser(
description="Nested Cross-Validation for CrispNAM model",
default_config_files=["config.yaml"],
- config_file_parser_class=configargparse.YAMLConfigFileParser
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
)
-
- parser.add_argument("-c", "--config", is_config_file=True,
- help="Path to config file")
- parser.add_argument("--dataset", type=str, default="framingham",
- choices=["framingham", "support", "pbc", "synthetic"],
- help="Dataset to use")
+ parser.add_argument(
+ "-c", "--config", is_config_file=True, help="Path to config file"
+ )
- parser.add_argument("--scaling", type=str, default="standard",
- choices=["minmax", "standard", "none"],
- help="Data scaling method for continuous features")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ choices=["framingham", "support", "pbc", "synthetic"],
+ help="Dataset to use",
+ )
- parser.add_argument("--num_epochs", type=int, default=250,
- help="Number of training epochs (reduced for nested CV)")
- parser.add_argument("--batch_size", type=int, default=512,
- help="Batch size for training")
- parser.add_argument("--patience", type=int, default=10,
- help="Patience for early stopping")
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ help="Data scaling method for continuous features",
+ )
+
+ parser.add_argument(
+ "--num_epochs",
+ type=int,
+ default=250,
+ help="Number of training epochs (reduced for nested CV)",
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=512, help="Batch size for training"
+ )
+ parser.add_argument(
+ "--patience", type=int, default=10, help="Patience for early stopping"
+ )
# Nested CV parameters
- parser.add_argument("--outer_folds", type=int, default=5,
- help="Number of outer CV folds")
- parser.add_argument("--inner_folds", type=int, default=3,
- help="Number of inner CV folds for hyperparameter tuning")
- parser.add_argument("--n_trials", type=int, default=20,
- help="Number of Optuna trials per inner fold")
+ parser.add_argument(
+ "--outer_folds", type=int, default=2, help="Number of outer CV folds"
+ )
+ parser.add_argument(
+ "--inner_folds",
+ type=int,
+ default=2,
+ help="Number of inner CV folds for hyperparameter tuning",
+ )
+ parser.add_argument(
+ "--n_trials",
+ type=int,
+ default=10,
+ help="Number of Optuna trials per inner fold",
+ )
# Event weighting
- parser.add_argument("--event_weighting", type=str, default="none",
- choices=["none", "balanced", "custom"],
- help="Event weighting strategy")
- parser.add_argument("--custom_event_weights", type=str, default=None,
- help="Custom weights for events (comma-separated)")
-
- parser.add_argument("--seed", type=int, default=42,
- help="Random seed for reproducibility")
-
+ parser.add_argument(
+ "--event_weighting",
+ type=str,
+ default="none",
+ choices=["none", "balanced", "custom"],
+ help="Event weighting strategy",
+ )
+ parser.add_argument(
+ "--custom_event_weights",
+ type=str,
+ default=None,
+ help="Custom weights for events (comma-separated)",
+ )
+
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
+ )
+
return parser.parse_args()
-def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_rate=1e-3,
- l2_reg=0.01, patience=10, event_weights=None, verbose=False):
+def train_model(
+ model,
+ train_loader,
+ val_loader=None,
+ num_epochs=100,
+ learning_rate=1e-3,
+ l2_reg=0.01,
+ patience=10,
+ event_weights=None,
+ verbose=False,
+):
"""Train model with early stopping"""
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
early_stopper = EarlyStopping(patience=patience)
device = next(model.parameters()).device
- best_val_loss = float('inf')
+ best_val_loss = float("inf")
for epoch in range(num_epochs):
model.train()
@@ -90,12 +134,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r
risk_scores, _ = model(x)
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
total = loss + reg
@@ -113,25 +163,32 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r
for x, t, e, _ in val_loader:
x, t, e = x.to(device), t.to(device), e.to(device)
risk_scores, _ = model(x)
-
+
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
val_loss += (loss + reg).item()
avg_val_loss = val_loss / len(val_loader)
-
- if avg_val_loss < best_val_loss:
- best_val_loss = avg_val_loss
-
+
+ best_val_loss = min(best_val_loss, avg_val_loss)
+
early_stopper.step(avg_val_loss)
if verbose and epoch % 10 == 0:
- print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
+ print(
+ f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
+ )
if early_stopper.should_stop:
if verbose:
@@ -139,37 +196,51 @@ def train_model(model, train_loader, val_loader=None, num_epochs=100, learning_r
break
elif verbose and epoch % 10 == 0:
print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f}")
-
+
return best_val_loss
-def hyperparameter_optimization(x_train_inner, t_train_inner, e_train_inner,
- x_val_inner, t_val_inner, e_val_inner,
- num_competing_risks, device, args, event_weights, n_cont):
+def hyperparameter_optimization(
+ x_train_inner,
+ t_train_inner,
+ e_train_inner,
+ x_val_inner,
+ t_val_inner,
+ e_val_inner,
+ num_competing_risks,
+ device,
+ args,
+ event_weights,
+ n_cont,
+):
"""Run Optuna hyperparameter optimization on inner training data"""
-
+
def objective(trial):
# Define hyperparameters to optimize
- learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
- l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True)
- dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.8)
- feature_dropout = trial.suggest_float('feature_dropout', 0.0, 0.5)
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
+ l2_reg = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True)
+ dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.8)
+ feature_dropout = trial.suggest_float("feature_dropout", 0.0, 0.5)
# Hidden dimensions
- n_layers = trial.suggest_int('n_layers', 1, 3)
+ n_layers = trial.suggest_int("n_layers", 1, 3)
hidden_dimensions = []
for i in range(n_layers):
- hidden_dimensions.append(trial.suggest_categorical(f'hidden_dim_{i}', [8,16, 32, 64, 128]))
-
- batch_norm = trial.suggest_categorical('batch_norm', [True, False])
-
+ hidden_dimensions.append(
+ trial.suggest_categorical(f"hidden_dim_{i}", [8, 16, 32, 64, 128])
+ )
+
+ batch_norm = trial.suggest_categorical("batch_norm", [True, False])
+
# Create data loaders
train_dataset = SurvivalDataset(x_train_inner, t_train_inner, e_train_inner)
val_dataset = SurvivalDataset(x_val_inner, t_val_inner, e_val_inner)
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True
+ )
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
-
+
# Initialize model
model = CrispNamModel(
num_features=x_train_inner.shape[1],
@@ -177,26 +248,28 @@ def objective(trial):
hidden_sizes=hidden_dimensions,
dropout_rate=dropout_rate,
feature_dropout=feature_dropout,
- batch_norm=batch_norm
+ batch_norm=batch_norm,
).to(device)
-
+
# Train model
best_val_loss = train_model(
- model, train_loader, val_loader,
- num_epochs=args.num_epochs,
+ model,
+ train_loader,
+ val_loader,
+ num_epochs=args.num_epochs,
learning_rate=learning_rate,
- l2_reg=l2_reg,
+ l2_reg=l2_reg,
patience=args.patience,
event_weights=event_weights,
- verbose=False
+ verbose=False,
)
-
+
return best_val_loss
-
+
# Create study and optimize
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=args.n_trials, show_progress_bar=False)
-
+
return study.best_params
@@ -211,29 +284,41 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
for k in range(n_events):
for i, time in enumerate(times):
risk_preds = abs_risks[:, k, i]
-
+
try:
risk_preds_2d = np.zeros((len(risk_preds), len(times)))
- risk_preds_2d[:, i] = risk_preds
-
+ risk_preds_2d[:, i] = risk_preds
+
auc_score, _ = auc_td(
- e_val, t_val, risk_preds_2d, times, time,
- km=(e_train, t_train), primary_risk=k+1
+ e_val,
+ t_val,
+ risk_preds_2d,
+ times,
+ time,
+ km=(e_train, t_train),
+ primary_risk=k + 1,
)
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score))
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score))
except Exception as ex:
- print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan)
try:
brier_score_val, _ = brier_score(
- e_val, t_val, risk_preds_2d, times, time,
- km=(e_train, t_train), primary_risk=k+1
+ e_val,
+ t_val,
+ risk_preds_2d,
+ times,
+ time,
+ km=(e_train, t_train),
+ primary_risk=k + 1,
+ )
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(
+ float(brier_score_val)
)
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_score_val))
except Exception as ex:
- print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan)
try:
tdci_result = concordance_index_ipcw(
@@ -242,112 +327,120 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
if isinstance(tdci_result, tuple):
tdci_score = tdci_result[0]
else:
- tdci_score = tdci_result
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score))
+ tdci_score = tdci_result
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score))
except Exception as ex:
- print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan)
return metrics
-def display_metrics_table(metrics_dict, quantiles=[0.25, 0.5, 0.75], dataset_name="dataset"):
+def display_metrics_table(
+ metrics_dict, quantiles=[0.25, 0.5, 0.75], dataset_name="dataset"
+):
"""Display evaluation metrics table and return processed data"""
time_points_by_metric = {}
event_types = set()
- metric_types = ['auc', 'tdci', 'brier']
-
+ metric_types = ["auc", "tdci", "brier"]
+
for key, values in metrics_dict.items():
if any(metric in key for metric in metric_types):
- parts = key.split('_')
+ parts = key.split("_")
metric_type = parts[0]
event_info = parts[1]
-
- event_type = int(event_info.replace('event', ''))
+
+ event_type = int(event_info.replace("event", ""))
event_types.add(event_type)
-
- if len(parts) > 2 and parts[2].startswith('t'):
- time_point = float(parts[2].replace('t', ''))
-
+
+ if len(parts) > 2 and parts[2].startswith("t"):
+ time_point = float(parts[2].replace("t", ""))
+
if (event_type, metric_type) not in time_points_by_metric:
time_points_by_metric[(event_type, metric_type)] = {}
-
+
time_points_by_metric[(event_type, metric_type)][time_point] = values
-
+
# Create summary table
summary_results = []
detailed_results = []
-
+
for event_type in sorted(event_types):
- row = {'Risk': f"Type {event_type}"}
-
+ row = {"Risk": f"Type {event_type}"}
+
for metric_type in metric_types:
if (event_type, metric_type) not in time_points_by_metric:
for q in quantiles:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
continue
-
+
time_data = time_points_by_metric[(event_type, metric_type)]
sorted_times = sorted(time_data.keys())
-
+
for q in quantiles:
q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q)))
q_time = sorted_times[q_idx]
q_values = time_data[q_time]
-
+
if q_values:
value_array = np.array(q_values)
mean_val = np.nanmean(value_array)
std_val = np.nanstd(value_array)
- row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}"
-
+ row[f"{metric_type.upper()}_q{q:.2f}"] = (
+ f"{mean_val:.3f} ± {std_val:.3f}"
+ )
+
# Add to detailed results for separate CSV
- detailed_results.append({
- 'Dataset': dataset_name,
- 'Risk_Type': event_type,
- 'Metric': metric_type.upper(),
- 'Time_Quantile': f"q{q:.2f}",
- 'Time_Value': q_time,
- 'Mean': mean_val,
- 'Std': std_val,
- 'N_Folds': len(q_values),
- 'Raw_Values': ';'.join(map(str, q_values))
- })
+ detailed_results.append(
+ {
+ "Dataset": dataset_name,
+ "Risk_Type": event_type,
+ "Metric": metric_type.upper(),
+ "Time_Quantile": f"q{q:.2f}",
+ "Time_Value": q_time,
+ "Mean": mean_val,
+ "Std": std_val,
+ "N_Folds": len(q_values),
+ "Raw_Values": ";".join(map(str, q_values)),
+ }
+ )
else:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
-
+
summary_results.append(row)
-
+
# Create DataFrames
summary_df = pd.DataFrame(summary_results)
detailed_df = pd.DataFrame(detailed_results)
-
- columns = ['Risk']
- for metric in ['AUC', 'TDCI', 'BRIER']:
+
+ columns = ["Risk"]
+ for metric in ["AUC", "TDCI", "BRIER"]:
for q in quantiles:
columns.append(f"{metric}_q{q:.2f}")
-
+
summary_df = summary_df[[col for col in columns if col in summary_df.columns]]
-
+
print("\nNested CV Performance Metrics:")
- print(tabulate(summary_df, headers='keys', tablefmt='pretty', showindex=False))
-
+ print(tabulate(summary_df, headers="keys", tablefmt="pretty", showindex=False))
+
print("\nInterpretation:")
print("- AUC: 0.5=random, >0.7=good, >0.8=excellent")
- print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
+ print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor")
-
+
return summary_df, detailed_df
def main():
args = parse_args()
- print(f"Running Nested Cross-Validation with {args.outer_folds} outer folds and {args.inner_folds} inner folds")
+ print(
+ f"Running Nested Cross-Validation with {args.outer_folds} outer folds and {args.inner_folds} inner folds"
+ )
print(args)
-
+
# Set random seed
set_seed(args.seed)
-
+
# Load dataset
if args.dataset.lower() == "framingham":
x, t, e, feature_names, n_cont, _ = load_framingham()
@@ -359,142 +452,193 @@ def main():
x, t, e, feature_names, n_cont, _ = load_synthetic_dataset()
else:
raise ValueError(f"Dataset {args.dataset} not supported")
-
+
num_competing_risks = len(np.unique(e)) - 1
- device = ("cuda" if torch.cuda.is_available() else "cpu")
-
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
# Outer cross-validation loop
- outer_cv = StratifiedKFold(n_splits=args.outer_folds, shuffle=True, random_state=args.seed)
+ outer_cv = StratifiedKFold(
+ n_splits=args.outer_folds, shuffle=True, random_state=args.seed
+ )
all_metrics = defaultdict(list)
all_best_params = []
-
+
quantiles = [0.25, 0.5, 0.75]
safe_max = 0.99 * np.max(t)
eval_times = np.quantile(t[t <= safe_max], quantiles)
-
+
print(f"Evaluation times: {eval_times}")
-
+
for outer_fold, (train_idx, test_idx) in enumerate(outer_cv.split(x, e)):
print(f"\n=== Outer Fold {outer_fold + 1}/{args.outer_folds} ===")
-
+
# Split data for outer fold
x_train_outer, x_test_outer = x[train_idx].copy(), x[test_idx].copy()
t_train_outer, t_test_outer = t[train_idx], t[test_idx]
e_train_outer, e_test_outer = e[train_idx], e[test_idx]
-
+
# Apply scaling within outer fold
if args.scaling.lower() == "standard":
scaler = StandardScaler()
- x_train_outer[:, -n_cont:] = scaler.fit_transform(x_train_outer[:, -n_cont:])
+ x_train_outer[:, -n_cont:] = scaler.fit_transform(
+ x_train_outer[:, -n_cont:]
+ )
x_test_outer[:, -n_cont:] = scaler.transform(x_test_outer[:, -n_cont:])
elif args.scaling.lower() == "minmax":
scaler = MinMaxScaler()
- x_train_outer[:, -n_cont:] = scaler.fit_transform(x_train_outer[:, -n_cont:])
+ x_train_outer[:, -n_cont:] = scaler.fit_transform(
+ x_train_outer[:, -n_cont:]
+ )
x_test_outer[:, -n_cont:] = scaler.transform(x_test_outer[:, -n_cont:])
-
+
# Inner cross-validation for hyperparameter tuning
- inner_cv = StratifiedKFold(n_splits=args.inner_folds, shuffle=True, random_state=args.seed)
+ inner_cv = StratifiedKFold(
+ n_splits=args.inner_folds, shuffle=True, random_state=args.seed
+ )
inner_scores = []
inner_params = []
-
- for inner_fold, (train_inner_idx, val_inner_idx) in enumerate(inner_cv.split(x_train_outer, e_train_outer)):
+
+ for inner_fold, (train_inner_idx, val_inner_idx) in enumerate(
+ inner_cv.split(x_train_outer, e_train_outer)
+ ):
print(f" Inner Fold {inner_fold + 1}/{args.inner_folds}")
-
+
# Split inner training data
x_train_inner = x_train_outer[train_inner_idx].copy()
x_val_inner = x_train_outer[val_inner_idx].copy()
- t_train_inner, t_val_inner = t_train_outer[train_inner_idx], t_train_outer[val_inner_idx]
- e_train_inner, e_val_inner = e_train_outer[train_inner_idx], e_train_outer[val_inner_idx]
-
+ t_train_inner, t_val_inner = (
+ t_train_outer[train_inner_idx],
+ t_train_outer[val_inner_idx],
+ )
+ e_train_inner, e_val_inner = (
+ e_train_outer[train_inner_idx],
+ e_train_outer[val_inner_idx],
+ )
+
# Calculate event weights for this inner fold
event_weights = None
if args.event_weighting == "balanced":
event_counts = np.zeros(num_competing_risks)
for k in range(1, num_competing_risks + 1):
- event_counts[k-1] = np.sum(e_train_inner == k)
+ event_counts[k - 1] = np.sum(e_train_inner == k)
event_counts = np.maximum(event_counts, 1)
event_weights = 1.0 / event_counts
- event_weights = event_weights * (num_competing_risks / event_weights.sum())
- event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device)
-
+ event_weights = event_weights * (
+ num_competing_risks / event_weights.sum()
+ )
+ event_weights = torch.tensor(
+ event_weights, dtype=torch.float32, device=device
+ )
+
# Hyperparameter optimization
best_params = hyperparameter_optimization(
- x_train_inner, t_train_inner, e_train_inner,
- x_val_inner, t_val_inner, e_val_inner,
- num_competing_risks, device, args, event_weights, n_cont
+ x_train_inner,
+ t_train_inner,
+ e_train_inner,
+ x_val_inner,
+ t_val_inner,
+ e_val_inner,
+ num_competing_risks,
+ device,
+ args,
+ event_weights,
+ n_cont,
)
-
+
inner_params.append(best_params)
-
+
# Select best hyperparameters (could use ensemble or average, here we take most common)
# For simplicity, we'll use the first inner fold's best params
best_params = inner_params[0]
all_best_params.append(best_params)
-
- print(f" Selected hyperparameters for outer fold {outer_fold + 1}: {best_params}")
-
+
+ print(
+ f" Selected hyperparameters for outer fold {outer_fold + 1}: {best_params}"
+ )
+
# Train final model on all outer training data with best hyperparameters
- n_layers = best_params['n_layers']
- hidden_dimensions = [best_params[f'hidden_dim_{i}'] for i in range(n_layers)]
-
+ n_layers = best_params["n_layers"]
+ hidden_dimensions = [best_params[f"hidden_dim_{i}"] for i in range(n_layers)]
+
# Calculate event weights for outer training data
event_weights = None
if args.event_weighting == "balanced":
event_counts = np.zeros(num_competing_risks)
for k in range(1, num_competing_risks + 1):
- event_counts[k-1] = np.sum(e_train_outer == k)
+ event_counts[k - 1] = np.sum(e_train_outer == k)
event_counts = np.maximum(event_counts, 1)
event_weights = 1.0 / event_counts
event_weights = event_weights * (num_competing_risks / event_weights.sum())
- event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device)
-
+ event_weights = torch.tensor(
+ event_weights, dtype=torch.float32, device=device
+ )
+
# Create final model
final_model = CrispNamModel(
num_features=x.shape[1],
num_competing_risks=num_competing_risks,
hidden_sizes=hidden_dimensions,
- dropout_rate=best_params['dropout_rate'],
- feature_dropout=best_params['feature_dropout'],
- batch_norm=best_params['batch_norm']
+ dropout_rate=best_params["dropout_rate"],
+ feature_dropout=best_params["feature_dropout"],
+ batch_norm=best_params["batch_norm"],
).to(device)
-
+
# Train final model
train_dataset = SurvivalDataset(x_train_outer, t_train_outer, e_train_outer)
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
-
+ train_loader = DataLoader(
+ train_dataset, batch_size=args.batch_size, shuffle=True
+ )
+
train_model(
- final_model, train_loader, None,
- num_epochs=args.num_epochs,
- learning_rate=best_params['learning_rate'],
- l2_reg=best_params['l2_reg'],
+ final_model,
+ train_loader,
+ None,
+ num_epochs=args.num_epochs,
+ learning_rate=best_params["learning_rate"],
+ l2_reg=best_params["l2_reg"],
patience=args.patience,
event_weights=event_weights,
- verbose=True
+ verbose=True,
)
-
+
# Evaluate on test data
- baseline_cifs = {k: compute_baseline_cif(t_train_outer, e_train_outer, eval_times, k + 1)
- for k in range(num_competing_risks)}
-
- abs_risks = predict_absolute_risk(final_model, x_test_outer, baseline_cifs, eval_times, device=device)
-
- fold_metrics = evaluate_model(final_model, x_test_outer, t_test_outer, e_test_outer,
- t_train_outer, e_train_outer, abs_risks, eval_times)
-
+ baseline_cifs = {
+ k: compute_baseline_cif(t_train_outer, e_train_outer, eval_times, k + 1)
+ for k in range(num_competing_risks)
+ }
+
+ abs_risks = predict_absolute_risk(
+ final_model, x_test_outer, baseline_cifs, eval_times, device=device
+ )
+
+ fold_metrics = evaluate_model(
+ final_model,
+ x_test_outer,
+ t_test_outer,
+ e_test_outer,
+ t_train_outer,
+ e_train_outer,
+ abs_risks,
+ eval_times,
+ )
+
for k, v in fold_metrics.items():
all_metrics[k].extend(v)
-
+
# Display final results and get DataFrames for saving
- summary_df, detailed_df = display_metrics_table(all_metrics, dataset_name=args.dataset)
-
+ summary_df, detailed_df = display_metrics_table(
+ all_metrics, dataset_name=args.dataset
+ )
+
# Generate shape plots using the last trained model (from the last outer fold)
print("\n=== Generating Shape Function Plots ===")
-
+
# Create figs subdirectory if not present
import os
+
if not os.path.exists("figs"):
os.makedirs("figs")
-
+
# Use the entire dataset for plotting (with proper scaling)
x_plot = x.copy()
if args.scaling.lower() == "standard":
@@ -503,11 +647,11 @@ def main():
elif args.scaling.lower() == "minmax":
scaler = MinMaxScaler()
x_plot[:, -n_cont:] = scaler.fit_transform(x_plot[:, -n_cont:])
-
+
# Generate plots for each risk type
for risk in range(1, num_competing_risks + 1):
print(f"Generating plots for risk type {risk}")
-
+
# Generate feature importance plot
fig, _, top_positive, top_negative = plot_feature_importance(
model=final_model,
@@ -515,131 +659,150 @@ def main():
feature_names=feature_names,
n_top=5, # Show top 5 positive contributors
n_bottom=5, # Show top 5 negative contributors
- risk_idx=risk,
+ risk_idx=risk,
figsize=(6, 4),
- output_file=f"results/plots/nested_cv_feature_importance_risk_{risk}_{args.dataset}.png"
+ output_file=f"results/plots/nested_cv_feature_importance_risk_{risk}_{args.dataset}.png",
)
-
+
# Get top features for shape function plots
top_features = top_positive + top_negative
-
+
# Generate shape function plots for top features
+ print("Ananya: Generating shape function plots for top features...")
+ print(f'{feature_names}')
+ print(f'{top_features}')
fig, _ = plot_coxnam_shape_functions(
model=final_model,
- X=x_plot,
+ X=x_plot,
risk_to_plot=risk,
- feature_names=feature_names,
- top_features=top_features,
+ feature_names=feature_names,
+ top_features=top_features,
ncols=5,
figsize=(12, 6),
- output_file=f"results/plots/nested_cv_shape_functions_risk_{risk}_{args.dataset}.png"
+ output_file=f"results/plots/nested_cv_shape_functions_risk_{risk}_{args.dataset}.png",
)
plt.close(fig)
-
- print(f"Shape function plots saved to results/plots/ directory")
-
+
+ print("Shape function plots saved to results/plots/ directory")
+
# Save metrics to files
print("\n=== Saving Metrics to Files ===")
-
+
# Save summary metrics (formatted table)
summary_filename = f"nested_cv_summary_metrics_{args.dataset}.csv"
summary_df.to_csv(summary_filename, index=False)
print(f"Summary metrics saved to: {summary_filename}")
-
+
# Save detailed metrics (all individual fold results)
detailed_filename = f"nested_cv_detailed_metrics_{args.dataset}.csv"
detailed_df.to_csv(detailed_filename, index=False)
print(f"Detailed metrics saved to: {detailed_filename}")
-
+
# Save to Excel with multiple sheets
excel_filename = f"nested_cv_metrics_{args.dataset}.xlsx"
- with pd.ExcelWriter(excel_filename, engine='openpyxl') as writer:
- summary_df.to_excel(writer, sheet_name='Summary', index=False)
- detailed_df.to_excel(writer, sheet_name='Detailed', index=False)
-
+ with pd.ExcelWriter(excel_filename, engine="openpyxl") as writer:
+ summary_df.to_excel(writer, sheet_name="Summary", index=False)
+ detailed_df.to_excel(writer, sheet_name="Detailed", index=False)
+
# Create a metadata sheet
- metadata_df = pd.DataFrame([
- {'Parameter': 'Dataset', 'Value': args.dataset},
- {'Parameter': 'Outer Folds', 'Value': args.outer_folds},
- {'Parameter': 'Inner Folds', 'Value': args.inner_folds},
- {'Parameter': 'Number of Trials', 'Value': args.n_trials},
- {'Parameter': 'Number of Epochs', 'Value': args.num_epochs},
- {'Parameter': 'Batch Size', 'Value': args.batch_size},
- {'Parameter': 'Event Weighting', 'Value': args.event_weighting},
- {'Parameter': 'Scaling', 'Value': args.scaling},
- {'Parameter': 'Random Seed', 'Value': args.seed}
- ])
- metadata_df.to_excel(writer, sheet_name='Metadata', index=False)
-
+ metadata_df = pd.DataFrame(
+ [
+ {"Parameter": "Dataset", "Value": args.dataset},
+ {"Parameter": "Outer Folds", "Value": args.outer_folds},
+ {"Parameter": "Inner Folds", "Value": args.inner_folds},
+ {"Parameter": "Number of Trials", "Value": args.n_trials},
+ {"Parameter": "Number of Epochs", "Value": args.num_epochs},
+ {"Parameter": "Batch Size", "Value": args.batch_size},
+ {"Parameter": "Event Weighting", "Value": args.event_weighting},
+ {"Parameter": "Scaling", "Value": args.scaling},
+ {"Parameter": "Random Seed", "Value": args.seed},
+ ]
+ )
+ metadata_df.to_excel(writer, sheet_name="Metadata", index=False)
+
print(f"Excel file with multiple sheets saved to: {excel_filename}")
-
+
# Save raw metrics dictionary as JSON for complete reproducibility
import json
-
+
# Convert numpy arrays to lists for JSON serialization
serializable_metrics = {}
for key, values in all_metrics.items():
- serializable_metrics[key] = [float(v) if not np.isnan(v) else None for v in values]
-
+ serializable_metrics[key] = [
+ float(v) if not np.isnan(v) else None for v in values
+ ]
+
json_filename = f"nested_cv_raw_metrics_{args.dataset}.json"
- with open(json_filename, 'w') as f:
- json.dump({
- 'metrics': serializable_metrics,
- 'metadata': {
- 'dataset': args.dataset,
- 'outer_folds': args.outer_folds,
- 'inner_folds': args.inner_folds,
- 'n_trials': args.n_trials,
- 'num_epochs': args.num_epochs,
- 'batch_size': args.batch_size,
- 'event_weighting': args.event_weighting,
- 'scaling': args.scaling,
- 'seed': args.seed
- }
- }, f, indent=2)
-
+ with open(json_filename, "w") as f:
+ json.dump(
+ {
+ "metrics": serializable_metrics,
+ "metadata": {
+ "dataset": args.dataset,
+ "outer_folds": args.outer_folds,
+ "inner_folds": args.inner_folds,
+ "n_trials": args.n_trials,
+ "num_epochs": args.num_epochs,
+ "batch_size": args.batch_size,
+ "event_weighting": args.event_weighting,
+ "scaling": args.scaling,
+ "seed": args.seed,
+ },
+ },
+ f,
+ indent=2,
+ )
+
print(f"Raw metrics (JSON) saved to: {json_filename}")
-
+
# Save aggregated best parameters
param_summary = {}
for param_name in all_best_params[0].keys():
param_values = [params[param_name] for params in all_best_params]
-
+
# Check if all values are numeric
if all(isinstance(val, (int, float)) for val in param_values):
param_summary[param_name] = np.mean(param_values)
else:
# For categorical parameters, take the most common
param_summary[param_name] = max(set(param_values), key=param_values.count)
-
- print(f"\nAggregated best hyperparameters across all outer folds:")
+
+ print("\nAggregated best hyperparameters across all outer folds:")
for param, value in param_summary.items():
print(f"{param}: {value}")
-
+
# Save to YAML
config_dict = {
"dataset": args.dataset,
"scaling": args.scaling,
"num_epochs": args.num_epochs,
"batch_size": args.batch_size,
- "learning_rate": param_summary['learning_rate'],
- "l2_reg": param_summary['l2_reg'],
+ "learning_rate": param_summary["learning_rate"],
+ "l2_reg": param_summary["l2_reg"],
"patience": args.patience,
- "dropout_rate": param_summary['dropout_rate'],
- "feature_dropout": param_summary['feature_dropout'],
- "hidden_dimensions": ",".join(map(str, [param_summary[f'hidden_dim_{i}'] for i in range(int(param_summary['n_layers']))])),
- "batch_norm": str(param_summary['batch_norm']),
+ "dropout_rate": param_summary["dropout_rate"],
+ "feature_dropout": param_summary["feature_dropout"],
+ "hidden_dimensions": ",".join(
+ map(
+ str,
+ [
+ param_summary[f"hidden_dim_{i}"]
+ for i in range(int(param_summary["n_layers"]))
+ ],
+ )
+ ),
+ "batch_norm": str(param_summary["batch_norm"]),
"event_weighting": args.event_weighting,
"seed": args.seed,
- "n_folds": args.outer_folds
+ "n_folds": args.outer_folds,
}
-
+
output_file = f"nested_cv_best_params_{args.dataset}.yaml"
- with open(output_file, 'w') as file:
+ with open(output_file, "w") as file:
yaml.dump(config_dict, file, default_flow_style=False)
-
+
print(f"\nBest hyperparameters saved to {output_file}")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/training_scripts/tune_optuna.py b/training_scripts/tune_optuna.py
index 3711de49..b6494f4d 100644
--- a/training_scripts/tune_optuna.py
+++ b/training_scripts/tune_optuna.py
@@ -1,74 +1,108 @@
-import yaml
-import configargparse
from collections import defaultdict
+import configargparse
+import numpy as np
import optuna
import torch
-import numpy as np
-import torch.optim as optim
-from torch.utils.data import DataLoader
+import yaml
+from model_utils import EarlyStopping, set_seed
from sklearn.model_selection import train_test_split
-from sklearn.preprocessing import StandardScaler, MinMaxScaler
+from sklearn.preprocessing import MinMaxScaler, StandardScaler
+from sksurv.metrics import brier_score, concordance_index_ipcw, cumulative_dynamic_auc
from sksurv.util import Surv
-from sksurv.metrics import concordance_index_ipcw, cumulative_dynamic_auc, brier_score
+from torch import optim
+from torch.utils.data import DataLoader
from crisp_nam.models import CrispNamModel
from crisp_nam.utils import (
- weighted_negative_log_likelihood_loss,
+ compute_baseline_cif,
+ compute_l2_penalty,
negative_log_likelihood_loss,
- compute_l2_penalty
+ predict_absolute_risk,
+ weighted_negative_log_likelihood_loss,
)
from data_utils import *
-from model_utils import EarlyStopping, set_seed
-from crisp_nam.utils import predict_absolute_risk, compute_baseline_cif
+
def parse_args():
parser = configargparse.ArgumentParser(
description="Training script for MultiTaskCoxNAM model with Optuna",
default_config_files=["config.yaml"],
- config_file_parser_class=configargparse.YAMLConfigFileParser
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
)
-
-
+
# Dataset
- parser.add_argument("--dataset", type=str, default="framingham", choices=["framingham", "support", "pbc", "synthetic"],
- help="Dataset to use")
-
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ choices=["framingham", "support", "pbc", "synthetic"],
+ help="Dataset to use",
+ )
+
# Data scaling
- parser.add_argument("--scaling", type=str, default="standard", choices=["minmax", "standard", "none"],
- help="Data scaling method for continuous features")
-
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ help="Data scaling method for continuous features",
+ )
+
# Training parameters
- parser.add_argument("--num_epochs", type=int, default=500,
- help="Number of training epochs")
- parser.add_argument("--batch_size", type=int, default=256,
- help="Batch size for training")
- parser.add_argument("--patience", type=int, default=10,
- help="Patience for early stopping")
-
+ parser.add_argument(
+ "--num_epochs", type=int, default=500, help="Number of training epochs"
+ )
+ parser.add_argument(
+ "--batch_size", type=int, default=256, help="Batch size for training"
+ )
+ parser.add_argument(
+ "--patience", type=int, default=10, help="Patience for early stopping"
+ )
+
# Optuna parameters
- parser.add_argument("--n_trials", type=int, default=50,
- help="Number of Optuna trials")
-
+ parser.add_argument(
+ "--n_trials", type=int, default=50, help="Number of Optuna trials"
+ )
+
# Weight parameters
- parser.add_argument("--event_weighting", type=str, default="none",
- choices=["none", "balanced", "custom"],
- help="Event weighting strategy (none, balanced, custom)")
- parser.add_argument("--custom_event_weights", type=str, default=None,
- help="Custom weights for events (comma-separated, e.g., '1.0,2.0')")
-
+ parser.add_argument(
+ "--event_weighting",
+ type=str,
+ default="none",
+ choices=["none", "balanced", "custom"],
+ help="Event weighting strategy (none, balanced, custom)",
+ )
+ parser.add_argument(
+ "--custom_event_weights",
+ type=str,
+ default=None,
+ help="Custom weights for events (comma-separated, e.g., '1.0,2.0')",
+ )
+
# Other parameters
- parser.add_argument("--seed", type=int, default=42,
- help="Random seed for reproducibility")
-
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
+ )
+
return parser.parse_args()
-def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_rate=1e-3,
- l2_reg=0.01, patience=10, event_weights=None, verbose=True):
+
+def train_model(
+ model,
+ train_loader,
+ val_loader=None,
+ num_epochs=500,
+ learning_rate=1e-3,
+ l2_reg=0.01,
+ patience=10,
+ event_weights=None,
+ verbose=True,
+):
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
early_stopper = EarlyStopping(patience=patience)
device = next(model.parameters()).device
- best_val_loss = float('inf')
+ best_val_loss = float("inf")
for epoch in range(num_epochs):
model.train()
@@ -79,12 +113,18 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r
# Use weighted loss if event_weights is provided
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
total = loss + reg
@@ -102,25 +142,32 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r
for x, t, e, _ in val_loader:
x, t, e = x.to(device), t.to(device), e.to(device)
risk_scores, _ = model(x)
-
+
# Use same loss function as in training
if event_weights is not None:
- loss = weighted_negative_log_likelihood_loss(risk_scores, t, e,
- model.num_competing_risks,
- event_weights=event_weights)
+ loss = weighted_negative_log_likelihood_loss(
+ risk_scores,
+ t,
+ e,
+ model.num_competing_risks,
+ event_weights=event_weights,
+ )
else:
- loss = negative_log_likelihood_loss(risk_scores, t, e, model.num_competing_risks)
-
+ loss = negative_log_likelihood_loss(
+ risk_scores, t, e, model.num_competing_risks
+ )
+
reg = compute_l2_penalty(model) * l2_reg
val_loss += (loss + reg).item()
avg_val_loss = val_loss / len(val_loader)
current_best = early_stopper.step(avg_val_loss)
-
- if avg_val_loss < best_val_loss:
- best_val_loss = avg_val_loss
+
+ best_val_loss = min(best_val_loss, avg_val_loss)
if verbose:
- print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
+ print(
+ f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}"
+ )
if early_stopper.should_stop:
if verbose:
@@ -128,7 +175,7 @@ def train_model(model, train_loader, val_loader=None, num_epochs=500, learning_r
break
elif verbose:
print(f"Epoch {epoch + 1} | Train Loss: {avg_train_loss:.4f}")
-
+
return best_val_loss
@@ -151,17 +198,14 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
# ---- AUC ----
try:
auc_score, _ = cumulative_dynamic_auc(
- survival_train,
- survival_val,
- risk_preds,
- times=[time]
+ survival_train, survival_val, risk_preds, times=[time]
)
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(float(auc_score[0]))
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(float(auc_score[0]))
avg_auc += float(auc_score[0])
count += 1
except Exception as ex:
- print(f"[Warning] AUC failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"auc_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] AUC failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"auc_event{k + 1}_t{time:.2f}"].append(np.nan)
# ---- Brier Score ----
try:
@@ -169,153 +213,176 @@ def evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, time
surv_probs = surv_probs.reshape(-1, 1) # Shape: (n_samples, n_times)
_, brier_scores = brier_score(
- survival_train,
- survival_val,
- surv_probs,
- times=np.array([time])
+ survival_train, survival_val, surv_probs, times=np.array([time])
+ )
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(
+ float(brier_scores[0])
)
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(float(brier_scores[0]))
except Exception as ex:
- print(f"[Warning] Brier failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"brier_event{k+1}_t{time:.2f}"].append(np.nan)
+ print(f"[Warning] Brier failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"brier_event{k + 1}_t{time:.2f}"].append(np.nan)
# ---- Time-dependent Concordance Index ----
try:
tdci_result = concordance_index_ipcw(
- survival_train,
- survival_val,
- estimate=risk_preds,
- tau=time
+ survival_train, survival_val, estimate=risk_preds, tau=time
)
if isinstance(tdci_result, tuple):
tdci_score = tdci_result[0]
else:
tdci_score = tdci_result # fallback
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(float(tdci_score))
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(float(tdci_score))
except Exception as ex:
- print(f"[Warning] td-CI failed at t={time:.2f}, event={k+1}: {ex}")
- metrics[f"tdci_event{k+1}_t{time:.2f}"].append(np.nan)
-
+ print(f"[Warning] td-CI failed at t={time:.2f}, event={k + 1}: {ex}")
+ metrics[f"tdci_event{k + 1}_t{time:.2f}"].append(np.nan)
+
# Calculate average AUC if we have valid measurements
if count > 0:
avg_auc = avg_auc / count
else:
avg_auc = 0.0
-
+
return metrics, avg_auc
def display_metrics_table(metrics_dict, quantiles=[0.25, 0.5, 0.75]):
import pandas as pd
from tabulate import tabulate
-
+
time_points_by_metric = {}
event_types = set()
- metric_types = ['auc', 'tdci', 'brier']
-
+ metric_types = ["auc", "tdci", "brier"]
+
for key, values in metrics_dict.items():
if any(metric in key for metric in metric_types):
- parts = key.split('_')
+ parts = key.split("_")
metric_type = parts[0]
event_info = parts[1]
-
+
# Extract event type
- event_type = int(event_info.replace('event', ''))
+ event_type = int(event_info.replace("event", ""))
event_types.add(event_type)
-
+
# Extract time point
- if len(parts) > 2 and parts[2].startswith('t'):
- time_point = float(parts[2].replace('t', ''))
-
+ if len(parts) > 2 and parts[2].startswith("t"):
+ time_point = float(parts[2].replace("t", ""))
+
# Initialize nested dictionaries if needed
if (event_type, metric_type) not in time_points_by_metric:
time_points_by_metric[(event_type, metric_type)] = {}
-
+
# Store metrics by time point
time_points_by_metric[(event_type, metric_type)][time_point] = values
-
+
results = []
-
+
for event_type in sorted(event_types):
- row = {'Risk': f"Type {event_type}"}
-
+ row = {"Risk": f"Type {event_type}"}
+
for metric_type in metric_types:
if (event_type, metric_type) not in time_points_by_metric:
# Skip metrics that don't exist for this event type
for q in quantiles:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
continue
-
+
# Get all time points for this event/metric
time_data = time_points_by_metric[(event_type, metric_type)]
sorted_times = sorted(time_data.keys())
-
+
# For each quantile
for q in quantiles:
# Calculate the index for this quantile
q_idx = max(0, min(len(sorted_times) - 1, int(len(sorted_times) * q)))
-
+
# Get the corresponding time point for this quantile
q_time = sorted_times[q_idx]
-
+
# Get the metrics for this time point
q_values = time_data[q_time]
-
+
# Calculate and format statistics
if q_values:
value_array = np.array(q_values)
mean_val = np.nanmean(value_array)
std_val = np.nanstd(value_array)
- row[f"{metric_type.upper()}_q{q:.2f}"] = f"{mean_val:.3f} ± {std_val:.3f}"
+ row[f"{metric_type.upper()}_q{q:.2f}"] = (
+ f"{mean_val:.3f} ± {std_val:.3f}"
+ )
else:
row[f"{metric_type.upper()}_q{q:.2f}"] = "N/A"
-
+
results.append(row)
-
+
df = pd.DataFrame(results)
-
+
# Define column order
- columns = ['Risk']
- for metric in ['AUC', 'TDCI', 'BRIER']:
+ columns = ["Risk"]
+ for metric in ["AUC", "TDCI", "BRIER"]:
for q in quantiles:
columns.append(f"{metric}_q{q:.2f}")
-
+
# Select columns in the right order (only those that exist)
df = df[[col for col in columns if col in df.columns]]
-
+
print("\nSummary Performance Metrics:")
- print(tabulate(df, headers='keys', tablefmt='pretty', showindex=False))
-
+ print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False))
+
print("\nInterpretation:")
print("- AUC: 0.5=random, >0.7=good, >0.8=excellent")
- print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
+ print("- TDCI (Time-Dependent C-Index): 0.5=random, >0.7=good, >0.8=excellent")
print("- Brier Score: 0=perfect, <0.25=good, >0.25=poor")
-def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, feature_names, n_cont,
- num_competing_risks, device, args, event_weights=None):
-
+def objective(
+ trial,
+ x,
+ t,
+ e,
+ x_train,
+ t_train,
+ e_train,
+ x_val,
+ t_val,
+ e_val,
+ feature_names,
+ n_cont,
+ num_competing_risks,
+ device,
+ args,
+ event_weights=None,
+):
# Define hyperparameters to optimize
- learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
- l2_reg = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True)
- dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.8)
- feature_dropout = trial.suggest_float('feature_dropout', 0.0, 0.5)
-
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
+ l2_reg = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True)
+ dropout_rate = trial.suggest_float("dropout_rate", 0.0, 0.8)
+ feature_dropout = trial.suggest_float("feature_dropout", 0.0, 0.5)
+
# For hidden_dimensions
- n_layers = trial.suggest_int('n_layers', 1, 3)
+ n_layers = trial.suggest_int("n_layers", 1, 3)
hidden_dimensions = []
for i in range(n_layers):
- hidden_dimensions.append(trial.suggest_categorical(f'hidden_dim_{i}', [32, 64, 128, 256]))
-
- batch_norm = trial.suggest_categorical('batch_norm', [True, False])
-
+ hidden_dimensions.append(
+ trial.suggest_categorical(f"hidden_dim_{i}", [32, 64, 128, 256])
+ )
+
+ batch_norm = trial.suggest_categorical("batch_norm", [True, False])
+
# Create data loaders
train_dataset = SurvivalDataset(x_train, t_train, e_train)
val_dataset = SurvivalDataset(x_val, t_val, e_val)
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size,num_workers=4, pin_memory=True)
-
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ num_workers=4,
+ pin_memory=True,
+ shuffle=True,
+ )
+ val_loader = DataLoader(
+ val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True
+ )
+
# Initialize model with the hyperparameters
model = CrispNamModel(
num_features=x.shape[1],
@@ -323,40 +390,40 @@ def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, fe
hidden_sizes=hidden_dimensions,
dropout_rate=dropout_rate,
feature_dropout=feature_dropout,
- batch_norm=batch_norm
+ batch_norm=batch_norm,
).to(device)
-
+
# Train model
best_val_loss = train_model(
- model,
- train_loader,
- val_loader,
- num_epochs=args.num_epochs,
+ model,
+ train_loader,
+ val_loader,
+ num_epochs=args.num_epochs,
learning_rate=learning_rate,
- l2_reg=l2_reg,
+ l2_reg=l2_reg,
patience=args.patience,
event_weights=event_weights,
- verbose=False # Set to False to reduce output during Optuna trials
+ verbose=False, # Set to False to reduce output during Optuna trials
)
-
+
# We can also evaluate with metrics like AUC
# If needed, uncomment this code to optimize for AUC instead of validation loss
"""
# Calculate evaluation times
safe_max = 0.99 * np.max(t_train)
eval_times = np.quantile(t_train[t_train <= safe_max], [0.25, 0.5, 0.75])
-
+
# Calculate baseline CIFs
- baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
+ baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
for k in range(num_competing_risks)}
-
+
# Predict and evaluate
abs_risks = predict_absolute_risk(model, x_val, baseline_cifs, eval_times, device=device)
_, avg_auc = evaluate_model(model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times)
-
+
return avg_auc # Maximize AUC
"""
-
+
# For now, we'll optimize for validation loss (minimize)
return best_val_loss # Minimize validation loss
@@ -364,10 +431,10 @@ def objective(trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val, fe
def main():
args = parse_args()
print(args)
-
+
# Set random seed for reproducibility
set_seed(args.seed)
-
+
# Load the dataset
if args.dataset.lower() == "framingham":
x, t, e, feature_names, n_cont, _ = load_framingham()
@@ -379,21 +446,18 @@ def main():
x, t, e, feature_names, n_cont, _ = load_synthetic_dataset()
else:
raise ValueError(f"Dataset {args.dataset} not supported")
-
+
# Note: Scaling will be done after train/validation split to prevent data leakage
# Compute number of competing risks
num_competing_risks = len(np.unique(e)) - 1 # Excluding censoring (0)
- device = ("cuda" if torch.cuda.is_available() else "cpu")
-
-
+ device = "cuda" if torch.cuda.is_available() else "cpu"
-
# Split data into train and validation sets
x_train, x_val, t_train, t_val, e_train, e_val = train_test_split(
x, t, e, test_size=0.2, random_state=args.seed, stratify=e
)
-
+
# Apply scaling after split to prevent data leakage
if args.scaling.lower() == "standard":
scaler = StandardScaler()
@@ -407,148 +471,183 @@ def main():
pass
else:
raise ValueError(f"Scaling method {args.scaling} not supported")
-
+
# Initialize event weights
event_weights = None
-
+
if args.event_weighting != "none":
if args.event_weighting == "balanced":
# Compute balanced weights (inverse of class frequencies)
event_counts = np.zeros(num_competing_risks)
for k in range(1, num_competing_risks + 1):
- event_counts[k-1] = np.sum(e_train == k)
-
+ event_counts[k - 1] = np.sum(e_train == k)
+
# Avoid division by zero
event_counts = np.maximum(event_counts, 1)
-
+
# Inverse frequency weighting
event_weights = 1.0 / event_counts
-
+
# Normalize weights to sum to num_competing_risks
event_weights = event_weights * (num_competing_risks / event_weights.sum())
-
+
print(f"Computed balanced event weights: {event_weights}")
-
+
elif args.event_weighting == "custom":
if args.custom_event_weights is None:
- raise ValueError("Custom event weights must be provided when using custom weighting")
-
+ raise ValueError(
+ "Custom event weights must be provided when using custom weighting"
+ )
+
custom_weights = [float(w) for w in args.custom_event_weights.split(",")]
if len(custom_weights) != num_competing_risks:
- raise ValueError(f"Expected {num_competing_risks} weights, got {len(custom_weights)}")
-
+ raise ValueError(
+ f"Expected {num_competing_risks} weights, got {len(custom_weights)}"
+ )
+
event_weights = np.array(custom_weights)
print(f"Using custom event weights: {event_weights}")
-
+
# Convert to torch tensor
event_weights = torch.tensor(event_weights, dtype=torch.float32, device=device)
-
+
# Create Optuna study
study = optuna.create_study(direction="minimize") # Minimize validation loss
-
+
# Run optimization
print("\n=== Starting Optuna hyperparameter optimization ===")
- study.optimize(lambda trial: objective(
- trial, x, t, e, x_train, t_train, e_train, x_val, t_val, e_val,
- feature_names, n_cont, num_competing_risks, device, args, event_weights
- ), n_trials=args.n_trials)
-
+ study.optimize(
+ lambda trial: objective(
+ trial,
+ x,
+ t,
+ e,
+ x_train,
+ t_train,
+ e_train,
+ x_val,
+ t_val,
+ e_val,
+ feature_names,
+ n_cont,
+ num_competing_risks,
+ device,
+ args,
+ event_weights,
+ ),
+ n_trials=args.n_trials,
+ )
+
# Get best parameters
best_params = study.best_params
print("\n=== Best Hyperparameters ===")
for param, value in best_params.items():
print(f"{param}: {value}")
-
# Extract hidden dimensions from best params
- n_layers = best_params.pop('n_layers')
+ n_layers = best_params.pop("n_layers")
hidden_dimensions = []
for i in range(n_layers):
- hidden_dimensions.append(best_params.pop(f'hidden_dim_{i}'))
-
+ hidden_dimensions.append(best_params.pop(f"hidden_dim_{i}"))
+
# Save best hyperparameters to YAML file
config_dict = {
"dataset": args.dataset,
"scaling": args.scaling,
"num_epochs": args.num_epochs,
"batch_size": args.batch_size,
- "learning_rate": best_params['learning_rate'],
- "l2_reg": best_params['l2_reg'],
+ "learning_rate": best_params["learning_rate"],
+ "l2_reg": best_params["l2_reg"],
"patience": args.patience,
- "dropout_rate": best_params['dropout_rate'],
- "feature_dropout": best_params['feature_dropout'],
- "hidden_dimensions": ",".join(map(str, hidden_dimensions)), # Format as comma-separated string
- "batch_norm": str(best_params['batch_norm']), # Convert to string "True" or "False"
+ "dropout_rate": best_params["dropout_rate"],
+ "feature_dropout": best_params["feature_dropout"],
+ "hidden_dimensions": ",".join(
+ map(str, hidden_dimensions)
+ ), # Format as comma-separated string
+ "batch_norm": str(
+ best_params["batch_norm"]
+ ), # Convert to string "True" or "False"
"event_weighting": args.event_weighting,
"seed": args.seed,
- "n_folds": 5 # Default for k-fold script
+ "n_folds": 5, # Default for k-fold script
}
-
+
# Add custom event weights if applicable
if args.event_weighting == "custom" and args.custom_event_weights is not None:
config_dict["custom_event_weights"] = args.custom_event_weights
-
+
# Save to YAML file
output_file = f"best_params_{args.dataset}.yaml"
- with open(output_file, 'w') as file:
+ with open(output_file, "w") as file:
yaml.dump(config_dict, file, default_flow_style=False)
-
+
print(f"\nBest hyperparameters saved to {output_file}")
-
+
# Train final model with best parameters
print("\n=== Training final model with best parameters ===")
-
+
# Create data loaders with all training data
train_dataset = SurvivalDataset(x_train, t_train, e_train)
val_dataset = SurvivalDataset(x_val, t_val, e_val)
-
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size,num_workers=4, pin_memory=True)
-
-
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ num_workers=4,
+ pin_memory=True,
+ shuffle=True,
+ )
+ val_loader = DataLoader(
+ val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True
+ )
+
# Initialize final model with best parameters
final_model = CrispNamModel(
num_features=x.shape[1],
num_competing_risks=num_competing_risks,
hidden_sizes=hidden_dimensions,
- dropout_rate=best_params['dropout_rate'],
- feature_dropout=best_params['feature_dropout'],
- batch_norm=best_params['batch_norm']
+ dropout_rate=best_params["dropout_rate"],
+ feature_dropout=best_params["feature_dropout"],
+ batch_norm=best_params["batch_norm"],
).to(device)
-
+
# Train final model
train_model(
- final_model,
- train_loader,
- val_loader,
- num_epochs=args.num_epochs,
- learning_rate=best_params['learning_rate'],
- l2_reg=best_params['l2_reg'],
+ final_model,
+ train_loader,
+ val_loader,
+ num_epochs=args.num_epochs,
+ learning_rate=best_params["learning_rate"],
+ l2_reg=best_params["l2_reg"],
patience=args.patience,
event_weights=event_weights,
- verbose=True
+ verbose=True,
)
-
+
# Evaluate final model
# Calculate evaluation times
safe_max = 0.99 * np.max(t_train)
eval_times = np.quantile(t_train[t_train <= safe_max], [0.25, 0.5, 0.75])
-
+
print(f"Evaluation times: {eval_times}")
-
+
# Calculate baseline CIFs
- baseline_cifs = {k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
- for k in range(num_competing_risks)}
-
+ baseline_cifs = {
+ k: compute_baseline_cif(t_train, e_train, eval_times, k + 1)
+ for k in range(num_competing_risks)
+ }
+
# Predict and evaluate
- abs_risks = predict_absolute_risk(final_model, x_val, baseline_cifs, eval_times, device=device)
- final_metrics, _ = evaluate_model(final_model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times)
-
+ abs_risks = predict_absolute_risk(
+ final_model, x_val, baseline_cifs, eval_times, device=device
+ )
+ final_metrics, _ = evaluate_model(
+ final_model, x_val, t_val, e_val, t_train, e_train, abs_risks, eval_times
+ )
+
# Display metrics
display_metrics_table(final_metrics)
-
-
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/training_scripts/tune_optuna_optimized.py b/training_scripts/tune_optuna_optimized.py
index eb226813..7daabf4c 100644
--- a/training_scripts/tune_optuna_optimized.py
+++ b/training_scripts/tune_optuna_optimized.py
@@ -1,35 +1,29 @@
-import yaml
-import configargparse
from collections import defaultdict
-import torch
-import optuna
+import configargparse
import numpy as np
-import torch.nn as nn
+import optuna
+import torch
+import yaml
+from model_utils import EarlyStopping, set_seed
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import MinMaxScaler, StandardScaler
+from sksurv.metrics import brier_score, concordance_index_ipcw, cumulative_dynamic_auc
from sksurv.util import Surv
-import torch.optim as optim
from tabulate import tabulate
+from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
-from sklearn.model_selection import train_test_split
-from sklearn.preprocessing import StandardScaler, MinMaxScaler
-from sksurv.metrics import (
- cumulative_dynamic_auc,
- brier_score,
- concordance_index_ipcw
-)
-from data_utils import *
-from model_utils import EarlyStopping, set_seed
from crisp_nam.models import CrispNamModel
from crisp_nam.utils import (
- weighted_negative_log_likelihood_loss,
- negative_log_likelihood_loss,
+ compute_baseline_cif,
compute_l2_penalty,
-)
-from crisp_nam.utils import (
+ negative_log_likelihood_loss,
predict_absolute_risk,
- compute_baseline_cif
+ weighted_negative_log_likelihood_loss,
)
+from data_utils import *
+
def train_model(
model: nn.Module,
@@ -45,7 +39,7 @@ def train_model(
device = next(model.parameters()).device
optimizer = optim.AdamW(model.parameters(), lr=lr)
early_stop = EarlyStopping(patience)
- scaler = torch.amp.GradScaler('cuda')
+ scaler = torch.amp.GradScaler("cuda")
for epoch in range(1, num_epochs + 1):
model.train()
@@ -56,14 +50,16 @@ def train_model(
eb = eb.to(device, non_blocking=True)
optimizer.zero_grad()
- with torch.amp.autocast('cuda'):
+ with torch.amp.autocast("cuda"):
scores, _ = model(xb)
if event_weights is not None:
loss = weighted_negative_log_likelihood_loss(
- scores, tb, eb, model.num_competing_risks, event_weights)
+ scores, tb, eb, model.num_competing_risks, event_weights
+ )
else:
loss = negative_log_likelihood_loss(
- scores, tb, eb, model.num_competing_risks)
+ scores, tb, eb, model.num_competing_risks
+ )
loss = loss + compute_l2_penalty(model) * l2_reg
scaler.scale(loss).backward()
@@ -81,26 +77,30 @@ def train_model(
xb = xb.to(device, non_blocking=True)
tb = tb.to(device, non_blocking=True)
eb = eb.to(device, non_blocking=True)
- with torch.amp.autocast('cuda'):
+ with torch.amp.autocast("cuda"):
scores, _ = model(xb)
if event_weights is not None:
loss = weighted_negative_log_likelihood_loss(
- scores, tb, eb, model.num_competing_risks, event_weights)
+ scores, tb, eb, model.num_competing_risks, event_weights
+ )
else:
loss = negative_log_likelihood_loss(
- scores, tb, eb, model.num_competing_risks)
+ scores, tb, eb, model.num_competing_risks
+ )
loss = loss + compute_l2_penalty(model) * l2_reg
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
if verbose:
- print(f"Epoch {epoch} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
+ print(
+ f"Epoch {epoch} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}"
+ )
if early_stop.step(avg_val_loss):
- if verbose: print("Early stopping.")
+ if verbose:
+ print("Early stopping.")
return avg_val_loss
- else:
- if verbose:
- print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")
+ elif verbose:
+ print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")
return avg_val_loss if val_loader is not None else avg_train_loss
@@ -115,15 +115,16 @@ def objective(
args,
event_weights,
):
-
- lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
- l2 = trial.suggest_float('l2_reg', 1e-5, 1e-1, log=True)
- dropout = trial.suggest_float('dropout_rate', 0.0, 0.8)
- feat_drop = trial.suggest_float('feature_dropout', 0.0, 0.5)
- n_layers = trial.suggest_int('n_layers', 1, 3)
- hidden_dims = [trial.suggest_categorical(f'hidden_dim_{i}', [32,64,128,256])
- for i in range(n_layers)]
- batch_norm = trial.suggest_categorical('batch_norm', [True, False])
+ lr = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
+ l2 = trial.suggest_float("l2_reg", 1e-5, 1e-1, log=True)
+ dropout = trial.suggest_float("dropout_rate", 0.0, 0.8)
+ feat_drop = trial.suggest_float("feature_dropout", 0.0, 0.5)
+ n_layers = trial.suggest_int("n_layers", 1, 3)
+ hidden_dims = [
+ trial.suggest_categorical(f"hidden_dim_{i}", [32, 64, 128, 256])
+ for i in range(n_layers)
+ ]
+ batch_norm = trial.suggest_categorical("batch_norm", [True, False])
loader_kwargs = dict(
batch_size=args.batch_size,
@@ -134,7 +135,7 @@ def objective(
)
train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
- val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
+ val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
model = CrispNamModel(
num_features=num_features,
@@ -171,12 +172,12 @@ def evaluate_model(
device: torch.device,
):
surv_train = Surv.from_arrays(e_train != 0, t_train)
- surv_val = Surv.from_arrays(e_val != 0, t_val)
+ surv_val = Surv.from_arrays(e_val != 0, t_val)
abs_risk = predict_absolute_risk(model, X, baseline_cifs, eval_times, device=device)
metrics = defaultdict(list)
for k_idx, t0 in enumerate(eval_times, start=1):
- preds = abs_risk[:, k_idx-1, :]
+ preds = abs_risk[:, k_idx - 1, :]
for ti_idx, ti in enumerate(eval_times):
r = preds[:, ti_idx]
try:
@@ -185,7 +186,9 @@ def evaluate_model(
except:
metrics[f"auc_event{k_idx}_t{ti:.2f}"].append(np.nan)
try:
- _, bsc = brier_score(surv_train, surv_val, 1 - r.reshape(-1,1), times=np.array([ti]))
+ _, bsc = brier_score(
+ surv_train, surv_val, 1 - r.reshape(-1, 1), times=np.array([ti])
+ )
metrics[f"brier_event{k_idx}_t{ti:.2f}"].append(float(bsc[0]))
except:
metrics[f"brier_event{k_idx}_t{ti:.2f}"].append(np.nan)
@@ -204,21 +207,34 @@ def parse_args():
default_config_files=["config.yml"],
config_file_parser_class=configargparse.YAMLConfigFileParser,
)
- parser.add_argument("--dataset", type=str, default="framingham",
- choices=["framingham","support","pbc","synthetic"] )
- parser.add_argument("--scaling", type=str, default="standard",
- choices=["minmax","standard","none"] )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="framingham",
+ choices=["framingham", "support", "pbc", "synthetic"],
+ )
+ parser.add_argument(
+ "--scaling",
+ type=str,
+ default="standard",
+ choices=["minmax", "standard", "none"],
+ )
parser.add_argument("--num_epochs", type=int, default=500)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--n_trials", type=int, default=50)
- parser.add_argument("--event_weighting", type=str, default="none",
- choices=["none","balanced","custom"] )
+ parser.add_argument(
+ "--event_weighting",
+ type=str,
+ default="none",
+ choices=["none", "balanced", "custom"],
+ )
parser.add_argument("--custom_event_weights", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)
return parser.parse_args()
+
# ----------------------- Main ----------------------- #
def main():
args = parse_args()
@@ -226,74 +242,81 @@ def main():
set_seed(args.seed)
# Load dataset
- if args.dataset == 'framingham':
+ if args.dataset == "framingham":
x, t, e, feature_names, n_cont, _ = load_framingham()
- elif args.dataset == 'support':
+ elif args.dataset == "support":
x, t, e, feature_names, n_cont, _ = load_support_dataset()
- elif args.dataset == 'pbc':
+ elif args.dataset == "pbc":
x, t, e, feature_names, n_cont, _ = load_pbc2_dataset()
else:
x, t, e, feature_names, n_cont, _ = load_synthetic_dataset()
# Scale continuous features
- if args.scaling == 'standard':
+ if args.scaling == "standard":
x[:, -n_cont:] = StandardScaler().fit_transform(x[:, -n_cont:])
- elif args.scaling == 'minmax':
+ elif args.scaling == "minmax":
x[:, -n_cont:] = MinMaxScaler().fit_transform(x[:, -n_cont:])
# Bulk convert to torch tensors on CPU
- X = torch.from_numpy(x.astype('float32'))
- T = torch.from_numpy(t.astype('float32'))
- E = torch.from_numpy(e.astype('int64'))
+ X = torch.from_numpy(x.astype("float32"))
+ T = torch.from_numpy(t.astype("float32"))
+ E = torch.from_numpy(e.astype("int64"))
# Train/validation split (fixed)
idx_train, idx_val = train_test_split(
np.arange(len(e)), test_size=0.2, random_state=args.seed, stratify=e
)
train_ds = TensorDataset(X[idx_train], T[idx_train], E[idx_train])
- val_ds = TensorDataset(X[idx_val], T[idx_val], E[idx_val])
+ val_ds = TensorDataset(X[idx_val], T[idx_val], E[idx_val])
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_competing_risks = len(np.unique(e)) - 1
# Event weighting
event_weights = None
- if args.event_weighting == 'balanced':
- counts = np.array([(e[idx_train]==k).sum() for k in range(1, num_competing_risks+1)])
+ if args.event_weighting == "balanced":
+ counts = np.array(
+ [(e[idx_train] == k).sum() for k in range(1, num_competing_risks + 1)]
+ )
counts = np.maximum(counts, 1)
w = 1.0 / counts
w *= num_competing_risks / w.sum()
- event_weights = torch.from_numpy(w.astype('float32')).to(device)
- elif args.event_weighting == 'custom':
- w = np.array(list(map(float, args.custom_event_weights.split(','))))
- event_weights = torch.from_numpy(w.astype('float32')).to(device)
+ event_weights = torch.from_numpy(w.astype("float32")).to(device)
+ elif args.event_weighting == "custom":
+ w = np.array(list(map(float, args.custom_event_weights.split(","))))
+ event_weights = torch.from_numpy(w.astype("float32")).to(device)
# Optuna study
- study = optuna.create_study(direction='minimize')
+ study = optuna.create_study(direction="minimize")
study.optimize(
lambda trial: objective(
- trial, train_ds, val_ds,
- x.shape[1], num_competing_risks,
- device, args, event_weights
+ trial,
+ train_ds,
+ val_ds,
+ x.shape[1],
+ num_competing_risks,
+ device,
+ args,
+ event_weights,
),
- n_trials=args.n_trials
+ n_trials=args.n_trials,
)
best = study.best_params
print("Best hyperparameters:", best)
- with open(f"best_params_{args.dataset}.yaml", 'w') as f:
+ with open(f"best_params_{args.dataset}.yaml", "w") as f:
yaml.dump(best, f)
# Final model training with best params
- n_layers = best['n_layers']
- hidden_dims = [best[f'hidden_dim_{i}'] for i in range(n_layers)]
+ n_layers = best["n_layers"]
+ hidden_dims = [best[f"hidden_dim_{i}"] for i in range(n_layers)]
final_model = CrispNamModel(
num_features=x.shape[1],
num_competing_risks=num_competing_risks,
hidden_sizes=hidden_dims,
- dropout_rate=best['dropout_rate'],
- feature_dropout=best['feature_dropout'],
- batch_norm=best['batch_norm'],
+ dropout_rate=best["dropout_rate"],
+ feature_dropout=best["feature_dropout"],
+ batch_norm=best["batch_norm"],
).to(device)
loader_kwargs = dict(
@@ -304,15 +327,15 @@ def main():
prefetch_factor=2,
)
final_train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
- final_val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
+ final_val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
train_model(
final_model,
final_train_loader,
val_loader=final_val_loader,
num_epochs=args.num_epochs,
- lr=best['learning_rate'],
- l2_reg=best['l2_reg'],
+ lr=best["learning_rate"],
+ l2_reg=best["l2_reg"],
patience=args.patience,
event_weights=event_weights,
verbose=True,
@@ -322,29 +345,36 @@ def main():
safe_max = 0.99 * np.max(t[idx_train])
eval_times = np.quantile(t[idx_train][t[idx_train] <= safe_max], [0.25, 0.5, 0.75])
baseline_cifs = {
- k: compute_baseline_cif(
- t[idx_train], e[idx_train], eval_times, k+1
- ) for k in range(num_competing_risks)
+ k: compute_baseline_cif(t[idx_train], e[idx_train], eval_times, k + 1)
+ for k in range(num_competing_risks)
}
metrics = evaluate_model(
final_model,
- x[idx_val], t[idx_val], e[idx_val],
- t[idx_train], e[idx_train],
- baseline_cifs, eval_times, device
+ x[idx_val],
+ t[idx_val],
+ e[idx_val],
+ t[idx_train],
+ e[idx_train],
+ baseline_cifs,
+ eval_times,
+ device,
)
# Reporting
rows = []
- for k in range(1, num_competing_risks+1):
- row = {'Risk': f'Type {k}'}
- for metric in ['auc', 'tdci', 'brier']:
+ for k in range(1, num_competing_risks + 1):
+ row = {"Risk": f"Type {k}"}
+ for metric in ["auc", "tdci", "brier"]:
for q, ti in zip([0.25, 0.5, 0.75], eval_times):
key = f"{metric}_event{k}_t{ti:.2f}"
vals = np.array(metrics[key], dtype=float)
- row[f"{metric.upper()}_q{q:.2f}"] = f"{np.nanmean(vals):.3f} ± {np.nanstd(vals):.3f}"
+ row[f"{metric.upper()}_q{q:.2f}"] = (
+ f"{np.nanmean(vals):.3f} ± {np.nanstd(vals):.3f}"
+ )
rows.append(row)
print("\n=== Performance ===")
- print(tabulate(rows, headers='keys', tablefmt='pretty', showindex=False))
+ print(tabulate(rows, headers="keys", tablefmt="pretty", showindex=False))
+
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/utils/__init__.py b/utils/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/utils/loss.py b/utils/loss.py
deleted file mode 100644
index 9cae44fd..00000000
--- a/utils/loss.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import torch
-
-def weighted_negative_log_likelihood_loss(risk_scores, times, events,
- num_competing_risks, event_weights=None,
- sample_weights=None, eps=1e-8) -> float:
- """
- Computes the weighted negative log-likelihood loss for competing risks Cox model.
-
- Args:
- risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
- times: Event/censoring times (batch_size,)
- events: Event indicators (0=censored, 1...K=event types) (batch_size,)
- num_competing_risks: Number of competing risks
- event_weights: Tensor of weights for each competing risk type (size: num_competing_risks)
- sample_weights: Tensor of weights for each sample (size: batch_size)
- eps: Small constant for numerical stability
-
- Returns:
- Weighted negative log partial likelihood loss
- """
- device = times.device
- batch_size = times.shape[0]
-
- # Initialize loss
- loss = torch.tensor(0.0, device=device)
-
- # Set default weights if not provided
- if event_weights is None:
- event_weights = torch.ones(num_competing_risks, device=device)
- if sample_weights is None:
- sample_weights = torch.ones(batch_size, device=device)
-
- # Count number of events
- n_events = (events > 0).sum().item()
- if n_events == 0:
- return loss
-
- # Process each competing risk separately
- for k in range(1, num_competing_risks + 1):
- # Find samples with this event type
- event_mask = (events == k)
- n_events_k = event_mask.sum().item()
-
- if n_events_k == 0:
- continue
-
- # Get risk scores for this competing risk
- risk_k = risk_scores[k-1].squeeze()
-
- # Get weight for this event type
- event_weight = event_weights[k-1]
-
- # For each event of type k
- for i in range(batch_size):
- if event_mask[i]:
- # Find samples in risk set (samples with time >= event time)
- risk_set = (times >= times[i])
-
- # Calculate log sum of exp of risk scores in risk set
- risk_set_scores = risk_k[risk_set]
- log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)
-
- # Subtract individual risk score from log sum and apply weights
- individual_loss = log_risk_sum - risk_k[i]
- weighted_individual_loss = individual_loss * event_weight * sample_weights[i]
- loss += weighted_individual_loss
-
- # Return average loss
- return loss / max(n_events, 1)
-
-def negative_log_likelihood_loss(risk_scores, times, events,
- num_competing_risks, eps=1e-8):
- """
- Computes the negative log-likelihood loss for competing risks Cox model.
-
- Args:
- risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
- times: Event/censoring times (batch_size,)
- events: Event indicators (0=censored, 1...K=event types) (batch_size,)
- num_competing_risks: Number of competing risks
- eps: Small constant for numerical stability
-
- Returns:
- Negative log partial likelihood loss
- """
- device = times.device
- batch_size = times.shape[0]
-
- # Initialize loss
- loss = torch.tensor(0.0, device=device)
-
- # Count number of events
- n_events = (events > 0).sum().item()
- if n_events == 0:
- return loss
-
- # Process each competing risk separately
- for k in range(1, num_competing_risks + 1):
- # Find samples with this event type
- event_mask = (events == k)
- n_events_k = event_mask.sum().item()
-
- if n_events_k == 0:
- continue
-
- # Get risk scores for this competing risk
- risk_k = risk_scores[k-1].squeeze()
-
- # For each event of type k
- for i in range(batch_size):
- if event_mask[i]:
- # Find samples in risk set (samples with time >= event time)
- risk_set = (times >= times[i])
-
- # Calculate log sum of exp of risk scores in risk set
- risk_set_scores = risk_k[risk_set]
- log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)
-
- # Subtract individual risk score from log sum
- loss += log_risk_sum - risk_k[i]
-
- # Return average loss
- return loss / max(n_events, 1)
-
-def compute_l2_penalty(model, include_bias=False) -> int:
- """
- Compute L2 regularization penalty on model parameters
-
- Args:
- model: Neural network model
- include_bias: Whether to include bias terms in regularization
-
- Returns:
- L2 penalty term
- """
- l2_reg = 0.0
- for name, param in model.named_parameters():
- if param.requires_grad:
- # Skip bias parameters if specified
- if not include_bias and 'bias' in name:
- continue
- l2_reg += torch.sum(param ** 2)
- return l2_reg
\ No newline at end of file
diff --git a/utils/plotting.py b/utils/plotting.py
deleted file mode 100644
index 1702632e..00000000
--- a/utils/plotting.py
+++ /dev/null
@@ -1,162 +0,0 @@
-import torch
-import numpy as np
-import pandas as pd
-import matplotlib.pyplot as plt
-from typing import Union, List
-
-def plot_feature_importance(model: torch.nn.Module,
- x_data: Union[np.ndarray, torch.Tensor],
- feature_names = None,
- n_top:int = 5,
- n_bottom:int = 5,
- risk_idx:int = 1,
- figsize: tuple = (8, 6),
- output_file:str = None,
- color_positive:str = '#2196F3',
- color_negative:str = '#F44336') -> tuple:
- """
- Plot feature importance with both top positive and negative influences,
- handling both CPU and CUDA devices automatically.
- """
- # determine model device
- device = next(model.parameters()).device
- model.eval()
-
- # prepare feature names
- num_features = model.num_features
- if feature_names is None:
- feature_names = [f"Feature {i+1}" for i in range(num_features)]
-
- # convert x_data to tensor on the model device
- if not isinstance(x_data, torch.Tensor):
- x = torch.tensor(x_data, dtype=torch.float32, device=device)
- else:
- x = x_data.to(device)
-
- feature_contribs = {}
- risk_idx0 = risk_idx - 1
-
- with torch.no_grad():
- for i in range(num_features):
- vals = x[:, i].unsqueeze(1) # shape (N,1)
- if torch.var(vals) <= 1e-8:
- feature_contribs[feature_names[i]] = 0.0
- continue
-
- # forward through the feature net and projection
- rep = model.feature_nets[i](vals)
- proj = model.risk_projections[i][risk_idx0](rep)
- # mean contribution as a Python float
- contrib = proj.mean().item()
- feature_contribs[feature_names[i]] = contrib
-
- # build a DataFrame for sorting
- df = pd.DataFrame({
- 'feature': list(feature_contribs.keys()),
- 'contribution': list(feature_contribs.values())
- })
- df['abs_contrib'] = df['contribution'].abs()
- df = df.sort_values('abs_contrib', ascending=False)
-
- pos = df[df['contribution'] > 0].head(n_top).sort_values('contribution')
- neg = df[df['contribution'] < 0].head(n_bottom).sort_values('contribution', ascending=False)
-
- top_pos = pos['feature'].tolist()
- top_neg = neg['feature'].tolist()
-
- # plotting
- fig, ax = plt.subplots(figsize=figsize)
- ax.barh(pos['feature'], pos['contribution'], color=color_positive, alpha=0.8)
- ax.barh(neg['feature'], neg['contribution'], color=color_negative, alpha=0.8)
- ax.axvline(0, color='black', linestyle='-', alpha=0.3)
-
- ax.set_xlabel('Contribution to Risk Score')
- ax.set_title(f'Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}')
- ax.grid(axis='x', linestyle='--', alpha=0.5)
- plt.tight_layout()
-
- if output_file:
- plt.savefig(output_file, bbox_inches='tight', dpi=300)
-
- return fig, ax, top_pos, top_neg
-
-
-def plot_coxnam_shape_functions(model:torch.nn.Module,
- X: Union[np.ndarray, torch.Tensor],
- risk_to_plot:int = 1,
- feature_names:List[str] = None,
- top_features:List[str] = None,
- ncols:int = 3,
- figsize:tuple = (12, 8),
- output_file:str = None) -> tuple:
- """
- Plot shape functions for each feature in a CoxNAM model,
- automatically handling CPU vs CUDA inputs.
- """
- device = next(model.parameters()).device
- model.eval()
- risk_idx = risk_to_plot - 1
-
- # ensure X is a numpy array
- if isinstance(X, torch.Tensor):
- X_np = X.cpu().numpy()
- else:
- X_np = np.array(X, dtype=float)
-
- # derive feature list
- num_features = model.num_features
- if feature_names is None:
- feature_names = [f"Feature {i+1}" for i in range(num_features)]
- if top_features:
- # map names back to indices
- idx_map = {name: i for i, name in enumerate(feature_names)}
- selected = [(idx_map.get(name, None), name) for name in top_features]
- selected = [(i, name) for i, name in selected if i is not None]
- else:
- selected = list(zip(range(num_features), feature_names))
-
- n_selected = len(selected)
- nrows = int(np.ceil(n_selected / ncols))
- fig, axes = plt.subplots(nrows, ncols, figsize=(figsize))
- axes = np.array(axes).reshape(-1)
-
- with torch.no_grad():
- for ax, (f_idx, fname) in zip(axes, selected):
- vals = X_np[:, f_idx]
- if vals.size == 0:
- ax.text(0.5, 0.5, "no data", ha='center', va='center')
- continue
-
- # choose evaluation points
- if np.issubdtype(vals.dtype, np.integer) or len(np.unique(vals)) <= 10:
- pts = np.unique(vals)
- else:
- pts = np.linspace(vals.min(), vals.max(), 100)
-
- # convert to tensor on correct device
- t_pts = torch.tensor(pts, dtype=torch.float32, device=device).unsqueeze(1)
-
- # compute shape values
- rep = model.feature_nets[f_idx](t_pts)
- proj = model.risk_projections[f_idx][risk_idx](rep)
- shp = proj.squeeze(-1).cpu().numpy()
-
- # plot
- ax.plot(pts, shp, linewidth=2)
- ax.axhline(0, linestyle='--', alpha=0.5)
- ax.set_title(fname)
- ax.set_xlabel('Value')
- ax.set_ylabel('Contribution')
- # rug plot
- ax.plot(vals, np.zeros_like(vals)-0.1, '|', alpha=0.3)
-
- # turn off any extra axes
- for ax in axes[n_selected:]:
- ax.axis('off')
-
- fig.suptitle(f'Shape Functions for Risk {risk_to_plot}', fontsize=14)
- plt.tight_layout(rect=[0,0,1,0.96])
- if output_file:
- plt.savefig(output_file, dpi=300, bbox_inches='tight')
-
- return fig, axes[:n_selected]
diff --git a/utils/risk_cif.py b/utils/risk_cif.py
deleted file mode 100644
index 0cd38d54..00000000
--- a/utils/risk_cif.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import torch
-import numpy as np
-from typing import List, Any
-
-def compute_baseline_cif(times:np.ndarray,
- events:np.ndarray,
- eval_times:List[Any],
- event_type:np.ndarray) -> np.ndarray:
- """
- Compute baseline cumulative incidence function for a specific event type
-
- Args:
- times: Numpy array of event times
- events: Numpy array of event indicators (0=censored, 1...K=event types)
- eval_times: Times at which to evaluate the CIF
- event_type: Event type to compute CIF for (1...K)
-
- Returns:
- Numpy array of baseline CIF values at eval_times
- """
- # Sort times and corresponding events
- sort_idx = np.argsort(times)
- sorted_times = times[sort_idx]
- sorted_events = events[sort_idx]
-
- # Initialize cumulative hazard
- n_samples = len(times)
- baseline_cif = np.zeros(len(eval_times))
-
- # For each evaluation time
- for i, t in enumerate(eval_times):
- cif_t = 0.0
- # Count number of events of the specified type before time t
- event_count = np.sum((sorted_events == event_type) & (sorted_times <= t))
- if event_count > 0:
- # Simple Aalen-Johansen estimator
- cif_t = event_count / n_samples
- baseline_cif[i] = cif_t
-
- return baseline_cif
-
-
-def predict_cif(model:torch.Module,
- x:np.ndarray,
- baseline_cif:np.ndarray,
- times:np.ndarray,
- event_of_interest:int) -> np.ndarray:
- """
- Predict cumulative incidence function for a specific competing risk.
-
- Args:
- model: Trained model.
- x: Input tensor of shape (n_samples, n_features).
- baseline_cif: Array of shape (len(times),) — estimated CIF for baseline (e.g. from compute_baseline_cif).
- times: Time points at which CIF is evaluated.
- event_type: Integer, 0-based index of event of interest.
-
- Returns:
- cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample.
- """
- model.eval()
- with torch.no_grad():
- logits, _ = model(x) # list of length num_risks
- f_j_x = logits[event_of_interest].squeeze(1).cpu().numpy() # (n_samples,)
-
- baseline_cif = np.asarray(baseline_cif).reshape(1, -1) # (1, T)
- risk_scores = np.exp(f_j_x).reshape(-1, 1) # (N, 1)
-
- # Fine-Gray style CIF prediction under PH assumption
- cif_pred = 1.0 - np.power(1.0 - baseline_cif, risk_scores) # shape (N, T)
-
- return cif_pred
-
-def predict_risk(model:np.ndarray,
- x_input:np.ndarray,
- device:str = 'cpu'):
- """
- Predicts relative risk scores for each competing risk.
-
- Args:
- model : Trained model.
- x_input (np.ndarray or torch.Tensor): Input features of shape (n_samples, n_features).
- device (str): Device to run the computation on.
-
- Returns:
- np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores.
- """
- model.eval()
-
- if isinstance(x_input, np.ndarray):
- x_tensor = torch.from_numpy(x_input).float().to(device)
- else:
- x_tensor = x_input.to(device).float()
-
- with torch.no_grad():
- risk_outputs, _ = model(x_tensor) # List of [batch_size, 1] tensors
- risks = torch.cat(risk_outputs, dim=1) # Shape: [batch_size, num_risks]
-
- return risks.cpu().numpy()
-
-def predict_absolute_risk(model:torch.Tensor,
- x_input:np.ndarray,
- baseline_cifs:List[Any],
- times:List[Any],
- device:str = 'cpu') -> np.ndarray:
- """
- Predict absolute risk (CIF) for each competing event by given time points.
-
- Args:
- model: Trained model.
- x_input (np.ndarray or Tensor): Input features, shape (n_samples, n_features).
- baseline_cifs (dict): Mapping of event index to baseline CIF array of shape (n_times,).
- times (np.ndarray): Time grid used for baseline_cifs.
- device: CPU or CUDA.
-
- Returns:
- np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs.
- """
- rel_risks = predict_risk(model, x_input, device) # shape (n_samples, num_events)
- n_samples, num_events = rel_risks.shape
- n_times = len(times)
-
- abs_risks = np.zeros((n_samples, num_events, n_times))
-
- for k in range(num_events):
- base_cif = np.clip(baseline_cifs[k], 1e-10, 0.9999) # avoid edge cases
- for i in range(n_samples):
- abs_risks[i, k, :] = 1 - np.power(1 - base_cif, np.exp(rel_risks[i, k]))
-
- return abs_risks
diff --git a/uv.lock b/uv.lock
index 1927b861..a84729ed 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,4 +1,5 @@
version = 1
+revision = 1
requires-python = ">=3.10"
resolution-markers = [
"python_full_version >= '3.13'",
@@ -45,6 +46,139 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/85/ae/7f2031ea76140444b2453fa139041e5afd4a09fc5300cfefeb1103291f80/autograd-gamma-0.5.0.tar.gz", hash = "sha256:f27abb7b8bb9cffc8badcbf59f3fe44a9db39e124ecacf1992b6d952934ac9c4", size = 3952 }
+[[package]]
+name = "babel"
+version = "2.18.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/7d/b2/51899539b6ceeeb420d40ed3cd4b7a40519404f9baf3d4ac99dc413a834b/babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d", size = 9959554 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/77/f5/21d2de20e8b8b0408f0681956ca2c69f1320a3848ac50e6e7f39c6159675/babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35", size = 10196845 },
+]
+
+[[package]]
+name = "backrefs"
+version = "6.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/86/e3/bb3a439d5cb255c4774724810ad8073830fac9c9dee123555820c1bcc806/backrefs-6.1.tar.gz", hash = "sha256:3bba1749aafe1db9b915f00e0dd166cba613b6f788ffd63060ac3485dc9be231", size = 7011962 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3b/ee/c216d52f58ea75b5e1841022bbae24438b19834a29b163cb32aa3a2a7c6e/backrefs-6.1-py310-none-any.whl", hash = "sha256:2a2ccb96302337ce61ee4717ceacfbf26ba4efb1d55af86564b8bbaeda39cac1", size = 381059 },
+ { url = "https://files.pythonhosted.org/packages/e6/9a/8da246d988ded941da96c7ed945d63e94a445637eaad985a0ed88787cb89/backrefs-6.1-py311-none-any.whl", hash = "sha256:e82bba3875ee4430f4de4b6db19429a27275d95a5f3773c57e9e18abc23fd2b7", size = 392854 },
+ { url = "https://files.pythonhosted.org/packages/37/c9/fd117a6f9300c62bbc33bc337fd2b3c6bfe28b6e9701de336b52d7a797ad/backrefs-6.1-py312-none-any.whl", hash = "sha256:c64698c8d2269343d88947c0735cb4b78745bd3ba590e10313fbf3f78c34da5a", size = 398770 },
+ { url = "https://files.pythonhosted.org/packages/eb/95/7118e935b0b0bd3f94dfec2d852fd4e4f4f9757bdb49850519acd245cd3a/backrefs-6.1-py313-none-any.whl", hash = "sha256:4c9d3dc1e2e558965202c012304f33d4e0e477e1c103663fd2c3cc9bb18b0d05", size = 400726 },
+ { url = "https://files.pythonhosted.org/packages/1d/72/6296bad135bfafd3254ae3648cd152980a424bd6fed64a101af00cc7ba31/backrefs-6.1-py314-none-any.whl", hash = "sha256:13eafbc9ccd5222e9c1f0bec563e6d2a6d21514962f11e7fc79872fd56cbc853", size = 412584 },
+ { url = "https://files.pythonhosted.org/packages/02/e3/a4fa1946722c4c7b063cc25043a12d9ce9b4323777f89643be74cef2993c/backrefs-6.1-py39-none-any.whl", hash = "sha256:a9e99b8a4867852cad177a6430e31b0f6e495d65f8c6c134b68c14c3c95bf4b0", size = 381058 },
+]
+
+[[package]]
+name = "certifi"
+version = "2026.1.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900 },
+]
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709 },
+ { url = "https://files.pythonhosted.org/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814 },
+ { url = "https://files.pythonhosted.org/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467 },
+ { url = "https://files.pythonhosted.org/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280 },
+ { url = "https://files.pythonhosted.org/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454 },
+ { url = "https://files.pythonhosted.org/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609 },
+ { url = "https://files.pythonhosted.org/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849 },
+ { url = "https://files.pythonhosted.org/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586 },
+ { url = "https://files.pythonhosted.org/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290 },
+ { url = "https://files.pythonhosted.org/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663 },
+ { url = "https://files.pythonhosted.org/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964 },
+ { url = "https://files.pythonhosted.org/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064 },
+ { url = "https://files.pythonhosted.org/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015 },
+ { url = "https://files.pythonhosted.org/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792 },
+ { url = "https://files.pythonhosted.org/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198 },
+ { url = "https://files.pythonhosted.org/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262 },
+ { url = "https://files.pythonhosted.org/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988 },
+ { url = "https://files.pythonhosted.org/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324 },
+ { url = "https://files.pythonhosted.org/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742 },
+ { url = "https://files.pythonhosted.org/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863 },
+ { url = "https://files.pythonhosted.org/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837 },
+ { url = "https://files.pythonhosted.org/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550 },
+ { url = "https://files.pythonhosted.org/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162 },
+ { url = "https://files.pythonhosted.org/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019 },
+ { url = "https://files.pythonhosted.org/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310 },
+ { url = "https://files.pythonhosted.org/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022 },
+ { url = "https://files.pythonhosted.org/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383 },
+ { url = "https://files.pythonhosted.org/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098 },
+ { url = "https://files.pythonhosted.org/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991 },
+ { url = "https://files.pythonhosted.org/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456 },
+ { url = "https://files.pythonhosted.org/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978 },
+ { url = "https://files.pythonhosted.org/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969 },
+ { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425 },
+ { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162 },
+ { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558 },
+ { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497 },
+ { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240 },
+ { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471 },
+ { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864 },
+ { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647 },
+ { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110 },
+ { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839 },
+ { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667 },
+ { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535 },
+ { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816 },
+ { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694 },
+ { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131 },
+ { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390 },
+ { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091 },
+ { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936 },
+ { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180 },
+ { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346 },
+ { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874 },
+ { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076 },
+ { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601 },
+ { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376 },
+ { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825 },
+ { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583 },
+ { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366 },
+ { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300 },
+ { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465 },
+ { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404 },
+ { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092 },
+ { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408 },
+ { url = "https://files.pythonhosted.org/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746 },
+ { url = "https://files.pythonhosted.org/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889 },
+ { url = "https://files.pythonhosted.org/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641 },
+ { url = "https://files.pythonhosted.org/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779 },
+ { url = "https://files.pythonhosted.org/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035 },
+ { url = "https://files.pythonhosted.org/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542 },
+ { url = "https://files.pythonhosted.org/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524 },
+ { url = "https://files.pythonhosted.org/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395 },
+ { url = "https://files.pythonhosted.org/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680 },
+ { url = "https://files.pythonhosted.org/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045 },
+ { url = "https://files.pythonhosted.org/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687 },
+ { url = "https://files.pythonhosted.org/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014 },
+ { url = "https://files.pythonhosted.org/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044 },
+ { url = "https://files.pythonhosted.org/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940 },
+ { url = "https://files.pythonhosted.org/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104 },
+ { url = "https://files.pythonhosted.org/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743 },
+ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402 },
+]
+
+[[package]]
+name = "click"
+version = "8.2.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215 },
+]
+
[[package]]
name = "colorama"
version = "0.4.6"
@@ -151,6 +285,7 @@ dependencies = [
{ name = "configargparse" },
{ name = "lifelines" },
{ name = "matplotlib" },
+ { name = "mypy" },
{ name = "openpyxl" },
{ name = "optuna" },
{ name = "pandas" },
@@ -161,11 +296,23 @@ dependencies = [
{ name = "torch" },
]
+[package.dev-dependencies]
+docs = [
+ { name = "click" },
+ { name = "mike" },
+ { name = "mkdocs" },
+ { name = "mkdocs-material" },
+ { name = "mkdocstrings" },
+ { name = "mkdocstrings-python" },
+ { name = "pymdown-extensions" },
+]
+
[package.metadata]
requires-dist = [
{ name = "configargparse", specifier = ">=1.7" },
{ name = "lifelines", specifier = ">=0.30.0" },
{ name = "matplotlib", specifier = ">=3.10.1" },
+ { name = "mypy", specifier = ">=1.19.1" },
{ name = "openpyxl", specifier = ">=3.1.5" },
{ name = "optuna", specifier = ">=4.3.0" },
{ name = "pandas", specifier = ">=2.2.3" },
@@ -176,6 +323,17 @@ requires-dist = [
{ name = "torch", specifier = ">=2.7.0" },
]
+[package.metadata.requires-dev]
+docs = [
+ { name = "click", specifier = "<=8.2.1" },
+ { name = "mike", specifier = ">=2.0.0" },
+ { name = "mkdocs", specifier = ">=1.5.3" },
+ { name = "mkdocs-material", specifier = ">=9.5.12" },
+ { name = "mkdocstrings", specifier = ">=0.24.1" },
+ { name = "mkdocstrings-python", specifier = ">=1.8.0" },
+ { name = "pymdown-extensions", specifier = ">=10.7.1" },
+]
+
[[package]]
name = "cycler"
version = "0.12.1"
@@ -293,6 +451,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/61/78c7b3851add1481b048b5fdc29067397a1784e2910592bc81bb3f608635/fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462", size = 199052 },
]
+[[package]]
+name = "ghp-import"
+version = "2.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "python-dateutil" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/d9/29/d40217cbe2f6b1359e00c6c307bb3fc876ba74068cbab3dde77f03ca0dc4/ghp-import-2.1.0.tar.gz", hash = "sha256:9c535c4c61193c2df8871222567d7fd7e5014d835f97dc7b7439069e2413d343", size = 10943 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 },
+]
+
[[package]]
name = "greenlet"
version = "3.2.3"
@@ -344,6 +514,68 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/4f/aab73ecaa6b3086a4c89863d94cf26fa84cbff63f52ce9bc4342b3087a06/greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a", size = 301236 },
]
+[[package]]
+name = "griffe"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "griffecli" },
+ { name = "griffelib" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214 },
+]
+
+[[package]]
+name = "griffecli"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama" },
+ { name = "griffelib" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345 },
+]
+
+[[package]]
+name = "griffelib"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004 },
+]
+
+[[package]]
+name = "idna"
+version = "3.11"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008 },
+]
+
+[[package]]
+name = "importlib-metadata"
+version = "8.7.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "zipp" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865 },
+]
+
+[[package]]
+name = "importlib-resources"
+version = "6.5.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 },
+]
+
[[package]]
name = "interface-meta"
version = "1.3.0"
@@ -461,6 +693,79 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 },
]
+[[package]]
+name = "librt"
+version = "0.7.8"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/24/5f3646ff414285e0f7708fa4e946b9bf538345a41d1c375c439467721a5e/librt-0.7.8.tar.gz", hash = "sha256:1a4ede613941d9c3470b0368be851df6bb78ab218635512d0370b27a277a0862", size = 148323 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/44/13/57b06758a13550c5f09563893b004f98e9537ee6ec67b7df85c3571c8832/librt-0.7.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b45306a1fc5f53c9330fbee134d8b3227fe5da2ab09813b892790400aa49352d", size = 56521 },
+ { url = "https://files.pythonhosted.org/packages/c2/24/bbea34d1452a10612fb45ac8356f95351ba40c2517e429602160a49d1fd0/librt-0.7.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:864c4b7083eeee250ed55135d2127b260d7eb4b5e953a9e5df09c852e327961b", size = 58456 },
+ { url = "https://files.pythonhosted.org/packages/04/72/a168808f92253ec3a810beb1eceebc465701197dbc7e865a1c9ceb3c22c7/librt-0.7.8-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6938cc2de153bc927ed8d71c7d2f2ae01b4e96359126c602721340eb7ce1a92d", size = 164392 },
+ { url = "https://files.pythonhosted.org/packages/14/5c/4c0d406f1b02735c2e7af8ff1ff03a6577b1369b91aa934a9fa2cc42c7ce/librt-0.7.8-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66daa6ac5de4288a5bbfbe55b4caa7bf0cd26b3269c7a476ffe8ce45f837f87d", size = 172959 },
+ { url = "https://files.pythonhosted.org/packages/82/5f/3e85351c523f73ad8d938989e9a58c7f59fb9c17f761b9981b43f0025ce7/librt-0.7.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4864045f49dc9c974dadb942ac56a74cd0479a2aafa51ce272c490a82322ea3c", size = 186717 },
+ { url = "https://files.pythonhosted.org/packages/08/f8/18bfe092e402d00fe00d33aa1e01dda1bd583ca100b393b4373847eade6d/librt-0.7.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a36515b1328dc5b3ffce79fe204985ca8572525452eacabee2166f44bb387b2c", size = 184585 },
+ { url = "https://files.pythonhosted.org/packages/4e/fc/f43972ff56fd790a9fa55028a52ccea1875100edbb856b705bd393b601e3/librt-0.7.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b7e7f140c5169798f90b80d6e607ed2ba5059784968a004107c88ad61fb3641d", size = 180497 },
+ { url = "https://files.pythonhosted.org/packages/e1/3a/25e36030315a410d3ad0b7d0f19f5f188e88d1613d7d3fd8150523ea1093/librt-0.7.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ff71447cb778a4f772ddc4ce360e6ba9c95527ed84a52096bd1bbf9fee2ec7c0", size = 200052 },
+ { url = "https://files.pythonhosted.org/packages/fc/b8/f3a5a1931ae2a6ad92bf6893b9ef44325b88641d58723529e2c2935e8abe/librt-0.7.8-cp310-cp310-win32.whl", hash = "sha256:047164e5f68b7a8ebdf9fae91a3c2161d3192418aadd61ddd3a86a56cbe3dc85", size = 43477 },
+ { url = "https://files.pythonhosted.org/packages/fe/91/c4202779366bc19f871b4ad25db10fcfa1e313c7893feb942f32668e8597/librt-0.7.8-cp310-cp310-win_amd64.whl", hash = "sha256:d6f254d096d84156a46a84861183c183d30734e52383602443292644d895047c", size = 49806 },
+ { url = "https://files.pythonhosted.org/packages/1b/a3/87ea9c1049f2c781177496ebee29430e4631f439b8553a4969c88747d5d8/librt-0.7.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ff3e9c11aa260c31493d4b3197d1e28dd07768594a4f92bec4506849d736248f", size = 56507 },
+ { url = "https://files.pythonhosted.org/packages/5e/4a/23bcef149f37f771ad30203d561fcfd45b02bc54947b91f7a9ac34815747/librt-0.7.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddb52499d0b3ed4aa88746aaf6f36a08314677d5c346234c3987ddc506404eac", size = 58455 },
+ { url = "https://files.pythonhosted.org/packages/22/6e/46eb9b85c1b9761e0f42b6e6311e1cc544843ac897457062b9d5d0b21df4/librt-0.7.8-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e9c0afebbe6ce177ae8edba0c7c4d626f2a0fc12c33bb993d163817c41a7a05c", size = 164956 },
+ { url = "https://files.pythonhosted.org/packages/7a/3f/aa7c7f6829fb83989feb7ba9aa11c662b34b4bd4bd5b262f2876ba3db58d/librt-0.7.8-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:631599598e2c76ded400c0a8722dec09217c89ff64dc54b060f598ed68e7d2a8", size = 174364 },
+ { url = "https://files.pythonhosted.org/packages/3f/2d/d57d154b40b11f2cb851c4df0d4c4456bacd9b1ccc4ecb593ddec56c1a8b/librt-0.7.8-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c1ba843ae20db09b9d5c80475376168feb2640ce91cd9906414f23cc267a1ff", size = 188034 },
+ { url = "https://files.pythonhosted.org/packages/59/f9/36c4dad00925c16cd69d744b87f7001792691857d3b79187e7a673e812fb/librt-0.7.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b5b007bb22ea4b255d3ee39dfd06d12534de2fcc3438567d9f48cdaf67ae1ae3", size = 186295 },
+ { url = "https://files.pythonhosted.org/packages/23/9b/8a9889d3df5efb67695a67785028ccd58e661c3018237b73ad081691d0cb/librt-0.7.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:dbd79caaf77a3f590cbe32dc2447f718772d6eea59656a7dcb9311161b10fa75", size = 181470 },
+ { url = "https://files.pythonhosted.org/packages/43/64/54d6ef11afca01fef8af78c230726a9394759f2addfbf7afc5e3cc032a45/librt-0.7.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:87808a8d1e0bd62a01cafc41f0fd6818b5a5d0ca0d8a55326a81643cdda8f873", size = 201713 },
+ { url = "https://files.pythonhosted.org/packages/2d/29/73e7ed2991330b28919387656f54109139b49e19cd72902f466bd44415fd/librt-0.7.8-cp311-cp311-win32.whl", hash = "sha256:31724b93baa91512bd0a376e7cf0b59d8b631ee17923b1218a65456fa9bda2e7", size = 43803 },
+ { url = "https://files.pythonhosted.org/packages/3f/de/66766ff48ed02b4d78deea30392ae200bcbd99ae61ba2418b49fd50a4831/librt-0.7.8-cp311-cp311-win_amd64.whl", hash = "sha256:978e8b5f13e52cf23a9e80f3286d7546baa70bc4ef35b51d97a709d0b28e537c", size = 50080 },
+ { url = "https://files.pythonhosted.org/packages/6f/e3/33450438ff3a8c581d4ed7f798a70b07c3206d298cf0b87d3806e72e3ed8/librt-0.7.8-cp311-cp311-win_arm64.whl", hash = "sha256:20e3946863d872f7cabf7f77c6c9d370b8b3d74333d3a32471c50d3a86c0a232", size = 43383 },
+ { url = "https://files.pythonhosted.org/packages/56/04/79d8fcb43cae376c7adbab7b2b9f65e48432c9eced62ac96703bcc16e09b/librt-0.7.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9b6943885b2d49c48d0cff23b16be830ba46b0152d98f62de49e735c6e655a63", size = 57472 },
+ { url = "https://files.pythonhosted.org/packages/b4/ba/60b96e93043d3d659da91752689023a73981336446ae82078cddf706249e/librt-0.7.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46ef1f4b9b6cc364b11eea0ecc0897314447a66029ee1e55859acb3dd8757c93", size = 58986 },
+ { url = "https://files.pythonhosted.org/packages/7c/26/5215e4cdcc26e7be7eee21955a7e13cbf1f6d7d7311461a6014544596fac/librt-0.7.8-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:907ad09cfab21e3c86e8f1f87858f7049d1097f77196959c033612f532b4e592", size = 168422 },
+ { url = "https://files.pythonhosted.org/packages/0f/84/e8d1bc86fa0159bfc24f3d798d92cafd3897e84c7fea7fe61b3220915d76/librt-0.7.8-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2991b6c3775383752b3ca0204842743256f3ad3deeb1d0adc227d56b78a9a850", size = 177478 },
+ { url = "https://files.pythonhosted.org/packages/57/11/d0268c4b94717a18aa91df1100e767b010f87b7ae444dafaa5a2d80f33a6/librt-0.7.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03679b9856932b8c8f674e87aa3c55ea11c9274301f76ae8dc4d281bda55cf62", size = 192439 },
+ { url = "https://files.pythonhosted.org/packages/8d/56/1e8e833b95fe684f80f8894ae4d8b7d36acc9203e60478fcae599120a975/librt-0.7.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3968762fec1b2ad34ce57458b6de25dbb4142713e9ca6279a0d352fa4e9f452b", size = 191483 },
+ { url = "https://files.pythonhosted.org/packages/17/48/f11cf28a2cb6c31f282009e2208312aa84a5ee2732859f7856ee306176d5/librt-0.7.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bb7a7807523a31f03061288cc4ffc065d684c39db7644c676b47d89553c0d714", size = 185376 },
+ { url = "https://files.pythonhosted.org/packages/b8/6a/d7c116c6da561b9155b184354a60a3d5cdbf08fc7f3678d09c95679d13d9/librt-0.7.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad64a14b1e56e702e19b24aae108f18ad1bf7777f3af5fcd39f87d0c5a814449", size = 206234 },
+ { url = "https://files.pythonhosted.org/packages/61/de/1975200bb0285fc921c5981d9978ce6ce11ae6d797df815add94a5a848a3/librt-0.7.8-cp312-cp312-win32.whl", hash = "sha256:0241a6ed65e6666236ea78203a73d800dbed896cf12ae25d026d75dc1fcd1dac", size = 44057 },
+ { url = "https://files.pythonhosted.org/packages/8e/cd/724f2d0b3461426730d4877754b65d39f06a41ac9d0a92d5c6840f72b9ae/librt-0.7.8-cp312-cp312-win_amd64.whl", hash = "sha256:6db5faf064b5bab9675c32a873436b31e01d66ca6984c6f7f92621656033a708", size = 50293 },
+ { url = "https://files.pythonhosted.org/packages/bd/cf/7e899acd9ee5727ad8160fdcc9994954e79fab371c66535c60e13b968ffc/librt-0.7.8-cp312-cp312-win_arm64.whl", hash = "sha256:57175aa93f804d2c08d2edb7213e09276bd49097611aefc37e3fa38d1fb99ad0", size = 43574 },
+ { url = "https://files.pythonhosted.org/packages/a1/fe/b1f9de2829cf7fc7649c1dcd202cfd873837c5cc2fc9e526b0e7f716c3d2/librt-0.7.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4c3995abbbb60b3c129490fa985dfe6cac11d88fc3c36eeb4fb1449efbbb04fc", size = 57500 },
+ { url = "https://files.pythonhosted.org/packages/eb/d4/4a60fbe2e53b825f5d9a77325071d61cd8af8506255067bf0c8527530745/librt-0.7.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:44e0c2cbc9bebd074cf2cdbe472ca185e824be4e74b1c63a8e934cea674bebf2", size = 59019 },
+ { url = "https://files.pythonhosted.org/packages/6a/37/61ff80341ba5159afa524445f2d984c30e2821f31f7c73cf166dcafa5564/librt-0.7.8-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:4d2f1e492cae964b3463a03dc77a7fe8742f7855d7258c7643f0ee32b6651dd3", size = 169015 },
+ { url = "https://files.pythonhosted.org/packages/1c/86/13d4f2d6a93f181ebf2fc953868826653ede494559da8268023fe567fca3/librt-0.7.8-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:451e7ffcef8f785831fdb791bd69211f47e95dc4c6ddff68e589058806f044c6", size = 178161 },
+ { url = "https://files.pythonhosted.org/packages/88/26/e24ef01305954fc4d771f1f09f3dd682f9eb610e1bec188ffb719374d26e/librt-0.7.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3469e1af9f1380e093ae06bedcbdd11e407ac0b303a56bbe9afb1d6824d4982d", size = 193015 },
+ { url = "https://files.pythonhosted.org/packages/88/a0/92b6bd060e720d7a31ed474d046a69bd55334ec05e9c446d228c4b806ae3/librt-0.7.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f11b300027ce19a34f6d24ebb0a25fd0e24a9d53353225a5c1e6cadbf2916b2e", size = 192038 },
+ { url = "https://files.pythonhosted.org/packages/06/bb/6f4c650253704279c3a214dad188101d1b5ea23be0606628bc6739456624/librt-0.7.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4adc73614f0d3c97874f02f2c7fd2a27854e7e24ad532ea6b965459c5b757eca", size = 186006 },
+ { url = "https://files.pythonhosted.org/packages/dc/00/1c409618248d43240cadf45f3efb866837fa77e9a12a71481912135eb481/librt-0.7.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60c299e555f87e4c01b2eca085dfccda1dde87f5a604bb45c2906b8305819a93", size = 206888 },
+ { url = "https://files.pythonhosted.org/packages/d9/83/b2cfe8e76ff5c1c77f8a53da3d5de62d04b5ebf7cf913e37f8bca43b5d07/librt-0.7.8-cp313-cp313-win32.whl", hash = "sha256:b09c52ed43a461994716082ee7d87618096851319bf695d57ec123f2ab708951", size = 44126 },
+ { url = "https://files.pythonhosted.org/packages/a9/0b/c59d45de56a51bd2d3a401fc63449c0ac163e4ef7f523ea8b0c0dee86ec5/librt-0.7.8-cp313-cp313-win_amd64.whl", hash = "sha256:f8f4a901a3fa28969d6e4519deceab56c55a09d691ea7b12ca830e2fa3461e34", size = 50262 },
+ { url = "https://files.pythonhosted.org/packages/fc/b9/973455cec0a1ec592395250c474164c4a58ebf3e0651ee920fef1a2623f1/librt-0.7.8-cp313-cp313-win_arm64.whl", hash = "sha256:43d4e71b50763fcdcf64725ac680d8cfa1706c928b844794a7aa0fa9ac8e5f09", size = 43600 },
+ { url = "https://files.pythonhosted.org/packages/1a/73/fa8814c6ce2d49c3827829cadaa1589b0bf4391660bd4510899393a23ebc/librt-0.7.8-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:be927c3c94c74b05128089a955fba86501c3b544d1d300282cc1b4bd370cb418", size = 57049 },
+ { url = "https://files.pythonhosted.org/packages/53/fe/f6c70956da23ea235fd2e3cc16f4f0b4ebdfd72252b02d1164dd58b4e6c3/librt-0.7.8-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7b0803e9008c62a7ef79058233db7ff6f37a9933b8f2573c05b07ddafa226611", size = 58689 },
+ { url = "https://files.pythonhosted.org/packages/1f/4d/7a2481444ac5fba63050d9abe823e6bc16896f575bfc9c1e5068d516cdce/librt-0.7.8-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:79feb4d00b2a4e0e05c9c56df707934f41fcb5fe53fd9efb7549068d0495b758", size = 166808 },
+ { url = "https://files.pythonhosted.org/packages/ac/3c/10901d9e18639f8953f57c8986796cfbf4c1c514844a41c9197cf87cb707/librt-0.7.8-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9122094e3f24aa759c38f46bd8863433820654927370250f460ae75488b66ea", size = 175614 },
+ { url = "https://files.pythonhosted.org/packages/db/01/5cbdde0951a5090a80e5ba44e6357d375048123c572a23eecfb9326993a7/librt-0.7.8-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7e03bea66af33c95ce3addf87a9bf1fcad8d33e757bc479957ddbc0e4f7207ac", size = 189955 },
+ { url = "https://files.pythonhosted.org/packages/6a/b4/e80528d2f4b7eaf1d437fcbd6fc6ba4cbeb3e2a0cb9ed5a79f47c7318706/librt-0.7.8-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f1ade7f31675db00b514b98f9ab9a7698c7282dad4be7492589109471852d398", size = 189370 },
+ { url = "https://files.pythonhosted.org/packages/c1/ab/938368f8ce31a9787ecd4becb1e795954782e4312095daf8fd22420227c8/librt-0.7.8-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a14229ac62adcf1b90a15992f1ab9c69ae8b99ffb23cb64a90878a6e8a2f5b81", size = 183224 },
+ { url = "https://files.pythonhosted.org/packages/3c/10/559c310e7a6e4014ac44867d359ef8238465fb499e7eb31b6bfe3e3f86f5/librt-0.7.8-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5bcaaf624fd24e6a0cb14beac37677f90793a96864c67c064a91458611446e83", size = 203541 },
+ { url = "https://files.pythonhosted.org/packages/f8/db/a0db7acdb6290c215f343835c6efda5b491bb05c3ddc675af558f50fdba3/librt-0.7.8-cp314-cp314-win32.whl", hash = "sha256:7aa7d5457b6c542ecaed79cec4ad98534373c9757383973e638ccced0f11f46d", size = 40657 },
+ { url = "https://files.pythonhosted.org/packages/72/e0/4f9bdc2a98a798511e81edcd6b54fe82767a715e05d1921115ac70717f6f/librt-0.7.8-cp314-cp314-win_amd64.whl", hash = "sha256:3d1322800771bee4a91f3b4bd4e49abc7d35e65166821086e5afd1e6c0d9be44", size = 46835 },
+ { url = "https://files.pythonhosted.org/packages/f9/3d/59c6402e3dec2719655a41ad027a7371f8e2334aa794ed11533ad5f34969/librt-0.7.8-cp314-cp314-win_arm64.whl", hash = "sha256:5363427bc6a8c3b1719f8f3845ea53553d301382928a86e8fab7984426949bce", size = 39885 },
+ { url = "https://files.pythonhosted.org/packages/4e/9c/2481d80950b83085fb14ba3c595db56330d21bbc7d88a19f20165f3538db/librt-0.7.8-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:ca916919793a77e4a98d4a1701e345d337ce53be4a16620f063191f7322ac80f", size = 59161 },
+ { url = "https://files.pythonhosted.org/packages/96/79/108df2cfc4e672336765d54e3ff887294c1cc36ea4335c73588875775527/librt-0.7.8-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:54feb7b4f2f6706bb82325e836a01be805770443e2400f706e824e91f6441dde", size = 61008 },
+ { url = "https://files.pythonhosted.org/packages/46/f2/30179898f9994a5637459d6e169b6abdc982012c0a4b2d4c26f50c06f911/librt-0.7.8-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:39a4c76fee41007070f872b648cc2f711f9abf9a13d0c7162478043377b52c8e", size = 187199 },
+ { url = "https://files.pythonhosted.org/packages/b4/da/f7563db55cebdc884f518ba3791ad033becc25ff68eb70902b1747dc0d70/librt-0.7.8-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac9c8a458245c7de80bc1b9765b177055efff5803f08e548dd4bb9ab9a8d789b", size = 198317 },
+ { url = "https://files.pythonhosted.org/packages/b3/6c/4289acf076ad371471fa86718c30ae353e690d3de6167f7db36f429272f1/librt-0.7.8-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95b67aa7eff150f075fda09d11f6bfb26edffd300f6ab1666759547581e8f666", size = 210334 },
+ { url = "https://files.pythonhosted.org/packages/4a/7f/377521ac25b78ac0a5ff44127a0360ee6d5ddd3ce7327949876a30533daa/librt-0.7.8-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:535929b6eff670c593c34ff435d5440c3096f20fa72d63444608a5aef64dd581", size = 211031 },
+ { url = "https://files.pythonhosted.org/packages/c5/b1/e1e96c3e20b23d00cf90f4aad48f0deb4cdfec2f0ed8380d0d85acf98bbf/librt-0.7.8-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:63937bd0f4d1cb56653dc7ae900d6c52c41f0015e25aaf9902481ee79943b33a", size = 204581 },
+ { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731 },
+ { url = "https://files.pythonhosted.org/packages/22/f0/07fb6ab5c39a4ca9af3e37554f9d42f25c464829254d72e4ebbd81da351c/librt-0.7.8-cp314-cp314t-win32.whl", hash = "sha256:171ca3a0a06c643bd0a2f62a8944e1902c94aa8e5da4db1ea9a8daf872685365", size = 41173 },
+ { url = "https://files.pythonhosted.org/packages/24/d4/7e4be20993dc6a782639625bd2f97f3c66125c7aa80c82426956811cfccf/librt-0.7.8-cp314-cp314t-win_amd64.whl", hash = "sha256:445b7304145e24c60288a2f172b5ce2ca35c0f81605f5299f3fa567e189d2e32", size = 47668 },
+ { url = "https://files.pythonhosted.org/packages/fc/85/69f92b2a7b3c0f88ffe107c86b952b397004b5b8ea5a81da3d9c04c04422/librt-0.7.8-cp314-cp314t-win_arm64.whl", hash = "sha256:8766ece9de08527deabcd7cb1b4f1a967a385d26e33e536d6d8913db6ef74f06", size = 40550 },
+]
+
[[package]]
name = "lifelines"
version = "0.30.0"
@@ -492,6 +797,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509 },
]
+[[package]]
+name = "markdown"
+version = "3.10.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180 },
+]
+
[[package]]
name = "markupsafe"
version = "3.0.2"
@@ -603,6 +917,149 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/b9/59e120d24a2ec5fc2d30646adb2efb4621aab3c6d83d66fb2a7a182db032/matplotlib-3.10.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb73d8aa75a237457988f9765e4dfe1c0d2453c5ca4eabc897d4309672c8e014", size = 8594298 },
]
+[[package]]
+name = "mergedeep"
+version = "1.3.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/3a/41/580bb4006e3ed0361b8151a01d324fb03f420815446c7def45d02f74c270/mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8", size = 4661 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354 },
+]
+
+[[package]]
+name = "mike"
+version = "2.1.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "importlib-metadata" },
+ { name = "importlib-resources" },
+ { name = "jinja2" },
+ { name = "mkdocs" },
+ { name = "pyparsing" },
+ { name = "pyyaml" },
+ { name = "pyyaml-env-tag" },
+ { name = "verspec" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/ab/f7/2933f1a1fb0e0f077d5d6a92c6c7f8a54e6128241f116dff4df8b6050bbf/mike-2.1.3.tar.gz", hash = "sha256:abd79b8ea483fb0275b7972825d3082e5ae67a41820f8d8a0dc7a3f49944e810", size = 38119 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/fd/1a/31b7cd6e4e7a02df4e076162e9783620777592bea9e4bb036389389af99d/mike-2.1.3-py3-none-any.whl", hash = "sha256:d90c64077e84f06272437b464735130d380703a76a5738b152932884c60c062a", size = 33754 },
+]
+
+[[package]]
+name = "mkdocs"
+version = "1.6.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "click" },
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "ghp-import" },
+ { name = "jinja2" },
+ { name = "markdown" },
+ { name = "markupsafe" },
+ { name = "mergedeep" },
+ { name = "mkdocs-get-deps" },
+ { name = "packaging" },
+ { name = "pathspec" },
+ { name = "pyyaml" },
+ { name = "pyyaml-env-tag" },
+ { name = "watchdog" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/bc/c6/bbd4f061bd16b378247f12953ffcb04786a618ce5e904b8c5a01a0309061/mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2", size = 3889159 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/22/5b/dbc6a8cddc9cfa9c4971d59fb12bb8d42e161b7e7f8cc89e49137c5b279c/mkdocs-1.6.1-py3-none-any.whl", hash = "sha256:db91759624d1647f3f34aa0c3f327dd2601beae39a366d6e064c03468d35c20e", size = 3864451 },
+]
+
+[[package]]
+name = "mkdocs-autorefs"
+version = "1.4.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markdown" },
+ { name = "markupsafe" },
+ { name = "mkdocs" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/52/c0/f641843de3f612a6b48253f39244165acff36657a91cc903633d456ae1ac/mkdocs_autorefs-1.4.4.tar.gz", hash = "sha256:d54a284f27a7346b9c38f1f852177940c222da508e66edc816a0fa55fc6da197", size = 56588 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/28/de/a3e710469772c6a89595fc52816da05c1e164b4c866a89e3cb82fb1b67c5/mkdocs_autorefs-1.4.4-py3-none-any.whl", hash = "sha256:834ef5408d827071ad1bc69e0f39704fa34c7fc05bc8e1c72b227dfdc5c76089", size = 25530 },
+]
+
+[[package]]
+name = "mkdocs-get-deps"
+version = "0.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "mergedeep" },
+ { name = "platformdirs" },
+ { name = "pyyaml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/98/f5/ed29cd50067784976f25ed0ed6fcd3c2ce9eb90650aa3b2796ddf7b6870b/mkdocs_get_deps-0.2.0.tar.gz", hash = "sha256:162b3d129c7fad9b19abfdcb9c1458a651628e4b1dea628ac68790fb3061c60c", size = 10239 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/9f/d4/029f984e8d3f3b6b726bd33cafc473b75e9e44c0f7e80a5b29abc466bdea/mkdocs_get_deps-0.2.0-py3-none-any.whl", hash = "sha256:2bf11d0b133e77a0dd036abeeb06dec8775e46efa526dc70667d8863eefc6134", size = 9521 },
+]
+
+[[package]]
+name = "mkdocs-material"
+version = "9.7.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "babel" },
+ { name = "backrefs" },
+ { name = "colorama" },
+ { name = "jinja2" },
+ { name = "markdown" },
+ { name = "mkdocs" },
+ { name = "mkdocs-material-extensions" },
+ { name = "paginate" },
+ { name = "pygments" },
+ { name = "pymdown-extensions" },
+ { name = "requests" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/27/e2/2ffc356cd72f1473d07c7719d82a8f2cbd261666828614ecb95b12169f41/mkdocs_material-9.7.1.tar.gz", hash = "sha256:89601b8f2c3e6c6ee0a918cc3566cb201d40bf37c3cd3c2067e26fadb8cce2b8", size = 4094392 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3e/32/ed071cb721aca8c227718cffcf7bd539620e9799bbf2619e90c757bfd030/mkdocs_material-9.7.1-py3-none-any.whl", hash = "sha256:3f6100937d7d731f87f1e3e3b021c97f7239666b9ba1151ab476cabb96c60d5c", size = 9297166 },
+]
+
+[[package]]
+name = "mkdocs-material-extensions"
+version = "1.3.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/79/9b/9b4c96d6593b2a541e1cb8b34899a6d021d208bb357042823d4d2cabdbe7/mkdocs_material_extensions-1.3.1.tar.gz", hash = "sha256:10c9511cea88f568257f960358a467d12b970e1f7b2c0e5fb2bb48cab1928443", size = 11847 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728 },
+]
+
+[[package]]
+name = "mkdocstrings"
+version = "1.0.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "jinja2" },
+ { name = "markdown" },
+ { name = "markupsafe" },
+ { name = "mkdocs" },
+ { name = "mkdocs-autorefs" },
+ { name = "pymdown-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/46/62/0dfc5719514115bf1781f44b1d7f2a0923fcc01e9c5d7990e48a05c9ae5d/mkdocstrings-1.0.3.tar.gz", hash = "sha256:ab670f55040722b49bb45865b2e93b824450fb4aef638b00d7acb493a9020434", size = 100946 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/41/1cf02e3df279d2dd846a1bf235a928254eba9006dd22b4a14caa71aed0f7/mkdocstrings-1.0.3-py3-none-any.whl", hash = "sha256:0d66d18430c2201dc7fe85134277382baaa15e6b30979f3f3bdbabd6dbdb6046", size = 35523 },
+]
+
+[[package]]
+name = "mkdocstrings-python"
+version = "2.0.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "griffe" },
+ { name = "mkdocs-autorefs" },
+ { name = "mkdocstrings" },
+ { name = "typing-extensions", marker = "python_full_version < '3.11'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/25/84/78243847ad9d5c21d30a2842720425b17e880d99dfe824dee11d6b2149b4/mkdocstrings_python-2.0.2.tar.gz", hash = "sha256:4a32ccfc4b8d29639864698e81cfeb04137bce76bb9f3c251040f55d4b6e1ad8", size = 199124 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/f3/31/7ee938abbde2322e553a2cb5f604cdd1e4728e08bba39c7ee6fae9af840b/mkdocstrings_python-2.0.2-py3-none-any.whl", hash = "sha256:31241c0f43d85a69306d704d5725786015510ea3f3c4bdfdb5a5731d83cdc2b0", size = 104900 },
+]
+
[[package]]
name = "mpmath"
version = "1.3.0"
@@ -612,6 +1069,61 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 },
]
+[[package]]
+name = "mypy"
+version = "1.19.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "librt", marker = "platform_python_implementation != 'PyPy'" },
+ { name = "mypy-extensions" },
+ { name = "pathspec" },
+ { name = "tomli", marker = "python_full_version < '3.11'" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2f/63/e499890d8e39b1ff2df4c0c6ce5d371b6844ee22b8250687a99fd2f657a8/mypy-1.19.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f05aa3d375b385734388e844bc01733bd33c644ab48e9684faa54e5389775ec", size = 13101333 },
+ { url = "https://files.pythonhosted.org/packages/72/4b/095626fc136fba96effc4fd4a82b41d688ab92124f8c4f7564bffe5cf1b0/mypy-1.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:022ea7279374af1a5d78dfcab853fe6a536eebfda4b59deab53cd21f6cd9f00b", size = 12164102 },
+ { url = "https://files.pythonhosted.org/packages/0c/5b/952928dd081bf88a83a5ccd49aaecfcd18fd0d2710c7ff07b8fb6f7032b9/mypy-1.19.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee4c11e460685c3e0c64a4c5de82ae143622410950d6be863303a1c4ba0e36d6", size = 12765799 },
+ { url = "https://files.pythonhosted.org/packages/2a/0d/93c2e4a287f74ef11a66fb6d49c7a9f05e47b0a4399040e6719b57f500d2/mypy-1.19.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de759aafbae8763283b2ee5869c7255391fbc4de3ff171f8f030b5ec48381b74", size = 13522149 },
+ { url = "https://files.pythonhosted.org/packages/7b/0e/33a294b56aaad2b338d203e3a1d8b453637ac36cb278b45005e0901cf148/mypy-1.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ab43590f9cd5108f41aacf9fca31841142c786827a74ab7cc8a2eacb634e09a1", size = 13810105 },
+ { url = "https://files.pythonhosted.org/packages/0e/fd/3e82603a0cb66b67c5e7abababce6bf1a929ddf67bf445e652684af5c5a0/mypy-1.19.1-cp310-cp310-win_amd64.whl", hash = "sha256:2899753e2f61e571b3971747e302d5f420c3fd09650e1951e99f823bc3089dac", size = 10057200 },
+ { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539 },
+ { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163 },
+ { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629 },
+ { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933 },
+ { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754 },
+ { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772 },
+ { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053 },
+ { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134 },
+ { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616 },
+ { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847 },
+ { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976 },
+ { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104 },
+ { url = "https://files.pythonhosted.org/packages/de/9f/a6abae693f7a0c697dbb435aac52e958dc8da44e92e08ba88d2e42326176/mypy-1.19.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e3157c7594ff2ef1634ee058aafc56a82db665c9438fd41b390f3bde1ab12250", size = 13201927 },
+ { url = "https://files.pythonhosted.org/packages/9a/a4/45c35ccf6e1c65afc23a069f50e2c66f46bd3798cbe0d680c12d12935caa/mypy-1.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdb12f69bcc02700c2b47e070238f42cb87f18c0bc1fc4cdb4fb2bc5fd7a3b8b", size = 12206730 },
+ { url = "https://files.pythonhosted.org/packages/05/bb/cdcf89678e26b187650512620eec8368fded4cfd99cfcb431e4cdfd19dec/mypy-1.19.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f859fb09d9583a985be9a493d5cfc5515b56b08f7447759a0c5deaf68d80506e", size = 12724581 },
+ { url = "https://files.pythonhosted.org/packages/d1/32/dd260d52babf67bad8e6770f8e1102021877ce0edea106e72df5626bb0ec/mypy-1.19.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9a6538e0415310aad77cb94004ca6482330fece18036b5f360b62c45814c4ef", size = 13616252 },
+ { url = "https://files.pythonhosted.org/packages/71/d0/5e60a9d2e3bd48432ae2b454b7ef2b62a960ab51292b1eda2a95edd78198/mypy-1.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:da4869fc5e7f62a88f3fe0b5c919d1d9f7ea3cef92d3689de2823fd27e40aa75", size = 13840848 },
+ { url = "https://files.pythonhosted.org/packages/98/76/d32051fa65ecf6cc8c6610956473abdc9b4c43301107476ac03559507843/mypy-1.19.1-cp313-cp313-win_amd64.whl", hash = "sha256:016f2246209095e8eda7538944daa1d60e1e8134d98983b9fc1e92c1fc0cb8dd", size = 10135510 },
+ { url = "https://files.pythonhosted.org/packages/de/eb/b83e75f4c820c4247a58580ef86fcd35165028f191e7e1ba57128c52782d/mypy-1.19.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:06e6170bd5836770e8104c8fdd58e5e725cfeb309f0a6c681a811f557e97eac1", size = 13199744 },
+ { url = "https://files.pythonhosted.org/packages/94/28/52785ab7bfa165f87fcbb61547a93f98bb20e7f82f90f165a1f69bce7b3d/mypy-1.19.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:804bd67b8054a85447c8954215a906d6eff9cabeabe493fb6334b24f4bfff718", size = 12215815 },
+ { url = "https://files.pythonhosted.org/packages/0a/c6/bdd60774a0dbfb05122e3e925f2e9e846c009e479dcec4821dad881f5b52/mypy-1.19.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:21761006a7f497cb0d4de3d8ef4ca70532256688b0523eee02baf9eec895e27b", size = 12740047 },
+ { url = "https://files.pythonhosted.org/packages/32/2a/66ba933fe6c76bd40d1fe916a83f04fed253152f451a877520b3c4a5e41e/mypy-1.19.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28902ee51f12e0f19e1e16fbe2f8f06b6637f482c459dd393efddd0ec7f82045", size = 13601998 },
+ { url = "https://files.pythonhosted.org/packages/e3/da/5055c63e377c5c2418760411fd6a63ee2b96cf95397259038756c042574f/mypy-1.19.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:481daf36a4c443332e2ae9c137dfee878fcea781a2e3f895d54bd3002a900957", size = 13807476 },
+ { url = "https://files.pythonhosted.org/packages/cd/09/4ebd873390a063176f06b0dbf1f7783dd87bd120eae7727fa4ae4179b685/mypy-1.19.1-cp314-cp314-win_amd64.whl", hash = "sha256:8bb5c6f6d043655e055be9b542aa5f3bdd30e4f3589163e85f93f3640060509f", size = 10281872 },
+ { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239 },
+]
+
+[[package]]
+name = "mypy-extensions"
+version = "1.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963 },
+]
+
[[package]]
name = "networkx"
version = "3.4.2"
@@ -1015,6 +1527,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 },
]
+[[package]]
+name = "paginate"
+version = "0.5.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/ec/46/68dde5b6bc00c1296ec6466ab27dddede6aec9af1b99090e1107091b3b84/paginate-0.5.7.tar.gz", hash = "sha256:22bd083ab41e1a8b4f3690544afb2c60c25e5c9a63a30fa2f483f6c60c8e5945", size = 19252 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746 },
+]
+
[[package]]
name = "pandas"
version = "2.3.0"
@@ -1064,6 +1585,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/39/c2/646d2e93e0af70f4e5359d870a63584dacbc324b54d73e6b3267920ff117/pandas-2.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bb3be958022198531eb7ec2008cfc78c5b1eed51af8600c6c5d9160d89d8d249", size = 13231847 },
]
+[[package]]
+name = "pathspec"
+version = "1.0.4"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206 },
+]
+
[[package]]
name = "pillow"
version = "11.2.1"
@@ -1141,6 +1671,37 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/21/2c/5e05f58658cf49b6667762cca03d6e7d85cededde2caf2ab37b81f80e574/pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044", size = 2674751 },
]
+[[package]]
+name = "platformdirs"
+version = "4.5.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/cf/86/0248f086a84f01b37aaec0fa567b397df1a119f73c16f6c7a9aac73ea309/platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda", size = 21715 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731 },
+]
+
+[[package]]
+name = "pygments"
+version = "2.19.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217 },
+]
+
+[[package]]
+name = "pymdown-extensions"
+version = "10.20.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "markdown" },
+ { name = "pyyaml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/1e/6c/9e370934bfa30e889d12e61d0dae009991294f40055c238980066a7fbd83/pymdown_extensions-10.20.1.tar.gz", hash = "sha256:e7e39c865727338d434b55f1dd8da51febcffcaebd6e1a0b9c836243f660740a", size = 852860 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/40/6d/b6ee155462a0156b94312bdd82d2b92ea56e909740045a87ccb98bf52405/pymdown_extensions-10.20.1-py3-none-any.whl", hash = "sha256:24af7feacbca56504b313b7b418c4f5e1317bb5fea60f03d57be7fcc40912aa0", size = 268768 },
+]
+
[[package]]
name = "pyparsing"
version = "3.2.3"
@@ -1215,6 +1776,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
]
+[[package]]
+name = "pyyaml-env-tag"
+version = "1.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pyyaml" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/eb/2e/79c822141bfd05a853236b504869ebc6b70159afc570e1d5a20641782eaa/pyyaml_env_tag-1.1.tar.gz", hash = "sha256:2eb38b75a2d21ee0475d6d97ec19c63287a7e140231e4214969d0eac923cd7ff", size = 5737 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722 },
+]
+
[[package]]
name = "qdldl"
version = "0.1.7.post5"
@@ -1248,6 +1821,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/08/f7/abac03a09f6848cee6d5dd7a7a8bd1dfed68766ee77f9cbf3e9de596ad68/qdldl-0.1.7.post5-cp313-cp313-win_amd64.whl", hash = "sha256:cc9be378e7bec67d4c62b7fa27cafb4f77d3e5e059d753c3dce0a5ae1ef5fea0", size = 90735 },
]
+[[package]]
+name = "requests"
+version = "2.32.5"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "certifi" },
+ { name = "charset-normalizer" },
+ { name = "idna" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738 },
+]
+
[[package]]
name = "scikit-learn"
version = "1.6.1"
@@ -1608,6 +2196,56 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 },
]
+[[package]]
+name = "urllib3"
+version = "2.6.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584 },
+]
+
+[[package]]
+name = "verspec"
+version = "0.1.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/44/8126f9f0c44319b2efc65feaad589cadef4d77ece200ae3c9133d58464d0/verspec-0.1.0.tar.gz", hash = "sha256:c4504ca697b2056cdb4bfa7121461f5a0e81809255b41c03dda4ba823637c01e", size = 27123 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/a4/ce/3b6fee91c85626eaf769d617f1be9d2e15c1cca027bbdeb2e0d751469355/verspec-0.1.0-py3-none-any.whl", hash = "sha256:741877d5633cc9464c45a469ae2a31e801e6dbbaa85b9675d481cda100f11c31", size = 19640 },
+]
+
+[[package]]
+name = "watchdog"
+version = "6.0.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0c/56/90994d789c61df619bfc5ce2ecdabd5eeff564e1eb47512bd01b5e019569/watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26", size = 96390 },
+ { url = "https://files.pythonhosted.org/packages/55/46/9a67ee697342ddf3c6daa97e3a587a56d6c4052f881ed926a849fcf7371c/watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112", size = 88389 },
+ { url = "https://files.pythonhosted.org/packages/44/65/91b0985747c52064d8701e1075eb96f8c40a79df889e59a399453adfb882/watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3", size = 89020 },
+ { url = "https://files.pythonhosted.org/packages/e0/24/d9be5cd6642a6aa68352ded4b4b10fb0d7889cb7f45814fb92cecd35f101/watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c", size = 96393 },
+ { url = "https://files.pythonhosted.org/packages/63/7a/6013b0d8dbc56adca7fdd4f0beed381c59f6752341b12fa0886fa7afc78b/watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2", size = 88392 },
+ { url = "https://files.pythonhosted.org/packages/d1/40/b75381494851556de56281e053700e46bff5b37bf4c7267e858640af5a7f/watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c", size = 89019 },
+ { url = "https://files.pythonhosted.org/packages/39/ea/3930d07dafc9e286ed356a679aa02d777c06e9bfd1164fa7c19c288a5483/watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948", size = 96471 },
+ { url = "https://files.pythonhosted.org/packages/12/87/48361531f70b1f87928b045df868a9fd4e253d9ae087fa4cf3f7113be363/watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860", size = 88449 },
+ { url = "https://files.pythonhosted.org/packages/5b/7e/8f322f5e600812e6f9a31b75d242631068ca8f4ef0582dd3ae6e72daecc8/watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0", size = 89054 },
+ { url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480 },
+ { url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451 },
+ { url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057 },
+ { url = "https://files.pythonhosted.org/packages/30/ad/d17b5d42e28a8b91f8ed01cb949da092827afb9995d4559fd448d0472763/watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881", size = 87902 },
+ { url = "https://files.pythonhosted.org/packages/5c/ca/c3649991d140ff6ab67bfc85ab42b165ead119c9e12211e08089d763ece5/watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11", size = 88380 },
+ { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079 },
+ { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078 },
+ { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076 },
+ { url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077 },
+ { url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078 },
+ { url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077 },
+ { url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078 },
+ { url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065 },
+ { url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070 },
+ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067 },
+]
+
[[package]]
name = "wrapt"
version = "1.17.2"
@@ -1671,3 +2309,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/09/5e/1655cf481e079c1f22d0cabdd4e51733679932718dc23bf2db175f329b76/wrapt-1.17.2-cp313-cp313t-win_amd64.whl", hash = "sha256:eaf675418ed6b3b31c7a989fd007fa7c3be66ce14e5c3b27336383604c9da85c", size = 40750 },
{ url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594 },
]
+
+[[package]]
+name = "zipp"
+version = "3.23.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276 },
+]