@@ -42,7 +42,7 @@ class TetraMesh(object):
4242 Nodal coordinates, shape ``(n_nodes, 3)``. Each row defines the
4343 coordinates of a mesh node.
4444 enod : torch.Tensor
45- Tetra element nodes, shape ``(n_elm, n_nodes )``. ``enod[i, :]`` gives
45+ Tetra element nodes, shape ``(n_elm, 4 )``. ``enod[i, :]`` gives
4646 the nodal indices of element ``i``.
4747 dof : torch.Tensor
4848 Per node degrees of freedom. ``dof[i, :]`` gives the degrees of
@@ -94,7 +94,7 @@ def generate_mesh_from_vertices(cls, coord, enod):
9494 Nodal coordinates, shape ``(n_nodes, 3)``. Each row defines the
9595 coordinates of a mesh node.
9696 enod : numpy.ndarray
97- Tetra element nodes, shape ``(n_elm, n_nodes )``. ``enod[i, :]``
97+ Tetra element nodes, shape ``(n_elm, 4 )``. ``enod[i, :]``
9898 gives the nodal indices of element ``i``.
9999
100100 Returns
@@ -138,20 +138,20 @@ def generate_mesh_from_levelset(
138138 mesh generation fails.
139139 """
140140 from skimage import measure
141-
141+
142142 # Create a grid to sample the level set
143143 # Use a reasonably fine grid to capture the surface accurately
144144 n = max (int (2 * bounding_radius / (max_cell_circumradius * 0.5 )), 20 )
145145 x = np .linspace (- bounding_radius , bounding_radius , n , dtype = np .float64 )
146146 y = np .linspace (- bounding_radius , bounding_radius , n , dtype = np .float64 )
147147 z = np .linspace (- bounding_radius , bounding_radius , n , dtype = np .float64 )
148148 X , Y , Z = np .meshgrid (x , y , z , indexing = 'ij' )
149-
149+
150150 # Evaluate level set function on the grid
151151 points_grid = np .column_stack ([X .ravel (), Y .ravel (), Z .ravel ()])
152152 values = np .array ([level_set (p ) for p in points_grid ], dtype = np .float64 )
153153 volume = values .reshape (X .shape )
154-
154+
155155 # Extract the surface mesh using marching cubes at level=0
156156 # Note: marching_cubes extracts the surface where volume == level
157157 try :
@@ -164,28 +164,28 @@ def generate_mesh_from_levelset(
164164 "This likely means the level set doesn't intersect the zero level "
165165 "within the bounding box, or doesn't define a valid closed surface."
166166 )
167-
167+
168168 # Transform vertices from grid coordinates to world coordinates and ensure float64
169169 verts = verts .astype (np .float64 ) + np .array ([x [0 ], y [0 ], z [0 ]], dtype = np .float64 )
170170 faces = faces .astype (np .int32 ) # Ensure integer type for faces
171-
171+
172172 # Verify we have a valid surface
173173 if len (verts ) == 0 or len (faces ) == 0 :
174174 raise ValueError (
175175 "Marching cubes produced no surface. "
176176 "Check that the level set crosses zero within the bounding_radius."
177177 )
178-
178+
179179 # Create mesh info for meshpy
180180 mesh_info = tet .MeshInfo ()
181181 mesh_info .set_points (verts .tolist ())
182-
182+
183183 # Set the surface facets - these define the boundary
184184 mesh_info .set_facets (faces .tolist ())
185-
185+
186186 # Build the volume mesh with quality constraints
187187 max_volume = (max_cell_circumradius ** 3 ) / 6.0
188-
188+
189189 try :
190190 mesh = tet .build (
191191 mesh_info ,
@@ -198,20 +198,20 @@ def generate_mesh_from_levelset(
198198 "The surface may be self-intersecting or have other topological issues. "
199199 "Try adjusting bounding_radius or max_cell_circumradius."
200200 )
201-
201+
202202 # Convert to meshio format with explicit double precision
203203 vertices = np .array (mesh .points , dtype = np .float64 )
204204 elements = np .array (mesh .elements , dtype = np .int64 )
205-
205+
206206 # Verify mesh generation succeeded
207207 if len (vertices ) == 0 or len (elements ) == 0 :
208208 raise ValueError (
209209 "Mesh generation produced no tetrahedra. "
210210 "This may indicate an issue with the surface topology."
211211 )
212-
212+
213213 mesh_obj = meshio .Mesh (vertices , [("tetra" , elements )])
214-
214+
215215 return cls ._build_tetramesh (mesh_obj )
216216
217217 def translate (self , translation_vector ):
@@ -225,7 +225,7 @@ def translate(self, translation_vector):
225225 # Convert to numpy if torch tensor
226226 if torch .is_tensor (translation_vector ):
227227 translation_vector = utils .ensure_numpy (translation_vector )
228-
228+
229229 self ._mesh .points += translation_vector
230230 self .coord = utils .ensure_torch (self ._mesh .points , dtype = torch .float64 )
231231 self .ecentroids += utils .ensure_torch (translation_vector , dtype = torch .float64 )
@@ -291,7 +291,7 @@ def save(self, file, element_data=None):
291291 element_data [key ] = [list (utils .ensure_numpy (element_data [key ]))]
292292 else :
293293 element_data [key ] = [list (element_data [key ])]
294-
294+
295295 # Convert coord and enod to numpy if they're torch tensors
296296 coord = utils .ensure_numpy (self .coord ) if torch .is_tensor (self .coord ) else self .coord
297297 enod = utils .ensure_numpy (self .enod ) if torch .is_tensor (self .enod ) else self .enod
@@ -336,24 +336,24 @@ def _build_tetramesh(cls, mesh):
336336 A new mesh object with all data as torch tensors.
337337 """
338338 tetmesh = cls ()
339-
339+
340340 # Convert core mesh data to tensors immediately
341341 tetmesh .coord = utils .ensure_torch (mesh .points , dtype = torch .float64 )
342-
342+
343343 # Handle enod - ensure it's 2D
344344 enod_data = mesh .cells_dict ["tetra" ]
345345 enod_tensor = utils .ensure_torch (enod_data , dtype = torch .int64 )
346346 # If enod is 1D (single tetrahedron), reshape to 2D
347347 if enod_tensor .ndim == 1 :
348348 enod_tensor = enod_tensor .reshape (1 , - 1 )
349349 tetmesh .enod = enod_tensor
350-
350+
351351 # Store tensor versions in _mesh for persistence
352352 tetmesh ._mesh = meshio .Mesh (
353353 points = utils .ensure_numpy (tetmesh .coord ),
354354 cells = [("tetra" , utils .ensure_numpy (tetmesh .enod ))]
355355 )
356-
356+
357357 # Complete initialization using tensor data
358358 tetmesh ._set_fem_matrices ()
359359 tetmesh ._expand_mesh_data ()
@@ -380,9 +380,9 @@ def _compute_mesh_faces(self, enod):
380380 enod = torch .tensor (enod , dtype = torch .int64 )
381381 elif enod .dtype != torch .int64 :
382382 enod = enod .to (dtype = torch .int64 )
383-
383+
384384 # Create permutations as long tensor
385- permutations = torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 3 ], [0 , 2 , 3 ], [1 , 2 , 3 ]],
385+ permutations = torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 3 ], [0 , 2 , 3 ], [1 , 2 , 3 ]],
386386 dtype = torch .int64 , device = enod .device )
387387 efaces = enod [:, permutations ]
388388 return efaces
@@ -487,7 +487,7 @@ def _compute_mesh_spheres(self, coord, enod):
487487 # Ensure inputs are torch tensors
488488 coord = utils .ensure_torch (coord , dtype = torch .float64 )
489489 enod = utils .ensure_torch (enod , dtype = torch .int64 )
490-
490+
491491 vertices = coord [enod ]
492492 n_tetra = enod .shape [0 ]
493493 pairs = torch .tensor (
@@ -568,18 +568,18 @@ def _set_fem_matrices(self):
568568 """
569569 # Convert coordinates to float64 tensor
570570 self .coord = utils .ensure_torch (self ._mesh .points , dtype = torch .float64 )
571-
571+
572572 # Convert element indices to long (int64) tensor for indexing
573573 self .enod = utils .ensure_torch (self ._mesh .cells_dict ["tetra" ], dtype = torch .int64 )
574574 if self .enod .dtype != torch .int64 :
575575 self .enod = self .enod .to (dtype = torch .int64 )
576-
576+
577577 # Generate and convert DOF indices to long tensor
578578 self .dof = utils .ensure_torch (
579- np .arange (0 , self .coord .shape [0 ] * 3 ).reshape (self .coord .shape [0 ], 3 ),
579+ np .arange (0 , self .coord .shape [0 ] * 3 ).reshape (self .coord .shape [0 ], 3 ),
580580 dtype = torch .int64
581581 )
582-
582+
583583 self .number_of_elements = self .enod .shape [0 ]
584584
585585 def _expand_mesh_data (self ):
@@ -588,7 +588,7 @@ def _expand_mesh_data(self):
588588 self .efaces = self ._compute_mesh_faces (self .enod )
589589 if self .efaces .dtype != torch .int64 :
590590 self .efaces = self .efaces .to (dtype = torch .int64 )
591-
591+
592592 # Compute mesh normals and other geometric properties
593593 self .enormals = utils .ensure_torch (
594594 self ._compute_mesh_normals (self .coord , self .enod , self .efaces ),
@@ -598,12 +598,12 @@ def _expand_mesh_data(self):
598598 self ._compute_mesh_centroids (self .coord , self .enod ),
599599 dtype = torch .float64
600600 )
601-
601+
602602 # Compute bounding spheres
603603 eradius_np , espherecentroids_np = self ._compute_mesh_spheres (self .coord , self .enod )
604604 self .eradius = utils .ensure_torch (eradius_np , dtype = torch .float64 )
605605 self .espherecentroids = utils .ensure_torch (espherecentroids_np , dtype = torch .float64 )
606-
606+
607607 # Compute global properties
608608 self .centroid = torch .mean (self .ecentroids , dim = 0 )
609609 self .evolumes = utils .ensure_torch (
0 commit comments