-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTorchTrack.py
More file actions
100 lines (87 loc) · 3.28 KB
/
TorchTrack.py
File metadata and controls
100 lines (87 loc) · 3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import os
import traceback
class TorchTrack:
def __init__(self, experiment_name='default'):
"""
Initialize TorchTrack with an experiment name.
Args:
experiment_name (str): Name of the current experiment
"""
self.data_file = 'data.json'
self.experiment_name = experiment_name
self.epoch_data = {
'loss_per_epoch': [],
'accuracy_per_epoch': []
}
# Ensure data.json exists and is initialized
self._initialize_data_file()
def _initialize_data_file(self):
"""
Initialize or reset the data.json file.
"""
try:
# Create an empty JSON structure if file doesn't exist
if not os.path.exists(self.data_file):
with open(self.data_file, 'w') as f:
json.dump([], f)
except Exception as e:
print(f"Error initializing data file: {e}")
traceback.print_exc()
def clean_previous_data(self):
"""
Clean previous experiment data from data.json.
"""
try:
with open(self.data_file, 'w') as f:
json.dump([], f)
print("Previous experiment data cleared successfully.")
# Reset epoch data
self.epoch_data = {
'loss_per_epoch': [],
'accuracy_per_epoch': []
}
except Exception as e:
print(f"Error cleaning previous data: {e}")
traceback.print_exc()
def log_epoch(self, loss, accuracy):
"""
Log individual epoch data.
Args:
loss (float): Loss value for the current epoch
accuracy (float): Accuracy value for the current epoch
"""
self.epoch_data['loss_per_epoch'].append(loss)
self.epoch_data['accuracy_per_epoch'].append(accuracy)
def log(self, hyperparameters, metrics, model_type, model_data):
"""
Log experiment data to data.json.
Args:
hyperparameters (dict): Hyperparameters used in the experiment
metrics (dict): Performance metrics of the experiment
"""
try:
# Read existing data
with open(self.data_file, 'r') as f:
data = json.load(f)
# Prepare experiment entry
experiment_entry = {
'experiment_name': self.experiment_name,
'hyperparameters': hyperparameters,
'metrics': metrics,
'model_type': model_type,
'model_data': model_data,
'epoch_data': self.epoch_data # Include epoch-level data
}
# Add new experiment data
data.append(experiment_entry)
# Write updated data back to file
with open(self.data_file, 'w') as f:
json.dump(data, f, indent=4)
except json.JSONDecodeError:
print("JSON Decode Error. Reinitializing data file.")
self._initialize_data_file()
self.log(hyperparameters, metrics)
except Exception as e:
print(f"Error logging data: {e}")
traceback.print_exc()