Skip to content

Commit 171b99f

Browse files
Merge pull request #25 from patterninc/fix/snowflake-warehouse-set
Fix/snowflake warehouse set
2 parents 84ad6cf + 96ce66d commit 171b99f

4 files changed

Lines changed: 15 additions & 24 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ds-platform-utils"
3-
version = "0.4.1"
3+
version = "0.4.2"
44
description = "Utility library for Pattern Data Science."
55
readme = "README.md"
66
authors = [

src/ds_platform_utils/metaflow/snowflake_connection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,17 @@ def _create_snowflake_connection(
8080
conn: SnowflakeConnection = Snowflake(
8181
integration=SNOWFLAKE_INTEGRATION,
8282
client_session_keep_alive=True,
83-
warehouse=warehouse,
8483
timezone="UTC" if use_utc else None,
8584
session_parameters={"QUERY_TAG": query_tag},
8685
).cn # type: ignore[attr-defined]
8786

87+
# Doing this in the connection parameters result in silently failing to set the warehouse,
88+
# so we have to execute a raw query to set it.
89+
try:
90+
conn.execute_string("USE WAREHOUSE {}".format(warehouse))
91+
except Exception as e:
92+
raise RuntimeError(f"Failed to set Snowflake warehouse to {warehouse}: {e}") from e
93+
8894
return conn
8995

9096

tests/unit_tests/snowflake/test__execute_sql.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,21 @@
11
"""Functional test for _execute_sql."""
22

33
from typing import Generator
4-
from unittest.mock import MagicMock
54

65
import pytest
76
from snowflake.connector import SnowflakeConnection
87

98
from ds_platform_utils._snowflake.run_query import _execute_sql
10-
from ds_platform_utils.metaflow.snowflake_connection import _create_snowflake_connection
9+
from ds_platform_utils.metaflow.snowflake_connection import get_snowflake_connection
1110

1211

1312
@pytest.fixture(scope="module")
14-
def patched_current() -> Generator[MagicMock, None, None]:
15-
"""Patch Metaflow `current` object for modules used in this test file."""
16-
mock_current = MagicMock("metaflow.current")
17-
mock_current.tags = ["ds.domain:testing", "ds.project:unit-tests"]
18-
mock_current.flow_name = "DummyFlow"
19-
mock_current.project_name = "dummy-project"
20-
mock_current.step_name = "dummy-step"
21-
mock_current.run_id = "123"
22-
mock_current.username = "tester"
23-
mock_current.is_production = False
24-
mock_current.namespace = "user:tester"
25-
mock_current.is_running_flow = True
26-
mock_current.card = []
27-
yield mock_current
28-
29-
30-
@pytest.fixture(scope="module")
31-
def snowflake_conn(patched_current) -> Generator[SnowflakeConnection, None, None]:
13+
def snowflake_conn() -> Generator[SnowflakeConnection, None, None]:
3214
"""Get a Snowflake connection for testing."""
33-
yield _create_snowflake_connection(warehouse=None, use_utc=True)
15+
from metaflow import current
16+
17+
current.is_production = False # Ensure we're in non-prod for testing
18+
yield get_snowflake_connection(warehouse=None, use_utc=True)
3419

3520

3621
def test_execute_sql_empty_string(snowflake_conn):

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)