33
44import psycopg2
55from fastapi import HTTPException
6- from psycopg2 import Error
6+ from psycopg2 import Error , pool
77from psycopg2 .extensions import connection , cursor
88
99from helpers .hasura import untrack_table , track_table
1010from helpers .timer import Timer
1111from models import Metadata , BatchRequest , CreateTableResult
1212
13- conn : connection = None
14- cur : cursor = None
13+ psql_pool = psycopg2 .pool .ThreadedConnectionPool (1 , 10 ,
14+ user = os .environ .get ('POSTGRES_USER' ),
15+ password = os .environ .get ('POSTGRES_PASSWORD' ),
16+ host = os .environ .get ('POSTGRES_HOST' ),
17+ port = os .environ .get ('POSTGRES_PORT' ),
18+ database = os .environ .get ('POSTGRES_DB' ))
1519
16- try :
17- conn = psycopg2 .connect (user = os .environ .get ('POSTGRES_USER' ),
18- password = os .environ .get ('POSTGRES_PASSWORD' ),
19- host = os .environ .get ('POSTGRES_HOST' ),
20- port = os .environ .get ('POSTGRES_PORT' ),
21- database = os .environ .get ('POSTGRES_DB' ))
22- cur = conn .cursor ()
23- except (Exception , Error ) as error :
24- print ("Error while connecting to PostgreSQL" , error )
25- if conn :
26- conn .close ()
27- if cur :
28- cur .close ()
29- print ("PostgreSQL connection is closed" )
30- exit (1 )
20+
21+ # FastAPI dependency
22+ # See https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/
23+ def get_db_conn ():
24+ conn = psql_pool .getconn ()
25+ try :
26+ yield conn
27+ finally :
28+ psql_pool .putconn (conn )
3129
3230
33- def execute_up_down (metadata : Metadata ):
31+ def shutdown_db ():
32+ psql_pool .closeall ()
33+
34+ def execute_up_down (cur : cursor , metadata : Metadata ):
3435 # Create table with sql_up
3536 try :
3637 cur .execute (metadata .sql_up )
@@ -58,7 +59,7 @@ def execute_up_down(metadata: Metadata):
5859 )
5960
6061
61- def create_table (metadata : Metadata ) -> CreateTableResult :
62+ def create_table (cur : cursor , metadata : Metadata ) -> CreateTableResult :
6263 """
6364 Create table as specified in metadata.
6465
@@ -72,16 +73,16 @@ def create_table(metadata: Metadata) -> CreateTableResult:
7273 # Initialise Tables table if not already
7374 cur .execute (open ("app/init.sql" , "r" ).read ())
7475
75- cmd = r"SELECT up, down FROM Tables WHERE table_name = %s"
76+ cmd = r"SELECT up, down FROM tables WHERE table_name = %s"
7677 metadata .table_name = metadata .table_name .lower ()
7778 cur .execute (cmd , (metadata .table_name ,))
7879 table_sql = cur .fetchone ()
7980 if not table_sql :
8081 # Execute create table
81- execute_up_down (metadata )
82+ execute_up_down (cur , metadata )
8283
8384 # Store metadata
84- cmd = r"INSERT INTO Tables (table_name, up, down) VALUES (%s, %s, %s)"
85+ cmd = r"INSERT INTO tables (table_name, up, down) VALUES (%s, %s, %s)"
8586 cur .execute (cmd , (metadata .table_name , metadata .sql_up , metadata .sql_down ))
8687
8788 return CreateTableResult .CREATED
@@ -90,18 +91,18 @@ def create_table(metadata: Metadata) -> CreateTableResult:
9091
9192 # Re-create
9293 cur .execute (table_sql [1 ]) # old sql_down
93- execute_up_down (metadata )
94+ execute_up_down (cur , metadata )
9495
9596 # Store new metadata
96- cmd = r"UPDATE Tables SET up = %s, down = %s WHERE table_name = %s"
97+ cmd = r"UPDATE tables SET up = %s, down = %s WHERE table_name = %s"
9798 cur .execute (cmd , (metadata .sql_up , metadata .sql_down , metadata .table_name ))
9899
99100 return CreateTableResult .UPDATED
100101
101102 return CreateTableResult .NONE
102103
103104
104- def get_primary_key_columns (table_name : str ) -> list [str ]:
105+ def get_primary_key_columns (cur : cursor , table_name : str ) -> list [str ]:
105106 cmd = f"""
106107 SELECT c.column_name
107108 FROM information_schema.columns c
@@ -119,9 +120,9 @@ def get_primary_key_columns(table_name: str) -> list[str]:
119120 return [row [0 ] for row in cur .fetchall ()]
120121
121122
122- def execute_upsert (metadata : Metadata , payload : list [Any ]):
123+ def execute_upsert (cur : cursor , metadata : Metadata , payload : list [Any ]):
123124 columns = [f'"{ col } "' for col in metadata .columns ]
124- key_columns = [f'"{ col } "' for col in get_primary_key_columns (metadata .table_name )]
125+ key_columns = [f'"{ col } "' for col in get_primary_key_columns (cur , metadata .table_name )]
125126 non_key_columns = [col for col in columns if col not in key_columns ]
126127
127128 cmd = f"""
@@ -135,8 +136,8 @@ def execute_upsert(metadata: Metadata, payload: list[Any]):
135136 cur .executemany (cmd , values )
136137
137138
138- def execute_delete (metadata : Metadata , payload : list [Any ]):
139- key_columns = get_primary_key_columns (metadata .table_name )
139+ def execute_delete (cur : cursor , metadata : Metadata , payload : list [Any ]):
140+ key_columns = get_primary_key_columns (cur , metadata .table_name )
140141 quoted_key_columns = [f'"{ col } "' for col in key_columns ]
141142
142143 cmd = f"""
@@ -148,7 +149,7 @@ def execute_delete(metadata: Metadata, payload: list[Any]):
148149 cur .execute (cmd , (values ,))
149150
150151
151- def do_insert (metadata : Metadata , payload : list [Any ]):
152+ def do_insert (cur : cursor , metadata : Metadata , payload : list [Any ]):
152153 t = Timer (f"sql_before of { metadata .table_name } " ).start ()
153154 if metadata .sql_before :
154155 try :
@@ -162,10 +163,10 @@ def do_insert(metadata: Metadata, payload: list[Any]):
162163
163164 t = Timer (f"insert of { metadata .table_name } " ).start ()
164165 try :
165- execute_upsert (metadata , payload )
166+ execute_upsert (cur , metadata , payload )
166167 if metadata .write_mode == 'overwrite' :
167168 # Delete rows not in payload
168- execute_delete (metadata , payload )
169+ execute_delete (cur , metadata , payload )
169170 except Error as e :
170171 raise HTTPException (
171172 status_code = 400 ,
@@ -185,14 +186,16 @@ def do_insert(metadata: Metadata, payload: list[Any]):
185186 t .stop ()
186187
187188
188- def do_batch_insert (requests : list [BatchRequest ]):
189+ def do_batch_insert (conn : connection , requests : list [BatchRequest ]):
190+ cur = conn .cursor ()
191+
189192 create_table_results = {}
190193 for request in requests :
191194 try :
192- create_table_result = create_table (request .metadata )
195+ create_table_result = create_table (cur , request .metadata )
193196 create_table_results [request .metadata .table_name .lower ()] = create_table_result
194197
195- do_insert (request .metadata , request .payload )
198+ do_insert (cur , request .metadata , request .payload )
196199 except HTTPException as e :
197200 print (e .detail )
198201 conn .rollback ()
0 commit comments