|
| 1 | +import ngsolve as ngs |
| 2 | +import numpy as np |
| 3 | +from webgpu.shapes import ShapeRenderer, generate_cone, generate_cylinder |
| 4 | +from webgpu.utils import ( |
| 5 | + BufferUsage, |
| 6 | + BufferBinding, |
| 7 | + ReadBuffer, |
| 8 | + UniformBinding, |
| 9 | + read_buffer, |
| 10 | + read_shader_file, |
| 11 | + run_compute_shader, |
| 12 | + uniform_from_array, |
| 13 | + buffer_from_array, |
| 14 | + write_array_to_buffer, |
| 15 | +) |
| 16 | + |
| 17 | +from .cf import FunctionData, MeshData, Binding as FunctionBinding |
| 18 | +from .mesh import Binding as MeshBinding |
| 19 | +from .mesh import ElType |
| 20 | + |
| 21 | + |
| 22 | +class SurfaceVectors(ShapeRenderer): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + function_data: FunctionData, |
| 26 | + mesh: MeshData, |
| 27 | + grid_size: float = 0.02, |
| 28 | + ): |
| 29 | + self.function_data = function_data |
| 30 | + self.mesh = mesh |
| 31 | + |
| 32 | + bbox = mesh.get_bounding_box() |
| 33 | + grid_size = np.linalg.norm(np.array(bbox[1]) - np.array(bbox[0])) * grid_size |
| 34 | + |
| 35 | + cyl = generate_cylinder(8, 0.05, 0.5, bottom_face=True) |
| 36 | + cone = generate_cone(8, 0.2, 0.5, bottom_face=True) |
| 37 | + arrow = cyl + cone.move((0, 0, 0.5)) |
| 38 | + |
| 39 | + super().__init__(arrow, None, None) |
| 40 | + # self.scale_mode = ShapeRenderer.SCALE_Z |
| 41 | + |
| 42 | + def get_bounding_box(self): |
| 43 | + return self.mesh.get_bounding_box() |
| 44 | + |
| 45 | + def get_compute_bindings(self): |
| 46 | + return [] |
| 47 | + |
| 48 | + def compute_vectors(self): |
| 49 | + self.u_nvectors = buffer_from_array( |
| 50 | + np.array([0], dtype=np.uint32), |
| 51 | + label="n_vectors", |
| 52 | + usage=BufferUsage.STORAGE | BufferUsage.COPY_DST | BufferUsage.COPY_SRC, |
| 53 | + ) |
| 54 | + |
| 55 | + mesh_buffers = self.mesh.get_buffers() |
| 56 | + func_buffers = self.function_data.get_buffers() |
| 57 | + n_trigs = self.mesh.num_elements[ElType.TRIG] |
| 58 | + self.u_ntrigs = uniform_from_array(np.array([n_trigs], dtype=np.uint32), label="n_trigs") |
| 59 | + |
| 60 | + positions = buffer_from_array(np.array([0], dtype=np.float32), label="positions") |
| 61 | + directions = buffer_from_array(np.array([0], dtype=np.float32), label="positions") |
| 62 | + values = buffer_from_array(np.array([0], dtype=np.float32), label="positions") |
| 63 | + |
| 64 | + bindings = [ |
| 65 | + *self.colormap.get_bindings(), |
| 66 | + BufferBinding(MeshBinding.VERTICES, mesh_buffers["vertices"]), |
| 67 | + BufferBinding(MeshBinding.TRIGS_INDEX, mesh_buffers[ElType.TRIG]), |
| 68 | + BufferBinding(22, positions, read_only=False), |
| 69 | + BufferBinding(23, directions, read_only=False), |
| 70 | + BufferBinding(25, values, read_only=False), |
| 71 | + BufferBinding(21, self.u_nvectors, read_only=False), |
| 72 | + UniformBinding(24, self.u_ntrigs), |
| 73 | + BufferBinding(MeshBinding.CURVATURE_VALUES_2D, mesh_buffers["curvature_2d"]), |
| 74 | + # BufferBinding(MeshBinding.DEFORMATION_VALUES, mesh_buffers["deformation_2d"]), |
| 75 | + UniformBinding(MeshBinding.DEFORMATION_SCALE, mesh_buffers["deformation_scale"]), |
| 76 | + BufferBinding(FunctionBinding.FUNCTION_VALUES_2D, func_buffers["data_2d"]), |
| 77 | + ] |
| 78 | + run_compute_shader( |
| 79 | + read_shader_file("ngsolve/surface_vectors.wgsl"), |
| 80 | + bindings, |
| 81 | + min(n_trigs // 256 + 1, 1024), |
| 82 | + entry_point="compute_surface_vectors", |
| 83 | + defines={ |
| 84 | + "MODE": 0, |
| 85 | + "MAX_EVAL_ORDER": self.function_data.order, |
| 86 | + "MAX_EVAL_ORDER_VEC3": self.function_data.order, |
| 87 | + }, |
| 88 | + ) |
| 89 | + |
| 90 | + self.n_vectors = int(read_buffer(self.u_nvectors, np.uint32)[0]) |
| 91 | + write_array_to_buffer(self.u_nvectors, np.array([0], dtype=np.uint32)) |
| 92 | + buffers = {} |
| 93 | + for name in ["positions", "directions"]: |
| 94 | + buffers[name] = self.device.createBuffer( |
| 95 | + size=3 * 4 * self.n_vectors, |
| 96 | + usage=BufferUsage.STORAGE | BufferUsage.COPY_SRC, |
| 97 | + label=name, |
| 98 | + ) |
| 99 | + buffers["values"] = self.device.createBuffer( |
| 100 | + size=4 * self.n_vectors, |
| 101 | + usage=BufferUsage.STORAGE | BufferUsage.COPY_SRC, |
| 102 | + label="values", |
| 103 | + ) |
| 104 | + |
| 105 | + bindings = [ |
| 106 | + *self.colormap.get_bindings(), |
| 107 | + BufferBinding(MeshBinding.VERTICES, mesh_buffers["vertices"]), |
| 108 | + BufferBinding(MeshBinding.TRIGS_INDEX, mesh_buffers[ElType.TRIG]), |
| 109 | + BufferBinding(22, buffers["positions"], read_only=False), |
| 110 | + BufferBinding(23, buffers["directions"], read_only=False), |
| 111 | + BufferBinding(25, buffers["values"], read_only=False), |
| 112 | + BufferBinding(21, self.u_nvectors, read_only=False), |
| 113 | + BufferBinding(MeshBinding.CURVATURE_VALUES_2D, mesh_buffers["curvature_2d"]), |
| 114 | + BufferBinding(FunctionBinding.FUNCTION_VALUES_2D, func_buffers["data_2d"]), |
| 115 | + # BufferBinding(MeshBinding.DEFORMATION_VALUES, mesh_buffers["deformation_2d"]), |
| 116 | + UniformBinding(MeshBinding.DEFORMATION_SCALE, mesh_buffers["deformation_scale"]), |
| 117 | + UniformBinding(24, self.u_ntrigs), |
| 118 | + ] |
| 119 | + |
| 120 | + run_compute_shader( |
| 121 | + read_shader_file("ngsolve/surface_vectors.wgsl"), |
| 122 | + bindings, |
| 123 | + min(n_trigs // 256 + 1, 1024), |
| 124 | + entry_point="compute_surface_vectors", |
| 125 | + defines={ |
| 126 | + "MODE": 1, |
| 127 | + "MAX_EVAL_ORDER": self.function_data.order, |
| 128 | + "MAX_EVAL_ORDER_VEC3": self.function_data.order, |
| 129 | + }, |
| 130 | + ) |
| 131 | + |
| 132 | + self.positions = read_buffer(buffers["positions"], np.float32).reshape(-1) |
| 133 | + self.values = read_buffer(buffers["values"], np.float32).reshape(-1) |
| 134 | + self.directions = read_buffer(buffers["directions"], np.float32).reshape(-1) |
| 135 | + |
| 136 | + def update(self, options): |
| 137 | + self.mesh.update(options) |
| 138 | + self.function_data.update(options) |
| 139 | + self.colormap.update(options) |
| 140 | + self.compute_vectors() |
| 141 | + super().update(options) |
| 142 | + return |
0 commit comments