Skip to content

Commit dc62ec8

Browse files
authored
Speed up cloned geography collision checks (#939)
1 parent 028c198 commit dc62ec8

2 files changed

Lines changed: 17 additions & 17 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Speed up cloned household geography assignment for large local-area calibration builds.

policyengine_us_data/calibration/clone_and_assign.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,36 +296,35 @@ def _sample(size, mask_slice=None, fixed_slice=None):
296296
# Clone 0: unrestricted draw
297297
indices[:n_records] = _sample(n_records, extreme_mask, fixed_states)
298298

299-
assigned_cds = np.empty((n_clones, n_records), dtype=object)
300-
assigned_cds[0] = cds[indices[:n_records]]
299+
_, cd_codes = np.unique(cds, return_inverse=True)
300+
cd_codes = cd_codes.astype(np.int32, copy=False)
301+
record_positions = np.arange(n_records)
302+
used_cd_by_record = np.zeros((n_records, cd_codes.max() + 1), dtype=bool)
303+
used_cd_by_record[record_positions, cd_codes[indices[:n_records]]] = True
301304

302305
for clone_idx in range(1, n_clones):
303306
start = clone_idx * n_records
304307
clone_indices = _sample(n_records, extreme_mask, fixed_states)
305-
clone_cds = cds[clone_indices]
306-
307-
collisions = np.zeros(n_records, dtype=bool)
308-
for prev in range(clone_idx):
309-
collisions |= clone_cds == assigned_cds[prev]
308+
clone_cd_codes = cd_codes[clone_indices]
309+
collisions = used_cd_by_record[record_positions, clone_cd_codes]
310310

311311
for _ in range(50):
312-
n_bad = collisions.sum()
312+
n_bad = int(collisions.sum())
313313
if n_bad == 0:
314314
break
315315
bad_mask = collisions
316316
if extreme_mask is not None and agi_probs is not None:
317-
replacement = _sample(n_records, extreme_mask, fixed_states)
318-
clone_indices[bad_mask] = replacement[bad_mask]
317+
fixed_bad = fixed_states[bad_mask] if fixed_states is not None else None
318+
replacement = _sample(n_bad, extreme_mask[bad_mask], fixed_bad)
319319
else:
320-
replacement = _sample(n_records, fixed_slice=fixed_states)
321-
clone_indices[collisions] = replacement[collisions]
322-
clone_cds = cds[clone_indices]
323-
collisions = np.zeros(n_records, dtype=bool)
324-
for prev in range(clone_idx):
325-
collisions |= clone_cds == assigned_cds[prev]
320+
fixed_bad = fixed_states[bad_mask] if fixed_states is not None else None
321+
replacement = _sample(n_bad, fixed_slice=fixed_bad)
322+
clone_indices[bad_mask] = replacement
323+
clone_cd_codes = cd_codes[clone_indices]
324+
collisions = used_cd_by_record[record_positions, clone_cd_codes]
326325

327326
indices[start : start + n_records] = clone_indices
328-
assigned_cds[clone_idx] = clone_cds
327+
used_cd_by_record[record_positions, clone_cd_codes] = True
329328

330329
assigned_blocks = blocks[indices]
331330
return GeographyAssignment(

0 commit comments

Comments
 (0)