-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathengine.py
More file actions
49 lines (32 loc) · 1.16 KB
/
engine.py
File metadata and controls
49 lines (32 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#training and evaluation functions
import torch
import torch.nn as nn
def train(data_loader, model, optimizer, device):
model.train()
for data in data_loader:
reviews = data['review']
target = data['target']
reviews = reviews.(to(device, dtype= torch.long)
target = target.to(device, dtype= torch.float)
#clear gradient
optimizer.zero_grad()
predictions = model(reviews)
#claculate loss
loss = nn.BCEWithLogitsLoss()(predictions, target.view(-1,1))
loss.backward()
optimizer.step()
def evaluate(data_loader, model, device):
final_predictions = []
final_target= []
model.eval()
#disable gradient calculation
with torch.no_grad():
reviews = data['review']
target = data['target']
reviews = reviews.(to(device, dtype= torch.long)
target = target.to(device, dtype= torch.float)
predictions = model(reviews)
predictions = prediction.cpu().numpy().tolist()
target = data['target'].cpu().numpy().tolist()
final_predictions.extend(predictions)
final_target.extend(target)