Skip to content

Commit 6a7ed01

Browse files
authored
Default eCPS reweighting to zero L0 penalty (#1124)
1 parent 104ff16 commit 6a7ed01

5 files changed

Lines changed: 17 additions & 8 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Default enhanced CPS reweighting to no L0 sparsity penalty while preserving the configurable penalty for large calibration runs.

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def reweight(
490490
targets_array,
491491
log_path="calibration_log.csv",
492492
epochs=500,
493-
l0_lambda=2.6445e-07,
493+
l0_lambda=0.0,
494494
init_mean=0.999, # initial proportion with non-zero weights
495495
temperature=0.25,
496496
seed=1456,
@@ -534,7 +534,7 @@ def loss(weights):
534534
return rel_error_normalized.mean()
535535

536536
logging.info(
537-
f"Sparse optimization using seed {seed}, temp {temperature} "
537+
f"Hard-concrete optimization using seed {seed}, temp {temperature} "
538538
+ f"init_mean {init_mean}, l0_lambda {l0_lambda}"
539539
)
540540
set_seeds(seed)
@@ -600,7 +600,7 @@ def loss(weights):
600600
final_weights_sparse,
601601
loss_matrix,
602602
targets_array,
603-
"L0 Sparse Solution",
603+
"L0 Sparse Solution" if l0_lambda else "Unpenalized HardConcrete Solution",
604604
target_names=target_names,
605605
)
606606

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: 3.14",
2323
]
2424
dependencies = [
25-
"policyengine-us==1.705.1",
25+
"policyengine-us==1.705.15",
2626
# policyengine-core 3.26.1 is the current 3.26.x runtime and includes the fix for
2727
# PolicyEngine/policyengine-core#482 (user-set ETERNITY inputs lost
2828
# after _invalidate_all_caches) and is required by policyengine-us 1.682.1+.

tests/unit/test_enhanced_cps.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
13
import numpy as np
24

35
from policyengine_us_data.datasets.cps import enhanced_cps
@@ -9,6 +11,12 @@
911
)
1012

1113

14+
def test_reweight_default_does_not_penalize_l0():
15+
signature = inspect.signature(enhanced_cps.reweight)
16+
17+
assert signature.parameters["l0_lambda"].default == 0.0
18+
19+
1220
def test_get_base_aca_takeup_uses_stored_values():
1321
data = {
1422
"takes_up_aca_if_eligible": {

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)