11
22import sys
33from argparse import ArgumentParser
4- from astropy .table import join , unique , vstack
4+ from astropy .table import join , unique , vstack , Table , setdiff
5+ import math
6+ import re
57import numpy as np
6- from pathlib import Path
8+ import pathlib
9+ import astropy .units as u
710
8- from ctapipe .io import read_table , write_table
11+ from ctapipe .io import HDF5Merger , read_table , write_table
912from ctapipe .containers import (
1013 ParticleClassificationContainer ,
1114 ReconstructedGeometryContainer ,
1215 ReconstructedEnergyContainer ,
1316)
14- from ctapipe .core import Tool , traits
17+ from ctapipe .core import Tool
1518from ctapipe .core .traits import (
1619 Unicode ,
1720 Bool ,
21+ Path ,
1822 List ,
23+ Enum ,
24+ classes_with_traits ,
1925)
2026from ctapipe .instrument import SubarrayDescription
2127from ctapipe .reco .utils import add_defaults_and_meta
@@ -54,7 +60,7 @@ class MergeSubarrayTables(Tool):
5460
5561 input_url = Path (
5662 help = "Input ctapipe HDF5 files including stereoscopic predictions." ,
57- allow_none = False ,
63+ allow_none = True ,
5864 exists = True ,
5965 directory_ok = False ,
6066 file_ok = True ,
@@ -70,7 +76,7 @@ class MergeSubarrayTables(Tool):
7076 ).tag (config = True )
7177
7278 input_files = List (
73- traits . Path (exists = True , directory_ok = False ),
79+ Path (exists = True , directory_ok = False ),
7480 default_value = [],
7581 help = "Input ctapipe HDF5 files including stereoscopic predictions." ,
7682 ).tag (config = True )
@@ -80,7 +86,7 @@ class MergeSubarrayTables(Tool):
8086 help = "Give a specific file pattern for matching files in ``input_dir``" ,
8187 ).tag (config = True )
8288
83- output_path = traits . Path (
89+ output_path = Path (
8490 help = "Output ctapipe HDF5 file for the merged stereoscopic predictions." ,
8591 allow_none = False ,
8692 exists = False ,
@@ -100,21 +106,43 @@ class MergeSubarrayTables(Tool):
100106 help = "List of reconstruction tasks to be used for the stereo combination." ,
101107 ).tag (config = True )
102108
109+ n_telescopes = Enum (
110+ [3 ,4 ],
111+ default_value = 4 ,
112+ allow_none = False ,
113+ help = "Number of telescopes in the subarray. "
114+ "This is used to determine the telescope combinations." ,
115+ ).tag (config = True )
116+
103117 overwrite = Bool (
104118 default_value = True ,
105119 allow_none = False ,
106120 help = "Overwrite the table in the hdf5 file if it exists" ,
107121 ).tag (config = True )
108122
109123 parser = ArgumentParser ()
110- parser .add_argument ("input_files" , nargs = "*" , type = Path )
124+ parser .add_argument ("input_files" , nargs = "*" , type = pathlib . Path )
111125
112126 aliases = {
113127 ("i" , "input-dir" ): "MergeSubarrayTables.input_dir" ,
114128 ("o" , "output" ): "MergeSubarrayTables.output_path" ,
115129 ("p" , "pattern" ): "MergeSubarrayTables.file_pattern" ,
130+ ("t" , "n_telescopes" ): "MergeSubarrayTables.n_telescopes" ,
116131 }
117132
133+ flags = {
134+ "overwrite" : (
135+ {"HDF5Merger" : {"overwrite" : True }},
136+ "Overwrite existing files" ,
137+ ),
138+ "append" : (
139+ {"HDF5Merger" : {"append" : True }},
140+ "Append to existing files" ,
141+ ),
142+ }
143+
144+ classes = classes_with_traits (HDF5Merger )
145+
118146 def setup (self ):
119147 # Set up the containers and colnames based on the reco tasks
120148 self .reco_containers = {
@@ -142,61 +170,225 @@ def setup(self):
142170 "or input files as positional arguments"
143171 )
144172 sys .exit (1 )
173+
174+ # Merge the first input file to the output path
175+ with HDF5Merger (
176+ parent = self ,
177+ output_path = self .output_path ,
178+ ) as merger :
179+ merger (self .input_files [0 ])
145180 # Read the SubarrayDescription from the first input file
146- self .subarray = SubarrayDescription .read (self .input_files [0 ])
181+ #self.subarray = SubarrayDescription.read(self.input_files[0])
182+ if self .n_telescopes == 4 :
183+ self .tel_id_2_index = {1 :0 , 2 :1 , 3 :2 , 4 :3 }
184+ self .tel_combinations = [[1 ,4 ], [4 ,3 ], [3 ,2 ], [2 ,1 ], [1 ,3 ], [4 ,2 ]]
185+ elif self .n_telescopes == 3 :
186+ self .tel_id_2_index = {1 :0 , 3 :1 , 4 :2 }
187+ self .tel_combinations = [[1 ,4 ], [4 ,3 ], [1 ,3 ]]
188+
189+ dl1b_parameters = []
190+ for tel_id in self .tel_id_2_index .keys ():
191+ dl1b_parameters .append (
192+ read_table (
193+ self .input_files [0 ],
194+ f"/dl1/event/telescope/parameters/tel_{ tel_id :03d} "
195+ )
196+ )
197+
198+ dl1b_parameter_table = vstack (dl1b_parameters )
199+ #self.log.info(dl1b_parameter_table)
200+
201+ dl1b_parameter_table .sort (TELESCOPE_EVENT_KEYS )
202+
203+ dl1b_parameter_groups = dl1b_parameter_table .group_by (SUBARRAY_EVENT_KEYS )
204+ self .weights = {}
205+ self .weights ["obs_id" ] = np .zeros (
206+ len (dl1b_parameter_groups .groups ),
207+ dtype = int
208+ )
209+ self .weights ["event_id" ] = np .zeros (
210+ len (dl1b_parameter_groups .groups ),
211+ dtype = int
212+ )
213+ self .weights [f"{ self .prefix } _telescopes" ] = np .zeros (
214+ (len (dl1b_parameter_groups .groups ), len (self .tel_id_2_index )),
215+ dtype = bool
216+ )
217+ self .weights [f"{ self .prefix } _is_valid" ] = np .ones (
218+ len (dl1b_parameter_groups .groups ),
219+ dtype = bool
220+ )
221+ for tel_combination in self .tel_combinations :
222+ self .weights [f"LST{ tel_combination [0 ]} LST{ tel_combination [1 ]} _norm" ] = np .zeros (
223+ len (dl1b_parameter_groups .groups ),
224+ dtype = float
225+ )
226+ for g , grp in enumerate (dl1b_parameter_groups .groups ):
227+ # Save the obs_id and event_id for each group
228+ self .weights ["obs_id" ][g ] = int (grp ["obs_id" ][0 ])
229+ self .weights ["event_id" ][g ] = int (grp ["event_id" ][0 ])
230+ #if self.weights["obs_id"][g] > 2:
231+ # break
232+ # Create boolean array for tel_ids with hillas_intensity > 25
233+ # tel_id indexing starts from 1
234+ tel_bool_array = np .zeros (len (self .tel_id_2_index ), dtype = bool )
235+ tel_hillas_array = np .zeros (len (self .tel_id_2_index ), dtype = float )
236+ # Set True for tel_ids with high intensity
237+ tel_mask = grp ["hillas_intensity" ] > 25
238+ for i , survival_tel_id in enumerate (grp ["tel_id" ][tel_mask ]):
239+ tel_bool_array [self .tel_id_2_index [survival_tel_id ]] = True
240+ tel_hillas_array [self .tel_id_2_index [survival_tel_id ]] = grp ["hillas_intensity" ][tel_mask ][i ] # Adjust for 0-based indexing
241+ self .weights [f"{ self .prefix } _telescopes" ][g ] = tel_bool_array
242+
243+ surviving_tel_ids = set (grp ["tel_id" ][tel_mask ])
244+ if len (grp ["tel_id" ][tel_mask ]) < 2 :
245+ self .weights [f"{ self .prefix } _is_valid" ][g ] = False
246+ elif len (grp ["tel_id" ][tel_mask ]) == 2 :
247+ for tel_combination in self .tel_combinations :
248+ if all (tel_id in surviving_tel_ids for tel_id in tel_combination ):
249+ self .weights [f"LST{ tel_combination [0 ]} LST{ tel_combination [1 ]} _norm" ][g ] = 1.0
250+ elif len (grp ["tel_id" ][tel_mask ]) > 2 :
251+ hillas_sum = {}
252+ for tel_combination in self .tel_combinations :
253+ if all (tel_id in surviving_tel_ids for tel_id in tel_combination ):
254+ hillas_sum [f"LST{ tel_combination [0 ]} LST{ tel_combination [1 ]} " ] = tel_hillas_array [self .tel_id_2_index [tel_combination [0 ]]] + tel_hillas_array [self .tel_id_2_index [tel_combination [1 ]]]
255+
256+ hillas_norms = {key : value / np .sum (list (hillas_sum .values ())) for key , value in hillas_sum .items ()}
257+ for key , hillas_norm in hillas_norms .items ():
258+ self .weights [f"{ key } _norm" ][g ] = hillas_norm
259+ # Convert to astropy Table
260+ self .weights = Table (data = self .weights )
261+ # self.log.info(len(self.weights))
147262
148263 def start (self ):
264+
149265 # Loop over the reconstruction tasks and combine the telescope tables to a subarray table
266+ class_table , eng_table , dir_table = None , None , None
150267 for reco_task in self .reco_tasks :
151268 self .log .info ("Processing %s..." , reco_task )
152269
153- # Read the subarray tables from the input files
270+ # Read and join the subarray tables from the input files
154271 subarray_tables = []
155272 for input_file in self .input_files :
156- subarray_tables .append (
157- read_table (
158- input_file ,
159- f"{ DL2_SUBARRAY_GROUP } /{ reco_task } /{ self .prefix } " ,
160- )
273+
274+ self .log .info ("Reading from file: %s" , input_file )
275+
276+ trigger = read_table (
277+ input_file ,
278+ "/dl1/event/subarray/trigger" ,
279+ )
280+ self .log .info (trigger )
281+ shower = read_table (
282+ input_file ,
283+ "/simulation/event/subarray/shower" ,
161284 )
162- # Stack the telescope tables to a common table
163- subarray_tables = vstack (subarray_tables )
285+ self .log .info (shower )
286+ self .log .info (len (trigger ))
287+ self .log .info (len (shower ))
288+
289+ dl2_tab = read_table (
290+ input_file ,
291+ f"{ DL2_SUBARRAY_GROUP } /{ reco_task } /{ self .prefix } " ,
292+ )
293+
294+ self .log .info (dl2_tab )
295+ self .log .info (len (dl2_tab ))
296+ self .log .info (len (trigger ))
297+ self .log .info (len (shower ))
298+ dl2_tab .keep_columns (SUBARRAY_EVENT_KEYS + self .reco_colnames [reco_task ])
299+ dl2_tab = join (
300+ left = dl2_tab ,
301+ right = self .weights ,
302+ keys = SUBARRAY_EVENT_KEYS ,
303+ )
304+ # Extract telescope IDs from filename by finding digits after 'LST' substrings
305+ tel_ids_comb = [int (match ) for match in re .findall (r'LST(\d+)' , str (input_file )) if match .isdigit ()]
306+ for col_name in self .reco_colnames [reco_task ]:
307+ if reco_task == "energy" :
308+ dl2_tab [col_name ] = np .log10 ((u .Quantity (dl2_tab [col_name ], unit = dl2_tab [col_name ].unit ).to_value (u .GeV )))
309+
310+ dl2_tab [col_name ] = dl2_tab [col_name ].data * dl2_tab [f"LST{ tel_ids_comb [0 ]} LST{ tel_ids_comb [1 ]} _norm" ].data
311+ subarray_tables .append (dl2_tab )
312+ subarray_table = vstack (subarray_tables )
313+
314+ subarray_table .keep_columns (SUBARRAY_EVENT_KEYS + self .reco_colnames [reco_task ] + [f"{ self .prefix } _telescopes" , f"{ self .prefix } _is_valid" ])
315+ subarray_table .sort (SUBARRAY_EVENT_KEYS )
316+ if reco_task == "classification" :
317+ class_table = subarray_table .copy ()
318+ elif reco_task == "energy" :
319+ eng_table = subarray_table .copy ()
320+ elif reco_task == "geometry" :
321+ dir_table = subarray_table .copy ()
322+ self .log .info (len (subarray_table ))
323+
164324 # Deep copy the table to avoid modifying the original table
165- predictions = subarray_tables .copy ()
325+ predictions = subarray_table .copy ()
166326 # Keep only the relevant columns for the mean calculation
167327 predictions .keep_columns (
168328 SUBARRAY_EVENT_KEYS + self .reco_colnames [reco_task ]
169329 )
170330 # Group the predictions by the subarray event keys
171331 predictions_grouped = predictions .group_by (SUBARRAY_EVENT_KEYS )
332+
172333 # Calculate the mean predictions for each subarray event
173- mean_predictions = predictions_grouped .groups .aggregate (np .mean )
334+ mean_predictions = predictions_grouped .groups .aggregate (np .nansum )
335+ if reco_task == "energy" :
336+ mean_predictions [self .reco_colnames [reco_task ][0 ]] = 10 ** mean_predictions [self .reco_colnames [reco_task ][0 ]]
337+
174338 # Sort the mean prediction table by the subarray event keys
175339 mean_predictions .sort (SUBARRAY_EVENT_KEYS )
340+
176341 # Unique the subarray tables to avoid duplicates
177- subarray_table = unique (
178- subarray_tables , keys = SUBARRAY_EVENT_KEYS
342+ # this is needed because of the vstack above
343+ final_subarray_table = unique (
344+ subarray_table , keys = SUBARRAY_EVENT_KEYS
179345 )
346+
180347 # Remove the columns that will be replace by the mean predictions
181- subarray_table .remove_columns (self .reco_colnames [reco_task ])
348+ final_subarray_table .remove_columns (self .reco_colnames [reco_task ])
349+
182350 # Join the mean predictions to the subarray table
183- subarray_table = join (
184- left = subarray_table ,
351+ final_subarray_table = join (
352+ left = final_subarray_table ,
185353 right = mean_predictions ,
186354 keys = SUBARRAY_EVENT_KEYS ,
187355 )
188- # Sort the table by the subarray event keys
189- subarray_table .sort (SUBARRAY_EVENT_KEYS )
356+ final_subarray_table .sort (SUBARRAY_EVENT_KEYS )
357+ #final_subarray_table[f"{self.prefix}_telescopes"] = self.weights[f"{self.prefix}_telescopes"]
358+ #final_subarray_table[f"{self.prefix}_is_valid"] = self.weights[f"{self.prefix}_is_valid"]
359+
360+ for col_name in self .reco_colnames [reco_task ]:
361+ # Set the prediction to NaN if the event is not valid
362+ final_subarray_table [col_name ] = np .where (
363+ final_subarray_table [f"{ self .prefix } _is_valid" ],
364+ final_subarray_table [col_name ],
365+ np .nan
366+ )
367+
368+ # Add units to the columns
369+ if reco_task == "energy" :
370+ final_subarray_table [col_name ] = u .Quantity (
371+ final_subarray_table [col_name ],
372+ unit = u .GeV ,
373+ ).to (u .TeV )
374+
375+ if reco_task == "geometry" :
376+ final_subarray_table [col_name ] = u .Quantity (
377+ final_subarray_table [col_name ],
378+ unit = u .deg ,
379+ )
380+
190381 # Add the default values and meta data to the table
191382 add_defaults_and_meta (
192- subarray_table ,
383+ final_subarray_table ,
193384 self .reco_containers [reco_task ],
194385 prefix = self .prefix ,
195386 add_tel_prefix = False ,
196387 )
388+
197389 # Save the prediction to the file
198390 write_table (
199- subarray_table ,
391+ final_subarray_table ,
200392 self .output_path ,
201393 f"{ DL2_SUBARRAY_GROUP } /{ reco_task } /{ self .prefix } " ,
202394 overwrite = self .overwrite ,
@@ -206,10 +398,21 @@ def start(self):
206398 self .output_path ,
207399 f"{ DL2_SUBARRAY_GROUP } /{ reco_task } /{ self .prefix } " ,
208400 )
401+ diff_1 = setdiff (
402+ self .weights , class_table , keys = SUBARRAY_EVENT_KEYS
403+ )
404+ self .log .info (diff_1 )
405+ diff_2 = setdiff (
406+ self .weights , eng_table , keys = SUBARRAY_EVENT_KEYS
407+ )
408+ self .log .info (diff_2 )
409+ diff_3 = setdiff (
410+ self .weights , dir_table , keys = SUBARRAY_EVENT_KEYS
411+ )
412+ self .log .info (diff_3 )
413+
209414
210415 def finish (self ):
211- # Write the SubarrayDescription to the output file
212- self .output_path .to_hdf (self .subarray , overwrite = self .overwrite )
213416 # Shutting down the tool
214417 self .log .info ("Tool is shutting down" )
215418
0 commit comments