diff --git a/columnflow/tasks/cmsGhent/plotting.py b/columnflow/tasks/cmsGhent/plotting.py index 908b322c7..d70cd628a 100644 --- a/columnflow/tasks/cmsGhent/plotting.py +++ b/columnflow/tasks/cmsGhent/plotting.py @@ -54,11 +54,12 @@ def run(self): for var_name in variable_tuple ] category_insts = [self.config_inst.get_category(c) for c in self.branch_data.categories] + category_insts_leafs = [c.get_leaf_categories() or [c] for c in category_insts] process_inst = self.config_inst.get_process(self.branch_data.process) sub_process_insts = [sub for sub, _, _ in process_inst.walk_processes(include_self=True)] # histogram data for process - process_hist = 0 + process_hists = {c.name: 0 for c in category_insts} with self.publish_step(f"plotting {self.branch_data.variable} for {process_inst.name}"): for dataset, inp in self.input().items(): @@ -72,19 +73,12 @@ def run(self): # work on a copy h = h_in.copy() - - # axis selections h = h[{ "process": [ hist.loc(p.id) for p in sub_process_insts if p.id in h.axes["process"] ], - "category": [ - hist.loc(c.id) - for c in category_insts - if c.id in h.axes["category"] - ], "shift": [ hist.loc(s.id) for s in plot_shifts @@ -92,28 +86,41 @@ def run(self): ], }] - # axis reductions - h = h[{"process": sum}] + # axis selections + for c, lcs in zip(category_insts, category_insts_leafs): + hc = h[{ + "category": [ + hist.loc(c.id) + for c in lcs + if c.id in h.axes["category"] + ], + }] - # add the histogram - process_hist = h + process_hist + # axis reductions + hc = hc[{"category": sum}] + + # add the histsogram + process_hists[c.name] = hc + process_hists[c.name] # there should be hists to plot - if not process_hist: + if not all(process_hists.values()): raise Exception( "no histograms found to plot; possible reasons:\n" + " - requested variable requires columns that were missing during histogramming\n" + " - selected --processes did not match any value on the process axis of the input histogram", ) - process_hists = OrderedDict( - (cat.name, h[{"category": hist.loc(cat.id)}]) - for cat in category_insts - ) + # update histograms using custom hooks + hists = self.invoke_hist_hooks(process_hists) + + for cat in hists: + if "process" in hists[cat].axes.name: + hists[cat] = hists[cat][{"process": sum}] + # call the plot function fig, _ = self.call_plot_func( self.plot_function, - hists=process_hists, + hists=hists, config_inst=self.config_inst, category_inst=process_inst.copy_shallow(), variable_insts=[var_inst.copy_shallow() for var_inst in variable_insts],