-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
107 lines (84 loc) · 2.98 KB
/
main.py
File metadata and controls
107 lines (84 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import asyncio
import socketio
import uvicorn
import uuid
from datetime import datetime
from langchain import embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain.text_splitter import MarkdownTextSplitter
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pathlib import Path
agent_id = "83809084-1a8f-4532-bb99-af959827067c"
agent_name = "Union"
def load_documents():
docs = []
md_loader = DirectoryLoader(".", glob="*.md", loader_cls=TextLoader)
docs.extend(md_loader.load())
return docs
docs = load_documents()
api_key = os.environ["OPENAI_API_KEY"]
splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_documents(docs)
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever()
llm = ChatOpenAI(temperature=0, model="gpt-4o")
qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
app = FastAPI()
asgi_app = socketio.ASGIApp(sio, other_asgi_app=app)
ROOMS = {}
@sio.event
async def connect(sid, environ):
print(f"Client connected: {sid}")
@sio.event
async def disconnect(sid):
print(f"Client disconnected: {sid}")
for room_id in list(ROOMS):
ROOMS[room_id].discard(sid)
if not ROOMS[room_id]:
del ROOMS[room_id]
@sio.event
async def message(sid, data):
print(f"Received message: {data}")
event_type = data.get('type')
payload = data.get('payload')
if event_type == 1:
room_id = payload.get('roomId')
await sio.save_session(sid, {'roomId': room_id})
await sio.enter_room(sid, room_id)
ROOMS.setdefault(room_id, set()).add(sid)
print(f"[room join] {sid} joined {room_id}")
elif event_type == 2:
text = payload.get('message')
room_id = payload.get('roomId')
sender_id = payload.get('senderId')
if sender_id != agent_id:
asyncio.create_task(handle_llm_response(text, room_id))
async def handle_llm_response(user_input: str, room_id: str):
try:
print(f"[llm] Running agent for input: {user_input}")
result = qa.run(user_input)
except Exception as e:
result = f"Sorry, something went wrong: {str(e)}"
# LLM's message
bot_msg = {
"senderId": agent_id,
"senderName": agent_name,
"text": result,
"roomId": room_id,
"createdAt": int(datetime.utcnow().timestamp() * 1000),
"source": "llm",
}
# Broadcast LLM response
await sio.emit("messageBroadcast", bot_msg, room=room_id)
# Emit messageComplete
await sio.emit("messageComplete", {
"roomId": room_id,
"responseId": str(uuid.uuid4())
}, room=room_id)
if __name__ == '__main__':
uvicorn.run(asgi_app, host="0.0.0.0", port=8000)