diff --git a/self_speculation/llama_model_utils.py b/self_speculation/llama_model_utils.py index 1ef9a53..23b5248 100644 --- a/self_speculation/llama_model_utils.py +++ b/self_speculation/llama_model_utils.py @@ -149,6 +149,10 @@ def crop_past_key_values( return past_key_values +def _compute_position_embeddings(model, hidden_states, position_ids): + return model.model.rotary_emb(hidden_states, position_ids) + + # Our forward_early(...) and forward_remainder(...) functions currently use transformers library's legacy KV cache implementation that is less efficient. # To ensure an apples to apples comparison, we created this forward function to use in autoregressive decoding to ensure it uses the same KV cache implementation instead. # FIXME: update forward_early(...) and forward_remainder(...) to use the updated more efficient KV cache implementation. @@ -189,16 +193,19 @@ def forward( ) hidden_states = inputs_embeds + position_embeddings = _compute_position_embeddings(model, hidden_states, position_ids) + for decoder_layer in model.model.layers: - hidden_states, past_key_values = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, - padding_mask=None, + position_embeddings=position_embeddings, ) + hidden_states = layer_outputs[0] past_key_values = past_key_values.to_legacy_cache() hidden_states = model.model.norm(hidden_states) @@ -249,20 +256,22 @@ def forward_early( ) hidden_states = inputs_embeds + position_embeddings = _compute_position_embeddings(model, hidden_states, position_ids) + for decoder_layer in model.model.layers[:exit_layer]: - hidden_states, past_key_values = decoder_layer( + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, - padding_mask=None, + position_embeddings=position_embeddings, ) + hidden_states = layer_outputs[0] past_key_values = past_key_values.to_legacy_cache() - # next_cache = next_decoder_cache if exit_query_cache is None: exit_query_cache = hidden_states else: @@ -336,51 +345,48 @@ def forward_remainder( full_past_key_values_length, # we have no past for the full model ) - next_decoder_cache = [] hidden_states = inputs_embeds - # TODO simplify full_hidden_states: Optional[torch.FloatTensor] = None + for idx, decoder_layer in enumerate(model.model.layers): is_early_exit = idx < exit_layer - past_key_value = ( - past_key_values[idx] - if (past_key_values is not None and idx < len(past_key_values)) - else None - ) + if is_early_exit: - # early hidden states: B x num_gen x C early_hidden_states = hidden_states[:, -num_tokens_to_generate:] early_position_ids = position_ids[:, -num_tokens_to_generate:] - hidden_states, past_key_values = decoder_layer( + early_position_embeddings = _compute_position_embeddings(model, early_hidden_states, early_position_ids) + + layer_outputs = decoder_layer( early_hidden_states, attention_mask=early_attention_mask, position_ids=early_position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, - padding_mask=None, + position_embeddings=early_position_embeddings, ) + hidden_states = layer_outputs[0] else: if full_hidden_states is None and exit_query_cache is not None: - # first time seeing the full hidden states, we need to rely on the - # query cache - # only use if exit query cache exists, if not this is our first call full_hidden_states = torch.cat( [exit_query_cache, hidden_states[:, -num_tokens_to_generate:]], dim=1, ) else: - # we already have seen the fully hidden states we can re-use them now full_hidden_states = hidden_states - hidden_states, past_key_values = decoder_layer( + + full_position_embeddings = _compute_position_embeddings(model, full_hidden_states, position_ids) + + layer_outputs = decoder_layer( full_hidden_states, attention_mask=full_attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=False, use_cache=True, - padding_mask=None, + position_embeddings=full_position_embeddings, ) + hidden_states = layer_outputs[0] past_key_values = past_key_values.to_legacy_cache() hidden_states = model.model.norm(hidden_states)