Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b9cae0f

Browse files
authored
Merge pull request #462 from pik94/datafold-cloud-api-token-flow
cloud api token flow
2 parents 831c7c1 + e711b97 commit b9cae0f

5 files changed

Lines changed: 29 additions & 13 deletions

File tree

data_diff/dbt.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
22
import os
33
import time
4+
import webbrowser
45
import rich
6+
from rich.prompt import Confirm
57

68
from collections import defaultdict
79
from dataclasses import dataclass
@@ -220,16 +222,28 @@ def _local_diff(diff_vars: DiffVars) -> None:
220222

221223

222224
def _cloud_diff(diff_vars: DiffVars) -> None:
225+
datafold_host = os.environ.get("DATAFOLD_HOST")
226+
if datafold_host is None:
227+
datafold_host = "https://app.datafold.com"
228+
datafold_host = datafold_host.rstrip("/")
229+
rich.print(f"Cloud datafold host: {datafold_host}")
230+
223231
api_key = os.environ.get("DATAFOLD_API_KEY")
232+
if not api_key:
233+
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
234+
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
235+
if yes_or_no:
236+
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
237+
return
238+
else:
239+
raise ValueError("Cannot diff because the API key is not provided")
224240

225241
if diff_vars.datasource_id is None:
226242
raise ValueError(
227243
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
228244
)
229-
if api_key is None:
230-
raise ValueError("API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
231245

232-
url = "https://app.datafold.com/api/v1/datadiffs"
246+
url = f"{datafold_host}/api/v1/datadiffs"
233247

234248
payload = {
235249
"data_source1_id": diff_vars.datasource_id,
@@ -255,8 +269,7 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
255269
response.raise_for_status()
256270
data = response.json()
257271
diff_id = data["id"]
258-
# TODO in future we should support self hosted datafold
259-
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
272+
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
260273
rich.print(
261274
"[red]"
262275
+ ".".join(diff_vars.prod_path)

data_diff/diff_tables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def get_stats_string(self, is_dbt: bool = False):
153153
string_output += f"\n{k}: {v}"
154154

155155
else:
156-
157156
string_output = ""
158157
string_output += f"{diff_stats.table1_count} rows in table A\n"
159158
string_output += f"{diff_stats.table2_count} rows in table B\n"

data_diff/joindiff_tables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def _diff_segments(
201201
if self.materialize_to_table
202202
else None,
203203
):
204-
205204
assert len(a_cols) == len(b_cols)
206205
logger.debug("Querying for different rows")
207206
diff = db.query(diff_rows, list)

data_diff/tracking.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,15 @@ def set_entrypoint_name(s):
5656
global entrypoint_name
5757
entrypoint_name = s
5858

59+
5960
dbt_user_id = None
6061

62+
6163
def set_dbt_user_id(s):
6264
global dbt_user_id
6365
dbt_user_id = s
6466

67+
6568
def get_anonymous_id():
6669
global g_anonymous_id
6770
if g_anonymous_id is None:

tests/test_dbt.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_integration_basic_dbt(self):
331331
artifacts_path = os.getcwd() + '/tests/dbt_artifacts'
332332
test_project_path = os.environ.get("DATA_DIFF_DBT_PROJ") or artifacts_path
333333
diff = run_datadiff_cli("--dbt", "--dbt-project-dir", test_project_path, "--dbt-profiles-dir", test_project_path)
334-
assert diff[-1].decode("utf-8") == "Diffs Complete!"
334+
assert "Diffs Complete!" in '\n'.join(d.decode("utf-8") for d in diff)
335335

336336
# assertions for the diff that exists in tests/dbt_artifacts/jaffle_shop.duckdb
337337
if test_project_path == artifacts_path:
@@ -340,7 +340,7 @@ def test_integration_basic_dbt(self):
340340
assert diff_string.count('<>') == 5
341341
# 4 with no diffs
342342
assert diff_string.count('No row differences') == 4
343-
# 1 with a diff
343+
# 1 with a diff
344344
assert diff_string.count('| Rows Added | Rows Removed') == 1
345345

346346

@@ -425,7 +425,7 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
425425
_cloud_diff(diff_vars)
426426

427427
mock_request.assert_called_once()
428-
mock_print.assert_called_once()
428+
self.assertEqual(len(mock_print.call_args_list), 2)
429429
request_data_dict = mock_request.call_args[1]["json"]
430430
self.assertEqual(
431431
mock_request.call_args[1]["headers"]["Authorization"],
@@ -455,17 +455,19 @@ def test_cloud_diff_ds_id_none(self, mock_request, mock_os_environ, mock_print):
455455
_cloud_diff(diff_vars)
456456

457457
mock_request.assert_not_called()
458-
mock_print.assert_not_called()
458+
mock_print.assert_called_once()
459459

460460
@patch("data_diff.dbt.rich.print")
461461
@patch("data_diff.dbt.os.environ")
462462
@patch("data_diff.dbt.requests.request")
463-
def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print):
463+
@patch("data_diff.dbt.Confirm.ask")
464+
def test_cloud_diff_api_key_none(self, mock_confirm_answer, mock_request, mock_os_environ, mock_print):
464465
expected_api_key = None
465466
mock_response = Mock()
466467
mock_response.json.return_value = {"id": 123}
467468
mock_request.return_value = mock_response
468469
mock_os_environ.get.return_value = expected_api_key
470+
mock_confirm_answer.return_value = False
469471
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
470472
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
471473
expected_datasource_id = 1
@@ -475,7 +477,7 @@ def test_cloud_diff_api_key_none(self, mock_request, mock_os_environ, mock_print
475477
_cloud_diff(diff_vars)
476478

477479
mock_request.assert_not_called()
478-
mock_print.assert_not_called()
480+
self.assertEqual(len(mock_print.call_args_list), 2)
479481

480482
@patch("data_diff.dbt._get_diff_vars")
481483
@patch("data_diff.dbt._local_diff")

0 commit comments

Comments
 (0)