From 35791dcd76b99696ffbf411288b8d7d738f9bf55 Mon Sep 17 00:00:00 2001 From: Tim Diekmann Date: Sun, 22 Feb 2026 23:24:00 +0100 Subject: [PATCH 1/3] BE-415: Merge `Condition` into `Expression` in query builder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The query builder artificially separated `Condition` (boolean SQL) from `Expression` (value SQL). In SQL, conditions are boolean-valued expressions — there is no fundamental distinction. This prevented natural composition like negating an arbitrary boolean expression. Merge all 16 `Condition` variants into `Expression`, add convenience constructors to avoid `Box::new()` boilerplate, and delete the now- redundant `condition.rs` module. The `Expression::Condition` wrapper variant is also removed since conditions are now first-class expressions. --- .../src/store/postgres/query/compile.rs | 65 ++- .../src/store/postgres/query/condition.rs | 403 -------------- .../postgres/query/expression/conditional.rs | 499 +++++++++++++++++- .../postgres/query/expression/from_item.rs | 17 +- .../store/postgres/query/expression/mod.rs | 2 +- .../postgres/query/expression/where_clause.rs | 8 +- .../src/store/postgres/query/mod.rs | 5 +- .../store/postgres/query/statement/select.rs | 4 +- .../src/store/postgres/query/table.rs | 16 +- 9 files changed, 535 insertions(+), 484 deletions(-) delete mode 100644 libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs index f1b9d347081..23fc9558bbc 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs @@ -18,9 +18,8 @@ use type_system::knowledge::Entity; use super::expression::{JoinType, TableName, TableReference}; use crate::store::postgres::query::{ - Alias, Column, Condition, Distinctness, EqualityOperator, Expression, Function, - PostgresQueryPath, PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _, - WindowStatement, + Alias, Column, Distinctness, EqualityOperator, Expression, Function, PostgresQueryPath, + PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _, WindowStatement, expression::{FromItem, GroupByExpression, PostgresType}, table::{ DataTypeEmbeddings, DatabaseColumn as _, EntityEditions, EntityEmbeddings, @@ -60,7 +59,7 @@ struct PathSelection { ordering: Option<(Ordering, Option)>, } -type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec; +type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec; type ColumnHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Expression) -> Expression; pub struct SelectCompiler<'p, 'q: 'p, T: QueryRecord> { @@ -201,7 +200,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } } - fn ontology_table_conditions(&mut self, alias: Alias) -> Vec { + fn ontology_table_conditions(&mut self, alias: Alias) -> Vec { let table = Table::OntologyTemporalMetadata.aliased(alias); if let Some(temporal_axes) = self.temporal_axes && self.artifacts.table_info.tables.insert(table) @@ -209,7 +208,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { let transaction_time_index = self.time_index(temporal_axes, TimeAxis::TransactionTime); match temporal_axes { QueryTemporalAxes::DecisionTime { .. } => { - vec![Condition::TimeIntervalContainsTimestamp( + vec![Expression::time_interval_contains_timestamp( Expression::ColumnReference( Column::OntologyTemporalMetadata( OntologyTemporalMetadata::TransactionTime, @@ -220,7 +219,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { )] } QueryTemporalAxes::TransactionTime { .. } => { - vec![Condition::Overlap( + vec![Expression::overlap( Expression::ColumnReference( Column::OntologyTemporalMetadata( OntologyTemporalMetadata::TransactionTime, @@ -236,12 +235,12 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } } - fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec { + fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec { let mut conditions = Vec::new(); let table = Table::EntityTemporalMetadata.aliased(alias); if self.artifacts.table_info.tables.insert(table.clone()) { if !self.include_drafts { - conditions.push(Condition::Exists(Expression::ColumnReference( + conditions.push(Expression::exists(Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::DraftId).aliased(alias), ))); } @@ -255,7 +254,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { // Adds the pinned timestamp condition, so for the projected decision time, we use // the transaction time and vice versa. conditions.extend([ - Condition::TimeIntervalContainsTimestamp( + Expression::time_interval_contains_timestamp( Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis( pinned_axis, @@ -264,7 +263,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { ), Expression::Parameter(pinned_time_index), ), - Condition::Overlap( + Expression::overlap( Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis( variable_axis, @@ -410,7 +409,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { ) } - /// Compiles a [`Filter`] to a `Condition`. + /// Compiles a [`Filter`] to an [`Expression`]. /// /// # Errors /// @@ -420,7 +419,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { pub fn compile_filter<'f: 'q>( &mut self, filter: &'p Filter<'f, R>, - ) -> Result> + ) -> Result> where R::QueryPath<'f>: PostgresQueryPath, { @@ -429,41 +428,41 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } Ok(match filter { - Filter::All(filters) => Condition::All( + Filter::All(filters) => Expression::all( filters .iter() .map(|filter| self.compile_filter(filter)) .collect::>()?, ), - Filter::Any(filters) => Condition::Any( + Filter::Any(filters) => Expression::any( filters .iter() .map(|filter| self.compile_filter(filter)) .collect::>()?, ), - Filter::Not(filter) => Condition::Not(Box::new(self.compile_filter(filter)?)), - Filter::Equal(lhs, rhs) => Condition::Equal( + Filter::Not(filter) => Expression::not(self.compile_filter(filter)?), + Filter::Equal(lhs, rhs) => Expression::equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::NotEqual(lhs, rhs) => Condition::NotEqual( + Filter::NotEqual(lhs, rhs) => Expression::not_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::Exists { path } => Condition::Exists(self.compile_path_column(path)), - Filter::Greater(lhs, rhs) => Condition::Greater( + Filter::Exists { path } => Expression::exists(self.compile_path_column(path)), + Filter::Greater(lhs, rhs) => Expression::greater( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::GreaterOrEqual(lhs, rhs) => Condition::GreaterOrEqual( + Filter::GreaterOrEqual(lhs, rhs) => Expression::greater_or_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::Less(lhs, rhs) => Condition::Less( + Filter::Less(lhs, rhs) => Expression::less( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::LessOrEqual(lhs, rhs) => Condition::LessOrEqual( + Filter::LessOrEqual(lhs, rhs) => Expression::less_or_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), @@ -643,11 +642,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { alias: None, }); self.statement.distinct.push(distance_expression.clone()); - Condition::LessOrEqual(distance_expression, maximum_expression) + Expression::less_or_equal(distance_expression, maximum_expression) } _ => bail!(SelectCompilerError::UnsupportedDistanceExpression), }, - Filter::In(lhs, rhs) => Condition::In( + Filter::In(lhs, rhs) => Expression::r#in( self.compile_filter_expression(lhs).0, self.compile_filter_expression_list(rhs).0, ), @@ -666,7 +665,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::StartsWith(left_filter, right_filter) + Expression::starts_with(left_filter, right_filter) } Filter::EndsWith(lhs, rhs) => { let (left_filter, left_parameter) = self.compile_filter_expression(lhs); @@ -683,7 +682,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::EndsWith(left_filter, right_filter) + Expression::ends_with(left_filter, right_filter) } Filter::ContainsSegment(lhs, rhs) => { let (left_filter, left_parameter) = self.compile_filter_expression(lhs); @@ -700,7 +699,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::ContainsSegment(left_filter, right_filter) + Expression::contains_segment(left_filter, right_filter) } }) } @@ -716,7 +715,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { &mut self, path: &R::QueryPath<'f>, operator: EqualityOperator, - ) -> Condition + ) -> Expression where R::QueryPath<'f>: PostgresQueryPath, { @@ -765,10 +764,10 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { match operator { EqualityOperator::Equal => { - Condition::Equal(version_expression, latest_version_expression) + Expression::equal(version_expression, latest_version_expression) } EqualityOperator::NotEqual => { - Condition::NotEqual(version_expression, latest_version_expression) + Expression::not_equal(version_expression, latest_version_expression) } } } @@ -778,7 +777,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { /// /// The following [`Filter`]s will be special cased: /// - Comparing the `"version"` field on [`Table::OntologyIds`] with `"latest"` for equality. - fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option + fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option where R::QueryPath<'f>: PostgresQueryPath, { @@ -1232,7 +1231,7 @@ impl<'p, 'q: 'p> SelectCompiler<'p, 'q, Entity> { Expression::CaseWhen { conditions: vec![( - Expression::Condition(Box::new(condition)), + condition, Expression::Function(Function::ArrayLiteral { elements: vec![compiler.compile_parameter(property_url).0], element_type: PostgresType::Text, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs deleted file mode 100644 index c779184e0ab..00000000000 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs +++ /dev/null @@ -1,403 +0,0 @@ -use core::{fmt, fmt::Write as _}; - -use crate::store::postgres::query::{Expression, Transpile}; - -/// A [`Filter`], which can be transpiled. -/// -/// [`Filter`]: hash_graph_store::filter::Filter -#[derive(Debug, Clone, PartialEq)] -pub enum Condition { - All(Vec), - Any(Vec), - Not(Box), - Equal(Expression, Expression), - NotEqual(Expression, Expression), - Exists(Expression), - Less(Expression, Expression), - LessOrEqual(Expression, Expression), - Greater(Expression, Expression), - GreaterOrEqual(Expression, Expression), - In(Expression, Expression), - TimeIntervalContainsTimestamp(Expression, Expression), - Overlap(Expression, Expression), - StartsWith(Expression, Expression), - EndsWith(Expression, Expression), - ContainsSegment(Expression, Expression), -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum EqualityOperator { - Equal, - NotEqual, -} - -impl Transpile for Condition { - #[expect(clippy::too_many_lines)] - fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::All(conditions) if conditions.is_empty() => fmt.write_str("TRUE"), - Self::Any(conditions) if conditions.is_empty() => fmt.write_str("FALSE"), - Self::All(conditions) => { - for (idx, condition) in conditions.iter().enumerate() { - if idx > 0 { - fmt.write_str(" AND ")?; - } - fmt.write_char('(')?; - condition.transpile(fmt)?; - fmt.write_char(')')?; - } - Ok(()) - } - Self::Any(conditions) => { - if conditions.len() > 1 { - fmt.write_char('(')?; - } - for (idx, condition) in conditions.iter().enumerate() { - if idx > 0 { - fmt.write_str(" OR ")?; - } - fmt.write_char('(')?; - condition.transpile(fmt)?; - fmt.write_char(')')?; - } - if conditions.len() > 1 { - fmt.write_char(')')?; - } - Ok(()) - } - Self::Not(condition) => { - if let Self::Exists(path) = &**condition { - path.transpile(fmt)?; - fmt.write_str(" IS NOT NULL") - } else { - fmt.write_str("NOT(")?; - condition.transpile(fmt)?; - fmt.write_char(')') - } - } - Self::Exists(path) => { - path.transpile(fmt)?; - fmt.write_str(" IS NULL") - } - Self::Equal(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ")?; - rhs.transpile(fmt) - } - Self::NotEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" != ")?; - rhs.transpile(fmt) - } - Self::Less(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" < ")?; - rhs.transpile(fmt) - } - Self::LessOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" <= ")?; - rhs.transpile(fmt) - } - Self::Greater(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" > ")?; - rhs.transpile(fmt) - } - Self::GreaterOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" >= ")?; - rhs.transpile(fmt) - } - Self::In(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ANY(")?; - rhs.transpile(fmt)?; - fmt.write_char(')') - } - Self::TimeIntervalContainsTimestamp(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" @> ")?; - rhs.transpile(fmt)?; - fmt.write_str("::TIMESTAMPTZ") - } - Self::Overlap(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" && ")?; - rhs.transpile(fmt) - } - Self::StartsWith(lhs, rhs) => { - fmt.write_str("starts_with(")?; - lhs.transpile(fmt)?; - fmt.write_str(", ")?; - rhs.transpile(fmt)?; - fmt.write_char(')') - } - Self::EndsWith(lhs, rhs) => { - fmt.write_str("right(")?; - lhs.transpile(fmt)?; - fmt.write_str(", length(")?; - rhs.transpile(fmt)?; - fmt.write_str(")) = ")?; - rhs.transpile(fmt) - } - Self::ContainsSegment(lhs, rhs) => { - fmt.write_str("strpos(")?; - lhs.transpile(fmt)?; - fmt.write_str(", ")?; - rhs.transpile(fmt)?; - fmt.write_str(") > 0") - } - } - } -} - -#[cfg(test)] -mod tests { - use alloc::borrow::Cow; - - use hash_codec::numeric::Real; - use hash_graph_store::{ - data_type::DataTypeQueryPath, - filter::{Filter, FilterExpression, Parameter}, - }; - use postgres_types::ToSql; - use type_system::ontology::DataTypeWithMetadata; - - use crate::store::postgres::query::{SelectCompiler, Transpile as _}; - - fn test_condition<'p, 'f: 'p>( - filter: &'f Filter<'p, DataTypeWithMetadata>, - rendered: &'static str, - parameters: &[&'p dyn ToSql], - ) { - let mut compiler = SelectCompiler::new(None, false); - let condition = compiler - .compile_filter(filter) - .expect("failed to compile filter"); - - assert_eq!(condition.transpile_to_string(), rendered); - - let parameter_list = parameters - .iter() - .map(|parameter| format!("{parameter:?}")) - .collect::>(); - let expected_parameters = compiler - .compile() - .1 - .iter() - .map(|parameter| format!("{parameter:?}")) - .collect::>(); - - assert_eq!(parameter_list, expected_parameters); - } - - #[test] - fn transpile_empty_condition() { - test_condition(&Filter::All(vec![]), "TRUE", &[]); - test_condition(&Filter::Any(vec![]), "FALSE", &[]); - } - - #[test] - fn transpile_exists_condition() { - test_condition( - &Filter::Exists { - path: DataTypeQueryPath::Description, - }, - r#""data_types_0_1_0"."schema"->>'description' IS NULL"#, - &[], - ); - - test_condition( - &Filter::Not(Box::new(Filter::Exists { - path: DataTypeQueryPath::Description, - })), - r#""data_types_0_1_0"."schema"->>'description' IS NOT NULL"#, - &[], - ); - } - - #[test] - fn transpile_all_condition() { - test_condition( - &Filter::All(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - - test_condition( - &Filter::All(vec![ - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::BaseUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/", - )), - convert: None, - }, - ), - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Version, - }, - FilterExpression::Parameter { - parameter: Parameter::Decimal(Real::from_natural(1, 1)), - convert: None, - }, - ), - ]), - r#"("ontology_ids_0_1_0"."base_url" = $1) AND ("ontology_ids_0_1_0"."version" = $2)"#, - &[ - &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", - &Real::from_natural(1, 1), - ], - ); - } - - #[test] - fn transpile_any_condition() { - test_condition( - &Filter::Any(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - - test_condition( - &Filter::Any(vec![ - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::BaseUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/", - )), - convert: None, - }, - ), - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Version, - }, - FilterExpression::Parameter { - parameter: Parameter::Decimal(Real::from_natural(1, 1)), - convert: None, - }, - ), - ]), - r#"(("ontology_ids_0_1_0"."base_url" = $1) OR ("ontology_ids_0_1_0"."version" = $2))"#, - &[ - &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", - &Real::from_natural(1, 1), - ], - ); - } - - #[test] - fn transpile_not_condition() { - test_condition( - &Filter::Not(Box::new(Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - ))), - r#"NOT("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - } - - #[test] - fn transpile_starts_with_condition() { - test_condition( - &Filter::StartsWith( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("foo")), - convert: None, - }, - ), - r#"starts_with("data_types_0_1_0"."schema"->>'title', $1)"#, - &[&"foo"], - ); - } - - #[test] - fn transpile_ends_with_condition() { - test_condition( - &Filter::EndsWith( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("bar")), - convert: None, - }, - ), - r#"right("data_types_0_1_0"."schema"->>'title', length($1)) = $1"#, - &[&"bar"], - ); - } - - #[test] - fn transpile_contains_segment_condition() { - test_condition( - &Filter::ContainsSegment( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("baz")), - convert: None, - }, - ), - r#"strpos("data_types_0_1_0"."schema"->>'title', $1) > 0"#, - &[&"baz"], - ); - } - - #[test] - fn render_without_parameters() { - test_condition( - &Filter::Any(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Description, - }, - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'description' = "data_types_0_1_0"."schema"->>'title')"#, - &[], - ); - } -} diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs index 4ea98e332ad..15dc7e2ed31 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs @@ -5,9 +5,7 @@ use core::fmt::{ use hash_graph_store::filter::PathToken; use super::ColumnReference; -use crate::store::postgres::query::{ - Condition, SelectStatement, Table, Transpile, WindowStatement, -}; +use crate::store::postgres::query::{SelectStatement, Table, Transpile, WindowStatement}; #[derive(Debug, Clone, PartialEq)] pub enum Function { @@ -207,9 +205,20 @@ impl Transpile for PostgresType { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum EqualityOperator { + Equal, + NotEqual, +} + /// A compiled expression in Postgres. +/// +/// This type unifies both value expressions and boolean conditions. In SQL, conditions are +/// boolean-valued expressions — there is no fundamental distinction between a "condition" and +/// an "expression". This allows natural composition, e.g. negating any boolean expression. #[derive(Debug, Clone, PartialEq)] pub enum Expression { + // --- Value expressions --- ColumnReference(ColumnReference<'static>), /// A parameter are transpiled as a placeholder, e.g. `$1`, in order to prevent SQL injection. Parameter(usize), @@ -242,21 +251,121 @@ pub enum Expression { /// Optional else result if no condition matches. else_result: Option>, }, - /// Wraps a [`Condition`] for use in expression contexts. - /// - /// This allows conditions (which evaluate to boolean) to be used where expressions - /// are expected, such as in CASE WHEN conditions. - /// - /// # Example SQL - /// ```sql - /// CASE WHEN (a = b AND c != d) THEN 'yes' ELSE 'no' END - /// ``` - Condition(Box), + + // --- Boolean conditions --- + /// Conjunction of conditions. Transpiles to `(c1) AND (c2) AND ...`. + /// Empty list transpiles to `TRUE`. + All(Vec), + /// Disjunction of conditions. Transpiles to `((c1) OR (c2) OR ...)`. + /// Empty list transpiles to `FALSE`. + Any(Vec), + /// Negation. Transpiles to `NOT(expr)`. + /// Special case: `Not(Exists(expr))` transpiles to `expr IS NOT NULL`. + Not(Box), + Equal(Box, Box), + NotEqual(Box, Box), + /// Null check. Transpiles to `expr IS NULL`. + Exists(Box), + Less(Box, Box), + LessOrEqual(Box, Box), + Greater(Box, Box), + GreaterOrEqual(Box, Box), + In(Box, Box), + TimeIntervalContainsTimestamp(Box, Box), + Overlap(Box, Box), + StartsWith(Box, Box), + EndsWith(Box, Box), + ContainsSegment(Box, Box), +} + +/// Convenience constructors for condition variants to avoid `Box::new()` boilerplate. +impl Expression { + #[must_use] + pub const fn all(conditions: Vec) -> Self { + Self::All(conditions) + } + + #[must_use] + pub const fn any(conditions: Vec) -> Self { + Self::Any(conditions) + } + + #[must_use] + pub fn not(inner: Self) -> Self { + Self::Not(Box::new(inner)) + } + + #[must_use] + pub fn equal(lhs: Self, rhs: Self) -> Self { + Self::Equal(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn not_equal(lhs: Self, rhs: Self) -> Self { + Self::NotEqual(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn exists(expr: Self) -> Self { + Self::Exists(Box::new(expr)) + } + + #[must_use] + pub fn less(lhs: Self, rhs: Self) -> Self { + Self::Less(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn less_or_equal(lhs: Self, rhs: Self) -> Self { + Self::LessOrEqual(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn greater(lhs: Self, rhs: Self) -> Self { + Self::Greater(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn greater_or_equal(lhs: Self, rhs: Self) -> Self { + Self::GreaterOrEqual(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn r#in(lhs: Self, rhs: Self) -> Self { + Self::In(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn time_interval_contains_timestamp(lhs: Self, rhs: Self) -> Self { + Self::TimeIntervalContainsTimestamp(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn overlap(lhs: Self, rhs: Self) -> Self { + Self::Overlap(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn starts_with(lhs: Self, rhs: Self) -> Self { + Self::StartsWith(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn ends_with(lhs: Self, rhs: Self) -> Self { + Self::EndsWith(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn contains_segment(lhs: Self, rhs: Self) -> Self { + Self::ContainsSegment(Box::new(lhs), Box::new(rhs)) + } } impl Transpile for Expression { + #[expect(clippy::too_many_lines)] fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { + // --- Value expressions --- Self::ColumnReference(column) => column.transpile(fmt), Self::Parameter(index) => write!(fmt, "${index}"), Self::Constant(constant) => constant.transpile(fmt), @@ -301,11 +410,121 @@ impl Transpile for Expression { } fmt.write_str(" END") } - Self::Condition(condition) => { - fmt.write_char('(')?; - condition.transpile(fmt)?; + + // --- Boolean conditions --- + Self::All(conditions) if conditions.is_empty() => fmt.write_str("TRUE"), + Self::Any(conditions) if conditions.is_empty() => fmt.write_str("FALSE"), + Self::All(conditions) => { + for (idx, condition) in conditions.iter().enumerate() { + if idx > 0 { + fmt.write_str(" AND ")?; + } + fmt.write_char('(')?; + condition.transpile(fmt)?; + fmt.write_char(')')?; + } + Ok(()) + } + Self::Any(conditions) => { + if conditions.len() > 1 { + fmt.write_char('(')?; + } + for (idx, condition) in conditions.iter().enumerate() { + if idx > 0 { + fmt.write_str(" OR ")?; + } + fmt.write_char('(')?; + condition.transpile(fmt)?; + fmt.write_char(')')?; + } + if conditions.len() > 1 { + fmt.write_char(')')?; + } + Ok(()) + } + Self::Not(inner) => { + if let Self::Exists(path) = &**inner { + path.transpile(fmt)?; + fmt.write_str(" IS NOT NULL") + } else { + fmt.write_str("NOT(")?; + inner.transpile(fmt)?; + fmt.write_char(')') + } + } + Self::Exists(path) => { + path.transpile(fmt)?; + fmt.write_str(" IS NULL") + } + Self::Equal(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" = ")?; + rhs.transpile(fmt) + } + Self::NotEqual(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" != ")?; + rhs.transpile(fmt) + } + Self::Less(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" < ")?; + rhs.transpile(fmt) + } + Self::LessOrEqual(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" <= ")?; + rhs.transpile(fmt) + } + Self::Greater(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" > ")?; + rhs.transpile(fmt) + } + Self::GreaterOrEqual(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" >= ")?; + rhs.transpile(fmt) + } + Self::In(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" = ANY(")?; + rhs.transpile(fmt)?; fmt.write_char(')') } + Self::TimeIntervalContainsTimestamp(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" @> ")?; + rhs.transpile(fmt)?; + fmt.write_str("::TIMESTAMPTZ") + } + Self::Overlap(lhs, rhs) => { + lhs.transpile(fmt)?; + fmt.write_str(" && ")?; + rhs.transpile(fmt) + } + Self::StartsWith(lhs, rhs) => { + fmt.write_str("starts_with(")?; + lhs.transpile(fmt)?; + fmt.write_str(", ")?; + rhs.transpile(fmt)?; + fmt.write_char(')') + } + Self::EndsWith(lhs, rhs) => { + fmt.write_str("right(")?; + lhs.transpile(fmt)?; + fmt.write_str(", length(")?; + rhs.transpile(fmt)?; + fmt.write_str(")) = ")?; + rhs.transpile(fmt) + } + Self::ContainsSegment(lhs, rhs) => { + fmt.write_str("strpos(")?; + lhs.transpile(fmt)?; + fmt.write_str(", ")?; + rhs.transpile(fmt)?; + fmt.write_str(") > 0") + } } } } @@ -322,11 +541,20 @@ where #[cfg(test)] mod tests { - use hash_graph_store::data_type::DataTypeQueryPath; + use alloc::borrow::Cow; + + use hash_codec::numeric::Real; + use hash_graph_store::{ + data_type::DataTypeQueryPath, + filter::{Filter, FilterExpression, Parameter}, + }; + use postgres_types::ToSql; + use type_system::ontology::DataTypeWithMetadata; use super::*; use crate::store::postgres::query::{ - Alias, PostgresQueryPath as _, test_helper::max_version_expression, + Alias, PostgresQueryPath as _, SelectCompiler, Transpile as _, + test_helper::max_version_expression, }; #[test] @@ -435,4 +663,239 @@ mod tests { }); assert_eq!(empty_array.transpile_to_string(), "ARRAY[]::text[]"); } + + fn test_condition<'p, 'f: 'p>( + filter: &'f Filter<'p, DataTypeWithMetadata>, + rendered: &'static str, + parameters: &[&'p dyn ToSql], + ) { + let mut compiler = SelectCompiler::new(None, false); + let condition = compiler + .compile_filter(filter) + .expect("failed to compile filter"); + + assert_eq!(condition.transpile_to_string(), rendered); + + let parameter_list = parameters + .iter() + .map(|parameter| format!("{parameter:?}")) + .collect::>(); + let expected_parameters = compiler + .compile() + .1 + .iter() + .map(|parameter| format!("{parameter:?}")) + .collect::>(); + + assert_eq!(parameter_list, expected_parameters); + } + + #[test] + fn transpile_empty_condition() { + test_condition(&Filter::All(vec![]), "TRUE", &[]); + test_condition(&Filter::Any(vec![]), "FALSE", &[]); + } + + #[test] + fn transpile_exists_condition() { + test_condition( + &Filter::Exists { + path: DataTypeQueryPath::Description, + }, + r#""data_types_0_1_0"."schema"->>'description' IS NULL"#, + &[], + ); + + test_condition( + &Filter::Not(Box::new(Filter::Exists { + path: DataTypeQueryPath::Description, + })), + r#""data_types_0_1_0"."schema"->>'description' IS NOT NULL"#, + &[], + ); + } + + #[test] + fn transpile_all_condition() { + test_condition( + &Filter::All(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + + test_condition( + &Filter::All(vec![ + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::BaseUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/", + )), + convert: None, + }, + ), + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Version, + }, + FilterExpression::Parameter { + parameter: Parameter::Decimal(Real::from_natural(1, 1)), + convert: None, + }, + ), + ]), + r#"("ontology_ids_0_1_0"."base_url" = $1) AND ("ontology_ids_0_1_0"."version" = $2)"#, + &[ + &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", + &Real::from_natural(1, 1), + ], + ); + } + + #[test] + fn transpile_any_condition() { + test_condition( + &Filter::Any(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + + test_condition( + &Filter::Any(vec![ + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::BaseUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/", + )), + convert: None, + }, + ), + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Version, + }, + FilterExpression::Parameter { + parameter: Parameter::Decimal(Real::from_natural(1, 1)), + convert: None, + }, + ), + ]), + r#"(("ontology_ids_0_1_0"."base_url" = $1) OR ("ontology_ids_0_1_0"."version" = $2))"#, + &[ + &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", + &Real::from_natural(1, 1), + ], + ); + } + + #[test] + fn transpile_not_condition() { + test_condition( + &Filter::Not(Box::new(Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + ))), + r#"NOT("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + } + + #[test] + fn transpile_starts_with_condition() { + test_condition( + &Filter::StartsWith( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("foo")), + convert: None, + }, + ), + r#"starts_with("data_types_0_1_0"."schema"->>'title', $1)"#, + &[&"foo"], + ); + } + + #[test] + fn transpile_ends_with_condition() { + test_condition( + &Filter::EndsWith( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("bar")), + convert: None, + }, + ), + r#"right("data_types_0_1_0"."schema"->>'title', length($1)) = $1"#, + &[&"bar"], + ); + } + + #[test] + fn transpile_contains_segment_condition() { + test_condition( + &Filter::ContainsSegment( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("baz")), + convert: None, + }, + ), + r#"strpos("data_types_0_1_0"."schema"->>'title', $1) > 0"#, + &[&"baz"], + ); + } + + #[test] + fn render_without_parameters() { + test_condition( + &Filter::Any(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Description, + }, + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'description' = "data_types_0_1_0"."schema"->>'title')"#, + &[], + ); + } } diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs index 69a7038f3f9..ec93ebcc8f5 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs @@ -1,7 +1,7 @@ use core::fmt::{self, Write as _}; -use super::{ColumnName, Function, JoinType, TableReference, TableSample}; -use crate::store::postgres::query::{Condition, SelectStatement, Transpile}; +use super::{ColumnName, Expression, Function, JoinType, TableReference, TableSample}; +use crate::store::postgres::query::{SelectStatement, Transpile}; /// A FROM item in a PostgreSQL query. /// @@ -171,7 +171,7 @@ pub enum FromItem<'id> { /// /// Multiple conditions support composite joins (e.g., multi-column foreign keys). /// When empty, transpiles to `ON TRUE`, producing a cartesian product. - condition: Vec, + condition: Vec, }, /// A JOIN using a USING clause with specified column names. @@ -320,7 +320,7 @@ impl<'id> FromItem<'id> { self, #[builder(start_fn)] r#type: JoinType, #[builder(start_fn, into)] from: Self, - #[builder(setters(vis = ""))] condition: Vec, + #[builder(setters(vis = ""))] condition: Vec, #[builder(setters(vis = ""))] join_using_alias: Option>, #[builder(setters(vis = ""))] columns: Vec>, ) -> Self { @@ -581,15 +581,12 @@ mod from_item_join_builder_impl { IsSet, IsUnset, SetColumns, SetCondition, SetJoinUsingAlias, State, }, }; - use crate::store::postgres::query::{ - Condition, - expression::{ColumnName, TableReference}, - }; + use crate::store::postgres::query::expression::{ColumnName, Expression, TableReference}; impl<'id, S: State> FromItemJoinBuilder<'id, S> { pub fn on( self, - conditions: Vec, + conditions: Vec, ) -> FromItemJoinBuilder<'id, SetCondition>> where S: State, @@ -958,7 +955,7 @@ mod tests { let base = FromItem::table(Table::DataTypes).build(); - let conditions = vec![Condition::Equal( + let conditions = vec![Expression::equal( Expression::ColumnReference(Column::DataTypes(DataTypes::OntologyId).into()), Expression::ColumnReference(Column::OntologyIds(OntologyIds::OntologyId).into()), )]; diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs index 402ccc95e05..737762e4bb2 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs @@ -13,7 +13,7 @@ mod with_clause; pub use self::{ column_reference::{ColumnName, ColumnReference}, - conditional::{Constant, Expression, Function, PostgresType}, + conditional::{Constant, EqualityOperator, Expression, Function, PostgresType}, from_item::FromItem, group_by_clause::GroupByExpression, join_type::JoinType, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs index f9b1bd4d48b..1ae7fa895d3 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs @@ -2,13 +2,11 @@ use core::fmt; use hash_graph_store::query::{NullOrdering, Ordering}; -use crate::store::postgres::query::{ - Condition, Expression, Transpile, expression::conditional::Transpiler, -}; +use crate::store::postgres::query::{Expression, Transpile, expression::conditional::Transpiler}; #[derive(Debug, Clone, Default, PartialEq)] pub struct WhereExpression { - pub conditions: Vec, + pub conditions: Vec, pub cursor: Vec<( Expression, Option, @@ -18,7 +16,7 @@ pub struct WhereExpression { } impl WhereExpression { - pub fn add_condition(&mut self, condition: Condition) { + pub fn add_condition(&mut self, condition: Expression) { self.conditions.push(condition); } diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs index a212ae82224..11ec1956bf8 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs @@ -3,7 +3,6 @@ //! Postgres implementation to compile queries. mod compile; -mod condition; mod data_type; mod entity; mod entity_type; @@ -31,9 +30,9 @@ use type_system::knowledge::{Entity, PropertyValue}; pub use self::{ compile::{SelectCompiler, SelectCompilerError}, - condition::{Condition, EqualityOperator}, expression::{ - Constant, Expression, Function, SelectExpression, WhereExpression, WithExpression, + Constant, EqualityOperator, Expression, Function, SelectExpression, WhereExpression, + WithExpression, }, statement::{ Distinctness, InsertStatementBuilder, SelectStatement, Statement, WindowStatement, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs index f6490aa6142..3f10b19f70c 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs @@ -1379,8 +1379,8 @@ mod tests { &compiler, r#" SELECT ("entity_editions_0_1_0"."properties" - (CASE WHEN - (($1 = ANY("entity_is_of_type_ids_0_1_0"."base_urls")) - AND ("entity_temporal_metadata_0_0_0"."entity_uuid" != $2)) + ($1 = ANY("entity_is_of_type_ids_0_1_0"."base_urls")) + AND ("entity_temporal_metadata_0_0_0"."entity_uuid" != $2) THEN ARRAY[$3]::text[] ELSE ARRAY[]::text[] END)) FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs index f3cc2fa7845..8e3b708eb2b 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs @@ -12,9 +12,7 @@ use hash_graph_temporal_versioning::TimeAxis; use postgres_types::ToSql; use super::expression::{ColumnName, ColumnReference, TableName, TableReference}; -use crate::store::postgres::query::{ - Condition, Constant, Expression, Transpile, expression::JoinType, -}; +use crate::store::postgres::query::{Constant, Expression, Transpile, expression::JoinType}; /// The name of a [`Table`] in the Postgres database. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -1927,13 +1925,13 @@ impl ForeignKeyReference { } } - pub fn conditions(self, on_alias: Alias, join_alias: Alias) -> Vec { + pub fn conditions(self, on_alias: Alias, join_alias: Alias) -> Vec { match self { Self::Single { join, on, join_type: _, - } => vec![Condition::Equal( + } => vec![Expression::equal( Expression::ColumnReference(join.aliased(join_alias)), Expression::ColumnReference(on.aliased(on_alias)), )], @@ -1942,11 +1940,11 @@ impl ForeignKeyReference { on: [on1, on2], join_type: _, } => vec![ - Condition::Equal( + Expression::equal( Expression::ColumnReference(join1.aliased(join_alias)), Expression::ColumnReference(on1.aliased(on_alias)), ), - Condition::Equal( + Expression::equal( Expression::ColumnReference(join2.aliased(join_alias)), Expression::ColumnReference(on2.aliased(on_alias)), ), @@ -2154,7 +2152,7 @@ impl Relation { } #[must_use] - pub fn additional_conditions(self, table: &TableReference<'_>) -> Vec { + pub fn additional_conditions(self, table: &TableReference<'_>) -> Vec { match self { Self::Reference { table: reference_table, @@ -2166,7 +2164,7 @@ impl Relation { column .inheritance_depth() .map_or_else(Vec::new, |inheritance_depth| { - vec![Condition::LessOrEqual( + vec![Expression::less_or_equal( Expression::ColumnReference( column.aliased(table.alias.unwrap_or_default()), ), From b2809c613facdb7866b4749ccb966cf0b2d06a5e Mon Sep 17 00:00:00 2001 From: Tim Diekmann Date: Mon, 23 Feb 2026 00:11:22 +0100 Subject: [PATCH 2/3] BE-415: Extract unary and binary operations into dedicated types Introduce `UnaryExpression`/`UnaryOperator` and `BinaryExpression`/`BinaryOperator` structs to replace the 14 individual condition variants on `Expression`. The enum shrinks from 26 to 15 variants while the external API (convenience constructors) remains unchanged. --- .../src/store/postgres/query/compile.rs | 8 +- .../store/postgres/query/expression/binary.rs | 80 ++++++++ .../postgres/query/expression/conditional.rs | 172 ++++++++---------- .../store/postgres/query/expression/mod.rs | 4 + .../store/postgres/query/expression/unary.rs | 46 +++++ 5 files changed, 208 insertions(+), 102 deletions(-) create mode 100644 libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs create mode 100644 libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs index 23fc9558bbc..1e5a35a1264 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs @@ -609,11 +609,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { }) .chain(once(SelectExpression::Expression { expression: Expression::Function(Function::Min( - Box::new(Expression::CosineDistance( - Box::new(Expression::ColumnReference( + Box::new(Expression::cosine_distance( + Expression::ColumnReference( embeddings_column.into(), - )), - Box::new(parameter_expression), + ), + parameter_expression, )), )), alias: Some("distance"), diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs new file mode 100644 index 00000000000..987b04ed164 --- /dev/null +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs @@ -0,0 +1,80 @@ +use core::fmt::{self, Write as _}; + +use crate::store::postgres::query::{Expression, Transpile}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[expect( + clippy::doc_paragraphs_missing_punctuation, + reason = "The documentation is only the transpiled symbols" +)] +pub enum BinaryOperator { + /// ` = ` + Equal, + /// ` != ` + NotEqual, + /// ` > ` + Greater, + /// ` >= ` + GreaterOrEqual, + /// ` < ` + Less, + /// ` <= ` + LessOrEqual, + /// ` = ANY()` + In, + /// ` @> ` + TimeIntervalContainsTimestamp, + /// ` && ::TIMESTAMPTZ` + Overlap, + /// ` <=> ` + CosineDistance, +} + +impl BinaryOperator { + fn transpile(self, fmt: &mut fmt::Formatter) -> fmt::Result { + let string = match self { + Self::Equal => " = ", + Self::NotEqual => " != ", + Self::Greater => " > ", + Self::GreaterOrEqual => " >= ", + Self::Less => " < ", + Self::LessOrEqual => " <= ", + Self::In => " = ANY(", + Self::TimeIntervalContainsTimestamp => " @> ", + Self::Overlap => " && ", + Self::CosineDistance => " <=> ", + }; + fmt.write_str(string) + } + + fn transpile_post(self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::In => fmt.write_char(')'), + Self::TimeIntervalContainsTimestamp => fmt.write_str("::TIMESTAMPTZ"), + Self::Equal + | Self::NotEqual + | Self::Greater + | Self::GreaterOrEqual + | Self::Less + | Self::LessOrEqual + | Self::Overlap + | Self::CosineDistance => Ok(()), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BinaryExpression { + pub op: BinaryOperator, + pub left: Box, + pub right: Box, +} + +impl Transpile for BinaryExpression { + fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.left.transpile(fmt)?; + self.op.transpile(fmt)?; + self.right.transpile(fmt)?; + self.op.transpile_post(fmt) + } +} diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs index 15dc7e2ed31..9b478c687be 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs @@ -5,7 +5,10 @@ use core::fmt::{ use hash_graph_store::filter::PathToken; use super::ColumnReference; -use crate::store::postgres::query::{SelectStatement, Table, Transpile, WindowStatement}; +use crate::store::postgres::query::{ + SelectStatement, Table, Transpile, WindowStatement, + expression::{BinaryExpression, BinaryOperator, UnaryExpression, UnaryOperator}, +}; #[derive(Debug, Clone, PartialEq)] pub enum Function { @@ -218,7 +221,6 @@ pub enum EqualityOperator { /// an "expression". This allows natural composition, e.g. negating any boolean expression. #[derive(Debug, Clone, PartialEq)] pub enum Expression { - // --- Value expressions --- ColumnReference(ColumnReference<'static>), /// A parameter are transpiled as a placeholder, e.g. `$1`, in order to prevent SQL injection. Parameter(usize), @@ -226,7 +228,6 @@ pub enum Expression { /// prevent SQL injection and no user input should ever be used as a [`Constant`]. Constant(Constant), Function(Function), - CosineDistance(Box, Box), Window(Box, WindowStatement), Cast(Box, PostgresType), /// Row expansion - expands a composite type into its constituent columns. @@ -252,27 +253,16 @@ pub enum Expression { else_result: Option>, }, - // --- Boolean conditions --- + Unary(UnaryExpression), + Binary(BinaryExpression), + /// Conjunction of conditions. Transpiles to `(c1) AND (c2) AND ...`. /// Empty list transpiles to `TRUE`. All(Vec), /// Disjunction of conditions. Transpiles to `((c1) OR (c2) OR ...)`. /// Empty list transpiles to `FALSE`. Any(Vec), - /// Negation. Transpiles to `NOT(expr)`. - /// Special case: `Not(Exists(expr))` transpiles to `expr IS NOT NULL`. - Not(Box), - Equal(Box, Box), - NotEqual(Box, Box), - /// Null check. Transpiles to `expr IS NULL`. - Exists(Box), - Less(Box, Box), - LessOrEqual(Box, Box), - Greater(Box, Box), - GreaterOrEqual(Box, Box), - In(Box, Box), - TimeIntervalContainsTimestamp(Box, Box), - Overlap(Box, Box), + StartsWith(Box, Box), EndsWith(Box, Box), ContainsSegment(Box, Box), @@ -292,57 +282,108 @@ impl Expression { #[must_use] pub fn not(inner: Self) -> Self { - Self::Not(Box::new(inner)) + Self::Unary(UnaryExpression { + op: UnaryOperator::Not, + expr: Box::new(inner), + }) } #[must_use] pub fn equal(lhs: Self, rhs: Self) -> Self { - Self::Equal(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::Equal, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn not_equal(lhs: Self, rhs: Self) -> Self { - Self::NotEqual(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::NotEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn exists(expr: Self) -> Self { - Self::Exists(Box::new(expr)) + Self::Unary(UnaryExpression { + op: UnaryOperator::IsNull, + expr: Box::new(expr), + }) } #[must_use] pub fn less(lhs: Self, rhs: Self) -> Self { - Self::Less(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::Less, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn less_or_equal(lhs: Self, rhs: Self) -> Self { - Self::LessOrEqual(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::LessOrEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn greater(lhs: Self, rhs: Self) -> Self { - Self::Greater(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::Greater, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn greater_or_equal(lhs: Self, rhs: Self) -> Self { - Self::GreaterOrEqual(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::GreaterOrEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn r#in(lhs: Self, rhs: Self) -> Self { - Self::In(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::In, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn time_interval_contains_timestamp(lhs: Self, rhs: Self) -> Self { - Self::TimeIntervalContainsTimestamp(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::TimeIntervalContainsTimestamp, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] pub fn overlap(lhs: Self, rhs: Self) -> Self { - Self::Overlap(Box::new(lhs), Box::new(rhs)) + Self::Binary(BinaryExpression { + op: BinaryOperator::Overlap, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn cosine_distance(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::CosineDistance, + left: Box::new(lhs), + right: Box::new(rhs), + }) } #[must_use] @@ -362,7 +403,6 @@ impl Expression { } impl Transpile for Expression { - #[expect(clippy::too_many_lines)] fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { // --- Value expressions --- @@ -370,11 +410,6 @@ impl Transpile for Expression { Self::Parameter(index) => write!(fmt, "${index}"), Self::Constant(constant) => constant.transpile(fmt), Self::Function(function) => function.transpile(fmt), - Self::CosineDistance(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" <=> ")?; - rhs.transpile(fmt) - } Self::Window(expression, window) => { expression.transpile(fmt)?; fmt.write_str(" OVER (")?; @@ -411,6 +446,9 @@ impl Transpile for Expression { fmt.write_str(" END") } + Self::Unary(unary) => unary.transpile(fmt), + Self::Binary(binary) => binary.transpile(fmt), + // --- Boolean conditions --- Self::All(conditions) if conditions.is_empty() => fmt.write_str("TRUE"), Self::Any(conditions) if conditions.is_empty() => fmt.write_str("FALSE"), @@ -442,67 +480,6 @@ impl Transpile for Expression { } Ok(()) } - Self::Not(inner) => { - if let Self::Exists(path) = &**inner { - path.transpile(fmt)?; - fmt.write_str(" IS NOT NULL") - } else { - fmt.write_str("NOT(")?; - inner.transpile(fmt)?; - fmt.write_char(')') - } - } - Self::Exists(path) => { - path.transpile(fmt)?; - fmt.write_str(" IS NULL") - } - Self::Equal(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ")?; - rhs.transpile(fmt) - } - Self::NotEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" != ")?; - rhs.transpile(fmt) - } - Self::Less(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" < ")?; - rhs.transpile(fmt) - } - Self::LessOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" <= ")?; - rhs.transpile(fmt) - } - Self::Greater(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" > ")?; - rhs.transpile(fmt) - } - Self::GreaterOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" >= ")?; - rhs.transpile(fmt) - } - Self::In(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ANY(")?; - rhs.transpile(fmt)?; - fmt.write_char(')') - } - Self::TimeIntervalContainsTimestamp(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" @> ")?; - rhs.transpile(fmt)?; - fmt.write_str("::TIMESTAMPTZ") - } - Self::Overlap(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" && ")?; - rhs.transpile(fmt) - } Self::StartsWith(lhs, rhs) => { fmt.write_str("starts_with(")?; lhs.transpile(fmt)?; @@ -553,8 +530,7 @@ mod tests { use super::*; use crate::store::postgres::query::{ - Alias, PostgresQueryPath as _, SelectCompiler, Transpile as _, - test_helper::max_version_expression, + Alias, PostgresQueryPath as _, SelectCompiler, test_helper::max_version_expression, }; #[test] diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs index 737762e4bb2..f24be237ff7 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs @@ -1,3 +1,4 @@ +mod binary; mod column_reference; mod conditional; mod from_item; @@ -8,10 +9,12 @@ mod order_clause; mod select_clause; mod table_reference; mod table_sample; +mod unary; mod where_clause; mod with_clause; pub use self::{ + binary::{BinaryExpression, BinaryOperator}, column_reference::{ColumnName, ColumnReference}, conditional::{Constant, EqualityOperator, Expression, Function, PostgresType}, from_item::FromItem, @@ -21,6 +24,7 @@ pub use self::{ select_clause::SelectExpression, table_reference::{TableName, TableReference}, table_sample::TableSample, + unary::{UnaryExpression, UnaryOperator}, where_clause::WhereExpression, with_clause::WithExpression, }; diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs new file mode 100644 index 00000000000..8dbbee87d52 --- /dev/null +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs @@ -0,0 +1,46 @@ +use core::fmt::{self, Write as _}; + +use crate::store::postgres::query::{Expression, Transpile}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[expect( + clippy::doc_paragraphs_missing_punctuation, + reason = "The documentation is only the transpiled symbols" +)] +pub enum UnaryOperator { + /// `NOT()` + Not, + /// ` IS NULL` + IsNull, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnaryExpression { + pub op: UnaryOperator, + pub expr: Box, +} + +impl Transpile for UnaryExpression { + fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self.op { + UnaryOperator::Not => { + if let Expression::Unary(Self { + op: UnaryOperator::IsNull, + expr, + }) = &*self.expr + { + expr.transpile(fmt)?; + fmt.write_str(" IS NOT NULL") + } else { + fmt.write_str("NOT(")?; + self.expr.transpile(fmt)?; + fmt.write_char(')') + } + } + UnaryOperator::IsNull => { + self.expr.transpile(fmt)?; + fmt.write_str(" IS NULL") + } + } + } +} From 030ccf84a6b6b84c347a13cfe9f5eb94c0dedd0d Mon Sep 17 00:00:00 2001 From: Tim Diekmann Date: Mon, 23 Feb 2026 02:48:38 +0100 Subject: [PATCH 3/3] BE-415: Fix doc comments for TimeIntervalContainsTimestamp and Overlap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ::TIMESTAMPTZ suffix was on the wrong operator — it belongs to TimeIntervalContainsTimestamp (which appends it to the RHS via transpile_post), not to Overlap (which is plain &&). --- .../src/store/postgres/query/expression/binary.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs index 987b04ed164..8ad1b074325 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs @@ -22,9 +22,9 @@ pub enum BinaryOperator { LessOrEqual, /// ` = ANY()` In, - /// ` @> ` + /// ` @> ::TIMESTAMPTZ` TimeIntervalContainsTimestamp, - /// ` && ::TIMESTAMPTZ` + /// ` && ` Overlap, /// ` <=> ` CosineDistance,