-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsupabase_memory.py
More file actions
124 lines (104 loc) · 4.43 KB
/
supabase_memory.py
File metadata and controls
124 lines (104 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import List, Dict, Any, Optional
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.memory import ConversationBufferMemory
from supabase import create_client, Client
from config import Config
import logging
import json
import uuid
from datetime import datetime
logger = logging.getLogger(__name__)
class SupabaseChatMessageHistory(BaseChatMessageHistory):
def __init__(self, session_id: Optional[str] = None):
if not Config.SUPABASE_URL or not Config.SUPABASE_KEY:
logger.error("Supabase credentials required for conversation memory")
raise ValueError("Supabase URL and KEY must be configured for conversation memory")
self.session_id = session_id or str(uuid.uuid4())
self.supabase: Client = create_client(Config.SUPABASE_URL, Config.SUPABASE_KEY)
@property
def messages(self) -> List[BaseMessage]:
try:
result = self.supabase.table('chat_messages')\
.select("*")\
.eq('session_id', self.session_id)\
.order('created_at')\
.execute()
messages = []
for record in result.data:
if record['message_type'] == 'human':
messages.append(HumanMessage(content=record['content']))
elif record['message_type'] == 'ai':
messages.append(AIMessage(content=record['content']))
return messages
except Exception as e:
logger.error(f"Error loading messages: {e}")
return []
def add_message(self, message: BaseMessage) -> None:
try:
self._ensure_session_exists()
message_type = 'human' if isinstance(message, HumanMessage) else 'ai'
self.supabase.table('chat_messages').insert({
'id': str(uuid.uuid4()),
'session_id': self.session_id,
'message_type': message_type,
'content': message.content,
'created_at': datetime.now().isoformat(),
'metadata': json.dumps(getattr(message, 'additional_kwargs', {}))
}).execute()
except Exception as e:
logger.error(f"Error saving message: {e}")
def _ensure_session_exists(self):
try:
self.supabase.table('chat_sessions').upsert({
'id': self.session_id,
'created_at': datetime.now().isoformat(),
'updated_at': datetime.now().isoformat(),
'metadata': '{}'
}, on_conflict='id').execute()
except Exception as e:
logger.error(f"Error ensuring session exists: {e}")
def clear(self) -> None:
try:
self.supabase.table('chat_messages')\
.delete()\
.eq('session_id', self.session_id)\
.execute()
except Exception as e:
logger.error(f"Error clearing messages: {e}")
def create_supabase_memory(session_id: Optional[str] = None) -> ConversationBufferMemory:
chat_history = SupabaseChatMessageHistory(session_id=session_id)
chat_history._ensure_session_exists()
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
chat_memory=chat_history
)
return memory, chat_history.session_id
def get_all_sessions() -> List[Dict[str, Any]]:
try:
if not Config.SUPABASE_URL or not Config.SUPABASE_KEY:
return []
supabase = create_client(Config.SUPABASE_URL, Config.SUPABASE_KEY)
result = supabase.table('chat_sessions')\
.select("id, created_at, updated_at")\
.order('updated_at', desc=True)\
.limit(20)\
.execute()
return [
{
"session_id": record['id'],
"created_at": record['created_at'],
"updated_at": record['updated_at']
}
for record in result.data
]
except Exception as e:
logger.error(f"Error getting sessions: {e}")
return []
def clear_session(session_id: str):
try:
chat_history = SupabaseChatMessageHistory(session_id=session_id)
chat_history.clear()
except Exception as e:
logger.error(f"Error clearing session: {e}")