Skip to content

Commit 17e6044

Browse files
committed
refactor(provider): drop matplotlib.use("Agg") for the Figure API
Use matplotlib's object-oriented Figure class directly instead of pyplot. Figure.savefig works without an interactive backend, so the module-level matplotlib.use("Agg") call -- and the E402 noqa it forced on every following import -- are no longer needed. Library code now touches no matplotlib global state. test: move matplotlib backend selection into conftest.py The Agg backend is now selected once in tests/conftest.py, before any test module is imported, instead of via a module-level matplotlib.use call in test_ref_tutorials.py. The test file's imports go back to a clean sorted block with no E402 noqa.
1 parent aac5e00 commit 17e6044

3 files changed

Lines changed: 43 additions & 56 deletions

File tree

src/ref_tutorials/provider.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,28 @@
1818

1919
import importlib.metadata
2020
from pathlib import Path
21-
from typing import Any
2221

23-
import matplotlib
24-
25-
matplotlib.use("Agg") # Diagnostics run headless; no display is available.
26-
27-
import matplotlib.pyplot as plt # noqa: E402
28-
import xarray as xr # noqa: E402
29-
from climate_ref_core.constraints import ( # noqa: E402
22+
import xarray as xr
23+
from climate_ref_core.constraints import (
3024
AddSupplementaryDataset,
3125
RequireContiguousTimerange,
3226
)
33-
from climate_ref_core.datasets import FacetFilter, SourceDatasetType # noqa: E402
34-
from climate_ref_core.diagnostics import ( # noqa: E402
27+
from climate_ref_core.datasets import FacetFilter, SourceDatasetType
28+
from climate_ref_core.diagnostics import (
3529
DataRequirement,
3630
Diagnostic,
3731
ExecutionDefinition,
3832
ExecutionResult,
3933
)
40-
from climate_ref_core.metric_values.typing import SeriesMetricValue # noqa: E402
41-
from climate_ref_core.providers import DiagnosticProvider # noqa: E402
42-
from climate_ref_core.pycmec.metric import CMECMetric # noqa: E402
43-
from climate_ref_core.pycmec.output import CMECOutput, OutputCV # noqa: E402
34+
from climate_ref_core.metric_values.typing import SeriesMetricValue
35+
from climate_ref_core.providers import DiagnosticProvider
36+
from climate_ref_core.pycmec.metric import CMECMetric
37+
from climate_ref_core.pycmec.output import CMECOutput, OutputCV
38+
39+
# The object-oriented matplotlib API. Using Figure directly -- rather than
40+
# pyplot -- means this library module touches no global state and needs no
41+
# interactive backend, so there is no matplotlib.use("Agg") call.
42+
from matplotlib.figure import Figure
4443

4544
#: Name of the NetCDF file the diagnostic writes into its output directory.
4645
_OUTPUT_FILENAME = "annual_mean_global_mean_tas.nc"
@@ -115,18 +114,19 @@ def make_figures(ds: xr.Dataset, output_directory: Path) -> dict[str, Path]:
115114
tas = ds["tas"].values
116115

117116
timeseries_path = output_directory / _TIMESERIES_PLOT
118-
fig, ax = plt.subplots(figsize=(7, 4))
117+
fig = Figure(figsize=(7, 4))
118+
ax = fig.subplots()
119119
ax.plot(years, tas, marker="o", color="#1f77b4")
120120
ax.set_xlabel("Year")
121121
ax.set_ylabel("Global-mean tas (K)")
122122
ax.set_title("Annual-mean global-mean near-surface air temperature")
123123
fig.tight_layout()
124124
fig.savefig(timeseries_path, dpi=150)
125-
plt.close(fig)
126125

127126
anomaly_path = output_directory / _ANOMALY_PLOT
128127
anomaly = tas - tas.mean()
129-
fig, ax = plt.subplots(figsize=(7, 4))
128+
fig = Figure(figsize=(7, 4))
129+
ax = fig.subplots()
130130
colors = ["#d62728" if a >= 0 else "#1f77b4" for a in anomaly]
131131
ax.bar(years, anomaly, color=colors)
132132
ax.axhline(0, color="#444444", linewidth=0.8)
@@ -135,7 +135,6 @@ def make_figures(ds: xr.Dataset, output_directory: Path) -> dict[str, Path]:
135135
ax.set_title("Annual-mean tas anomaly relative to the period mean")
136136
fig.tight_layout()
137137
fig.savefig(anomaly_path, dpi=150)
138-
plt.close(fig)
139138

140139
return {
141140
"Annual-mean global-mean tas timeseries": timeseries_path,
@@ -186,13 +185,11 @@ def _series_values(
186185
class AnnualMeanGlobalMeanTas(Diagnostic):
187186
"""Annual-mean, global-mean near-surface air temperature.
188187
189-
A minimal custom diagnostic for the tutorials. It requires CMIP6 ``tas``
190-
data, and pulls in the matching ``areacella`` cell-area field as a
188+
A minimal custom diagnostic for the tutorials.
189+
It requires CMIP6 ``tas`` data, and pulls in the matching ``areacella`` cell-area field as a
191190
supplementary dataset so the global mean can be area-weighted.
192191
193-
Its execution registers scalar metric values, a series metric value, and
194-
two figures -- a representative cross-section of what a real diagnostic
195-
produces.
192+
Its execution registers scalar metric values, a series metric value, and two figures.
196193
"""
197194

198195
name = "Annual Mean Global Mean Temperature"
@@ -202,10 +199,13 @@ class AnnualMeanGlobalMeanTas(Diagnostic):
202199
(
203200
DataRequirement(
204201
source_type=SourceDatasetType.CMIP6,
202+
# We only look at each
205203
filters=(FacetFilter(facets={"variable_id": ("tas",)}),),
206204
group_by=("source_id", "experiment_id", "variant_label"),
207205
constraints=(
206+
# Ensure that we have a contiguous time range to compute the annual mean from
208207
RequireContiguousTimerange(group_by=("instance_id",)),
208+
# Add the matching areacella dataset as a supplementary dataset
209209
AddSupplementaryDataset.from_defaults(
210210
"areacella", SourceDatasetType.CMIP6
211211
),
@@ -225,6 +225,7 @@ class AnnualMeanGlobalMeanTas(Diagnostic):
225225
def execute(self, definition: ExecutionDefinition) -> None:
226226
"""Compute the diagnostic: write the output NetCDF and the figures."""
227227
input_datasets = definition.datasets[SourceDatasetType.CMIP6]
228+
228229
result = calculate_annual_mean_global_mean(input_datasets.path.to_list())
229230
if "time_bnds" in result:
230231
result = result.drop_vars("time_bnds")
@@ -235,21 +236,17 @@ def execute(self, definition: ExecutionDefinition) -> None:
235236
def build_execution_result(self, definition: ExecutionDefinition) -> ExecutionResult:
236237
"""Package the output into an :class:`ExecutionResult`.
237238
238-
``execute`` has already written the NetCDF file and the figures. Here we
239-
register the scalar metric values, the series metric value, and the
240-
figures so the REF records them.
239+
``execute`` has already written the NetCDF file and the figures.
240+
241+
Here we register the scalar metric values, the series metric value,
242+
and the figures so the REF records them.
241243
"""
242244
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
243245
ds = xr.open_dataset(
244246
definition.output_directory / _OUTPUT_FILENAME, decode_times=time_coder
245247
)
246248

247249
selectors = definition.datasets[SourceDatasetType.CMIP6].selector_dict()
248-
input_selectors: dict[str, Any] = {
249-
"source_id": selectors["source_id"],
250-
"experiment_id": selectors["experiment_id"],
251-
"variant_label": selectors["variant_label"],
252-
}
253250

254251
# Register the figures written by execute().
255252
output_bundle = CMECOutput.create_template()
@@ -265,14 +262,14 @@ def build_execution_result(self, definition: ExecutionDefinition) -> ExecutionRe
265262
OutputCV.FILENAME.value: relative_path,
266263
OutputCV.LONG_NAME.value: caption,
267264
OutputCV.DESCRIPTION.value: caption,
268-
OutputCV.DIMENSIONS.value: input_selectors,
265+
OutputCV.DIMENSIONS.value: selectors,
269266
}
270267

271268
return ExecutionResult.build_from_output_bundle(
272269
definition,
273270
cmec_output_bundle=output_bundle,
274-
cmec_metric_bundle=_scalar_metric_bundle(ds, input_selectors),
275-
series=_series_values(ds, input_selectors),
271+
cmec_metric_bundle=_scalar_metric_bundle(ds, selectors),
272+
series=_series_values(ds, selectors),
276273
)
277274

278275

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Shared fixtures"""
2+
import matplotlib
3+
4+
# The plotting helpers use pyplot.
5+
# Force a non-interactive backend so the figure-producing tests run headless
6+
matplotlib.use("Agg")

tests/test_ref_tutorials.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,22 @@
1-
"""Unit tests for the ``ref_tutorials`` helper package.
2-
3-
These tests cover the *observable contract* of the helper functions, not their
4-
internals. They do not require network access or the generated API client.
5-
"""
1+
"""Unit tests for the ``ref_tutorials`` helper package. """
62

73
from __future__ import annotations
84

95
from dataclasses import dataclass, field
106

117
import matplotlib
8+
import pandas as pd
9+
import pytest
1210

13-
matplotlib.use("Agg") # Headless backend for CI.
14-
15-
import pandas as pd # noqa: E402
16-
import pytest # noqa: E402
17-
18-
from ref_tutorials import ( # noqa: E402
11+
from ref_tutorials import (
1912
metric_values_to_dataframe,
2013
model_comparison_figure,
2114
save_figure,
2215
set_publication_style,
2316
)
24-
from ref_tutorials.plotting import PUBLICATION_RCPARAMS # noqa: E402
17+
from ref_tutorials.plotting import PUBLICATION_RCPARAMS
2518

26-
# --- Test doubles -----------------------------------------------------------
27-
# Minimal stand-ins for the metric value objects returned by climate_ref_client.
19+
# The Agg backend is selected in conftest.py before this module is imported.
2820

2921

3022
@dataclass
@@ -42,8 +34,6 @@ def make(cls, value: float, **dims: object) -> _MetricValue:
4234
return cls(value=value, dimensions=_Dimensions(additional_properties=dict(dims)))
4335

4436

45-
# --- metric_values_to_dataframe ---------------------------------------------
46-
4737

4838
def test_metric_values_to_dataframe_flattens_dimensions():
4939
values = [
@@ -79,8 +69,6 @@ def test_metric_values_to_dataframe_empty_input():
7969
assert df.empty
8070

8171

82-
# --- set_publication_style --------------------------------------------------
83-
8472

8573
def test_set_publication_style_applies_rcparams():
8674
matplotlib.rcParams["font.size"] = 99 # Pollute first.
@@ -90,8 +78,6 @@ def test_set_publication_style_applies_rcparams():
9078
assert matplotlib.rcParams["font.size"] == PUBLICATION_RCPARAMS["font.size"]
9179

9280

93-
# --- model_comparison_figure ------------------------------------------------
94-
9581

9682
def test_model_comparison_figure_returns_fig_and_axes():
9783
df = pd.DataFrame(
@@ -122,8 +108,6 @@ def test_model_comparison_figure_aggregates_ensemble_members():
122108
assert len(ax.patches) == 2
123109

124110

125-
# --- save_figure ------------------------------------------------------------
126-
127111

128112
def test_save_figure_writes_png_and_pdf(tmp_path):
129113
df = pd.DataFrame({"source_id": ["ModelA"], "value": [1.0]})

0 commit comments

Comments
 (0)