-
Notifications
You must be signed in to change notification settings - Fork 0
Sourcery refactored master branch #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -204,14 +204,14 @@ def __len__(self): | |
|
|
||
| def train_model(model, criterion,optimizer,scheduler, num_epoch = 25): | ||
| since = time.time() | ||
|
|
||
| best_model_wts = copy.deepcopy(model.state_dict()) | ||
| best_acc = 0.0 | ||
|
|
||
| for epoch in range(num_epoch): | ||
| print('Epoch {}/{}'.format(epoch, num_epoch - 1)) | ||
| print(f'Epoch {epoch}/{num_epoch - 1}') | ||
| print('*'*10) | ||
|
|
||
|
Comment on lines
-207
to
+214
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| #each stage have training and validation phase | ||
| for phase in ['train', 'val']: | ||
| if phase == 'train': | ||
|
|
@@ -223,7 +223,7 @@ def train_model(model, criterion,optimizer,scheduler, num_epoch = 25): | |
| #print("debugging: eval phase") | ||
| running_loss = 0.0 | ||
| running_corrects = 0 | ||
|
|
||
| #Iterate | ||
| for inputs, labels in dataloaders[phase]: | ||
| inputs = inputs.cuda() | ||
|
|
@@ -232,15 +232,15 @@ def train_model(model, criterion,optimizer,scheduler, num_epoch = 25): | |
| #zero the parameter gradient | ||
| #print("debugging: zero grad") | ||
| optimizer.zero_grad() | ||
|
|
||
| #forward | ||
| with torch.set_grad_enabled(phase == 'train'): | ||
| #print("debugging: forward phase") | ||
| inputs = inputs.cuda() | ||
| outputs = model(inputs) | ||
| _, preds = torch.max(outputs,1) | ||
| loss = criterion(outputs,labels) | ||
|
|
||
| #back | ||
| if phase == 'train': | ||
| #print("debugging: backward phase") | ||
|
|
@@ -249,13 +249,13 @@ def train_model(model, criterion,optimizer,scheduler, num_epoch = 25): | |
| #stat | ||
| running_loss += loss.item() *inputs.size(0) | ||
| running_corrects += torch.sum(preds == labels.data) | ||
|
|
||
| epoch_loss = running_loss / datasetSize[phase] | ||
| epoch_acc = running_corrects.double() / datasetSize[phase] | ||
| #loss.append(epoch_loss) | ||
| #acc.append(epoch_acc) | ||
| print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) | ||
|
|
||
| #deep copy the model | ||
| if phase == 'val' and epoch_acc > best_acc: | ||
| best_acc = epoch_acc | ||
|
|
@@ -326,7 +326,7 @@ def visualize_model(model, num_images=4): | |
| fig = plt.figure() | ||
|
|
||
| with torch.no_grad(): | ||
| for i, (inputs, labels) in enumerate(dataloaders['val']): | ||
| for inputs, labels in dataloaders['val']: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| inputs = inputs.to(device) | ||
| labels = labels.to(device) | ||
|
|
||
|
|
@@ -337,7 +337,7 @@ def visualize_model(model, num_images=4): | |
| images_so_far += 1 | ||
| ax = plt.subplot(num_images//2, 2, images_so_far) | ||
| ax.axis('off') | ||
| ax.set_title('predicted: {}'.format(labels_ls[preds[j]])) | ||
| ax.set_title(f'predicted: {labels_ls[preds[j]]}') | ||
| imshow(inputs.cpu().data[j]) | ||
| pred_ls.append(int(preds[j].cpu().numpy())) | ||
|
|
||
|
|
@@ -350,7 +350,7 @@ def visualize_model(model, num_images=4): | |
| # In[30]: | ||
|
|
||
|
|
||
| for i, (inputs, name) in enumerate(dataloaders['val']): | ||
| for inputs, name in dataloaders['val']: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
| count += 1 | ||
| print(inputs.shape) | ||
| print(name) | ||
|
|
@@ -365,14 +365,13 @@ def visualize_model(model, num_images=4): | |
| count = 0 | ||
| model_ft.eval() | ||
| test_pred = torch.LongTensor() | ||
| for i, data in enumerate(testloader): | ||
|
|
||
| for data in testloader: | ||
| data = Variable(data[0], volatile=True) | ||
| if torch.cuda.is_available(): | ||
| data = data.cuda() | ||
|
|
||
| output = model_ft(data) | ||
|
|
||
| pred = output.cpu().data.max(1, keepdim=True)[1] | ||
| test_pred = torch.cat((test_pred, pred), dim=0) | ||
| ''' | ||
|
|
@@ -411,7 +410,7 @@ def visualize_model(model, num_images=4): | |
| Fucking Hell. What the fuck is this. | ||
|
|
||
| ''' | ||
| for i, (inputs, name) in enumerate(testloader): | ||
| for inputs, name in testloader: | ||
| count += 1 | ||
| print(inputs.shape) | ||
| print(name) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -83,14 +83,13 @@ def imshow(inp, title = None): | |
|
|
||
| def train_model(model, criterion, optimizer, scheduler, num_epoch=25): | ||
| since = time.time() | ||
|
|
||
| best_model_wts = copy.deepcopy(model.state_dict()) | ||
| best_acc = 0.0 | ||
|
|
||
| for epoch in range(num_epoch): | ||
| print('Epoch{}/{}'.format(epoch,num_epoch-1)) | ||
| print(f'Epoch{epoch}/{num_epoch - 1}') | ||
| print('-'*10) | ||
|
|
||
|
Comment on lines
-86
to
-93
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| #Each epoch in range(num_epochs): | ||
| for phase in ['train','val']: | ||
| if phase == 'train': | ||
|
|
@@ -121,7 +120,6 @@ def train_model(model, criterion, optimizer, scheduler, num_epoch=25): | |
| if phase == 'train': | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| #stat | ||
| running_loss += loss.item() * inputs.size(0) | ||
| running_corrects += torch.sum(preds == labels.data) | ||
|
|
@@ -159,7 +157,7 @@ def visualize_model(model, num_images=6): | |
| fig = plt.figure() | ||
|
|
||
| with torch.no_grad(): | ||
| for i, (inputs, labels) in enumerate(dataloaders['val']): | ||
| for inputs, labels in dataloaders['val']: | ||
|
Comment on lines
-162
to
+160
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| inputs = inputs.to(device) | ||
| labels = labels.to(device) | ||
|
|
||
|
|
@@ -170,7 +168,7 @@ def visualize_model(model, num_images=6): | |
| images_so_far += 1 | ||
| ax = plt.subplot(num_images//2, 2, images_so_far) | ||
| ax.axis('off') | ||
| ax.set_title('predicted: {}'.format(class_names[preds[j]])) | ||
| ax.set_title(f'predicted: {class_names[preds[j]]}') | ||
| imshow(inputs.cpu().data[j]) | ||
|
|
||
| if images_so_far == num_images: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,21 +103,21 @@ def forward(self, x): | |
| #train the network | ||
|
|
||
| for epoch in range(2): | ||
|
|
||
| running_loss = 0 | ||
| for i, data in enumerate(trainloader, 0): | ||
| #get the input | ||
| inputs, labels = data | ||
|
|
||
| #zero the parameter gradients | ||
| optimizer.zero_grad() | ||
|
|
||
| #forward + backward + optimize | ||
| output = net(inputs) | ||
| loss = criterion(output, labels) | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
|
Comment on lines
-106
to
+120
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found the following improvement in Lines |
||
| #print stat | ||
| running_loss += loss.item() | ||
| if i % 2000 == 1999:#every 2000 mini-batches | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines
24-123refactored with the following changes:use-fstring-for-formatting)