Skip to content

Commit 6ed821f

Browse files
committed
Add geography tests and menu_instance_with_geography fixture
1 parent 6852992 commit 6ed821f

2 files changed

Lines changed: 227 additions & 0 deletions

File tree

tests/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,49 @@ def save_ssprff_econ(tmp_path):
193193

194194
ssp_econ.to_zarr(d / "integration-econ-bc39.zarr")
195195
rff_econ.to_netcdf(d / "rff_global_socioeconomics.nc4")
196+
197+
198+
@pytest.fixture(scope="module")
199+
def menu_instance_with_geography(menu_class, discount_types, econ, climate):
200+
"""Menu instance with geography parameter."""
201+
datadir = os.path.join(os.path.dirname(__file__), "data")
202+
yield menu_class(
203+
sector_path=[{"dummy_sector": os.path.join(datadir, "damages")}],
204+
save_path=None,
205+
discrete_discounting=True,
206+
econ_vars=econ,
207+
climate_vars=climate,
208+
fit_type="ols",
209+
variable=[{"dummy_sector": "damages"}],
210+
sector="dummy_sector",
211+
discounting_type=discount_types,
212+
ext_method="global_c_ratio",
213+
save_files=[
214+
"damage_function_points",
215+
"global_consumption",
216+
"damage_function_coefficients",
217+
"damage_function_fit",
218+
],
219+
ce_path=os.path.join(datadir, "CEs"),
220+
subset_dict={
221+
"ssp": ["SSP2", "SSP3", "SSP4"],
222+
"region": [
223+
"IND.21.317.1249",
224+
"CAN.2.33.913",
225+
"USA.14.608",
226+
"EGY.11",
227+
"SDN.4.11.50.164",
228+
"NGA.25.510",
229+
"SAU.7",
230+
"RUS.16.430.430",
231+
"SOM.2.5",
232+
],
233+
},
234+
formula="damages ~ -1 + anomaly + np.power(anomaly, 2)",
235+
extrap_formula=None,
236+
fair_aggregation=["median_params", "ce", "mean"],
237+
weitzman_parameter=[0.1],
238+
geography="globe",
239+
country_mapping_path=None,
240+
individual_region=None,
241+
)

tests/test_geography.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""Tests for geography functionality and backward compatibility."""
2+
3+
import pandas
4+
import xarray as xr
5+
import numpy as np
6+
import pytest
7+
8+
from dscim.menu.risk_aversion import RiskAversionRecipe
9+
10+
11+
class TestGlobeGeographyEquivalence:
12+
"""Tests that xarray path produces same results as pandas path for globe."""
13+
14+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
15+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
16+
def test_damages_dataset_equals_global_damages_calculation(self, menu_instance):
17+
df_pandas = menu_instance.global_damages_calculation()
18+
ds_xarray = menu_instance.damages_dataset(geography="globe")
19+
20+
df_xarray = ds_xarray.to_dataframe().reset_index()
21+
22+
assert "damages" in df_pandas.columns
23+
assert "damages" in df_xarray.columns
24+
25+
damages_pandas = df_pandas["damages"].sort_values().reset_index(drop=True)
26+
damages_xarray = df_xarray["damages"].sort_values().reset_index(drop=True)
27+
28+
np.testing.assert_allclose(
29+
damages_pandas.values,
30+
damages_xarray.values,
31+
rtol=1e-10,
32+
atol=1e-10,
33+
)
34+
35+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
36+
@pytest.mark.parametrize(
37+
"discount_types", ["euler_ramsey", "euler_gwr", "constant"], indirect=True
38+
)
39+
def test_damages_dataset_returns_dataset(self, menu_instance):
40+
result = menu_instance.damages_dataset(geography="globe")
41+
assert isinstance(result, xr.Dataset)
42+
assert "damages" in result.data_vars
43+
44+
45+
class TestGeographyAggregation:
46+
"""Tests for _aggregate_by_geography method."""
47+
48+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
49+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
50+
def test_aggregate_globe_sums_all_regions(self, menu_instance):
51+
damages = menu_instance.calculated_damages * menu_instance.collapsed_pop
52+
53+
expected = damages.sum(dim="region")
54+
actual = menu_instance._aggregate_by_geography(damages, "globe")
55+
56+
assert actual.region.values == ["globe"]
57+
58+
actual_values = actual.squeeze(dim="region", drop=True)
59+
xr.testing.assert_allclose(expected, actual_values)
60+
61+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
62+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
63+
def test_aggregate_ir_preserves_regions(self, menu_instance):
64+
damages = menu_instance.calculated_damages * menu_instance.collapsed_pop
65+
66+
result = menu_instance._aggregate_by_geography(damages, "ir")
67+
68+
xr.testing.assert_allclose(result, damages)
69+
70+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
71+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
72+
def test_invalid_geography_raises_error(self, menu_instance):
73+
damages = menu_instance.calculated_damages * menu_instance.collapsed_pop
74+
75+
with pytest.raises(ValueError, match="Unknown geography"):
76+
menu_instance._aggregate_by_geography(damages, "invalid_geography")
77+
78+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
79+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
80+
def test_country_without_mapping_raises_error(self, menu_instance):
81+
damages = menu_instance.calculated_damages * menu_instance.collapsed_pop
82+
83+
menu_instance.country_mapping = None
84+
85+
with pytest.raises(ValueError, match="country_mapping"):
86+
menu_instance._aggregate_by_geography(damages, "country")
87+
88+
89+
class TestBackwardCompatibility:
90+
"""Tests for backward compatibility with existing API."""
91+
92+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
93+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
94+
def test_global_damages_calculation_returns_dataframe(self, menu_instance):
95+
result = menu_instance.global_damages_calculation()
96+
assert isinstance(result, pandas.DataFrame)
97+
assert "region" not in result.columns
98+
99+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
100+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
101+
def test_damage_function_points_returns_dataframe(self, menu_instance):
102+
result = menu_instance.damage_function_points
103+
assert isinstance(result, pandas.DataFrame)
104+
105+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
106+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
107+
def test_default_geography_is_globe(self, menu_instance):
108+
assert menu_instance.geography == "globe"
109+
110+
111+
class TestDualPathEquivalence:
112+
"""Tests for pandas vs xarray path equivalence."""
113+
114+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
115+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
116+
def test_pandas_path_used_for_globe(self, menu_instance):
117+
assert menu_instance.geography == "globe"
118+
119+
result = menu_instance.damage_function_points
120+
assert isinstance(result, pandas.DataFrame)
121+
122+
expected = menu_instance._damage_function_points_pandas()
123+
pandas.testing.assert_frame_equal(result, expected)
124+
125+
@pytest.mark.parametrize("menu_class", [RiskAversionRecipe], indirect=True)
126+
@pytest.mark.parametrize("discount_types", ["euler_ramsey"], indirect=True)
127+
def test_xarray_path_matches_pandas_path_for_globe(self, menu_instance):
128+
# Compare full pipeline: damages, climate merge, illegal filtering
129+
pandas_result = menu_instance._damage_function_points_pandas()
130+
131+
original_geography = menu_instance.geography
132+
menu_instance.geography = "globe"
133+
xarray_result = menu_instance._damage_function_points_xarray()
134+
menu_instance.geography = original_geography
135+
136+
assert isinstance(pandas_result, pandas.DataFrame)
137+
assert isinstance(xarray_result, pandas.DataFrame)
138+
139+
assert "damages" in pandas_result.columns
140+
assert "damages" in xarray_result.columns
141+
142+
sort_cols = [c for c in ["year", "ssp", "model", "gcm", "rcp"] if c in pandas_result.columns]
143+
pandas_sorted = pandas_result.sort_values(sort_cols).reset_index(drop=True)
144+
xarray_sorted = xarray_result.sort_values(sort_cols).reset_index(drop=True)
145+
146+
np.testing.assert_allclose(
147+
pandas_sorted["damages"].values,
148+
xarray_sorted["damages"].values,
149+
rtol=1e-10,
150+
atol=1e-10,
151+
)
152+
153+
if "anomaly" in pandas_sorted.columns and "anomaly" in xarray_sorted.columns:
154+
pandas_nan = pandas_sorted["anomaly"].isna()
155+
xarray_nan = xarray_sorted["anomaly"].isna()
156+
assert (pandas_nan == xarray_nan).all()
157+
158+
pandas_valid = pandas_sorted.loc[~pandas_nan, "anomaly"].values
159+
xarray_valid = xarray_sorted.loc[~xarray_nan, "anomaly"].values
160+
np.testing.assert_allclose(
161+
pandas_valid,
162+
xarray_valid,
163+
rtol=1e-10,
164+
atol=1e-10,
165+
)
166+
167+
168+
class TestCountryAggregation:
169+
"""Tests for country-level aggregation."""
170+
171+
@pytest.mark.skip(reason="Requires country_mapping fixture")
172+
def test_country_aggregation(self):
173+
pass
174+
175+
176+
class TestIndividualRegion:
177+
"""Tests for individual region calculations."""
178+
179+
@pytest.mark.skip(reason="For future individual_region support")
180+
def test_individual_region_filter(self):
181+
pass

0 commit comments

Comments
 (0)