@@ -190,6 +190,7 @@ def assign_takeup_with_reported_anchors(
190190 rates ,
191191 reported_mask : Optional [np .ndarray ] = None ,
192192 group_keys : Optional [np .ndarray ] = None ,
193+ eligible_mask : Optional [np .ndarray ] = None ,
193194) -> np .ndarray :
194195 """Apply the SSI/SNAP-style reported-first takeup pattern.
195196
@@ -206,22 +207,30 @@ def assign_takeup_with_reported_anchors(
206207 if len (rates_arr ) != len (draws ):
207208 raise ValueError ("rates and draws must align" )
208209
210+ if eligible_mask is None :
211+ eligible_mask = np .ones (len (draws ), dtype = bool )
212+ else :
213+ eligible_mask = np .asarray (eligible_mask , dtype = bool )
214+ if len (eligible_mask ) != len (draws ):
215+ raise ValueError ("eligible_mask and draws must align" )
216+
209217 baseline = draws < rates_arr
210218 if reported_mask is None :
211- return baseline
219+ return eligible_mask & baseline
212220
213221 reported_mask = np .asarray (reported_mask , dtype = bool )
214222 if len (reported_mask ) != len (draws ):
215223 raise ValueError ("reported_mask and draws must align" )
216224
225+ eligible_mask = eligible_mask | reported_mask
217226 result = reported_mask .copy ()
218227
219228 if group_keys is None :
220229 unique_rates = np .unique (rates_arr )
221230 if len (unique_rates ) != 1 :
222231 raise ValueError ("group_keys required when rates vary by entity" )
223- target_count = int (unique_rates [0 ] * len ( draws ))
224- non_reporters = ~ reported_mask
232+ target_count = int (unique_rates [0 ] * int ( eligible_mask . sum () ))
233+ non_reporters = eligible_mask & ~ reported_mask
225234 remaining_needed = max (0 , target_count - int (reported_mask .sum ()))
226235 adjusted_rate = (
227236 remaining_needed / int (non_reporters .sum ()) if non_reporters .any () else 0
@@ -238,10 +247,11 @@ def assign_takeup_with_reported_anchors(
238247 group_rates = np .unique (rates_arr [group_mask ])
239248 if len (group_rates ) != 1 :
240249 raise ValueError ("Each takeup group must have a single rate" )
241- target_count = int (group_rates [0 ] * int (group_mask .sum ()))
250+ group_eligible = group_mask & eligible_mask
251+ target_count = int (group_rates [0 ] * int (group_eligible .sum ()))
242252 group_reported = reported_mask [group_mask ]
243253 remaining_needed = max (0 , target_count - int (group_reported .sum ()))
244- group_non_reporters = group_mask & ~ reported_mask
254+ group_non_reporters = group_eligible & ~ reported_mask
245255 adjusted_rate = (
246256 remaining_needed / int (group_non_reporters .sum ())
247257 if group_non_reporters .any ()
@@ -423,6 +433,7 @@ def compute_block_takeup_for_entities(
423433 entity_hh_ids : np .ndarray = None ,
424434 entity_clone_ids : np .ndarray = None ,
425435 reported_mask : Optional [np .ndarray ] = None ,
436+ eligible_mask : Optional [np .ndarray ] = None ,
426437) -> np .ndarray :
427438 """Compute boolean takeup via block-level seeded draws."""
428439 draws = compute_block_takeup_draws_for_entities (
@@ -448,6 +459,7 @@ def compute_block_takeup_for_entities(
448459 rates ,
449460 reported_mask = reported_mask ,
450461 group_keys = group_keys ,
462+ eligible_mask = eligible_mask ,
451463 )
452464
453465
@@ -660,6 +672,7 @@ def apply_block_takeup_to_arrays(
660672 takeup_filter : List [str ] = None ,
661673 precomputed_rates : Optional [Dict [str , Any ]] = None ,
662674 reported_anchors : Optional [Dict [str , np .ndarray ]] = None ,
675+ eligibility_masks : Optional [Dict [str , np .ndarray ]] = None ,
663676 voluntary_filing_inputs : Optional [Dict [str , np .ndarray ]] = None ,
664677) -> Dict [str , np .ndarray ]:
665678 """Compute takeup draws from raw arrays.
@@ -686,13 +699,18 @@ def apply_block_takeup_to_arrays(
686699 precomputed_rates: Optional {rate_key: rate_or_dict} cache.
687700 When provided, skips ``load_take_up_rate`` calls and
688701 uses cached values instead.
702+ reported_anchors: Optional {takeup variable: bool array}; reported
703+ recipients are always assigned take-up.
704+ eligibility_masks: Optional {takeup variable: bool array}; non-reported
705+ take-up is drawn only from the matching eligible entity rows.
689706
690707 Returns:
691708 {variable_name: bool_array} for each takeup variable.
692709 """
693710 filter_set = set (takeup_filter ) if takeup_filter is not None else None
694711 result = {}
695712 reported_anchors = reported_anchors or {}
713+ eligibility_masks = eligibility_masks or {}
696714
697715 for spec in SIMPLE_TAKEUP_VARS :
698716 var_name = spec ["variable" ]
@@ -716,6 +734,9 @@ def apply_block_takeup_to_arrays(
716734 reported_mask = reported_anchors .get (var_name )
717735 if reported_mask is not None and len (reported_mask ) != n_ent :
718736 raise ValueError (f"reported anchor for { var_name } has wrong length" )
737+ eligible_mask = eligibility_masks .get (var_name )
738+ if eligible_mask is not None and len (eligible_mask ) != n_ent :
739+ raise ValueError (f"eligibility mask for { var_name } has wrong length" )
719740 if var_name == "would_file_taxes_voluntarily" :
720741 if voluntary_filing_inputs is None :
721742 raise ValueError (
@@ -739,6 +760,7 @@ def apply_block_takeup_to_arrays(
739760 ent_hh_ids ,
740761 ent_clone_indices ,
741762 reported_mask = reported_mask ,
763+ eligible_mask = eligible_mask ,
742764 )
743765 result [var_name ] = bools
744766
0 commit comments