From 84228c12e37954736ec0b09f50a4bca4dc622dbb Mon Sep 17 00:00:00 2001 From: harens <12570877+harens@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:31:10 +0100 Subject: [PATCH] fix(context-builder): persist checkpoint metadata Save ContextBuilder architecture settings alongside the state_dict so non-default num_layers, bidirectional, and LSTM configurations can be restored. Keep loading older raw state_dict checkpoints by inferring constructor settings from stored tensor shapes and recurrent layer keys where possible. --- deepcase/context_builder/context_builder.py | 70 +++++++++++++++++---- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/deepcase/context_builder/context_builder.py b/deepcase/context_builder/context_builder.py index 1793315..6189636 100644 --- a/deepcase/context_builder/context_builder.py +++ b/deepcase/context_builder/context_builder.py @@ -60,6 +60,14 @@ def __init__(self, input_size, output_size, hidden_size=128, num_layers=1, # Initialise super super().__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.max_length = max_length + self.bidirectional = bidirectional + self.LSTM = LSTM + ################################################################ # Initialise layers # ################################################################ @@ -610,7 +618,16 @@ def save(self, outfile): File to output model. """ # Save to output file - torch.save(self.state_dict(), outfile) + torch.save({ + 'state_dict' : self.state_dict(), + 'input_size' : self.input_size, + 'output_size' : self.output_size, + 'hidden_size' : self.hidden_size, + 'num_layers' : self.num_layers, + 'max_length' : self.max_length, + 'bidirectional': self.bidirectional, + 'LSTM' : self.LSTM, + }, outfile) @classmethod def load(cls, infile, device=None): @@ -621,17 +638,46 @@ def load(cls, infile, device=None): infile : string File from which to load model. """ - # Load state dictionary - state_dict = torch.load(infile, map_location=device) - - # Get input variables from state_dict - input_size = state_dict.get('embedding.weight').shape[0] - output_size = state_dict.get('decoder_event.out.weight').shape[0] - hidden_size = state_dict.get('embedding.weight').shape[1] - num_layers = 1 # TODO - max_length = state_dict.get('decoder_attention.attn.weight').shape[0] - bidirectional = state_dict.get('decoder_attention.attn.weight').shape[1] // hidden_size != num_layers - LSTM = False # TODO + # Load checkpoint + checkpoint = torch.load(infile, map_location=device) + + if 'state_dict' in checkpoint: + state_dict = checkpoint.get('state_dict') + + input_size = checkpoint.get('input_size') + output_size = checkpoint.get('output_size') + hidden_size = checkpoint.get('hidden_size') + num_layers = checkpoint.get('num_layers') + max_length = checkpoint.get('max_length') + bidirectional = checkpoint.get('bidirectional') + LSTM = checkpoint.get('LSTM') + else: + # Backward compatibility for old raw state_dict checkpoints. + state_dict = checkpoint + + input_size = state_dict.get('embedding.weight').shape[0] + output_size = state_dict.get('decoder_event.out.weight').shape[0] + hidden_size = state_dict.get('embedding.weight').shape[1] + max_length = state_dict.get('decoder_attention.attn.weight').shape[0] + + layer_indices = list() + bidirectional = False + prefix = 'encoder.recurrent.weight_ih_l' + for key in state_dict: + if not key.startswith(prefix): continue + + suffix = key[len(prefix):] + if suffix.endswith('_reverse'): + suffix = suffix[:-len('_reverse')] + bidirectional = True + + if suffix.isdigit(): + layer_indices.append(int(suffix)) + + num_layers = max(layer_indices) + 1 if layer_indices else 1 + + recurrent_weight = state_dict.get('encoder.recurrent.weight_hh_l0') + LSTM = recurrent_weight.shape[0] == 4 * hidden_size # Create ContextBuilder result = cls(