-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_model.py
More file actions
55 lines (48 loc) · 2.07 KB
/
run_model.py
File metadata and controls
55 lines (48 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
##############################################################################
# Define parameters
##############################################################################
from get_test_results import run_model_on_test
from train_model import train_model
hidden_sz = 10 # Hidden size
fc_sz = 20 # Fully connected size
embed_sz = 50 # Embedding size
dropout_rate = 0.5 # Dropout rate
learning_rate = 1e-3 # Learning rate
training_epochs = 1 # Number of epochs
bidirectional = True # Should the T-LSTM be directional?
survival = True # Should the model use survival output form and survival function?
##############################################################################
# Path
##############################################################################
# Path were to put the results. Should contain a directory data with the data inside
path_results = "xx"
##############################################################################
# Load util functions
##############################################################################
##############################################################################
# Training
##############################################################################
train_model(path_results,
model="TLSTM",
hidden_sz=hidden_sz,
fc_sz=fc_sz,
dropout_rate=dropout_rate,
discount="log",
learning_rate=learning_rate,
training_epochs=training_epochs,
embed_sz=embed_sz,
survival=True,
bidirectional=True)
##############################################################################
# Test
##############################################################################
run_model_on_test(path_results,
model="TLSTM",
hidden_sz=hidden_sz,
fc_sz=fc_sz,
dropout_rate=dropout_rate,
discount="log",
learning_rate=learning_rate,
embed_sz=embed_sz,
l=True,
bidirectional=True)