Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ use type_system::knowledge::Entity;

use super::expression::{JoinType, TableName, TableReference};
use crate::store::postgres::query::{
Alias, Column, Condition, Distinctness, EqualityOperator, Expression, Function,
PostgresQueryPath, PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _,
WindowStatement,
Alias, Column, Distinctness, EqualityOperator, Expression, Function, PostgresQueryPath,
PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _, WindowStatement,
expression::{FromItem, GroupByExpression, PostgresType},
table::{
DataTypeEmbeddings, DatabaseColumn as _, EntityEditions, EntityEmbeddings,
Expand Down Expand Up @@ -60,7 +59,7 @@ struct PathSelection {
ordering: Option<(Ordering, Option<NullOrdering>)>,
}

type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec<Condition>;
type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec<Expression>;
type ColumnHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Expression) -> Expression;

pub struct SelectCompiler<'p, 'q: 'p, T: QueryRecord> {
Expand Down Expand Up @@ -201,15 +200,15 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
}
}

fn ontology_table_conditions(&mut self, alias: Alias) -> Vec<Condition> {
fn ontology_table_conditions(&mut self, alias: Alias) -> Vec<Expression> {
let table = Table::OntologyTemporalMetadata.aliased(alias);
if let Some(temporal_axes) = self.temporal_axes
&& self.artifacts.table_info.tables.insert(table)
{
let transaction_time_index = self.time_index(temporal_axes, TimeAxis::TransactionTime);
match temporal_axes {
QueryTemporalAxes::DecisionTime { .. } => {
vec![Condition::TimeIntervalContainsTimestamp(
vec![Expression::time_interval_contains_timestamp(
Expression::ColumnReference(
Column::OntologyTemporalMetadata(
OntologyTemporalMetadata::TransactionTime,
Expand All @@ -220,7 +219,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
)]
}
QueryTemporalAxes::TransactionTime { .. } => {
vec![Condition::Overlap(
vec![Expression::overlap(
Expression::ColumnReference(
Column::OntologyTemporalMetadata(
OntologyTemporalMetadata::TransactionTime,
Expand All @@ -236,12 +235,12 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
}
}

fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec<Condition> {
fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec<Expression> {
let mut conditions = Vec::new();
let table = Table::EntityTemporalMetadata.aliased(alias);
if self.artifacts.table_info.tables.insert(table.clone()) {
if !self.include_drafts {
conditions.push(Condition::Exists(Expression::ColumnReference(
conditions.push(Expression::exists(Expression::ColumnReference(
Column::EntityTemporalMetadata(EntityTemporalMetadata::DraftId).aliased(alias),
)));
}
Expand All @@ -255,7 +254,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
// Adds the pinned timestamp condition, so for the projected decision time, we use
// the transaction time and vice versa.
conditions.extend([
Condition::TimeIntervalContainsTimestamp(
Expression::time_interval_contains_timestamp(
Expression::ColumnReference(
Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis(
pinned_axis,
Expand All @@ -264,7 +263,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
),
Expression::Parameter(pinned_time_index),
),
Condition::Overlap(
Expression::overlap(
Expression::ColumnReference(
Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis(
variable_axis,
Expand Down Expand Up @@ -410,7 +409,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
)
}

/// Compiles a [`Filter`] to a `Condition`.
/// Compiles a [`Filter`] to an [`Expression`].
///
/// # Errors
///
Expand All @@ -420,7 +419,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
pub fn compile_filter<'f: 'q>(
&mut self,
filter: &'p Filter<'f, R>,
) -> Result<Condition, Report<SelectCompilerError>>
) -> Result<Expression, Report<SelectCompilerError>>
where
R::QueryPath<'f>: PostgresQueryPath,
{
Expand All @@ -429,41 +428,41 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
}

Ok(match filter {
Filter::All(filters) => Condition::All(
Filter::All(filters) => Expression::all(
filters
.iter()
.map(|filter| self.compile_filter(filter))
.collect::<Result<_, _>>()?,
),
Filter::Any(filters) => Condition::Any(
Filter::Any(filters) => Expression::any(
filters
.iter()
.map(|filter| self.compile_filter(filter))
.collect::<Result<_, _>>()?,
),
Filter::Not(filter) => Condition::Not(Box::new(self.compile_filter(filter)?)),
Filter::Equal(lhs, rhs) => Condition::Equal(
Filter::Not(filter) => Expression::not(self.compile_filter(filter)?),
Filter::Equal(lhs, rhs) => Expression::equal(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Filter::NotEqual(lhs, rhs) => Condition::NotEqual(
Filter::NotEqual(lhs, rhs) => Expression::not_equal(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Filter::Exists { path } => Condition::Exists(self.compile_path_column(path)),
Filter::Greater(lhs, rhs) => Condition::Greater(
Filter::Exists { path } => Expression::exists(self.compile_path_column(path)),
Filter::Greater(lhs, rhs) => Expression::greater(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Filter::GreaterOrEqual(lhs, rhs) => Condition::GreaterOrEqual(
Filter::GreaterOrEqual(lhs, rhs) => Expression::greater_or_equal(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Filter::Less(lhs, rhs) => Condition::Less(
Filter::Less(lhs, rhs) => Expression::less(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Filter::LessOrEqual(lhs, rhs) => Condition::LessOrEqual(
Filter::LessOrEqual(lhs, rhs) => Expression::less_or_equal(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression(rhs).0,
),
Expand Down Expand Up @@ -610,11 +609,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
})
.chain(once(SelectExpression::Expression {
expression: Expression::Function(Function::Min(
Box::new(Expression::CosineDistance(
Box::new(Expression::ColumnReference(
Box::new(Expression::cosine_distance(
Expression::ColumnReference(
embeddings_column.into(),
)),
Box::new(parameter_expression),
),
parameter_expression,
)),
)),
alias: Some("distance"),
Expand Down Expand Up @@ -643,11 +642,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
alias: None,
});
self.statement.distinct.push(distance_expression.clone());
Condition::LessOrEqual(distance_expression, maximum_expression)
Expression::less_or_equal(distance_expression, maximum_expression)
}
_ => bail!(SelectCompilerError::UnsupportedDistanceExpression),
},
Filter::In(lhs, rhs) => Condition::In(
Filter::In(lhs, rhs) => Expression::r#in(
self.compile_filter_expression(lhs).0,
self.compile_filter_expression_list(rhs).0,
),
Expand All @@ -666,7 +665,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
right_filter
};

Condition::StartsWith(left_filter, right_filter)
Expression::starts_with(left_filter, right_filter)
}
Filter::EndsWith(lhs, rhs) => {
let (left_filter, left_parameter) = self.compile_filter_expression(lhs);
Expand All @@ -683,7 +682,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
right_filter
};

Condition::EndsWith(left_filter, right_filter)
Expression::ends_with(left_filter, right_filter)
}
Filter::ContainsSegment(lhs, rhs) => {
let (left_filter, left_parameter) = self.compile_filter_expression(lhs);
Expand All @@ -700,7 +699,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
right_filter
};

Condition::ContainsSegment(left_filter, right_filter)
Expression::contains_segment(left_filter, right_filter)
}
})
}
Expand All @@ -716,7 +715,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
&mut self,
path: &R::QueryPath<'f>,
operator: EqualityOperator,
) -> Condition
) -> Expression
where
R::QueryPath<'f>: PostgresQueryPath,
{
Expand Down Expand Up @@ -765,10 +764,10 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {

match operator {
EqualityOperator::Equal => {
Condition::Equal(version_expression, latest_version_expression)
Expression::equal(version_expression, latest_version_expression)
}
EqualityOperator::NotEqual => {
Condition::NotEqual(version_expression, latest_version_expression)
Expression::not_equal(version_expression, latest_version_expression)
}
}
}
Expand All @@ -778,7 +777,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> {
///
/// The following [`Filter`]s will be special cased:
/// - Comparing the `"version"` field on [`Table::OntologyIds`] with `"latest"` for equality.
fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option<Condition>
fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option<Expression>
where
R::QueryPath<'f>: PostgresQueryPath,
{
Expand Down Expand Up @@ -1232,7 +1231,7 @@ impl<'p, 'q: 'p> SelectCompiler<'p, 'q, Entity> {

Expression::CaseWhen {
conditions: vec![(
Expression::Condition(Box::new(condition)),
condition,
Expression::Function(Function::ArrayLiteral {
elements: vec![compiler.compile_parameter(property_url).0],
element_type: PostgresType::Text,
Expand Down
Loading
Loading