Skip to content

Commit 1b87bef

Browse files
author
Johannes Otepka
committed
corrections to get map and apply to work for balanced and direct view
1 parent 12b1536 commit 1b87bef

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

docs/source/examples/basic_task_label.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,26 @@ def wait(t):
2020

2121
# use load balanced view
2222
bview = rc.load_balanced_view()
23-
ar_list1 = [bview.map_async(wait, [2], label=f"mylabel_{i:02}") for i in range(10)]
24-
bview.wait(ar_list1)
23+
ar_list_b1 = [
24+
bview.map_async(wait, [2], label=f"mylabel_map_{i:02}") for i in range(10)
25+
]
26+
ar_list_b2 = [
27+
bview.apply_async(wait, 2, label=f"mylabel_apply_{i:02}") for i in range(10)
28+
]
29+
bview.wait(ar_list_b1)
30+
bview.wait(ar_list_b2)
31+
2532

2633
# use direct view
2734
dview = rc[:]
28-
ar_list2 = [dview.apply_async(wait, 2, label=f"mylabel_{i + 10:02}") for i in range(10)]
29-
dview.wait(ar_list2)
35+
ar_list_d1 = [
36+
dview.apply_async(wait, 2, label=f"mylabel_map_{i + 10:02}") for i in range(10)
37+
]
38+
ar_list_d2 = [
39+
dview.map_async(wait, [2], label=f"mylabel_apply_{i + 10:02}") for i in range(10)
40+
]
41+
dview.wait(ar_list_d1)
42+
dview.wait(ar_list_d2)
3043

3144
# query database
3245
data = rc.db_query({'label': {"$nin": ""}}, keys=['msg_id', 'label', 'engine_uuid'])

ipyparallel/client/view.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ class View(HasTraits):
9898
block = Bool(False)
9999
track = Bool(False)
100100
targets = Any()
101+
label = Any()
101102

102103
history = List()
103104
outstanding = Set()
104105
results = Dict()
105106
client = Instance('ipyparallel.Client', allow_none=True)
106107

107108
_socket = Any()
108-
_flag_names = List(['targets', 'block', 'track'])
109+
_flag_names = List(['targets', 'block', 'track', 'label'])
109110
_in_sync_results = Bool(False)
110111
_targets = Any()
111112
_idents = Any()
@@ -569,9 +570,12 @@ def _really_apply(
569570
block = self.block if block is None else block
570571
track = self.track if track is None else track
571572
targets = self.targets if targets is None else targets
573+
label = (
574+
self.label if label is None else label
575+
) # comes into play when calling map[_async] (self.label)
572576
label = (
573577
kwargs.pop("label") if "label" in kwargs and label is None else label
574-
) # is this the correct/best way of retieving label?
578+
) # this is required can calling apply[_async]
575579
metadata = dict(label=label)
576580

577581
_idents, _targets = self.client._build_targets(targets)
@@ -658,7 +662,12 @@ def map(
658662

659663
assert len(sequences) > 0, "must have some sequences to map onto!"
660664
pf = ParallelFunction(
661-
self, f, block=block, track=track, return_exceptions=return_exceptions
665+
self,
666+
f,
667+
block=block,
668+
track=track,
669+
return_exceptions=return_exceptions,
670+
label=label,
662671
)
663672
return pf.map(*sequences)
664673

@@ -1308,11 +1317,6 @@ def set_flags(self, **kwargs):
13081317
raise ValueError(f"Invalid timeout: {t}")
13091318

13101319
self.timeout = t
1311-
if 'label' in kwargs:
1312-
l = kwargs['label']
1313-
if not isinstance(l, (str, type(None))):
1314-
raise TypeError(f"Invalid type for label: {type(l)!r}")
1315-
self.label = l
13161320

13171321
@sync_results
13181322
@save_ids
@@ -1384,7 +1388,12 @@ def _really_apply(
13841388
follow = self.follow if follow is None else follow
13851389
timeout = self.timeout if timeout is None else timeout
13861390
targets = self.targets if targets is None else targets
1387-
label = self.label if label is None else label
1391+
label = (
1392+
self.label if label is None else label
1393+
) # comes into play when calling map[_async] (self.label)
1394+
label = (
1395+
kwargs.pop("label") if "label" in kwargs and label is None else label
1396+
) # this is required can calling apply[_async]
13881397

13891398
if not isinstance(retries, int):
13901399
raise TypeError(f'retries must be int, not {type(retries)!r}')

0 commit comments

Comments
 (0)