Skip to content

Commit a50645c

Browse files
fix: address CodeRabbit review round 3 - cleanup and hardening
- test-warehouse.yml: replace for-loop with case statement for Docker adapter check - dateadd.sql: use bare TIMESTAMPADD keywords instead of SQL_TSI_* constants, add case-insensitive datepart matching - spark.py: harden connection cleanup with None-init + conditional close, escape single quotes in container_path - dremio.py: switch from PyYAML to ruamel.yaml for project consistency, log non-file parsing failures, make seeding idempotent with DROP TABLE before CREATE TABLE Co-Authored-By: Itamar Hartstein <haritamar@gmail.com>
1 parent a5dcef9 commit a50645c

4 files changed

Lines changed: 39 additions & 26 deletions

File tree

.github/workflows/test-warehouse.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,10 @@ jobs:
186186
# Docker-based adapters use ephemeral containers, so a fixed schema
187187
# name is safe (the concurrency group prevents parallel collisions).
188188
# This enables caching the seeded database state between runs.
189-
DOCKER_ADAPTERS="postgres clickhouse trino dremio duckdb spark"
190189
IS_DOCKER=false
191-
for adapter in $DOCKER_ADAPTERS; do
192-
if [ "$adapter" = "${{ inputs.warehouse-type }}" ]; then
193-
IS_DOCKER=true
194-
break
195-
fi
196-
done
190+
case "${{ inputs.warehouse-type }}" in
191+
postgres|clickhouse|trino|dremio|duckdb|spark) IS_DOCKER=true ;;
192+
esac
197193
198194
if [ "$IS_DOCKER" = "true" ]; then
199195
SCHEMA_NAME="elementary_tests"

elementary/monitor/dbt_project/macros/overrides/dateadd.sql

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,28 @@
1010
#}
1111
1212
{% macro dremio__dateadd(datepart, interval, from_date_or_timestamp) %}
13+
{% set datepart = datepart | lower %}
1314
{% set interval = interval | string %}
1415
{# dbt-dremio's original macro wraps the result in a scalar subquery
1516
("select TIMESTAMPADD(...) order by 1"), so when we receive the
1617
interval from upstream it may carry a trailing "order by 1". #}
1718
{% set interval = interval.replace('order by 1', '') %}
1819
{% if datepart == 'year' %}
19-
TIMESTAMPADD(SQL_TSI_YEAR, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
20+
TIMESTAMPADD(YEAR, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
2021
{% elif datepart == 'quarter' %}
21-
TIMESTAMPADD(SQL_TSI_QUARTER, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
22+
TIMESTAMPADD(QUARTER, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
2223
{% elif datepart == 'month' %}
23-
TIMESTAMPADD(SQL_TSI_MONTH, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
24+
TIMESTAMPADD(MONTH, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
2425
{% elif datepart == 'week' %}
25-
TIMESTAMPADD(SQL_TSI_WEEK, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
26+
TIMESTAMPADD(WEEK, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
2627
{% elif datepart == 'hour' %}
27-
TIMESTAMPADD(SQL_TSI_HOUR, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
28+
TIMESTAMPADD(HOUR, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
2829
{% elif datepart == 'minute' %}
29-
TIMESTAMPADD(SQL_TSI_MINUTE, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
30+
TIMESTAMPADD(MINUTE, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
3031
{% elif datepart == 'second' %}
31-
TIMESTAMPADD(SQL_TSI_SECOND, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
32+
TIMESTAMPADD(SECOND, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
3233
{% elif datepart == 'day' %}
33-
TIMESTAMPADD(SQL_TSI_DAY, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
34+
TIMESTAMPADD(DAY, CAST({{interval}} as int), CAST({{from_date_or_timestamp}} as TIMESTAMP))
3435
{% else %}
3536
{{ exceptions.raise_compiler_error("dremio__dateadd: unrecognized datepart '" ~ datepart ~ "'. Supported: year, quarter, month, week, day, hour, minute, second.") }}
3637
{% endif %}

tests/e2e_dbt_project/external_seeders/dremio.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import re
77
import time
88

9-
import yaml
9+
from ruamel.yaml import YAML
10+
1011
from external_seeders.base import ExternalSeeder
1112

1213

@@ -21,8 +22,9 @@ def _docker_defaults() -> dict[str, str]:
2122
# --- docker-compose.yml: MinIO credentials ---
2223
compose_path = os.path.join(project_dir, "docker-compose.yml")
2324
try:
25+
_yaml = YAML()
2426
with open(compose_path) as fh:
25-
cfg = yaml.safe_load(fh)
27+
cfg = _yaml.load(fh)
2628
services = cfg.get("services", {})
2729
for item in services.get("dremio-minio", {}).get("environment", []):
2830
if isinstance(item, str) and "=" in item:
@@ -35,8 +37,10 @@ def _docker_defaults() -> dict[str, str]:
3537
defaults["MINIO_ACCESS_KEY"] = v
3638
elif k == "MINIO_ROOT_PASSWORD":
3739
defaults["MINIO_SECRET_KEY"] = v
38-
except Exception:
40+
except FileNotFoundError:
3941
pass
42+
except Exception as e:
43+
print(f" Warning: failed parsing docker defaults from {compose_path}: {e}")
4044

4145
# --- dremio-setup.sh: Dremio login credentials ---
4246
# Extract default values from bash variable assignments like:
@@ -51,8 +55,10 @@ def _docker_defaults() -> dict[str, str]:
5155
m = re.search(r'DREMIO_USER="\$\{DREMIO_USER:-([^}]+)\}"', content)
5256
if m:
5357
defaults["DREMIO_USER"] = m.group(1)
54-
except Exception:
58+
except FileNotFoundError:
5559
pass
60+
except Exception as e:
61+
print(f" Warning: failed parsing dremio defaults from {setup_path}: {e}")
5662

5763
return defaults
5864

@@ -364,9 +370,13 @@ def load(self) -> None:
364370

365371
fqn = f"{nessie_ns}.{qi(table_name)}"
366372

367-
# Create empty Iceberg table with VARCHAR columns
373+
# Drop + recreate to ensure idempotent seeding (no stale data)
368374
col_defs = ", ".join(f"{qi(c)} VARCHAR" for c in cols)
369-
create_sql = f"CREATE TABLE IF NOT EXISTS {fqn} ({col_defs})"
375+
try:
376+
self._sql(token, f"DROP TABLE IF EXISTS {fqn}", timeout=30)
377+
except Exception:
378+
pass # Table may not exist yet
379+
create_sql = f"CREATE TABLE {fqn} ({col_defs})"
370380
try:
371381
self._sql(token, create_sql, timeout=60)
372382
created_tables.append(table_name)

tests/e2e_dbt_project/external_seeders/spark.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ def load(self) -> None:
4040
port = int(os.environ.get("SPARK_PORT", "10000"))
4141

4242
print(f"Connecting to Spark Thrift at {host}:{port}...")
43-
conn = hive.Connection(host=host, port=port, username="dbt")
44-
cursor = conn.cursor()
43+
conn = None
44+
cursor = None
4545
try:
46+
conn = hive.Connection(host=host, port=port, username="dbt")
47+
cursor = conn.cursor()
4648
print(f"Creating schema '{seed_schema}'...")
4749
cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{seed_schema}`")
4850

@@ -69,13 +71,15 @@ def load(self) -> None:
6971
continue
7072

7173
container_path = f"/seed-data/{subdir}/{fname}"
74+
# Escape single quotes in path to prevent SQL injection
75+
safe_path = container_path.replace("'", "''")
7276
tmp_view = f"_tmp_csv_{table_name}"
7377
print(f" Loading: {table_name}")
7478
try:
7579
cursor.execute(
7680
f"CREATE OR REPLACE TEMPORARY VIEW {q(tmp_view)} "
7781
f"USING csv "
78-
f"OPTIONS (path '{container_path}', header 'true', "
82+
f"OPTIONS (path '{safe_path}', header 'true', "
7983
f"inferSchema 'true')"
8084
)
8185
cursor.execute(
@@ -88,8 +92,10 @@ def load(self) -> None:
8892
except Exception as e:
8993
failures.append(f"{table_name}: {e}")
9094
finally:
91-
cursor.close()
92-
conn.close()
95+
if cursor is not None:
96+
cursor.close()
97+
if conn is not None:
98+
conn.close()
9399
if failures:
94100
raise RuntimeError(
95101
"Spark seed loading failed:\n - " + "\n - ".join(failures)

0 commit comments

Comments
 (0)