11import abc
2+ import collections
23import functools
34import json
45import logging
1213from sqlalchemy .sql .compiler import SQLCompiler
1314from sqlalchemy .sql .expression import ClauseElement , Executable
1415
15- from subsetter .common import DatabaseConfig , parse_table_name
16+ from subsetter .common import DatabaseConfig , parse_table_name , pydantic_search
1617from subsetter .config_model import (
1718 ConflictStrategy ,
1819 DatabaseOutputConfig ,
2122)
2223from subsetter .filters import FilterOmit , FilterView , FilterViewChain
2324from subsetter .metadata import DatabaseMetadata
24- from subsetter .plan_model import SQLTableIdentifier
25+ from subsetter .plan_model import SQLLeftJoin , SQLTableIdentifier
2526from subsetter .planner import SubsetPlan
2627from subsetter .solver import toposort
2728
@@ -69,7 +70,7 @@ def create(
6970 select : sa .Select ,
7071 * ,
7172 name : str = "" ,
72- primary_key : Tuple [str , ...] = (),
73+ indexes : Iterable [ Tuple [str , ...] ] = (),
7374 ) -> Tuple [sa .Table , int ]:
7475 """
7576 Create a temporary table on the passed connection generated by the passed
@@ -82,9 +83,8 @@ def create(
8283 schema: The schema to create the temporary table within. For some dialects
8384 temporary tables always exist in their own schema and this parameter
8485 will be ignored.
85- primary_key: If set will mark the set of columns passed as primary keys in
86- the temporary table. This tuple should match a subset of the
87- column names in the select query.
86+ indexes: creates an index on each tuple of columns listed. This is useful
87+ if future queries are likely to reference these columns.
8888
8989 Returns a tuple containing the generated table object and the number of rows that
9090 were inserted in the table.
@@ -106,10 +106,7 @@ def create(
106106 metadata ,
107107 schema = temp_schema ,
108108 prefixes = ["TEMPORARY" ],
109- * (
110- sa .Column (col .name , col .type , primary_key = col .name in primary_key )
111- for col in select .selected_columns
112- ),
109+ * (sa .Column (col .name , col .type ) for col in select .selected_columns ),
113110 )
114111 try :
115112 metadata .create_all (conn )
@@ -122,6 +119,22 @@ def create(
122119 if "--read-only" not in str (exc ):
123120 raise
124121
122+ for idx , index_cols in enumerate (indexes ):
123+ # For some dialects/data types we may not be able to construct an index. We just do our
124+ # best here instead of hard failing.
125+ try :
126+ sa .Index (
127+ f"{ temp_name } _idx_{ idx } " ,
128+ * (table_obj .columns [col_name ] for col_name in index_cols ),
129+ ).create (bind = conn )
130+ except sa .exc .OperationalError :
131+ LOGGER .warning (
132+ "Failed to create index %s on temporary table %s" ,
133+ index_cols ,
134+ temp_name ,
135+ exc_info = True ,
136+ )
137+
125138 # Copy data into the temporary table
126139 stmt = table_obj .insert ().from_select (list (table_obj .columns ), select )
127140 LOGGER .debug (
@@ -834,6 +847,18 @@ def _materialize_tables(
834847 conn : sa .Connection ,
835848 plan : SubsetPlan ,
836849 ) -> None :
850+ # Figure out what sets of columns are going to be queried for our materialized tables.
851+ joined_columns = collections .defaultdict (set )
852+ for data in pydantic_search (plan ):
853+ if not isinstance (data , SQLLeftJoin ):
854+ continue
855+ table_id = data .right
856+ if not table_id .sampled :
857+ continue
858+ joined_columns [(table_id .table_schema , table_id .table_name )].add (
859+ tuple (data .right_columns )
860+ )
861+
837862 materialization_order = self ._materialization_order (meta , plan )
838863 for schema , table_name , ref_count in materialization_order :
839864 table = meta .tables [(schema , table_name )]
@@ -866,7 +891,7 @@ def _materialize_tables(
866891 schema ,
867892 table_q ,
868893 name = table_name ,
869- primary_key = table . primary_key ,
894+ indexes = joined_columns [( schema , table_name )] ,
870895 )
871896 )
872897 self .cached_table_sizes [(schema , table_name )] = rowcount
@@ -889,7 +914,7 @@ def _materialize_tables(
889914 schema ,
890915 meta .temp_tables [(schema , table_name , 0 )].select (),
891916 name = table_name ,
892- primary_key = table . primary_key ,
917+ indexes = joined_columns [( schema , table_name )] ,
893918 )
894919 )
895920 LOGGER .info (
0 commit comments