Skip to content

Commit 6d47c89

Browse files
committed
add semi transparent background colored overlay (for colorbar)
1 parent e8ba433 commit 6d47c89

10 files changed

Lines changed: 341 additions & 24 deletions

File tree

webgpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .background import Background
12
from .clipping import Clipping
23
from .colormap import Colormap, Colorbar
34
from .font import Font

webgpu/background.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""Generic rounded-rectangle background overlay renderer."""
2+
3+
import ctypes as ct
4+
5+
from .renderer import Renderer, RenderOptions
6+
from .uniforms import UniformBase
7+
from .utils import (
8+
create_bind_group,
9+
get_device,
10+
read_shader_file,
11+
)
12+
from .webgpu_api import (
13+
BlendComponent,
14+
BlendFactor,
15+
BlendOperation,
16+
BlendState,
17+
ColorTargetState,
18+
CompareFunction,
19+
DepthStencilState,
20+
FragmentState,
21+
PrimitiveState,
22+
VertexState,
23+
)
24+
25+
26+
_BINDING = 50
27+
28+
29+
class BackgroundUniforms(UniformBase):
30+
_binding = _BINDING
31+
_fields_ = [
32+
("position", ct.c_float * 2),
33+
("width", ct.c_float),
34+
("height", ct.c_float),
35+
("bg_color", ct.c_float * 3),
36+
("_pad", ct.c_float),
37+
]
38+
39+
40+
class Background(Renderer):
41+
"""Semi-transparent rounded-rectangle background overlay.
42+
43+
Place this before other renderers in a MultipleRenderer to provide
44+
a readable backdrop behind text or UI elements.
45+
46+
@param position: (x, y) top-left corner in NDC
47+
@param width: width in NDC
48+
@param height: height in NDC (of the content area, padding is added automatically)
49+
"""
50+
vertex_entry_point: str = "background_vertex"
51+
fragment_entry_point: str = "background_fragment"
52+
select_entry_point: str = ""
53+
n_vertices: int = 6
54+
n_instances: int = 1
55+
56+
def __init__(self, position=(0, 0), width=1, height=0.05):
57+
super().__init__()
58+
self._position = position
59+
self._width = width
60+
self._height = height
61+
self.uniforms = None
62+
63+
@property
64+
def position(self):
65+
return self._position
66+
67+
@position.setter
68+
def position(self, value):
69+
self._position = value
70+
if self.uniforms is not None:
71+
self.uniforms.position = value
72+
self.uniforms.update_buffer()
73+
self.set_needs_update()
74+
75+
@property
76+
def width(self):
77+
return self._width
78+
79+
@width.setter
80+
def width(self, value):
81+
self._width = value
82+
if self.uniforms is not None:
83+
self.uniforms.width = value
84+
self.uniforms.update_buffer()
85+
self.set_needs_update()
86+
87+
@property
88+
def height(self):
89+
return self._height
90+
91+
@height.setter
92+
def height(self, value):
93+
self._height = value
94+
if self.uniforms is not None:
95+
self.uniforms.height = value
96+
self.uniforms.update_buffer()
97+
self.set_needs_update()
98+
99+
def get_shader_code(self):
100+
return read_shader_file("background.wgsl")
101+
102+
def get_bindings(self):
103+
return self.uniforms.get_bindings()
104+
105+
def update(self, options: RenderOptions):
106+
if self.uniforms is None:
107+
self.uniforms = BackgroundUniforms()
108+
self.uniforms.position = self.position
109+
self.uniforms.width = self.width
110+
self.uniforms.height = self.height
111+
self.uniforms.bg_color = (1.0, 1.0, 1.0)
112+
self.uniforms.update_buffer()
113+
114+
def create_render_pipeline(self, options: RenderOptions) -> None:
115+
bindings = options.get_bindings() + self.get_bindings()
116+
117+
if bindings == self._last_bindings:
118+
return
119+
120+
layout, self.group = create_bind_group(
121+
self.device, options.get_bindings() + self.get_bindings()
122+
)
123+
pipeline_layout = self.device.createPipelineLayout([layout])
124+
125+
depth_stencil = DepthStencilState(
126+
format=options.canvas.depth_format,
127+
depthWriteEnabled=False,
128+
depthCompare=CompareFunction.always,
129+
)
130+
131+
bg_color_target = ColorTargetState(
132+
format=options.canvas.format,
133+
blend=BlendState(
134+
color=BlendComponent(
135+
srcFactor=BlendFactor.src_alpha,
136+
dstFactor=BlendFactor.one_minus_src_alpha,
137+
operation=BlendOperation.add,
138+
),
139+
alpha=BlendComponent(
140+
srcFactor=BlendFactor.one,
141+
dstFactor=BlendFactor.one_minus_src_alpha,
142+
operation=BlendOperation.add,
143+
),
144+
),
145+
)
146+
147+
shader_module = self.device.createShaderModule(self._get_preprocessed_shader_code())
148+
self.pipeline = self.device.createRenderPipeline(
149+
pipeline_layout,
150+
vertex=VertexState(
151+
module=shader_module,
152+
entryPoint=self.vertex_entry_point,
153+
buffers=[],
154+
),
155+
fragment=FragmentState(
156+
module=shader_module,
157+
entryPoint=self.fragment_entry_point,
158+
targets=[bg_color_target],
159+
),
160+
primitive=PrimitiveState(topology=self.topology),
161+
depthStencil=depth_stencil,
162+
multisample=options.canvas.multisample,
163+
label="Background",
164+
)
165+
self._select_pipeline = None
166+
self._transparent_pipeline = None
167+
self._last_bindings = bindings
168+
169+
def get_export_descriptor(self, options, buffer_registry):
170+
desc = super().get_export_descriptor(options, buffer_registry)
171+
desc.depth_write = False
172+
desc.pass_type = "transparent"
173+
return desc
174+
175+
def get_theme_buffer_id(self, registry):
176+
"""Return the buffer id for theme color updates (bg_color at offset 16)."""
177+
if self.uniforms is None or self.uniforms._buffer is None:
178+
return None
179+
key = id(self.uniforms._buffer)
180+
if key in registry._buffers:
181+
return registry._buffers[key][0]
182+
return None

webgpu/colormap.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

3+
from .background import Background
34
from .labels import Labels
4-
from .renderer import BaseRenderer, Renderer, RenderOptions
5+
from .renderer import BaseRenderer, MultipleRenderer, Renderer, RenderOptions
56
from .uniforms import Binding, UniformBase, ct
67
from .utils import (
78
SamplerBinding,
@@ -253,12 +254,34 @@ def _create_texture(self):
253254
)
254255

255256

256-
class Colorbar(Renderer):
257+
class ColorbarStrip(Renderer):
258+
"""Renders the colored strip of the colorbar."""
257259
vertex_entry_point: str = "colormap_vertex"
258260
fragment_entry_point: str = "colormap_fragment"
259261
select_entry_point: str = ""
260262
n_vertices: int = 3
261263

264+
def __init__(self, get_bindings_fn):
265+
super().__init__()
266+
self._get_bindings_fn = get_bindings_fn
267+
268+
def get_shader_code(self):
269+
return read_shader_file("colormap.wgsl")
270+
271+
def get_bindings(self):
272+
return self._get_bindings_fn()
273+
274+
def update(self, options: RenderOptions):
275+
pass
276+
277+
def get_export_descriptor(self, options, buffer_registry):
278+
desc = super().get_export_descriptor(options, buffer_registry)
279+
desc.pass_type = "transparent"
280+
return desc
281+
282+
283+
class Colorbar(MultipleRenderer):
284+
262285
def __init__(
263286
self,
264287
colormap: Colormap | None = None,
@@ -267,16 +290,20 @@ def __init__(
267290
height=0.05,
268291
number_format=None,
269292
):
270-
super().__init__()
271-
self.gpu_objects.colormap = colormap or Colormap()
293+
self.colormap = colormap or Colormap()
272294
self.number_format = number_format
273-
self.gpu_objects.labels = Labels([], [], font_size=14, h_align="center", v_align="top")
274295
self.uniforms = None
275296

276297
self._position = position
277298
self._width = width
278299
self._height = height
279-
colormap._callbacks.append(self.set_needs_update)
300+
301+
self._bg = Background(position=position, width=width, height=height)
302+
self._strip = ColorbarStrip(lambda: self._get_all_bindings())
303+
self._labels = Labels([], [], font_size=14, h_align="center", v_align="top")
304+
305+
super().__init__([self._bg, self._strip, self._labels])
306+
self.colormap._callbacks.append(self.set_needs_update)
280307

281308
@property
282309
def position(self):
@@ -287,6 +314,7 @@ def position(self, value):
287314
self._position = value
288315
if self.uniforms is not None:
289316
self.uniforms.position = value
317+
self._bg.position = value
290318
self.set_needs_update()
291319

292320
@property
@@ -298,6 +326,7 @@ def width(self, value):
298326
self._width = value
299327
if self.uniforms is not None:
300328
self.uniforms.width = value
329+
self._bg.width = value
301330
self.set_needs_update()
302331

303332
@property
@@ -309,14 +338,12 @@ def height(self, value):
309338
self._height = value
310339
if self.uniforms is not None:
311340
self.uniforms.height = value
341+
self._bg.height = value
312342
self.set_needs_update()
313343

314-
def get_shader_code(self):
315-
return read_shader_file("colormap.wgsl")
316-
317-
def get_bindings(self):
344+
def _get_all_bindings(self):
318345
return (
319-
self.gpu_objects.colormap.get_bindings() + self.gpu_objects.labels.get_bindings() + self.uniforms.get_bindings()
346+
self.colormap.get_bindings() + self._labels.get_bindings() + self.uniforms.get_bindings()
320347
)
321348

322349
def update(self, options: RenderOptions):
@@ -327,38 +354,33 @@ def update(self, options: RenderOptions):
327354
self.uniforms.height = self.height
328355

329356
self.uniforms.update_buffer()
357+
self.colormap.update(options)
330358

331-
self.n_instances = 2 * self.gpu_objects.colormap.n_colors
359+
self._strip.n_instances = 2 * self.colormap.n_colors
332360

333-
self.gpu_objects.labels.labels = [
361+
self._labels.labels = [
334362
format_number(v, self.number_format)
335363
for v in [
336-
self.gpu_objects.colormap.minval + i / 4 * (self.gpu_objects.colormap.maxval - self.gpu_objects.colormap.minval)
364+
self.colormap.minval + i / 4 * (self.colormap.maxval - self.colormap.minval)
337365
for i in range(6)
338366
]
339367
]
340-
self.gpu_objects.labels.positions = [
368+
self._labels.positions = [
341369
(
342370
self.position[0] + i * self.width / 4,
343371
self.position[1] - 0.01,
344372
0,
345373
)
346374
for i in range(5)
347375
]
348-
self.gpu_objects.labels.set_needs_update()
349-
350-
def render(self, options: RenderOptions):
351-
super().render(options)
352-
self.gpu_objects.labels.render(options)
353-
354-
376+
super().update(options)
355377

356378
def set_min(self, minval):
357-
self.gpu_objects.colormap.set_min(minval)
379+
self.colormap.set_min(minval)
358380
self.set_needs_update()
359381

360382
def set_max(self, maxval):
361-
self.gpu_objects.colormap.set_max(maxval)
383+
self.colormap.set_max(maxval)
362384
self.set_needs_update()
363385

364386

webgpu/engine/engine.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class RenderEngine {
178178
interactions: descriptor.interactions || [],
179179
camera: descriptor.camera || {},
180180
light: descriptor.light || {},
181+
theme: descriptor.theme || {},
181182
};
182183

183184
this.buffers = _toMap(descriptor.buffers);
@@ -905,6 +906,14 @@ class RenderEngine {
905906
if (this.canvas && this.canvas.style) {
906907
this.canvas.style.backgroundColor = dark ? DARK_CANVAS_BG : LIGHT_CANVAS_BG;
907908
}
909+
// Update colorbar background color to match theme
910+
if (this.scene && this.scene.theme && this.scene.theme.buffer_id) {
911+
const buf = this.buffers && this.buffers.get(this.scene.theme.buffer_id);
912+
if (buf) {
913+
const c = this.clearColor;
914+
this.device.queue.writeBuffer(buf, 16, new Float32Array([c.r, c.g, c.b]));
915+
}
916+
}
908917
}
909918

910919
_setupThemeObserver() {

webgpu/export/capture.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _capture(scene, live: bool):
5757
interactions=_detect_interactions(scene, registry),
5858
camera=_export_camera(options.camera, options, registry),
5959
light=_export_light(options.light, registry, live=live),
60+
theme=_export_theme(scene, registry),
6061
)
6162
return export, registry
6263

@@ -106,6 +107,28 @@ def _capture_renderer(obj, options, registry, render_passes, compute_passes):
106107
compute_passes.extend(attr.get_export_compute_passes(options, registry))
107108

108109

110+
def _export_theme(scene, registry) -> dict:
111+
"""Find any renderer's buffer for theme color updates."""
112+
from ..renderer import Renderer
113+
for obj in scene.render_objects:
114+
buf_id = _find_theme_buffer(obj, registry)
115+
if buf_id:
116+
return {"buffer_id": buf_id}
117+
return {}
118+
119+
120+
def _find_theme_buffer(obj, registry) -> str | None:
121+
"""Recursively check if a renderer provides a theme-sensitive buffer."""
122+
if hasattr(obj, 'get_theme_buffer_id'):
123+
return obj.get_theme_buffer_id(registry)
124+
if hasattr(obj, 'render_objects'):
125+
for child in obj.render_objects:
126+
buf_id = _find_theme_buffer(child, registry)
127+
if buf_id:
128+
return buf_id
129+
return None
130+
131+
109132
def _export_camera(camera, options, registry) -> dict:
110133
"""Export camera state and identify the camera buffer."""
111134
t = camera.transform

0 commit comments

Comments
 (0)