Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
SupportsTxIsolationSnapshot |
SupportsFiltersOnRelationsWithoutJoins |
SupportsDefaultInInsert |
NativeUpsert |
PartialIndex
// InsertReturning | DeleteReturning - unimplemented.
});
Expand Down
4 changes: 2 additions & 2 deletions quaint/src/ast/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub enum OnConflict<'a> {
///
/// let expected_sql = indoc!(
/// "
/// MERGE INTO [users]
/// MERGE INTO [users] WITH (HOLDLOCK)
/// USING (SELECT @P1 AS [id]) AS [dual] ([id])
/// ON [dual].[id] = [users].[id]
/// WHEN NOT MATCHED THEN
Expand All @@ -88,7 +88,7 @@ pub enum OnConflict<'a> {
/// [`DefaultValue::Generated`]: enum.DefaultValue.html#variant.Generated
/// [column has a default value]: struct.Column.html#method.default
DoNothing,
/// ON CONFLICT UPDATE is supported for Sqlite and Postgres
/// ON CONFLICT UPDATE is supported for Sqlite, Postgres, and MSSQL (via MERGE)
Update(Update<'a>, Vec<Column<'a>>),
}

Expand Down
191 changes: 149 additions & 42 deletions quaint/src/ast/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::convert::TryFrom;
pub struct Merge<'a> {
pub(crate) table: Table<'a>,
pub(crate) using: Using<'a>,
pub(crate) when_matched: Option<Update<'a>>,
pub(crate) when_not_matched: Option<Query<'a>>,
pub(crate) returning: Option<Vec<Column<'a>>>,
}
Expand All @@ -23,11 +24,17 @@ impl<'a> Merge<'a> {
Self {
table: table.into(),
using: using.into(),
when_matched: None,
when_not_matched: None,
returning: None,
}
}

pub(crate) fn when_matched(mut self, update: Update<'a>) -> Self {
self.when_matched = Some(update);
self
}

pub(crate) fn when_not_matched<Q>(mut self, query: Q) -> Self
where
Q: Into<Query<'a>>,
Expand All @@ -44,6 +51,146 @@ impl<'a> Merge<'a> {
self.returning = Some(columns.into_iter().map(|k| k.into()).collect());
self
}

/// Build a MERGE from an INSERT with `OnConflict::Update`.
///
/// The ON condition is derived from the explicit constraint columns
/// (not from `table.index_definitions`).
pub(crate) fn from_insert_with_update(insert: Insert<'a>) -> crate::Result<Self> {
let table = insert.table.ok_or_else(|| {
let kind = ErrorKind::conversion("Insert needs to point to a table for conversion to Merge.");
Error::builder(kind).build()
})?;

let (update, constraints) = match insert.on_conflict {
Some(OnConflict::Update(update, constraints)) => (update, constraints),
_ => {
let kind = ErrorKind::conversion("Insert must have OnConflict::Update for this conversion.");
return Err(Error::builder(kind).build());
}
};

if constraints.is_empty() {
let kind = ErrorKind::conversion("OnConflict::Update requires non-empty constraint columns.");
return Err(Error::builder(kind).build());
}

let columns = insert.columns;

for constraint in &constraints {
if !columns.iter().any(|column| column.name == constraint.name) {
let kind = ErrorKind::conversion(format!(
"OnConflict::Update constraint column `{}` must be present in the insert columns.",
constraint.name
));

return Err(Error::builder(kind).build());
}
}

let query = build_using_query(&columns, insert.values)?;
let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect();

// Build ON conditions from the explicit constraint columns.
// If the table has an alias, ON conditions must reference the alias
// (T-SQL requires using the alias once it is declared on the MERGE target).
let table_ref = match &table.typ {
TableType::Table(name) => {
let effective_name = table.alias.clone().unwrap_or_else(|| name.clone());
Table {
typ: TableType::Table(effective_name),
alias: None,
database: if table.alias.is_some() { None } else { table.database.clone() },
index_definitions: Vec::new(),
}
}
_ => {
let kind = ErrorKind::conversion("Merge target must be a simple table.");
return Err(Error::builder(kind).build());
}
};
let on_conditions = build_on_conditions_from_constraints(&constraints, &table_ref);

let using = query.into_using("dual", bare_columns.clone()).on(on_conditions);

let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect();
let not_matched = Insert::multi(bare_columns).values(dual_columns);
let mut merge = Merge::new(table, using)
.when_matched(update)
.when_not_matched(not_matched);

if let Some(columns) = insert.returning {
merge = merge.returning(columns);
}

Ok(merge)
}
}

/// Build ON conditions from explicit constraint columns (AND-joined).
fn build_on_conditions_from_constraints<'a>(constraints: &[Column<'a>], table: &Table<'a>) -> ConditionTree<'a> {
let mut conditions: Option<ConditionTree<'a>> = None;

for col in constraints {
let bare_name = col.name.clone();
let dual_col = Column::new(bare_name.clone()).table("dual");
let table_col = Column::new(bare_name).table(table.clone());
let cond = dual_col.equals(table_col);

conditions = Some(match conditions {
None => cond.into(),
Some(existing) => existing.and(cond),
});
}

conditions.unwrap_or(ConditionTree::NoCondition)
}

/// Extract the USING query from insert values — shared between DoNothing and Update paths.
fn build_using_query<'a>(columns: &[Column<'a>], values: Expression<'a>) -> crate::Result<Query<'a>> {
match values.kind {
ExpressionKind::Row(row) => {
let cols_vals = columns.iter().zip(row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

Ok(Query::from(select))
}
ExpressionKind::Values(values) => {
let mut rows = values.rows.into_iter();
let first_row = rows.next().ok_or_else(|| {
let kind = ErrorKind::conversion("Insert values cannot be empty.");
Error::builder(kind).build()
})?;
let cols_vals = columns.iter().zip(first_row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

let union = rows.fold(Union::new(select), |union, row| {
let cols_vals = columns.iter().zip(row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

union.all(select)
});
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Ok(Query::from(union))
}
ExpressionKind::Selection(selection) => Ok(Query::from(selection)),
ExpressionKind::Parameterized(value) => {
Ok(Select::default().value(ExpressionKind::ParameterizedRow(value)).into())
}
_ => {
let kind = ErrorKind::conversion("Insert type not supported.");
Err(Error::builder(kind).build())
}
}
}

impl<'a> From<Merge<'a>> for Query<'a> {
Expand Down Expand Up @@ -103,53 +250,13 @@ impl<'a> TryFrom<Insert<'a>> for Merge<'a> {
}

let columns = insert.columns;

let query = match insert.values.kind {
ExpressionKind::Row(row) => {
let cols_vals = columns.iter().zip(row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

Query::from(select)
}
ExpressionKind::Values(values) => {
let mut rows = values.rows;
let row = rows.pop().unwrap();
let cols_vals = columns.iter().zip(row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

let union = rows.into_iter().fold(Union::new(select), |union, row| {
let cols_vals = columns.iter().zip(row.values);

let select = cols_vals.fold(Select::default(), |query, (col, val)| {
query.value(val.alias(col.name.clone()))
});

union.all(select)
});

Query::from(union)
}
ExpressionKind::Selection(selection) => Query::from(selection),
ExpressionKind::Parameterized(value) => {
Select::default().value(ExpressionKind::ParameterizedRow(value)).into()
}
_ => {
let kind = ErrorKind::conversion("Insert type not supported.");
return Err(Error::builder(kind).build());
}
};
let query = build_using_query(&columns, insert.values)?;

let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect();

let using = query
.into_using("dual", bare_columns.clone())
.on(table.join_conditions(&columns).unwrap());
.on(table.join_conditions(&columns)?);

let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect();
let not_matched = Insert::multi(bare_columns).values(dual_columns);
Expand Down
6 changes: 3 additions & 3 deletions quaint/src/tests/upsert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::test_api::*;
use crate::{connector::Queryable, prelude::*};
use quaint_test_macros::test_each_connector;

#[test_each_connector(tags("postgresql", "sqlite"))]
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
async fn upsert_on_primary_key(api: &mut dyn TestApi) -> crate::Result<()> {
let table = api.create_temp_table("id int primary key, x int").await?;

Expand Down Expand Up @@ -39,7 +39,7 @@ fn upsert_on_primary_key_query(table: &str) -> Query<'_> {
.into()
}

#[test_each_connector(tags("postgresql", "sqlite"))]
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
async fn upsert_on_unique_field(api: &mut dyn TestApi) -> crate::Result<()> {
let table = api.create_temp_table("id int primary key, x int UNIQUE, y int").await?;

Expand Down Expand Up @@ -82,7 +82,7 @@ fn upsert_on_unique_field_query(table: &str) -> Query<'_> {
.into()
}

#[test_each_connector(tags("postgresql", "sqlite"))]
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
async fn upsert_on_multiple_unique_fields(api: &mut dyn TestApi) -> crate::Result<()> {
let table = api
.create_temp_table("id int primary key, x int, y int, CONSTRAINT ux_x_y UNIQUE (x, y)")
Expand Down
6 changes: 3 additions & 3 deletions quaint/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ pub trait Visitor<'a> {

/// A point to modify an incoming query to make it compatible with the
/// underlying database.
fn compatibility_modifications(&self, query: Query<'a>) -> Query<'a> {
query
fn compatibility_modifications(&self, query: Query<'a>) -> crate::Result<Query<'a>> {
Ok(query)
}

fn surround_with<F>(&mut self, begin: &str, end: &str, f: F) -> Result
Expand Down Expand Up @@ -512,7 +512,7 @@ pub trait Visitor<'a> {

/// A walk through a complete `Query` statement
fn visit_query(&mut self, mut query: Query<'a>) -> Result {
query = self.compatibility_modifications(query);
query = self.compatibility_modifications(query)?;

match query {
Query::Select(select) => self.visit_select(*select),
Expand Down
Loading