-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_singleton.py
More file actions
44 lines (34 loc) · 1.32 KB
/
model_singleton.py
File metadata and controls
44 lines (34 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import os
from model import GPT
from config import GPTConfig
from tokenizer import Tokenizer
from contextlib import nullcontext
DEMO_MODE = False
if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"
if DEVICE == "cuda": CTX = torch.autocast(DEVICE, dtype=torch.bfloat16)
elif DEVICE == "mps": CTX = nullcontext() # No autocast for MPS
else: CTX = torch.autocast("cpu", dtype=torch.bfloat16)
def load_model():
print(f"Loading chat model on {DEVICE}...")
load_device = 'cpu' if DEVICE == 'mps' else DEVICE
checkpoint = torch.load("data/best_checkpoint_chat.pt", map_location=load_device)
# Initialize model with chat-specific config
model_config = GPTConfig(**checkpoint["model_args"])
model = GPT(model_config)
# Handle _orig_mod prefix in state dict
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
return model
def get_tokenizer():
return Tokenizer("data/tokenizer.model") # Using chat-specific tokenizer
model = load_model()
tokenizer = get_tokenizer()