Skip to content

Commit 257d059

Browse files
authored
Merge branch 'plotly:master' into patch-3
2 parents 250a09b + d93391e commit 257d059

9 files changed

Lines changed: 1539 additions & 31 deletions

File tree

src/py/CHANGELOG.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
v1.2.0
2+
- Try to use plotly JSON encoder instead of default
3+
14
v1.1.0
25
- Add testing
36
- Fix a variety of type bugs
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""These are really just scripts, run them as such. Include pickle dev group."""
Binary file not shown.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Create pickled fig for use in integration tests."""
2+
3+
import pickle
4+
from pathlib import Path
5+
6+
import datashader as ds
7+
import datashader.transfer_functions as tf
8+
import pandas as pd
9+
import plotly.express as px
10+
import zstandard as zstd
11+
from colorcet import fire
12+
13+
cctx = zstd.ZstdCompressor(level=20)
14+
15+
df = pd.read_csv(
16+
"https://raw.githubusercontent.com/plotly/datasets/master/uber-rides-data1.csv",
17+
)
18+
dff = (
19+
df.query("Lat < 40.82")
20+
.query("Lat > 40.70")
21+
.query("Lon > -74.02")
22+
.query("Lon < -73.91")
23+
)
24+
25+
26+
cvs = ds.Canvas(plot_width=1000, plot_height=1000)
27+
agg = cvs.points(dff, x="Lon", y="Lat")
28+
# agg is an xarray object, see http://xarray.pydata.org/en/stable/ for more details
29+
coords_lat, coords_lon = agg.coords["Lat"].to_numpy(), agg.coords["Lon"].to_numpy()
30+
# Corners of the image
31+
coordinates = [
32+
[coords_lon[0], coords_lat[0]],
33+
[coords_lon[-1], coords_lat[0]],
34+
[coords_lon[-1], coords_lat[-1]],
35+
[coords_lon[0], coords_lat[-1]],
36+
]
37+
38+
39+
img = tf.shade(agg, cmap=fire)[::-1].to_pil()
40+
41+
42+
# Trick to create rapidly a figure with map axes
43+
fig = px.scatter_map(dff[:1], lat="Lat", lon="Lon", zoom=12)
44+
# Add the datashader image as a tile map layer image
45+
fig.update_layout(
46+
map_style="carto-darkmatter",
47+
map_layers=[{"sourcetype": "image", "source": img, "coordinates": coordinates}],
48+
)
49+
50+
raw = pickle.dumps(fig, protocol=5) # >=3.8
51+
compressed = cctx.compress(raw)
52+
with Path(f"./figs/{Path(__file__).stem}.pkl.zst").open("wb") as f:
53+
f.write(compressed)
54+
55+
print( # noqa: T201
56+
f"{Path(__file__).stem}.pkl: "
57+
f"{len(raw) / 1024:.1f} -> {len(compressed) / 1024:.1f} KB",
58+
)

src/py/kaleido/kaleido.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@
2828

2929
_logger = logistro.getLogger(__name__)
3030

31+
try:
32+
from plotly.utils import PlotlyJSONEncoder # noqa: I001
33+
from choreographer import channels
34+
35+
channels.register_custom_encoder(PlotlyJSONEncoder)
36+
_logger.debug("Successfully registered PlotlyJSONEncoder.")
37+
except ImportError as e:
38+
_logger.debug(f'Couldn\'t import plotly due to "{e!s}" - skipping.')
39+
3140
# Show a warning if the installed Plotly version
3241
# is incompatible with this version of Kaleido
3342
warn_incompatible_plotly()

src/py/pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ kaleido = ['vendor/**']
1414
[project]
1515
name = "kaleido"
1616
description = "Plotly graph export library"
17-
license = {file = "LICENSE.md"}
17+
license = { "file" = "LICENSE.md" }
1818
readme = "README.md"
1919
requires-python = ">=3.8"
2020
dynamic = ["version"]
@@ -26,7 +26,7 @@ maintainers = [
2626
{name = "Andrew Pikul", email = "ajpikul@gmail.com"},
2727
]
2828
dependencies = [
29-
"choreographer>=1.0.10",
29+
"choreographer>=1.1.1",
3030
"logistro>=1.0.8",
3131
"orjson>=3.10.15",
3232
"packaging",
@@ -55,6 +55,13 @@ dev = [
5555
"typing-extensions>=4.12.2",
5656
"hypothesis>=6.113.0",
5757
]
58+
pickles = [
59+
"colorcet>=3.1.0",
60+
"datashader>=0.15.2",
61+
"pillow>=10.4.0",
62+
"plotly[express]>=6.3.0",
63+
"zstandard>=0.23.0",
64+
]
5865

5966
[tool.ruff.lint]
6067
select = ["ALL"]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import numpy as np
2+
import plotly.graph_objects as go
3+
import pytest
4+
from plotly.subplots import make_subplots
5+
6+
import kaleido
7+
8+
pytestmark = pytest.mark.asyncio(loop_scope="function")
9+
10+
rng = np.random.default_rng() # creates a Generator instance
11+
12+
13+
async def test_complex_plotly_encoder():
14+
"""Test that kaleido can handle complex Plotly figures with numpy arrays."""
15+
16+
# Create complex numpy arrays
17+
x = np.linspace(0, 4 * np.pi, 100)
18+
y1 = np.sin(x) * np.exp(-x / 10)
19+
20+
# Create a 2D array for heatmap
21+
z = np.outer(
22+
np.sin(np.linspace(0, np.pi, 20)),
23+
np.cos(np.linspace(0, np.pi, 30)),
24+
)
25+
26+
# Create subplot figure
27+
fig = make_subplots(
28+
rows=2,
29+
cols=2,
30+
subplot_titles=(
31+
"Complex Scatter",
32+
"Heatmap",
33+
"Bar with Color",
34+
"3D Surface",
35+
),
36+
specs=[
37+
[{"type": "scatter"}, {"type": "heatmap"}],
38+
[{"type": "bar"}, {"type": "surface"}],
39+
],
40+
)
41+
42+
# Scatter with numpy marker sizes and complex styling
43+
fig.add_trace(
44+
go.Scatter(
45+
x=x,
46+
y=y1,
47+
mode="lines+markers",
48+
line={"color": "blue", "width": 2},
49+
marker={
50+
"size": [rng.integers(2, 12) for _ in range(len(x))],
51+
"color": np.abs(y1),
52+
"colorscale": "Viridis",
53+
"showscale": True,
54+
},
55+
),
56+
row=1,
57+
col=1,
58+
)
59+
60+
# Heatmap with 2D numpy array
61+
fig.add_trace(
62+
go.Heatmap(
63+
z=z,
64+
colorscale="Plasma",
65+
),
66+
row=1,
67+
col=2,
68+
)
69+
70+
# Bar chart with numpy data and color mapping
71+
categories = np.array(["A", "B", "C", "D", "E", "F"])
72+
values = rng.normal(50, 20, len(categories))
73+
74+
fig.add_trace(
75+
go.Bar(
76+
x=categories,
77+
y=values,
78+
marker={
79+
"color": np.abs(values),
80+
"colorscale": "Blues",
81+
"line": {"width": 2, "color": "black"},
82+
},
83+
),
84+
row=2,
85+
col=1,
86+
)
87+
88+
# 3D surface with complex numpy operations
89+
x_surf = np.linspace(-3, 3, 30)
90+
y_surf = np.linspace(-3, 3, 30)
91+
X, Y = np.meshgrid(x_surf, y_surf) # noqa: N806
92+
Z = np.sin(np.sqrt(X**2 + Y**2)) * np.exp(-(X**2 + Y**2) / 10) # noqa: N806
93+
94+
fig.add_trace(
95+
go.Surface(
96+
x=X,
97+
y=Y,
98+
z=Z,
99+
colorscale="Cividis",
100+
),
101+
row=2,
102+
col=2,
103+
)
104+
105+
# Complex layout with numpy-based annotations
106+
fig.update_layout(
107+
title="Complex Numpy Figure for Encoder Testing",
108+
height=800,
109+
width=1200,
110+
)
111+
112+
# Render with kaleido
113+
img_bytes = await kaleido.calc_fig(fig)
114+
115+
assert isinstance(img_bytes, bytes)

src/py/tests/test_page_generator.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,32 +175,29 @@ def st_mathjax(dir_path: Path):
175175
# Test default combinations
176176
@pytest.mark.order(1)
177177
async def test_defaults_no_plotly_available():
178-
"""Test defaults when plotly package is not available."""
179-
if not find_spec("plotly"):
180-
raise ImportError("Tests must be run with plotly installed to function")
181-
182-
old_path = sys.path
183-
sys.path = sys.path[:1]
184-
if find_spec("plotly"):
185-
raise RuntimeError(
186-
"Plotly cannot be imported during this test, "
187-
"as this tests default behavior while trying to import plotly. "
188-
"The best solution is to make sure this test always runs first, "
189-
"or if you really need to, run it separately and then skip it "
190-
"in the main group.",
191-
)
192-
193-
# Test no imports (plotly not available)
194-
no_imports = PageGenerator().generate_index()
195-
scripts, _encodings = get_scripts_from_html(no_imports)
196-
197-
# Should have mathjax, plotly default, and kaleido_scopes
198-
assert len(scripts) == 3 # noqa: PLR2004
199-
assert scripts[0] == DEFAULT_MATHJAX
200-
assert scripts[1] == DEFAULT_PLOTLY
201-
assert scripts[2].endswith("kaleido_scopes.js")
178+
"""
179+
Test defaults when plotly package is not available.
202180
203-
sys.path = old_path
181+
When we generate_index(), if we don't have plotly in path, we use a CDN.
182+
"""
183+
_old_path = sys.path
184+
try:
185+
sys.path = []
186+
_plotly_mo = sys.modules.pop("plotly", None)
187+
188+
# Test no imports (plotly not available)
189+
no_imports = PageGenerator().generate_index()
190+
scripts, _encodings = get_scripts_from_html(no_imports)
191+
192+
# Should have mathjax, plotly default, and kaleido_scopes
193+
assert len(scripts) == 3 # noqa: PLR2004
194+
assert scripts[0] == DEFAULT_MATHJAX
195+
assert scripts[1] == DEFAULT_PLOTLY
196+
assert scripts[2].endswith("kaleido_scopes.js")
197+
finally:
198+
sys.path = _old_path
199+
if _plotly_mo:
200+
sys.modules.update({"plotly": _plotly_mo})
204201

205202

206203
async def test_defaults_with_plotly_available():

0 commit comments

Comments
 (0)