Skip to content

Commit 1be75e2

Browse files
committed
Add random seed
1 parent 5a441e2 commit 1be75e2

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

src/qcatch/find_retained_cells/cell_calling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646
MAX_MEM_GB = 0.3
4747

48+
RNG = np.random.default_rng(seed=42)
49+
4850
NonAmbientBarcodeResult = namedtuple(
4951
"NonAmbientBarcodeResult",
5052
[
@@ -360,15 +362,15 @@ def simulate_multinomial_loglikelihoods(
360362
logger.debug("Range of N: %d", num_all_n)
361363
logger.debug("Number of features: %d", len(profile_p))
362364

363-
sampled_features = np.random.choice(len(profile_p), size=n_sample_feature_block, p=profile_p, replace=True)
365+
sampled_features = RNG.choice(len(profile_p), size=n_sample_feature_block, p=profile_p, replace=True)
364366
k = 0
365367

366368
log_profile_p = np.log(profile_p)
367369

368370
for sim_idx in range(num_sims):
369371
if verbose and sim_idx % 1000 == 999:
370372
logger.debug("Simulation progress: completed %d/%d simulations", sim_idx + 1, num_sims)
371-
curr_counts = np.ravel(sp_stats.multinomial.rvs(distinct_n[0], profile_p, size=1))
373+
curr_counts = np.ravel(sp_stats.multinomial.rvs(distinct_n[0], profile_p, size=1, random_state=RNG))
372374

373375
curr_loglk = sp_stats.multinomial.logpmf(curr_counts, distinct_n[0], p=profile_p)
374376

@@ -378,7 +380,7 @@ def simulate_multinomial_loglikelihoods(
378380
step = distinct_n[i] - distinct_n[i - 1]
379381
if step >= jump:
380382
# Instead of iterating for each n, sample the intermediate ns all at once
381-
curr_counts += np.ravel(sp_stats.multinomial.rvs(step, profile_p, size=1))
383+
curr_counts += np.ravel(sp_stats.multinomial.rvs(step, profile_p, size=1, random_state=RNG))
382384
curr_loglk = sp_stats.multinomial.logpmf(curr_counts, distinct_n[i], p=profile_p)
383385
assert not np.isnan(curr_loglk)
384386
else:
@@ -388,7 +390,7 @@ def simulate_multinomial_loglikelihoods(
388390
k += 1
389391
if k >= n_sample_feature_block:
390392
# Amortize this operation
391-
sampled_features = np.random.choice(
393+
sampled_features = RNG.choice(
392394
len(profile_p), size=n_sample_feature_block, p=profile_p, replace=True
393395
)
394396
k = 0
@@ -555,6 +557,7 @@ def find_nonambient_barcodes(
555557
return None
556558

557559
assert not np.any(np.isin(eval_bcs, orig_cells))
560+
558561
logger.debug(f"Number of candidate bcs: {len(eval_bcs)}")
559562
logger.debug(f"Range candidate bc umis: {umis_per_bc[eval_bcs].min()}, {umis_per_bc[eval_bcs].max()}")
560563

0 commit comments

Comments
 (0)