diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs index f1b9d347081..1e5a35a1264 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/compile.rs @@ -18,9 +18,8 @@ use type_system::knowledge::Entity; use super::expression::{JoinType, TableName, TableReference}; use crate::store::postgres::query::{ - Alias, Column, Condition, Distinctness, EqualityOperator, Expression, Function, - PostgresQueryPath, PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _, - WindowStatement, + Alias, Column, Distinctness, EqualityOperator, Expression, Function, PostgresQueryPath, + PostgresRecord, SelectExpression, SelectStatement, Table, Transpile as _, WindowStatement, expression::{FromItem, GroupByExpression, PostgresType}, table::{ DataTypeEmbeddings, DatabaseColumn as _, EntityEditions, EntityEmbeddings, @@ -60,7 +59,7 @@ struct PathSelection { ordering: Option<(Ordering, Option)>, } -type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec; +type TableHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Alias) -> Vec; type ColumnHook<'p, 'q, T> = fn(&mut SelectCompiler<'p, 'q, T>, Expression) -> Expression; pub struct SelectCompiler<'p, 'q: 'p, T: QueryRecord> { @@ -201,7 +200,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } } - fn ontology_table_conditions(&mut self, alias: Alias) -> Vec { + fn ontology_table_conditions(&mut self, alias: Alias) -> Vec { let table = Table::OntologyTemporalMetadata.aliased(alias); if let Some(temporal_axes) = self.temporal_axes && self.artifacts.table_info.tables.insert(table) @@ -209,7 +208,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { let transaction_time_index = self.time_index(temporal_axes, TimeAxis::TransactionTime); match temporal_axes { QueryTemporalAxes::DecisionTime { .. } => { - vec![Condition::TimeIntervalContainsTimestamp( + vec![Expression::time_interval_contains_timestamp( Expression::ColumnReference( Column::OntologyTemporalMetadata( OntologyTemporalMetadata::TransactionTime, @@ -220,7 +219,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { )] } QueryTemporalAxes::TransactionTime { .. } => { - vec![Condition::Overlap( + vec![Expression::overlap( Expression::ColumnReference( Column::OntologyTemporalMetadata( OntologyTemporalMetadata::TransactionTime, @@ -236,12 +235,12 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } } - fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec { + fn temporal_metadata_conditions(&mut self, alias: Alias) -> Vec { let mut conditions = Vec::new(); let table = Table::EntityTemporalMetadata.aliased(alias); if self.artifacts.table_info.tables.insert(table.clone()) { if !self.include_drafts { - conditions.push(Condition::Exists(Expression::ColumnReference( + conditions.push(Expression::exists(Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::DraftId).aliased(alias), ))); } @@ -255,7 +254,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { // Adds the pinned timestamp condition, so for the projected decision time, we use // the transaction time and vice versa. conditions.extend([ - Condition::TimeIntervalContainsTimestamp( + Expression::time_interval_contains_timestamp( Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis( pinned_axis, @@ -264,7 +263,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { ), Expression::Parameter(pinned_time_index), ), - Condition::Overlap( + Expression::overlap( Expression::ColumnReference( Column::EntityTemporalMetadata(EntityTemporalMetadata::from_time_axis( variable_axis, @@ -410,7 +409,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { ) } - /// Compiles a [`Filter`] to a `Condition`. + /// Compiles a [`Filter`] to an [`Expression`]. /// /// # Errors /// @@ -420,7 +419,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { pub fn compile_filter<'f: 'q>( &mut self, filter: &'p Filter<'f, R>, - ) -> Result> + ) -> Result> where R::QueryPath<'f>: PostgresQueryPath, { @@ -429,41 +428,41 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { } Ok(match filter { - Filter::All(filters) => Condition::All( + Filter::All(filters) => Expression::all( filters .iter() .map(|filter| self.compile_filter(filter)) .collect::>()?, ), - Filter::Any(filters) => Condition::Any( + Filter::Any(filters) => Expression::any( filters .iter() .map(|filter| self.compile_filter(filter)) .collect::>()?, ), - Filter::Not(filter) => Condition::Not(Box::new(self.compile_filter(filter)?)), - Filter::Equal(lhs, rhs) => Condition::Equal( + Filter::Not(filter) => Expression::not(self.compile_filter(filter)?), + Filter::Equal(lhs, rhs) => Expression::equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::NotEqual(lhs, rhs) => Condition::NotEqual( + Filter::NotEqual(lhs, rhs) => Expression::not_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::Exists { path } => Condition::Exists(self.compile_path_column(path)), - Filter::Greater(lhs, rhs) => Condition::Greater( + Filter::Exists { path } => Expression::exists(self.compile_path_column(path)), + Filter::Greater(lhs, rhs) => Expression::greater( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::GreaterOrEqual(lhs, rhs) => Condition::GreaterOrEqual( + Filter::GreaterOrEqual(lhs, rhs) => Expression::greater_or_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::Less(lhs, rhs) => Condition::Less( + Filter::Less(lhs, rhs) => Expression::less( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), - Filter::LessOrEqual(lhs, rhs) => Condition::LessOrEqual( + Filter::LessOrEqual(lhs, rhs) => Expression::less_or_equal( self.compile_filter_expression(lhs).0, self.compile_filter_expression(rhs).0, ), @@ -610,11 +609,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { }) .chain(once(SelectExpression::Expression { expression: Expression::Function(Function::Min( - Box::new(Expression::CosineDistance( - Box::new(Expression::ColumnReference( + Box::new(Expression::cosine_distance( + Expression::ColumnReference( embeddings_column.into(), - )), - Box::new(parameter_expression), + ), + parameter_expression, )), )), alias: Some("distance"), @@ -643,11 +642,11 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { alias: None, }); self.statement.distinct.push(distance_expression.clone()); - Condition::LessOrEqual(distance_expression, maximum_expression) + Expression::less_or_equal(distance_expression, maximum_expression) } _ => bail!(SelectCompilerError::UnsupportedDistanceExpression), }, - Filter::In(lhs, rhs) => Condition::In( + Filter::In(lhs, rhs) => Expression::r#in( self.compile_filter_expression(lhs).0, self.compile_filter_expression_list(rhs).0, ), @@ -666,7 +665,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::StartsWith(left_filter, right_filter) + Expression::starts_with(left_filter, right_filter) } Filter::EndsWith(lhs, rhs) => { let (left_filter, left_parameter) = self.compile_filter_expression(lhs); @@ -683,7 +682,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::EndsWith(left_filter, right_filter) + Expression::ends_with(left_filter, right_filter) } Filter::ContainsSegment(lhs, rhs) => { let (left_filter, left_parameter) = self.compile_filter_expression(lhs); @@ -700,7 +699,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { right_filter }; - Condition::ContainsSegment(left_filter, right_filter) + Expression::contains_segment(left_filter, right_filter) } }) } @@ -716,7 +715,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { &mut self, path: &R::QueryPath<'f>, operator: EqualityOperator, - ) -> Condition + ) -> Expression where R::QueryPath<'f>: PostgresQueryPath, { @@ -765,10 +764,10 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { match operator { EqualityOperator::Equal => { - Condition::Equal(version_expression, latest_version_expression) + Expression::equal(version_expression, latest_version_expression) } EqualityOperator::NotEqual => { - Condition::NotEqual(version_expression, latest_version_expression) + Expression::not_equal(version_expression, latest_version_expression) } } } @@ -778,7 +777,7 @@ impl<'p, 'q: 'p, R: PostgresRecord> SelectCompiler<'p, 'q, R> { /// /// The following [`Filter`]s will be special cased: /// - Comparing the `"version"` field on [`Table::OntologyIds`] with `"latest"` for equality. - fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option + fn compile_special_filter<'f: 'q>(&mut self, filter: &'p Filter<'f, R>) -> Option where R::QueryPath<'f>: PostgresQueryPath, { @@ -1232,7 +1231,7 @@ impl<'p, 'q: 'p> SelectCompiler<'p, 'q, Entity> { Expression::CaseWhen { conditions: vec![( - Expression::Condition(Box::new(condition)), + condition, Expression::Function(Function::ArrayLiteral { elements: vec![compiler.compile_parameter(property_url).0], element_type: PostgresType::Text, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs deleted file mode 100644 index c779184e0ab..00000000000 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/condition.rs +++ /dev/null @@ -1,403 +0,0 @@ -use core::{fmt, fmt::Write as _}; - -use crate::store::postgres::query::{Expression, Transpile}; - -/// A [`Filter`], which can be transpiled. -/// -/// [`Filter`]: hash_graph_store::filter::Filter -#[derive(Debug, Clone, PartialEq)] -pub enum Condition { - All(Vec), - Any(Vec), - Not(Box), - Equal(Expression, Expression), - NotEqual(Expression, Expression), - Exists(Expression), - Less(Expression, Expression), - LessOrEqual(Expression, Expression), - Greater(Expression, Expression), - GreaterOrEqual(Expression, Expression), - In(Expression, Expression), - TimeIntervalContainsTimestamp(Expression, Expression), - Overlap(Expression, Expression), - StartsWith(Expression, Expression), - EndsWith(Expression, Expression), - ContainsSegment(Expression, Expression), -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum EqualityOperator { - Equal, - NotEqual, -} - -impl Transpile for Condition { - #[expect(clippy::too_many_lines)] - fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::All(conditions) if conditions.is_empty() => fmt.write_str("TRUE"), - Self::Any(conditions) if conditions.is_empty() => fmt.write_str("FALSE"), - Self::All(conditions) => { - for (idx, condition) in conditions.iter().enumerate() { - if idx > 0 { - fmt.write_str(" AND ")?; - } - fmt.write_char('(')?; - condition.transpile(fmt)?; - fmt.write_char(')')?; - } - Ok(()) - } - Self::Any(conditions) => { - if conditions.len() > 1 { - fmt.write_char('(')?; - } - for (idx, condition) in conditions.iter().enumerate() { - if idx > 0 { - fmt.write_str(" OR ")?; - } - fmt.write_char('(')?; - condition.transpile(fmt)?; - fmt.write_char(')')?; - } - if conditions.len() > 1 { - fmt.write_char(')')?; - } - Ok(()) - } - Self::Not(condition) => { - if let Self::Exists(path) = &**condition { - path.transpile(fmt)?; - fmt.write_str(" IS NOT NULL") - } else { - fmt.write_str("NOT(")?; - condition.transpile(fmt)?; - fmt.write_char(')') - } - } - Self::Exists(path) => { - path.transpile(fmt)?; - fmt.write_str(" IS NULL") - } - Self::Equal(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ")?; - rhs.transpile(fmt) - } - Self::NotEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" != ")?; - rhs.transpile(fmt) - } - Self::Less(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" < ")?; - rhs.transpile(fmt) - } - Self::LessOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" <= ")?; - rhs.transpile(fmt) - } - Self::Greater(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" > ")?; - rhs.transpile(fmt) - } - Self::GreaterOrEqual(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" >= ")?; - rhs.transpile(fmt) - } - Self::In(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" = ANY(")?; - rhs.transpile(fmt)?; - fmt.write_char(')') - } - Self::TimeIntervalContainsTimestamp(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" @> ")?; - rhs.transpile(fmt)?; - fmt.write_str("::TIMESTAMPTZ") - } - Self::Overlap(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" && ")?; - rhs.transpile(fmt) - } - Self::StartsWith(lhs, rhs) => { - fmt.write_str("starts_with(")?; - lhs.transpile(fmt)?; - fmt.write_str(", ")?; - rhs.transpile(fmt)?; - fmt.write_char(')') - } - Self::EndsWith(lhs, rhs) => { - fmt.write_str("right(")?; - lhs.transpile(fmt)?; - fmt.write_str(", length(")?; - rhs.transpile(fmt)?; - fmt.write_str(")) = ")?; - rhs.transpile(fmt) - } - Self::ContainsSegment(lhs, rhs) => { - fmt.write_str("strpos(")?; - lhs.transpile(fmt)?; - fmt.write_str(", ")?; - rhs.transpile(fmt)?; - fmt.write_str(") > 0") - } - } - } -} - -#[cfg(test)] -mod tests { - use alloc::borrow::Cow; - - use hash_codec::numeric::Real; - use hash_graph_store::{ - data_type::DataTypeQueryPath, - filter::{Filter, FilterExpression, Parameter}, - }; - use postgres_types::ToSql; - use type_system::ontology::DataTypeWithMetadata; - - use crate::store::postgres::query::{SelectCompiler, Transpile as _}; - - fn test_condition<'p, 'f: 'p>( - filter: &'f Filter<'p, DataTypeWithMetadata>, - rendered: &'static str, - parameters: &[&'p dyn ToSql], - ) { - let mut compiler = SelectCompiler::new(None, false); - let condition = compiler - .compile_filter(filter) - .expect("failed to compile filter"); - - assert_eq!(condition.transpile_to_string(), rendered); - - let parameter_list = parameters - .iter() - .map(|parameter| format!("{parameter:?}")) - .collect::>(); - let expected_parameters = compiler - .compile() - .1 - .iter() - .map(|parameter| format!("{parameter:?}")) - .collect::>(); - - assert_eq!(parameter_list, expected_parameters); - } - - #[test] - fn transpile_empty_condition() { - test_condition(&Filter::All(vec![]), "TRUE", &[]); - test_condition(&Filter::Any(vec![]), "FALSE", &[]); - } - - #[test] - fn transpile_exists_condition() { - test_condition( - &Filter::Exists { - path: DataTypeQueryPath::Description, - }, - r#""data_types_0_1_0"."schema"->>'description' IS NULL"#, - &[], - ); - - test_condition( - &Filter::Not(Box::new(Filter::Exists { - path: DataTypeQueryPath::Description, - })), - r#""data_types_0_1_0"."schema"->>'description' IS NOT NULL"#, - &[], - ); - } - - #[test] - fn transpile_all_condition() { - test_condition( - &Filter::All(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - - test_condition( - &Filter::All(vec![ - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::BaseUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/", - )), - convert: None, - }, - ), - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Version, - }, - FilterExpression::Parameter { - parameter: Parameter::Decimal(Real::from_natural(1, 1)), - convert: None, - }, - ), - ]), - r#"("ontology_ids_0_1_0"."base_url" = $1) AND ("ontology_ids_0_1_0"."version" = $2)"#, - &[ - &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", - &Real::from_natural(1, 1), - ], - ); - } - - #[test] - fn transpile_any_condition() { - test_condition( - &Filter::Any(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - - test_condition( - &Filter::Any(vec![ - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::BaseUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/", - )), - convert: None, - }, - ), - Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Version, - }, - FilterExpression::Parameter { - parameter: Parameter::Decimal(Real::from_natural(1, 1)), - convert: None, - }, - ), - ]), - r#"(("ontology_ids_0_1_0"."base_url" = $1) OR ("ontology_ids_0_1_0"."version" = $2))"#, - &[ - &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", - &Real::from_natural(1, 1), - ], - ); - } - - #[test] - fn transpile_not_condition() { - test_condition( - &Filter::Not(Box::new(Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::VersionedUrl, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed( - "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", - )), - convert: None, - }, - ))), - r#"NOT("data_types_0_1_0"."schema"->>'$id' = $1)"#, - &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], - ); - } - - #[test] - fn transpile_starts_with_condition() { - test_condition( - &Filter::StartsWith( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("foo")), - convert: None, - }, - ), - r#"starts_with("data_types_0_1_0"."schema"->>'title', $1)"#, - &[&"foo"], - ); - } - - #[test] - fn transpile_ends_with_condition() { - test_condition( - &Filter::EndsWith( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("bar")), - convert: None, - }, - ), - r#"right("data_types_0_1_0"."schema"->>'title', length($1)) = $1"#, - &[&"bar"], - ); - } - - #[test] - fn transpile_contains_segment_condition() { - test_condition( - &Filter::ContainsSegment( - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - FilterExpression::Parameter { - parameter: Parameter::Text(Cow::Borrowed("baz")), - convert: None, - }, - ), - r#"strpos("data_types_0_1_0"."schema"->>'title', $1) > 0"#, - &[&"baz"], - ); - } - - #[test] - fn render_without_parameters() { - test_condition( - &Filter::Any(vec![Filter::Equal( - FilterExpression::Path { - path: DataTypeQueryPath::Description, - }, - FilterExpression::Path { - path: DataTypeQueryPath::Title, - }, - )]), - r#"("data_types_0_1_0"."schema"->>'description' = "data_types_0_1_0"."schema"->>'title')"#, - &[], - ); - } -} diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs new file mode 100644 index 00000000000..8ad1b074325 --- /dev/null +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/binary.rs @@ -0,0 +1,80 @@ +use core::fmt::{self, Write as _}; + +use crate::store::postgres::query::{Expression, Transpile}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[expect( + clippy::doc_paragraphs_missing_punctuation, + reason = "The documentation is only the transpiled symbols" +)] +pub enum BinaryOperator { + /// ` = ` + Equal, + /// ` != ` + NotEqual, + /// ` > ` + Greater, + /// ` >= ` + GreaterOrEqual, + /// ` < ` + Less, + /// ` <= ` + LessOrEqual, + /// ` = ANY()` + In, + /// ` @> ::TIMESTAMPTZ` + TimeIntervalContainsTimestamp, + /// ` && ` + Overlap, + /// ` <=> ` + CosineDistance, +} + +impl BinaryOperator { + fn transpile(self, fmt: &mut fmt::Formatter) -> fmt::Result { + let string = match self { + Self::Equal => " = ", + Self::NotEqual => " != ", + Self::Greater => " > ", + Self::GreaterOrEqual => " >= ", + Self::Less => " < ", + Self::LessOrEqual => " <= ", + Self::In => " = ANY(", + Self::TimeIntervalContainsTimestamp => " @> ", + Self::Overlap => " && ", + Self::CosineDistance => " <=> ", + }; + fmt.write_str(string) + } + + fn transpile_post(self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::In => fmt.write_char(')'), + Self::TimeIntervalContainsTimestamp => fmt.write_str("::TIMESTAMPTZ"), + Self::Equal + | Self::NotEqual + | Self::Greater + | Self::GreaterOrEqual + | Self::Less + | Self::LessOrEqual + | Self::Overlap + | Self::CosineDistance => Ok(()), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BinaryExpression { + pub op: BinaryOperator, + pub left: Box, + pub right: Box, +} + +impl Transpile for BinaryExpression { + fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.left.transpile(fmt)?; + self.op.transpile(fmt)?; + self.right.transpile(fmt)?; + self.op.transpile_post(fmt) + } +} diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs index 4ea98e332ad..9b478c687be 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/conditional.rs @@ -6,7 +6,8 @@ use hash_graph_store::filter::PathToken; use super::ColumnReference; use crate::store::postgres::query::{ - Condition, SelectStatement, Table, Transpile, WindowStatement, + SelectStatement, Table, Transpile, WindowStatement, + expression::{BinaryExpression, BinaryOperator, UnaryExpression, UnaryOperator}, }; #[derive(Debug, Clone, PartialEq)] @@ -207,7 +208,17 @@ impl Transpile for PostgresType { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum EqualityOperator { + Equal, + NotEqual, +} + /// A compiled expression in Postgres. +/// +/// This type unifies both value expressions and boolean conditions. In SQL, conditions are +/// boolean-valued expressions — there is no fundamental distinction between a "condition" and +/// an "expression". This allows natural composition, e.g. negating any boolean expression. #[derive(Debug, Clone, PartialEq)] pub enum Expression { ColumnReference(ColumnReference<'static>), @@ -217,7 +228,6 @@ pub enum Expression { /// prevent SQL injection and no user input should ever be used as a [`Constant`]. Constant(Constant), Function(Function), - CosineDistance(Box, Box), Window(Box, WindowStatement), Cast(Box, PostgresType), /// Row expansion - expands a composite type into its constituent columns. @@ -242,30 +252,164 @@ pub enum Expression { /// Optional else result if no condition matches. else_result: Option>, }, - /// Wraps a [`Condition`] for use in expression contexts. - /// - /// This allows conditions (which evaluate to boolean) to be used where expressions - /// are expected, such as in CASE WHEN conditions. - /// - /// # Example SQL - /// ```sql - /// CASE WHEN (a = b AND c != d) THEN 'yes' ELSE 'no' END - /// ``` - Condition(Box), + + Unary(UnaryExpression), + Binary(BinaryExpression), + + /// Conjunction of conditions. Transpiles to `(c1) AND (c2) AND ...`. + /// Empty list transpiles to `TRUE`. + All(Vec), + /// Disjunction of conditions. Transpiles to `((c1) OR (c2) OR ...)`. + /// Empty list transpiles to `FALSE`. + Any(Vec), + + StartsWith(Box, Box), + EndsWith(Box, Box), + ContainsSegment(Box, Box), +} + +/// Convenience constructors for condition variants to avoid `Box::new()` boilerplate. +impl Expression { + #[must_use] + pub const fn all(conditions: Vec) -> Self { + Self::All(conditions) + } + + #[must_use] + pub const fn any(conditions: Vec) -> Self { + Self::Any(conditions) + } + + #[must_use] + pub fn not(inner: Self) -> Self { + Self::Unary(UnaryExpression { + op: UnaryOperator::Not, + expr: Box::new(inner), + }) + } + + #[must_use] + pub fn equal(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::Equal, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn not_equal(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::NotEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn exists(expr: Self) -> Self { + Self::Unary(UnaryExpression { + op: UnaryOperator::IsNull, + expr: Box::new(expr), + }) + } + + #[must_use] + pub fn less(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::Less, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn less_or_equal(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::LessOrEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn greater(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::Greater, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn greater_or_equal(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::GreaterOrEqual, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn r#in(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::In, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn time_interval_contains_timestamp(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::TimeIntervalContainsTimestamp, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn overlap(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::Overlap, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn cosine_distance(lhs: Self, rhs: Self) -> Self { + Self::Binary(BinaryExpression { + op: BinaryOperator::CosineDistance, + left: Box::new(lhs), + right: Box::new(rhs), + }) + } + + #[must_use] + pub fn starts_with(lhs: Self, rhs: Self) -> Self { + Self::StartsWith(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn ends_with(lhs: Self, rhs: Self) -> Self { + Self::EndsWith(Box::new(lhs), Box::new(rhs)) + } + + #[must_use] + pub fn contains_segment(lhs: Self, rhs: Self) -> Self { + Self::ContainsSegment(Box::new(lhs), Box::new(rhs)) + } } impl Transpile for Expression { fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { + // --- Value expressions --- Self::ColumnReference(column) => column.transpile(fmt), Self::Parameter(index) => write!(fmt, "${index}"), Self::Constant(constant) => constant.transpile(fmt), Self::Function(function) => function.transpile(fmt), - Self::CosineDistance(lhs, rhs) => { - lhs.transpile(fmt)?; - fmt.write_str(" <=> ")?; - rhs.transpile(fmt) - } Self::Window(expression, window) => { expression.transpile(fmt)?; fmt.write_str(" OVER (")?; @@ -301,11 +445,63 @@ impl Transpile for Expression { } fmt.write_str(" END") } - Self::Condition(condition) => { - fmt.write_char('(')?; - condition.transpile(fmt)?; + + Self::Unary(unary) => unary.transpile(fmt), + Self::Binary(binary) => binary.transpile(fmt), + + // --- Boolean conditions --- + Self::All(conditions) if conditions.is_empty() => fmt.write_str("TRUE"), + Self::Any(conditions) if conditions.is_empty() => fmt.write_str("FALSE"), + Self::All(conditions) => { + for (idx, condition) in conditions.iter().enumerate() { + if idx > 0 { + fmt.write_str(" AND ")?; + } + fmt.write_char('(')?; + condition.transpile(fmt)?; + fmt.write_char(')')?; + } + Ok(()) + } + Self::Any(conditions) => { + if conditions.len() > 1 { + fmt.write_char('(')?; + } + for (idx, condition) in conditions.iter().enumerate() { + if idx > 0 { + fmt.write_str(" OR ")?; + } + fmt.write_char('(')?; + condition.transpile(fmt)?; + fmt.write_char(')')?; + } + if conditions.len() > 1 { + fmt.write_char(')')?; + } + Ok(()) + } + Self::StartsWith(lhs, rhs) => { + fmt.write_str("starts_with(")?; + lhs.transpile(fmt)?; + fmt.write_str(", ")?; + rhs.transpile(fmt)?; fmt.write_char(')') } + Self::EndsWith(lhs, rhs) => { + fmt.write_str("right(")?; + lhs.transpile(fmt)?; + fmt.write_str(", length(")?; + rhs.transpile(fmt)?; + fmt.write_str(")) = ")?; + rhs.transpile(fmt) + } + Self::ContainsSegment(lhs, rhs) => { + fmt.write_str("strpos(")?; + lhs.transpile(fmt)?; + fmt.write_str(", ")?; + rhs.transpile(fmt)?; + fmt.write_str(") > 0") + } } } } @@ -322,11 +518,19 @@ where #[cfg(test)] mod tests { - use hash_graph_store::data_type::DataTypeQueryPath; + use alloc::borrow::Cow; + + use hash_codec::numeric::Real; + use hash_graph_store::{ + data_type::DataTypeQueryPath, + filter::{Filter, FilterExpression, Parameter}, + }; + use postgres_types::ToSql; + use type_system::ontology::DataTypeWithMetadata; use super::*; use crate::store::postgres::query::{ - Alias, PostgresQueryPath as _, test_helper::max_version_expression, + Alias, PostgresQueryPath as _, SelectCompiler, test_helper::max_version_expression, }; #[test] @@ -435,4 +639,239 @@ mod tests { }); assert_eq!(empty_array.transpile_to_string(), "ARRAY[]::text[]"); } + + fn test_condition<'p, 'f: 'p>( + filter: &'f Filter<'p, DataTypeWithMetadata>, + rendered: &'static str, + parameters: &[&'p dyn ToSql], + ) { + let mut compiler = SelectCompiler::new(None, false); + let condition = compiler + .compile_filter(filter) + .expect("failed to compile filter"); + + assert_eq!(condition.transpile_to_string(), rendered); + + let parameter_list = parameters + .iter() + .map(|parameter| format!("{parameter:?}")) + .collect::>(); + let expected_parameters = compiler + .compile() + .1 + .iter() + .map(|parameter| format!("{parameter:?}")) + .collect::>(); + + assert_eq!(parameter_list, expected_parameters); + } + + #[test] + fn transpile_empty_condition() { + test_condition(&Filter::All(vec![]), "TRUE", &[]); + test_condition(&Filter::Any(vec![]), "FALSE", &[]); + } + + #[test] + fn transpile_exists_condition() { + test_condition( + &Filter::Exists { + path: DataTypeQueryPath::Description, + }, + r#""data_types_0_1_0"."schema"->>'description' IS NULL"#, + &[], + ); + + test_condition( + &Filter::Not(Box::new(Filter::Exists { + path: DataTypeQueryPath::Description, + })), + r#""data_types_0_1_0"."schema"->>'description' IS NOT NULL"#, + &[], + ); + } + + #[test] + fn transpile_all_condition() { + test_condition( + &Filter::All(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + + test_condition( + &Filter::All(vec![ + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::BaseUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/", + )), + convert: None, + }, + ), + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Version, + }, + FilterExpression::Parameter { + parameter: Parameter::Decimal(Real::from_natural(1, 1)), + convert: None, + }, + ), + ]), + r#"("ontology_ids_0_1_0"."base_url" = $1) AND ("ontology_ids_0_1_0"."version" = $2)"#, + &[ + &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", + &Real::from_natural(1, 1), + ], + ); + } + + #[test] + fn transpile_any_condition() { + test_condition( + &Filter::Any(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + + test_condition( + &Filter::Any(vec![ + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::BaseUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/", + )), + convert: None, + }, + ), + Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Version, + }, + FilterExpression::Parameter { + parameter: Parameter::Decimal(Real::from_natural(1, 1)), + convert: None, + }, + ), + ]), + r#"(("ontology_ids_0_1_0"."base_url" = $1) OR ("ontology_ids_0_1_0"."version" = $2))"#, + &[ + &"https://blockprotocol.org/@blockprotocol/types/data-type/text/", + &Real::from_natural(1, 1), + ], + ); + } + + #[test] + fn transpile_not_condition() { + test_condition( + &Filter::Not(Box::new(Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::VersionedUrl, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed( + "https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1", + )), + convert: None, + }, + ))), + r#"NOT("data_types_0_1_0"."schema"->>'$id' = $1)"#, + &[&"https://blockprotocol.org/@blockprotocol/types/data-type/text/v/1"], + ); + } + + #[test] + fn transpile_starts_with_condition() { + test_condition( + &Filter::StartsWith( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("foo")), + convert: None, + }, + ), + r#"starts_with("data_types_0_1_0"."schema"->>'title', $1)"#, + &[&"foo"], + ); + } + + #[test] + fn transpile_ends_with_condition() { + test_condition( + &Filter::EndsWith( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("bar")), + convert: None, + }, + ), + r#"right("data_types_0_1_0"."schema"->>'title', length($1)) = $1"#, + &[&"bar"], + ); + } + + #[test] + fn transpile_contains_segment_condition() { + test_condition( + &Filter::ContainsSegment( + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + FilterExpression::Parameter { + parameter: Parameter::Text(Cow::Borrowed("baz")), + convert: None, + }, + ), + r#"strpos("data_types_0_1_0"."schema"->>'title', $1) > 0"#, + &[&"baz"], + ); + } + + #[test] + fn render_without_parameters() { + test_condition( + &Filter::Any(vec![Filter::Equal( + FilterExpression::Path { + path: DataTypeQueryPath::Description, + }, + FilterExpression::Path { + path: DataTypeQueryPath::Title, + }, + )]), + r#"("data_types_0_1_0"."schema"->>'description' = "data_types_0_1_0"."schema"->>'title')"#, + &[], + ); + } } diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs index 69a7038f3f9..ec93ebcc8f5 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/from_item.rs @@ -1,7 +1,7 @@ use core::fmt::{self, Write as _}; -use super::{ColumnName, Function, JoinType, TableReference, TableSample}; -use crate::store::postgres::query::{Condition, SelectStatement, Transpile}; +use super::{ColumnName, Expression, Function, JoinType, TableReference, TableSample}; +use crate::store::postgres::query::{SelectStatement, Transpile}; /// A FROM item in a PostgreSQL query. /// @@ -171,7 +171,7 @@ pub enum FromItem<'id> { /// /// Multiple conditions support composite joins (e.g., multi-column foreign keys). /// When empty, transpiles to `ON TRUE`, producing a cartesian product. - condition: Vec, + condition: Vec, }, /// A JOIN using a USING clause with specified column names. @@ -320,7 +320,7 @@ impl<'id> FromItem<'id> { self, #[builder(start_fn)] r#type: JoinType, #[builder(start_fn, into)] from: Self, - #[builder(setters(vis = ""))] condition: Vec, + #[builder(setters(vis = ""))] condition: Vec, #[builder(setters(vis = ""))] join_using_alias: Option>, #[builder(setters(vis = ""))] columns: Vec>, ) -> Self { @@ -581,15 +581,12 @@ mod from_item_join_builder_impl { IsSet, IsUnset, SetColumns, SetCondition, SetJoinUsingAlias, State, }, }; - use crate::store::postgres::query::{ - Condition, - expression::{ColumnName, TableReference}, - }; + use crate::store::postgres::query::expression::{ColumnName, Expression, TableReference}; impl<'id, S: State> FromItemJoinBuilder<'id, S> { pub fn on( self, - conditions: Vec, + conditions: Vec, ) -> FromItemJoinBuilder<'id, SetCondition>> where S: State, @@ -958,7 +955,7 @@ mod tests { let base = FromItem::table(Table::DataTypes).build(); - let conditions = vec![Condition::Equal( + let conditions = vec![Expression::equal( Expression::ColumnReference(Column::DataTypes(DataTypes::OntologyId).into()), Expression::ColumnReference(Column::OntologyIds(OntologyIds::OntologyId).into()), )]; diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs index 402ccc95e05..f24be237ff7 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/mod.rs @@ -1,3 +1,4 @@ +mod binary; mod column_reference; mod conditional; mod from_item; @@ -8,12 +9,14 @@ mod order_clause; mod select_clause; mod table_reference; mod table_sample; +mod unary; mod where_clause; mod with_clause; pub use self::{ + binary::{BinaryExpression, BinaryOperator}, column_reference::{ColumnName, ColumnReference}, - conditional::{Constant, Expression, Function, PostgresType}, + conditional::{Constant, EqualityOperator, Expression, Function, PostgresType}, from_item::FromItem, group_by_clause::GroupByExpression, join_type::JoinType, @@ -21,6 +24,7 @@ pub use self::{ select_clause::SelectExpression, table_reference::{TableName, TableReference}, table_sample::TableSample, + unary::{UnaryExpression, UnaryOperator}, where_clause::WhereExpression, with_clause::WithExpression, }; diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs new file mode 100644 index 00000000000..8dbbee87d52 --- /dev/null +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/unary.rs @@ -0,0 +1,46 @@ +use core::fmt::{self, Write as _}; + +use crate::store::postgres::query::{Expression, Transpile}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[expect( + clippy::doc_paragraphs_missing_punctuation, + reason = "The documentation is only the transpiled symbols" +)] +pub enum UnaryOperator { + /// `NOT()` + Not, + /// ` IS NULL` + IsNull, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UnaryExpression { + pub op: UnaryOperator, + pub expr: Box, +} + +impl Transpile for UnaryExpression { + fn transpile(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self.op { + UnaryOperator::Not => { + if let Expression::Unary(Self { + op: UnaryOperator::IsNull, + expr, + }) = &*self.expr + { + expr.transpile(fmt)?; + fmt.write_str(" IS NOT NULL") + } else { + fmt.write_str("NOT(")?; + self.expr.transpile(fmt)?; + fmt.write_char(')') + } + } + UnaryOperator::IsNull => { + self.expr.transpile(fmt)?; + fmt.write_str(" IS NULL") + } + } + } +} diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs index f9b1bd4d48b..1ae7fa895d3 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/expression/where_clause.rs @@ -2,13 +2,11 @@ use core::fmt; use hash_graph_store::query::{NullOrdering, Ordering}; -use crate::store::postgres::query::{ - Condition, Expression, Transpile, expression::conditional::Transpiler, -}; +use crate::store::postgres::query::{Expression, Transpile, expression::conditional::Transpiler}; #[derive(Debug, Clone, Default, PartialEq)] pub struct WhereExpression { - pub conditions: Vec, + pub conditions: Vec, pub cursor: Vec<( Expression, Option, @@ -18,7 +16,7 @@ pub struct WhereExpression { } impl WhereExpression { - pub fn add_condition(&mut self, condition: Condition) { + pub fn add_condition(&mut self, condition: Expression) { self.conditions.push(condition); } diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs index a212ae82224..11ec1956bf8 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/mod.rs @@ -3,7 +3,6 @@ //! Postgres implementation to compile queries. mod compile; -mod condition; mod data_type; mod entity; mod entity_type; @@ -31,9 +30,9 @@ use type_system::knowledge::{Entity, PropertyValue}; pub use self::{ compile::{SelectCompiler, SelectCompilerError}, - condition::{Condition, EqualityOperator}, expression::{ - Constant, Expression, Function, SelectExpression, WhereExpression, WithExpression, + Constant, EqualityOperator, Expression, Function, SelectExpression, WhereExpression, + WithExpression, }, statement::{ Distinctness, InsertStatementBuilder, SelectStatement, Statement, WindowStatement, diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs index f6490aa6142..3f10b19f70c 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/statement/select.rs @@ -1379,8 +1379,8 @@ mod tests { &compiler, r#" SELECT ("entity_editions_0_1_0"."properties" - (CASE WHEN - (($1 = ANY("entity_is_of_type_ids_0_1_0"."base_urls")) - AND ("entity_temporal_metadata_0_0_0"."entity_uuid" != $2)) + ($1 = ANY("entity_is_of_type_ids_0_1_0"."base_urls")) + AND ("entity_temporal_metadata_0_0_0"."entity_uuid" != $2) THEN ARRAY[$3]::text[] ELSE ARRAY[]::text[] END)) FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" diff --git a/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs b/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs index f3cc2fa7845..8e3b708eb2b 100644 --- a/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs +++ b/libs/@local/graph/postgres-store/src/store/postgres/query/table.rs @@ -12,9 +12,7 @@ use hash_graph_temporal_versioning::TimeAxis; use postgres_types::ToSql; use super::expression::{ColumnName, ColumnReference, TableName, TableReference}; -use crate::store::postgres::query::{ - Condition, Constant, Expression, Transpile, expression::JoinType, -}; +use crate::store::postgres::query::{Constant, Expression, Transpile, expression::JoinType}; /// The name of a [`Table`] in the Postgres database. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -1927,13 +1925,13 @@ impl ForeignKeyReference { } } - pub fn conditions(self, on_alias: Alias, join_alias: Alias) -> Vec { + pub fn conditions(self, on_alias: Alias, join_alias: Alias) -> Vec { match self { Self::Single { join, on, join_type: _, - } => vec![Condition::Equal( + } => vec![Expression::equal( Expression::ColumnReference(join.aliased(join_alias)), Expression::ColumnReference(on.aliased(on_alias)), )], @@ -1942,11 +1940,11 @@ impl ForeignKeyReference { on: [on1, on2], join_type: _, } => vec![ - Condition::Equal( + Expression::equal( Expression::ColumnReference(join1.aliased(join_alias)), Expression::ColumnReference(on1.aliased(on_alias)), ), - Condition::Equal( + Expression::equal( Expression::ColumnReference(join2.aliased(join_alias)), Expression::ColumnReference(on2.aliased(on_alias)), ), @@ -2154,7 +2152,7 @@ impl Relation { } #[must_use] - pub fn additional_conditions(self, table: &TableReference<'_>) -> Vec { + pub fn additional_conditions(self, table: &TableReference<'_>) -> Vec { match self { Self::Reference { table: reference_table, @@ -2166,7 +2164,7 @@ impl Relation { column .inheritance_depth() .map_or_else(Vec::new, |inheritance_depth| { - vec![Condition::LessOrEqual( + vec![Expression::less_or_equal( Expression::ColumnReference( column.aliased(table.alias.unwrap_or_default()), ),