Skip to content

Commit 6a616d6

Browse files
authored
Make simulation input exports safe by default (#496)
* Make simulation input exports safe by default * Preserve inherited inputs in safe exports
1 parent 4ab6c10 commit 6a616d6

3 files changed

Lines changed: 230 additions & 5 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Exclude pseudo-inputs and calculated values from simulation input exports by default.

policyengine_core/simulations/simulation.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,10 +1588,92 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool:
15881588

15891589
return True
15901590

1591+
def get_input_variables(self, include_computed_variables: bool = True) -> List[str]:
1592+
"""Return variable names stored as inputs on this simulation.
1593+
1594+
Args:
1595+
include_computed_variables: When ``True``, return the legacy
1596+
runtime list of variables with stored values. When ``False``,
1597+
return only structurally input variables that were populated
1598+
through ``set_input`` on the current branch.
1599+
1600+
Returns:
1601+
List[str]: Stored input variable names.
1602+
"""
1603+
if include_computed_variables:
1604+
return list(self.input_variables)
1605+
1606+
return [
1607+
variable_name
1608+
for variable_name in self.tax_benefit_system.variables
1609+
if len(
1610+
self._get_exportable_input_periods(
1611+
variable_name,
1612+
include_computed_variables=False,
1613+
)
1614+
)
1615+
> 0
1616+
]
1617+
1618+
@property
1619+
def true_input_variables(self) -> List[str]:
1620+
"""Stored variables that are safe to reload as source inputs."""
1621+
return self.get_input_variables(include_computed_variables=False)
1622+
1623+
def _is_exportable_input_variable(self, variable_name: str) -> bool:
1624+
variable = self.tax_benefit_system.get_variable(variable_name)
1625+
return variable is not None and variable.is_input_variable()
1626+
1627+
def _get_visible_branch_names(self) -> List[str]:
1628+
branch_names = [self.branch_name]
1629+
parent = getattr(self, "parent_branch", None)
1630+
while parent is not None:
1631+
branch_names.append(parent.branch_name)
1632+
parent = getattr(parent, "parent_branch", None)
1633+
branch_names.append("default")
1634+
return list(dict.fromkeys(branch_names))
1635+
1636+
def _get_exportable_input_periods(
1637+
self,
1638+
variable_name: str,
1639+
include_computed_variables: bool,
1640+
) -> List[Period]:
1641+
if include_computed_variables:
1642+
return self.get_holder(variable_name).get_known_periods()
1643+
1644+
if not self._is_exportable_input_variable(variable_name):
1645+
return []
1646+
1647+
user_input_periods = {
1648+
period
1649+
for input_variable_name, branch_name, period in getattr(
1650+
self, "_user_input_keys", set()
1651+
)
1652+
if input_variable_name == variable_name
1653+
and branch_name in self._get_visible_branch_names()
1654+
}
1655+
if not user_input_periods:
1656+
return []
1657+
variable = self.tax_benefit_system.get_variable(variable_name)
1658+
holder = self.get_holder(variable_name)
1659+
if variable.definition_period == ETERNITY:
1660+
return holder.get_known_periods()
1661+
known_periods = set(holder.get_known_periods())
1662+
return sorted(user_input_periods & known_periods, key=str)
1663+
15911664
def to_input_dataframe(
15921665
self,
1666+
include_computed_variables: bool = False,
15931667
) -> pd.DataFrame:
1594-
"""Exports a DataFrame which can be loaded back to a new Simulation to reproduce the same results.
1668+
"""Exports a DataFrame that can be loaded back into a new Simulation.
1669+
1670+
By default, only structurally input variables populated through
1671+
``set_input`` are exported. This avoids serializing pseudo-inputs and
1672+
stale calculated values that would override formulas when reloaded.
1673+
1674+
Args:
1675+
include_computed_variables: If ``True``, export every variable with
1676+
a known period, matching the historical unsafe behavior.
15951677
15961678
Returns:
15971679
pd.DataFrame: The DataFrame containing the input values.
@@ -1601,7 +1683,9 @@ def to_input_dataframe(
16011683

16021684
for variable in self.tax_benefit_system.variables:
16031685
variable_meta = self.tax_benefit_system.variables[variable]
1604-
for period in self.get_holder(variable).get_known_periods():
1686+
for period in self._get_exportable_input_periods(
1687+
variable, include_computed_variables
1688+
):
16051689
# Test if period matches entity definition period
16061690
if variable_meta.definition_period != period.unit:
16071691
continue
@@ -1611,8 +1695,16 @@ def to_input_dataframe(
16111695

16121696
return df
16131697

1614-
def to_input_dict(self) -> dict:
1615-
"""Exports a dictionary which can be loaded back to a new Simulation to reproduce the same results.
1698+
def to_input_dict(self, include_computed_variables: bool = False) -> dict:
1699+
"""Exports a dictionary that can be loaded back into a new Simulation.
1700+
1701+
By default, only structurally input variables populated through
1702+
``set_input`` are exported. This avoids serializing pseudo-inputs and
1703+
stale calculated values that would override formulas when reloaded.
1704+
1705+
Args:
1706+
include_computed_variables: If ``True``, export every variable with
1707+
a known period, matching the historical unsafe behavior.
16161708
16171709
Returns:
16181710
dict: The dictionary containing the input values.
@@ -1621,7 +1713,9 @@ def to_input_dict(self) -> dict:
16211713

16221714
for variable in self.tax_benefit_system.variables:
16231715
data[variable] = {}
1624-
for period in self.get_holder(variable).get_known_periods():
1716+
for period in self._get_exportable_input_periods(
1717+
variable, include_computed_variables
1718+
):
16251719
values = self.calculate(variable, period, map_to="person")
16261720
if values is not None:
16271721
data[variable][str(period)] = values.tolist()

tests/core/test_simulations.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from policyengine_core.country_template.situation_examples import single
2+
from policyengine_core.country_template import Simulation as CountryTemplateSimulation
3+
from policyengine_core.country_template.entities import Person
4+
from policyengine_core.data import Dataset
5+
from policyengine_core.model_api import Variable
6+
from policyengine_core.periods import MONTH
27
from policyengine_core.simulations import SimulationBuilder
38
import policyengine_core.simulations.simulation as simulation_module
49
from policyengine_core.simulations.simulation_macro_cache import (
510
SimulationMacroCache,
611
)
712
import importlib.metadata
813
import numpy as np
14+
import pandas as pd
915
from pathlib import Path
1016

1117

@@ -112,3 +118,127 @@ def __init__(self, tax_benefit_system):
112118
simulation = SimulationBuilder().build_default_simulation(tax_benefit_system)
113119

114120
simulation.calculate("income_tax", "2017-01")
121+
122+
123+
class formula_component_for_safe_export(Variable):
124+
value_type = float
125+
entity = Person
126+
definition_period = MONTH
127+
label = "Formula component for safe export tests."
128+
129+
def formula(person, period):
130+
return person("salary", period) * 0
131+
132+
133+
class pseudo_input_for_safe_export(Variable):
134+
value_type = float
135+
entity = Person
136+
definition_period = MONTH
137+
label = "Pseudo-input for safe export tests."
138+
adds = ["formula_component_for_safe_export"]
139+
140+
141+
def _safe_export_dataset(dataframe):
142+
return Dataset.from_dataframe(dataframe, "2022-01")
143+
144+
145+
def _safe_export_simulation(isolated_tax_benefit_system):
146+
isolated_tax_benefit_system.add_variable(formula_component_for_safe_export)
147+
isolated_tax_benefit_system.add_variable(pseudo_input_for_safe_export)
148+
149+
dataframe = pd.DataFrame(
150+
{
151+
"person_id__2022": [0],
152+
"household_id__2022": [0],
153+
"person_household_id__2022": [0],
154+
"person_household_role__2022": ["parent"],
155+
"household_weight__2022": [1.0],
156+
"salary__2022-01": [0.0],
157+
"pseudo_input_for_safe_export__2022-01": [999.0],
158+
}
159+
)
160+
return CountryTemplateSimulation(
161+
tax_benefit_system=isolated_tax_benefit_system,
162+
dataset=_safe_export_dataset(dataframe),
163+
)
164+
165+
166+
def test__given_pseudo_input_in_dataset__then_input_dataframe_excludes_it(
167+
isolated_tax_benefit_system,
168+
):
169+
# Given
170+
simulation = _safe_export_simulation(isolated_tax_benefit_system)
171+
172+
assert simulation.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 999.0
173+
174+
# When
175+
dataframe = simulation.to_input_dataframe()
176+
reloaded = CountryTemplateSimulation(
177+
tax_benefit_system=isolated_tax_benefit_system,
178+
dataset=_safe_export_dataset(dataframe),
179+
)
180+
181+
# Then
182+
assert "salary__2022-01" in dataframe.columns
183+
assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns
184+
assert "salary" in simulation.true_input_variables
185+
assert "pseudo_input_for_safe_export" not in simulation.true_input_variables
186+
assert (
187+
"pseudo_input_for_safe_export__2022-01"
188+
in simulation.to_input_dataframe(include_computed_variables=True).columns
189+
)
190+
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0
191+
192+
193+
def test__given_pseudo_input_in_dataset__then_input_dict_h5_round_trip_excludes_it(
194+
isolated_tax_benefit_system, tmp_path
195+
):
196+
# Given
197+
simulation = _safe_export_simulation(isolated_tax_benefit_system)
198+
exported_data = simulation.to_input_dict()
199+
h5_path = tmp_path / "safe_export.h5"
200+
201+
class SafeExportDataset(Dataset):
202+
name = "safe_export"
203+
label = "Safe export"
204+
file_path = h5_path
205+
data_format = Dataset.TIME_PERIOD_ARRAYS
206+
207+
# When
208+
SafeExportDataset().save_dataset(exported_data)
209+
reloaded = CountryTemplateSimulation(
210+
tax_benefit_system=isolated_tax_benefit_system,
211+
dataset=Dataset.from_file(h5_path),
212+
)
213+
214+
# Then
215+
assert "salary" in exported_data
216+
assert "pseudo_input_for_safe_export" not in exported_data
217+
assert "pseudo_input_for_safe_export" in simulation.to_input_dict(
218+
include_computed_variables=True
219+
)
220+
assert reloaded.calculate("pseudo_input_for_safe_export", "2022-01")[0] == 0.0
221+
222+
223+
def test__given_branch_inherits_dataset_inputs__then_safe_exports_include_them(
224+
isolated_tax_benefit_system,
225+
):
226+
# Given
227+
simulation = _safe_export_simulation(isolated_tax_benefit_system)
228+
branch = simulation.get_branch("reform")
229+
230+
assert branch.calculate("salary", "2022-01")[0] == 0.0
231+
232+
# When
233+
dataframe = branch.to_input_dataframe()
234+
exported_data = branch.to_input_dict()
235+
236+
# Then
237+
assert "person_id__ETERNITY" in dataframe.columns
238+
assert "household_id__ETERNITY" in dataframe.columns
239+
assert "household_weight__2022" in dataframe.columns
240+
assert "salary__2022-01" in dataframe.columns
241+
assert "pseudo_input_for_safe_export__2022-01" not in dataframe.columns
242+
assert "salary" in exported_data
243+
assert "pseudo_input_for_safe_export" not in exported_data
244+
assert "salary" in branch.true_input_variables

0 commit comments

Comments
 (0)