Skip to content

Commit 68f7a36

Browse files
authored
Merge pull request dbt-msft#710 from dbt-msft/fix/708-dbt-transactions
Implements `dbt_sqlserver_use_dbt_transactions` behavior flag
2 parents 9f86896 + 2e1c605 commit 68f7a36

7 files changed

Lines changed: 441 additions & 6 deletions

File tree

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,24 @@ The same setting is also honoured via `vars:` for backwards compatibility; the b
129129

130130
*(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed.
131131

132+
### `dbt_sqlserver_use_dbt_transactions`
133+
134+
_(default: `false`)_ When enabled, makes dbt's transaction hooks real at the SQL Server level by emitting `BEGIN TRANSACTION` / `COMMIT TRANSACTION` through the adapter's `add_begin_query` and `add_commit_query` methods.
135+
136+
The default is `false`, preserving existing behavior where `begin`/`commit` hooks are logical no-ops and the ODBC driver auto-commits each statement. When `dbt_sqlserver_use_dbt_transactions: true`, the adapter emits real T-SQL transaction statements, and rollback uses `IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION`.
137+
138+
The driver connection remains in autocommit mode (`autocommit=true`) in both modes.
139+
140+
This mode is opt-in and should be tested carefully with project-specific materializations and hooks.
141+
142+
```yaml
143+
# dbt_project.yml
144+
flags:
145+
dbt_sqlserver_use_dbt_transactions: true # <-- opt-in; default is false
146+
```
147+
148+
**Compatibility notes:** Enabling `dbt_sqlserver_use_dbt_transactions: true` may expose transaction-state assumptions hidden by autocommit-only mode. Explicit transaction macros may interact with dbt-managed transactions, and cleanup after failed DDL/DML may differ. Review pre/post hooks for in-transaction vs out-of-transaction semantics.
149+
132150
## Contributing
133151

134152
[![Unit tests](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml)

dbt/adapters/sqlserver/sqlserver_adapter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def __init__(self, config, mp_context=None):
6262
)
6363
if self.behavior.dbt_sqlserver_use_native_string_types:
6464
self.Column = SQLServerColumnNative
65+
# add_begin_query/add_commit_query read the instance flag, while dbt-core
66+
# rollback handling is classmethod-based and reads the class flag.
67+
use_dbt_transactions = bool(self.behavior.dbt_sqlserver_use_dbt_transactions)
68+
SQLServerConnectionManager._dbt_sqlserver_use_dbt_transactions = use_dbt_transactions
69+
self.connections._dbt_sqlserver_use_dbt_transactions = use_dbt_transactions
6570

6671
@property
6772
def _behavior_flags(self) -> List[BehaviorFlag]:
@@ -106,6 +111,18 @@ def _behavior_flags(self) -> List[BehaviorFlag]:
106111
"The new behaviour is intended to become the default in a future release."
107112
),
108113
},
114+
{
115+
"name": "dbt_sqlserver_use_dbt_transactions",
116+
"default": False,
117+
"description": (
118+
"When True, dbt transaction hooks (begin/commit) emit real T-SQL "
119+
"BEGIN TRANSACTION / COMMIT TRANSACTION statements. "
120+
"When False (default and legacy), begin/commit are no-ops and each statement "
121+
"is auto-committed by the driver. This means earlier successful statements "
122+
"are not rolled back if a later statement fails. "
123+
"This behavior is intended to become the default in a future release."
124+
),
125+
},
109126
]
110127

111128
@available.parse(lambda *a, **k: [])

dbt/adapters/sqlserver/sqlserver_connections.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime as dt
22
import time
3+
import traceback
34
from contextlib import contextmanager
45
from typing import (
56
Any,
@@ -25,6 +26,7 @@
2526
from dbt.adapters.events.types import (
2627
AdapterEventDebug,
2728
ConnectionUsed,
29+
RollbackFailed,
2830
SQLQuery,
2931
SQLQueryStatus,
3032
)
@@ -63,6 +65,8 @@
6365
class SQLServerConnectionManager(SQLConnectionManager):
6466
TYPE = "sqlserver"
6567

68+
_dbt_sqlserver_use_dbt_transactions: bool = False
69+
6670
@contextmanager
6771
def exception_handler(self, sql):
6872
"""Translate backend database errors and re-raise everything else.
@@ -142,10 +146,42 @@ def cancel(self, connection: Connection):
142146
logger.debug("Cancel query")
143147

144148
def add_begin_query(self):
145-
pass
149+
if self._dbt_sqlserver_use_dbt_transactions:
150+
return self.add_query("BEGIN TRANSACTION", auto_begin=False)
146151

147152
def add_commit_query(self):
148-
pass
153+
if self._dbt_sqlserver_use_dbt_transactions:
154+
return self.add_query("IF @@TRANCOUNT > 0 COMMIT TRANSACTION", auto_begin=False)
155+
156+
@classmethod
157+
def _rollback_handle(cls, connection: Connection) -> None:
158+
if cls._dbt_sqlserver_use_dbt_transactions:
159+
cursor = None
160+
try:
161+
cursor = connection.handle.cursor()
162+
cursor.execute("IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION")
163+
except Exception:
164+
fire_event(
165+
RollbackFailed(
166+
conn_name=cast_to_str(connection.name),
167+
exc_info=traceback.format_exc(),
168+
node_info=get_node_info(),
169+
)
170+
)
171+
finally:
172+
if cursor is not None:
173+
cursor.close()
174+
else:
175+
try:
176+
connection.handle.rollback()
177+
except Exception:
178+
fire_event(
179+
RollbackFailed(
180+
conn_name=cast_to_str(connection.name),
181+
exc_info=traceback.format_exc(),
182+
node_info=get_node_info(),
183+
)
184+
)
149185

150186
def add_query(
151187
self,

dbt/include/sqlserver/macros/materializations/hooks.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
{% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %}
33
{% if not inside_transaction and loop.first %}
44
{% call statement(auto_begin=inside_transaction) %}
5-
if @@trancount > 0 commit;
5+
{% if not adapter.behavior.dbt_sqlserver_use_dbt_transactions %}
6+
if @@trancount > 0 commit; -- post hooks after fictitious transaction work as expected
7+
{% else %}
8+
commit; -- align transaction=False hook behavior with dbt-core transaction semantics.
9+
{% endif %}
610
{% endcall %}
711
{% endif %}
812
{% set rendered = render(hook.get('sql')) | trim %}

dbt/include/sqlserver/macros/materializations/models/table/table_dml_refresh.sql

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,19 @@
5454
{%- set column_list = target_columns | map(attribute='quoted') | join(', ') -%}
5555

5656
{# Atomic DML swap — RCSI protects concurrent readers #}
57-
{# dbt-sqlserver uses autocommit=True and add_begin_query/add_commit_query #}
58-
{# are no-ops, so this creates a simple (non-nested) transaction. #}
57+
{# When dbt_sqlserver_use_dbt_transactions is off (default), autocommit #}
58+
{# means we need the explicit BEGIN/COMMIT. When the flag is on, dbt #}
59+
{# already wraps the statement call in a transaction, so skip it. #}
5960
{% call statement('dml_refresh_swap') -%}
61+
{% if not adapter.behavior.dbt_sqlserver_use_dbt_transactions %}
6062
BEGIN TRANSACTION;
63+
{% endif %}
6164
DELETE FROM {{ target_relation }};
6265
INSERT INTO {{ target_relation }} ({{ column_list }})
6366
SELECT {{ column_list }} FROM {{ refresh_relation }};
67+
{% if not adapter.behavior.dbt_sqlserver_use_dbt_transactions %}
6468
COMMIT TRANSACTION;
69+
{% endif %}
6570
{%- endcall %}
6671

6772
{# Cleanup scratch table #}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import pytest
2+
3+
from dbt.tests.util import run_dbt
4+
5+
6+
class BaseTransactionsEnabled:
7+
@pytest.fixture(scope="class")
8+
def project_config_update(self):
9+
return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}}
10+
11+
12+
class TestTableMaterializationTransactionsOn(BaseTransactionsEnabled):
13+
@pytest.fixture(scope="class")
14+
def models(self):
15+
return {
16+
"table_model.sql": """
17+
{{ config(materialized='table') }}
18+
select 1 as id, 'hello' as name
19+
""",
20+
}
21+
22+
def test_table_materialization(self, project):
23+
results = run_dbt(["run"])
24+
assert len(results) == 1
25+
26+
rows = project.run_sql("select id, name from {schema}.table_model", fetch="all")
27+
assert len(rows) == 1
28+
assert rows[0][0] == 1
29+
assert rows[0][1] == "hello"
30+
31+
32+
class TestViewMaterializationTransactionsOn(BaseTransactionsEnabled):
33+
@pytest.fixture(scope="class")
34+
def models(self):
35+
return {
36+
"view_model.sql": """
37+
{{ config(materialized='view') }}
38+
select 42 as answer
39+
""",
40+
}
41+
42+
def test_view_materialization(self, project):
43+
results = run_dbt(["run"])
44+
assert len(results) == 1
45+
46+
rows = project.run_sql("select answer from {schema}.view_model", fetch="all")
47+
assert len(rows) == 1
48+
assert rows[0][0] == 42
49+
50+
51+
class TestIncrementalMaterializationTransactionsOn(BaseTransactionsEnabled):
52+
@pytest.fixture(scope="class")
53+
def models(self):
54+
return {
55+
"incremental_model.sql": """
56+
{{ config(materialized='incremental', unique_key='id') }}
57+
select 1 as id, 'first' as value
58+
{% if is_incremental() %}
59+
union all
60+
select 2 as id, 'second' as value
61+
{% endif %}
62+
""",
63+
}
64+
65+
def test_incremental_materialization(self, project):
66+
results = run_dbt(["run"])
67+
assert len(results) == 1
68+
69+
rows = project.run_sql(
70+
"select count(*) as cnt from {schema}.incremental_model", fetch="one"
71+
)
72+
assert rows[0] == 1
73+
74+
results = run_dbt(["run"])
75+
assert len(results) == 1
76+
77+
rows = project.run_sql(
78+
"select count(*) as cnt from {schema}.incremental_model", fetch="one"
79+
)
80+
assert rows[0] == 2
81+
82+
83+
class BaseFailingModelWithSideEffect:
84+
@pytest.fixture(scope="class")
85+
def models(self):
86+
return {
87+
"failing_model.sql": """
88+
{{ config(
89+
materialized='table',
90+
pre_hook=[
91+
"INSERT INTO {{ this.schema }}.audit_log "
92+
"(msg, created_at) VALUES ('from_model', getdate())"
93+
]
94+
) }}
95+
select 1/0 as boom
96+
""",
97+
}
98+
99+
100+
class TestRollbackWithoutFlag(BaseFailingModelWithSideEffect):
101+
@pytest.fixture(scope="class")
102+
def project_config_update(self):
103+
return {"flags": {"dbt_sqlserver_use_dbt_transactions": False}}
104+
105+
@pytest.mark.xfail(
106+
strict=True,
107+
reason="Without transactions flag, DML in pre-hooks is auto-committed and not rolled back,"
108+
" remove after migration to always use transactions.",
109+
)
110+
def test_side_effect_rolled_back(self, project):
111+
project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)")
112+
run_dbt(["run", "-m", "failing_model"], expect_pass=False)
113+
rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one")
114+
assert rows[0] == 0
115+
116+
117+
class TestRollbackWithFlag(BaseFailingModelWithSideEffect):
118+
@pytest.fixture(scope="class")
119+
def project_config_update(self):
120+
return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}}
121+
122+
def test_side_effect_rolled_back(self, project):
123+
project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)")
124+
run_dbt(["run", "-m", "failing_model"], expect_pass=False)
125+
rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one")
126+
assert rows[0] == 0
127+
128+
129+
class TestAfterCommitModelHookTransactionsOn(BaseTransactionsEnabled):
130+
@pytest.fixture(scope="class")
131+
def project_config_update(self):
132+
return {
133+
"flags": {"dbt_sqlserver_use_dbt_transactions": True},
134+
"models": {
135+
"test": {
136+
"post-hook": [
137+
{"sql": "select 1", "transaction": False},
138+
],
139+
}
140+
},
141+
}
142+
143+
@pytest.fixture(scope="class")
144+
def models(self):
145+
return {"after_commit_hook_model.sql": "select 1 as id"}
146+
147+
def test_after_commit_post_hook_does_not_double_commit(self, project):
148+
run_dbt()
149+
150+
151+
class TestFailedModelThenSuccessTransactionsOn(BaseTransactionsEnabled):
152+
@pytest.fixture(scope="class")
153+
def models(self):
154+
return {
155+
"good_model.sql": """
156+
{{ config(materialized='table') }}
157+
select 1 as id
158+
""",
159+
"bad_model.sql": """
160+
{{ config(materialized='table') }}
161+
select 1/0 as boom
162+
""",
163+
}
164+
165+
def test_failed_then_successful_run(self, project):
166+
results = run_dbt(["run", "-m", "bad_model"], expect_pass=False)
167+
assert len(results) == 1
168+
assert results[0].status == "error"
169+
170+
results = run_dbt(["run", "-m", "good_model"])
171+
assert len(results) == 1
172+
assert results[0].status == "success"
173+
174+
rows = project.run_sql("select id from {schema}.good_model", fetch="all")
175+
assert len(rows) == 1
176+
assert rows[0][0] == 1
177+
178+
179+
_snapshot_seed_csv = """id,name,updated_at
180+
1,alice,2024-01-01 00:00:00
181+
2,bob,2024-01-01 00:00:00
182+
"""
183+
184+
_snapshot_sql = """
185+
{% snapshot snap %}
186+
{{ config(
187+
target_schema=schema,
188+
unique_key='id',
189+
strategy='timestamp',
190+
updated_at='updated_at',
191+
) }}
192+
select * from {{ ref('snap_seed') }}
193+
{% endsnapshot %}
194+
"""
195+
196+
197+
class TestSnapshotTransactionsOn(BaseTransactionsEnabled):
198+
@pytest.fixture(scope="class")
199+
def seeds(self):
200+
return {"snap_seed.csv": _snapshot_seed_csv}
201+
202+
@pytest.fixture(scope="class")
203+
def snapshots(self):
204+
return {"snap.sql": _snapshot_sql}
205+
206+
@pytest.fixture(scope="class")
207+
def models(self):
208+
return {}
209+
210+
def test_snapshot_create_and_merge(self, project):
211+
run_dbt(["seed"])
212+
results = run_dbt(["snapshot"])
213+
assert len(results) == 1
214+
assert results[0].status == "success"
215+
216+
rows = project.run_sql("select count(*) from {schema}.snap", fetch="one")
217+
assert rows[0] == 2
218+
219+
results = run_dbt(["snapshot"])
220+
assert len(results) == 1
221+
assert results[0].status == "success"

0 commit comments

Comments
 (0)