Skip to content

Commit a0f221c

Browse files
committed
feat(mssql): support native upsert via MERGE WITH (HOLDLOCK)
Enable NativeUpsert capability for MSSQL by translating INSERT ... ON CONFLICT UPDATE into MERGE ... WITH (HOLDLOCK) statements, providing atomic upsert semantics equivalent to PostgreSQL's and SQLite's ON CONFLICT DO UPDATE. - Add Merge::from_insert_with_update() to convert INSERT with OnConflict::Update into MERGE with WHEN MATCHED THEN UPDATE - Emit WITH (HOLDLOCK) on all MSSQL MERGE statements for serializable isolation under concurrent load - Place WITH (HOLDLOCK) before table alias per T-SQL syntax rules - Use alias in ON clause when merge target is aliased - Extract shared build_using_query() helper to eliminate duplication between DoNothing and Update paths - Make Visitor::compatibility_modifications return Result to propagate errors from Merge construction instead of panicking
1 parent 3c6e192 commit a0f221c

7 files changed

Lines changed: 477 additions & 70 deletions

File tree

psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
5555
SupportsTxIsolationSnapshot |
5656
SupportsFiltersOnRelationsWithoutJoins |
5757
SupportsDefaultInInsert |
58+
NativeUpsert |
5859
PartialIndex
5960
// InsertReturning | DeleteReturning - unimplemented.
6061
});

quaint/src/ast/insert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub enum OnConflict<'a> {
6363
///
6464
/// let expected_sql = indoc!(
6565
/// "
66-
/// MERGE INTO [users]
66+
/// MERGE INTO [users] WITH (HOLDLOCK)
6767
/// USING (SELECT @P1 AS [id]) AS [dual] ([id])
6868
/// ON [dual].[id] = [users].[id]
6969
/// WHEN NOT MATCHED THEN
@@ -88,7 +88,7 @@ pub enum OnConflict<'a> {
8888
/// [`DefaultValue::Generated`]: enum.DefaultValue.html#variant.Generated
8989
/// [column has a default value]: struct.Column.html#method.default
9090
DoNothing,
91-
/// ON CONFLICT UPDATE is supported for Sqlite and Postgres
91+
/// ON CONFLICT UPDATE is supported for Sqlite, Postgres, and MSSQL (via MERGE)
9292
Update(Update<'a>, Vec<Column<'a>>),
9393
}
9494

quaint/src/ast/merge.rs

Lines changed: 149 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::convert::TryFrom;
1010
pub struct Merge<'a> {
1111
pub(crate) table: Table<'a>,
1212
pub(crate) using: Using<'a>,
13+
pub(crate) when_matched: Option<Update<'a>>,
1314
pub(crate) when_not_matched: Option<Query<'a>>,
1415
pub(crate) returning: Option<Vec<Column<'a>>>,
1516
}
@@ -23,11 +24,17 @@ impl<'a> Merge<'a> {
2324
Self {
2425
table: table.into(),
2526
using: using.into(),
27+
when_matched: None,
2628
when_not_matched: None,
2729
returning: None,
2830
}
2931
}
3032

33+
pub(crate) fn when_matched(mut self, update: Update<'a>) -> Self {
34+
self.when_matched = Some(update);
35+
self
36+
}
37+
3138
pub(crate) fn when_not_matched<Q>(mut self, query: Q) -> Self
3239
where
3340
Q: Into<Query<'a>>,
@@ -44,6 +51,146 @@ impl<'a> Merge<'a> {
4451
self.returning = Some(columns.into_iter().map(|k| k.into()).collect());
4552
self
4653
}
54+
55+
/// Build a MERGE from an INSERT with `OnConflict::Update`.
56+
///
57+
/// The ON condition is derived from the explicit constraint columns
58+
/// (not from `table.index_definitions`).
59+
pub(crate) fn from_insert_with_update(insert: Insert<'a>) -> crate::Result<Self> {
60+
let table = insert.table.ok_or_else(|| {
61+
let kind = ErrorKind::conversion("Insert needs to point to a table for conversion to Merge.");
62+
Error::builder(kind).build()
63+
})?;
64+
65+
let (update, constraints) = match insert.on_conflict {
66+
Some(OnConflict::Update(update, constraints)) => (update, constraints),
67+
_ => {
68+
let kind = ErrorKind::conversion("Insert must have OnConflict::Update for this conversion.");
69+
return Err(Error::builder(kind).build());
70+
}
71+
};
72+
73+
if constraints.is_empty() {
74+
let kind = ErrorKind::conversion("OnConflict::Update requires non-empty constraint columns.");
75+
return Err(Error::builder(kind).build());
76+
}
77+
78+
let columns = insert.columns;
79+
80+
for constraint in &constraints {
81+
if !columns.iter().any(|column| column.name == constraint.name) {
82+
let kind = ErrorKind::conversion(format!(
83+
"OnConflict::Update constraint column `{}` must be present in the insert columns.",
84+
constraint.name
85+
));
86+
87+
return Err(Error::builder(kind).build());
88+
}
89+
}
90+
91+
let query = build_using_query(&columns, insert.values)?;
92+
let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect();
93+
94+
// Build ON conditions from the explicit constraint columns.
95+
// If the table has an alias, ON conditions must reference the alias
96+
// (T-SQL requires using the alias once it is declared on the MERGE target).
97+
let table_ref = match &table.typ {
98+
TableType::Table(name) => {
99+
let effective_name = table.alias.clone().unwrap_or_else(|| name.clone());
100+
Table {
101+
typ: TableType::Table(effective_name),
102+
alias: None,
103+
database: if table.alias.is_some() { None } else { table.database.clone() },
104+
index_definitions: Vec::new(),
105+
}
106+
}
107+
_ => {
108+
let kind = ErrorKind::conversion("Merge target must be a simple table.");
109+
return Err(Error::builder(kind).build());
110+
}
111+
};
112+
let on_conditions = build_on_conditions_from_constraints(&constraints, &table_ref);
113+
114+
let using = query.into_using("dual", bare_columns.clone()).on(on_conditions);
115+
116+
let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect();
117+
let not_matched = Insert::multi(bare_columns).values(dual_columns);
118+
let mut merge = Merge::new(table, using)
119+
.when_matched(update)
120+
.when_not_matched(not_matched);
121+
122+
if let Some(columns) = insert.returning {
123+
merge = merge.returning(columns);
124+
}
125+
126+
Ok(merge)
127+
}
128+
}
129+
130+
/// Build ON conditions from explicit constraint columns (AND-joined).
131+
fn build_on_conditions_from_constraints<'a>(constraints: &[Column<'a>], table: &Table<'a>) -> ConditionTree<'a> {
132+
let mut conditions: Option<ConditionTree<'a>> = None;
133+
134+
for col in constraints {
135+
let bare_name = col.name.clone();
136+
let dual_col = Column::new(bare_name.clone()).table("dual");
137+
let table_col = Column::new(bare_name).table(table.clone());
138+
let cond = dual_col.equals(table_col);
139+
140+
conditions = Some(match conditions {
141+
None => cond.into(),
142+
Some(existing) => existing.and(cond),
143+
});
144+
}
145+
146+
conditions.unwrap_or(ConditionTree::NoCondition)
147+
}
148+
149+
/// Extract the USING query from insert values — shared between DoNothing and Update paths.
150+
fn build_using_query<'a>(columns: &[Column<'a>], values: Expression<'a>) -> crate::Result<Query<'a>> {
151+
match values.kind {
152+
ExpressionKind::Row(row) => {
153+
let cols_vals = columns.iter().zip(row.values);
154+
155+
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
156+
query.value(val.alias(col.name.clone()))
157+
});
158+
159+
Ok(Query::from(select))
160+
}
161+
ExpressionKind::Values(values) => {
162+
let mut rows = values.rows.into_iter();
163+
let first_row = rows.next().ok_or_else(|| {
164+
let kind = ErrorKind::conversion("Insert values cannot be empty.");
165+
Error::builder(kind).build()
166+
})?;
167+
let cols_vals = columns.iter().zip(first_row.values);
168+
169+
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
170+
query.value(val.alias(col.name.clone()))
171+
});
172+
173+
let union = rows.fold(Union::new(select), |union, row| {
174+
let cols_vals = columns.iter().zip(row.values);
175+
176+
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
177+
query.value(val.alias(col.name.clone()))
178+
});
179+
180+
union.all(select)
181+
});
182+
183+
Ok(Query::from(union))
184+
}
185+
ExpressionKind::Selection(selection) => Ok(Query::from(selection)),
186+
ExpressionKind::Parameterized(value) => {
187+
Ok(Select::default().value(ExpressionKind::ParameterizedRow(value)).into())
188+
}
189+
_ => {
190+
let kind = ErrorKind::conversion("Insert type not supported.");
191+
Err(Error::builder(kind).build())
192+
}
193+
}
47194
}
48195

49196
impl<'a> From<Merge<'a>> for Query<'a> {
@@ -103,53 +250,13 @@ impl<'a> TryFrom<Insert<'a>> for Merge<'a> {
103250
}
104251

105252
let columns = insert.columns;
106-
107-
let query = match insert.values.kind {
108-
ExpressionKind::Row(row) => {
109-
let cols_vals = columns.iter().zip(row.values);
110-
111-
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
112-
query.value(val.alias(col.name.clone()))
113-
});
114-
115-
Query::from(select)
116-
}
117-
ExpressionKind::Values(values) => {
118-
let mut rows = values.rows;
119-
let row = rows.pop().unwrap();
120-
let cols_vals = columns.iter().zip(row.values);
121-
122-
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
123-
query.value(val.alias(col.name.clone()))
124-
});
125-
126-
let union = rows.into_iter().fold(Union::new(select), |union, row| {
127-
let cols_vals = columns.iter().zip(row.values);
128-
129-
let select = cols_vals.fold(Select::default(), |query, (col, val)| {
130-
query.value(val.alias(col.name.clone()))
131-
});
132-
133-
union.all(select)
134-
});
135-
136-
Query::from(union)
137-
}
138-
ExpressionKind::Selection(selection) => Query::from(selection),
139-
ExpressionKind::Parameterized(value) => {
140-
Select::default().value(ExpressionKind::ParameterizedRow(value)).into()
141-
}
142-
_ => {
143-
let kind = ErrorKind::conversion("Insert type not supported.");
144-
return Err(Error::builder(kind).build());
145-
}
146-
};
253+
let query = build_using_query(&columns, insert.values)?;
147254

148255
let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect();
149256

150257
let using = query
151258
.into_using("dual", bare_columns.clone())
152-
.on(table.join_conditions(&columns).unwrap());
259+
.on(table.join_conditions(&columns)?);
153260

154261
let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect();
155262
let not_matched = Insert::multi(bare_columns).values(dual_columns);

quaint/src/tests/upsert.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::test_api::*;
22
use crate::{connector::Queryable, prelude::*};
33
use quaint_test_macros::test_each_connector;
44

5-
#[test_each_connector(tags("postgresql", "sqlite"))]
5+
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
66
async fn upsert_on_primary_key(api: &mut dyn TestApi) -> crate::Result<()> {
77
let table = api.create_temp_table("id int primary key, x int").await?;
88

@@ -39,7 +39,7 @@ fn upsert_on_primary_key_query(table: &str) -> Query<'_> {
3939
.into()
4040
}
4141

42-
#[test_each_connector(tags("postgresql", "sqlite"))]
42+
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
4343
async fn upsert_on_unique_field(api: &mut dyn TestApi) -> crate::Result<()> {
4444
let table = api.create_temp_table("id int primary key, x int UNIQUE, y int").await?;
4545

@@ -82,7 +82,7 @@ fn upsert_on_unique_field_query(table: &str) -> Query<'_> {
8282
.into()
8383
}
8484

85-
#[test_each_connector(tags("postgresql", "sqlite"))]
85+
#[test_each_connector(tags("postgresql", "sqlite", "mssql"))]
8686
async fn upsert_on_multiple_unique_fields(api: &mut dyn TestApi) -> crate::Result<()> {
8787
let table = api
8888
.create_temp_table("id int primary key, x int, y int, CONSTRAINT ux_x_y UNIQUE (x, y)")

quaint/src/visitor.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ pub trait Visitor<'a> {
8383

8484
/// A point to modify an incoming query to make it compatible with the
8585
/// underlying database.
86-
fn compatibility_modifications(&self, query: Query<'a>) -> Query<'a> {
87-
query
86+
fn compatibility_modifications(&self, query: Query<'a>) -> crate::Result<Query<'a>> {
87+
Ok(query)
8888
}
8989

9090
fn surround_with<F>(&mut self, begin: &str, end: &str, f: F) -> Result
@@ -514,7 +514,7 @@ pub trait Visitor<'a> {
514514

515515
/// A walk through a complete `Query` statement
516516
fn visit_query(&mut self, mut query: Query<'a>) -> Result {
517-
query = self.compatibility_modifications(query);
517+
query = self.compatibility_modifications(query)?;
518518

519519
match query {
520520
Query::Select(select) => self.visit_select(*select),

0 commit comments

Comments
 (0)