-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
57 lines (48 loc) · 1.07 KB
/
config.py
File metadata and controls
57 lines (48 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
MODEL = {
'net': 'resnet56',
'net_type': 'stable'
}
DATA_SET = {
'dataset': 'cifar10', # cifar10 | cifar100 | imagenet
'num_workers': 2,
}
TRAIN = {
'loss': 'crossentropy',
'batch_size': 128,
'epoch_num': 200,
'lr': 1e-1,
'weight_decay': 1e-4,
'train_val_split': (0.9, 0.1),
}
QUANTIZATION = {
'quantize_activations': False,
'quantize_weights': False,
}
STABILITY = {
'stable': True,
'stability_coeff': 0.1
}
DISTILLATION = {
'highp_bits': 32,
'lowp_bits': 8,
'activation_loss_coeff': 1e-5,
}
WANDB = {
# Add your WandB account and project name here to upload data to WandB
'wandb_entity': None
'wandb_project_name': None,
}
CONFIG = {}
CONFIG.update(MODEL)
CONFIG.update(DATA_SET)
CONFIG.update(TRAIN)
CONFIG.update(QUANTIZATION)
CONFIG.update(STABILITY)
CONFIG.update(DISTILLATION)
CONFIG.update(WANDB)
if torch.cuda.is_available():
print(
f'''Running on device {torch.cuda.current_device()}
name {torch.cuda.get_device_name(device=torch.cuda.current_device())}'''
)