Thanks for your great work.
Q1:
I found that we should execute
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
and then do
past_key_value.update (kv_pruned / key|value_states).
Since the pruned score is calculated for each attention head.
This is totally different from the original implementation of GQA.
The question is that the original GQA reduces the k/v cache (bsz, num_key_value_groups=[8], q_len, head_dim/pruned_dim), but your work eliminates this advantage (bsz, num_heads=[32], q_len, head_dim/pruned_dim).
Q2:
I also noticed that in the prefill stage, although we prune the token number to max_capacity_prompt (2k), we still use full attention to compute attention weight.
For example, we input a 6k prompt to generate a response, and in the prefill stage, we choose the 2k most important tokens key/value_states_compress.
However, we still use 6k (seq_len dim) query_states@key_states.T instead of 2k key_states_compress@value_states_compress.T to compute attention weight.
Why don't we use the pruned 2k (seq_len dim) key_states_compress@value_states_compress.T to compute attention weight?
Thanks a lot!