This repository contains code for an enhanced version of 3D-RetinaNet, originally proposed with the ROAD dataset. We extend the model with concept-based explainability, integrating a Concept Embedding Module (CEM) to make the architecture explainable by design. This work is described in more detail in our project paper (preprint).
- Requirements
- Training 3D-RetinaNet
- Testing and Building Tubes
- Performance
- Concept-Based Explainability Extension (CEM)
- Citation
- Reference
We need three things to get started with training: datasets, kinetics pre-trained weight, and pytorch with torchvision and tensoboardX.
We used only the ROAD dataset, introduced in the dataset release paper.
- Install Pytorch and torchvision
- Install tensorboardX via
pip install tensorboardx - Download Kinetics-400 pretrained weights into
kinetics-pt/, using this script or Google Drive
You will need 4 GPUs (each with at least 10GB VRAM). Example command:
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py /home/user/ /home/user/ /home/user/kinetics-pt/ \
--MODE=train --ARCH=resnet50 --MODEL_TYPE=I3D --DATASET=road --TRAIN_SUBSETS=train_3 \
--SEQ_LEN=8 --TEST_SEQ_LEN=8 --BATCH_SIZE=4 --LR=0.0041For evaluation and tube generation, use:
python main.py /home/user/ /home/user/ /home/user/kinetics-pt/ \
--MODE=gen_dets --MODEL_TYPE=I3D --TEST_SEQ_LEN=8 --TRAIN_SUBSETS=train_3 \
--SEQ_LEN=8 --BATCH_SIZE=4 --LR=0.0041Results obtained after 60 training epochs with the explainable model (CEM enabled):
- Agentness MEANAP: 54.67
- Agent MEANAP: 37.10
- Action MEANAP: 23.03
- Location MEANAP: 27.92
- Duplex MEANAP: 26.09
- Triplet MEANAP: 19.49
- Ego-action MEANAP: 42.40
[CEM] Concept Prediction:
- Accuracy: 77.05%
- F1 Micro: 0.2030, F1 Macro: 0.1060
We extended 3D-RetinaNet to support explainability by design through integration of a Concept Embedding Module (CEM).
- 🧠 CEM Head: Learns interpretable concept embeddings with dual embeddings (active/inactive).
- 🔁 Transformer Encoder: Encodes temporal patterns on the concept bottleneck.
- 🎯 Ego Head: Replaced with a concept-driven prediction head.
- 📊 Concept Supervision: Via BCEWithLogitsLoss using dynamic class balancing.
- 🧩 Triplet Annotations: Dataset modified to include triplet-based concept labels.
- 📉 Loss Tracking: Training logs
cem_lossalongside detection losses. - 📈 Evaluation Enhancements: F1-scores, per-concept stats, hardest concept analysis.
--USE_CEM=True --num_concepts=68 --cem_dim=16If this work was helpful, consider citing the original 3D-RetinaNet paper:
@ARTICLE{singh2022road,
author = {Singh, Gurkirt and others},
journal = {IEEE TPAMI},
title = {ROAD: The ROad event Awareness Dataset for autonomous Driving},
year = {5555},
doi = {10.1109/TPAMI.2022.3150906},
}