66import numpy as np
77
88
9- def test_probegroup ():
9+ @pytest .fixture
10+ def probegroup ():
11+ """Fixture: a ProbeGroup with 3 probes, each with device channel indices set."""
1012 probegroup = ProbeGroup ()
11-
1213 nchan = 0
1314 for i in range (3 ):
1415 probe = generate_dummy_probe ()
1516 probe .move ([i * 100 , i * 80 ])
1617 n = probe .get_contact_count ()
17- probe .set_device_channel_indices (np .arange (n )[::- 1 ] + nchan )
18- shank_ids = np .ones (n )
19- shank_ids [: n // 2 ] *= i * 2
20- shank_ids [n // 2 :] *= i * 2 + 1
21- probe .set_shank_ids (shank_ids )
18+ probe .set_device_channel_indices (np .arange (n ) + nchan )
2219 probegroup .add_probe (probe )
23-
2420 nchan += n
21+ return probegroup
2522
23+
24+ def test_probegroup (probegroup ):
2625 indices = probegroup .get_global_device_channel_indices ()
2726
2827 ids = probegroup .get_global_contact_ids ()
2928
3029 df = probegroup .to_dataframe ()
31- # ~ print(df['global_contact_ids'])
3230
3331 arr = probegroup .to_numpy (complete = False )
3432 other = ProbeGroup .from_numpy (arr )
@@ -38,12 +36,6 @@ def test_probegroup():
3836 d = probegroup .to_dict ()
3937 other = ProbeGroup .from_dict (d )
4038
41- # ~ from probeinterface.plotting import plot_probe_group, plot_probe
42- # ~ import matplotlib.pyplot as plt
43- # ~ plot_probe_group(probegroup)
44- # ~ plot_probe_group(other)
45- # ~ plt.show()
46-
4739 # checking automatic generation of ids with new dummy probes
4840 probegroup .probes = []
4941 for i in range (3 ):
@@ -116,6 +108,152 @@ def test_set_contact_ids_rejects_wrong_size():
116108 probe .set_contact_ids (["a" , "b" , "c" ])
117109
118110
111+ # ── get_global_contact_positions() tests ────────────────────────────────────
112+
113+
114+ def test_get_global_contact_positions_shape (probegroup ):
115+ pos = probegroup .get_global_contact_positions ()
116+ assert pos .shape == (probegroup .get_contact_count (), probegroup .ndim )
117+
118+
119+ def test_get_global_contact_positions_matches_per_probe (probegroup ):
120+ pos = probegroup .get_global_contact_positions ()
121+ offset = 0
122+ for probe in probegroup .probes :
123+ n = probe .get_contact_count ()
124+ np .testing .assert_array_equal (pos [offset : offset + n ], probe .contact_positions )
125+ offset += n
126+
127+
128+ def test_get_global_contact_positions_single_probe (probegroup ):
129+ pos = probegroup .get_global_contact_positions ()
130+ np .testing .assert_array_equal (
131+ pos [: probegroup .probes [0 ].get_contact_count ()], probegroup .probes [0 ].contact_positions
132+ )
133+
134+
135+ def test_get_global_contact_positions_3d ():
136+ pg = ProbeGroup ()
137+ for i in range (2 ):
138+ probe = generate_dummy_probe ().to_3d ()
139+ probe .move ([i * 100 , i * 80 , i * 30 ])
140+ pg .add_probe (probe )
141+ pos = pg .get_global_contact_positions ()
142+ assert pos .shape [1 ] == 3
143+ assert pos .shape [0 ] == pg .get_contact_count ()
144+
145+
146+ def test_get_global_contact_positions_reflects_move ():
147+ """Positions should reflect probe movement."""
148+ pg = ProbeGroup ()
149+ probe = generate_dummy_probe ()
150+ original_pos = probe .contact_positions .copy ()
151+ probe .move ([50 , 60 ])
152+ pg .add_probe (probe )
153+ pos = pg .get_global_contact_positions ()
154+ np .testing .assert_array_equal (pos , original_pos + np .array ([50 , 60 ]))
155+
156+
157+ # ── copy() tests ────────────────────────────────────────────────────────────
158+
159+
160+ def test_copy_returns_new_object (probegroup ):
161+ pg_copy = probegroup .copy ()
162+ assert pg_copy is not probegroup
163+ assert len (pg_copy .probes ) == len (probegroup .probes )
164+ for orig , copied in zip (probegroup .probes , pg_copy .probes ):
165+ assert orig is not copied
166+
167+
168+ def test_copy_preserves_positions (probegroup ):
169+ pg_copy = probegroup .copy ()
170+ for orig , copied in zip (probegroup .probes , pg_copy .probes ):
171+ np .testing .assert_array_equal (orig .contact_positions , copied .contact_positions )
172+
173+
174+ def test_copy_preserves_device_channel_indices (probegroup ):
175+ pg_copy = probegroup .copy ()
176+ np .testing .assert_array_equal (
177+ probegroup .get_global_device_channel_indices (),
178+ pg_copy .get_global_device_channel_indices (),
179+ )
180+
181+
182+ def test_copy_does_not_preserve_contact_ids (probegroup ):
183+ """Probe.copy() intentionally does not copy contact_ids."""
184+ pg_copy = probegroup .copy ()
185+ # All contact_ids should be empty strings after copy
186+ assert all (cid == "" for cid in pg_copy .get_global_contact_ids ())
187+
188+
189+ def test_copy_is_independent (probegroup ):
190+ """Mutating the copy must not affect the original."""
191+ original_positions = probegroup .probes [0 ].contact_positions .copy ()
192+ pg_copy = probegroup .copy ()
193+ pg_copy .probes [0 ].move ([999 , 999 ])
194+ np .testing .assert_array_equal (probegroup .probes [0 ].contact_positions , original_positions )
195+
196+
197+ # ── get_slice() tests ───────────────────────────────────────────────────────
198+
199+
200+ def test_get_slice_by_bool (probegroup ):
201+ total = probegroup .get_contact_count ()
202+ sel = np .zeros (total , dtype = bool )
203+ sel [:5 ] = True # first 5 contacts from the first probe
204+ sliced = probegroup .get_slice (sel )
205+ assert sliced .get_contact_count () == 5
206+
207+
208+ def test_get_slice_by_index (probegroup ):
209+ indices = np .array ([0 , 1 , 2 , 33 , 34 ]) # contacts from both probes
210+ sliced = probegroup .get_slice (indices )
211+ assert sliced .get_contact_count () == 5
212+
213+
214+ def test_get_slice_preserves_device_channel_indices (probegroup ):
215+ indices = np .array ([0 , 1 , 2 ])
216+ sliced = probegroup .get_slice (indices )
217+ orig_chans = probegroup .get_global_device_channel_indices ()["device_channel_indices" ][:3 ]
218+ sliced_chans = sliced .get_global_device_channel_indices ()["device_channel_indices" ]
219+ np .testing .assert_array_equal (sliced_chans , orig_chans )
220+
221+
222+ def test_get_slice_preserves_positions (probegroup ):
223+ indices = np .array ([0 , 1 , 2 ])
224+ sliced = probegroup .get_slice (indices )
225+ expected = probegroup .get_global_contact_positions ()[indices ]
226+ np .testing .assert_array_equal (sliced .get_global_contact_positions (), expected )
227+
228+
229+ def test_get_slice_empty_selection (probegroup ):
230+ sliced = probegroup .get_slice (np .array ([], dtype = int ))
231+ assert sliced .get_contact_count () == 0
232+ assert len (sliced .probes ) == 0
233+
234+
235+ def test_get_slice_wrong_bool_size (probegroup ):
236+ with pytest .raises (AssertionError ):
237+ probegroup .get_slice (np .array ([True , False ])) # wrong size
238+
239+
240+ def test_get_slice_out_of_bounds (probegroup ):
241+ total = probegroup .get_contact_count ()
242+ with pytest .raises (AssertionError ):
243+ probegroup .get_slice (np .array ([total + 10 ]))
244+
245+
246+ def test_get_slice_all_contacts (probegroup ):
247+ """Slicing with all contacts should give an equivalent ProbeGroup."""
248+ total = probegroup .get_contact_count ()
249+ sliced = probegroup .get_slice (np .arange (total ))
250+ assert sliced .get_contact_count () == total
251+ np .testing .assert_array_equal (
252+ sliced .get_global_contact_positions (),
253+ probegroup .get_global_contact_positions (),
254+ )
255+
256+
119257if __name__ == "__main__" :
120258 test_probegroup ()
121259 # ~ test_probegroup_3d()
0 commit comments