Skip to content

Commit fe8d2fe

Browse files
committed
mesh 2d element colors
1 parent 06f83b6 commit fe8d2fe

6 files changed

Lines changed: 57 additions & 24 deletions

File tree

ngsolve_webgpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .clipping import ClippingCF
44
from webgpu.colormap import Colorbar, Colormap
55
from webgpu.clipping import Clipping
6+
from .geometry import GeometryRenderer
67

78

89
from webgpu.utils import register_shader_directory as _register_shader_directory

ngsolve_webgpu/cf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from webgpu.vectors import BaseVectorRenderer, VectorRenderer
1616
from webgpu.webgpu_api import Buffer
1717

18-
from .mesh import Binding as MeshBinding, MeshElements2d
18+
from .mesh import Binding as MeshBinding, BaseMeshElements2d
1919
from .mesh import ElType, MeshData
2020

2121

@@ -264,7 +264,7 @@ def vandermonde_3d(order):
264264
return _vandermonde_mats[order]
265265

266266

267-
class CFRenderer(MeshElements2d):
267+
class CFRenderer(BaseMeshElements2d):
268268
"""Use "vertices", "index" and "trig_function_values" buffers to render a mesh"""
269269

270270
fragment_entry_point = "fragmentTrig"

ngsolve_webgpu/geometry.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class Binding:
2020

2121

2222
class BaseGeometryRenderer(Renderer):
23+
clipping: Clipping | None = None
24+
25+
def __init__(self, clipping, *args, **kwargs):
26+
self.clipping = clipping
27+
super().__init__(*args, **kwargs)
28+
2329
def create_render_pipeline(self, options):
2430
super().create_render_pipeline(options)
2531
self.create_pick_pipeline(options)
@@ -71,10 +77,9 @@ def pick_index_render(self, encoder, texture, depth_texture, load_op):
7177

7278
class GeometryFaceRenderer(BaseGeometryRenderer):
7379
n_vertices: int = 3
74-
clipping: Clipping | None = None
7580

76-
def __init__(self, geo):
77-
super().__init__(label="GeometryFaces")
81+
def __init__(self, geo, clipping):
82+
super().__init__(clipping, label="GeometryFaces")
7883
self.geo = geo
7984
self.colors = None
8085
self.active = True
@@ -119,15 +124,14 @@ def get_shader_code(self):
119124
class GeometryEdgeRenderer(BaseGeometryRenderer):
120125
n_vertices: int = 4
121126
topology: PrimitiveTopology = PrimitiveTopology.triangle_strip
122-
clipping: Clipping | None = None
123127

124128
# make sure that edges are rendered on top of faces
125129
depthBias: int = -5
126130
depthBiasSlopeScale: int = -5
127131

128-
def __init__(self, geo):
132+
def __init__(self, geo, clipping):
129133
self.geo = geo
130-
super().__init__(label="GeometryEdges")
134+
super().__init__(clipping, label="GeometryEdges")
131135
self.active = True
132136
self.thickness = 0.005
133137
self._buffers = {}
@@ -163,11 +167,10 @@ def get_bindings(self):
163167
class GeometryVertexRenderer(BaseGeometryRenderer):
164168
n_vertices: int = 4
165169
topology: PrimitiveTopology = PrimitiveTopology.triangle_strip
166-
clipping: Clipping | None = None
167170

168-
def __init__(self, geo):
171+
def __init__(self, geo, clipping):
169172
self.geo = geo
170-
super().__init__(label="GeometryVertices")
173+
super().__init__(clipping, label="GeometryVertices")
171174
self.active = True
172175
self.thickness = 0.05
173176
self._buffers = {}
@@ -203,12 +206,12 @@ def get_bindings(self):
203206

204207

205208
class GeometryRenderer(MultipleRenderer):
206-
def __init__(self, geo, label="Geometry"):
209+
def __init__(self, geo, label="Geometry", clipping=None):
207210
self.geo = geo
208-
self.faces = GeometryFaceRenderer(geo)
209-
self.edges = GeometryEdgeRenderer(geo)
210-
self.vertices = GeometryVertexRenderer(geo)
211-
self.clipping = Clipping()
211+
self.clipping = clipping or Clipping()
212+
self.faces = GeometryFaceRenderer(geo, self.clipping)
213+
self.edges = GeometryEdgeRenderer(geo, self.clipping)
214+
self.vertices = GeometryVertexRenderer(geo, self.clipping)
212215
self.faces.clipping = self.clipping
213216
self.edges.clipping = self.clipping
214217
self.vertices.clipping = self.clipping

ngsolve_webgpu/mesh.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from webgpu.clipping import Clipping
66
from webgpu.font import Font
77
from webgpu.renderer import Renderer, RenderOptions, check_timestamp
8+
from webgpu.colormap import Colormap
89

910
# from webgpu.uniforms import Binding
1011
from webgpu.uniforms import UniformBase, ct
@@ -206,7 +207,7 @@ def _create_data(self):
206207
self.num_elements[ElType.TRIG] = len(trigs)
207208
trigs_data = np.zeros((len(trigs), 4), dtype=np.uint32)
208209
trigs_data[:, :3] = trigs["nodes"][:, :3] - 1
209-
trigs_data[:, 3] = trigs["index"]
210+
trigs_data[:, 3] = trigs["index"] - 1
210211
self.elements[ElType.TRIG] = trigs_data
211212

212213
# 3d Elements
@@ -276,7 +277,7 @@ def get_buffers(self):
276277
return self.gpu_elements
277278

278279

279-
class MeshElements2d(Renderer):
280+
class BaseMeshElements2d(Renderer):
280281
depthBias: int = 1
281282
depthBiasSlopeScale: float = 1.0
282283
vertex_entry_point: str = "vertexTrigP1Indexed"
@@ -328,8 +329,29 @@ def get_bindings(self):
328329
def get_shader_code(self):
329330
return read_shader_file("ngsolve/mesh.wgsl")
330331

332+
class MeshElements2d(BaseMeshElements2d):
333+
fragment_entry_point = "fragment2dElement"
331334

332-
class MeshWireframe2d(MeshElements2d):
335+
def __init__(self, data: MeshData, clipping=None,
336+
colors: list | None = None,
337+
label="MeshElements2d"):
338+
super().__init__(data, label=label, clipping=clipping)
339+
if colors is None:
340+
mesh = data.mesh
341+
colors = [[int(ci * 255) for ci in fd.color] for fd in mesh.FaceDescriptors()]
342+
self.colormap = Colormap(colormap=colors, minval=-0.5, maxval=len(colors)-0.5)
343+
self.colormap.discrete = 0
344+
self.colormap.n_colors = 4*len(colors)
345+
346+
def update(self, options: RenderOptions):
347+
super().update(options)
348+
self.colormap.update(options)
349+
350+
def get_bindings(self):
351+
return super().get_bindings() + self.colormap.get_bindings()
352+
353+
354+
class MeshWireframe2d(BaseMeshElements2d):
333355
depthBias: int = 0
334356
depthBiasSlopeScale: float = 0.
335357
topology: PrimitiveTopology = PrimitiveTopology.line_strip
@@ -374,6 +396,7 @@ def shrink(self, value):
374396
self._shrink = value
375397
if self.uniforms is not None:
376398
self.uniforms.shrink = value
399+
self.uniforms.update_buffer()
377400

378401
def get_bounding_box(self) -> tuple[list[float], list[float]] | None:
379402
return self.data.get_bounding_box()

ngsolve_webgpu/shaders/mesh.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ fn vertexMeshHex(@builtin(vertex_index) vertId: u32, @builtin(instance_index) el
247247
@fragment
248248
fn fragment2dElement(input: VertexOutput2d) -> @location(0) vec4<f32> {
249249
checkClipping(input.p);
250-
return lightCalcColor(input.p, input.n, u_mesh_color);
250+
return lightCalcColor(input.p, input.n, getColor(f32(input.index)));
251251
}
252252

253253
@fragment

ngsolve_webgpu/shaders/shader.wgsl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct VertexOutput2d {
1313
@location(1) lam: vec2<f32>,
1414
@location(2) @interpolate(flat) id: u32,
1515
@location(3) n: vec3<f32>,
16+
@location(4) @interpolate(flat) index: u32,
1617
};
1718

1819
struct VertexOutput3d {
@@ -37,7 +38,8 @@ fn vertexEdgeP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) e
3738
return VertexOutput1d(position, p, lam, edgeId);
3839
}
3940

40-
fn calcTrig(p: array<vec3<f32>, 3>, vertexId: u32, trigId: u32) -> VertexOutput2d {
41+
fn calcTrig(p: array<vec3<f32>, 3>, vertexId: u32, trigId: u32, index: u32)
42+
-> VertexOutput2d {
4143
let subdivision = u_subdivision;
4244
let h = 1.0 / f32(subdivision);
4345

@@ -83,7 +85,8 @@ fn calcTrig(p: array<vec3<f32>, 3>, vertexId: u32, trigId: u32) -> VertexOutput2
8385

8486
let mapped_position = cameraMapPoint(position);
8587

86-
return VertexOutput2d(mapped_position, position, lam, trigId, normal);
88+
return VertexOutput2d(mapped_position, position, lam, trigId, normal,
89+
index);
8790
}
8891

8992

@@ -95,12 +98,13 @@ fn vertexTrigP1Indexed(@builtin(vertex_index) vertexId: u32, @builtin(instance_i
9598
trigs[4 * trigId + 2]
9699
);
97100

101+
let index = trigs[4 * trigId + 3];
98102
var p = array<vec3<f32>, 3>(
99103
vec3<f32>(vertices[vid[0] ], vertices[vid[0] + 1], vertices[vid[0] + 2]),
100104
vec3<f32>(vertices[vid[1] ], vertices[vid[1] + 1], vertices[vid[1] + 2]),
101105
vec3<f32>(vertices[vid[2] ], vertices[vid[2] + 1], vertices[vid[2] + 2])
102106
);
103-
return calcTrig(p, vertexId, trigId);
107+
return calcTrig(p, vertexId, trigId, index);
104108
}
105109

106110
@vertex
@@ -110,6 +114,7 @@ fn vertexWireframe2d(@builtin(vertex_index) vertexId: u32, @builtin(instance_ind
110114
trigs[4 * trigId + 1],
111115
trigs[4 * trigId + 2]
112116
);
117+
let index = trigs[4 * trigId + 3];
113118

114119
var p = array<vec3<f32>, 3>(
115120
vec3<f32>(vertices[vid[0] ], vertices[vid[0] + 1], vertices[vid[0] + 2]),
@@ -155,7 +160,8 @@ fn vertexWireframe2d(@builtin(vertex_index) vertexId: u32, @builtin(instance_ind
155160
position += u_deformation_scale * evalTrigVec3(&u_deformation_values_2d, trigId, lam);
156161
}
157162
return VertexOutput2d(cameraMapPoint(position), position, lam, trigId,
158-
normalize(cross(p[1] - p[0], p[2] - p[0])));
163+
normalize(cross(p[1] - p[0], p[2] - p[0])),
164+
index);
159165
}
160166

161167

0 commit comments

Comments
 (0)