Skip to content

Commit 99aeec1

Browse files
committed
feat: use DB connection pool
1 parent 27abf8e commit 99aeec1

2 files changed

Lines changed: 53 additions & 44 deletions

File tree

app/helpers/postgres.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,35 @@
33

44
import psycopg2
55
from fastapi import HTTPException
6-
from psycopg2 import Error
6+
from psycopg2 import Error, pool
77
from psycopg2.extensions import connection, cursor
88

99
from helpers.hasura import untrack_table, track_table
1010
from helpers.timer import Timer
1111
from models import Metadata, BatchRequest, CreateTableResult
1212

13-
conn: connection = None
14-
cur: cursor = None
13+
psql_pool = psycopg2.pool.ThreadedConnectionPool(1, 10,
14+
user=os.environ.get('POSTGRES_USER'),
15+
password=os.environ.get('POSTGRES_PASSWORD'),
16+
host=os.environ.get('POSTGRES_HOST'),
17+
port=os.environ.get('POSTGRES_PORT'),
18+
database=os.environ.get('POSTGRES_DB'))
1519

16-
try:
17-
conn = psycopg2.connect(user=os.environ.get('POSTGRES_USER'),
18-
password=os.environ.get('POSTGRES_PASSWORD'),
19-
host=os.environ.get('POSTGRES_HOST'),
20-
port=os.environ.get('POSTGRES_PORT'),
21-
database=os.environ.get('POSTGRES_DB'))
22-
cur = conn.cursor()
23-
except (Exception, Error) as error:
24-
print("Error while connecting to PostgreSQL", error)
25-
if conn:
26-
conn.close()
27-
if cur:
28-
cur.close()
29-
print("PostgreSQL connection is closed")
30-
exit(1)
20+
21+
# FastAPI dependency
22+
# See https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/
23+
def get_db_conn():
24+
conn = psql_pool.getconn()
25+
try:
26+
yield conn
27+
finally:
28+
psql_pool.putconn(conn)
3129

3230

33-
def execute_up_down(metadata: Metadata):
31+
def shutdown_db():
32+
psql_pool.closeall()
33+
34+
def execute_up_down(cur: cursor, metadata: Metadata):
3435
# Create table with sql_up
3536
try:
3637
cur.execute(metadata.sql_up)
@@ -58,7 +59,7 @@ def execute_up_down(metadata: Metadata):
5859
)
5960

6061

61-
def create_table(metadata: Metadata) -> CreateTableResult:
62+
def create_table(cur: cursor, metadata: Metadata) -> CreateTableResult:
6263
"""
6364
Create table as specified in metadata.
6465
@@ -72,16 +73,16 @@ def create_table(metadata: Metadata) -> CreateTableResult:
7273
# Initialise Tables table if not already
7374
cur.execute(open("app/init.sql", "r").read())
7475

75-
cmd = r"SELECT up, down FROM Tables WHERE table_name = %s"
76+
cmd = r"SELECT up, down FROM tables WHERE table_name = %s"
7677
metadata.table_name = metadata.table_name.lower()
7778
cur.execute(cmd, (metadata.table_name,))
7879
table_sql = cur.fetchone()
7980
if not table_sql:
8081
# Execute create table
81-
execute_up_down(metadata)
82+
execute_up_down(cur, metadata)
8283

8384
# Store metadata
84-
cmd = r"INSERT INTO Tables(table_name, up, down) VALUES (%s, %s, %s)"
85+
cmd = r"INSERT INTO tables(table_name, up, down) VALUES (%s, %s, %s)"
8586
cur.execute(cmd, (metadata.table_name, metadata.sql_up, metadata.sql_down))
8687

8788
return CreateTableResult.CREATED
@@ -90,18 +91,18 @@ def create_table(metadata: Metadata) -> CreateTableResult:
9091

9192
# Re-create
9293
cur.execute(table_sql[1]) # old sql_down
93-
execute_up_down(metadata)
94+
execute_up_down(cur, metadata)
9495

9596
# Store new metadata
96-
cmd = r"UPDATE Tables SET up = %s, down = %s WHERE table_name = %s"
97+
cmd = r"UPDATE tables SET up = %s, down = %s WHERE table_name = %s"
9798
cur.execute(cmd, (metadata.sql_up, metadata.sql_down, metadata.table_name))
9899

99100
return CreateTableResult.UPDATED
100101

101102
return CreateTableResult.NONE
102103

103104

104-
def get_primary_key_columns(table_name: str) -> list[str]:
105+
def get_primary_key_columns(cur: cursor, table_name: str) -> list[str]:
105106
cmd = f"""
106107
SELECT c.column_name
107108
FROM information_schema.columns c
@@ -119,9 +120,9 @@ def get_primary_key_columns(table_name: str) -> list[str]:
119120
return [row[0] for row in cur.fetchall()]
120121

121122

122-
def execute_upsert(metadata: Metadata, payload: list[Any]):
123+
def execute_upsert(cur: cursor, metadata: Metadata, payload: list[Any]):
123124
columns = [f'"{col}"' for col in metadata.columns]
124-
key_columns = [f'"{col}"' for col in get_primary_key_columns(metadata.table_name)]
125+
key_columns = [f'"{col}"' for col in get_primary_key_columns(cur, metadata.table_name)]
125126
non_key_columns = [col for col in columns if col not in key_columns]
126127

127128
cmd = f"""
@@ -135,8 +136,8 @@ def execute_upsert(metadata: Metadata, payload: list[Any]):
135136
cur.executemany(cmd, values)
136137

137138

138-
def execute_delete(metadata: Metadata, payload: list[Any]):
139-
key_columns = get_primary_key_columns(metadata.table_name)
139+
def execute_delete(cur: cursor, metadata: Metadata, payload: list[Any]):
140+
key_columns = get_primary_key_columns(cur, metadata.table_name)
140141
quoted_key_columns = [f'"{col}"' for col in key_columns]
141142

142143
cmd = f"""
@@ -148,7 +149,7 @@ def execute_delete(metadata: Metadata, payload: list[Any]):
148149
cur.execute(cmd, (values,))
149150

150151

151-
def do_insert(metadata: Metadata, payload: list[Any]):
152+
def do_insert(cur: cursor, metadata: Metadata, payload: list[Any]):
152153
t = Timer(f"sql_before of {metadata.table_name}").start()
153154
if metadata.sql_before:
154155
try:
@@ -162,10 +163,10 @@ def do_insert(metadata: Metadata, payload: list[Any]):
162163

163164
t = Timer(f"insert of {metadata.table_name}").start()
164165
try:
165-
execute_upsert(metadata, payload)
166+
execute_upsert(cur, metadata, payload)
166167
if metadata.write_mode == 'overwrite':
167168
# Delete rows not in payload
168-
execute_delete(metadata, payload)
169+
execute_delete(cur, metadata, payload)
169170
except Error as e:
170171
raise HTTPException(
171172
status_code=400,
@@ -185,14 +186,16 @@ def do_insert(metadata: Metadata, payload: list[Any]):
185186
t.stop()
186187

187188

188-
def do_batch_insert(requests: list[BatchRequest]):
189+
def do_batch_insert(conn: connection, requests: list[BatchRequest]):
190+
cur = conn.cursor()
191+
189192
create_table_results = {}
190193
for request in requests:
191194
try:
192-
create_table_result = create_table(request.metadata)
195+
create_table_result = create_table(cur, request.metadata)
193196
create_table_results[request.metadata.table_name.lower()] = create_table_result
194197

195-
do_insert(request.metadata, request.payload)
198+
do_insert(cur, request.metadata, request.payload)
196199
except HTTPException as e:
197200
print(e.detail)
198201
conn.rollback()

app/main.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import os
2-
from typing import Any
2+
from contextlib import asynccontextmanager
3+
from typing import Any, Annotated
34

45
import uvicorn
5-
from fastapi import FastAPI, HTTPException, Depends
6+
from fastapi import FastAPI, Depends
67
from fastapi.middleware.cors import CORSMiddleware
8+
from psycopg2.extensions import connection
79

8-
from helpers.postgres import do_batch_insert
10+
from helpers.postgres import do_batch_insert, get_db_conn, shutdown_db
911
from helpers.auth import validate_api_key
1012
from models import Metadata, BatchRequest
1113

14+
@asynccontextmanager
15+
async def lifespan(_app: FastAPI):
16+
yield
17+
shutdown_db()
1218

13-
app = FastAPI()
19+
app = FastAPI(lifespan=lifespan)
1420
app.add_middleware(
1521
CORSMiddleware,
1622
allow_origins=["http://localhost", "http://scraper"],
@@ -21,14 +27,14 @@
2127

2228

2329
@app.post("/batch_insert", dependencies=[Depends(validate_api_key)])
24-
def batch_insert(requests: list[BatchRequest]):
25-
do_batch_insert(requests)
30+
def batch_insert(requests: list[BatchRequest], conn: Annotated[connection, Depends(get_db_conn)]):
31+
do_batch_insert(conn, requests)
2632
return {}
2733

2834

2935
@app.post("/insert", dependencies=[Depends(validate_api_key)])
30-
def insert(metadata: Metadata, payload: list[Any]):
31-
do_batch_insert([BatchRequest(metadata=metadata, payload=payload)])
36+
def insert(metadata: Metadata, payload: list[Any], conn: Annotated[connection, Depends(get_db_conn)]):
37+
do_batch_insert(conn, [BatchRequest(metadata=metadata, payload=payload)])
3238
return {}
3339

3440

0 commit comments

Comments
 (0)