From 6a2b87be55f0a454bf38688093fbac08076f3879 Mon Sep 17 00:00:00 2001 From: maadcoen Date: Thu, 23 Jan 2025 10:54:30 +0100 Subject: [PATCH 1/3] PlotVariablesCatsPerProcess with leaf based categorization --- columnflow/tasks/cmsGhent/plotting.py | 35 ++++++++++++++------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/columnflow/tasks/cmsGhent/plotting.py b/columnflow/tasks/cmsGhent/plotting.py index 908b322c7..fa0c593b4 100644 --- a/columnflow/tasks/cmsGhent/plotting.py +++ b/columnflow/tasks/cmsGhent/plotting.py @@ -54,11 +54,12 @@ def run(self): for var_name in variable_tuple ] category_insts = [self.config_inst.get_category(c) for c in self.branch_data.categories] + category_insts_leafs = [c.get_leaf_categories() for c in category_insts] process_inst = self.config_inst.get_process(self.branch_data.process) sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] # histogram data for process - process_hist = 0 + process_hists = {c.name: 0 for c in category_insts} with self.publish_step(f"plotting {self.branch_data.variable} for {process_inst.name}"): for dataset, inp in self.input().items(): @@ -72,44 +73,44 @@ def run(self): # work on a copy h = h_in.copy() - - # axis selections h = h[{ "process": [ hist.loc(p.id) for p in sub_process_insts if p.id in h.axes["process"] ], - "category": [ - hist.loc(c.id) - for c in category_insts - if c.id in h.axes["category"] - ], "shift": [ hist.loc(s.id) for s in plot_shifts if s.id in h.axes["shift"] ], }] - - # axis reductions h = h[{"process": sum}] - # add the histogram - process_hist = h + process_hist + # axis selections + for c, lcs in zip(category_insts, category_insts_leafs): + hc = h[{ + "category": [ + hist.loc(c.id) + for c in lcs + if c.id in h.axes["category"] + ], + }] + + # axis reductions + hc = hc[{"category": sum}] + + # add the histsogram + process_hists[c.name] = hc + process_hists[c.name] # there should be hists to plot - if not process_hist: + if not all(process_hists.values()): raise Exception( "no histograms found to plot; possible reasons:\n" + " - requested variable requires columns that were missing during histogramming\n" + " - selected --processes did not match any value on the process axis of the input histogram", ) - process_hists = OrderedDict( - (cat.name, h[{"category": hist.loc(cat.id)}]) - for cat in category_insts - ) # call the plot function fig, _ = self.call_plot_func( self.plot_function, From ea75e48db388226a1d7b5a64c84502e806d33243 Mon Sep 17 00:00:00 2001 From: maadcoen Date: Fri, 24 Jan 2025 17:56:40 +0100 Subject: [PATCH 2/3] implement hooks --- columnflow/tasks/cmsGhent/plotting.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/cmsGhent/plotting.py b/columnflow/tasks/cmsGhent/plotting.py index fa0c593b4..a4d2f3715 100644 --- a/columnflow/tasks/cmsGhent/plotting.py +++ b/columnflow/tasks/cmsGhent/plotting.py @@ -85,7 +85,6 @@ def run(self): if s.id in h.axes["shift"] ], }] - h = h[{"process": sum}] # axis selections for c, lcs in zip(category_insts, category_insts_leafs): @@ -111,10 +110,17 @@ def run(self): " - selected --processes did not match any value on the process axis of the input histogram", ) + # update histograms using custom hooks + hists = self.invoke_hist_hooks(process_hists) + + for cat in hists: + if "process" in hists[cat].axes.name: + hists[cat] = hists[cat][{"process": sum}] + # call the plot function fig, _ = self.call_plot_func( self.plot_function, - hists=process_hists, + hists=hists, config_inst=self.config_inst, category_inst=process_inst.copy_shallow(), variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts], From c5b4b0c72c5fb7d5270fdd131d53a28ba6695992 Mon Sep 17 00:00:00 2001 From: maadcoen Date: Wed, 5 Feb 2025 10:29:46 +0100 Subject: [PATCH 3/3] fix cat leaf fetching --- columnflow/tasks/cmsGhent/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/cmsGhent/plotting.py b/columnflow/tasks/cmsGhent/plotting.py index a4d2f3715..d70cd628a 100644 --- a/columnflow/tasks/cmsGhent/plotting.py +++ b/columnflow/tasks/cmsGhent/plotting.py @@ -54,7 +54,7 @@ def run(self): for var_name in variable_tuple ] category_insts = [self.config_inst.get_category(c) for c in self.branch_data.categories] - category_insts_leafs = [c.get_leaf_categories() for c in category_insts] + category_insts_leafs = [c.get_leaf_categories() or [c] for c in category_insts] process_inst = self.config_inst.get_process(self.branch_data.process) sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)]