-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb.py
More file actions
347 lines (283 loc) · 10.6 KB
/
db.py
File metadata and controls
347 lines (283 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""
Database module for PostgreSQL interactions.
Handles user management, message logging, and database operations.
"""
import asyncio
import logging
from datetime import datetime, timezone
from typing import Optional, Dict, Any
import asyncpg
from config import DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME
# Set up logging
logger = logging.getLogger(__name__)
async def get_pool() -> asyncpg.Pool:
"""
Creates and returns a connection pool to PostgreSQL database.
Returns:
asyncpg.Pool: Database connection pool
Raises:
Exception: If connection to database fails
"""
try:
pool = await asyncpg.create_pool(
host=DB_HOST,
port=DB_PORT,
user=DB_USER,
password=DB_PASSWORD,
database=DB_NAME,
min_size=1,
max_size=10,
command_timeout=60
)
logger.info(f"Successfully created database connection pool to {DB_HOST}:{DB_PORT}")
return pool
except Exception as e:
logger.error(f"Failed to create database connection pool: {e}")
raise
async def init_db(pool: asyncpg.Pool) -> None:
"""
Initializes the database by creating required tables if they don't exist.
Args:
pool: Database connection pool
Raises:
Exception: If table creation fails
"""
logger.info("Initializing database tables...")
# SQL for creating users table
create_users_table = """
CREATE TABLE IF NOT EXISTS users (
user_id BIGINT PRIMARY KEY,
username TEXT,
first_name TEXT,
join_date TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
is_approved BOOLEAN DEFAULT FALSE,
spam_reports INTEGER DEFAULT 0
);
"""
# SQL for creating messages table
create_messages_table = """
CREATE TABLE IF NOT EXISTS messages (
message_id SERIAL PRIMARY KEY,
user_id BIGINT REFERENCES users(user_id) ON DELETE CASCADE,
message_text TEXT,
is_spam BOOLEAN DEFAULT FALSE,
timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""
# SQL for creating indexes for better performance
create_indexes = [
"CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);",
"CREATE INDEX IF NOT EXISTS idx_users_is_approved ON users(is_approved);",
"CREATE INDEX IF NOT EXISTS idx_messages_user_id ON messages(user_id);",
"CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp);",
"CREATE INDEX IF NOT EXISTS idx_messages_is_spam ON messages(is_spam);"
]
try:
async with pool.acquire() as connection:
# Create tables
await connection.execute(create_users_table)
await connection.execute(create_messages_table)
# Create indexes
for index_sql in create_indexes:
await connection.execute(index_sql)
logger.info("Database tables and indexes created successfully")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def add_new_user(
pool: asyncpg.Pool,
user_id: int,
username: Optional[str],
first_name: Optional[str]
) -> bool:
"""
Adds a new user to the users table.
Args:
pool: Database connection pool
user_id: Telegram user ID
username: Telegram username (can be None)
first_name: User's first name (can be None)
Returns:
bool: True if user was added successfully, False if user already exists
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
# Check if user already exists
existing_user = await connection.fetchrow(
"SELECT user_id FROM users WHERE user_id = $1",
user_id
)
if existing_user:
logger.info(f"User {user_id} already exists in database")
return False
# Insert new user
await connection.execute(
"""
INSERT INTO users (user_id, username, first_name, join_date)
VALUES ($1, $2, $3, $4)
""",
user_id, username, first_name, datetime.now(timezone.utc)
)
logger.info(f"Successfully added new user {user_id} ({username}) to database")
return True
except Exception as e:
logger.error(f"Failed to add user {user_id}: {e}")
raise
async def get_user(pool: asyncpg.Pool, user_id: int) -> Optional[Dict[str, Any]]:
"""
Retrieves user information by user ID.
Args:
pool: Database connection pool
user_id: Telegram user ID
Returns:
Optional[Dict[str, Any]]: User data dict or None if user not found
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
user_record = await connection.fetchrow(
"""
SELECT user_id, username, first_name, join_date, is_approved, spam_reports
FROM users
WHERE user_id = $1
""",
user_id
)
if user_record:
user_data = {
'user_id': user_record['user_id'],
'username': user_record['username'],
'first_name': user_record['first_name'],
'join_date': user_record['join_date'],
'is_approved': user_record['is_approved'],
'spam_reports': user_record['spam_reports']
}
logger.debug(f"Retrieved user data for {user_id}")
return user_data
else:
logger.debug(f"User {user_id} not found in database")
return None
except Exception as e:
logger.error(f"Failed to get user {user_id}: {e}")
raise
async def approve_user(pool: asyncpg.Pool, user_id: int) -> bool:
"""
Approves a user by setting is_approved to TRUE.
Args:
pool: Database connection pool
user_id: Telegram user ID
Returns:
bool: True if user was approved, False if user not found
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
result = await connection.execute(
"UPDATE users SET is_approved = TRUE WHERE user_id = $1",
user_id
)
# Check if any rows were affected
rows_affected = int(result.split()[-1])
if rows_affected > 0:
logger.info(f"Successfully approved user {user_id}")
return True
else:
logger.warning(f"User {user_id} not found for approval")
return False
except Exception as e:
logger.error(f"Failed to approve user {user_id}: {e}")
raise
async def log_message(
pool: asyncpg.Pool,
user_id: int,
message_text: str,
is_spam: bool = False
) -> int:
"""
Logs a message to the messages table.
Args:
pool: Database connection pool
user_id: Telegram user ID
message_text: Text content of the message
is_spam: Whether the message is classified as spam
Returns:
int: ID of the inserted message record
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
message_id = await connection.fetchval(
"""
INSERT INTO messages (user_id, message_text, is_spam, timestamp)
VALUES ($1, $2, $3, $4)
RETURNING message_id
""",
user_id, message_text, is_spam, datetime.now(timezone.utc)
)
logger.debug(f"Logged message {message_id} from user {user_id} (spam: {is_spam})")
return message_id
except Exception as e:
logger.error(f"Failed to log message from user {user_id}: {e}")
raise
async def increment_spam_reports(pool: asyncpg.Pool, user_id: int) -> bool:
"""
Increments the spam report count for a user.
Args:
pool: Database connection pool
user_id: Telegram user ID
Returns:
bool: True if spam reports were incremented, False if user not found
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
result = await connection.execute(
"UPDATE users SET spam_reports = spam_reports + 1 WHERE user_id = $1",
user_id
)
rows_affected = int(result.split()[-1])
if rows_affected > 0:
logger.info(f"Incremented spam reports for user {user_id}")
return True
else:
logger.warning(f"User {user_id} not found for spam report increment")
return False
except Exception as e:
logger.error(f"Failed to increment spam reports for user {user_id}: {e}")
raise
async def get_user_stats(pool: asyncpg.Pool) -> Dict[str, int]:
"""
Gets basic statistics about users in the database.
Args:
pool: Database connection pool
Returns:
Dict[str, int]: Dictionary with user statistics
Raises:
Exception: If database operation fails
"""
try:
async with pool.acquire() as connection:
stats = await connection.fetchrow(
"""
SELECT
COUNT(*) as total_users,
COUNT(*) FILTER (WHERE is_approved = TRUE) as approved_users,
COUNT(*) FILTER (WHERE spam_reports > 0) as users_with_reports
FROM users
"""
)
return {
'total_users': stats['total_users'],
'approved_users': stats['approved_users'],
'users_with_reports': stats['users_with_reports'],
'pending_approval': stats['total_users'] - stats['approved_users']
}
except Exception as e:
logger.error(f"Failed to get user stats: {e}")
raise