-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
I'm going through the source code, hunting for potential bugs or speedups.
You're implementing CrossAttention yourself here: https://github.com/EPFLiGHT/MultiMeditron/blob/master/src/multimeditron/model/attention.py
The implementation looks correct, but probably leaves quite some performance on the table. Why not use the proper built-in pytorch functionality: scaled_dot_product_attention. That should run with much faster kernels, flash attention, etc.
How about something like this (notable changes in forward):
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
class CrossAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.1,
proj_drop: float = 0.1,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def _shape(self, x, B, T):
return x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
def forward(
self,
x: torch.Tensor, # [B, N_q, C]
experts: List[torch.Tensor], # list of [B, N_i, C]
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N_q, C = x.shape
context = torch.cat(experts, dim=1) # [B, N_kv, C]
N_kv = context.size(1)
q = self._shape(self.q_proj(x), B, N_q)
k = self._shape(self.k_proj(context), B, N_kv)
v = self._shape(self.v_proj(context), B, N_kv)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=self.attn_drop if self.training else 0.0,
is_causal=False,
)
out = out.transpose(1, 2).reshape(B, N_q, C)
out = self.proj_drop(self.proj(out))
return out
Please note: This is untested!
I don't want you to just override the current CrossAttention implementation with this, just bringing this to your attention ;)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels