Skip to content

Commit 4418341

Browse files
committed
Add support for NETCDF3_64BIT_DATA in write_netcdf()
For performance, this involves first writing a `NETCDF4` file and then using `ncks` to convert it to `NETCDF3_64BIT_DATA` (CDF5) format.
1 parent 08cb31c commit 4418341

1 file changed

Lines changed: 70 additions & 14 deletions

File tree

  • conda_package/mpas_tools

conda_package/mpas_tools/io.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
1-
from __future__ import absolute_import, division, print_function, \
2-
unicode_literals
1+
import os
2+
import subprocess
3+
import sys
4+
from datetime import datetime
35

4-
import numpy
56
import netCDF4
6-
from datetime import datetime
7-
import sys
7+
import numpy
88

9+
from mpas_tools.logging import check_call
910

1011
default_format = 'NETCDF3_64BIT'
1112
default_engine = None
1213
default_char_dim_name = 'StrLen'
1314
default_fills = netCDF4.default_fillvals
1415

1516

16-
def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
17-
char_dim_name=None):
17+
def write_netcdf(
18+
ds,
19+
fileName,
20+
fillValues=None,
21+
format=None,
22+
engine=None,
23+
char_dim_name=None,
24+
logger=None,
25+
):
1826
"""
1927
Write an xarray.Dataset to a file with NetCDF4 fill values and the given
2028
name of the string dimension. Also adds the time and command-line to the
2129
history attribute.
2230
31+
Note: the ``NETCDF3_64BIT_DATA`` format is handled as a special case
32+
because xarray output with this format is not performant. First, the file
33+
is written in `NETCDF4` format, which supports larger files and variables.
34+
Then, the `ncks` command is used to convert the file to the
35+
`NETCDF3_64BIT_DATA` format.
36+
37+
2338
Parameters
2439
----------
2540
ds : xarray.Dataset
@@ -50,7 +65,11 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
5065
``mpas_tools.io.default_char_dim_name``, which can be modified but
5166
which defaults to ``'StrLen'``
5267
53-
"""
68+
logger : logging.Logger, optional
69+
A logger to write messages to write the output of `ncks` conversion
70+
calls to. If None, `ncks` output is suppressed. This is only
71+
relevant if `format` is 'NETCDF3_64BIT_DATA'
72+
""" # noqa: E501
5473
if format is None:
5574
format = default_format
5675

@@ -71,8 +90,9 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
7190
dtype = ds[variableName].dtype
7291
for fillType in fillValues:
7392
if dtype == numpy.dtype(fillType):
74-
encodingDict[variableName] = \
75-
{'_FillValue': fillValues[fillType]}
93+
encodingDict[variableName] = {
94+
'_FillValue': fillValues[fillType]
95+
}
7696
break
7797
else:
7898
encodingDict[variableName] = {'_FillValue': None}
@@ -88,14 +108,50 @@ def write_netcdf(ds, fileName, fillValues=None, format=None, engine=None,
88108
# reading Time otherwise
89109
ds.encoding['unlimited_dims'] = {'Time'}
90110

91-
ds.to_netcdf(fileName, encoding=encodingDict, format=format, engine=engine)
111+
# for performance, we have to handle this as a special case
112+
convert = format == 'NETCDF3_64BIT_DATA'
113+
114+
if convert:
115+
basename, extension = os.path.splitext(fileName)
116+
out_filename = f'{basename}.netcdf4{extension}'
117+
format = 'NETCDF4'
118+
if engine == 'scipy':
119+
# that's not going to work
120+
engine = 'netcdf4'
121+
else:
122+
out_filename = fileName
123+
124+
ds.to_netcdf(
125+
out_filename, encoding=encodingDict, format=format, engine=engine
126+
)
127+
128+
if convert:
129+
args = [
130+
'ncks',
131+
'-O',
132+
'-5',
133+
out_filename,
134+
fileName,
135+
]
136+
if logger is None:
137+
subprocess.run(
138+
args,
139+
check=True,
140+
stdout=subprocess.DEVNULL,
141+
stderr=subprocess.DEVNULL,
142+
)
143+
else:
144+
check_call(args, logger=logger)
92145

93146

94147
def update_history(ds):
95-
'''Add or append history to attributes of a data set'''
148+
"""Add or append history to attributes of a data set"""
96149

97-
thiscommand = datetime.now().strftime("%a %b %d %H:%M:%S %Y") + ": " + \
98-
" ".join(sys.argv[:])
150+
thiscommand = (
151+
datetime.now().strftime('%a %b %d %H:%M:%S %Y')
152+
+ ': '
153+
+ ' '.join(sys.argv[:])
154+
)
99155
if 'history' in ds.attrs:
100156
newhist = '\n'.join([thiscommand, ds.attrs['history']])
101157
else:

0 commit comments

Comments
 (0)