-
Notifications
You must be signed in to change notification settings - Fork 96
Expand file tree
/
Copy pathtest_fetch.py
More file actions
508 lines (391 loc) · 17 KB
/
test_fetch.py
File metadata and controls
508 lines (391 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
"""Tests for the modern fetch API: to_dicts, to_pandas, to_arrays, keys, fetch1"""
import decimal
import itertools
import os
import shutil
from operator import itemgetter
import numpy as np
import pandas
import pytest
import datajoint as dj
from tests import schema
def test_getattribute(subject):
"""Testing fetch with attributes using new API"""
list1 = sorted(subject.proj().to_dicts(), key=itemgetter("subject_id"))
list2 = sorted(subject.keys(), key=itemgetter("subject_id"))
for l1, l2 in zip(list1, list2):
assert l1 == l2, "Primary key is not returned correctly"
tmp = subject.to_arrays(order_by="subject_id")
subject_notes, real_id = subject.to_arrays("subject_notes", "real_id")
np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp["subject_notes"]))
np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"]))
def test_getattribute_for_fetch1(subject):
"""Testing Fetch1.__call__ with attributes"""
assert (subject & "subject_id=10").fetch1("subject_id") == 10
assert (subject & "subject_id=10").fetch1("subject_id", "species") == (
10,
"monkey",
)
def test_order_by(lang, languages):
"""Tests order_by sorting order"""
for ord_name, ord_lang in itertools.product(*2 * [["ASC", "DESC"]]):
cur = lang.to_arrays(order_by=("name " + ord_name, "language " + ord_lang))
languages.sort(key=itemgetter(1), reverse=ord_lang == "DESC")
languages.sort(key=itemgetter(0), reverse=ord_name == "DESC")
for c, l in zip(cur, languages): # noqa: E741
assert np.all(cc == ll for cc, ll in zip(c, l)), "Sorting order is different"
def test_order_by_default(lang, languages):
"""Tests order_by sorting order with defaults"""
cur = lang.to_arrays(order_by=("language", "name DESC"))
languages.sort(key=itemgetter(0), reverse=True)
languages.sort(key=itemgetter(1), reverse=False)
for c, l in zip(cur, languages): # noqa: E741
assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
def test_limit(lang):
"""Test the limit kwarg"""
limit = 4
cur = lang.to_arrays(limit=limit)
assert len(cur) == limit, "Length is not correct"
def test_order_by_limit(lang, languages):
"""Test the combination of order by and limit kwargs"""
cur = lang.to_arrays(limit=4, order_by=["language", "name DESC"])
languages.sort(key=itemgetter(0), reverse=True)
languages.sort(key=itemgetter(1), reverse=False)
assert len(cur) == 4, "Length is not correct"
for c, l in list(zip(cur, languages))[:4]: # noqa: E741
assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
def test_head_tail(schema_any):
"""Test head() and tail() convenience methods"""
query = schema.User * schema.Language
n = 5
# head and tail now return list of dicts
head_result = query.head(n)
assert isinstance(head_result, list)
assert len(head_result) == n
assert all(isinstance(row, dict) for row in head_result)
n = 4
tail_result = query.tail(n)
assert isinstance(tail_result, list)
assert len(tail_result) == n
assert all(isinstance(row, dict) for row in tail_result)
def test_limit_offset(lang, languages):
"""Test the limit and offset kwargs together"""
cur = lang.to_arrays(offset=2, limit=4, order_by=["language", "name DESC"])
languages.sort(key=itemgetter(0), reverse=True)
languages.sort(key=itemgetter(1), reverse=False)
assert len(cur) == 4, "Length is not correct"
for c, l in list(zip(cur, languages[2:6])): # noqa: E741
assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
def test_iter(lang, languages):
"""Test iterator - now lazy streaming"""
languages_copy = languages.copy()
languages_copy.sort(key=itemgetter(0), reverse=True)
languages_copy.sort(key=itemgetter(1), reverse=False)
# Iteration now yields dicts directly
result = list(lang.to_dicts(order_by=["language", "name DESC"]))
for row, (tname, tlang) in list(zip(result, languages_copy)):
assert row["name"] == tname and row["language"] == tlang, "Values are not the same"
def test_keys(lang, languages):
"""test key fetch"""
languages_copy = languages.copy()
languages_copy.sort(key=itemgetter(0), reverse=True)
languages_copy.sort(key=itemgetter(1), reverse=False)
# Use to_arrays for attribute fetch
cur = lang.to_arrays("name", "language", order_by=("language", "name DESC"))
# Use keys() for primary key fetch
cur2 = list(lang.keys(order_by=["language", "name DESC"]))
for c, c2 in zip(zip(*cur), cur2):
assert c == tuple(c2.values()), "Values are not the same"
def test_fetch1_step1(lang, languages):
assert (
lang.contents
== languages
== [
("Fabian", "English"),
("Edgar", "English"),
("Dimitri", "English"),
("Dimitri", "Ukrainian"),
("Fabian", "German"),
("Edgar", "Japanese"),
]
), "Unexpected contents in Language table"
key = {"name": "Edgar", "language": "Japanese"}
true = languages[-1]
dat = (lang & key).fetch1()
for k, (ke, c) in zip(true, dat.items()):
assert k == c == (lang & key).fetch1(ke), "Values are not the same"
def test_misspelled_attribute(schema_any):
"""Test that misspelled attributes raise error"""
with pytest.raises(dj.DataJointError):
(schema.Language & 'lang = "ENGLISH"').to_dicts()
def test_to_dicts(lang):
"""Test to_dicts returns list of dictionaries"""
d = lang.to_dicts()
for dd in d:
assert isinstance(dd, dict)
def test_offset(lang, languages):
"""Tests offset"""
cur = lang.to_arrays(limit=4, offset=1, order_by=["language", "name DESC"])
languages.sort(key=itemgetter(0), reverse=True)
languages.sort(key=itemgetter(1), reverse=False)
assert len(cur) == 4, "Length is not correct"
for c, l in list(zip(cur, languages[1:]))[:4]: # noqa: E741
assert np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
def test_len(lang):
"""Tests __len__"""
assert len(lang.to_arrays()) == len(lang), "__len__ is not behaving properly"
def test_fetch1_step2(lang):
"""Tests whether fetch1 raises error for multiple rows"""
with pytest.raises(dj.DataJointError):
lang.fetch1()
def test_fetch1_step3(lang):
"""Tests whether fetch1 raises error for multiple rows with attribute"""
with pytest.raises(dj.DataJointError):
lang.fetch1("name")
def test_decimal(schema_any):
"""Tests that decimal fields are correctly fetched and used in restrictions, see issue #334"""
rel = schema.DecimalPrimaryKey()
assert len(rel.to_arrays()), "Table DecimalPrimaryKey contents are empty"
rel.insert1([decimal.Decimal("3.1415926")])
keys = rel.to_arrays()
assert len(keys) > 0
assert len(rel & keys[0]) == 1
keys = rel.keys()
assert len(keys) >= 2
assert len(rel & keys[1]) == 1
def test_nullable_numbers(schema_any):
"""test mixture of values and nulls in numeric attributes"""
table = schema.NullableNumbers()
table.insert(
(
(
k,
np.random.randn(),
np.random.randint(-1000, 1000),
np.random.randn(),
)
for k in range(10)
)
)
table.insert1((100, None, None, None))
f, d, i = table.to_arrays("fvalue", "dvalue", "ivalue")
# Check for None in integer column
assert None in i
# Check for None or nan in float columns (None may be returned for nullable fields)
assert any(v is None or (isinstance(v, float) and np.isnan(v)) for v in d)
assert any(v is None or (isinstance(v, float) and np.isnan(v)) for v in f)
def test_to_pandas(subject):
"""Test to_pandas returns DataFrame with primary key as index"""
df = subject.to_pandas(order_by="subject_id")
assert isinstance(df, pandas.DataFrame)
assert df.index.names == subject.primary_key
def test_to_polars(subject):
"""Test to_polars returns polars DataFrame"""
polars = pytest.importorskip("polars")
df = subject.to_polars()
assert isinstance(df, polars.DataFrame)
def test_to_arrow(subject):
"""Test to_arrow returns PyArrow Table"""
pyarrow = pytest.importorskip("pyarrow")
table = subject.to_arrow()
assert isinstance(table, pyarrow.Table)
def test_same_secondary_attribute(schema_any):
children = (schema.Child * schema.Parent().proj()).to_arrays()["name"]
assert len(children) == 1
assert children[0] == "Dan"
def test_query_caching(schema_any):
"""Test query caching with to_arrays"""
# initialize cache directory
os.makedirs(os.path.expanduser("~/dj_query_cache"), exist_ok=True)
with dj.config.override(query_cache=os.path.expanduser("~/dj_query_cache")):
conn = schema.TTest3.connection
# insert sample data and load cache
schema.TTest3.insert([dict(key=100 + i, value=200 + i) for i in range(2)])
conn.set_query_cache(query_cache="main")
cached_res = schema.TTest3().to_arrays()
# attempt to insert while caching enabled
try:
schema.TTest3.insert([dict(key=200 + i, value=400 + i) for i in range(2)])
assert False, "Insert allowed while query caching enabled"
except dj.DataJointError:
conn.set_query_cache()
# insert new data
schema.TTest3.insert([dict(key=600 + i, value=800 + i) for i in range(2)])
# re-enable cache to access old results
conn.set_query_cache(query_cache="main")
previous_cache = schema.TTest3().to_arrays()
# verify properly cached and how to refresh results
assert all([c == p for c, p in zip(cached_res, previous_cache)])
conn.set_query_cache()
uncached_res = schema.TTest3().to_arrays()
assert len(uncached_res) > len(cached_res)
# purge query cache
conn.purge_query_cache()
# reset cache directory state
shutil.rmtree(os.path.expanduser("~/dj_query_cache"), ignore_errors=True)
def test_fetch_group_by(schema_any):
"""
https://github.com/datajoint/datajoint-python/issues/914
"""
assert schema.Parent().keys(order_by="name") == [{"parent_id": 1}]
def test_dj_u_distinct(schema_any):
"""
Test developed to see if removing DISTINCT from the select statement
generation breaks the dj.U universal set implementation
"""
# Contents to be inserted
contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)]
schema.Stimulus.insert(contents)
# Query the whole table
test_query = schema.Stimulus()
# Use dj.U to create a list of unique contrast and brightness combinations
result = dj.U("contrast", "brightness") & test_query
expected_result = [
{"contrast": 2, "brightness": 3},
{"contrast": 3, "brightness": 2},
{"contrast": 5, "brightness": 5},
]
fetched_result = result.to_dicts(order_by=("contrast", "brightness"))
schema.Stimulus.delete_quick()
assert fetched_result == expected_result
def test_backslash(schema_any):
"""
https://github.com/datajoint/datajoint-python/issues/999
"""
expected = "She\\Hulk"
schema.Parent.insert([(2, expected)])
q = schema.Parent & dict(name=expected)
assert q.fetch1("name") == expected
q.delete()
def test_lazy_iteration(lang, languages):
"""Test that iteration is lazy (uses generator)"""
# The new iteration is a generator
iter_obj = iter(lang)
# Should be a generator
import types
assert isinstance(iter_obj, types.GeneratorType)
# Each item should be a dict
first = next(iter_obj)
assert isinstance(first, dict)
assert "name" in first and "language" in first
def test_to_arrays_include_key(lang, languages):
"""Test to_arrays with include_key=True returns keys as list of dicts"""
# Fetch with include_key=True
keys, names, langs = lang.to_arrays("name", "language", include_key=True, order_by="KEY")
# keys should be a list of dicts with primary key columns
assert isinstance(keys, list)
assert all(isinstance(k, dict) for k in keys)
assert all(set(k.keys()) == {"name", "language"} for k in keys)
# names and langs should be numpy arrays
assert isinstance(names, np.ndarray)
assert isinstance(langs, np.ndarray)
# Length should match
assert len(keys) == len(names) == len(langs) == len(languages)
# Keys should match the data
for key, name, language in zip(keys, names, langs):
assert key["name"] == name
assert key["language"] == language
# Keys should be usable for restrictions
first_key = keys[0]
restricted = lang & first_key
assert len(restricted) == 1
assert restricted.fetch1("name") == first_key["name"]
def test_to_arrays_include_key_single_attr(subject):
"""Test to_arrays include_key with single attribute"""
keys, species = subject.to_arrays("species", include_key=True)
assert isinstance(keys, list)
assert isinstance(species, np.ndarray)
assert len(keys) == len(species)
# Verify keys have only primary key columns
assert all("subject_id" in k for k in keys)
def test_to_arrays_without_include_key(lang):
"""Test that to_arrays without include_key doesn't return keys"""
result = lang.to_arrays("name", "language")
# Should return tuple of arrays, not (keys, ...)
assert isinstance(result, tuple)
assert len(result) == 2
names, langs = result
assert isinstance(names, np.ndarray)
assert isinstance(langs, np.ndarray)
def test_to_arrays_inhomogeneous_shapes(schema_any):
"""Test to_arrays handles arrays of different shapes correctly.
Regression test for https://github.com/datajoint/datajoint-python/issues/1380
"""
table = schema.Longblob()
table.delete()
# Insert arrays with different shapes that numpy would try to broadcast
table.insert(
[
{"id": 0, "data": np.random.randn(100)}, # shape (100,)
{"id": 1, "data": np.random.randn(100, 1)}, # shape (100, 1)
{"id": 2, "data": np.random.randn(100, 2)}, # shape (100, 2)
]
)
# This should not raise ValueError
data = table.to_arrays("data", order_by="id")
# Should return object array with 3 elements
assert data.dtype == object
assert len(data) == 3
# Each element should preserve its original shape
assert data[0].shape == (100,)
assert data[1].shape == (100, 1)
assert data[2].shape == (100, 2)
def test_to_arrays_inhomogeneous_shapes_second_axis(schema_any):
"""Test to_arrays handles arrays differing on second axis.
Regression test for https://github.com/datajoint/datajoint-python/issues/1380
"""
table = schema.Longblob()
table.delete()
# Insert arrays with different shapes on second axis
table.insert(
[
{"id": 0, "data": np.random.randn(100)}, # shape (100,)
{"id": 1, "data": np.random.randn(1, 100)}, # shape (1, 100)
{"id": 2, "data": np.random.randn(2, 100)}, # shape (2, 100)
]
)
# This should not raise ValueError
data = table.to_arrays("data", order_by="id")
# Should return object array with 3 elements
assert data.dtype == object
assert len(data) == 3
# Each element should preserve its original shape
assert data[0].shape == (100,)
assert data[1].shape == (1, 100)
assert data[2].shape == (2, 100)
def test_fetch_KEY(lang, languages):
"""Test fetch('KEY') returns list of primary key dicts.
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
import warnings
# Suppress deprecation warning for fetch
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
# fetch('KEY') should return list of primary key dicts
keys = lang.fetch("KEY")
assert isinstance(keys, list)
assert len(keys) == len(languages)
assert all(isinstance(k, dict) for k in keys)
# Primary key is (name, language)
assert all(set(k.keys()) == {"name", "language"} for k in keys)
def test_fetch1_KEY(lang):
"""Test fetch1('KEY') returns primary key dict.
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
key = {"name": "Edgar", "language": "Japanese"}
result = (lang & key).fetch1("KEY")
assert isinstance(result, dict)
assert result == key
def test_fetch_KEY_with_other_attrs(lang):
"""Test fetch('KEY', 'name') returns (keys_list, name_array).
Regression test for https://github.com/datajoint/datajoint-python/issues/1381
"""
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
# fetch('KEY', 'name') should return tuple of (list of dicts, array)
keys, names = lang.fetch("KEY", "name")
assert isinstance(keys, list)
assert all(isinstance(k, dict) for k in keys)
assert isinstance(names, np.ndarray)
assert len(keys) == len(names)