Skip to content

Commit 5d11d51

Browse files
committed
update xception evaluation code
1 parent 9426298 commit 5d11d51

10 files changed

Lines changed: 150 additions & 151 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,5 @@ docs/_build/
440440

441441
# Pyenv
442442
.python-version
443+
444+
lightning_logs

Cargo.lock

Lines changed: 16 additions & 140 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Evaluate the predictions. Firstly prepare the predictions as described in the [d
131131
```python
132132
from avdeepfake1m.evaluation import ap_ar_1d, auc
133133
print(ap_ar_1d("<PREDICTION_JSON>", "<METADATA_JSON>", "file", "fake_segments", 1, [0.5, 0.75, 0.9, 0.95], [50, 30, 20, 10, 5], [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]))
134-
print(auc("<PREDICTION_TXT>", "<METADATA_JSON>"))
134+
print(auc("<PREDICTION_TXT>", "<METADATA_JSON>", "file", "fake_segments"))
135135
```
136136

137137
## License

examples/xception/README.md

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,32 @@ This example trains a Xception model on the AVDeepfake1M/AVDeepfake1M++ dataset
1616
```bash
1717
python train.py --data_root /path/to/avdeepfake1m --model xception
1818
```
19+
### Output
1920

20-
## Output
21-
22-
* **Checkpoints:** Model checkpoints are saved under `./ckpt1/xception/`. The last checkpoint is saved as `last.ckpt`.
21+
* **Checkpoints:** Model checkpoints are saved under `./ckpt/xception/`. The last checkpoint is saved as `last.ckpt`.
2322
* **Logs:** Training logs (including metrics like `train_loss`, `val_loss`, and learning rates) are saved by PyTorch Lightning, typically in a directory named `./lightning_logs/`. You can view these logs using TensorBoard (`tensorboard --logdir ./lightning_logs`).
23+
24+
25+
## Inference
26+
27+
After training, you can generate predictions on a dataset subset (train, val, or test) using `infer.py`. This script will save the predictions to a text file, following the format from the [challenge](https://deepfakes1m.github.io/2025/details).
28+
29+
```bash
30+
python infer.py --data_root /path/to/avdeepfake1m --checkpoint /path/to/your/checkpoint.ckpt --model xception --subset val
31+
```
32+
33+
The output prediction file will be saved to `output/<model_name>_<subset>.txt` (e.g., `output/xception_val.txt`).
34+
35+
## Evaluation
36+
37+
```bash
38+
python evaluate.py <path_to_prediction_file> <path_to_metadata_json>
39+
```
40+
41+
For example:
42+
43+
```bash
44+
python evaluate.py ./output/xception_val.txt /path/to/avdeepfake1m/val_metadata.json
45+
```
46+
47+
This will print the AUC score based on your model's predictions.

examples/xception/evaluate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import argparse
2+
3+
from avdeepfake1m.evaluation import auc
4+
5+
if __name__ == "__main__":
6+
parser = argparse.ArgumentParser(description="Evaluation script for AV-Deepfake1M")
7+
parser.add_argument("prediction_file_path", type=str, help="Path to the prediction file (e.g., output/results/xception_val.txt)")
8+
parser.add_argument("metadata_file_path", type=str, help="Path to the metadata JSON file (e.g., /path/to/val_metadata.json)")
9+
args = parser.parse_args()
10+
11+
print(auc(
12+
args.prediction_file_path,
13+
args.metadata_file_path,
14+
"file", # As per README, this is usually "file"
15+
"fake_segments" # As per README, this is usually "fake_segments" for AUC
16+
))

examples/xception/infer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import argparse
2+
3+
import torch
4+
from tqdm.auto import tqdm
5+
from pathlib import Path
6+
7+
from avdeepfake1m.loader import AVDeepfake1mPlusPlusVideo
8+
from xception import Xception
9+
10+
parser = argparse.ArgumentParser(description="Xception inference")
11+
parser.add_argument("--data_root", type=str)
12+
parser.add_argument("--checkpoint", type=str)
13+
parser.add_argument("--model", type=str)
14+
parser.add_argument("--batch_size", type=int, default=128)
15+
parser.add_argument("--subset", type=str, choices=["train", "val", "test"])
16+
parser.add_argument("--gpus", type=int, default=1)
17+
parser.add_argument("--take_num", type=int, default=None)
18+
19+
if __name__ == '__main__':
20+
args = parser.parse_args()
21+
use_gpu = args.gpus > 0
22+
device = "cuda" if use_gpu else "cpu"
23+
24+
if args.model == "xception":
25+
model = Xception.load_from_checkpoint(args.checkpoint, lr=None, distributed=False).eval()
26+
else:
27+
raise ValueError(f"Unknown model: {args.model}")
28+
29+
model.to(device)
30+
model.train()
31+
test_dataset = AVDeepfake1mPlusPlusVideo(args.subset, args.data_root, take_num=args.take_num)
32+
33+
save_path = f"output/{args.model}_{args.subset}.txt"
34+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
35+
with open(save_path, "w") as f:
36+
with torch.inference_mode():
37+
for i, (video, _, label) in enumerate(tqdm(test_dataset)):
38+
# batch video as frames use batch_size
39+
preds_video = []
40+
for j in range(0, len(video), args.batch_size):
41+
batch = video[j:j + args.batch_size].to(device)
42+
preds_video.append(model(batch))
43+
44+
preds_video = torch.cat(preds_video, dim=0).flatten()
45+
# choose the max prediction
46+
pred = preds_video.max().item()
47+
48+
file_name = test_dataset.metadata[i].file
49+
f.write(f"{file_name};{pred}\n")

0 commit comments

Comments
 (0)