Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 778626c

Browse files
committed
fix unit tests
1 parent 802d0d3 commit 778626c

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

data_diff/dbt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,13 @@ def _local_diff(diff_vars: DiffVars) -> None:
222222

223223

224224
def _cloud_diff(diff_vars: DiffVars) -> None:
225-
datafold_host = os.environ.get("DATAFOLD_HOST", "https://app.datafold.com").rstrip("/")
226-
api_key = os.environ.get("DATAFOLD_API_KEY")
225+
datafold_host = os.environ.get("DATAFOLD_HOST")
226+
if datafold_host is None:
227+
datafold_host = "https://app.datafold.com"
228+
datafold_host = datafold_host.rstrip("/")
227229
rich.print(f"Cloud datafold host: {datafold_host}")
230+
231+
api_key = os.environ.get("DATAFOLD_API_KEY")
228232
if not api_key:
229233
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
230234
yes_or_no = Confirm.ask("Would you like to generate a new API key?")

tests/test_dbt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_integration_basic_dbt(self):
331331
artifacts_path = os.getcwd() + '/tests/dbt_artifacts'
332332
test_project_path = os.environ.get("DATA_DIFF_DBT_PROJ") or artifacts_path
333333
diff = run_datadiff_cli("--dbt", "--dbt-project-dir", test_project_path, "--dbt-profiles-dir", test_project_path)
334-
assert diff[-1].decode("utf-8") == "Diffs Complete!"
334+
assert "Diffs Complete!" in '\n'.join(d.decode("utf-8") for d in diff)
335335

336336
# assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
337337
if test_project_path == artifacts_path:
@@ -340,7 +340,7 @@ def test_integration_basic_dbt(self):
340340
assert diff_string.count('<>') == 5
341341
# 4 with no diffs
342342
assert diff_string.count('No row differences') == 4
343-
# 1 with a diff
343+
# 1 with a diff
344344
assert diff_string.count('| Rows Added | Rows Removed') == 1
345345

346346

@@ -425,7 +425,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
425425
_cloud_diff(diff_vars)
426426

427427
mock_request.assert_called_once()
428-
mock_print.assert_called_once()
428+
self.assertEqual(len(mock_print.call_args_list), 2)
429429
request_data_dict = mock_request.call_args[1]["json"]
430430
self.assertEqual(
431431
mock_request.call_args[1]["headers"]["Authorization"],
@@ -455,17 +455,19 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
455455
_cloud_diff(diff_vars)
456456

457457
mock_request.assert_not_called()
458-
mock_print.assert_not_called()
458+
mock_print.assert_called_once()
459459

460460
@patch("data_diff.dbt.rich.print")
461461
@patch("data_diff.dbt.os.environ")
462462
@patch("data_diff.dbt.requests.request")
463-
def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print):
463+
@patch("data_diff.dbt.Confirm.ask")
464+
def test_cloud_diff_api_key_none(self, mock_confirm_answer, mock_request, mock_os_environ, mock_print):
464465
expected_api_key = None
465466
mock_response = Mock()
466467
mock_response.json.return_value = {"id": 123}
467468
mock_request.return_value = mock_response
468469
mock_os_environ.get.return_value = expected_api_key
470+
mock_confirm_answer.return_value = False
469471
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
470472
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
471473
expected_datasource_id = 1
@@ -475,7 +477,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
475477
_cloud_diff(diff_vars)
476478

477479
mock_request.assert_not_called()
478-
mock_print.assert_not_called()
480+
self.assertEqual(len(mock_print.call_args_list), 2)
479481

480482
@patch("data_diff.dbt._get_diff_vars")
481483
@patch("data_diff.dbt._local_diff")

0 commit comments

Comments
 (0)