diff --git a/dbtai/utils.py b/dbtai/utils.py index 76c0a49..8157c36 100644 --- a/dbtai/utils.py +++ b/dbtai/utils.py @@ -1,6 +1,7 @@ import yaml import os import appdirs +import sqlglot def get_config(): configdir = appdirs.user_data_dir("dbtai", "dbtai") @@ -8,3 +9,27 @@ def get_config(): with open(os.path.join(configdir, "config.yaml"), "r") as f: return yaml.load(f, Loader=yaml.FullLoader) + + +def find_ctes(expression, ctes=[]): + if isinstance(expression, sqlglot.exp.CTE): + ctes.append(expression) + if hasattr(expression, "args"): + for child in expression.args.values(): + if isinstance(child, list): + for item in child: + if hasattr(item, "args"): + find_ctes(item, ctes) + else: + find_ctes(child, ctes) + return ctes + +def final_selected_columns(sql_query): + parsed = sqlglot.parse_one(sql_query) + ctes = find_ctes(parsed) + + columns = [expr.alias_or_name for expr in ctes[-1].this.expressions] + return columns + + +parsed = sqlglot.parse_one(sql_query) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e650c97..60aca2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ inquirer = "^3.2.4" appdirs = "^1.4.4" pyyaml = "^6.0.1" sqlfluff = "^2.3.5" +sqlglot = ">=0.1.0" [tool.poetry.dev-dependencies] pytest = "^8.0.1"