Skip to content

Commit 7de3a1e

Browse files
committed
Make simulation input exports safe by default
1 parent 4ab6c10 commit 7de3a1e

3 files changed

Lines changed: 196 additions & 5 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Exclude pseudo-inputs and calculated values from simulation input exports by default.

policyengine_core/simulations/simulation.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,10 +1588,82 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool:
15881588

15891589
return True
15901590

1591+
def get_input_variables(self, include_computed_variables: bool = True) -> List[str]:
1592+
"""Return variable names stored as inputs on this simulation.
1593+
1594+
Args:
1595+
include_computed_variables: When ``True``, return the legacy
1596+
runtime list of variables with stored values. When ``False``,
1597+
return only structurally input variables that were populated
1598+
through ``set_input`` on the current branch.
1599+
1600+
Returns:
1601+
List[str]: Stored input variable names.
1602+
"""
1603+
if include_computed_variables:
1604+
return list(self.input_variables)
1605+
1606+
return [
1607+
variable_name
1608+
for variable_name in self.tax_benefit_system.variables
1609+
if len(
1610+
self._get_exportable_input_periods(
1611+
variable_name,
1612+
include_computed_variables=False,
1613+
)
1614+
)
1615+
> 0
1616+
]
1617+
1618+
@property
1619+
def true_input_variables(self) -> List[str]:
1620+
"""Stored variables that are safe to reload as source inputs."""
1621+
return self.get_input_variables(include_computed_variables=False)
1622+
1623+
def _is_exportable_input_variable(self, variable_name: str) -> bool:
1624+
variable = self.tax_benefit_system.get_variable(variable_name)
1625+
return variable is not None and variable.is_input_variable()
1626+
1627+
def _get_exportable_input_periods(
1628+
self,
1629+
variable_name: str,
1630+
include_computed_variables: bool,
1631+
) -> List[Period]:
1632+
if include_computed_variables:
1633+
return self.get_holder(variable_name).get_known_periods()
1634+
1635+
if not self._is_exportable_input_variable(variable_name):
1636+
return []
1637+
1638+
user_input_periods = {
1639+
period
1640+
for input_variable_name, branch_name, period in getattr(
1641+
self, "_user_input_keys", set()
1642+
)
1643+
if input_variable_name == variable_name and branch_name == self.branch_name
1644+
}
1645+
if not user_input_periods:
1646+
return []
1647+
variable = self.tax_benefit_system.get_variable(variable_name)
1648+
holder = self.get_holder(variable_name)
1649+
if variable.definition_period == ETERNITY:
1650+
return holder.get_known_periods()
1651+
known_periods = set(holder.get_known_periods())
1652+
return sorted(user_input_periods & known_periods, key=str)
1653+
15911654
def to_input_dataframe(
15921655
self,
1656+
include_computed_variables: bool = False,
15931657
) -> pd.DataFrame:
1594-
"""Exports a DataFrame which can be loaded back to a new Simulation to reproduce the same results.
1658+
"""Exports a DataFrame that can be loaded back into a new Simulation.
1659+
1660+
By default, only structurally input variables populated through
1661+
``set_input`` are exported. This avoids serializing pseudo-inputs and
1662+
stale calculated values that would override formulas when reloaded.
1663+
1664+
Args:
1665+
include_computed_variables: If ``True``, export every variable with
1666+
a known period, matching the historical unsafe behavior.
15951667
15961668
Returns:
15971669
pd.DataFrame: The DataFrame containing the input values.
@@ -1601,7 +1673,9 @@ def to_input_dataframe(
16011673

16021674
for variable in self.tax_benefit_system.variables:
16031675
variable_meta = self.tax_benefit_system.variables[variable]
1604-
for period in self.get_holder(variable).get_known_periods():
1676+
for period in self._get_exportable_input_periods(
1677+
variable, include_computed_variables
1678+
):
16051679
# Test if period matches entity definition period
16061680
if variable_meta.definition_period != period.unit:
16071681
continue
@@ -1611,8 +1685,16 @@ def to_input_dataframe(
16111685

16121686
return df
16131687

1614-
def to_input_dict(self) -> dict:
1615-
"""Exports a dictionary which can be loaded back to a new Simulation to reproduce the same results.
1688+
def to_input_dict(self, include_computed_variables: bool = False) -> dict:
1689+
"""Exports a dictionary that can be loaded back into a new Simulation.
1690+
1691+
By default, only structurally input variables populated through
1692+
``set_input`` are exported. This avoids serializing pseudo-inputs and
1693+
stale calculated values that would override formulas when reloaded.
1694+
1695+
Args:
1696+
include_computed_variables: If ``True``, export every variable with
1697+
a known period, matching the historical unsafe behavior.
16161698
16171699
Returns:
16181700
dict: The dictionary containing the input values.
@@ -1621,7 +1703,9 @@ def to_input_dict(self) -> dict:
16211703

16221704
for variable in self.tax_benefit_system.variables:
16231705
data[variable] = {}
1624-
for period in self.get_holder(variable).get_known_periods():
1706+
for period in self._get_exportable_input_periods(
1707+
variable, include_computed_variables
1708+
):
16251709
values = self.calculate(variable, period, map_to="person")
16261710
if values is not None:
16271711
data[variable][str(period)] = values.tolist()

tests/core/test_simulations.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from policyengine_core.country_template.situation_examples import single
2+
from policyengine_core.country_template import Simulation as CountryTemplateSimulation
3+
from policyengine_core.country_template.entities import Person
4+
from policyengine_core.data import Dataset
5+
from policyengine_core.model_api import Variable
6+
from policyengine_core.periods import MONTH
27
from policyengine_core.simulations import SimulationBuilder
38
import policyengine_core.simulations.simulation as simulation_module
49
from policyengine_core.simulations.simulation_macro_cache import (
510
SimulationMacroCache,
611
)
712
import importlib.metadata
813
import numpy as np
14+
import pandas as pd
915
from pathlib import Path
1016

1117

@@ -112,3 +118,103 @@ def __init__(self, tax_benefit_system):
112118
simulation = SimulationBuilder().build_default_simulation(tax_benefit_system)
113119

114120
simulation.calculate("income_tax", "2017-01")
121+
122+
123+
class formula_component_for_safe_export(Variable):
124+
value_type = float
125+
entity = Person
126+
definition_period = MONTH
127+
label = "Formula component for safe export tests."
128+
129+
def formula(person, period):
130+
return person("salary", period) * 0
131+
132+
133+
class pseudo_input_for_safe_export(Variable):
134+
value_type = float
135+
entity = Person
136+
definition_period = MONTH
137+
label = "Pseudo-input for safe export tests."
138+
adds = ["formula_component_for_safe_export"]
139+
140+
141+
def _safe_export_dataset(dataframe):
142+
return Dataset.from_dataframe(dataframe, "2022-01")
143+
144+
145+
def _safe_export_simulation(isolated_tax_benefit_system):
146+
isolated_tax_benefit_system.add_variable(formula_component_for_safe_export)
147+
isolated_tax_benefit_system.add_variable(pseudo_input_for_safe_export)
148+
149+
dataframe = pd.DataFrame(
150+
{
151+
"person_id__2022": [0],
152+
"household_id__2022": [0],
153+
"person_household_id__2022": [0],
154+
"person_household_role__2022": ["parent"],
155+
"household_weight__2022": [1.0],
156+
"salary__2022-01": [0.0],
157+
"pseudo_input_for_safe_export__2022-01": [999.0],
158+
}
159+
)
160+
return CountryTemplateSimulation(
161+
tax_benefit_system=isolated_tax_benefit_system,
162+
dataset=_safe_export_dataset(dataframe),
163+
)
164+
165+
166+
def test__given_pseudo_input_in_dataset__then_input_dataframe_excludes_it(
167+
isolated_tax_benefit_system,
168+
):
169+
# Given
170+
simulation = _safe_export_simulation(isolated_tax_benefit_system)
171+
172+
assert simulation.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 999.0
173+
174+
# When
175+
dataframe = simulation.to_input_dataframe()
176+
reloaded = CountryTemplateSimulation(
177+
tax_benefit_system=isolated_tax_benefit_system,
178+
dataset=_safe_export_dataset(dataframe),
179+
)
180+
181+
# Then
182+
assert "salary__2022-01" in dataframe.columns
183+
assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns
184+
assert "salary" in simulation.true_input_variables
185+
assert "pseudo_input_for_safe_export" not in simulation.true_input_variables
186+
assert (
187+
"pseudo_input_for_safe_export__2022-01"
188+
in simulation.to_input_dataframe(include_computed_variables=True).columns
189+
)
190+
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0
191+
192+
193+
def test__given_pseudo_input_in_dataset__then_input_dict_h5_round_trip_excludes_it(
194+
isolated_tax_benefit_system, tmp_path
195+
):
196+
# Given
197+
simulation = _safe_export_simulation(isolated_tax_benefit_system)
198+
exported_data = simulation.to_input_dict()
199+
h5_path = tmp_path / "safe_export.h5"
200+
201+
class SafeExportDataset(Dataset):
202+
name = "safe_export"
203+
label = "Safe export"
204+
file_path = h5_path
205+
data_format = Dataset.TIME_PERIOD_ARRAYS
206+
207+
# When
208+
SafeExportDataset().save_dataset(exported_data)
209+
reloaded = CountryTemplateSimulation(
210+
tax_benefit_system=isolated_tax_benefit_system,
211+
dataset=Dataset.from_file(h5_path),
212+
)
213+
214+
# Then
215+
assert "salary" in exported_data
216+
assert "pseudo_input_for_safe_export" not in exported_data
217+
assert "pseudo_input_for_safe_export" in simulation.to_input_dict(
218+
include_computed_variables=True
219+
)
220+
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0

0 commit comments

Comments
 (0)