-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_network.py
More file actions
95 lines (75 loc) · 2.44 KB
/
train_network.py
File metadata and controls
95 lines (75 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import pandas as pd
import numpy as np
import glob
import os
import sys
import random
from deepcrispr import deepcrispr
'''
pretraining
'''
#data directory
data_dir = '../PAM_Sites'
#PAM strings
data = []
#load files
for i,genome in enumerate(glob.glob(os.path.join(data_dir,'*.txt'))):
with open(genome) as f:
text = f.readlines()
text = [l[:20]+l[21:24] for l in text[1:]]
data.extend(text)
sys.stdout.write("processing file %i \r" % (i+1))
sys.stdout.flush()
print('\nsaved %i pam sites' % len(data))
#train val split
print('shuffling data into train/val splits')
num_samples = len(data)
val_size = int(0.2 * num_samples)
#random.shuffle(data)
seqs = data[:-val_size]
val_seqs = data[-val_size:]
#train model
if not os.path.exists('savedmodels'):
os.makedirs('savedmodels')
model = deepcrispr()
model.pretrain(seqs,X_val=val_seqs,savepath='savedmodels/deepcrispr_pretrain.ckpt')
'''
training
'''
#data directory
data_dir = '../Ecoli_Training_dataset'
#PAM strings
train_data = []
val_data = []
test_data = []
train_labels = []
val_labels = []
test_labels = []
#load files
files = list(glob.glob(os.path.join(data_dir,'*')))
for i,genome in enumerate(files[:-60]):
data = pd.read_csv(genome)
train_data.extend(data['Seq+PAM'])
train_labels.extend(data['Prediction'])
sys.stdout.write("processing train file %i \r" % (i+1))
sys.stdout.flush()
for i,genome in enumerate(files[-60:-30]):
data = pd.read_csv(genome)
val_data.extend(data['Seq+PAM'])
val_labels.extend(data['Prediction'])
sys.stdout.write("processing val file %i \r" % (i+1))
sys.stdout.flush()
for i,genome in enumerate(files[-30:]):
data = pd.read_csv(genome)
test_data.extend(data['Seq+PAM'])
test_labels.extend(data['Prediction'])
sys.stdout.write("processing test file %i \r" % (i+1))
sys.stdout.flush()
#train model
model.load('savedmodels/deepcrispr_pretrain.ckpt')
model.train(train_data,train_labels,val_data,val_labels,savepath='savedmodels/deepcrispr.ckpt')
model.load('savedmodels/deepcrispr.ckpt')
fscore,precision,recall = model.fscore(test_data,test_labels)
print('test set fscore: %.6f' % fscore)
print('test set precision: %.6f' % precision)
print('test set recall: %.6f' % recall)