Official Repository for CEN (Context Enhanced Network), presented at MICCAI-2024.
This repository contains the PyTorch implementation of the Context Enhanced Network (CEN) for advanced mammography analysis. The network utilizes pairs of mammography views (MLO and CC) to establish a comprehensive context, ultimately enhancing classification performance through a robust fusion mechanism.
The CEN framework employs a ViT-B/16 backbone (pretrained on ImageNet) to extract features from mammography screening exams. By combining the Craniocaudal (CC) and Mediolateral Oblique (MLO) views, the model leverages inter-view relationships, significantly improving the diagnostic robustness compared to single-view approaches.
This repository includes a specialized Demo Environment (demo_train.py & demo_test.py) specifically tailored for educational demonstrations, running seamlessly on limited hardware configurations by leveraging a toy dataset (DEMO_DATA).
- Multi-View Fusion: Intelligently combines MLO and CC views for enhanced contextual learning.
- Pretrained Backbone: Utilizes
ViT-B/16for powerful feature extraction. - Educational Demo Mode: Includes lightweight training and testing scripts designed to run fast on CPUs or basic GPUs with lower memory overhead.
- Automated Visualization: Generates loss curves and prediction visualisations out-of-the-box.
demo_train.py: Lightweight training script for demonstrations.demo_test.py: Testing and evaluation script with visualizations.models.py: Network architecture definitions (includingMAX_model).data.py: PyTorchDatasetimplementations for loading mammography pairs.calc_metrics2.py: Utilities for calculating performance metrics.create_toy_dataset.py: Script to generate syntheticDEMO_DATA.DEMO_DATA/: Folder containing the toy dataset used in demo scripts.demo_output/: Directory where trained models (.pth), logs, and visualization plots are saved.
Ensure you have Python 3.8+ installed. The required packages include:
torchtorchvisionnumpyPillow(PIL)matplotlibtqdm
You can install them via pip:
pip install torch torchvision numpy Pillow matplotlib tqdmTo train the model using the provided toy dataset, simply run:
python demo_train.pyThis script runs a fast 2-epoch training loop, logging both training and validation loss, and saves the best model weights to demo_output/. A loss graph (demo_loss_plot.png) will also be generated automatically.
After training, evaluate the model performance on the test set:
python demo_test.pyThis script loads the best model from demo_output/ and generates prediction visualizations inside demo_output/test_results/.
The Context Enhanced Network provides automated visualizations to easily track training performance and model predictions on test cases.
During training, the loss across epochs is plotted to verify model convergence.
Example loss plot over epochs showing consistent convergence.
The test scripts generate direct visual overlays on the mammography images. In these outputs:
- 🟦 Blue Boxes: Ground truth bounding boxes.
- 🟩 Green Boxes: True Positive predictions (correctly identified regions by the CEN model).
- 🟥 Red Boxes: False Positive predictions (incorrectly identified regions).
Here is an example output from test patient 301:
Left: MLO View | Right: CC View
This side-by-side visualisation effectively demonstrates how the model utilizes context from both views to pinpoint regions of interest.
If you find this code useful for your research, please refer to our MICCAI-2024 publication.
Developed by Hakkı Keman.


