Skip to content

Commit 917c928

Browse files
committed
Complete implementation of saving positions per timestamps (untested)
1 parent 024d73f commit 917c928

1 file changed

Lines changed: 149 additions & 79 deletions

File tree

pybromo/diffusion.py

Lines changed: 149 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def simulate_diffusion(self, save_pos=False, total_emission=True,
623623
save_pos=save_pos, radial=radial,
624624
wrap_func=wrap_func)
625625

626-
## Append em to the permanent storage
626+
# Append em to the permanent storage
627627
# if total_emission, data is just a linear array
628628
# otherwise is a 2-D array (self.num_particles, c_size)
629629
em_store.append(em)
@@ -685,108 +685,171 @@ def timestamp_names(self):
685685
names.append(node.name)
686686
return names
687687

688-
def _sim_timestamps(self, max_rate, bg_rate, emission, i_start, rs,
689-
ip_start=0, scale=10, sort=True, save_pos=False):
690-
"""Simulate timestamps from emission trajectories.
691-
692-
Uses attributes: `.t_step`.
688+
@staticmethod
689+
def _timestamps_from_counts(counts, time_axis, max_rate, bg_rate,
690+
position=None, sort=True):
691+
"""Compute timestamps from timetraces of counts.
692+
693+
This function operates on a given "group" of particles
694+
(a population) and a given time chunk.
695+
Number of particles is `counts.shape[0]` and number of time bins is
696+
`counts.shape[1] == len(time_axis)`.
697+
698+
The function takes a 2D array `counts` (1 row per particle,
699+
1 column per time bin), and a 1D array of times `time_axis`
700+
and generates an array of timestamps where counts > 0.
701+
When counts is > 1, there will be multiple identical timestamps
702+
(for example if counts is 3, there will be 3 identical timestamps).
703+
This function also computes particle number and (optionally)
704+
particle position for each timestamp.
705+
706+
If `positions` is not None, returns also position of particles at each
707+
timestamps. `positions` should be an array of shape
708+
(num_particles, num_spatial_dims, num_time_bins) containing positions
709+
for the same particles and time bins in `counts`.
693710
694711
Returns:
695-
A tuple of two arrays: timestamps and particles.
712+
A tuple of 3 arrays: timestamps, particles and positions.
713+
714+
Notes:
715+
Particles number always starts at 0, as this function is unaware
716+
of the "real" particle ID. If needed, the caller should add an
717+
offset to the returned particles array to obtain
718+
the real particle ID.
696719
"""
697-
counts_chunk = sim_timetrace_bg(emission, max_rate, bg_rate,
698-
self.t_step, rs=rs)
699-
nrows = emission.shape[0]
700-
if bg_rate is not None:
701-
nrows += 1
702-
assert counts_chunk.shape == (nrows, emission.shape[1])
703-
max_counts = counts_chunk.max()
720+
max_counts = counts.max()
704721
if max_counts == 0:
705722
return (np.array([], dtype=np.int64), # timestamps
706723
np.array([], dtype=np.int64), # particles
707724
np.array([], dtype=np.float32)) # positions
708725

709-
time_start = i_start * scale
710-
time_stop = time_start + counts_chunk.shape[1] * scale
711-
ts_range = np.arange(time_start, time_stop, scale, dtype='int64')
712-
726+
# These lists will contain one array per particle
727+
ts_times_parlist = []
728+
ts_particles_parlist = []
729+
ts_positions_parlist = []
713730
# Loop for each particle to compute timestamps
714-
times_chunk_p = []
715-
par_index_chunk_p = []
716-
for ip, counts_chunk_ip in enumerate(counts_chunk):
731+
for ip, counts_ip in enumerate(counts):
717732
# Compute timestamps for particle ip for all bins with counts > 0
718-
times_c_ip = []
733+
ts_times_by_num_counts = []
734+
ts_positions_by_num_counts = []
719735
for v in range(1, max_counts + 1):
720-
times_c_ip.append(ts_range[counts_chunk_ip >= v])
736+
mask = counts_ip >= v
737+
ts_times_by_num_counts.append(time_axis[mask])
738+
if position is not None:
739+
# list of 2D arrays
740+
ts_positions_by_num_counts.append(position[ip, :, mask])
721741

722742
# Stack the timestamps from different "counts"
723-
t = np.hstack(times_c_ip)
724-
# Append current particle
725-
times_chunk_p.append(t)
726-
par_index_chunk_p.append(np.full(t.size, ip + ip_start, dtype='u1'))
727-
if save_pos:
728-
# TODO compute and asve positions
729-
pass
743+
ts = np.hstack(ts_times_by_num_counts)
744+
ts_times_parlist.append(ts)
745+
ts_particles_parlist.append(np.full(ts.size, ip, dtype='u1'))
746+
if position is not None:
747+
# concatenate 2D arrays by columns
748+
pos_current_particle = np.hstack(ts_positions_by_num_counts)
749+
ts_positions_parlist.append(pos_current_particle)
730750

731751
# Merge the arrays of different particles
732-
times_chunk = np.hstack(times_chunk_p)
733-
par_index_chunk = np.hstack(par_index_chunk_p)
734-
positions_chunk = np.zeros(
735-
(times_chunk.shape[-1], self.positions.shape[1]), dtype=np.float32)
752+
ts_times = np.hstack(ts_times_parlist)
753+
ts_particles = np.hstack(ts_particles_parlist)
754+
# ts_positions are 2D, concatenate columns
755+
# (rows are spatial coordinates)
756+
ts_positions = np.hstack(ts_positions_parlist)
736757

737758
if sort:
738-
# Sort timestamps inside the merged chunk
739-
index_sort = times_chunk.argsort(kind='mergesort')
740-
times_chunk = times_chunk[index_sort]
741-
par_index_chunk = par_index_chunk[index_sort]
742-
if save_pos:
743-
# TODO: reorder positions
744-
pass
759+
# Sort merged timestamps (from all particles)
760+
index_sort = ts_times.argsort(kind='mergesort')
761+
ts_times = ts_times[index_sort]
762+
ts_particles = ts_particles[index_sort]
763+
if position is not None:
764+
ts_positions = ts_positions[index_sort]
745765

746-
return times_chunk, par_index_chunk, positions_chunk
766+
return ts_times, ts_particles, ts_positions
747767

748768
def _sim_timestamps_populations(self, emission, max_rates, populations,
749-
bg_rates, i_start, rs, scale=10,
750-
save_pos=False):
769+
bg_rates, i_start, rs,
770+
position=None, scale=10):
771+
"""Simulate timestamps for all the populations of particles.
772+
773+
This method simulates timestamps for a time-chunk starting at
774+
the trajectory index `i_start` and ending at
775+
`i_start + emission.shape[1]`.
776+
777+
Arguments:
778+
emission (array): 2D array of normalized emission rates
779+
(max emission is 1).
780+
Each row is a particle and each column a time step.
781+
This is emission is for a time-slice starting at `i_start`
782+
in the full trajectory.
783+
max_rates (list): list of max emission rates in Hz for each
784+
population.
785+
populations (list of 2-elemnt tuples): list of populations. Each
786+
population is define as a slice. For example,
787+
slice(4, 7) is a population with particles 4, 5, and 6.
788+
Particle IDs start at 0.
789+
i_start (int): index in the full trajectory where the passed
790+
`emission` array starts.
791+
scale (int): factor to convert a time index to timestamps.
792+
For example, if a simulation has a time-step of
793+
500 nm, and scale = 10, the timestamps will increment in
794+
units of 50 ns.
795+
positions (None or array): array of shape
796+
`(num_particles, num_spatial_dims, num_time_bins)` containing
797+
particle positions for the same time chunk covered by
798+
the `emission` array.
799+
800+
Returns:
801+
3 arrays for the current time-chunk:
802+
- `ts_times`: timestamps with unit `t_step / scale`
803+
- `ts_particles`: particle IDs for each timestamp
804+
- `ts_positions`: particle position for each timestamp
805+
"""
751806
if populations is None:
752807
populations = [slice(0, self.num_particles)]
808+
save_pos = position is not None
753809

754-
# Loop for each population
755-
ts_chunk_pop_list, par_index_chunk_pop_list = [], []
756-
positions_chunk_pop_list = []
810+
times = (i_start + np.arange(emission.shape[1], dtype='int64')) * scale
811+
812+
# These lists will contain one array per population
813+
ts_times_poplist = []
814+
ts_particles_poplist = []
815+
ts_positions_poplist = []
757816
# Loop through populations
758-
for rate, pop, bg in zip(max_rates, populations, bg_rates):
817+
for max_rate, pop, bg_rate in zip(max_rates, populations, bg_rates):
759818
emission_pop = emission[pop]
760-
ts_chunk_pop, par_index_chunk_pop, positions_pop = \
761-
self._sim_timestamps(
762-
rate, bg, emission_pop, i_start, ip_start=pop.start,
763-
rs=rs, scale=scale, sort=False, save_pos=save_pos)
764-
765-
ts_chunk_pop_list.append(ts_chunk_pop)
766-
par_index_chunk_pop_list.append(par_index_chunk_pop)
819+
position_pop = position[pop] if save_pos else None
820+
counts_pop = sim_counts_timetrace_with_bg(
821+
emission_pop, max_rate, bg_rate, self.t_step, rs=rs)
822+
ts_times_pop, ts_particles_pop, ts_positions_pop = \
823+
self._timestamps_from_counts(
824+
counts_pop, times, max_rate=max_rate, bg_rate=bg_rate,
825+
sort=False, position=position_pop)
826+
ts_particles_pop += pop.start
827+
ts_times_poplist.append(ts_times_pop)
828+
ts_particles_poplist.append(ts_particles_pop)
767829
if save_pos:
768-
positions_chunk_pop_list.append(positions_pop)
830+
ts_positions_poplist.append(ts_positions_pop)
769831

770832
# Merge populations
771-
times_chunk_s = np.hstack(ts_chunk_pop_list)
772-
par_index_chunk_s = np.hstack(par_index_chunk_pop_list)
833+
ts_times = np.hstack(ts_times_poplist)
834+
ts_particles = np.hstack(ts_particles_poplist)
773835
if save_pos:
774-
positions_chunk_s = np.hstack(positions_chunk_pop_list)
836+
ts_positions = np.hstack(ts_positions_poplist)
837+
assert ts_positions.shape[-1] == ts_times.shape[0]
775838

776-
# Sort timestamps inside the merged chunk
777-
index_sort = times_chunk_s.argsort(kind='mergesort')
778-
times_chunk_s = times_chunk_s[index_sort]
779-
par_index_chunk_s = par_index_chunk_s[index_sort]
839+
# Sort the merged timestamps (from all populations)
840+
index_sort = ts_times.argsort(kind='mergesort')
841+
ts_times = ts_times[index_sort]
842+
ts_particles = ts_particles[index_sort]
780843
if save_pos:
781-
positions_chunk_s = positions_chunk_s[index_sort]
782-
return times_chunk_s, par_index_chunk_s, positions_chunk_s
844+
ts_positions = ts_positions[:, index_sort]
845+
return ts_times, ts_particles, ts_positions
783846

784847
def simulate_timestamps_mix(self, max_rates, populations, bg_rate,
785848
rs=None, seed=1, chunksize=2**16,
786849
comp_filter=None, overwrite=False,
787850
skip_existing=False, scale=10, save_pos=False,
788851
path=None, t_chunksize=None, timeslice=None):
789-
"""Compute one timestamps array for a mixture of N populations.
852+
"""Compute a timestamps array for a mixture of N populations.
790853
791854
Timestamp data are saved to disk and accessible as pytables arrays in
792855
`._timestamps` and `._tparticles`.
@@ -860,6 +923,7 @@ def simulate_timestamps_mix(self, max_rates, populations, bg_rate,
860923
bg_rates = [None] * (len(max_rates) - 1) + [bg_rate]
861924
prev_time = 0
862925
# Loop through time and for each time-slice simulate all populations
926+
pos_chunk = None
863927
for i_start, i_end in iter_chunk_index(timeslice_size, t_chunksize):
864928

865929
curr_time = np.around(i_start * self.t_step, decimals=0)
@@ -868,22 +932,24 @@ def simulate_timestamps_mix(self, max_rates, populations, bg_rate,
868932
prev_time = curr_time
869933

870934
em_chunk = self.emission[:, i_start:i_end]
935+
if save_pos:
936+
pos_chunk = self.position[:, :, i_start:i_end]
871937

872-
times_chunk_s, par_index_chunk_s, positions_chunk_s = \
938+
ts_times_chunk, ts_particles_chunk, ts_positions_chunk = \
873939
self._sim_timestamps_populations(
874940
em_chunk, max_rates, populations, bg_rates, i_start,
875-
rs, scale, save_pos=save_pos)
941+
rs, scale, position=pos_chunk)
876942

877943
# Save sorted "photons" (suffix '_s')
878-
ts_list.append(times_chunk_s)
879-
part_list.append(par_index_chunk_s)
880-
pos_list.append(positions_chunk_s) # it may be a list of None
944+
ts_list.append(ts_times_chunk)
945+
part_list.append(ts_particles_chunk)
946+
pos_list.append(ts_positions_chunk) # it may be a list of None
881947

882948
for ts, part, pos in zip(ts_list, part_list, pos_list):
883949
self._timestamps.append(ts)
884950
self._tparticles.append(part)
885951
if save_pos:
886-
self._tpositions.append(pos)
952+
self._tpositions.append(pos.T)
887953

888954
# Save current random state so it can be resumed in the next session
889955
self.ts_group._v_attrs['last_random_state'] = rs.get_state()
@@ -1001,12 +1067,12 @@ def simulate_timestamps_mix_da(self, max_rates_d, max_rates_a,
10011067
times_chunk_s_d, par_index_chunk_s_d, _ = \
10021068
self._sim_timestamps_populations(
10031069
em_chunk, max_rates_d, populations, bg_rates_d, i_start,
1004-
rs, scale)
1070+
rs=rs, scale=scale)
10051071

10061072
times_chunk_s_a, par_index_chunk_s_a, _ = \
10071073
self._sim_timestamps_populations(
10081074
em_chunk, max_rates_a, populations, bg_rates_a, i_start,
1009-
rs, scale)
1075+
rs=rs, scale=scale)
10101076

10111077
# Save sorted timestamps (suffix '_s') and corresponding particles
10121078
self._timestamps_d.append(times_chunk_s_d)
@@ -1132,12 +1198,12 @@ def simulate_timestamps_mix_da_online(self, max_rates_d, max_rates_a,
11321198
times_chunk_s_d, par_index_chunk_s_d, _ = \
11331199
self._sim_timestamps_populations(
11341200
em_chunk, max_rates_d, populations, bg_rates_d, i_start,
1135-
rs, scale)
1201+
rs=rs, scale=scale)
11361202

11371203
times_chunk_s_a, par_index_chunk_s_a, _ = \
11381204
self._sim_timestamps_populations(
11391205
em_chunk, max_rates_a, populations, bg_rates_a, i_start,
1140-
rs, scale)
1206+
rs=rs, scale=scale)
11411207

11421208
# Save sorted timestamps (suffix '_s') and corresponding particles
11431209
self._timestamps_d.append(times_chunk_s_d)
@@ -1159,9 +1225,12 @@ def sim_timetrace(emission, max_rate, t_step):
11591225
return np.random.poisson(lam=emission_rates).astype(np.uint8)
11601226

11611227

1162-
def sim_timetrace_bg(emission, max_rate, bg_rate, t_step, rs=None):
1228+
def sim_counts_timetrace_with_bg(emission, max_rate, bg_rate, t_step, rs=None):
11631229
"""Draw random emitted photons from r.v. ~ Poisson(emission_rates).
11641230
1231+
Generate an array of counts on a binned time axis
1232+
for one or more particles. Optionally, adds a trace for background counts.
1233+
11651234
Arguments:
11661235
emission (2D array): array of normalized emission rates. One row per
11671236
particle (axis = 0). Columns are the different time steps.
@@ -1175,9 +1244,9 @@ def sim_timetrace_bg(emission, max_rate, bg_rate, t_step, rs=None):
11751244
11761245
Returns:
11771246
`counts` an 2D uint8 array of counts in each time bin, for each
1178-
particle. If `bg_rate` is None counts.shape == emission.shape.
1247+
particle. If `bg_rate` is None, then `counts.shape == emission.shape`.
11791248
Otherwise, `counts` has one row more than `emission` for storing
1180-
the constant Poisson background.
1249+
the background counts.
11811250
"""
11821251
if rs is None:
11831252
rs = np.random.RandomState()
@@ -1198,6 +1267,7 @@ def sim_timetrace_bg(emission, max_rate, bg_rate, t_step, rs=None):
11981267
counts[-1] = rs.poisson(lam=bg_rate * t_step, size=em.shape[1])
11991268
return counts
12001269

1270+
12011271
def sim_timetrace_bg2(emission, max_rate, bg_rate, t_step, rs=None):
12021272
"""Draw random emitted photons from r.v. ~ Poisson(emission_rates).
12031273

0 commit comments

Comments
 (0)