Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 58 additions & 12 deletions deepcase/context_builder/context_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def __init__(self, input_size, output_size, hidden_size=128, num_layers=1,
# Initialise super
super().__init__()

self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.max_length = max_length
self.bidirectional = bidirectional
self.LSTM = LSTM

################################################################
# Initialise layers #
################################################################
Expand Down Expand Up @@ -610,7 +618,16 @@ def save(self, outfile):
File to output model.
"""
# Save to output file
torch.save(self.state_dict(), outfile)
torch.save({
'state_dict' : self.state_dict(),
'input_size' : self.input_size,
'output_size' : self.output_size,
'hidden_size' : self.hidden_size,
'num_layers' : self.num_layers,
'max_length' : self.max_length,
'bidirectional': self.bidirectional,
'LSTM' : self.LSTM,
}, outfile)

@classmethod
def load(cls, infile, device=None):
Expand All @@ -621,17 +638,46 @@ def load(cls, infile, device=None):
infile : string
File from which to load model.
"""
# Load state dictionary
state_dict = torch.load(infile, map_location=device)

# Get input variables from state_dict
input_size = state_dict.get('embedding.weight').shape[0]
output_size = state_dict.get('decoder_event.out.weight').shape[0]
hidden_size = state_dict.get('embedding.weight').shape[1]
num_layers = 1 # TODO
max_length = state_dict.get('decoder_attention.attn.weight').shape[0]
bidirectional = state_dict.get('decoder_attention.attn.weight').shape[1] // hidden_size != num_layers
LSTM = False # TODO
# Load checkpoint
checkpoint = torch.load(infile, map_location=device)

if 'state_dict' in checkpoint:
state_dict = checkpoint.get('state_dict')

input_size = checkpoint.get('input_size')
output_size = checkpoint.get('output_size')
hidden_size = checkpoint.get('hidden_size')
num_layers = checkpoint.get('num_layers')
max_length = checkpoint.get('max_length')
bidirectional = checkpoint.get('bidirectional')
LSTM = checkpoint.get('LSTM')
else:
# Backward compatibility for old raw state_dict checkpoints.
state_dict = checkpoint

input_size = state_dict.get('embedding.weight').shape[0]
output_size = state_dict.get('decoder_event.out.weight').shape[0]
hidden_size = state_dict.get('embedding.weight').shape[1]
max_length = state_dict.get('decoder_attention.attn.weight').shape[0]

layer_indices = list()
bidirectional = False
prefix = 'encoder.recurrent.weight_ih_l'
for key in state_dict:
if not key.startswith(prefix): continue

suffix = key[len(prefix):]
if suffix.endswith('_reverse'):
suffix = suffix[:-len('_reverse')]
bidirectional = True

if suffix.isdigit():
layer_indices.append(int(suffix))

num_layers = max(layer_indices) + 1 if layer_indices else 1

recurrent_weight = state_dict.get('encoder.recurrent.weight_hh_l0')
LSTM = recurrent_weight.shape[0] == 4 * hidden_size

# Create ContextBuilder
result = cls(
Expand Down