Skip to content

Commit dae5c2e

Browse files
authored
Merge branch 'main' into jo/sort_python_model_dfs_in_tests
2 parents bd98345 + a1b21e3 commit dae5c2e

File tree

6 files changed

+146
-24
lines changed

6 files changed

+146
-24
lines changed

.circleci/continue_config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ jobs:
9393
- run:
9494
name: Run linters and code style checks
9595
command: make py-style
96-
- run:
97-
name: Exercise the benchmarks
98-
command: make benchmark-ci
96+
# - run:
97+
# name: Exercise the benchmarks
98+
# command: make benchmark-ci
9999
- run:
100100
name: Run cicd tests
101101
command: make cicd-test

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,7 @@ def audit(
22792279
snapshot=snapshot,
22802280
start=start,
22812281
end=end,
2282+
execution_time=execution_time,
22822283
snapshots=self.snapshots,
22832284
):
22842285
audit_id = f"{audit_result.audit.name}"

sqlmesh/core/renderer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,14 @@ def _resolve_table(table: str | exp.Table) -> str:
196196
**kwargs,
197197
}
198198

199+
if this_model:
200+
render_kwargs["this_model"] = this_model
201+
202+
macro_evaluator.locals.update(render_kwargs)
203+
199204
variables = kwargs.pop("variables", {})
205+
if variables:
206+
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
200207

201208
expressions = [self._expression]
202209
if isinstance(self._expression, d.Jinja):
@@ -268,14 +275,6 @@ def _resolve_table(table: str | exp.Table) -> str:
268275
f"Could not parse the rendered jinja at '{self._path}'.\n{ex}"
269276
) from ex
270277

271-
if this_model:
272-
render_kwargs["this_model"] = this_model
273-
274-
macro_evaluator.locals.update(render_kwargs)
275-
276-
if variables:
277-
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
278-
279278
for definition in self._macro_definitions:
280279
try:
281280
macro_evaluator.evaluate(definition)

sqlmesh/core/scheduler.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def _dag(
659659
}
660660
snapshots_to_create = snapshots_to_create or set()
661661
original_snapshots_to_create = snapshots_to_create.copy()
662+
upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}
662663

663664
snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
664665
dag = DAG[SchedulingUnit]()
@@ -670,12 +671,15 @@ def _dag(
670671
snapshot = self.snapshots_by_name[snapshot_id.name]
671672
intervals = intervals_per_snapshot.get(snapshot.name, [])
672673

673-
upstream_dependencies: t.List[SchedulingUnit] = []
674+
upstream_dependencies: t.Set[SchedulingUnit] = set()
674675

675676
for p_sid in snapshot.parents:
676-
upstream_dependencies.extend(
677+
upstream_dependencies.update(
677678
self._find_upstream_dependencies(
678-
p_sid, intervals_per_snapshot, original_snapshots_to_create
679+
p_sid,
680+
intervals_per_snapshot,
681+
original_snapshots_to_create,
682+
upstream_dependencies_cache,
679683
)
680684
)
681685

@@ -726,29 +730,42 @@ def _find_upstream_dependencies(
726730
parent_sid: SnapshotId,
727731
intervals_per_snapshot: t.Dict[str, Intervals],
728732
snapshots_to_create: t.Set[SnapshotId],
729-
) -> t.List[SchedulingUnit]:
733+
cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]],
734+
) -> t.Set[SchedulingUnit]:
730735
if parent_sid not in self.snapshots:
731-
return []
736+
return set()
737+
if parent_sid in cache:
738+
return cache[parent_sid]
732739

733740
p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
734741

742+
parent_node: t.Optional[SchedulingUnit] = None
735743
if p_intervals:
736744
if len(p_intervals) > 1:
737-
return [DummyNode(snapshot_name=parent_sid.name)]
738-
interval = p_intervals[0]
739-
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
740-
if parent_sid in snapshots_to_create:
741-
return [CreateNode(snapshot_name=parent_sid.name)]
745+
parent_node = DummyNode(snapshot_name=parent_sid.name)
746+
else:
747+
interval = p_intervals[0]
748+
parent_node = EvaluateNode(
749+
snapshot_name=parent_sid.name, interval=interval, batch_index=0
750+
)
751+
elif parent_sid in snapshots_to_create:
752+
parent_node = CreateNode(snapshot_name=parent_sid.name)
753+
754+
if parent_node is not None:
755+
cache[parent_sid] = {parent_node}
756+
return {parent_node}
757+
742758
# This snapshot has no intervals and doesn't need creation which means
743759
# that it can be a transitive dependency
744-
transitive_deps: t.List[SchedulingUnit] = []
760+
transitive_deps: t.Set[SchedulingUnit] = set()
745761
parent_snapshot = self.snapshots[parent_sid]
746762
for grandparent_sid in parent_snapshot.parents:
747-
transitive_deps.extend(
763+
transitive_deps.update(
748764
self._find_upstream_dependencies(
749-
grandparent_sid, intervals_per_snapshot, snapshots_to_create
765+
grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
750766
)
751767
)
768+
cache[parent_sid] = transitive_deps
752769
return transitive_deps
753770

754771
def _run_or_audit(

tests/core/test_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12158,3 +12158,21 @@ def test_grants_empty_values():
1215812158
def test_grants_table_type(kind: t.Union[str, _ModelKind], expected: DataObjectType):
1215912159
model = create_sql_model("test_table", parse_one("SELECT 1 as id"), kind=kind)
1216012160
assert model.grants_table_type == expected
12161+
12162+
12163+
def test_model_macro_using_locals_called_from_jinja(assert_exp_eq) -> None:
12164+
@macro()
12165+
def execution_date(evaluator):
12166+
return f"""'{evaluator.locals.get("execution_date")}'"""
12167+
12168+
expressions = d.parse(
12169+
"""
12170+
MODEL (name db.table);
12171+
12172+
JINJA_QUERY_BEGIN;
12173+
SELECT {{ execution_date() }} AS col;
12174+
JINJA_END;
12175+
"""
12176+
)
12177+
model = load_sql_based_model(expressions)
12178+
assert_exp_eq(model.render_query(), '''SELECT '1970-01-01' AS "col"''')

tests/core/test_scheduler.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,3 +1126,90 @@ def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot
11261126
)
11271127
},
11281128
}
1129+
1130+
1131+
def test_dag_upstream_dependency_caching_with_complex_diamond(mocker: MockerFixture, make_snapshot):
1132+
r"""
1133+
Test that the upstream dependency caching correctly handles a complex diamond dependency graph.
1134+
1135+
Dependency graph:
1136+
A (has intervals)
1137+
/ \
1138+
B C (no intervals - transitive)
1139+
/ \ / \
1140+
D E F (no intervals - transitive)
1141+
\ / \ /
1142+
G H (has intervals - selected)
1143+
1144+
This creates multiple paths from G and H to A. Without caching, A's dependencies would be
1145+
computed multiple times (once for each path). With caching, they should be computed once
1146+
and reused.
1147+
"""
1148+
snapshots = {}
1149+
1150+
for name in ["a", "b", "c", "d", "e", "f", "g", "h"]:
1151+
snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id")))
1152+
snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING)
1153+
1154+
# A is the root
1155+
snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1156+
snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1157+
1158+
# Middle layer: D, E, F depend on B and/or C
1159+
snapshots["d"] = snapshots["d"].model_copy(update={"parents": (snapshots["b"].snapshot_id,)})
1160+
snapshots["e"] = snapshots["e"].model_copy(
1161+
update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)}
1162+
)
1163+
snapshots["f"] = snapshots["f"].model_copy(update={"parents": (snapshots["c"].snapshot_id,)})
1164+
1165+
# Bottom layer: G and H depend on D/E and E/F respectively
1166+
snapshots["g"] = snapshots["g"].model_copy(
1167+
update={"parents": (snapshots["d"].snapshot_id, snapshots["e"].snapshot_id)}
1168+
)
1169+
snapshots["h"] = snapshots["h"].model_copy(
1170+
update={"parents": (snapshots["e"].snapshot_id, snapshots["f"].snapshot_id)}
1171+
)
1172+
1173+
scheduler = Scheduler(
1174+
snapshots=list(snapshots.values()),
1175+
snapshot_evaluator=mocker.Mock(),
1176+
state_sync=mocker.Mock(),
1177+
default_catalog=None,
1178+
)
1179+
1180+
batched_intervals = {
1181+
snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1182+
snapshots["g"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1183+
snapshots["h"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1184+
}
1185+
1186+
full_dag = snapshots_to_dag(snapshots.values())
1187+
dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag)
1188+
1189+
# Verify the DAG structure:
1190+
# 1. A should be evaluated first (no dependencies)
1191+
# 2. Both G and H should depend on A (through transitive dependencies)
1192+
# 3. Transitive nodes (B, C, D, E, F) should not appear as separate evaluation nodes
1193+
expected_a_node = EvaluateNode(
1194+
snapshot_name='"a"',
1195+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1196+
batch_index=0,
1197+
)
1198+
1199+
expected_g_node = EvaluateNode(
1200+
snapshot_name='"g"',
1201+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1202+
batch_index=0,
1203+
)
1204+
1205+
expected_h_node = EvaluateNode(
1206+
snapshot_name='"h"',
1207+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1208+
batch_index=0,
1209+
)
1210+
1211+
assert dag.graph == {
1212+
expected_a_node: set(),
1213+
expected_g_node: {expected_a_node},
1214+
expected_h_node: {expected_a_node},
1215+
}

0 commit comments

Comments
 (0)