Skip to content

Commit 18c91db

Browse files
committed
Remove duplicate code and add helper functions + proper download data testing
1 parent c859a43 commit 18c91db

2 files changed

Lines changed: 122 additions & 93 deletions

File tree

src/io4dolfinx/backends/exodus/backend.py

Lines changed: 86 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from pathlib import Path
11-
from typing import Any
11+
from typing import Any, Literal, cast
1212

1313
from mpi4py import MPI
1414

@@ -205,6 +205,64 @@ def write_mesh(
205205
raise NotImplementedError("The Exodus backend cannot write meshes.")
206206

207207

208+
def _read_mesh_geometry(infile: netCDF4.Dataset) -> tuple[int, npt.NDArray[np.floating]]:
209+
# use page 171 of manual to extract data
210+
num_nodes = infile.dimensions["num_nodes"].size
211+
gdim = infile.dimensions["num_dim"].size
212+
213+
# Get coordinates of mesh
214+
coord_var = infile.variables.get("coord")
215+
if coord_var is None:
216+
coordinates = np.zeros((num_nodes, gdim), dtype=np.float64)
217+
for i, coord in enumerate(["x", "y", "z"]):
218+
coord_i = infile.variables.get(f"coord{coord}")
219+
if coord_i is not None:
220+
coordinates[: coord_i.size, i] = coord_i[:]
221+
else:
222+
coordinates = np.asarray(coord_var)
223+
return gdim, coordinates
224+
225+
226+
def _get_entity_blocks(
227+
infile: netCDF4.Dataset, search_type: Literal["cell", "facet"]
228+
) -> tuple[int, list[netCDF4.Variable]]:
229+
# use page 171 of manual to extract data
230+
num_blocks = infile.dimensions["num_el_blk"].size
231+
232+
# Get element connectivity
233+
all_connectivity_variables = [infile.variables[f"connect{i + 1}"] for i in range(num_blocks)]
234+
235+
# Compute max topological dimension in mesh and find the correct
236+
max_tdim = _compute_tdim(max(all_connectivity_variables, key=_compute_tdim))
237+
238+
# Extract only the connectivity blocks that we need
239+
if search_type == "cell":
240+
search_dim = max_tdim
241+
elif search_type == "facet":
242+
search_dim = max_tdim - 1
243+
else:
244+
raise RuntimeError(f"Unknown entity type: {search_type}")
245+
return search_dim, list(
246+
filter(lambda el: _compute_tdim(el) == search_dim, all_connectivity_variables)
247+
)
248+
249+
250+
def _extract_connectivity_data(
251+
entity_blocks: list[netCDF4.Variable],
252+
) -> tuple[list[npt.NDArray[np.int64]], dolfinx.mesh.CellType, list[int]]:
253+
connectivity_arrays = []
254+
cell_types = []
255+
entity_block_index = []
256+
for entity_block in entity_blocks:
257+
connectivity_arrays.append(entity_block[:] - 1)
258+
cell_types.append(_get_cell_type(entity_block))
259+
entity_block_index.append(int(entity_block.name.removeprefix("connect")) - 1)
260+
for cell in cell_types:
261+
assert cell_types[0] == cell, "Mixed cell types not supported"
262+
cell_type = cell_types[0]
263+
return connectivity_arrays, cell_type, entity_block_index
264+
265+
208266
def read_mesh_data(
209267
filename: Path | str,
210268
comm: MPI.Intracomm,
@@ -225,53 +283,21 @@ def read_mesh_data(
225283
check_file_exists(filename)
226284
with netCDF4.Dataset(filename, "r") as infile:
227285
if comm.rank == 0:
228-
# use page 171 of manual to extract data
229-
num_nodes = infile.dimensions["num_nodes"].size
230-
gdim = infile.dimensions["num_dim"].size
231-
num_blocks = infile.dimensions["num_el_blk"].size
232-
233-
# Get coordinates of mesh
234-
coord_var = infile.variables.get("coord")
235-
if coord_var is None:
236-
coordinates = np.zeros((num_nodes, gdim), dtype=np.float64)
237-
for i, coord in enumerate(["x", "y", "z"]):
238-
coord_i = infile.variables.get(f"coord{coord}")
239-
if coord_i is not None:
240-
coordinates[: coord_i.size, i] = coord_i[:]
241-
else:
242-
coordinates = np.asarray(coord_var)
243-
# Get element connectivity
244-
all_connectivity_variables = [
245-
infile.variables[f"connect{i + 1}"] for i in range(num_blocks)
246-
]
247-
248-
# Compute max topological dimension in mesh and find the correct
249-
max_tdim = _compute_tdim(max(all_connectivity_variables, key=_compute_tdim))
286+
gdim, coordinates = _read_mesh_geometry(infile)
250287

251-
# Extract only the connectivity blocks that we need
252-
entity_blocks = list(
253-
filter(lambda el: _compute_tdim(el) == max_tdim, all_connectivity_variables)
254-
)
288+
_tdim, entity_blocks = _get_entity_blocks(infile, "cell")
255289
if len(entity_blocks) > 0:
256290
# Extract markers directly from entity-blocks
257-
connectivity_arrays = []
258-
cell_types = []
259-
num_entities = []
260-
entity_block_index = []
261-
for entity_block in entity_blocks:
262-
connectivity_arrays.append(entity_block[:] - 1)
263-
num_entities.append(entity_block.shape[0])
264-
cell_types.append(_get_cell_type(entity_block))
265-
entity_block_index.append(int(entity_block.name.removeprefix("connect")) - 1)
266-
for cell in cell_types:
267-
assert cell_types[0] == cell, "Mixed cell types not supported"
268-
cell_type = cell_types[0]
291+
connectivity_arrays, cell_type, _entity_block_index = _extract_connectivity_data(
292+
entity_blocks
293+
)
269294

270295
cells = np.vstack(connectivity_arrays)
271296
if isinstance(cells, np.ma.MaskedArray):
272297
cells = cells.filled()
273298
else:
274299
raise ValueError(f"No blocks found in {filename}")
300+
275301
perm = dolfinx.cpp.io.perm_vtk(cell_type, cells.shape[1])
276302
cells = cells[:, perm]
277303
cell_type, gdim, xtype, num_dofs_per_cell = comm.bcast(
@@ -328,51 +354,27 @@ def read_meshtags_data(
328354
"""
329355
if comm.rank == 0:
330356
with netCDF4.Dataset(filename, "r") as infile:
331-
# use page 171 of manual to extract data
332-
num_blocks = infile.dimensions["num_el_blk"].size
333-
334-
# Extract all connectivity blocks
335-
all_connectivity_variables = [
336-
infile.variables[f"connect{i + 1}"] for i in range(num_blocks)
337-
]
338-
339357
# Compute max topological dimension in mesh and find the correct
340-
max_tdim = _compute_tdim(max(all_connectivity_variables, key=_compute_tdim))
341-
if name == "cell":
342-
search_dim = max_tdim
343-
elif name == "facet":
344-
search_dim = max_tdim - 1
358+
if name == "cell" or name == "facet":
359+
search_dim, entity_blocks = _get_entity_blocks(
360+
infile, cast(Literal["cell", "facet"], name)
361+
)
345362
else:
346-
raise ValueError(f"Only name 'cell' or 'facet' is supported, got '{name}'")
347-
348-
# Extract only the connectivity blocks that we need
349-
entity_blocks = list(
350-
filter(lambda el: _compute_tdim(el) == search_dim, all_connectivity_variables)
351-
)
363+
raise RuntimeError("Expected name='cell' or 'facet' got {name}")
352364

353365
if len(entity_blocks) > 0:
354366
# Extract markers directly from entity-blocks
355-
connectivity_arrays = []
356-
cell_types = []
357-
num_entities = []
358-
entity_block_index = []
359-
for entity_block in entity_blocks:
360-
connectivity_arrays.append(entity_block[:] - 1)
361-
num_entities.append(entity_block.shape[0])
362-
cell_types.append(_get_cell_type(entity_block))
363-
entity_block_index.append(int(entity_block.name.removeprefix("connect")) - 1)
364-
for cell in cell_types:
365-
assert cell_types[0] == cell, "Mixed cell types not supported"
366-
cell_type = cell_types[0]
367-
367+
connectivity_arrays, cell_type, entity_block_index = _extract_connectivity_data(
368+
entity_blocks
369+
)
368370
marked_entities = np.vstack(connectivity_arrays)
369371
entity_values = np.zeros(marked_entities.shape[0], dtype=np.int64)
370372
if "eb_prop1" in infile.variables.keys():
371373
block_values = infile.variables["eb_prop1"][:]
372374

373375
# First check if entities are in eb_prop1
374-
insert_offset = np.zeros(len(num_entities) + 1, dtype=np.int64)
375-
insert_offset[1:] = np.cumsum(num_entities)
376+
insert_offset = np.zeros(len(connectivity_arrays) + 1, dtype=np.int64)
377+
insert_offset[1:] = np.cumsum([c_arr.shape[0] for c_arr in connectivity_arrays])
376378
for i, index in enumerate(entity_block_index):
377379
entity_values[insert_offset[i] : insert_offset[i + 1]] = block_values[index]
378380
else:
@@ -381,9 +383,7 @@ def read_meshtags_data(
381383
elif name == "facet" and "ss_prop1" in infile.variables.keys():
382384
# If we haven't found the cell type as a block, we should be extracting facets
383385
# (from side-sets), then we need the parent cell
384-
entity_blocks = list(
385-
filter(lambda el: _compute_tdim(el) == max_tdim, all_connectivity_variables)
386-
)
386+
_tdim, entity_blocks = _get_entity_blocks(infile, "cell")
387387
cell_types = []
388388
for entity_block in entity_blocks:
389389
cell_types.append(_get_cell_type(entity_block))
@@ -630,7 +630,7 @@ def read_cell_data(
630630
node_names = netCDF4.chartostring(raw_names)
631631
if name not in node_names:
632632
raise ValueError(
633-
f"Point data with name {name} not found in file.",
633+
f"Cell data with name {name} not found in file.",
634634
f"Available variables: {node_names}",
635635
)
636636
index = np.flatnonzero(name == node_names)[0] + 1
@@ -662,10 +662,12 @@ def read_cell_data(
662662
num_components = dataset.shape[1]
663663
# Broadcast num components to all other ranks
664664
num_components = comm.bcast(num_components, root=0)
665+
665666
# Zero data on all other processes
666667
if comm.rank != 0:
667668
dataset = np.zeros((0, num_components), dtype=np.float64)
668669
_time = float(time) if time is not None else None
670+
669671
topology = read_mesh_data(filename, comm, _time, False, backend_args=None).cells
670672
return topology, dataset
671673

@@ -683,7 +685,13 @@ def read_function_names(
683685
Returns:
684686
A list of function names.
685687
"""
686-
raise NotImplementedError("The Exodus backend does not support reading function names.")
688+
with netCDF4.Dataset(filename, "r") as infile:
689+
function_names: list[str] = []
690+
for key in ["name_elem_var", "name_nod_var"]:
691+
raw_names = infile.variables[key][:].data
692+
decoded_names = netCDF4.chartostring(raw_names)
693+
function_names.extend(decoded_names)
694+
return function_names
687695

688696

689697
def write_data(
@@ -705,11 +713,3 @@ def write_data(
705713
backend_args: The backend arguments
706714
"""
707715
raise NotImplementedError("Exodus has not implemented this yet")
708-
709-
710-
def getNames(model, key):
711-
# name of the element variables
712-
name_var = []
713-
for vname in np.ma.getdata(model.variables[key][:]).astype("U8"):
714-
name_var.append("".join(vname))
715-
return name_var

tests/test_exodus.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import urllib.request
2+
from enum import Enum
3+
14
from mpi4py import MPI
25

36
import pytest
@@ -7,18 +10,40 @@
710
netcdf4 = pytest.importorskip("netCDF4")
811

912

10-
def download_file_if_not_exists(url, filename):
11-
if not filename.exists():
12-
import urllib.request
13+
class DownloadStatus(Enum):
14+
success = 1
15+
failed = -1
16+
no_connection = -2
17+
1318

14-
urllib.request.urlretrieve(url, filename)
19+
def download_file_if_not_exists(
20+
url, filename, comm: MPI.Intracomm = MPI.COMM_WORLD, rank: int = 0
21+
) -> DownloadStatus:
22+
status = DownloadStatus.failed
23+
if comm.rank == rank:
24+
if not filename.exists():
25+
try:
26+
urllib.request.urlretrieve(url, filename)
27+
status = DownloadStatus.success
28+
except urllib.error.URLError as e:
29+
if str(e) == "<urlopen error [Errno -3] Temporary failure in name resolution>":
30+
status = DownloadStatus.no_connection
31+
else:
32+
status = DownloadStatus.failed
33+
else:
34+
status = DownloadStatus.success
35+
status = comm.bcast(status, root=rank)
36+
comm.Barrier()
37+
return status
1538

1639

1740
def test_read_mesh_and_cell_data(tmp_path):
41+
tmp_path = MPI.COMM_WORLD.bcast(tmp_path, root=0)
1842
filename = tmp_path / "openmc_master_out_openmc0.e"
1943
url = "https://github.com/neams-th-coe/cardinal/blob/devel/test/tests/neutronics/feedback/single_level/gold/openmc_master_out_openmc0.e?raw=true"
20-
download_file_if_not_exists(url, filename)
21-
44+
status = download_file_if_not_exists(url, filename)
45+
if status == DownloadStatus.no_connection:
46+
pytest.skip("No internet connection")
2247
mesh = io4dolfinx.read_mesh(filename, MPI.COMM_WORLD, backend="exodus")
2348
io4dolfinx.read_meshtags(filename, mesh, meshtag_name="cell", backend="exodus")
2449
io4dolfinx.read_meshtags(filename, mesh, meshtag_name="facet", backend="exodus")
@@ -28,9 +53,13 @@ def test_read_mesh_and_cell_data(tmp_path):
2853

2954

3055
def test_read_mesh_point_data(tmp_path):
56+
tmp_path = MPI.COMM_WORLD.bcast(tmp_path, root=0)
57+
3158
filename = tmp_path / "openmc_master_out_openmc0.e"
3259
url = "https://github.com/idaholab/moose/blob/next/test/tests/kernels/2d_diffusion/gold/matdiffusion_out.e?raw=true"
33-
download_file_if_not_exists(url, filename)
60+
status = download_file_if_not_exists(url, filename)
61+
if status == DownloadStatus.no_connection:
62+
pytest.skip("No internet connection")
3463

3564
mesh = io4dolfinx.read_mesh(filename, MPI.COMM_WORLD, backend="exodus")
3665
io4dolfinx.read_point_data(filename, name="u", mesh=mesh, backend="exodus", time=1.0)

0 commit comments

Comments
 (0)