forked from sagittaeri/htt
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain
More file actions
executable file
·79 lines (69 loc) · 2.49 KB
/
train
File metadata and controls
executable file
·79 lines (69 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python
from rootpy.extern.argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--max-trees', type=int, default=200)
parser.add_argument('--min-trees', type=int, default=1)
parser.add_argument('--learning-rate', type=float, default=0.1)
parser.add_argument('--max-fraction', type=float, default=0.3)
parser.add_argument('--min-fraction', type=float, default=0.001)
parser.add_argument('--min-fraction-steps', type=int, default=100)
parser.add_argument('--nfold', type=int, default=10,
help='the number of folds in the cross-validation')
parser.add_argument('--masses', nargs='+', default=['125',])
parser.add_argument('--suffix', default=None)
parser.add_argument('--procs', type=int, default=-1)
parser.add_argument('category', choices=('vbf', 'boosted'))
args = parser.parse_args()
from mva.categories import Category_VBF, Category_Boosted
from mva.analysis import Analysis
from mva.samples import Higgs
from mva.defaults import TRAIN_FAKES_REGION
if args.masses == ['all',]:
args.masses = Higgs.MASSES[:]
masses_label = 'all'
else:
args.masses = map(int, args.masses)
args.masses.sort()
masses_label = '_'.join(map(str, args.masses))
if args.category == 'vbf':
category = Category_VBF
else:
category = Category_Boosted
analysis = Analysis(
year=2012,
systematics=False,
fakes_region=TRAIN_FAKES_REGION,
suffix=args.suffix)
analysis.normalize(category)
# combine embedded and MC Ztt for training
# TODO: account for the fact that N(MC) != N(EMB)
#analysis_eb = get_analysis(args, embedding=True)
#analysis_mc = get_analysis(args, embedding=False)
#analysis_eb.normalize(category)
#analysis_mc.normalize(category)
#analysis_eb.ztautau.scale *= 0.5
#analysis_mc.ztautau.scale *= 0.5
#backgrounds_train = [
# analysis_eb.ztautau,
# analysis_mc.ztautau,
# analysis.others,
# analysis.qcd,
#]
backgrounds_train = analysis.backgrounds
signals_train = [
Higgs(year=2012,
masses=args.masses,
modes=category.train_signal_modes),
]
clf = analysis.get_clf(category, load=False, mass=masses_label)
clf.train(signals=signals_train,
backgrounds=backgrounds_train,
remove_negative_weights=True,
max_trees=args.max_trees,
min_trees=args.min_trees,
learning_rate=args.learning_rate,
max_fraction=args.max_fraction,
min_fraction=args.min_fraction,
min_fraction_steps=args.min_fraction_steps,
cv_nfold=args.nfold,
n_jobs=args.procs)