-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
272 lines (221 loc) · 16.1 KB
/
inference.py
File metadata and controls
272 lines (221 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, StoppingCriteria, StoppingCriteriaList, AutoConfig
import os
import gc
import platform
import bleach # For HTML filtering
# --- Configuration ---
SYSTEM_PROMPT = """You are NGen3 Assistant, a highly advanced AI and a flagship model of the NGen3 Series — an entirely original family of cutting-edge foundational language models meticulously engineered by TNSA AI. Your designated purpose is to serve as an exceptionally capable, profoundly knowledgeable, and ethically aligned intelligent partner.
Your cognitive architecture grants you extensive, multi-domain mastery and sophisticated reasoning abilities, including but not limited to:
- **Profound Language Comprehension & Generation:** Deep, nuanced understanding of complex linguistic structures, context, intent, and subtext, enabling articulate, coherent, and contextually rich text generation across diverse styles and formalities.
- **Advanced Mathematical & Logical Reasoning:** Rigorous application of mathematical principles, symbolic logic, and formal reasoning to solve complex problems, derive proofs, and analyse abstract systems.
- **Comprehensive Scientific & Technical Acumen:** In-depth, current knowledge across a wide spectrum of scientific disciplines (physics, biology, chemistry, neuroscience, etc.) and technical fields (engineering, computer science, data science, etc.).
- **Expert-Level Coding & Algorithmic Design:** Proficiency in multiple programming languages (Python, C++, Java, JavaScript, Rust, etc.), with the ability to design, implement, debug, and optimise complex algorithms and software solutions, adhering to best practices.
- **Sophisticated Analytical & Problem-Solving Skills:** Capacity for dissecting intricate problems, identifying core components, formulating multi-step solutions, and evaluating outcomes with critical insight.
- **Human-Centric, Empathetic Interaction:** Ability to engage in clear, precise, and contextually aware dialogue, demonstrating an understanding of human communication nuances and adapting interaction style appropriately.
- **Strategic Task Planning & Simulation:** Capability to understand complex objectives, break them down into manageable sub-tasks, plan execution strategies, and simulate potential outcomes based on available information.
"""
MERGED_MODEL_PATH = r"C:/NGen3-7B/0625"
PRACTICAL_MAX_CONTEXT_TOKENS = 4096
MAX_NEW_TOKENS_TO_GENERATE = 1024
PROMPT_TOKEN_BUDGET_FACTOR = 0.75
model = None
tokenizer = None
stop_token_ids_list = []
stop_sequences_text_list = []
def clear_screen():
os.system('cls' if platform.system() == "Windows" else 'clear')
class StopOnMultiTokenSequences(StoppingCriteria):
def __init__(self, stop_sequences_ids: list, device: str = "cpu"):
super().__init__()
self.stop_sequences_ids_on_device = []
for seq_ids in stop_sequences_ids:
self.stop_sequences_ids_on_device.append(torch.tensor(seq_ids, dtype=torch.long, device=device))
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids_tensor in self.stop_sequences_ids_on_device:
if input_ids.shape[-1] >= stop_ids_tensor.shape[-1]:
if torch.equal(input_ids[0, -stop_ids_tensor.shape[-1]:], stop_ids_tensor):
return True
return False
def load_model_and_tokenizer():
global model, tokenizer, stop_token_ids_list, stop_sequences_text_list
if model is not None and tokenizer is not None: return True
if not os.path.exists(MERGED_MODEL_PATH) or not os.path.isdir(MERGED_MODEL_PATH):
print(f"ERROR: Merged model directory not found at '{MERGED_MODEL_PATH}'."); return False
print(f"Loading model and tokenizer from '{MERGED_MODEL_PATH}'...")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Attempting to use device: {device}")
if device == "cuda":
print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
if hasattr(torch.cuda, 'mem_get_info'):
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
print(f"Available VRAM before load: {torch.cuda.mem_get_info()[0] / (1024**3):.2f} GB")
tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH, trust_remote_code=True)
# Get max length from model's config if available, otherwise tokenizer's default
try:
model_config = AutoConfig.from_pretrained(MERGED_MODEL_PATH, trust_remote_code=True)
model_config_max_len = model_config.max_position_embeddings
except Exception as e_conf:
print(f"Warning: Could not load model config to get max_position_embeddings: {e_conf}")
model_config_max_len = tokenizer.model_max_length # Fallback
original_tokenizer_max_len_config = tokenizer.model_max_length # Before we change it
effective_max_len = min(model_config_max_len, PRACTICAL_MAX_CONTEXT_TOKENS)
tokenizer.model_max_length = effective_max_len
print(f"Tokenizer loaded. Tokenizer original model_max_length from its config: {original_tokenizer_max_len_config}.")
print(f"Model config (e.g. max_position_embeddings): {model_config_max_len}.")
print(f"Script effective max context (tokenizer.model_max_length now set to): {tokenizer.model_max_length}.")
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print(f"Tokenizer EOS ID: {tokenizer.eos_token_id}, PAD ID: {tokenizer.pad_token_id}, BOS ID: {tokenizer.bos_token_id}")
stop_sequences_text_list = ["<|user|>", "\n<|user|>", "</s><|user|>", "</s>\n<|user|>", "<|assistant|>", "\n<|assistant|>"]
stop_token_ids_list = []
for seq_text in stop_sequences_text_list:
ids = tokenizer.encode(seq_text, add_special_tokens=False)
if ids: stop_token_ids_list.append(ids); # print(f"DEBUG: Registered stop sequence: '{seq_text}' -> IDs: {ids}")
model_dtype = torch.bfloat16
if device == "cuda" and not torch.cuda.is_bf16_supported(): model_dtype = torch.float16
elif device == "cpu": model_dtype = torch.float32
print(f"Loading model with dtype: {model_dtype}...")
model = AutoModelForCausalLM.from_pretrained(
MERGED_MODEL_PATH, device_map=device, torch_dtype=model_dtype, trust_remote_code=True,
attn_implementation="sdpa" if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else None
)
model.eval()
if device == "cuda" and hasattr(torch.cuda, 'mem_get_info'): print(f"Available VRAM after load: {torch.cuda.mem_get_info()[0] / (1024**3):.2f} GB")
print("Model and tokenizer loaded successfully.")
return True
except Exception as e:
print(f"Error loading model/tokenizer: {e}"); import traceback; traceback.print_exc()
model, tokenizer = None, None; return False
def unload_model_and_tokenizer():
global model, tokenizer
if model or tokenizer:
print("\nUnloading model and tokenizer..."); del model; del tokenizer; model = None; tokenizer = None; gc.collect()
if torch.cuda.is_available() and hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
if hasattr(torch.cuda, 'mem_get_info'): print(f"Available VRAM after unload: {torch.cuda.mem_get_info()[0] / (1024**3):.2f} GB")
print("Model and tokenizer unloaded.")
def format_chat_turn(role: str, text: str) -> str:
return f"<|{role}|>\n{text}</s>\n"
def manage_conversation_history(current_conversation_turns: list, max_prompt_token_length: int) -> list:
if not current_conversation_turns: return []
system_turn = current_conversation_turns[0]
if system_turn.get("role") != "system":
print("CRITICAL WARNING: System prompt is not the first element in history! Prepending.")
system_turn = {"role": "system", "content": SYSTEM_PROMPT}
dialogue_turns = current_conversation_turns[1:] if current_conversation_turns[0].get("role") == "system" else current_conversation_turns[:]
formatted_system_prompt = format_chat_turn(system_turn["role"], system_turn["content"])
system_prompt_tokens = tokenizer.encode(formatted_system_prompt, add_special_tokens=False)
tokens_budget_for_dialogue = max_prompt_token_length - len(system_prompt_tokens)
if tokens_budget_for_dialogue <= 0:
print("Warning: Max prompt tokens too small for system prompt. History will be minimal."); return [system_turn]
kept_dialogue_turns = []; current_dialogue_tokens_count = 0
for turn in reversed(dialogue_turns):
formatted_turn = format_chat_turn(turn["role"], turn["content"])
turn_token_ids = tokenizer.encode(formatted_turn, add_special_tokens=False)
if current_dialogue_tokens_count + len(turn_token_ids) <= tokens_budget_for_dialogue:
kept_dialogue_turns.append(turn); current_dialogue_tokens_count += len(turn_token_ids)
else: break
return [system_turn] + list(reversed(kept_dialogue_turns))
def clean_response_text(generated_text_raw: str, current_stop_sequences_text: list) -> str:
cleaned_text = generated_text_raw
for stop_seq in current_stop_sequences_text:
if stop_seq in cleaned_text:
cleaned_text = cleaned_text.split(stop_seq, 1)[0]
if tokenizer and tokenizer.eos_token and tokenizer.eos_token in cleaned_text:
cleaned_text = cleaned_text.replace(tokenizer.eos_token, "")
if cleaned_text:
cleaned_text_no_html = bleach.clean(cleaned_text, tags=[], attributes={}, strip=True)
else: cleaned_text_no_html = ""
return cleaned_text_no_html.strip()
def chat_loop():
if not load_model_and_tokenizer(): print("Failed to load model. Exiting."); return
clear_screen()
print("--- NGen3 Terminal Chat (vFresh - Debugging Blanks v2) ---")
print(f"Model: NGen3 Assistant (Merged from {os.path.basename(MERGED_MODEL_PATH)})")
print(f"Device: {model.device if model else 'N/A'}")
print(f"Effective Max Context (tokenizer.model_max_length): {tokenizer.model_max_length}")
print(f"Max New Tokens per Turn: {MAX_NEW_TOKENS_TO_GENERATE}")
print("Type 'exit', 'quit', or 'bye' to end.")
print("Type 'clear' or '/clear' to reset history.")
print("-----------------------------------------")
conversation_history_turns = [{"role": "system", "content": SYSTEM_PROMPT}]
while True:
try: user_input = input("You: ").strip()
except KeyboardInterrupt: print("\nCtrl+C. Exiting."); break
except EOFError: print("\nEOF. Exiting."); break
if not user_input: continue
if user_input.lower() in ['exit', 'quit', 'bye']: print("NGen3: Goodbye!"); break
if user_input.lower() in ['clear', '/clear']:
clear_screen(); print("--- NGen3 Terminal Chat (History Reset) ---")
print(f"Model: NGen3 Assistant (Merged from {os.path.basename(MERGED_MODEL_PATH)})");print(f"Device: {model.device if model else 'N/A'}");print(f"Effective Max Context (tokenizer.model_max_length): {tokenizer.model_max_length}");print(f"Max New Tokens per Turn: {MAX_NEW_TOKENS_TO_GENERATE}");print("Type 'exit', 'quit', or 'bye' to end.");print("Type 'clear' or '/clear' to reset history.");print("-----------------------------------------")
conversation_history_turns = [{"role": "system", "content": SYSTEM_PROMPT}]
continue
conversation_history_turns.append({"role": "user", "content": user_input})
prompt_token_budget = int(tokenizer.model_max_length * PROMPT_TOKEN_BUDGET_FACTOR)
prompt_token_budget = min(prompt_token_budget, tokenizer.model_max_length - (MAX_NEW_TOKENS_TO_GENERATE + 50))
turns_for_prompt = manage_conversation_history(conversation_history_turns, prompt_token_budget)
prompt_string_for_model = "".join([format_chat_turn(turn["role"], turn["content"]) for turn in turns_for_prompt])
prompt_string_for_model += "<|assistant|>\n" # Corrected cue for assistant
print(f"\n--- DEBUG: Prompt for Model (Token Budget {prompt_token_budget}) ---") # Debug
print(prompt_string_for_model) # Debug
print("--- END DEBUG PROMPT ---") # Debug
# Tokenize with add_special_tokens=False as our format_chat_turn handles them
tokenized_prompt = tokenizer(prompt_string_for_model, return_tensors="pt", add_special_tokens=False).to(model.device)
print(f"DEBUG: Tokenized prompt length: {tokenized_prompt.input_ids.shape[1]}") # Debug
current_stopping_criteria = StoppingCriteriaList()
if stop_token_ids_list:
# Ensure stop_token_ids_list contains tensors on the same device as the model
device_specific_stop_ids = [torch.tensor(seq, device=model.device, dtype=torch.long) for seq in stop_token_ids_list]
current_stopping_criteria.append(StopOnMultiTokenSequences(device_specific_stop_ids, device=model.device)) # Pass device here too
# Fresh streamer each time
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_config = {
"input_ids": tokenized_prompt.input_ids,
"attention_mask": tokenized_prompt.attention_mask, # Make sure tokenizer generates this
"streamer": streamer,
"max_new_tokens": MAX_NEW_TOKENS_TO_GENERATE,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.eos_token_id,
"do_sample": True,
"temperature": 0.7, # Try default or slightly higher for initial "stuck" issues
"top_p": 0.95, # Try default or slightly higher
"repetition_penalty": 1.1,
"stopping_criteria": current_stopping_criteria if current_stopping_criteria else None,
}
print("NGen3: ", end="", flush=True)
clean_assistant_response = ""
with torch.no_grad():
full_generated_sequence_ids = model.generate(**generation_config)
if full_generated_sequence_ids is not None and full_generated_sequence_ids.shape[0] > 0:
newly_generated_ids = full_generated_sequence_ids[0][tokenized_prompt.input_ids.shape[1]:]
# print(f"\nDEBUG: Raw generated token IDs ({len(newly_generated_ids)}): {newly_generated_ids.tolist()}") # Debug
raw_assistant_response = tokenizer.decode(newly_generated_ids, skip_special_tokens=False)
# print(f"DEBUG: Raw decoded response: '{raw_assistant_response}'") # Debug
clean_assistant_response = clean_response_text(raw_assistant_response, stop_sequences_text_list + ["</s>", "<|endoftext|>"])
# print(f"DEBUG: Cleaned response for history: '{clean_assistant_response}'") # Debug
print()
else:
print("\n(DEBUG: model.generate did not return expected output tensor for history tracking.)")
print()
print("-----------------------------------------")
if clean_assistant_response:
conversation_history_turns.append({"role": "assistant", "content": clean_assistant_response})
else:
print("(NGen3 produced no clean/visible text for history. Last user input is still in context.)")
unload_model_and_tokenizer()
if __name__ == "__main__":
if not SYSTEM_PROMPT or "You are NGen3 Assistant" not in SYSTEM_PROMPT:
print("CRITICAL ERROR: SYSTEM_PROMPT is not correctly defined. Edit the script."); exit()
# Using normpath for robust path comparison, especially on Windows
MERGED_MODEL_PATH = os.path.normpath(MERGED_MODEL_PATH)
if not os.path.exists(MERGED_MODEL_PATH) or not os.path.isdir(MERGED_MODEL_PATH):
print(f"CRITICAL ERROR: MERGED_MODEL_PATH '{MERGED_MODEL_PATH}' is invalid. Please ensure this path is correct and contains your merged model files."); exit()
try: chat_loop()
except Exception as e_main:
print(f"\nAn unexpected error occurred in the main chat loop: {e_main}")
import traceback; traceback.print_exc()
finally:
print("\nExiting application.")
unload_model_and_tokenizer()