-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_labels.py
More file actions
34 lines (27 loc) · 1.3 KB
/
Copy pathexport_labels.py
File metadata and controls
34 lines (27 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import json
from transformers import ViTForImageClassification, DetrForObjectDetection
def export_labels():
output_dir = os.path.join("src-tauri", "resources")
os.makedirs(output_dir, exist_ok=True)
print("Loading ViT config...")
vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Convert keys to integer keys or strings as needed, but JSON requires string keys
vit_labels = {int(k): v for k, v in vit.config.id2label.items()}
# Sort by key to ensure order matches indices
vit_list = [vit_labels[i] for i in range(len(vit_labels))]
with open(os.path.join(output_dir, "vit_labels.json"), "w") as f:
json.dump(vit_list, f, indent=2)
print("Saved vit_labels.json")
print("Loading DETR config...")
detr = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
detr_labels = {int(k): v for k, v in detr.config.id2label.items()}
# DETR config might skip some class IDs, let's create a mapping dict instead of list
with open(os.path.join(output_dir, "detr_labels.json"), "w") as f:
json.dump({str(k): v for k, v in detr_labels.items()}, f, indent=2)
print("Saved detr_labels.json")
if __name__ == "__main__":
try:
export_labels()
except Exception as e:
print(f"Error: {e}")