-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_dset.py
More file actions
122 lines (90 loc) · 2.84 KB
/
create_dset.py
File metadata and controls
122 lines (90 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import csv
import pickle
import random
import imageio
import numpy as np
from skimage.transform import resize
classes = ['bird','blank','cattle','chimpanzee','elephant',
'forest buffalo','gorilla','hippopotamus','human','hyena',
'large ungulate','leopard','lion','other (non-primate)',
'other (primate)','pangolin','porcupine','reptile','rodent',
'small antelope','small cat','wild dog','duiker', 'hog']
counts = [ 2386, 122270, 372, 5045, 1085,
9, 174, 175, 20005, 10,
224, 209, 2, 1876,
20349, 63, 569, 7, 2899,
273, 79, 21, 21471, 4557 ]
seed = 6554
ignore_list = [1, 8]
#### for positive classes
undersample_number = counts
undersample_number[-2] = 6000
undersample_number[1] = 0
undersample_number[8] = 0 ## assuming we use a nice pretrained model for humans :)
def get_train_labels():
gt_dict = {} ## For easier GT extraction
list_classes_gt = {} ## For sampling
for i in xrange(len(classes)):
list_classes_gt[i] = []
with open('train_labels.csv','rb') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
num = 0
for row in spamreader:
if num ==0:
num+=1
continue
filename = row[0]
rr = [float(r) for r in row[1:]]
if rr.count(1) > 1:
print 'gg'
classid = rr.index(1)
gt_dict[filename] = classid
list_classes_gt[classid].append(filename)
savename = 'data/gt_dict.pkl'
pickle.dump(gt_dict, open(savename, 'w'), -1)
savename = 'data/list_classes_gt.pkl'
pickle.dump(list_classes_gt, open(savename, 'w'), -1)
def get_test_ids():
test_names = []
with open('submission_format.csv','rb') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
num = 0
for row in spamreader:
if num ==0:
num+=1
continue
test_names.append(row[0])
return test_names
def preprocess(img, shape=None):
if shape:
img = resize(img, shape)
if img.max() > 1.:
img = img / 255.
img = img.astype(np.float32)
return img
def make_dataset(vid_dir='micro/', shape=(30,64,64,3)):
classes_gt = pickle.load(open('data/list_classes_gt.pkl'))
items = []
for key,values in classes_gt.items():
print "In class - ", key
if key in ignore_list:
print "Ignoring this class: ", classes[key]
continue
if key == 22:
random.seed(seed)
values = random.sample(values, undersample_number[key])
for num,value in enumerate(values):
if num%500 == 0:
print "Processing {}/{} videos".format(num,len(values))
vidreader = imageio.get_reader(vid_dir + value, 'ffmpeg')
vid = np.zeros(shape)
for num,frame in enumerate(vidreader):
vid[num] = frame
vid = vid.transpose(3,0,1,2) #### Since we need CxDxHxW
items.append([vid, key])
savename = "../data/items.pkl"
pickle.dump(items, open(savename, 'w'), -1)
return items
if __name__ == '__main__':
# get_train_labels()
make_dataset()