This repository contains the official implementation of wav2sleep💤, which has been accepted to Machine Learning for Health (ML4H) 2024.
wav2sleep is a unified model for sleep staging from sets of physiological signals, including cardio-respiratory (ECG, PPG, respiratory) and neural (EOG) modalities. It can be jointly trained across heterogeneous datasets, where the availability of input signals can vary. At test-time, the model can be applied to any subset of the modalities used during training.
After jointly training on over 10,000 overnight recordings from publicly available polysomnography datasets, including SHHS and MESA, wav2sleep outperforms existing sleep stage classification models across a range of input signal combinations.
To find out more, check out our paper: https://arxiv.org/abs/2411.04644
February 2026
- EOG Model Release: We've released
wav2sleep-eog, a 5-class sleep staging model using EOG signals (Wake, N1, N2, N3, REM). - Python API: New
load_model()andpredict_on_folder()functions make it easy to run inference on your own data:
from wav2sleep import load_model, predict_on_folder
model = load_model("hf://joncarter/wav2sleep-eog")
predict_on_folder("/path/to/edfs", "/path/to/outputs", model=model)Implementation of Wav2Sleep and baseline models in PyTorch, using Lightning for training.
Figure: wav2sleep architecture for sets of signals.
Scripts for transforming EDF files and sleep stage annotations from all 7 datasets used into efficient, columnar parquet files for model training. This can be parallelised over multiple CPU cores or an entire cluster using Ray, e.g.:
uv run preprocessing/1_ingest.py --folder /path/to/shhs --output-folder /path/to/processed/datasets --max-parallel 16- Set-up and Installation
- Training and Evaluation
- Inference
- Hugging Face Hub
- Visualising Results
- Citation
- License
Install directly from GitHub:
pip install git+https://github.com/joncarter1/wav2sleep.gitOr clone and install for development:
git clone https://github.com/joncarter1/wav2sleep
cd wav2sleep
uv syncOur work uses datasets managed by the National Sleep Research Resource (NSRR). To reproduce our results, you will need to apply for access to the following datasets:
- SHHS
- MESA
- WSC
- CCSHS
- CFS
- CHAT
- MROS
Once approved, these can be downloaded with the NSRR Ruby gem, e.g. nsrr download shhs. More details can be found on the NSRR website.
Once downloaded, we provide high-performance processing scripts to process each dataset and split it into training, validation and test partitions. Instructions on how to do this can be found here.
To train and evaluate the model on all datasets, just run:
uv run scripts/train.py model=wav2sleep num_gpus=1 tune_batch_size=True target_batch_size=16 name=wav2sleep-repro inputs=all datasets=all test=TrueThis will find the largest batch size that fits on your GPU, and accumulate batches for an effective batch size of at least 16. If you're lucky enough to have more than one GPU, you can specify e.g. num_gpus=2 to run across them using distributed data parallel (DDP) training.
The predict.py script can be used to run either cardio-respiratory or EOG models.
Note: You may need to reduce batch size depending on GPU size.
Run the EOG model on already-processed parquet files using the no-preprocess flag:
uv run scripts/predict.py --input-folder /path/to/processed/mesa/test \
--output-folder /tmp/example-mesa-outputs \
--model-folder hf://joncarter/wav2sleep-eog \
--batch-size 16 --no-preprocessRun the EOG model on raw EDF files, with analysis of up to 14 hours per file:
uv run scripts/predict.py --input-folder /path/to/edf/folder \
--output-folder /tmp/example-outputs \
--model-folder hf://joncarter/wav2sleep-eog \
--batch-size 16 --max-length-hours 14(This will skip broken EDF files)
| Model | Signals | Classes | Description |
|---|---|---|---|
hf://joncarter/wav2sleep |
ECG, PPG, ABD, THX | 4 | Cardio-respiratory (Wake, Light, Deep, REM) |
hf://joncarter/wav2sleep-eog |
EOG-L, EOG-R | 5 | EOG-based (Wake, N1, N2, N3, REM) |
Models can be loaded directly from the Hugging Face Hub:
from wav2sleep import load_model, predict_on_folder
# Load from HF Hub
model = load_model("hf://joncarter/wav2sleep-eog")
# Run inference
predict_on_folder(
input_folder="/path/to/edf_files",
output_folder="/path/to/predictions",
model=model,
)To upload a local model folder to the Hub:
uv run scripts/upload_to_hub.py \
--local-folder /path/to/model-folder \
--repo-id your-username/wav2sleep-eog \
--variant wav2sleep-eogWe use MLFlow to log trained models and evaluation metrics. By default, these will be stored in a local directory (./mlruns) and can be visualized results by running:
uv run mlflow serverand visiting http://localhost:5000 in your browser:
Figure: Screenshot from an MLFlow dashboard
If you find this code useful for your research, please cite our paper:
@misc{carter2024wav2sleepunifiedmultimodalapproach,
title={wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals},
author={Jonathan F. Carter and Lionel Tarassenko},
year={2024},
eprint={2411.04644},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.04644},
}This project is licensed under the MIT License - see the LICENSE file for details.
This software is provided for research and development purposes only. It is not a medical device, and is not intended for use in clinical decision-making.
This project was created with Cookiecutter and the joncarter1/cookiecutter_research template.
