@@ -296,36 +296,35 @@ def _sample(size, mask_slice=None, fixed_slice=None):
296296 # Clone 0: unrestricted draw
297297 indices [:n_records ] = _sample (n_records , extreme_mask , fixed_states )
298298
299- assigned_cds = np .empty ((n_clones , n_records ), dtype = object )
300- assigned_cds [0 ] = cds [indices [:n_records ]]
299+ _ , cd_codes = np .unique (cds , return_inverse = True )
300+ cd_codes = cd_codes .astype (np .int32 , copy = False )
301+ record_positions = np .arange (n_records )
302+ used_cd_by_record = np .zeros ((n_records , cd_codes .max () + 1 ), dtype = bool )
303+ used_cd_by_record [record_positions , cd_codes [indices [:n_records ]]] = True
301304
302305 for clone_idx in range (1 , n_clones ):
303306 start = clone_idx * n_records
304307 clone_indices = _sample (n_records , extreme_mask , fixed_states )
305- clone_cds = cds [clone_indices ]
306-
307- collisions = np .zeros (n_records , dtype = bool )
308- for prev in range (clone_idx ):
309- collisions |= clone_cds == assigned_cds [prev ]
308+ clone_cd_codes = cd_codes [clone_indices ]
309+ collisions = used_cd_by_record [record_positions , clone_cd_codes ]
310310
311311 for _ in range (50 ):
312- n_bad = collisions .sum ()
312+ n_bad = int ( collisions .sum () )
313313 if n_bad == 0 :
314314 break
315315 bad_mask = collisions
316316 if extreme_mask is not None and agi_probs is not None :
317- replacement = _sample ( n_records , extreme_mask , fixed_states )
318- clone_indices [ bad_mask ] = replacement [bad_mask ]
317+ fixed_bad = fixed_states [ bad_mask ] if fixed_states is not None else None
318+ replacement = _sample ( n_bad , extreme_mask [bad_mask ], fixed_bad )
319319 else :
320- replacement = _sample (n_records , fixed_slice = fixed_states )
321- clone_indices [collisions ] = replacement [collisions ]
322- clone_cds = cds [clone_indices ]
323- collisions = np .zeros (n_records , dtype = bool )
324- for prev in range (clone_idx ):
325- collisions |= clone_cds == assigned_cds [prev ]
320+ fixed_bad = fixed_states [bad_mask ] if fixed_states is not None else None
321+ replacement = _sample (n_bad , fixed_slice = fixed_bad )
322+ clone_indices [bad_mask ] = replacement
323+ clone_cd_codes = cd_codes [clone_indices ]
324+ collisions = used_cd_by_record [record_positions , clone_cd_codes ]
326325
327326 indices [start : start + n_records ] = clone_indices
328- assigned_cds [ clone_idx ] = clone_cds
327+ used_cd_by_record [ record_positions , clone_cd_codes ] = True
329328
330329 assigned_blocks = blocks [indices ]
331330 return GeographyAssignment (
0 commit comments