@@ -116,6 +116,196 @@ def test_set_contact_ids_rejects_wrong_size():
116116 probe .set_contact_ids (["a" , "b" , "c" ])
117117
118118
119+ def _make_probegroup (n_probes = 3 ):
120+ """Helper: build a ProbeGroup with device channel indices set."""
121+ probegroup = ProbeGroup ()
122+ nchan = 0
123+ for i in range (n_probes ):
124+ probe = generate_dummy_probe ()
125+ probe .move ([i * 100 , i * 80 ])
126+ n = probe .get_contact_count ()
127+ probe .set_device_channel_indices (np .arange (n ) + nchan )
128+ probe .set_contact_ids ([f"e{ j } " for j in range (nchan , nchan + n )])
129+ nchan += n
130+ probegroup .add_probe (probe )
131+ return probegroup
132+
133+
134+ def _make_probegroup_full (n_probes = 3 ):
135+ """Helper: build a ProbeGroup where **every** probe is added."""
136+ probegroup = ProbeGroup ()
137+ nchan = 0
138+ for i in range (n_probes ):
139+ probe = generate_dummy_probe ()
140+ probe .move ([i * 100 , i * 80 ])
141+ n = probe .get_contact_count ()
142+ probe .set_device_channel_indices (np .arange (n ) + nchan )
143+ probe .set_contact_ids ([f"e{ j } " for j in range (nchan , nchan + n )])
144+ probegroup .add_probe (probe )
145+ nchan += n
146+ return probegroup
147+
148+
149+ # ── copy() tests ────────────────────────────────────────────────────────────
150+
151+
152+ def test_copy_returns_new_object ():
153+ pg = _make_probegroup_full (2 )
154+ pg_copy = pg .copy ()
155+ assert pg_copy is not pg
156+ assert len (pg_copy .probes ) == len (pg .probes )
157+ for orig , copied in zip (pg .probes , pg_copy .probes ):
158+ assert orig is not copied
159+
160+
161+ def test_copy_preserves_positions ():
162+ pg = _make_probegroup_full (2 )
163+ pg_copy = pg .copy ()
164+ for orig , copied in zip (pg .probes , pg_copy .probes ):
165+ np .testing .assert_array_equal (orig .contact_positions , copied .contact_positions )
166+
167+
168+ def test_copy_preserves_device_channel_indices ():
169+ pg = _make_probegroup_full (2 )
170+ pg_copy = pg .copy ()
171+ np .testing .assert_array_equal (
172+ pg .get_global_device_channel_indices (),
173+ pg_copy .get_global_device_channel_indices (),
174+ )
175+
176+
177+ def test_copy_does_not_preserve_contact_ids ():
178+ """Probe.copy() intentionally does not copy contact_ids."""
179+ pg = _make_probegroup_full (2 )
180+ pg_copy = pg .copy ()
181+ # All contact_ids should be empty strings after copy
182+ assert all (cid == "" for cid in pg_copy .get_global_contact_ids ())
183+
184+
185+ def test_copy_is_independent ():
186+ """Mutating the copy must not affect the original."""
187+ pg = _make_probegroup_full (2 )
188+ original_positions = pg .probes [0 ].contact_positions .copy ()
189+ pg_copy = pg .copy ()
190+ pg_copy .probes [0 ].move ([999 , 999 ])
191+ np .testing .assert_array_equal (pg .probes [0 ].contact_positions , original_positions )
192+
193+
194+ # ── get_slice() tests ───────────────────────────────────────────────────────
195+
196+
197+ def test_get_slice_by_bool ():
198+ pg = _make_probegroup_full (2 )
199+ total = pg .get_contact_count ()
200+ sel = np .zeros (total , dtype = bool )
201+ sel [:5 ] = True # first 5 contacts from the first probe
202+ sliced = pg .get_slice (sel )
203+ assert sliced .get_contact_count () == 5
204+
205+
206+ def test_get_slice_by_index ():
207+ pg = _make_probegroup_full (2 )
208+ indices = np .array ([0 , 1 , 2 , 33 , 34 ]) # contacts from both probes
209+ sliced = pg .get_slice (indices )
210+ assert sliced .get_contact_count () == 5
211+
212+
213+ def test_get_slice_preserves_device_channel_indices ():
214+ pg = _make_probegroup_full (2 )
215+ indices = np .array ([0 , 1 , 2 ])
216+ sliced = pg .get_slice (indices )
217+ orig_chans = pg .get_global_device_channel_indices ()["device_channel_indices" ][:3 ]
218+ sliced_chans = sliced .get_global_device_channel_indices ()["device_channel_indices" ]
219+ np .testing .assert_array_equal (sliced_chans , orig_chans )
220+
221+
222+ def test_get_slice_preserves_positions ():
223+ pg = _make_probegroup_full (2 )
224+ indices = np .array ([0 , 1 , 2 ])
225+ sliced = pg .get_slice (indices )
226+ expected = pg .get_global_contact_positions ()[indices ]
227+ np .testing .assert_array_equal (sliced .get_global_contact_positions (), expected )
228+
229+
230+ def test_get_slice_empty_selection ():
231+ pg = _make_probegroup_full (2 )
232+ sliced = pg .get_slice (np .array ([], dtype = int ))
233+ assert sliced .get_contact_count () == 0
234+ assert len (sliced .probes ) == 0
235+
236+
237+ def test_get_slice_wrong_bool_size ():
238+ pg = _make_probegroup_full (2 )
239+ with pytest .raises (AssertionError ):
240+ pg .get_slice (np .array ([True , False ])) # wrong size
241+
242+
243+ def test_get_slice_out_of_bounds ():
244+ pg = _make_probegroup_full (2 )
245+ total = pg .get_contact_count ()
246+ with pytest .raises (AssertionError ):
247+ pg .get_slice (np .array ([total + 10 ]))
248+
249+
250+ def test_get_slice_all_contacts ():
251+ """Slicing with all contacts should give an equivalent ProbeGroup."""
252+ pg = _make_probegroup_full (2 )
253+ total = pg .get_contact_count ()
254+ sliced = pg .get_slice (np .arange (total ))
255+ assert sliced .get_contact_count () == total
256+ np .testing .assert_array_equal (
257+ sliced .get_global_contact_positions (),
258+ pg .get_global_contact_positions (),
259+ )
260+
261+
262+ # ── get_global_contact_positions() tests ────────────────────────────────────
263+
264+
265+ def test_get_global_contact_positions_shape ():
266+ pg = _make_probegroup_full (3 )
267+ pos = pg .get_global_contact_positions ()
268+ assert pos .shape == (pg .get_contact_count (), pg .ndim )
269+
270+
271+ def test_get_global_contact_positions_matches_per_probe ():
272+ pg = _make_probegroup_full (3 )
273+ pos = pg .get_global_contact_positions ()
274+ offset = 0
275+ for probe in pg .probes :
276+ n = probe .get_contact_count ()
277+ np .testing .assert_array_equal (pos [offset : offset + n ], probe .contact_positions )
278+ offset += n
279+
280+
281+ def test_get_global_contact_positions_single_probe ():
282+ pg = _make_probegroup_full (1 )
283+ pos = pg .get_global_contact_positions ()
284+ np .testing .assert_array_equal (pos , pg .probes [0 ].contact_positions )
285+
286+
287+ def test_get_global_contact_positions_3d ():
288+ pg = ProbeGroup ()
289+ for i in range (2 ):
290+ probe = generate_dummy_probe ().to_3d ()
291+ probe .move ([i * 100 , i * 80 , i * 30 ])
292+ pg .add_probe (probe )
293+ pos = pg .get_global_contact_positions ()
294+ assert pos .shape [1 ] == 3
295+ assert pos .shape [0 ] == pg .get_contact_count ()
296+
297+
298+ def test_get_global_contact_positions_reflects_move ():
299+ """Positions should reflect probe movement."""
300+ pg = ProbeGroup ()
301+ probe = generate_dummy_probe ()
302+ original_pos = probe .contact_positions .copy ()
303+ probe .move ([50 , 60 ])
304+ pg .add_probe (probe )
305+ pos = pg .get_global_contact_positions ()
306+ np .testing .assert_array_equal (pos , original_pos + np .array ([50 , 60 ]))
307+
308+
119309if __name__ == "__main__" :
120310 test_probegroup ()
121311 # ~ test_probegroup_3d()
0 commit comments