-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparams.py
More file actions
500 lines (339 loc) · 15.8 KB
/
params.py
File metadata and controls
500 lines (339 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
import sys,os,time
from math import sqrt, floor, ceil
import cPickle as pickle
import numpy as np
import pandas as pd
import argparse
opt = None
################ useful stuff ###############
def combine_dicts(*dicts):
tuples = []
for i in dicts:
if i is not None:
tuples += i.items()
return dict(tuples)
def getOrElse(map, key, default=None):
if key not in map or map[key] is None:
return default
else:
return map[key]
def init_params(argv):
global opt
# parse command line
parser = argparse.ArgumentParser()
parser.add_argument('-ID', type=str, nargs='+', default=["test"], help="id strings for this run")
parser.add_argument('-words', type=str, default="words_2018-03-19.pkl")
parser.add_argument('-humans', type=str, default="humans_7-14-2018.pkl")
parser.add_argument('-hid', type=str, nargs='+', default=[None])
parser.add_argument('-start_net', dest='start_net', type=str, nargs='+', default=None)
parser.add_argument('-params', dest='params', type=str, default="treatment.params.pkl")
parser.add_argument('-pi', type=int, nargs='+')
parser.add_argument('-treatment_run', action='store_true', default=False, help="Is this a treatment run?")
parser.add_argument('-treatments', dest='treatments', type=str, default=None)
parser.add_argument('-ti', type=int, nargs='+')
parser.add_argument('-zip', action='store_true', default=False, help="zip pi and hid lists (don't run all combinations)")
parser.add_argument('-lesion', action='store_true', default=False, help="Is this a lesion run?")
parser.add_argument('-treatment_language', type=str, default=None)
parser.add_argument('-flip_language', action='store_true', default=False, help="Flip the treatment language")
parser.add_argument('-run', type=int, default=None)
parser.add_argument('-nruns', type=int, default=1) ### obsolete -Uli 7/17/18
parser.add_argument('-naming_only', action='store_true', default=False, help="train naming assoc only")
parser.add_argument('-norm', type=str, default='classic')
parser.add_argument('-hood_std_devs', type=float, default=2.0, help="neighborhood size in stdevs of gaussian sigma function")
parser.add_argument('-nbatch', type=int, default=8)
parser.add_argument('-maxR', type=float, default=10)
parser.add_argument('-minR', type=float, default=3)
parser.add_argument('-act_fun', type=str, default='gauss')
parser.add_argument('-train_across_method', type=str, default='p')
parser.add_argument('-sem_lesion_strength', type=float, default = 0.0)
parser.add_argument('-eng_lesion_strength', type=float, default = 0.0)
parser.add_argument('-spa_lesion_strength', type=float, default = 0.0)
parser.add_argument('-pre_seed', type=int, default = -1)
parser.add_argument('-sem_lesion_seed', type=int, default = -1)
parser.add_argument('-eng_lesion_seed', type=int, default = -1)
parser.add_argument('-spa_lesion_seed', type=int, default = -1)
parser.add_argument('-post_seed', type=int, default = -1)
parser.add_argument('-noleak', action='store_true', default=False, help="Don't leak, duh")
parser.add_argument('-force_mono', action='store_true', default=True, help="force alpha and sigma do decrease after age 10")
parser.add_argument('-mask_assoc', action='store_true', default=False, help="Don't apply mask to assoc connections when lesioning")
parser.add_argument('-treatment_seed', type=int, default = -1)
parser.add_argument('-test_freq', type=float, default=1.0, help="test interval")
parser.add_argument('-no_trans', action='store_true', default=False, help="Don't test translation")
parser.add_argument('-initial_test', action='store_true', default=False, help="Test before we do anything")
parser.add_argument('-save_pre', action='store_true', default=False, help="save final net")
parser.add_argument('-save_post', action='store_true', default=False, help="save final net")
parser.add_argument('-save_pre_treat', action='store_true', default=False, help="save final net")
parser.add_argument('-save_post_treat', action='store_true', default=False, help="save final net")
parser.add_argument('-save_prefix', type=str, default=None)
parser.add_argument('-save_outputs', action='store_true', default=False, help="save raw test outputs")
parser.add_argument('-nogpu', action='store_true', default=False, help="Don't run model on GPU")
parser.add_argument('-nice', action='store_true', default=False, help="Don't hog all GPU resources")
opt = parser.parse_args(argv)
return opt
################## PARAMETERS #########################
class P(object):
file_name = None
frame = None
pi = None
params = None
hid = None
start_net = None
pre_seed = None
sem_lesion_seed = None
eng_lesion_seed = None
spa_lesion_seed = None
post_seed = None
map_size = None
exposure = None
min_exp = None
act_fun = None
nbatch = None
alpha_years = None
alphas = None
assoc_alpha_factor = None
sigma_years = None
sigmas = None
maxR = None
minR=None
norm = None
noise = None
noise_std = None
hood_std_devs = None
train_across_method = None
lesion_run = None
treatment_run = None
semantic_lesion_type = None
semantic_lesion_strength = None
assoc_lesion_type = None
assoc_lesion_strength = None
phonetic_lesion_type = None
eng_lesion_strength = None
spa_lesion_strength = None
mask_assoc = None
min_word_freq = None
@classmethod
def load(cls, fname):
cls.frame = pd.read_pickle(fname)
cls.file_name = fname
@classmethod
def getp(cls, name, default=None):
if name in cls.params:
return cls.params[name]
elif name.upper() in cls.params:
return cls.params[name.upper()]
elif name.lower() in cls.params:
return cls.params[name.lower()]
elif default is not None:
return default
else:
return None
@classmethod
def sigmaf(cls):
return lambda age: np.interp(age, cls.sigma_years, cls.sigmas)
@classmethod
def map_alphaf(cls):
return lambda age: np.interp(age, cls.alpha_years, cls.alphas)
@classmethod
def assoc_alphaf(cls):
return lambda age: cls.assoc_alpha_factor * np.interp(age, cls.alpha_years, cls.alphas)
@classmethod
def set_pi(cls, pi):
cls.pi = pi
cls.params = cls.frame.loc[pi].copy()
cls.hid = cls.getp("hid", opt.hid)
cls.start_net = cls.getp("start_net", opt.start_net)
cls.lesion_run = cls.getp("lesion_run", opt.lesion)
cls.treatment_run = cls.getp("treatment_run", opt.treatment_run)
cls.pre_seed = cls.getp("pre_seed", opt.pre_seed)
cls.sem_lesion_seed = cls.getp("sem_lesion_seed", opt.sem_lesion_seed)
cls.eng_lesion_seed = cls.getp("eng_lesion_seed", opt.eng_lesion_seed)
cls.spa_lesion_seed = cls.getp("spa_lesion_seed", opt.spa_lesion_seed)
cls.post_seed = cls.getp("post_seed", opt.post_seed)
#cls.treatment_seed = cls.getp("treatment_seed", opt.treatment_seed)
cls.map_size = int(cls.getp("MAP_SIZE"))
cls.exposure = cls.getp("EXPOSURE", 1.0)
cls.min_exp = cls.getp("MIN_EXP", 0.0)
cls.maxR = cls.getp("maxR", opt.maxR)
cls.minR = cls.getp("minR", opt.minR)
cls.act_fun = cls.getp("ACT_FUN", opt.act_fun)
cls.nbatch = cls.getp("NBATCH", opt.nbatch)
cls.alpha_years = cls.params.index[cls.params.index.str.startswith("ALPHA_")].map(lambda col: int(col[6:])).tolist()
cls.alpha_years.sort()
ALPHAS = ["ALPHA_%d"%year for year in cls.alpha_years]
if True: #opt.force_mono:
cls.params[ALPHAS[2:]] = cls.params[ALPHAS[2:]].cummin()
cls.alphas = cls.params[ALPHAS].tolist()
cls.assoc_alpha_factor = cls.getp("assoc_alpha_factor", default=1.0)
cls.train_across_method = cls.getp("train_across_method", default=opt.train_across_method) #"p","q","or", "and"
cls.sigma_years = cls.params.index[cls.params.index.str.startswith("SIGMA_")].map(lambda col: int(col[6:])).tolist()
cls.sigma_years.sort()
SIGMAS = ["SIGMA_%d"%year for year in cls.sigma_years]
if True: #opt.force_mono:
cls.params[SIGMAS[2:]] = cls.params[SIGMAS[2:]].cummin()
cls.sigmas = cls.params[SIGMAS].tolist()
cls.noise_std = cls.getp("NOISE_STD", 0.0)
cls.noise = cls.getp("NOISE", "uniform")
cls.norm = cls.getp("NORM", default=opt.norm)
cls.semantic_lesion_type = cls.getp("SEM_LESION_TYPE", default="none")
cls.semantic_lesion_strength = cls.getp("SEM_LESION_STRENGTH", default=opt.sem_lesion_strength)
cls.phonetic_lesion_type = cls.getp("PHONETIC_LESION_TYPE", default="none")
cls.eng_lesion_strength = cls.getp("ENG_LESION_STRENGTH", default=opt.eng_lesion_strength)
cls.spa_lesion_strength = cls.getp("SPA_LESION_STRENGTH", default=opt.spa_lesion_strength)
cls.mask_assoc = cls.getp("mask_assoc", default=opt.mask_assoc)
cls.papt_thresh = cls.getp("PAPT_THRESH", default=0.2)
cls.min_word_freq = cls.getp("MIN_WORD_FREQ",default=0.0)
cls.hood_std_devs = cls.getp("HOOD_STD_DEVS", opt.hood_std_devs)
return pi
################## TREATMENT PARAMS #########################
class T(object):
file_name = None
frame = None
ti = None
treatment = None
treatment_seed = -1
select_words = None
ntreatment_words = None
treatment_sem_factor = None
treatment_alpha_factor = None
noleak = None
leak_filter = None
treatment_leak_factor = None
treatment_leak_factor2 = None
treatment_hood_size = None
assoc_train_method = None
sessions_per_week = None
treatment_normal_exp_weeks = None
leak_named_only = None
leak_alt_naming = None
leak_train_map_method = None
leak_train_map_factor = None
leak_use_fake_act = None
@classmethod
def load(cls, fname):
cls.frame = pickle.load(open(fname))
cls.file_name = fname
@classmethod
def getp(cls, name, default=None):
if name in cls.treatment:
return cls.treatment[name]
elif name.upper() in cls.treatment:
return cls.treatment[name.upper()]
elif name.lower() in cls.treatment:
return cls.treatment[name.lower()]
else:
return default
@classmethod
def set_ti(cls, ti):
cls.ti = ti
cls.treatment = cls.frame.ix[ti]
cls.seed = cls.getp("treatment_seed", opt.treatment_seed)
cls.ntreatment_words = cls.getp("ntreatment_words", 30)
cls.treatment_sem_factor = cls.getp("treatment_sem_factor", 0.5)
cls.treatment_alpha_factor = cls.getp("treatment_alpha_factor", 0.5)
cls.noleak = cls.getp("noleak", opt.noleak)
cls.leak_filter = cls.getp("leak_filter", "trans") ### trans, named, none
cls.leak_named_only = cls.getp("leak_named_only", True) ### obsolete
cls.leak_alt_naming = cls.getp("leak_alt_naming", True) ### obsolete
cls.treatment_leak_factor = cls.getp("treatment_leak_factor", 0.5)
cls.treatment_leak_factor2 = cls.getp("treatment_leak_factor2", 0.0)
cls.treatment_hood_size = cls.getp("treatment_hood_size", 3.0)
cls.select_words = cls.getp("select_words", "qpu")
cls.assoc_train_method = cls.getp("assoc_train_method","and") # "none", "or", "and", "p", "always"
cls.sessions_per_week = cls.getp("sessions_per_week", 2.0)
### TODO:
cls.treatment_normal_exp_weeks = cls.getp("treatment_normal_exp_weeks", 0.5)
cls.treatment_normal_exp_train_all = cls.getp("treatment_normal_exp_train_all", False)
cls.leak_train_map_method = cls.getp("leak_train_map_method", "none") # "none", "sem", "cross", "both"
cls.leak_train_map_factor = cls.getp("leak_train_map_factor", 0.0)
cls.leak_use_fake_act = cls.getp("leak_use_fake_act", False)
################## HUMAN DATA #########################
class H:
file_name = None
frame = None
hid = None
human = None
stroke = patient = None
eng_exp = None
screener_eng = None
screener_spa = None
bnt_eng = None
bnt_spa = None
papt = None
baseline_eng = None
baseline_spa = None
age_at_stroke = None
age = None
treatment_language = None
treatment_sessions = None
@classmethod
def load(cls, fname):
cls.frame = pickle.load(open(fname))
cls.file_name = fname
@classmethod
def getp(cls, name, default=None):
if name in cls.human:
return cls.human[name]
elif name.upper() in cls.human:
return cls.human[name.upper()]
elif name.lower() in cls.human:
return cls.human[name.lower()]
else:
return default
@classmethod
def eng_expf(cls, post=False):
if not post:
return lambda age: cls.eng_exp[int(floor(age))]
else:
return lambda age: cls.eng_exp[int(ceil(min(age, cls.age)))]
@classmethod
def set_hid(cls, hid):
cls.hid = hid
cls.human = cls.frame.ix[hid]
cls.stroke = cls.patient = cls.getp("patient", False)
cls.age = cls.getp("age")
cls.age_at_stroke = cls.getp("age_at_stroke")
cls.eng_exp = np.clip([cls.human["ENG_EXP%d"%i] for i in range(int(cls.age+2))], P.min_exp, 1.0 - P.min_exp)
cls.spa_exp = 1.0 - cls.eng_exp
cls.screener_eng = cls.getp("screener_en")
cls.screener_spa = cls.getp("screener_spa")
cls.bnt_eng = cls.getp("bnt_eng")
cls.bnt_spa = cls.getp("bnt_spa")
cls.papt = cls.getp("papt")
cls.baseline_eng = cls.getp("baseline_eng")
cls.baseline_spa = cls.getp("baseline_spa")
cls.treatment_language = cls.getp("treatment_language", None)
cls.treatment_sessions = cls.getp("ntreatment_sessions", 0)
return hid
##################### TRAINING DATA ##########################
class W:
nwords = None
nsem_features = None
neng_features = None
nspa_features = None
sem_data = None
eng_data = None
spa_data = None
words = None
rare_words = None
common_words = None
@classmethod
def load(cls, fname):
wdata = pickle.load(open(fname))
cls.sem_data, cls.eng_data, cls.spa_data, cls.sem_count, cls.sem_yes, cls.categories, cls.papt, cls.papt2, cls.papt3, rare_words = wdata
cls.words = cls.sem_data.index
cls.nwords = cls.sem_data.shape[0]
cls.nsem_features = cls.sem_data.shape[1]
cls.neng_features = cls.eng_data.shape[1]
cls.nspa_features = cls.spa_data.shape[1]
cls.rare_words = rare_words.index
cls.common_words = cls.sem_data.index[~cls.sem_data.index.isin(cls.rare_words)]
cls.word_order = pd.Series(index=cls.words, data=np.arange(cls.nwords))
@classmethod
def word2int(cls, words):
return cls.word_order[words].values
@classmethod
def word_freq(cls):
word_freq = pd.Series(index=cls.words, data=np.ones(cls.nwords))
word_freq[cls.rare_words] = np.linspace(1.0, P.min_word_freq, len(cls.rare_words))
word_freq /= word_freq.sum()
return word_freq