-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathble_generator.py
More file actions
282 lines (233 loc) · 9.57 KB
/
ble_generator.py
File metadata and controls
282 lines (233 loc) · 9.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import wat
import json
import math
import time
import random
import datetime
import dataclasses
import asyncio
import struct
from bleak import BleakClient
import colorsys
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
import uvicorn
import torch
from vlm_processor import VLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
CV_BATCH_SIZE = 32
CV_METHOD = "pca_center"
CV_REPETITION_PENALTY = 1.1
CV_TEMPERATURE = 0.8
N_CONTEXT = 60
CV_DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
CV_DEFAULT_LAYERS = list(range(5, 22))
CVEC = "vectors/moon/moon_20241218.gguf"
MIN_CVEC, MAX_CVEC = -1.1, 1.2
DEVICE = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
PROMPT = "I see a "
# BLE constants
IMU_SERVICE_UUID = "eb1d3224-ab67-4114-89db-d12ac0684005"
IMU_DATA_UUID = "963eeca0-d121-458c-b32f-a99c40d8bf19"
DEVICE_ADDRESS = "753E1AA1-3AD1-DEF4-5B4A-CF09F9640206"
# Global state
current_strength = 0.0
generation_active = True
token_queue = asyncio.Queue()
start_time = time.time() # Add this line to track time for sine wave
sinwave_mode = True # This existing line will control which mode we use
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# vlm = VLM()
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
async def get_page():
return FileResponse("static/stream.html")
@app.get("/stream")
async def stream_text(request: Request):
async def event_generator():
try:
while True:
if await request.is_disconnected():
break
# Get next token from queue
token = await token_queue.get()
yield {
"data": json.dumps({
"content": token.content,
"token_id": token.token_id,
"strength": token.strength
})
}
except asyncio.CancelledError:
pass
return EventSourceResponse(event_generator())
@dataclasses.dataclass
class Token:
content: str
token_id: int = 0
strength: float = 0
class Generator:
def __init__(self):
print("Getting camera scene...")
# scene = vlm.get_image_and_description()
start = "I can see " #f"<|start_header_id|>user<|end_header_id|>\n\n You can see {scene}<|eot_id|><|start_header_id|>assistant<|end_header_id|> In this image I can see"
print("Loading LM CVEC MODEL")
self.tokenizer = AutoTokenizer.from_pretrained(CV_DEFAULT_MODEL)
self.tokenizer.pad_token_id = 0
model = AutoModelForCausalLM.from_pretrained(CV_DEFAULT_MODEL, torch_dtype=torch.float16).to(DEVICE)
self.model = ControlModel(model, CV_DEFAULT_LAYERS)
print("Loading vector...")
self.vector = ControlVector.import_gguf(CVEC)
self.initial_tokens = self.tokenizer.tokenize(start)
self.tokens = self.initial_tokens.copy()
self.fullstop_token = self.tokenizer.encode(".")
self.step = 0
self.previous_cvec_applied = None
self.max_tokens = 600
def next(self, raw_strength: float):
# print(self.step)
strength = (raw_strength + 1) * 0.5 * (MAX_CVEC - MIN_CVEC) + MIN_CVEC
vector = self.vector * strength
# if self.previous_cvec_applied is None or vector != self.previous_cvec_applied:
# print(f"\nApplying strength: {strength:.2f}")
self.model.set_control(vector)
self.previous_cvec_applied = vector
context = self.tokenizer.convert_tokens_to_string(self.tokens[-N_CONTEXT:])
model_tokens = self.tokenizer(context, return_tensors="pt").to(self.model.device)
logits = self.model.forward(**model_tokens).logits[0, -1, :]
# logits[self.tokenizer.eos_token_id] = -10000 # set eos score very low so it isnt selected in softmax
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
token_text = self.tokenizer.decode(next_token)
next_token_item = next_token.item()
# If we hit end of line token or max tokens, reset tokens to initial prompt
if self.step >= self.max_tokens or next_token_item == self.tokenizer.eos_token_id:
print("Resetting tokens")
# scene = vlm.get_image_and_description()
# print(scene)
start = "I can see " #f"<|start_header_id|>user<|end_header_id|>\n\n You can see {scene}<|eot_id|><|start_header_id|>assistant<|end_header_id|>In this image I can see"
self.initial_tokens = self.tokenizer.tokenize(start)
self.tokens = self.initial_tokens.copy()
self.step = 0
else:
self.tokens.append(token_text)
self.step += 1
return Token(content=token_text, token_id=next_token_item, strength=strength)
class BLEController:
def __init__(self):
self.client = None
def parse_imu_data(self, data: bytearray) -> tuple:
"""Parse the raw IMU data bytes into orientation values."""
values = struct.unpack('3f', data) # 3 float values: pitch, roll, yaw
return values
def notification_handler(self, sender, data):
"""Handle incoming notifications from the BLE device."""
global current_strength
pitch, roll, yaw = self.parse_imu_data(data)
# Map yaw to smooth circular pattern: -180/180° -> 0, 90° -> 1, -90° -> -1
normalized_yaw = math.sin(math.radians(yaw)) # Convert to radians and apply sine
current_strength = normalized_yaw
# print(f"Current strength: {current_strength:.2f}")
def get_sine_strength() -> float:
"""Calculate sine wave strength based on current second in minute"""
current_second = datetime.datetime.now().second
# Map seconds (0-59) to radians (0-2π) and shift by π/2 to start at -1
angle = (current_second / 60) * 2 * math.pi - (math.pi / 2)
return math.sin(angle)
async def run_ble():
global generation_active, current_strength
ble = BLEController()
try:
if sinwave_mode:
while generation_active:
# current_strength = get_sine_strength()
await asyncio.sleep(0.1) # Update every 100ms
else:
print(f"Connecting to BLE device at {DEVICE_ADDRESS}...")
async with BleakClient(DEVICE_ADDRESS) as client:
print("Connected! Reading orientation data...")
ble.client = client
await client.start_notify(IMU_DATA_UUID, ble.notification_handler)
while generation_active:
await asyncio.sleep(0.1)
except Exception as e:
print(f"\nBLE Error: {str(e)}")
generation_active = False
def get_sine_inc(i:int):
# given timestep i, return a value between MIN_CVEC and MAX_CVEC on sinewave
return MIN_CVEC + (MAX_CVEC - MIN_CVEC) * (math.sin(i / 200 * 2 * math.pi) + 1) / 2
async def run_generator():
global generation_active
generator = Generator()
# print(PROMPT, end='', flush=True)
# Put initial prompt in queue
await token_queue.put(Token(content="I can see ", token_id=0, strength=0))
i = MIN_CVEC
try:
while generation_active:
# token = generator.next(current_strength)
token = generator.next(get_sine_inc(i))
i += 1
await token_queue.put(token)
await asyncio.sleep(0.01)
except Exception as e:
print(f"\nGenerator Error: {str(e)}")
generation_active = False
async def run_server():
config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info")
server = uvicorn.Server(config)
await server.serve()
async def main():
global generation_active
try:
# Run all processes concurrently
await asyncio.gather(
run_ble(),
run_generator(),
run_server()
)
except KeyboardInterrupt:
print("\nStopping...")
generation_active = False
finally:
generation_active = False
def chat_template_unparse(messages: list[tuple[str, str]]) -> str:
# Convert chat template (role, content) into a string
template = []
for role, content in messages:
template.append(
f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
)
if messages[-1][0] != "assistant":
# prefill assistant prefix
template.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
return "".join(template)
def chat_template_parse(resp: str) -> list[tuple[str, str]]:
# Parse chat template response into list of (role, content) tuples
resp = resp.strip().removeprefix("<|begin_of_text|>")
messages = []
for part in resp.split("<|start_header_id|>"):
role_and_content = part.split("<|end_header_id|>")
if len(role_and_content) == 1:
role, content = role_and_content[0], ""
else:
role, content = role_and_content
content = content.split("<|eot_id|>")[0]
messages.append((role.strip(), content.strip()))
return messages
if __name__ == "__main__":
asyncio.run(main())