-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
94 lines (79 loc) · 3.61 KB
/
main.py
File metadata and controls
94 lines (79 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch.utils.data import DataLoader
from algoritm import MyAlgo, load_model_from_file
from model import PointNetClassHead
from dataset import collate_fn
import torch.optim as optim
import torch.nn.functional as F
import logging
import os
# Configurazione del logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def train(model, train_loader, optimizer, epoch):
model.train()
total_loss = 0
for i, (points, labels) in enumerate(train_loader):
points = points.permute(0, 2, 1)
optimizer.zero_grad()
outputs, _, _ = model(points)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (i + 1) % 10 == 0: # Stampa il log ogni 10 batch
logging.info(f'Epoch [{epoch + 1}/100], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')
avg_loss = total_loss / len(train_loader)
logging.info(f'Epoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}')
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for points, labels in test_loader:
points = points.permute(0, 2, 1)
outputs, _, _ = model(points)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = correct / total
logging.info(f'Accuracy: {accuracy * 100:.2f}%')
return accuracy
def main():
num_epochs = 20
num_points = 512 # Specifica il numero di punti desiderato
algo = MyAlgo(PointNetClassHead(k=10), 'ModelNet10', num_points=num_points)
# Carica il modello pre-allenato se esiste
model = PointNetClassHead(k=10)
if os.path.exists('best_model.pth'):
logging.info('Loading pre-trained model...')
model.load_state_dict(torch.load('best_model.pth'))
else:
logging.info('No pre-trained model found. Training from scratch...')
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Caricamento dei dati pre-elaborati
if os.path.exists('processed/train_dataset.pt') and os.path.exists('processed/test_dataset.pt'):
logging.info('Loading pre-processed datasets...')
train_dataset = torch.load('processed/train_dataset.pt')
test_dataset = torch.load('processed/test_dataset.pt')
else:
logging.info('No pre-processed data found. Processing data...')
train_dataset = algo.load_data(split='train')
test_dataset = algo.load_data(split='test')
os.makedirs('processed', exist_ok=True)
torch.save(train_dataset, 'processed/train_dataset.pt')
torch.save(test_dataset, 'processed/test_dataset.pt')
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
best_accuracy = 0.0
for epoch in range(num_epochs):
train(model, train_loader, optimizer, epoch)
accuracy = test(model, test_loader)
# Salva il modello se l'accuratezza migliora
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), 'best_model.pth')
logging.info(f'New best model saved with accuracy: {accuracy * 100:.2f}%')
if __name__ == '__main__':
main()