-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.typ
More file actions
5499 lines (4669 loc) · 268 KB
/
main.typ
File metadata and controls
5499 lines (4669 loc) · 268 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* TODOS-
* - Style links
* - Eq numbering
* - Sec numbering
* - Fix algos
* - TOC styling
* - Multi-citation fmt
* */
#import "@preview/lovelace:0.3.0": *
/* Custom objects */
#let DR = `Dropout`
#let int = math.integral
#let prod = math.product
#let SM = math.op(`Softmax`)
#let CAT = math.op(`Concat`)
#let LIN = math.op(`Linear`)
#let LN = math.op(`LayerNorm`)
#let MLP = math.op(`MLP`)
#let CA = math.op(`CausalAttention`)
#let NORM = math.op(`Norm`)
#let CUMSUM = math.op(`cumsum`)
#let CUMPROD = math.op(`cumprod`)
#let SEGSUM = math.op(`segsum`)
#let SEGPROD = math.op(`segprod`)
#let SUM = math.op(`sum`)
#let MHA = math.op(`MHA`)
#let TOPK = math.op(`topk`)
#let TR = math.op(`Trace`)
/* #let nice_box(body) = { block(stroke: black, inset: 1em, radius: .5em)[#body] } */
#let nice_box(body) = {
block(fill: luma(238), inset: 1em, radius: .5em)[#body]
}
#let warn_box(body) = {
block(fill: rgb("#FF9999"), inset: 1em, radius: .5em)[#body]
}
#set table(
inset: 6pt,
stroke: none,
)
#show quote: it => [
#align(left)[
#pad(left: 2em, right: 2em)[
#emph[#it.body]
#if it.attribution != none [
#linebreak()
#align(right)[— #it.attribution]
]
]
]
]
#show figure.where(kind: table): set figure.caption(position: top)
#show figure.where(kind: image): set figure.caption(position: bottom)
#show figure.caption: set align(left)
#let content-to-string(content) = {
if content.has("text") {
content.text
} else if content.has("children") {
content.children.map(content-to-string).join("")
} else if content.has("body") {
content-to-string(content.body)
} else if content == [ ] {
" "
}
}
#let conf(
title: "Decoder-Only Transformers",
subtitle: none,
authors: "Garrett Goon",
keywords: (),
date: none,
abstract: none,
cols: 1,
margin: (x: 1.0in, y: 1.0in),
paper: "us-letter",
lang: "en",
region: "US",
font: (),
fontsize: 12pt,
sectionnumbering: "1.",
pagenumbering: "1",
doc,
) = {
set document(
title: title,
author: authors.map(author => content-to-string(author.name)),
keywords: keywords,
)
set page(
paper: paper,
margin: margin,
numbering: pagenumbering,
columns: cols,
)
set par(justify: true)
set text(
lang: lang,
region: region,
size: fontsize,
)
set heading(numbering: sectionnumbering)
set enum(numbering: "1.a.i.")
place(top, float: true, scope: "parent", clearance: 4mm)[
#if title != none {
align(center)[#block(inset: 2em)[
#text(weight: "bold", size: 1.5em)[#title]
#(
if subtitle != none {
parbreak()
text(weight: "bold", size: 1.25em)[#subtitle]
}
)
]]
}
#if authors != none and authors != [] {
let count = authors.len()
let ncols = calc.min(count, 3)
grid(
columns: (1fr,) * ncols,
row-gutter: 1.5em,
..authors.map(author => align(center)[
#author.name \
#author.affiliation \
#author.email
])
)
}
#if date != none {
align(center)[#block(inset: 1em)[
#date
]]
}
#if abstract != none {
block(inset: 2em)[
#text(weight: "semibold")[Abstract] #h(1em) #abstract
]
}
]
doc
}
#show: doc => conf(
title: [Decoders],
authors: (
(name: [Garrett Goon], affiliation: "", email: ""),
),
abstract: [Notes on various aspects of decoder models (and related topics). Conventions
are in the appendix, @app_conventions.
],
pagenumbering: "1",
cols: 1,
doc,
)
/* https://github.com/typst/typst/discussions/4031#discussioncomment-9258528 */
#let appendix(body) = {
set heading(numbering: "A.1", supplement: [Appendix])
counter(heading).update(0)
body
}
/* Outline styling */
#show outline.entry.where(level: 1): it => {
v(12pt, weak: true)
strong(it)
}
#outline(indent: auto)
#show link: it => underline(text(fill: blue)[#it])
/* Eq ref styling https://typst.app/docs/reference/model/ref/ */
#set math.equation(
numbering: "(1)",
number-align: end + bottom,
)
#show math.equation: set block(breakable: true)
#show ref: it => {
let foot = footnote
let eq = math.equation
let el = it.element
set text(blue)
if el != none and el.func() == eq {
// Override equation references.
link(
el.location(),
numbering(
el.numbering,
..counter(eq).at(el.location()),
),
)
} else if el != none and el.func() == foot {
// Footnote styling
link(
el.location(),
"Footnote "
+ numbering(
el.numbering,
..counter(eq).at(el.location()),
),
)
} else {
// Other references as usual.
it
}
}
= Architecture
<architecture>
== Decoder-Only Fundamentals
<sec_decoder_only>
The Transformers architecture @vaswani2017attention, which dominates
Natural Language Processing (NLP) as of July 2023, is a relatively
simple architecture. There are various flavors and variants of
Tranformers, but focus here on the decoder-only versions which underlie
the GPT models
@gpt2radford2019language@gpt3brown2020language@gpt4openai2023.
The full decoder-only architecture can be seen in @fig_transformers_architecture. See
@app_conventions for more on conventions.
#figure(
image("figures/transformer-general.jpg"),
caption: [
The full transformers architecture. Diagram taken from
@korthikanti2022reducing
],
)
<fig_transformers_architecture>
At a high level, decoder-only transformers take in an ordered series of
word-like objects, called tokens, and are trained to predict the next
token in the sequence. Given some initial text, transformers can be used
to give a prediction for the likelihood of any possible continuation of
that text. An outline of the mechanics#footnote[This describes the
vanilla architecture; almost every component is modified in the
available variants.]:
+ Raw text is #strong[tokenized] and turned into a series of
integers#footnote[There are about
#link("https://github.com/ray-project/llm-numbers")[1.3 tokens per word],
on average.] whose values lie in `range(V)`, with $V$ the vocabulary size.
+ The tokenized text is chunked and turned into `(B, S)`-shaped (batch size and
sequence length, respectively) integer tensors, $x_(b s)$.
+ The #strong[embedding layer] converts the integer tensors into
continuous representations of shape `(B, S, D)`, $z_(b s d)$, with $D$ the size
of the hidden dimension. #strong[Positional encodings] have also been
added to the tensor at this stage to help the architecture understand
the relative ordering of the text.
+ The $z_(b s d)$ tensors pass through a series of transformer blocks,
each of which has two primary components:
+ In the #strong[attention] sub-block, components of $z_(b s d)$ at
different positions ($s$-values) interact with each other, resulting
in another `(B, S, D)`-shaped tensor, $z'_(b s d)$.
+ In the #strong[MLP] block, each position in $z'_(b s d)$ is
processed independently and in parallel by a two-layer feed-forward
network, resulting once more in a -shaped tensor.
Importantly, there are #strong[residual connections] around each of
these#footnote[This gives rise to the concept of the #strong[residual
stream] which each transformer block reads from and writes back to
repeatedly.] (the arrows in @fig_transformers_architecture),
meaning that the output of each block is added back to its original
input.
+ Finally, we convert the `(B, S, D)`-shaped tensors to `(B, S, V)`-shaped ones, $y_(b s v)$.
This is the role of the #strong[language model head] (which is often
just the embedding layer used in an inverse manner.)
+ The $y_(b s v)$ predict what the next token will be, i.e.
$x_(b s + 1)$, having seen the #strong[context] of the first $s$
tokens in the sequence. Specifically, removing the batch index for
simplicity, a $SM$ of $y_(s v)$ gives the conditional probability
$p_(s v) = P (t_(s + 1) \| t_s dots.h t_0)$ for the indicated series
of tokens. Because of the chain rule of probability, these individual
probabilities can be combined to form the probability that any
sequence of tokens follows a given initial seed#footnote[In more
detail, these probabilities are created by products:
$P (t_(s + n) dots.h t_(s + 1) \| t_s dots.h t_0) = P (t_(s + n) \| t_(s + n - 1) dots.h t_s dots.h t_0) times dots.h times P (t_(s + 1) \| t_s dots.h t_0)$.].
Each batch (the $b$-index) is processed independently. We omitted $LN$ and $DR$
layers above, as well as the causal mask; these will be covered below as
we step through the architecture in more detail.
=== Embedding Layer and Positional Encodings <subsubsec_embedding_and_pe>
The #strong[embedding] layer is just a simple lookup table: each of the `range(V)` indices in the
vocabulary is mapped to a $D$-dimensional vector via a large `(V, D)`-shaped table/matrix. This layer maps
$x_(b s) arrow.r z_(b s d)$. In , this is an `nn.Embedding(V, D)` instance.
To each item in a batch, we add identical #strong[positional encodings] to the vectors above with
the goal of adding fixed, position-dependent correlations in the sequence dimension which will
hopefully make it easier for the architecture to pick up on the relative positions of the inputs
#footnote[Positional encodings and the causal mask are the only components in the vanilla
transformers architecture which carry weights with a dimension of size $S$; i.e. they are the only
parts that have explicit sequence-length dependence. A related though experiment: you can convince
yourself that if the inputs $z_(b s d)$ were just random noise, the transformers architecture
would not be able to predict the $s$-index of each such input in the absence of positional
encodings.] This layer maps $z_(b s d) arrow.l z_(b s d) + p_(s d)$, with $p_(s d)$ the positional
encoding tensor.
The above components require $(V + S) D approx V D$ parameters per
model.
=== Layer Norm <layer_norm>
The original transformers paper @vaswani2017attention put $LN$ instances
after the #strong[attention] and #strong[MLP] blocks, but now it is
common @xiong2020layer to put them before these blocks#footnote[Which
makes intuitive sense for the purposes of stabilizing the matrix
multiplications in the blocks].
The $LN$ operations acts over the hidden dimension (since this is the
dimension the subsequent $LIN$ instances act on). Spelling it out, given the
input tensor $z_(b s d)$ whose mean and variance over the $d$-index are
$mu_(b s)$ and $sigma_(b s)$, respectively, the $LN$ output is
$
z_( b s d ) & <- ( (z_( b s d ) - mu_( b s ) ) / sigma_( b s ) ) gamma_d
+ beta_( d ) equiv LN_d z_( b s d )
$
where $gamma_d \, beta_d$ are the trainable scale and
bias parameters. In `torch`, this is a `nn.LayerNorm(D)` instance. Since there are two $LN$ instances
in each transformer block, these components require $2 D$ parameters per
layer.
We will continue discussing $LN$ instances in what follows in order to adhere to the usual
construction and to discuss methods like sequence-parallelism in their original form (see
@subsec_seq_parallelism), but note: the data-independent $LN$ transformations due to $gamma_d \, beta_d$
are completely redundant when immediately followed by a $LIN$ layer, since both act linearly on their
inputs and $LIN$ is already the most general data-independent linear transformation. Explicitly, the
$gamma_d \, beta_d$ parameters can be absorbed into the $LIN$ parameters:
$
(x_(b s d) gamma_d + beta_d) W_(d d') + b_(d') & = x_(b s d) W'_(d d') + b'_(d') med \, quad W'_(d d') equiv gamma_d W_(d d') med \, quad b'_(d') equiv b_(d') + beta_d W_(d d') med \,
$
for arbitrary $x_(b s d)$. That is, these transformations can be
equivalently performed by the weight matrix and bias (if included) in the $LIN$
layer#footnote[Note the importance of data-independence here: the
data-dependent mean and standard deviation terms cannot be similarly
absorbed. Also, because the usual training algorithms are not invariant
under parameter redefinitions, the above unfortunately does not imply
that removing the $LIN$ learnable parameters (`elementwise_affine=False` in `torch`) will have no effect on
training dynamics. $gamma_d \, beta_d$ can shoved into the $LIN$ layer's
parameters as a small inference-time optimization, though.].
=== Causal Attention <attn_layer>
#strong[Causal attention] is the most complex layer. It features $A$
sets of weight matrices#footnote[There are also bias terms, but we will
often neglect to write them explicitly or account for their (negligible)
parameter count.] $Q_(d e a) \, K_(d e a) \, V_(d e a)$ where
$a in {0 \, dots.h \, A - 1}$ and $e in {0 \, dots.h \, D \/ A}$, where
$D$ is assumed perfectly divisible by $A$. From these, we form three
different vectors:
$
q_(b s e a) & = z_(b s d) Q_(d e a) med \, quad k_(b s e a) = z_(b s d) K_(d e a) med \, quad v_(b s e a) = z_(b s d) V_(d e a)
$
These are the #strong[query, key, and value] tensors, respectively
#footnote[There are of course many variants of the architecture and one
variant which is popular in Summer 2023 is multi-query attention
/* @shazeer2019fast in which all heads share #emph[the same] key and value */
vectors and only the query changes across heads, as this greatly reduces
/* inference costs. See @subsec_multi_query_attn. */
].
Using the above tensors, we will then build up an #strong[attention map]
$w_(b s s' a)$ which corresponds to how much attention the token at
position $s$ pays to the token at position $s'$. Because we have the
goal of predicting the next token in the sequence, we need these weights
to be causal: the final prediction $y_(b s v)$ should only have access
to information propagated from positions $x_(b s' v)$ with $s' lt.eq s$.
This corresponds to the condition that $w_(b s s' a) = 0$ if $s' > s$.
The entire causal Transformers architecture as a whole obeys this
condition: the outputs
$z_(b s d) = mono("CausalTransformer") (x_(b s' d'))$ only depend on
those inputs $x_(b s' d')$ with $s' lt.eq s$.
These weights come from $SM$-ed attention scores, which are just a
normalized dot-product over the hidden dimension:
$
w_( b s s' d a ) & =SM_( s' ) ((q_( b s e )k_( b s' e a ) )/sqrt(D \/ A) + m_( s s' ) ), "s.t." sum_(s')w_( b d s s' a ) =1
$
The tensor $m_(s s')$ is the causal mask which zeroes
out the relevant attention map components above
$
m_{ s s\' } & = cases(
0 & s <= s' ,
- infinity & = s > s'
)
$
forcing $w_(b s s' d a) = 0$ for $s > s'$. In other words, the causal mask ensures that a given
tensor, say $z_(b s d)$, only has dependence on other tensors whose sequence index, say $s'$, obeys
$s' lt.eq s$. This is crucial for inference-time optimizations, in particular the use of the
#strong[kv-cache] in which key-value pairs do not need to be re-computed.
The $sqrt(D \/ A)$ normalization is motivated by demanding that the variance of the $SM$ argument be 1 at
initialization, assuming that other components have been configured so that that the query and key
components are i.i.d. from a Gaussian normal distribution#footnote[The square root is removed in the normalization espou @mutransfer-and-similar-ideas].
The weights above are then passed through a dropout layer and used to
re-weigh the #strong[value] vectors and form the tensors
$
y_( b s e a) & = DR (w_( b d s s' a) ) v_( b s'e a )
$<eq_reweighted_values>
and these `(B, S, D/A, A)`-shaped tensors are then concatenated along
the $e$-direction to re-form a `(B, S, D)`-shaped tensor $u_(b s d)$
$ u_(b s d) & = y_(b s (e a)) $ in
#link("https://einops.rocks/1-einops-basics/")[`einops`]-like notation for
concatenation. Finally, another weight matrix $O_(d' d)$ and dropout
layer transform the output once again to get the final output
$
z_( b s d ) & = DR (u_( b s d' ) O_( d'd ) ) .
$
For completeness, the entire operation in condensed notation with
indices left implicit is:
$
z & arrow DR ( CAT ( DR (SM ( ( ( z dot Q_( a ) )dot ( z dot K_( a ) )) / sqrt(D / A) ) )dot z dot V_( a ) ) dot O )
$<eq_causal_attn>
where all of the dot-products are over feature
dimensions (those of size $D$ or $D \/ A$).
Below is pedagogical#footnote[The code is written for clarity, not
speed. An example optimization missing here: there is no need to form
separate $Q_a \, K_a \, V_a$ $LIN$ layers, one large layer which is later
chunked is more efficient] sample code for such a $CA$ layer#footnote[When
using sequence-parallelism, it will be more natural to separate out the
final $DR$ layer and combine it with the subsequent $LN$, as they are sharded
together; see @subsec_seq_parallelism. The same is true for the $MLP$
layer below.]:
```python
class CausalAttention(nn.Module):
def __init__(
self,
block_size=K,
dropout=0.1,
hidden_dim=D,
num_attn_heads=A,
):
super().__init__()
self.block_size = block_size
self.dropout = dropout
self.hidden_dim = hidden_dim
self.num_attn_heads = num_attn_heads
self.head_dim, remainder = divmod(hidden_dim, num_attn_heads)
assert not remainder, "num_attn_heads must divide hidden_dim evenly"
self.Q = nn.ModuleList(
[nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
)
self.K = nn.ModuleList(
[nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
)
self.V = nn.ModuleList(
[nn.Linear(hidden_dim, self.head_dim) for _ in range(num_attn_heads)]
)
self.O = nn.Linear(hidden_dim, hidden_dim)
self.attn_dropout = nn.Dropout(dropout)
self.out_dropout = nn.Dropout(dropout)
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(block_size, block_size)[None]),
)
def get_qkv(self, inputs):
queries = [q(inputs) for q in self.Q]
keys = [k(inputs) for k in self.K]
values = [v(inputs) for v in self.V]
return queries, keys, values
def get_attn_maps(self, queries, keys):
S = queries[0].shape[1]
norm = math.sqrt(self.head_dim)
non_causal_attn_scores = [(q @ k.transpose(-2, -1)) / norm for q, k in zip(queries, keys)]
# Note: this mask shape is a bit of a hack to make generation from the KV cache work without
# specifying an extra boolean. When queries and keys have different sequence lengths and the
# queries are of seq_len == 1,p the query attends to all of the keys; effectively there is
# no mask at all.
causal_attn_scores = [
a.masked_fill(self.causal_mask[:, :S, :S] == 0, float("-inf"))
for a in non_causal_attn_scores
]
attn_maps = [a.softmax(dim=-1) for a in causal_attn_scores]
return attn_maps
def forward(self, inputs):
queries, keys, values = self.get_qkv(inputs)
attn_maps = self.get_attn_maps(queries, keys)
weighted_values = torch.cat(
[self.attn_dropout(a) @ v for a, v in zip(attn_maps, values)], dim=-1
)
z = self.O(weighted_values)
z = self.out_dropout(z)
return z
```
The parameter count is dominated by the weight matrices which carry
$4 D^2$ total parameters per layer.
=== MLP <subsubsec_mlp>
The feed-forward network is straightforward and corresponds to
$
z_( b s d ) & -> DR (phi ( z_( b s d' )W^0_( d'e ) ) W^1_( e d ) )
$<eq_mlp>
where $W^0$ and $W^1$ are `(B, S, D)`- and `(E*D, D)`-shaped matrices,
respectively (see @app_conventions for notation) and $phi$ is a
non-linearity#footnote[The `GeLU`
#link("https://pytorch.org/docs/stable/generated/torch.nn.GELU.html")[non-linearity]
is common.]. In code, where we again separate out the last $DR$ layer as we
did in in @attn_layer:
```python
class MLP(nn.Module):
def __init__(
self,
hidden_dim=D,
expansion_factor=E,
dropout=0.1,
):
super().__init__()
self.hidden_dim = hidden_dim
self.expansion_factor = expansion_factor
self.dropout = dropout
linear_1 = nn.Linear(hidden_dim, expansion_factor * hidden_dim)
linear_2 = nn.Linear(expansion_factor * hidden_dim, hidden_dim)
gelu = nn.GELU()
self.layers = nn.Sequential(linear_1, gelu, linear_2)
self.dropout = nn.Dropout(dropout)
def forward(self, inputs):
z = self.layers(inputs)
z = self.dropout(z)
return z
```
This bock requires $2 E D^2$ parameters per layer, only counting the
contribution from weights.
=== Language Model Head <subsubsec_language_model_head>
The layer which converts the `(B, S, D)`-shaped outputs, $z_(b s d)$, to `(B, S, V)`-shaped
predictions over the vocabulary, $y_(b s v)$, is the #strong[Language
Model Head]. It is a linear layer, whose weights are often tied to be
exactly those of the initial embedding layer of
@subsubsec_embedding_and_pe.
=== All Together
<all-together>
It is then relatively straightforward to tie every thing together. In
code, we can first create a transformer block like which corresponds to
the schematic function
$
z & arrow z + MLP ( LN ( z + CA (LN ( z ) ) ))
$
indices suppressed.
=== The Loss Function
<the-loss-function>
The last necessary component is the loss function. The training loop data is the `(B, K)`-shaped#footnote[`K` is
the block size, the maximum sequence-length for the model. See @app_conventions.] token
inputs ($x_(b s)$) along with their shifted-by-one relatives $y_(b s)$ where `x[:, s + 1] == y[:, x]` . The `(B, K, V)`-shaped outputs
($z_(b s v)$) of the `DecoderOnly` network are treated as the logits which predict the value of the next token,
given the present context:
$
p(x_( b (s+1) )=v| x_( b s ), x_( b (s-1) ), ..., x_( b 0 )) & = SM_( v ) z_( b s v )
$
<eq_transformer_conditional_prob>
and so the model is trained using the usual cross-entropy/maximum-likelihood loss#footnote[Here's an
alternative derivation for why this loss is minimized when the learned distribution perfectly
matches the actual one. Let $p (x)$ be the actual distribution and $q_theta (x)$ be the model.
Taking the continuous case, the expected loss is $cal(L) = - int dif x thin p(x)ln q _(
theta)(x)$. We want to minimize this, subject to the condition that $int dif x q _( theta)(x)
=1$. So, we use the #link("https://e n.wikipedia.org/wiki/Calculus_of_variations")[calculus of
variations] on the loss with a Lagrange multiplier: $cal(L)' = cal(L) + lambda int dif x
thin q_( theta )(x)$. Solving $( delta cal(L)' )/( delta q _( theta )(x) )=0$ yields $q_theta (x)
= p (x)$. This seems more straightforward and general than the usual argument via the
KL-divergence and Jensen's inequality.]
$
cal(L) & = -1 / (B K) sum_( b,s )ln p(x_( b (s+1) )=y_( b(s+1) )| x_( b s ), x_( b (s-1) ),
..., x_( b 0 )) \
& = - 1 / ( B K )sum_( b,s )SM_( v ) z_( b s v)|_( v=y_( b(s+1) ) ) .
$
Note that the losses for all possible context lengths
are included in the sum, equally weighted#footnote[In Natural Language
Processing (NLP), the perplexity is often reported instead of the loss, which is
just the exponential of the loss, a geometric-mean over the gold-answer
probabilities:
$"perplexity" = e^( cal(L) ) = (product_( b, s )p(x _( b
(s+1) )=| x _( b s ), x _( b (s-1) ), ..., x _( b 0 )) ) ^( -1 /( B K ) )$.].
In `torch` code, the loss computation might look like the following (using fake data):
```python
model = DecoderOnly(
num_attn_heads=A,
block_size=K,
dropout=0.1,
expansion_factor=E,
hidden_dim=D,
num_layers=L,
vocab_size=V,
)
tokens = torch.randint(model.vocab_size, size=(B, model.block_size + 1))
inputs, targets = tokens[:, :-1], tokens[:, 1:]
outputs = model(inputs)
outputs_flat, targets_flat = outputs.reshape(-1, outputs.shape[-1]), targets.reshape(-1)
loss = F.cross_entropy(outputs_flat, targets_flat)
```
== Architecture and Algorithm Variants
<architecture-and-algorithm-variants>
There are, of course, many variants on the basic architecture. Some
particularly important ones are summarized here.
=== RMS Norm <subsec_rms_norm>
RMS Norm @zhang2019rootmeansquarelayer is a minimized layer norm (@layer_norm) which skips removing
the mean, adding the bias, and only divides by the mean squared activation:
$
z_( b s d ) & <- ( (z_( b s d ) ) / sqrt( epsilon + 1 / D sum_d z_( b s d )^( 2 ) ) ) gamma_d
equiv mono("RMSNorm")_d z_( b s d ) .
$
Like in the layer norm case, the weight $gamma_( d )$ is again redundant if immediately followed by
a linear layer.
=== GLU Variants<subsec_glu_variants>
In @shazeer2020gluvariantsimprovetransformer, Shazeer advocated for
replacing the usual linear-then-activation function pattern,
$ z_(d') & = phi (W_(d' d) x_d) $ to
$ z_(d') & = V_(d' e) x_e phi (W_(d' d) x_d) med . $ So, just
perform another linear operation on the original input and broadcast it
against the usual activation function output. Biases for can also be
included. This construction is typically called “$phi$GLU\" where
$phi$ is the name of the activation function: ReGLU, SwiGLU/SiGLU
($phi = x sigma (x)$ used in the LLaMA models), etc.
=== Multi-Query Attention <subsec_multi_query_attn>
In @shazeer2019fast, the $A$ different key and value matrices are
replaced by a single matrix each, while $A$ different query-heads
remain. The mechanisms are otherwise unchanged: where there were
previously distinct key and value tensors used across different heads,
we just use the same tensors everywhere. This is #strong[Multi-Query
Attention] (MQA).
The primary reason for multi-query attention is that it vastly reduces
the size of the kv-cache (see @sec_kv_cache) during inference time,
decreasing the memory-burden of the cache by a factor of $A$. This
strategy also reduces activation memory during training, but that is
more of a side-effect.
=== Grouped Attention <subsec_grouped_attn>
#strong[Grouped Query Attention] (GQA) @ainslie2023gqa is the natural
extension of multi-query-attention to using $1 < G < A$ matrices for key
and value generation. Each of the $G$ different keys gets matched up
with $A \/ G$ heads (nice divisibility assumed)#footnote[Llama-2
@touvron2023llama2 uses GQA with $G = 8$, seemingly chosen so that each
group can be sharded and put on its own GPU within a standard 8-GPU
node.].
=== Parallel $MLP$ and $CA$ Layers
<parallel-and-layers>
Rather than first pass inputs into the $CA$ layer of each block, and then
pass those outputs on to $MLP$ in series,
#link("https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L303")[GPT-J-6B]
instead processes the $LN$ outputs in #emph[parallel]. That is, instead of
something like
$
z arrow.l z + MLP (LN (z + CA (z)))
$
we instead have#footnote[This alternative layer was also used in PaLM
@chowdhery2022palm where it was claimed that this formulation is
$tilde.op 15 %$ faster due to the ability to fuse the $MLP$ and $CA$matrix
multiplies together (though this is not done in the GPT-J-6B repo
above).]
$ z arrow.l z + MLP (z) + CA (z) med . $
Note that a $LN$ instance is also removed.
=== RoPE Embeddings
<rope-embeddings>
A shortcoming of traditional embeddings
$x_(b s d) arrow.r x_(b s d) + p_(s d)$ is that they do not generalize
very well: a model trained on such embeddings with a maximum sequence
length $K$ will do very poorly when evaluated on longer sequences. RoPE
(Rotary Position Embedding) @su2022roformer and variants thereof can
extend the viable context length by more clever mechanisms with stronger
implicit biases.
RoPE and its variants can be motivated by a few natural conditions.
Given the queries and keys for an input $q_(s h) \, k_(s h)$
(suppressing batch indices), the corresponding attention scores
computation $a_(s s') (q_s \, k_(s'))$ should reasonably satisfy the
below:
+ The attention score should only depend on the position indices
$s \, s'$ through their difference $s - s'$, i.e., through their
relative distance to each other.
+ The score computation should still be efficient, i.e., based on
matrix-mulitiplies.
+ The operation should preserve the scale of the intermediate
representations and attention scores, in order to avoid issues with
standard normalization.
These conditions suggest a very natural family of solutions: just rotate
the usual queries by some fixed element of $S O (H)$ using a generator
proportional to the position index and rotate the keys by the conjugate
element. That is, replace the $q_(s h) \, k_(s d)$ by
$
q'_( s h )&eq.triple [e^( i s hat(n)dot T ) ]_( h h' ) q_( s h' ) eq.triple R(s)_( h h' ) q_( s h' ) \
k'_( s h )&eq.triple [e^( -i s hat(n)dot T ) ]_( h h' ) k_( s h' ) eq.triple R(s)^( dagger )_( h h' ) k_( s h' ) ,
$<eq_rope>
which makes their dot-product is
$q'_(s h) k'_(s' h) = R (s - s ') q_(s h) k_(s h')$.
Performing the above computation with a dense element of $S O (H)$ is
infeasible, as it would require a new dense matrix-multiply by a unique
$H times H$ matrix at each sequence position#footnote[For one, the
$cal(O) ( S H ^2 )$ memory cost to store the matrices
would be significant. The FLOPs cost is only $2 B S H^2$, the same as
for other matrix multiplies, but because different matrices are needed
at position (it's a batched matrix multiply), these FLOPs would be much
more GPU memory-bandwidth intensive.] In the original RoPE paper, the
rotation $hat(n)$ was chosen such that the matrices are $2 times 2$
block-diagonal with the entries of the form#footnote[If $H$ isn't even,
the vectors are padded by an extra zero.]
$
R (s)_([h : h + 2] [h : h + 2]) & = mat(delim: "(", cos (s theta_h), - sin (s theta_h); sin (s theta_h), cos (s theta_h))
$
where $ theta_h & = 10^(- 4 h \/ H) med . $ The RoPE memory costs are thus $cal(O) ( K
H)$#footnote[A single RoPE buffer can be shared amongst all attention layers and are broadcast
across all heads, amortizing the memory costs.]. The sparsity present in this constrained form of
the RoPE matrices means that @eq_rope can be computed in $cal(O) ( B S H )$ time, rather than
$cal(O) ( B S H ^2 )$, as it would be for a general rotation matrix. See the paper for explicit
expressions.
While not obvious, the 2x2 form of the RoPE rotations above is in fact completely
general#footnote[Thanks to Davis Wertheimer for teaching me these facts.], after accounting for the
aribtitrariness inherent in the key and query projections. A general $S O(N)$ rotation can be
expressed in the form $R = C dot B dot C^T$ where $C in S O (N)$ and $B$ is block diagonal with 2x2
or 1x1 blocks. This can be demonstrated from the real Schur dedomposition specialized on $S O (N)$
elements or through the
#link("https://leimao.github.io/blog/Matrix-Block-Hiagonalization-Theorem/")[block diagonal
decomposition] after using specific properties of $S O (N)$ eigenvectors and
eigenvalues#footnote[Namely, that eigenvalues generically come in pairs of phases $exp(plus.minus i
theta)$, the real and imaginary components of eigenvectors have the same norm, and the real and
imaginary components of eigenvectors belonging to different eigenvalue pairs are all mutually
orthogonal.]. From this decomposition, we are guaranteed that we can write $R(1) = C dot B dot
B^T$, and hence $R(s)=R(1)^s=C dot B^s dot C^T$ satisfies all conditions we require, and the `C`
factors can be absorbed into the key and query projectors without loss of generality.
==== RoPE Implementation Notes
There are two common implementation strategies which (annoyingly) differ from each other in
important ways.
*HuggingFace Style*: In this implementation, cosine and sines of shape `(seqlen, head_dim)` are
provided, $cos_( s h )$ and $sin_( s h )$. In this approach, rather than pairing up consecutive
dimensions $h, h+1$ to perform the 2x2 rotations over, we pair up head dimension indices $h,h +
H/2$. So, the rotated queries, say, are given by
$
q_( s h )\' &= q_( s h ) cos_( s h ) + overline(q)_( s h ) sin_( s h ) , space
overline(q)_( s h ) & eq.triple CAT([ - q_( s[H/2:] ), q_( s[:H/2] )]) space .
$
*Complex Style*: In this implementation, phases $p_( s h ) = exp(i theta _( s h ))$ are provided,
where the $d$ index is `head_dim/2` = $H/2$ dimensional. The query, say, is then reshaped to have a
final two-dimensional axis, $q_( s h ) --> q _( s h t)$, $t in {0, 1}$ where the last two axes are
interpreted as the real and imaginary components of a complex tensor, respectively: $q _( s h t) -->
overline(q)_(s h)$. The RoPE-rotated queries then come from taking the real and imaginary parts of
$overline(q)_(s h) exp(i theta _( s h )) $ and reshaping into a `(seqlen, head_dim)` shaped tensor.
==== YaRN
YaRN @peng2023yarnefficientcontextwindow is a commonly-used RoPE extension variant based on the
following observations. Let $S$ be the original context length that the model was trained on and
$S\'$ be the target length: $S\'>S$.
First, assume that $theta_h= b^( -(floor.l h/2 floor.r)/H ) eq.triple (2 pi)/lambda _( h)$ for some
base#footnote[Typical default is $b=10^( 4 )$, as above.], so that head $h$ (and $h+1$) undergoes a
full rotation when used to describe a relative distance $s= lambda_( h )= 2 pi b^( (floor.l h/2 floor.r)/H )$.
Therefore, higher-indexed heads rotate over a longer token horizon, with the maximal horizon being
$cal(O)( b )$. For context lengths $1<< S << b$, low head-index pairs undergo many rotations, making
the RoPE mapping from relative distances to angle degenerate, while high head-index pairs undergo
less than a full rotation, leading to the intuition that low head-index pairs will be better at
capturing local distance information, while high-index pairs can better capture longer-distance
features.
When scaling up to lengths $S\'$, the most naive strategy is to rescale the overall angle: $theta_(
h ) --> S/(S\')times theta_( h )$. This is referred to as _Positional Interpolation_, or PI, and
ensures that the same maximal rotations are achieved when processing sequences of length up to $S\'$
as were seen when originally only processing up to $S$. The naive PI strategy is empirically found
to lead to degraded short-context performance. An alternative strategy is to instead rescale the
base, $b--> kappa times b$ for some $kappa > 1$. The former strategy is a uniform dilation of the
wavelength for all heads, while the latter #footnote[Called ABF for adjusted base frequency
@xiong2023effectivelongcontextscalingfoundation.] performs a greater dilation for higher head
indices. A choice of $kappa$ that connects the two strategies is to choose this parameter such that
the angle for the higher head index, $theta_( H-2 )$, is the same in both cases. This is called the
NTK-aware scheme.
YaRN aims for a middle ground, with the criteria of preserving the wavelengths of the
short-wavelength heads and linearly scaling the long-wavelength heads so that they undergo the same
maximal degree of rotation over horizon $S\'$ as they previously did over $S$. That is:
- If $lambda_( h ) << S$, do not adjust the wavelength.
- If $lambda_( h ) >= S$, use linear interpolation, as in PI.
Let $r_( h ) eq.triple (S)/lambda_( h ) $ be the number of rotations that head-pair idx $h$
undergoes over the original context length. We specify two parameters#footnote[Typical: $alpha,
beta=1,32$.], $alpha < beta$, that demarcate the boundaries of the above regions and a linear
interpolation function $0 <=gamma_( h ) <= 1$ such that the YaRN wavelengths are a weighted sum of
the PI and original wavelengths:
$
lambda\'_( h ) &= (1 - gamma_( h )) times (S\') / S times lambda_( h ) + gamma_( h ) lambda_( h )
$
where high head-index/long-wavelength pairs ($gamma_( h ) --> 0$) get the full linear
interpolation and the low head-index pairs ($gamma_( h ) --> 1$) have unchanged wavelengths. The specific
form of $gamma_( h )$ is:
$
gamma_( h ) (r_( h )) = cases(
0 wide &"if" r_( h ) < alpha,
(r_( h ) - alpha)/ (beta -alpha) &"if" alpha <=r_( h ) <= beta,
1 &"if" r_( h ) > beta,
)
$
or as a function of the head index $h$:
$
gamma_( h ) = cases(
1 wide&"if" floor.l h/2 floor.r < H ln(S / (2 pi beta)) ,
(r_( h ) - alpha)/ (beta -alpha) &"if" H ln(S / (2 pi beta)) <= floor.l h/2 floor.r <= H ln(S / (2 pi alpha)),
0 &"if" floor.l h/2 floor.r > H ln(S / (2 pi alpha)),
) space .
$
Finally, YaRN suggests scaling#footnote[Equivalent to scaling the keys and queries by $0.1
ln((S\')/S) + 1$.] the temperature by $T --> T / (0.1 ln((S\')/S) +
1)^2$, found by a phenomenological fit.
=== Flash Attention <subsec_flash_attention>
Flash Attention @dao2022flashattention@dao2023flashattention2 optimizes
the self attention computation by never materializing the
$cal(O) ( S ^2 )$ attention scores in off-chip
memory. This increases the arithmetic intensity of the computation and
reduces the activation memory required, at the expense of needing
recomputation in the backwards pass.
The central idea is to decompose the attention computation in the
following way. Dropping the batch index, let
$q_(s d) \, k_(s d) \, v_(s d)$ be the queries, keys, and values, and
$z_(s d)$ be the final output. Splitting into attention heads as in
$q_(s d) = q_(s (a h)) arrow.r q_(s a h)$ and similar, the computation
is#footnote[We omit the usual $sqrt(D \/ A)$ normalization factor inside
the Softmax to de-clutter the presentation. Really, this normalization
should just be enforced at the level of the matrices which are used to
generate the queries, keys, and values, anyway.]
$
z_( s a h ) &= SM_( s' ) ( q_( s a h' ) k_( s' a h' ) ) v_( s' a h )
$
which is then concatenated as
$z_(s (a h)) -> z_(s d)$ to get the result. We are omitting the
(very important) causal mask for clarity of presentation. Because each
attention head computation is identical, we also omit the $a$-index
going forward in this section.
The issue is that a naive computation would compute all
$cal(O) ( S ^2 )$ components of the attention scores
$q_(s h') k_(s' h')$ for each attention head and their exponential all
at once, which incurs a penalty of shuttling back and forth
$cal(O) ( S ^2
)$ elements to and from on-chip memory multiple times in order
to get the final $z_(s h)$ outputs (in addition to being potentially
memory expensive). Flash Attention functions by instead computing the
exponentials in stages with fewer memory transfers and never populating
the attention scores or exponentials on off-chip memory.
This works by first chunking all of the inputs along their sequence
dimensions as in:
- $q_(s h) = q_((i r) h) arrow.r q_(i r h)$ where
$i in {0 \, dots.h \, I - 1}$ and $r in {0 \, dots.h \, R - 1}$ with
$S = R I$
- $k_(s h) = k_((j c) h) arrow.r k_(j c h) \, v_(s h) = v_((j c) h) arrow.r v_(j c h)$
where $j in {0 \, dots.h \, J - 1}$ and $c in {0 \, dots.h \, C - 1}$
with $S = J C$
The chunk sizes are determined by memory constraints, as discussed
below. Then, the per-attention-head computation is equivalently written
as
$
z_(i r h) &= SM_(j c) ( q_( i r h' ) k_( j c h' ) ) v_( j c h ) \
&= ( exp ( q_( i r h' ) k_( j c h' ) ) ) / ( sum_( j c )exp ( q_( i r h'' ) k_( j c h'' ) ) ) v_( j c h ) \
&eq.triple (sum_( j ) Z_( i r j h ) ) / ( sum_( j'c ) exp (q_( i r h'' ) k_( j'c h'' )) ) \
&eq.triple (sum_( j ) Z_( i r j h ) ) / (sum_( j' )L_( i j'r )) \
&eq.triple ( Z_( i r h ) ) / (L_( i r ))
$
where we introduced the notation which will be used
in the algorithm below. The algorithm proceeds similarly to how it's
outlined above: we compute in chunks, looping over $i$ and an inner $j$
loop which is used to compute the numerator and denominator
simultaneously.
Ignoring the important causal mask and not tracking the maximum logits
(which we should do for numerical stability), the basic version which
captures the essentials of the algorithm is below. Additional
recomputation is needed for the backwards pass.
#figure(
kind: "algorithm",
supplement: [Algorithm],
caption: [Flash Attention (Naive - Missing causal mask/max tracking.)],
pseudocode-list(booktabs: true)[
+ *For* $i in ...$ #h(1fr) `# Computing outputs z[i, r, h] for all r, h`
+ Initialize off-chip tensor $z_(i r h)$ to zeros
+ Move $q_(i r h)$ on-chip, instantiate temp $Z_(i r h)$ to zeros on-chip.
+ *For* $j in ...$ #h(1fr) `# On-chip compute. r, c indices processed in parallel.`
+ Move $k_(j c h) \, v_(j c h)$ on-chip $Z_(i r h) arrow.l Z_(i r h) + exp (q_(i r h') k_(j c h')) v_(j c h)$
+ Update numerator $L_(i r) arrow.l L_(i r) + sum_c exp (q_(i r h') k_(j c h'))$
+ Update denominator $z_(i r h) arrow.l Z_(i r h) / L_(i r)$ #h(1fr) `# Write result off-chip`
],
) <algo_fa_fwd_basic>
We now analyze the memory transfer costs. As a baseline, vanilla
attention requires $cal(O) ( S
^2+ D S )$ memory transfers per attention head, where the two
factors come from the attention scores and $q \, k \, v$, respectively.
For flash attention, we no longer shuttle the attention scores off-chip,
but $k \, v$ are repeatedly moved back and forth. These transfers form
most of the memory operations in the inner loop above, which access
$cal(O) ( I J C H ) ~
cal(O) ( (H S ^2 )/ R )$ elements over the
lifetime of the algorithm (per attention head). The factor $H \/ R$
determines the memory-access advantage, and this number is bound by the
on-chip memory size. The on-chip bytes from the queries, keys, and
vectors take $cal(O)
( C H + R H )$ memory and the temporaries from attention
scores and exponentials require $cal(O) ( R C )$. If we
have $M$ bytes of on-chip memory, then we have the constraint
$C H + R H + R C lt.tilde M$, and assuming assuming the chunks were
chosen to maximize on-chip memory usage, $H / R tilde.op H^2 / M$. Since
$M tilde.op 10^5$ bytes on 2023 GPUs, this is a small factor for the
typical head dimensions $H tilde.op 64$, as desired.
Flash attention is also a big win for activation memory: a naive
algorithm has a $cal(O) ( A B S
^2 )$ per-layer contribution to activation memory due to
needing to save the attention weights, but these are discarded and
re-computed for flash attention. The only additional memory cost comes
from the $cal(O) ( A B S )$ elements in the $ell_(a b s)$
statistics, which are dominated by the $cal(O) ( B S D )$
costs from needing to save inputs, and hence negligible.
==== The Details <subsubsec_fa_details>
Here we give more detailed descriptions of the flash-attention forwards
and backwards passes.
For the forwards pass, we add in maximum-logits tracking for more
numerically stable exponential computation and the causal mask. The
causal mask $C_(s s') = C_((i r) (j c))$ is zero if $s gt.eq s'$ and
$- infinity$ otherwise. The algorithm is as below.
#figure(
kind: "algorithm",
supplement: [Algorithm],
caption: [Flash Attention Forward Pass],
pseudocode-list(booktabs: true)[
+ *For* $i in ...$ `#Computing outputs z[i, r, h] for all r, h`
+ Initialize off-chip tensors $z _( i r h ), ell _( i r )$ to zeros #h(1fr)
+ Move $q _( i r h )$ on-chip, instantiate temp $Z _( i r h )$ to zeros and $M ^"new" _( i r ), M ^"old" _( i r )$ to $-infinity $ on-chip
+ *For* $j in ..$ #h(1fr) `# On-chip compute. r, c indices processed in parallel`
+ Move $k_( j c h ),v _( j c h )$ on-chip
+ $S_( i r j c ) <- q_( i r h' ) k_( j c h' ) + C_( i j r c )$ #h(1fr) `# logits + causal mask`