|
| 1 | +"""Snowflake harness for the engine-parity suite. |
| 2 | +
|
| 3 | +Owns the Snowpark session, the per-run `IDENTITIES_ENGINE_PARITY_<uuid>` |
| 4 | +TEMPORARY table (auto-drops at session close, no teardown), and the |
| 5 | +batched `INSERT ... SELECT UNION ALL` and `SELECT ... UNION ALL` shapes |
| 6 | +that round-trip through Snowflake in two queries per session. |
| 7 | +""" |
| 8 | + |
| 9 | +import os |
| 10 | +import uuid |
| 11 | +from collections.abc import Iterator |
| 12 | +from contextlib import contextmanager |
| 13 | + |
| 14 | +from snowflake.snowpark import Session |
| 15 | + |
| 16 | +from flagsmith_sql_flag_engine.dialect import Dialect |
| 17 | +from flagsmith_sql_flag_engine.dialects.snowflake import SnowflakeDialect |
| 18 | +from flagsmith_sql_flag_engine.utils import escape_string |
| 19 | +from tests.harnesses._base import EvaluationCase, IdentityRow |
| 20 | + |
| 21 | +# Cases the SQL translator can't match the engine on under Snowflake; |
| 22 | +# xfail keeps the divergence visible without masking a regression |
| 23 | +# elsewhere. Entries are file stems (matching `EngineTestCase.name`); |
| 24 | +# add the why inline. |
| 25 | +_XFAIL_CASE_NAMES: set[str] = { |
| 26 | + # Engine sorts semver prereleases (1.0.0-rc.2 < 1.0.0-rc.3); the SQL |
| 27 | + # semver-sort-key collapses to major.minor.patch only. |
| 28 | + "test_semver_greater_than_prerelease__should_match", |
| 29 | + "test_semver_less_than_prerelease__should_match", |
| 30 | + # Engine does trait-first dispatch: a row with a trait literally named |
| 31 | + # `$.identity` shadows the JSONPath lookup. Replicating per-row trait |
| 32 | + # fallback in SQL roughly doubles the cost of every wrapped JSONPath |
| 33 | + # condition (Snowflake evaluates both IFF arms), so we accept the |
| 34 | + # divergence on this niche shape (`$.`-prefixed trait names) and let |
| 35 | + # callers fall back to the engine. |
| 36 | + "test_jsonpath_like_trait__existing_jsonpath__should_match_trait", |
| 37 | +} |
| 38 | + |
| 39 | + |
| 40 | +def _q(s: str) -> str: |
| 41 | + """Quote a value for inclusion in a single-quoted Snowflake string |
| 42 | + literal. Snowflake string literals process `\\` as an escape, so JSON |
| 43 | + traits with `\\uXXXX` or `\\"` would lose their backslash before |
| 44 | + reaching PARSE_JSON; double the backslashes here. The single-quote |
| 45 | + doubling is the SQL-standard escape that `escape_string` already |
| 46 | + handles.""" |
| 47 | + return escape_string(s.replace("\\", "\\\\")) |
| 48 | + |
| 49 | + |
| 50 | +class SnowflakeHarness: |
| 51 | + name: str = "snowflake" |
| 52 | + dialect: Dialect = SnowflakeDialect() |
| 53 | + xfail_case_names: set[str] = _XFAIL_CASE_NAMES |
| 54 | + |
| 55 | + @contextmanager |
| 56 | + def session(self) -> Iterator[Session]: |
| 57 | + config: dict[str, str] = { |
| 58 | + "account": os.environ["SNOWFLAKE_ACCOUNT"], |
| 59 | + "user": os.environ["SNOWFLAKE_USER"], |
| 60 | + "role": os.environ.get("SNOWFLAKE_ROLE", "ACCOUNTADMIN"), |
| 61 | + "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE", "COMPUTE_WH"), |
| 62 | + "database": os.environ.get("SNOWFLAKE_DATABASE", "FS_TEST"), |
| 63 | + "schema": os.environ.get("SNOWFLAKE_SCHEMA", "PUBLIC"), |
| 64 | + "private_key_file": os.environ["SNOWFLAKE_PRIVATE_KEY_PATH"], |
| 65 | + } |
| 66 | + sess = Session.builder.configs(config).create() |
| 67 | + try: |
| 68 | + yield sess |
| 69 | + finally: |
| 70 | + sess.close() |
| 71 | + |
| 72 | + def setup_identities(self, session: Session, rows: list[IdentityRow]) -> str: |
| 73 | + suffix = uuid.uuid4().hex[:8] |
| 74 | + db = os.environ.get("SNOWFLAKE_DATABASE", "FS_TEST") |
| 75 | + schema = os.environ.get("SNOWFLAKE_SCHEMA", "PUBLIC") |
| 76 | + table = f"{db}.{schema}.IDENTITIES_ENGINE_PARITY_{suffix}" |
| 77 | + # TEMPORARY so the table auto-drops at session close — no teardown. |
| 78 | + session.sql( |
| 79 | + f""" |
| 80 | + CREATE TEMPORARY TABLE {table} ( |
| 81 | + environment_id STRING NOT NULL, |
| 82 | + id NUMBER NOT NULL, |
| 83 | + identifier STRING NOT NULL, |
| 84 | + identity_key STRING NOT NULL, |
| 85 | + traits VARIANT |
| 86 | + ) |
| 87 | + """ |
| 88 | + ).collect() |
| 89 | + |
| 90 | + if not rows: |
| 91 | + return table |
| 92 | + |
| 93 | + selects = [ |
| 94 | + f"SELECT '{_q(r.environment_id)}', {r.id}, " |
| 95 | + f"'{_q(r.identifier)}', '{_q(r.identity_key)}', " |
| 96 | + + (f"PARSE_JSON('{_q(r.traits_json)}')" if r.traits_json else "NULL") |
| 97 | + for r in rows |
| 98 | + ] |
| 99 | + session.sql( |
| 100 | + f"INSERT INTO {table} " |
| 101 | + "(environment_id, id, identifier, identity_key, traits) " |
| 102 | + + "\nUNION ALL\n".join(selects) |
| 103 | + ).collect() |
| 104 | + return table |
| 105 | + |
| 106 | + def evaluate( |
| 107 | + self, |
| 108 | + session: Session, |
| 109 | + identity_table: str, |
| 110 | + cases: list[EvaluationCase], |
| 111 | + ) -> dict[str, bool]: |
| 112 | + select_clauses = [ |
| 113 | + f"SELECT '{_q(c.pair_id)}' AS pair_id, " |
| 114 | + f"EXISTS (SELECT 1 FROM {identity_table} i " |
| 115 | + f"WHERE i.environment_id = '{_q(c.environment_key)}' " |
| 116 | + f"AND ({c.predicate_sql})) AS m" |
| 117 | + for c in cases |
| 118 | + ] |
| 119 | + rows = session.sql("\nUNION ALL\n".join(select_clauses)).collect() |
| 120 | + return {row["PAIR_ID"]: bool(row["M"]) for row in rows} |
0 commit comments