Skip to content

Commit fbab54a

Browse files
author
Johannes Otepka
committed
test code for labels added
1 parent 07ccc38 commit fbab54a

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

ipyparallel/tests/test_label.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Tests for task label functionality"""
2+
3+
# Copyright (c) IPython Development Team.
4+
# Distributed under the terms of the Modified BSD License.
5+
import logging
6+
import os
7+
from unittest import TestCase
8+
9+
import pytest
10+
11+
import ipyparallel as ipp
12+
from ipyparallel.cluster.launcher import LocalControllerLauncher
13+
14+
15+
def speudo_wait(t):
16+
import time
17+
18+
tic = time.time()
19+
print(f"waiting for {t}s...")
20+
# time.sleep(t) # do NOT wait for t seconds to speed up tests
21+
print("done")
22+
return time.time() - tic
23+
24+
25+
class TaskLabelTest:
26+
def setUp(self):
27+
self.cluster = ipp.Cluster(
28+
n=2, log_level=10, controller=self.get_controller_launcher()
29+
)
30+
self.cluster.start_cluster_sync()
31+
32+
self.rc = self.cluster.connect_client_sync()
33+
self.rc.wait_for_engines(n=2)
34+
35+
def get_controller_launcher(self):
36+
raise NotImplementedError
37+
38+
def tearDown(self):
39+
self.cluster.stop_engines()
40+
self.cluster.stop_controller()
41+
# self.cluster.close()
42+
43+
def run_tasks(self, view):
44+
ar_list = []
45+
# use context to set label
46+
with view.temp_flags(label="mylabel_map"):
47+
ar_list.append(view.map_async(speudo_wait, [1.1, 1.2, 1.3, 1.4, 1.5]))
48+
# use set_flags to set label
49+
ar_list.extend(
50+
[
51+
view.set_flags(label=f"mylabel_apply_{i:02}").apply_async(
52+
speudo_wait, 2 + i / 10
53+
)
54+
for i in range(5)
55+
]
56+
)
57+
view.wait(ar_list)
58+
59+
# build list of used labels
60+
map_labels = ["mylabel_map"]
61+
apply_labels = []
62+
for i in range(5):
63+
apply_labels.append(f"mylabel_apply_{i:02}")
64+
return map_labels, apply_labels
65+
66+
def check_labels(self, labels):
67+
# query database
68+
data = self.rc.db_query({'label': {"$nin": ""}}, keys=['msg_id', 'label'])
69+
for d in data:
70+
msg_id = d['msg_id']
71+
label = d['label']
72+
assert label in labels
73+
labels.remove(label)
74+
75+
assert len(labels) == 0
76+
77+
def clear_db(self):
78+
self.rc.purge_everything()
79+
80+
def test_balanced_view(self):
81+
bview = self.rc.load_balanced_view()
82+
map_labels, apply_labels = self.run_tasks(bview)
83+
labels = map_labels * 5 + apply_labels
84+
self.check_labels(labels)
85+
self.clear_db()
86+
87+
def test_direct_view(self):
88+
dview = self.rc[:]
89+
map_labels, apply_labels = self.run_tasks(dview)
90+
labels = map_labels * 2 + apply_labels * 2
91+
self.check_labels(labels)
92+
self.clear_db()
93+
94+
95+
class TestLabelDictDB(TaskLabelTest, TestCase):
96+
def get_controller_launcher(self):
97+
class dictDB(LocalControllerLauncher):
98+
controller_args = ["--dictdb"]
99+
100+
return dictDB
101+
102+
103+
class TestLabelSqliteDB(TaskLabelTest, TestCase):
104+
def get_controller_launcher(self):
105+
class sqliteDB(LocalControllerLauncher):
106+
controller_args = ["--sqlitedb"]
107+
108+
return sqliteDB

0 commit comments

Comments
 (0)