diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py index c512b734e0..622e370d72 100644 --- a/graphistry/PlotterBase.py +++ b/graphistry/PlotterBase.py @@ -10,7 +10,7 @@ from graphistry.plugins_types.hypergraph import HypergraphResult from graphistry.render.resolve_render_mode import resolve_render_mode from graphistry.Engine import Engine, EngineAbstractType, df_to_engine -import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, sys, uuid, warnings +import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, requests, sys, uuid, warnings from functools import lru_cache, partialmethod from weakref import WeakValueDictionary @@ -41,7 +41,7 @@ error, hash_pdf, in_ipython, in_databricks, make_iframe, random_string, warn, cache_coercion, cache_coercion_helper, WeakValueWrapper ) -from graphistry.otel import otel_traced, otel_detail_enabled +from graphistry.otel import otel_traced, otel_detail_enabled, inject_trace_headers from .bolt_util import ( bolt_graph_to_edges_dataframe, @@ -2498,11 +2498,44 @@ def plot( 'type': 'arrow', 'viztoken': str(uuid.uuid4()) } - # Validate collections in url_params (catches bypass of .collections() method) from graphistry.validate.validate_collections import normalize_collections_url_params url_params = normalize_collections_url_params(self._url_params, validate=validate_mode, warn=warn) + token = self.session.api_token + if token: + resp = None + try: + server_base = '%s://%s' % (self.session.protocol, self.session.hostname) + resp = requests.post( + '%s/api/v1/auth/jwt/ott/' % server_base, + headers=inject_trace_headers({'Authorization': 'Bearer %s' % token}), + verify=self.session.certificate_validation, + timeout=30, + ) + resp.raise_for_status() + content_type = resp.headers.get('content-type', '') + if 'application/json' not in content_type: + raise ValueError( + 'OTT endpoint returned non-JSON (content-type: %s) — ' + 'server may not have the OTT endpoint deployed yet. ' + 'Body: %.200s' % (content_type, resp.text)) + url_params['token'] = resp.json()['ott'] + except requests.HTTPError as e: + assert resp is not None + logger.warning( + "OTT exchange failed — cross-origin iframe embedding will require " + "re-login (SameSite cookies blocked). " + "Ensure OTT_EXCHANGE_SECRET is set on the server. " + "Error: %s (status=%s, body=%.200s)", + e, resp.status_code, resp.text) + except Exception as e: + logger.warning( + "OTT exchange failed — cross-origin iframe embedding will require " + "re-login (SameSite cookies blocked). " + "Error: %s (body=%.200s)", + e, resp.text if resp is not None else '') + viz_url = self._pygraphistry._viz_url(info, url_params) cfg_client_protocol_hostname = self.session.client_protocol_hostname full_url = ('%s:%s' % (self.session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url diff --git a/graphistry/tests/test_trace_headers_behavior.py b/graphistry/tests/test_trace_headers_behavior.py index 96014e2c0a..a74e87154c 100644 --- a/graphistry/tests/test_trace_headers_behavior.py +++ b/graphistry/tests/test_trace_headers_behavior.py @@ -3,10 +3,12 @@ from unittest import mock import pandas as pd +import requests -# Import the ArrowFileUploader MODULE before graphistry shadows it with the class -# This ensures sys.modules has the module, allowing proper mock patching +# Import modules before graphistry shadows them with classes/symbols. +# This ensures sys.modules has the modules, allowing proper mock patching. import graphistry.ArrowFileUploader as _arrow_file_uploader_module # noqa: F401 +import graphistry.PlotterBase as _plotter_base_module # noqa: F401 import graphistry from graphistry.compute.ast import n, e_forward @@ -55,13 +57,13 @@ def _post_response_for_plot(url: str): return _mock_response({"is_valid": True, "is_uploaded": True}) if "/api/v2/share/link/" in url: return _mock_response({"success": True}) + if "/api/v1/auth/jwt/ott/" in url: + return _mock_response({"ott": "test-ott-token"}) raise AssertionError(f"Unexpected POST url: {url}") -@mock.patch("graphistry.arrow_uploader.inject_trace_headers") @mock.patch("requests.post") -def test_plot_injects_traceparent(mock_post, mock_inject): - mock_inject.side_effect = _inject_trace +def test_plot_injects_traceparent(mock_post): headers_seen = [] def _fake_post(url, **kwargs): @@ -70,22 +72,139 @@ def _fake_post(url, **kwargs): mock_post.side_effect = _fake_post - g = _make_graph() - g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False) + plotter_base_module = sys.modules["graphistry.PlotterBase"] + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): + g = _make_graph() + g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False) assert headers_seen assert all(h.get("traceparent") == TRACEPARENT for h in headers_seen) -@mock.patch("graphistry.arrow_uploader.inject_trace_headers") @mock.patch("requests.post") -def test_upload_injects_traceparent(mock_post, mock_inject_uploader): - # Patch ArrowFileUploader module's inject_trace_headers via sys.modules - # This is needed because graphistry.ArrowFileUploader resolves to the class, - # not the module (due to re-exports in graphistry/__init__.py) - arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"] +def test_plot_ott_in_url(mock_post): + """OTT from JWT exchange must appear as ?token= in the returned viz URL.""" + mock_post.side_effect = lambda url, **kw: _post_response_for_plot(url) + + plotter_base_module = sys.modules["graphistry.PlotterBase"] + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): + g = _make_graph() + url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + + assert "token=test-ott-token" in url, f"OTT missing from viz URL: {url}" + + +def _patch_inject(fn): + """Decorator: patch inject_trace_headers in both modules that use it.""" + import functools + @functools.wraps(fn) + @mock.patch("requests.post") + def wrapper(mock_post, *args, **kwargs): + plotter_base_module = sys.modules["graphistry.PlotterBase"] + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): + return fn(mock_post, *args, **kwargs) + return wrapper + + +@_patch_inject +def test_plot_ott_http_error_degrades_gracefully(mock_post): + """503 from OTT endpoint → URL has no ?token= (degrades to cookie auth).""" + def _side_effect(url, **kw): + if "/api/v1/auth/jwt/ott/" in url: + resp = _mock_response({"error": "server error"}, status=503) + resp.raise_for_status = mock.Mock( + side_effect=requests.HTTPError("503 Server Error", response=resp)) + return resp + return _post_response_for_plot(url) + + mock_post.side_effect = _side_effect + g = _make_graph() + url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + assert "&token=" not in url, f"?token= must be absent on OTT failure: {url}" + - mock_inject_uploader.side_effect = _inject_trace +@_patch_inject +def test_plot_ott_missing_key_degrades_gracefully(mock_post): + """Malformed OTT response (no 'ott' key) → URL has no ?token=.""" + def _side_effect(url, **kw): + if "/api/v1/auth/jwt/ott/" in url: + return _mock_response({}) # missing 'ott' key + return _post_response_for_plot(url) + + mock_post.side_effect = _side_effect + g = _make_graph() + url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + assert "&token=" not in url, f"?token= must be absent on malformed response: {url}" + + +@_patch_inject +def test_plot_ott_html_response_degrades_gracefully(mock_post): + """Non-JSON (HTML) response from OTT endpoint → URL has no ?token=. + + Reproduces the JSONDecodeError seen in Colab when the server redirects to + a login page (HTTP 200 + text/html) because the endpoint isn't deployed yet. + """ + def _side_effect(url, **kw): + if "/api/v1/auth/jwt/ott/" in url: + resp = mock.Mock() + resp.status_code = 200 + resp.headers = {"content-type": "text/html; charset=utf-8"} + resp.text = "Please log in" + resp.raise_for_status = mock.Mock() # 200, does not raise + return resp + return _post_response_for_plot(url) + + mock_post.side_effect = _side_effect + g = _make_graph() + url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + assert "&token=" not in url, f"?token= must be absent when server returns HTML: {url}" + + +@_patch_inject +def test_plot_ott_connection_error_degrades_gracefully(mock_post): + """Network error on OTT exchange → URL has no ?token=.""" + def _side_effect(url, **kw): + if "/api/v1/auth/jwt/ott/" in url: + raise requests.ConnectionError("connection refused") + return _post_response_for_plot(url) + + mock_post.side_effect = _side_effect + g = _make_graph() + url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + assert "&token=" not in url, f"?token= must be absent on connection error: {url}" + + +@_patch_inject +def test_plot_ott_failure_warns_about_iframe(mock_post): + """Warning message on OTT failure must mention cross-origin iframe re-login.""" + def _side_effect(url, **kw): + if "/api/v1/auth/jwt/ott/" in url: + resp = _mock_response({"error": "misconfigured"}, status=503) + resp.raise_for_status = mock.Mock( + side_effect=requests.HTTPError("503", response=resp)) + return resp + return _post_response_for_plot(url) + + mock_post.side_effect = _side_effect + g = _make_graph() + plotter_base_module = sys.modules["graphistry.PlotterBase"] + with mock.patch.object(plotter_base_module.logger, "warning") as mock_warn: + g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False) + assert mock_warn.called, "Expected a warning on OTT failure" + warning_text = " ".join(str(a) for a in mock_warn.call_args[0]) + assert "cross-origin" in warning_text, f"Warning must mention cross-origin: {warning_text}" + + +@mock.patch("requests.post") +def test_upload_injects_traceparent(mock_post): headers_seen = [] def _fake_post(url, **kwargs): @@ -94,7 +213,17 @@ def _fake_post(url, **kwargs): mock_post.side_effect = _fake_post - with mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace): + # Patch inject_trace_headers in all three modules that make POST requests: + # arrow_uploader.py, ArrowFileUploader.py, and PlotterBase.py (OTT exchange). + # Use sys.modules because graphistry/__init__.py re-exports some names as classes, + # shadowing the module attributes on the graphistry package. + arrow_uploader_module = sys.modules["graphistry.arrow_uploader"] + arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"] + plotter_base_module = sys.modules["graphistry.PlotterBase"] + + with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \ + mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace): g = _make_graph() g.upload(validate=False, warn=False, memoize=False, erase_files_on_fail=False)