-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·111 lines (84 loc) · 2.59 KB
/
main.py
File metadata and controls
executable file
·111 lines (84 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import io
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy
from PIL import Image
import wikipedia
import warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
warnings.filterwarnings('ignore')
#transforms and return image
def load_image(path):
transform = transforms.Compose([
transforms.Resize(size=(244, 244)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])
image = Image.open(path)
image = transform(image)[:3,:,:].unsqueeze(0)
image = image.to(device)
return image
#loading pretrained vgg19 model
model = models.vgg19(pretrained = False)
model.load_state_dict(torch.load('Models/vgg.pth', map_location = device))
model.to(device)
model.eval()
#loading resnet101 trained on dataset
#finetune model
resnet= models.resnet152(pretrained = False)
ftrs = resnet.fc.in_features # gives input dimentions of fullyconnected layer
resnet.fc = nn.Linear(ftrs,133)
resnet.load_state_dict(torch.load('Models/model.pt', map_location = device))
resnet.to(device)
#returns dog detected or not
def vgg(path):
'''
vgg19 is trained on imagenet containg 1000 classes
so from class no. 151 to 277 reprsents the dogs(including wild)
'''
output= model(path)
return torch.max(output, 1)[1].item()
#returns predicted breed
def res(path):
output = resnet(path)
return torch.max(output,1)[1].item()
#reading class_name if not a dog from vgg classes
def class_name_vgg(idx):
file = open('classes/vgg.txt', 'r')
lines = file.read().split('\n')
lines = [x for x in lines]
return lines[idx]
#returns breed name from text file
def breed_name(idx):
file = open('classes/breed.txt', 'r')
lines = file.read().split('\n')
lines = [x for x in lines]
return lines[idx]
# pass the image to trained model and predict the breed.
def breed(path):
in_img = load_image(path)
a = vgg(in_img)
print('----------------------')
print(a)
print('-----------------------')
if a >= 151 and a <=280:
class_no = res(in_img)
found_breed = breed_name(class_no)
flag = 0
return found_breed, flag
else:
found_obj = class_name_vgg(a) #returns class from vgg to show what is in image
flag = 1
return found_obj, flag
#returns information from wikipedia
def wiki(info):
info = info + ' dog'
try:
inf = wikipedia.summary(info)
except:
inf = 'NO DATA FOUND'
return inf