Skip to content
15 changes: 11 additions & 4 deletions cf/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -7586,18 +7586,25 @@ def collapse(

# ---------------------------------------------------------
# Update dimension coordinates, auxiliary coordinates,
# cell measures and domain ancillaries
# cell measures, domain ancillaries, domain_topologies,
# and cell connectivities.
# ---------------------------------------------------------
for axis, domain_axis in collapse_axes.items():
# Ignore axes which are already size 1
size = domain_axis.get_size()
if size == 1:
continue

# REMOVE all cell measures and domain ancillaries
# which span this axis
# REMOVE all cell measures, domain ancillaries,
# domain_topologies, and cell connectivities which
# span this axis
c = f.constructs.filter(
filter_by_type=("cell_measure", "domain_ancillary"),
filter_by_type=(
"cell_measure",
"domain_ancillary",
"domain_topology",
"cell_connectivity",
),
filter_by_axis=(axis,),
axis_mode="or",
todict=True,
Expand Down
145 changes: 145 additions & 0 deletions cf/test/create_test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,6 +2228,150 @@ def _make_ugrid_2(filename):
return filename


def _make_ugrid_3(filename):
"""Create a UGRID mesh topology and no fields/domains."""
n = netCDF4.Dataset(filename, "w")

n.Conventions = f"CF-{VN}"

n.createDimension("nMesh3_node", 7)
n.createDimension("nMesh3_edge", 9)
n.createDimension("nMesh3_face", 3)
n.createDimension("connectivity2", 2)
n.createDimension("connectivity4", 4)
n.createDimension("connectivity5", 5)

Mesh3 = n.createVariable("Mesh3", "i4", ())
Mesh3.cf_role = "mesh_topology"
Mesh3.topology_dimension = 2
Mesh3.node_coordinates = "Mesh3_node_x Mesh3_node_y"
Mesh3.face_node_connectivity = "Mesh3_face_nodes"
Mesh3.edge_node_connectivity = "Mesh3_edge_nodes"
Mesh3.face_dimension = "nMesh3_face"
Mesh3.edge_dimension = "nMesh3_edge"
Mesh3.face_face_connectivity = "Mesh3_face_links"
Mesh3.edge_edge_connectivity = "Mesh3_edge_links"

# Node
Mesh3_node_x = n.createVariable("Mesh3_node_x", "f4", ("nMesh3_node",))
Mesh3_node_x.standard_name = "longitude"
Mesh3_node_x.units = "degrees_east"
Mesh3_node_x[...] = [-45, -43, -45, -43, -45, -43, -40]

Mesh3_node_y = n.createVariable("Mesh3_node_y", "f4", ("nMesh3_node",))
Mesh3_node_y.standard_name = "latitude"
Mesh3_node_y.units = "degrees_north"
Mesh3_node_y[...] = [35, 35, 33, 33, 31, 31, 34]

Mesh3_edge_nodes = n.createVariable(
"Mesh3_edge_nodes", "i4", ("nMesh3_edge", "connectivity2")
)
Mesh3_edge_nodes.long_name = "Maps every edge to its two nodes"
Mesh3_edge_nodes[...] = [
[1, 6],
[3, 6],
[3, 1],
[0, 1],
[2, 0],
[2, 3],
[2, 4],
[5, 4],
[3, 5],
]

# Face
Mesh3_face_x = n.createVariable(
"Mesh3_face_x", "f8", ("nMesh3_face",), fill_value=-99
)
Mesh3_face_x.standard_name = "longitude"
Mesh3_face_x.units = "degrees_east"
Mesh3_face_x[...] = [-44, -44, -42]

Mesh3_face_y = n.createVariable(
"Mesh3_face_y", "f8", ("nMesh3_face",), fill_value=-99
)
Mesh3_face_y.standard_name = "latitude"
Mesh3_face_y.units = "degrees_north"
Mesh3_face_y[...] = [34, 32, 34]

Mesh3_face_nodes = n.createVariable(
"Mesh3_face_nodes",
"i4",
("nMesh3_face", "connectivity4"),
fill_value=-99,
)
Mesh3_face_nodes.long_name = "Maps every face to its corner nodes"
Mesh3_face_nodes[...] = [[2, 3, 1, 0], [4, 5, 3, 2], [6, 1, 3, -99]]

Mesh3_face_links = n.createVariable(
"Mesh3_face_links",
"i4",
("nMesh3_face", "connectivity4"),
fill_value=-99,
)
Mesh3_face_links.long_name = "neighbour faces for faces"
Mesh3_face_links[...] = [
[1, 2, -99, -99],
[0, -99, -99, -99],
[0, -99, -99, -99],
]

# Edge
Mesh3_edge_x = n.createVariable(
"Mesh3_edge_x", "f8", ("nMesh3_edge",), fill_value=-99
)
Mesh3_edge_x.standard_name = "longitude"
Mesh3_edge_x.units = "degrees_east"
Mesh3_edge_x[...] = [-41.5, -41.5, -43, -44, -45, -44, -45, -44, -43]

Mesh3_edge_y = n.createVariable(
"Mesh3_edge_y", "f8", ("nMesh3_edge",), fill_value=-99
)
Mesh3_edge_y.standard_name = "latitude"
Mesh3_edge_y.units = "degrees_north"
Mesh3_edge_y[...] = [34.5, 33.5, 34, 35, 34, 33, 32, 31, 32]

Mesh3_edge_links = n.createVariable(
"Mesh3_edge_links",
"i4",
("nMesh3_edge", "connectivity5"),
fill_value=-99,
)
Mesh3_edge_links.long_name = "neighbour edges for edges"
Mesh3_edge_links[...] = [
[1, 2, 3, -99, -99],
[0, 2, 5, 8, -99],
[3, 0, 1, 5, 8],
[4, 2, 0, -99, -99],
[
3,
5,
6,
-99,
-99,
],
[4, 6, 2, 1, 8],
[
4,
5,
7,
-99,
-99,
],
[
6,
8,
-99,
-99,
-99,
],
[7, 5, 2, 1, -99],
]

n.close()
return filename


def _make_aggregation_value(filename):
"""Create an aggregation variable with 'unique_values'."""
n = netCDF4.Dataset(filename, "w")
Expand Down Expand Up @@ -2341,6 +2485,7 @@ def _make_aggregation_value(filename):

ugrid_1 = _make_ugrid_1("ugrid_1.nc")
ugrid_2 = _make_ugrid_2("ugrid_2.nc")
ugrid_3 = _make_ugrid_3("ugrid_3.nc")

aggregation_value = _make_aggregation_value("aggregation_value.nc")

Expand Down
148 changes: 131 additions & 17 deletions cf/test/test_UGRID.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import atexit
import datetime
import faulthandler
import itertools
import os
import tempfile
import unittest

import netCDF4
import numpy as np

faulthandler.enable() # to debug seg faults and timeouts
Expand All @@ -14,12 +16,12 @@
warnings = False

# Set up temporary files
n_tmpfiles = 1
n_tmpfiles = 2
tmpfiles = [
tempfile.mkstemp("_test_read_write.nc", dir=os.getcwd())[1]
tempfile.mkstemp("_test_ugrid.nc", dir=os.getcwd())[1]
for i in range(n_tmpfiles)
]
[tmpfile1] = tmpfiles
[tmpfile, tmpfile1] = tmpfiles


def _remove_tmpfiles():
Expand All @@ -34,6 +36,31 @@ def _remove_tmpfiles():
atexit.register(_remove_tmpfiles)


def n_mesh_variables(filename):
"""Return the number of mesh variables in the file."""
nc = netCDF4.Dataset(filename, "r")
n = 0
for v in nc.variables.values():
try:
v.getncattr("topology_dimension")
except AttributeError:
pass
else:
n += 1

nc.close()
return n


def combinations(face, edge, point):
"""Return combinations for field/domain indexing."""
return [
i
for n in range(1, 4)
for i in itertools.permutations([face, edge, point], n)
]


class UGRIDTest(unittest.TestCase):
"""Test UGRID field constructs."""

Expand All @@ -45,6 +72,10 @@ class UGRIDTest(unittest.TestCase):
os.path.dirname(os.path.abspath(__file__)), "ugrid_2.nc"
)

filename3 = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "ugrid_3.nc"
)

def setUp(self):
"""Preparations called immediately before each test method."""
# Disable log messages to silence expected warnings
Expand Down Expand Up @@ -76,10 +107,6 @@ def test_UGRID_read(self):
g.cell_connectivity().get_connectivity(), "edge"
)

# Check that all fields have the same mesh id
mesh_ids1 = set(g.get_mesh_id() for g in f1)
self.assertEqual(len(mesh_ids1), 1)

f2 = cf.read(self.filename2)
self.assertEqual(len(f2), 3)
for g in f2:
Expand All @@ -98,13 +125,6 @@ def test_UGRID_read(self):
g.cell_connectivity().get_connectivity(), "edge"
)

# Check that all fields have the same mesh id
mesh_ids2 = set(g.get_mesh_id() for g in f2)
self.assertEqual(len(mesh_ids2), 1)

# Check that the different files have different mesh ids
self.assertNotEqual(mesh_ids1, mesh_ids2)

def test_UGRID_data(self):
"""Test reading of UGRID data."""
node1, face1, edge1 = cf.read(self.filename1)
Expand Down Expand Up @@ -177,9 +197,103 @@ def test_read_UGRID_domain(self):
g.cell_connectivity().get_connectivity(), "edge"
)

# Check that all domains have the same mesh id
mesh_ids1 = set(g.get_mesh_id() for g in d1)
self.assertEqual(len(mesh_ids1), 1)
def test_read_write_UGRID_field(self):
"""Test the cf.read and cf.write with UGRID fields."""
# Face, edge, and point fields that are all part of the same
# UGRID mesh
ugrid = cf.example_fields(8, 9, 10)
face, edge, point = (0, 1, 2)

tmpfile = "tmpfileu.nc"
# Test for equality with the fields defined in memory. Only
# works for face and edge fields.
for cell in (face, edge):
f = ugrid[cell]
cf.write(f, tmpfile)
g = cf.read(tmpfile)
self.assertEqual(len(g), 1)
self.assertTrue(g[0].equals(f))

# Test round-tripping of field combinations
for cells in combinations(face, edge, point):
f = []
for cell in cells:
f.append(ugrid[cell])

cf.write(f, tmpfile)

# Check that there's only one mesh variable in the file
self.assertEqual(n_mesh_variables(tmpfile), 1)

g = cf.read(tmpfile)
self.assertEqual(len(g), len(f))

cf.write(g, tmpfile1)

# Check that there's only one mesh variable in the file
self.assertEqual(n_mesh_variables(tmpfile1), 1)

h = cf.read(tmpfile1)
self.assertEqual(len(h), len(g))
self.assertTrue(h[0].equals(g[0]))

def test_read_write_UGRID_domain(self):
"""Test the cf.read and cf.write with UGRID domains."""
# Face, edge, and point fields/domains that are all part of
# the same UGRID mesh
ugrid = [f.domain for f in cf.example_fields(8, 9, 10)]
face, edge, point = (0, 1, 2)

# Test for equality with the fields defined in memory. Only
# works for face and edge domains.
for cell in (face, edge):
d = ugrid[cell]
cf.write(d, tmpfile)
e = cf.read(tmpfile, domain=True)
self.assertEqual(len(e), 2)
self.assertTrue(e[0].equals(d))
self.assertEqual(e[1].domain_topology().get_cell(), "point")

# Test round-tripping of domain combinations for the
# example_field domains, and also the domain read from
# 'ugrid_3.nc'.
for iteration in ("memory", "file"):
for cells in combinations(face, edge, point):
d = []
for cell in cells:
d.append(ugrid[cell])

if point not in cells:
# When we write a non-point domains, we also get
# the point locations.
d.append(ugrid[point])
elif cells == (point,):
# When we write a point domain on its own, we also
# get the edge location.
d.append(ugrid[edge])

cf.write(d, tmpfile)

# Check that there's only one mesh variable in the file
self.assertEqual(n_mesh_variables(tmpfile), 1)

e = cf.read(tmpfile, domain=True)

self.assertEqual(len(e), len(d))

cf.write(e, tmpfile1)

# Check that there's only one mesh variable in the file
self.assertEqual(n_mesh_variables(tmpfile1), 1)

f = cf.read(tmpfile1, domain=True)
self.assertEqual(len(f), len(e))
for i, j in zip(f, e):
self.assertTrue(i.equals(j))

# Set up for the 'file' iteration
ugrid = cf.read(self.filename3, domain=True)
face, edge, point = (2, 1, 0)


if __name__ == "__main__":
Expand Down
Loading