Skip to content

Commit 5b062db

Browse files
Add challenge 74: Layer Normalization (Medium)
Layer normalization normalizes each row of an N×D input independently, which is the key operation in transformer/LLM architectures. Unlike batch normalization (column-wise), this requires per-row reductions that cannot be trivially parallelized — solvers must think carefully about shared memory reductions, work distribution (one block per row), and the two-pass algorithm (mean then variance). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0578315 commit 5b062db

8 files changed

Lines changed: 414 additions & 0 deletions

File tree

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
<p>
2+
Implement the layer normalization forward pass for a 2D input tensor of shape [N, D], where N is
3+
the number of samples and D is the feature dimension. Unlike batch normalization (which normalizes
4+
across the batch), layer normalization computes independent statistics for each sample and normalizes
5+
across its feature dimension, then applies per-feature learnable scale (<code>gamma</code>) and
6+
shift (<code>beta</code>) parameters.
7+
</p>
8+
9+
<p>
10+
For each sample i, layer normalization computes:
11+
\[
12+
\begin{align}
13+
\mu_i &= \frac{1}{D} \sum_{j=0}^{D-1} x_{i,j} \\
14+
\sigma_i^2 &= \frac{1}{D} \sum_{j=0}^{D-1} (x_{i,j} - \mu_i)^2 \\
15+
\hat{x}_{i,j} &= \frac{x_{i,j} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} \\
16+
y_{i,j} &= \gamma_j \hat{x}_{i,j} + \beta_j
17+
\end{align}
18+
\]
19+
</p>
20+
21+
<h2>Implementation Requirements</h2>
22+
<ul>
23+
<li>Use only native features (external libraries are not permitted)</li>
24+
<li>The <code>solve</code> function signature must remain unchanged</li>
25+
<li>The final result must be stored in the <code>output</code> tensor</li>
26+
</ul>
27+
28+
<h2>Example 1:</h2>
29+
<p>
30+
Input:<br>
31+
\( \text{input} \) (\(2 \times 4\)):
32+
\[
33+
\begin{bmatrix}
34+
1.0 & 2.0 & 3.0 & 4.0 \\
35+
5.0 & 6.0 & 7.0 & 8.0
36+
\end{bmatrix}
37+
\]
38+
\( \text{gamma} \) (\(4\)):
39+
\[
40+
\begin{bmatrix}
41+
1.0 & 1.0 & 1.0 & 1.0
42+
\end{bmatrix}
43+
\]
44+
\( \text{beta} \) (\(4\)):
45+
\[
46+
\begin{bmatrix}
47+
0.0 & 0.0 & 0.0 & 0.0
48+
\end{bmatrix}
49+
\]
50+
\( \epsilon \) = 1e-5<br><br>
51+
Output:<br>
52+
\( \text{output} \) (\(2 \times 4\)):
53+
\[
54+
\begin{bmatrix}
55+
-1.3416 & -0.4472 & 0.4472 & 1.3416 \\
56+
-1.3416 & -0.4472 & 0.4472 & 1.3416
57+
\end{bmatrix}
58+
\]
59+
</p>
60+
61+
<h2>Example 2:</h2>
62+
<p>
63+
Input:<br>
64+
\( \text{input} \) (\(2 \times 4\)):
65+
\[
66+
\begin{bmatrix}
67+
2.0 & 2.0 & 4.0 & 4.0 \\
68+
1.0 & 3.0 & 3.0 & 5.0
69+
\end{bmatrix}
70+
\]
71+
\( \text{gamma} \) (\(4\)):
72+
\[
73+
\begin{bmatrix}
74+
1.0 & 1.0 & 1.0 & 1.0
75+
\end{bmatrix}
76+
\]
77+
\( \text{beta} \) (\(4\)):
78+
\[
79+
\begin{bmatrix}
80+
0.0 & 0.0 & 0.0 & 0.0
81+
\end{bmatrix}
82+
\]
83+
\( \epsilon \) = 1e-5<br><br>
84+
Output:<br>
85+
\( \text{output} \) (\(2 \times 4\)):
86+
\[
87+
\begin{bmatrix}
88+
-1.0 & -1.0 & 1.0 & 1.0 \\
89+
-1.4142 & 0.0 & 0.0 & 1.4142
90+
\end{bmatrix}
91+
\]
92+
</p>
93+
94+
<h2>Constraints</h2>
95+
<ul>
96+
<li>1 &le; <code>N</code> &le; 65,536</li>
97+
<li>1 &le; <code>D</code> &le; 8,192</li>
98+
<li><code>eps</code> = 1e-5</li>
99+
<li>-100.0 &le; input values &le; 100.0</li>
100+
<li>0.1 &le; gamma values &le; 10.0</li>
101+
<li>-10.0 &le; beta values &le; 10.0</li>
102+
<li>Performance is measured with <code>N</code> = 8,192, <code>D</code> = 4,096</li>
103+
</ul>
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="Layer Normalization",
12+
atol=1e-04,
13+
rtol=1e-04,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(
19+
self,
20+
input: torch.Tensor,
21+
gamma: torch.Tensor,
22+
beta: torch.Tensor,
23+
output: torch.Tensor,
24+
N: int,
25+
D: int,
26+
eps: float,
27+
):
28+
assert input.shape == (N, D), f"Expected input.shape=({N}, {D}), got {input.shape}"
29+
assert output.shape == (N, D), f"Expected output.shape=({N}, {D}), got {output.shape}"
30+
assert gamma.shape == (D,), f"Expected gamma.shape=({D},), got {gamma.shape}"
31+
assert beta.shape == (D,), f"Expected beta.shape=({D},), got {beta.shape}"
32+
assert input.dtype == gamma.dtype == beta.dtype == output.dtype == torch.float32
33+
assert input.device.type == "cuda"
34+
assert gamma.device.type == "cuda"
35+
assert beta.device.type == "cuda"
36+
assert output.device.type == "cuda"
37+
38+
mean = input.mean(dim=1, keepdim=True)
39+
var = input.var(dim=1, keepdim=True, unbiased=False)
40+
normalized = (input - mean) / torch.sqrt(var + eps)
41+
output.copy_(gamma * normalized + beta)
42+
43+
def get_solve_signature(self) -> Dict[str, tuple]:
44+
return {
45+
"input": (ctypes.POINTER(ctypes.c_float), "in"),
46+
"gamma": (ctypes.POINTER(ctypes.c_float), "in"),
47+
"beta": (ctypes.POINTER(ctypes.c_float), "in"),
48+
"output": (ctypes.POINTER(ctypes.c_float), "out"),
49+
"N": (ctypes.c_int, "in"),
50+
"D": (ctypes.c_int, "in"),
51+
"eps": (ctypes.c_float, "in"),
52+
}
53+
54+
def generate_example_test(self) -> Dict[str, Any]:
55+
dtype = torch.float32
56+
N, D = 2, 4
57+
input = torch.tensor(
58+
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], device="cuda", dtype=dtype
59+
)
60+
gamma = torch.tensor([1.0, 1.0, 1.0, 1.0], device="cuda", dtype=dtype)
61+
beta = torch.tensor([0.0, 0.0, 0.0, 0.0], device="cuda", dtype=dtype)
62+
output = torch.empty((N, D), device="cuda", dtype=dtype)
63+
return {
64+
"input": input,
65+
"gamma": gamma,
66+
"beta": beta,
67+
"output": output,
68+
"N": N,
69+
"D": D,
70+
"eps": 1e-5,
71+
}
72+
73+
def generate_functional_test(self) -> List[Dict[str, Any]]:
74+
dtype = torch.float32
75+
tests = []
76+
77+
# single_sample_small_d
78+
N, D = 1, 4
79+
tests.append(
80+
{
81+
"input": torch.tensor([[1.0, 2.0, 3.0, 4.0]], device="cuda", dtype=dtype),
82+
"gamma": torch.ones(D, device="cuda", dtype=dtype),
83+
"beta": torch.zeros(D, device="cuda", dtype=dtype),
84+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
85+
"N": N,
86+
"D": D,
87+
"eps": 1e-5,
88+
}
89+
)
90+
91+
# single_sample_single_feature — all same value; var=0, norm output = beta
92+
N, D = 1, 1
93+
tests.append(
94+
{
95+
"input": torch.tensor([[3.0]], device="cuda", dtype=dtype),
96+
"gamma": torch.tensor([2.0], device="cuda", dtype=dtype),
97+
"beta": torch.tensor([1.0], device="cuda", dtype=dtype),
98+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
99+
"N": N,
100+
"D": D,
101+
"eps": 1e-5,
102+
}
103+
)
104+
105+
# all_zeros_input
106+
N, D = 4, 8
107+
tests.append(
108+
{
109+
"input": torch.zeros((N, D), device="cuda", dtype=dtype),
110+
"gamma": torch.ones(D, device="cuda", dtype=dtype),
111+
"beta": torch.zeros(D, device="cuda", dtype=dtype),
112+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
113+
"N": N,
114+
"D": D,
115+
"eps": 1e-5,
116+
}
117+
)
118+
119+
# negative_numbers
120+
N, D = 2, 4
121+
tests.append(
122+
{
123+
"input": torch.tensor(
124+
[[-1.0, -2.0, -3.0, -4.0], [-5.0, -6.0, -7.0, -8.0]],
125+
device="cuda",
126+
dtype=dtype,
127+
),
128+
"gamma": torch.ones(D, device="cuda", dtype=dtype),
129+
"beta": torch.zeros(D, device="cuda", dtype=dtype),
130+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
131+
"N": N,
132+
"D": D,
133+
"eps": 1e-5,
134+
}
135+
)
136+
137+
# different_gamma_beta
138+
N, D = 2, 4
139+
tests.append(
140+
{
141+
"input": torch.tensor(
142+
[[0.0, 1.0, 2.0, 3.0], [-3.0, -1.0, 1.0, 3.0]],
143+
device="cuda",
144+
dtype=dtype,
145+
),
146+
"gamma": torch.tensor([2.0, 0.5, 1.0, 3.0], device="cuda", dtype=dtype),
147+
"beta": torch.tensor([1.0, -1.0, 0.0, 0.5], device="cuda", dtype=dtype),
148+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
149+
"N": N,
150+
"D": D,
151+
"eps": 1e-5,
152+
}
153+
)
154+
155+
# power_of_2_medium
156+
N, D = 16, 64
157+
tests.append(
158+
{
159+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
160+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
161+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
162+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
163+
"N": N,
164+
"D": D,
165+
"eps": 1e-5,
166+
}
167+
)
168+
169+
# power_of_2_large_d
170+
N, D = 32, 512
171+
tests.append(
172+
{
173+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
174+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
175+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-2.0, 2.0),
176+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
177+
"N": N,
178+
"D": D,
179+
"eps": 1e-5,
180+
}
181+
)
182+
183+
# non_power_of_2
184+
N, D = 30, 100
185+
tests.append(
186+
{
187+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-3.0, 3.0),
188+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.1, 3.0),
189+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-5.0, 5.0),
190+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
191+
"N": N,
192+
"D": D,
193+
"eps": 1e-5,
194+
}
195+
)
196+
197+
# non_power_of_2_large
198+
N, D = 255, 300
199+
tests.append(
200+
{
201+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-50.0, 50.0),
202+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
203+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
204+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
205+
"N": N,
206+
"D": D,
207+
"eps": 1e-5,
208+
}
209+
)
210+
211+
# realistic_transformer_size
212+
N, D = 1024, 768
213+
tests.append(
214+
{
215+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
216+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
217+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
218+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
219+
"N": N,
220+
"D": D,
221+
"eps": 1e-5,
222+
}
223+
)
224+
225+
return tests
226+
227+
def generate_performance_test(self) -> Dict[str, Any]:
228+
dtype = torch.float32
229+
N, D = 8192, 4096
230+
return {
231+
"input": torch.empty((N, D), device="cuda", dtype=dtype).uniform_(-10.0, 10.0),
232+
"gamma": torch.empty(D, device="cuda", dtype=dtype).uniform_(0.5, 2.0),
233+
"beta": torch.empty(D, device="cuda", dtype=dtype).uniform_(-1.0, 1.0),
234+
"output": torch.empty((N, D), device="cuda", dtype=dtype),
235+
"N": N,
236+
"D": D,
237+
"eps": 1e-5,
238+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// input, gamma, beta, output are device pointers
4+
extern "C" void solve(const float* input, const float* gamma, const float* beta, float* output,
5+
int N, int D, float eps) {}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# input, gamma, beta, output are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
input: cute.Tensor,
9+
gamma: cute.Tensor,
10+
beta: cute.Tensor,
11+
output: cute.Tensor,
12+
N: cute.Int32,
13+
D: cute.Int32,
14+
eps: cute.Float32,
15+
):
16+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# input, gamma, beta are tensors on GPU
6+
@jax.jit
7+
def solve(
8+
input: jax.Array, gamma: jax.Array, beta: jax.Array, N: int, D: int, eps: float
9+
) -> jax.Array:
10+
# return output tensor directly
11+
pass
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from gpu.host import DeviceContext
2+
from gpu.id import block_dim, block_idx, thread_idx
3+
from memory import UnsafePointer
4+
from math import ceildiv
5+
6+
# input, gamma, beta, output are device pointers
7+
@export
8+
def solve(input: UnsafePointer[Float32], gamma: UnsafePointer[Float32],
9+
beta: UnsafePointer[Float32], output: UnsafePointer[Float32],
10+
N: Int32, D: Int32, eps: Float32):
11+
pass

0 commit comments

Comments
 (0)