From 8b3c1f98547105416845d39e100112bcf93f2402 Mon Sep 17 00:00:00 2001 From: Benoit <69694610+BenoitDalFerro@users.noreply.github.com> Date: Tue, 14 Feb 2023 15:42:47 +0100 Subject: [PATCH] Error in model, scaling only q matrix not qK.T dot product (qk.T/sqrt(dim_per_head) As per Vaswani et al, 2017 p.4 Is torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) not q / math.sqrt(dim_per_head) https://arxiv.org/pdf/1912.05372.pdf --- xlm/model/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xlm/model/transformer.py b/xlm/model/transformer.py index 53660e16..6a61ab0e 100755 --- a/xlm/model/transformer.py +++ b/xlm/model/transformer.py @@ -206,8 +206,8 @@ def unshape(x): k, v = cache[self.layer_id] cache[self.layer_id] = (k, v) - q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) + scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)