Skip to content

Commit a48fc33

Browse files
committed
push old uncommit changes
1 parent e15bc8c commit a48fc33

1 file changed

Lines changed: 233 additions & 30 deletions

File tree

scripts/merge_subarray_table.py

Lines changed: 233 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11

22
import sys
33
from argparse import ArgumentParser
4-
from astropy.table import join, unique, vstack
4+
from astropy.table import join, unique, vstack, Table, setdiff
5+
import math
6+
import re
57
import numpy as np
6-
from pathlib import Path
8+
import pathlib
9+
import astropy.units as u
710

8-
from ctapipe.io import read_table, write_table
11+
from ctapipe.io import HDF5Merger, read_table, write_table
912
from ctapipe.containers import (
1013
ParticleClassificationContainer,
1114
ReconstructedGeometryContainer,
1215
ReconstructedEnergyContainer,
1316
)
14-
from ctapipe.core import Tool, traits
17+
from ctapipe.core import Tool
1518
from ctapipe.core.traits import (
1619
Unicode,
1720
Bool,
21+
Path,
1822
List,
23+
Enum,
24+
classes_with_traits,
1925
)
2026
from ctapipe.instrument import SubarrayDescription
2127
from ctapipe.reco.utils import add_defaults_and_meta
@@ -54,7 +60,7 @@ class MergeSubarrayTables(Tool):
5460

5561
input_url = Path(
5662
help="Input ctapipe HDF5 files including stereoscopic predictions.",
57-
allow_none=False,
63+
allow_none=True,
5864
exists=True,
5965
directory_ok=False,
6066
file_ok=True,
@@ -70,7 +76,7 @@ class MergeSubarrayTables(Tool):
7076
).tag(config=True)
7177

7278
input_files = List(
73-
traits.Path(exists=True, directory_ok=False),
79+
Path(exists=True, directory_ok=False),
7480
default_value=[],
7581
help="Input ctapipe HDF5 files including stereoscopic predictions.",
7682
).tag(config=True)
@@ -80,7 +86,7 @@ class MergeSubarrayTables(Tool):
8086
help="Give a specific file pattern for matching files in ``input_dir``",
8187
).tag(config=True)
8288

83-
output_path = traits.Path(
89+
output_path = Path(
8490
help="Output ctapipe HDF5 file for the merged stereoscopic predictions.",
8591
allow_none=False,
8692
exists=False,
@@ -100,21 +106,43 @@ class MergeSubarrayTables(Tool):
100106
help="List of reconstruction tasks to be used for the stereo combination.",
101107
).tag(config=True)
102108

109+
n_telescopes = Enum(
110+
[3,4],
111+
default_value=4,
112+
allow_none=False,
113+
help="Number of telescopes in the subarray. "
114+
"This is used to determine the telescope combinations.",
115+
).tag(config=True)
116+
103117
overwrite = Bool(
104118
default_value=True,
105119
allow_none=False,
106120
help="Overwrite the table in the hdf5 file if it exists",
107121
).tag(config=True)
108122

109123
parser = ArgumentParser()
110-
parser.add_argument("input_files", nargs="*", type=Path)
124+
parser.add_argument("input_files", nargs="*", type=pathlib.Path)
111125

112126
aliases = {
113127
("i", "input-dir"): "MergeSubarrayTables.input_dir",
114128
("o", "output"): "MergeSubarrayTables.output_path",
115129
("p", "pattern"): "MergeSubarrayTables.file_pattern",
130+
("t", "n_telescopes"): "MergeSubarrayTables.n_telescopes",
116131
}
117132

133+
flags = {
134+
"overwrite": (
135+
{"HDF5Merger": {"overwrite": True}},
136+
"Overwrite existing files",
137+
),
138+
"append": (
139+
{"HDF5Merger": {"append": True}},
140+
"Append to existing files",
141+
),
142+
}
143+
144+
classes = classes_with_traits(HDF5Merger)
145+
118146
def setup(self):
119147
# Set up the containers and colnames based on the reco tasks
120148
self.reco_containers = {
@@ -142,61 +170,225 @@ def setup(self):
142170
"or input files as positional arguments"
143171
)
144172
sys.exit(1)
173+
174+
# Merge the first input file to the output path
175+
with HDF5Merger(
176+
parent=self,
177+
output_path=self.output_path,
178+
) as merger:
179+
merger(self.input_files[0])
145180
# Read the SubarrayDescription from the first input file
146-
self.subarray = SubarrayDescription.read(self.input_files[0])
181+
#self.subarray = SubarrayDescription.read(self.input_files[0])
182+
if self.n_telescopes == 4:
183+
self.tel_id_2_index = {1:0, 2:1, 3:2, 4:3}
184+
self.tel_combinations = [[1,4], [4,3], [3,2], [2,1], [1,3], [4,2]]
185+
elif self.n_telescopes == 3:
186+
self.tel_id_2_index = {1:0, 3:1, 4:2}
187+
self.tel_combinations = [[1,4], [4,3], [1,3]]
188+
189+
dl1b_parameters = []
190+
for tel_id in self.tel_id_2_index.keys():
191+
dl1b_parameters.append(
192+
read_table(
193+
self.input_files[0],
194+
f"/dl1/event/telescope/parameters/tel_{tel_id:03d}"
195+
)
196+
)
197+
198+
dl1b_parameter_table = vstack(dl1b_parameters)
199+
#self.log.info(dl1b_parameter_table)
200+
201+
dl1b_parameter_table.sort(TELESCOPE_EVENT_KEYS)
202+
203+
dl1b_parameter_groups = dl1b_parameter_table.group_by(SUBARRAY_EVENT_KEYS)
204+
self.weights = {}
205+
self.weights["obs_id"] = np.zeros(
206+
len(dl1b_parameter_groups.groups),
207+
dtype=int
208+
)
209+
self.weights["event_id"] = np.zeros(
210+
len(dl1b_parameter_groups.groups),
211+
dtype=int
212+
)
213+
self.weights[f"{self.prefix}_telescopes"] = np.zeros(
214+
(len(dl1b_parameter_groups.groups), len(self.tel_id_2_index)),
215+
dtype=bool
216+
)
217+
self.weights[f"{self.prefix}_is_valid"] = np.ones(
218+
len(dl1b_parameter_groups.groups),
219+
dtype=bool
220+
)
221+
for tel_combination in self.tel_combinations:
222+
self.weights[f"LST{tel_combination[0]}LST{tel_combination[1]}_norm"] = np.zeros(
223+
len(dl1b_parameter_groups.groups),
224+
dtype=float
225+
)
226+
for g, grp in enumerate(dl1b_parameter_groups.groups):
227+
# Save the obs_id and event_id for each group
228+
self.weights["obs_id"][g] = int(grp["obs_id"][0])
229+
self.weights["event_id"][g] = int(grp["event_id"][0])
230+
#if self.weights["obs_id"][g] > 2:
231+
# break
232+
# Create boolean array for tel_ids with hillas_intensity > 25
233+
# tel_id indexing starts from 1
234+
tel_bool_array = np.zeros(len(self.tel_id_2_index), dtype=bool)
235+
tel_hillas_array = np.zeros(len(self.tel_id_2_index), dtype=float)
236+
# Set True for tel_ids with high intensity
237+
tel_mask = grp["hillas_intensity"] > 25
238+
for i, survival_tel_id in enumerate(grp["tel_id"][tel_mask]):
239+
tel_bool_array[self.tel_id_2_index[survival_tel_id]] = True
240+
tel_hillas_array[self.tel_id_2_index[survival_tel_id]] = grp["hillas_intensity"][tel_mask][i] # Adjust for 0-based indexing
241+
self.weights[f"{self.prefix}_telescopes"][g] = tel_bool_array
242+
243+
surviving_tel_ids = set(grp["tel_id"][tel_mask])
244+
if len(grp["tel_id"][tel_mask]) < 2:
245+
self.weights[f"{self.prefix}_is_valid"][g] = False
246+
elif len(grp["tel_id"][tel_mask]) == 2:
247+
for tel_combination in self.tel_combinations:
248+
if all(tel_id in surviving_tel_ids for tel_id in tel_combination):
249+
self.weights[f"LST{tel_combination[0]}LST{tel_combination[1]}_norm"][g] = 1.0
250+
elif len(grp["tel_id"][tel_mask]) > 2:
251+
hillas_sum = {}
252+
for tel_combination in self.tel_combinations:
253+
if all(tel_id in surviving_tel_ids for tel_id in tel_combination):
254+
hillas_sum[f"LST{tel_combination[0]}LST{tel_combination[1]}"] = tel_hillas_array[self.tel_id_2_index[tel_combination[0]]] + tel_hillas_array[self.tel_id_2_index[tel_combination[1]]]
255+
256+
hillas_norms = {key: value / np.sum(list(hillas_sum.values())) for key, value in hillas_sum.items()}
257+
for key, hillas_norm in hillas_norms.items():
258+
self.weights[f"{key}_norm"][g] = hillas_norm
259+
# Convert to astropy Table
260+
self.weights = Table(data=self.weights)
261+
# self.log.info(len(self.weights))
147262

148263
def start(self):
264+
149265
# Loop over the reconstruction tasks and combine the telescope tables to a subarray table
266+
class_table, eng_table, dir_table = None, None, None
150267
for reco_task in self.reco_tasks:
151268
self.log.info("Processing %s...", reco_task)
152269

153-
# Read the subarray tables from the input files
270+
# Read and join the subarray tables from the input files
154271
subarray_tables = []
155272
for input_file in self.input_files:
156-
subarray_tables.append(
157-
read_table(
158-
input_file,
159-
f"{DL2_SUBARRAY_GROUP}/{reco_task}/{self.prefix}",
160-
)
273+
274+
self.log.info("Reading from file: %s", input_file)
275+
276+
trigger = read_table(
277+
input_file,
278+
"/dl1/event/subarray/trigger",
279+
)
280+
self.log.info(trigger)
281+
shower = read_table(
282+
input_file,
283+
"/simulation/event/subarray/shower",
161284
)
162-
# Stack the telescope tables to a common table
163-
subarray_tables = vstack(subarray_tables)
285+
self.log.info(shower)
286+
self.log.info(len(trigger))
287+
self.log.info(len(shower))
288+
289+
dl2_tab = read_table(
290+
input_file,
291+
f"{DL2_SUBARRAY_GROUP}/{reco_task}/{self.prefix}",
292+
)
293+
294+
self.log.info(dl2_tab)
295+
self.log.info(len(dl2_tab))
296+
self.log.info(len(trigger))
297+
self.log.info(len(shower))
298+
dl2_tab.keep_columns(SUBARRAY_EVENT_KEYS + self.reco_colnames[reco_task])
299+
dl2_tab = join(
300+
left=dl2_tab,
301+
right=self.weights,
302+
keys=SUBARRAY_EVENT_KEYS,
303+
)
304+
# Extract telescope IDs from filename by finding digits after 'LST' substrings
305+
tel_ids_comb = [int(match) for match in re.findall(r'LST(\d+)', str(input_file)) if match.isdigit()]
306+
for col_name in self.reco_colnames[reco_task]:
307+
if reco_task == "energy":
308+
dl2_tab[col_name] = np.log10((u.Quantity(dl2_tab[col_name], unit=dl2_tab[col_name].unit).to_value(u.GeV)))
309+
310+
dl2_tab[col_name] = dl2_tab[col_name].data * dl2_tab[f"LST{tel_ids_comb[0]}LST{tel_ids_comb[1]}_norm"].data
311+
subarray_tables.append(dl2_tab)
312+
subarray_table = vstack(subarray_tables)
313+
314+
subarray_table.keep_columns(SUBARRAY_EVENT_KEYS + self.reco_colnames[reco_task] + [f"{self.prefix}_telescopes", f"{self.prefix}_is_valid"])
315+
subarray_table.sort(SUBARRAY_EVENT_KEYS)
316+
if reco_task == "classification":
317+
class_table = subarray_table.copy()
318+
elif reco_task == "energy":
319+
eng_table = subarray_table.copy()
320+
elif reco_task == "geometry":
321+
dir_table = subarray_table.copy()
322+
self.log.info(len(subarray_table))
323+
164324
# Deep copy the table to avoid modifying the original table
165-
predictions = subarray_tables.copy()
325+
predictions = subarray_table.copy()
166326
# Keep only the relevant columns for the mean calculation
167327
predictions.keep_columns(
168328
SUBARRAY_EVENT_KEYS + self.reco_colnames[reco_task]
169329
)
170330
# Group the predictions by the subarray event keys
171331
predictions_grouped = predictions.group_by(SUBARRAY_EVENT_KEYS)
332+
172333
# Calculate the mean predictions for each subarray event
173-
mean_predictions = predictions_grouped.groups.aggregate(np.mean)
334+
mean_predictions = predictions_grouped.groups.aggregate(np.nansum)
335+
if reco_task == "energy":
336+
mean_predictions[self.reco_colnames[reco_task][0]] = 10 ** mean_predictions[self.reco_colnames[reco_task][0]]
337+
174338
# Sort the mean prediction table by the subarray event keys
175339
mean_predictions.sort(SUBARRAY_EVENT_KEYS)
340+
176341
# Unique the subarray tables to avoid duplicates
177-
subarray_table = unique(
178-
subarray_tables, keys=SUBARRAY_EVENT_KEYS
342+
# this is needed because of the vstack above
343+
final_subarray_table = unique(
344+
subarray_table, keys=SUBARRAY_EVENT_KEYS
179345
)
346+
180347
# Remove the columns that will be replace by the mean predictions
181-
subarray_table.remove_columns(self.reco_colnames[reco_task])
348+
final_subarray_table.remove_columns(self.reco_colnames[reco_task])
349+
182350
# Join the mean predictions to the subarray table
183-
subarray_table = join(
184-
left=subarray_table,
351+
final_subarray_table = join(
352+
left=final_subarray_table,
185353
right=mean_predictions,
186354
keys=SUBARRAY_EVENT_KEYS,
187355
)
188-
# Sort the table by the subarray event keys
189-
subarray_table.sort(SUBARRAY_EVENT_KEYS)
356+
final_subarray_table.sort(SUBARRAY_EVENT_KEYS)
357+
#final_subarray_table[f"{self.prefix}_telescopes"] = self.weights[f"{self.prefix}_telescopes"]
358+
#final_subarray_table[f"{self.prefix}_is_valid"] = self.weights[f"{self.prefix}_is_valid"]
359+
360+
for col_name in self.reco_colnames[reco_task]:
361+
# Set the prediction to NaN if the event is not valid
362+
final_subarray_table[col_name] = np.where(
363+
final_subarray_table[f"{self.prefix}_is_valid"],
364+
final_subarray_table[col_name],
365+
np.nan
366+
)
367+
368+
# Add units to the columns
369+
if reco_task == "energy":
370+
final_subarray_table[col_name] = u.Quantity(
371+
final_subarray_table[col_name],
372+
unit=u.GeV,
373+
).to(u.TeV)
374+
375+
if reco_task == "geometry":
376+
final_subarray_table[col_name] = u.Quantity(
377+
final_subarray_table[col_name],
378+
unit=u.deg,
379+
)
380+
190381
# Add the default values and meta data to the table
191382
add_defaults_and_meta(
192-
subarray_table,
383+
final_subarray_table,
193384
self.reco_containers[reco_task],
194385
prefix=self.prefix,
195386
add_tel_prefix=False,
196387
)
388+
197389
# Save the prediction to the file
198390
write_table(
199-
subarray_table,
391+
final_subarray_table,
200392
self.output_path,
201393
f"{DL2_SUBARRAY_GROUP}/{reco_task}/{self.prefix}",
202394
overwrite=self.overwrite,
@@ -206,10 +398,21 @@ def start(self):
206398
self.output_path,
207399
f"{DL2_SUBARRAY_GROUP}/{reco_task}/{self.prefix}",
208400
)
401+
diff_1 = setdiff(
402+
self.weights, class_table, keys=SUBARRAY_EVENT_KEYS
403+
)
404+
self.log.info(diff_1)
405+
diff_2 = setdiff(
406+
self.weights, eng_table, keys=SUBARRAY_EVENT_KEYS
407+
)
408+
self.log.info(diff_2)
409+
diff_3 = setdiff(
410+
self.weights, dir_table, keys=SUBARRAY_EVENT_KEYS
411+
)
412+
self.log.info(diff_3)
413+
209414

210415
def finish(self):
211-
# Write the SubarrayDescription to the output file
212-
self.output_path.to_hdf(self.subarray, overwrite=self.overwrite)
213416
# Shutting down the tool
214417
self.log.info("Tool is shutting down")
215418

0 commit comments

Comments
 (0)