Skip to content
39 changes: 36 additions & 3 deletions graphistry/PlotterBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from graphistry.plugins_types.hypergraph import HypergraphResult
from graphistry.render.resolve_render_mode import resolve_render_mode
from graphistry.Engine import Engine, EngineAbstractType, df_to_engine
import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, sys, uuid, warnings
import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, requests, sys, uuid, warnings
from functools import lru_cache, partialmethod
from weakref import WeakValueDictionary

Expand Down Expand Up @@ -41,7 +41,7 @@
error, hash_pdf, in_ipython, in_databricks, make_iframe, random_string, warn,
cache_coercion, cache_coercion_helper, WeakValueWrapper
)
from graphistry.otel import otel_traced, otel_detail_enabled
from graphistry.otel import otel_traced, otel_detail_enabled, inject_trace_headers

from .bolt_util import (
bolt_graph_to_edges_dataframe,
Expand Down Expand Up @@ -2498,11 +2498,44 @@ def plot(
'type': 'arrow',
'viztoken': str(uuid.uuid4())
}

# Validate collections in url_params (catches bypass of .collections() method)
from graphistry.validate.validate_collections import normalize_collections_url_params
url_params = normalize_collections_url_params(self._url_params, validate=validate_mode, warn=warn)

token = self.session.api_token
if token:
resp = None
try:
server_base = '%s://%s' % (self.session.protocol, self.session.hostname)
resp = requests.post(
'%s/api/v1/auth/jwt/ott/' % server_base,
headers=inject_trace_headers({'Authorization': 'Bearer %s' % token}),
verify=self.session.certificate_validation,
timeout=30,
)
resp.raise_for_status()
content_type = resp.headers.get('content-type', '')
if 'application/json' not in content_type:
raise ValueError(
'OTT endpoint returned non-JSON (content-type: %s) — '
'server may not have the OTT endpoint deployed yet. '
'Body: %.200s' % (content_type, resp.text))
url_params['token'] = resp.json()['ott']
except requests.HTTPError as e:
assert resp is not None
logger.warning(
"OTT exchange failed — cross-origin iframe embedding will require "
"re-login (SameSite cookies blocked). "
"Ensure OTT_EXCHANGE_SECRET is set on the server. "
"Error: %s (status=%s, body=%.200s)",
e, resp.status_code, resp.text)
except Exception as e:
logger.warning(
"OTT exchange failed — cross-origin iframe embedding will require "
"re-login (SameSite cookies blocked). "
"Error: %s (body=%.200s)",
e, resp.text if resp is not None else '<no response>')

viz_url = self._pygraphistry._viz_url(info, url_params)
cfg_client_protocol_hostname = self.session.client_protocol_hostname
full_url = ('%s:%s' % (self.session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url
Expand Down
159 changes: 144 additions & 15 deletions graphistry/tests/test_trace_headers_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from unittest import mock

import pandas as pd
import requests

# Import the ArrowFileUploader MODULE before graphistry shadows it with the class
# This ensures sys.modules has the module, allowing proper mock patching
# Import modules before graphistry shadows them with classes/symbols.
# This ensures sys.modules has the modules, allowing proper mock patching.
import graphistry.ArrowFileUploader as _arrow_file_uploader_module # noqa: F401
import graphistry.PlotterBase as _plotter_base_module # noqa: F401

import graphistry
from graphistry.compute.ast import n, e_forward
Expand Down Expand Up @@ -55,13 +57,13 @@ def _post_response_for_plot(url: str):
return _mock_response({"is_valid": True, "is_uploaded": True})
if "/api/v2/share/link/" in url:
return _mock_response({"success": True})
if "/api/v1/auth/jwt/ott/" in url:
return _mock_response({"ott": "test-ott-token"})
raise AssertionError(f"Unexpected POST url: {url}")


@mock.patch("graphistry.arrow_uploader.inject_trace_headers")
@mock.patch("requests.post")
def test_plot_injects_traceparent(mock_post, mock_inject):
mock_inject.side_effect = _inject_trace
def test_plot_injects_traceparent(mock_post):
headers_seen = []

def _fake_post(url, **kwargs):
Expand All @@ -70,22 +72,139 @@ def _fake_post(url, **kwargs):

mock_post.side_effect = _fake_post

g = _make_graph()
g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False)
plotter_base_module = sys.modules["graphistry.PlotterBase"]
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]

with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
g = _make_graph()
g.plot(render="g", as_files=False, validate=False, warn=False, memoize=False)

assert headers_seen
assert all(h.get("traceparent") == TRACEPARENT for h in headers_seen)


@mock.patch("graphistry.arrow_uploader.inject_trace_headers")
@mock.patch("requests.post")
def test_upload_injects_traceparent(mock_post, mock_inject_uploader):
# Patch ArrowFileUploader module's inject_trace_headers via sys.modules
# This is needed because graphistry.ArrowFileUploader resolves to the class,
# not the module (due to re-exports in graphistry/__init__.py)
arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"]
def test_plot_ott_in_url(mock_post):
"""OTT from JWT exchange must appear as ?token= in the returned viz URL."""
mock_post.side_effect = lambda url, **kw: _post_response_for_plot(url)

plotter_base_module = sys.modules["graphistry.PlotterBase"]
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]

with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
g = _make_graph()
url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)

assert "token=test-ott-token" in url, f"OTT missing from viz URL: {url}"


def _patch_inject(fn):
"""Decorator: patch inject_trace_headers in both modules that use it."""
import functools
@functools.wraps(fn)
@mock.patch("requests.post")
def wrapper(mock_post, *args, **kwargs):
plotter_base_module = sys.modules["graphistry.PlotterBase"]
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]
with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
return fn(mock_post, *args, **kwargs)
return wrapper


@_patch_inject
def test_plot_ott_http_error_degrades_gracefully(mock_post):
"""503 from OTT endpoint → URL has no ?token= (degrades to cookie auth)."""
def _side_effect(url, **kw):
if "/api/v1/auth/jwt/ott/" in url:
resp = _mock_response({"error": "server error"}, status=503)
resp.raise_for_status = mock.Mock(
side_effect=requests.HTTPError("503 Server Error", response=resp))
return resp
return _post_response_for_plot(url)

mock_post.side_effect = _side_effect
g = _make_graph()
url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)
assert "&token=" not in url, f"?token= must be absent on OTT failure: {url}"


mock_inject_uploader.side_effect = _inject_trace
@_patch_inject
def test_plot_ott_missing_key_degrades_gracefully(mock_post):
"""Malformed OTT response (no 'ott' key) → URL has no ?token=."""
def _side_effect(url, **kw):
if "/api/v1/auth/jwt/ott/" in url:
return _mock_response({}) # missing 'ott' key
return _post_response_for_plot(url)

mock_post.side_effect = _side_effect
g = _make_graph()
url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)
assert "&token=" not in url, f"?token= must be absent on malformed response: {url}"


@_patch_inject
def test_plot_ott_html_response_degrades_gracefully(mock_post):
"""Non-JSON (HTML) response from OTT endpoint → URL has no ?token=.

Reproduces the JSONDecodeError seen in Colab when the server redirects to
a login page (HTTP 200 + text/html) because the endpoint isn't deployed yet.
"""
def _side_effect(url, **kw):
if "/api/v1/auth/jwt/ott/" in url:
resp = mock.Mock()
resp.status_code = 200
resp.headers = {"content-type": "text/html; charset=utf-8"}
resp.text = "<html><body>Please log in</body></html>"
resp.raise_for_status = mock.Mock() # 200, does not raise
return resp
return _post_response_for_plot(url)

mock_post.side_effect = _side_effect
g = _make_graph()
url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)
assert "&token=" not in url, f"?token= must be absent when server returns HTML: {url}"


@_patch_inject
def test_plot_ott_connection_error_degrades_gracefully(mock_post):
"""Network error on OTT exchange → URL has no ?token=."""
def _side_effect(url, **kw):
if "/api/v1/auth/jwt/ott/" in url:
raise requests.ConnectionError("connection refused")
return _post_response_for_plot(url)

mock_post.side_effect = _side_effect
g = _make_graph()
url = g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)
assert "&token=" not in url, f"?token= must be absent on connection error: {url}"


@_patch_inject
def test_plot_ott_failure_warns_about_iframe(mock_post):
"""Warning message on OTT failure must mention cross-origin iframe re-login."""
def _side_effect(url, **kw):
if "/api/v1/auth/jwt/ott/" in url:
resp = _mock_response({"error": "misconfigured"}, status=503)
resp.raise_for_status = mock.Mock(
side_effect=requests.HTTPError("503", response=resp))
return resp
return _post_response_for_plot(url)

mock_post.side_effect = _side_effect
g = _make_graph()
plotter_base_module = sys.modules["graphistry.PlotterBase"]
with mock.patch.object(plotter_base_module.logger, "warning") as mock_warn:
g.plot(render="url", as_files=False, validate=False, warn=False, memoize=False)
assert mock_warn.called, "Expected a warning on OTT failure"
warning_text = " ".join(str(a) for a in mock_warn.call_args[0])
assert "cross-origin" in warning_text, f"Warning must mention cross-origin: {warning_text}"


@mock.patch("requests.post")
def test_upload_injects_traceparent(mock_post):
headers_seen = []

def _fake_post(url, **kwargs):
Expand All @@ -94,7 +213,17 @@ def _fake_post(url, **kwargs):

mock_post.side_effect = _fake_post

with mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace):
# Patch inject_trace_headers in all three modules that make POST requests:
# arrow_uploader.py, ArrowFileUploader.py, and PlotterBase.py (OTT exchange).
# Use sys.modules because graphistry/__init__.py re-exports some names as classes,
# shadowing the module attributes on the graphistry package.
arrow_uploader_module = sys.modules["graphistry.arrow_uploader"]
arrow_file_uploader_module = sys.modules["graphistry.ArrowFileUploader"]
plotter_base_module = sys.modules["graphistry.PlotterBase"]

with mock.patch.object(arrow_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(arrow_file_uploader_module, "inject_trace_headers", side_effect=_inject_trace), \
mock.patch.object(plotter_base_module, "inject_trace_headers", side_effect=_inject_trace):
g = _make_graph()
g.upload(validate=False, warn=False, memoize=False, erase_files_on_fail=False)

Expand Down
Loading