-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample.py
More file actions
32 lines (26 loc) · 885 Bytes
/
example.py
File metadata and controls
32 lines (26 loc) · 885 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from open_olmo.main import OLMoHybridConfig, OLMoHybrid
if __name__ == "__main__":
torch.manual_seed(0)
cfg = OLMoHybridConfig(
vocab_size=1024,
d_model=256,
num_heads=4,
num_layers=8,
hybrid_ratio=3,
max_seq_len=512,
chunk_size=32,
)
model = OLMoHybrid(cfg)
print(f"Layer pattern : {model.layer_types}")
print(f"Parameters : {model.num_parameters():,}")
B, T = 2, 64
tokens = torch.randint(0, cfg.vocab_size, (B, T))
logits, _ = model(tokens)
print(logits)
print(logits.shape)
assert logits.shape == (B, T, cfg.vocab_size), logits.shape
# gen = model.generate(tokens[:, :8], max_new_tokens=16)
# assert gen.shape[0] == B and gen.shape[1] == 8 + 16, gen.shape
# print(f"Forward : {logits.shape} ✓")
# print(f"Generate : {gen.shape} ✓")