Skip to content

Commit 57b21ac

Browse files
authored
[Workflow API] Resolve issues in setting unserializable objects as private attributes (#1605)
* Delayed aggregator initialization Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Added test case Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Added test case Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Documentation update Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Incorporated review comments Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Incorporated review comments. Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Incorporated review comments Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> * Incorporated review comments. Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com> --------- Signed-off-by: Sachin Gupta <sachin.gupta.dsp@gmail.com>
1 parent 518afb7 commit 57b21ac

6 files changed

Lines changed: 308 additions & 31 deletions

File tree

docs/about/features_index/workflowinterface.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ Some important points to remember while creating callback function and private a
236236
- If no Callback Function or private attributes is specified then the Participant shall not have any *private attributes*
237237
- In above example multiple collaborators have the same callback function or private attributes. Depending on the Federated Learning requirements, user can specify unique callback function or private attributes for each Participant
238238
- *Private attributes* needs to be set after instantiating the participant.
239+
- **Known Limitations**: When using a `callable` to initialize *private attributes* that are **not serializable**, users should be aware of following limitations:
240+
* `checkpoint` should not be enabled with `LocalRuntime`. Users should ensure that default (disabled) setting of checkpoint is used or it is explicitly disabled :code:`flow = FederatedFlow( ..., checkpoint = false)`
241+
* filtering of attributes (via `include` or `exclude`) cannot be used during the transition from aggregator step to collaborator steps. This limitation applies to **all attributes** if any non-serializable private attribute is present in aggregator. The flow logic must be updated to avoid filtering in steps that transition control from aggregator to collaborators
239242

240243
Now let's see how the runtime for a flow is assigned, and the flow gets run:
241244

@@ -558,6 +561,7 @@ In a distributed environment consisting of Director, Envoys and User Node (where
558561

559562
**IMPORTANT**: While this information is useful for debugging, depending on your workflow it may require significant disk space. For this reason, checkpoint is disabled by default.
560563

564+
561565
Future Plans
562566
==============
563567
Following functionalities are planned for inclusion in future releases of the Workflow Interface:

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,14 @@ def _setup_initial_state(self) -> None:
170170
"""
171171
self._metaflow_interface = MetaflowInterface(self.__class__, self.runtime.backend)
172172
self._run_id = self._metaflow_interface.create_run()
173-
# Initialize aggregator private attributes
174-
self.runtime.initialize_aggregator()
175-
self._foreach_methods = []
176173
FLSpec._reset_clones()
177174
FLSpec._create_clones(self, self.runtime.collaborators)
178-
# Initialize collaborator private attributes
175+
176+
# Initialize participant private attributes
177+
self.runtime.initialize_aggregator()
179178
self.runtime.initialize_collaborators()
179+
self._foreach_methods = []
180+
180181
if self._checkpoint:
181182
print(f"Created flow {self.__class__.__name__}")
182183

tests/end_to_end/test_suites/wf_local_func_tests.py

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,32 @@
88
import random
99
from metaflow import Step
1010

11-
from tests.end_to_end.utils.wf_common_fixtures import fx_local_federated_workflow, fx_local_federated_workflow_prvt_attr
11+
from tests.end_to_end.utils.wf_common_fixtures import (
12+
fx_local_federated_workflow,
13+
fx_local_federated_workflow_prvt_attr,
14+
fx_local_fed_wf_unserializable_pvt_attrs,
15+
)
16+
1217
from tests.end_to_end.workflow.exclude_flow import TestFlowExclude
1318
from tests.end_to_end.workflow.include_exclude_flow import TestFlowIncludeExclude
1419
from tests.end_to_end.workflow.include_flow import TestFlowInclude
1520
from tests.end_to_end.workflow.internal_loop import TestFlowInternalLoop
1621
from tests.end_to_end.workflow.reference_flow import TestFlowReference
1722
from tests.end_to_end.workflow.subset_flow import TestFlowSubsetCollaborators
18-
from tests.end_to_end.workflow.private_attr_wo_callable import TestFlowPrivateAttributesWoCallable
23+
from tests.end_to_end.workflow.private_attr_wo_callable import (
24+
TestFlowPrivateAttributesWoCallable,
25+
)
1926
from tests.end_to_end.workflow.private_attributes_flow import TestFlowPrivateAttributes
2027
from tests.end_to_end.workflow.private_attr_both import TestFlowPrivateAttributesBoth
28+
from tests.end_to_end.workflow.unserializable_private_attr import (
29+
TestFlowUnserializablePrivateAttributes,
30+
)
2131

2232
from tests.end_to_end.utils import wf_helper as wf_helper
2333

2434
log = logging.getLogger(__name__)
2535

36+
2637
def test_exclude_flow(request, fx_local_federated_workflow):
2738
"""
2839
Test if variable is excluded, variables not show in next step
@@ -73,7 +84,9 @@ def test_internal_loop(request, fx_local_federated_workflow):
7384
model = None
7485
optimizer = None
7586

76-
flflow = TestFlowInternalLoop(model, optimizer, request.config.num_rounds, checkpoint=True)
87+
flflow = TestFlowInternalLoop(
88+
model, optimizer, request.config.num_rounds, checkpoint=True
89+
)
7790
flflow.runtime = fx_local_federated_workflow.runtime
7891
flflow.run()
7992

@@ -87,25 +100,37 @@ def test_internal_loop(request, fx_local_federated_workflow):
87100
"end",
88101
]
89102

90-
steps_present_in_cli, missing_steps_in_cli, extra_steps_in_cli = wf_helper.validate_flow(
91-
flflow, expected_flow_steps
92-
)
93-
94-
assert len(steps_present_in_cli) == len(expected_flow_steps), "Number of steps fetched from Datastore through CLI do not match the Expected steps provided"
95-
assert len(missing_steps_in_cli) == 0, f"Following steps missing from Datastore: {missing_steps_in_cli}"
96-
assert len(extra_steps_in_cli) == 0, f"Following steps are extra in Datastore: {extra_steps_in_cli}"
103+
steps_present_in_cli, missing_steps_in_cli, extra_steps_in_cli = (
104+
wf_helper.validate_flow(flflow, expected_flow_steps)
105+
)
106+
107+
assert len(steps_present_in_cli) == len(
108+
expected_flow_steps
109+
), "Number of steps fetched from Datastore through CLI do not match the Expected steps provided"
110+
assert (
111+
len(missing_steps_in_cli) == 0
112+
), f"Following steps missing from Datastore: {missing_steps_in_cli}"
113+
assert (
114+
len(extra_steps_in_cli) == 0
115+
), f"Following steps are extra in Datastore: {extra_steps_in_cli}"
97116
assert flflow.end_count == 1, "End function called more than one time"
98117

99-
log.info("\n Summary of internal flow testing \n"
100-
"No issues found and below are the tests that ran successfully\n"
101-
"1. Number of training completed is equal to training rounds\n"
102-
"2. CLI steps and Expected steps are matching\n"
103-
"3. Number of tasks are aligned with number of rounds and number of collaborators\n"
104-
"4. End function executed one time")
118+
log.info(
119+
"\n Summary of internal flow testing \n"
120+
"No issues found and below are the tests that ran successfully\n"
121+
"1. Number of training completed is equal to training rounds\n"
122+
"2. CLI steps and Expected steps are matching\n"
123+
"3. Number of tasks are aligned with number of rounds and number of collaborators\n"
124+
"4. End function executed one time"
125+
)
105126
log.info("Successfully ended test_internal_loop")
106127

107128

108-
@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborator_private_attr_index", "int", None )], indirect=True)
129+
@pytest.mark.parametrize(
130+
"fx_local_federated_workflow",
131+
[("init_collaborator_private_attr_index", "int", None)],
132+
indirect=True,
133+
)
109134
def test_reference_flow(request, fx_local_federated_workflow):
110135
"""
111136
Test reference variables matched through out the flow
@@ -118,7 +143,12 @@ def test_reference_flow(request, fx_local_federated_workflow):
118143
flflow.run()
119144
log.info("Successfully ended test_reference_flow")
120145

121-
@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborator_private_attr_name", "str", None )], indirect=True)
146+
147+
@pytest.mark.parametrize(
148+
"fx_local_federated_workflow",
149+
[("init_collaborator_private_attr_name", "str", None)],
150+
indirect=True,
151+
)
122152
def test_subset_collaborators(request, fx_local_federated_workflow):
123153
"""
124154
Test the subset of collaborators in a federated workflow.
@@ -158,16 +188,16 @@ def test_subset_collaborators(request, fx_local_federated_workflow):
158188
)
159189

160190
assert len(list(step)) == len(subset_collaborators), (
161-
f"...Flow only ran for {len(list(step))} "
162-
+ f"instead of the {len(subset_collaborators)} expected "
163-
+ f"collaborators- Testcase Failed."
164-
)
191+
f"...Flow only ran for {len(list(step))} "
192+
+ f"instead of the {len(subset_collaborators)} expected "
193+
+ f"collaborators- Testcase Failed."
194+
)
165195
log.info(
166196
f"Found {len(list(step))} tasks for each of the "
167197
+ f"{len(subset_collaborators)} collaborators"
168198
)
169-
log.info(f'subset_collaborators = {subset_collaborators}')
170-
log.info(f'collaborators_ran = {collaborators_ran}')
199+
log.info(f"subset_collaborators = {subset_collaborators}")
200+
log.info(f"collaborators_ran = {collaborators_ran}")
171201
for collaborator_name in subset_collaborators:
172202
assert collaborator_name in collaborators_ran, (
173203
f"...Flow did not execute for "
@@ -177,7 +207,8 @@ def test_subset_collaborators(request, fx_local_federated_workflow):
177207

178208
log.info(
179209
f"Testing FederatedFlow - Ending test for validating "
180-
+ f"the subset of collaborators.")
210+
+ f"the subset of collaborators."
211+
)
181212
log.info("Successfully ended test_subset_collaborators")
182213

183214

@@ -194,7 +225,11 @@ def test_private_attr_wo_callable(request, fx_local_federated_workflow_prvt_attr
194225
log.info("Successfully ended test_private_attr_wo_callable")
195226

196227

197-
@pytest.mark.parametrize("fx_local_federated_workflow", [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np" )], indirect=True)
228+
@pytest.mark.parametrize(
229+
"fx_local_federated_workflow",
230+
[("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np")],
231+
indirect=True,
232+
)
198233
def test_private_attributes(request, fx_local_federated_workflow):
199234
"""
200235
Set private attribute through callable function
@@ -208,7 +243,11 @@ def test_private_attributes(request, fx_local_federated_workflow):
208243
log.info("Successfully ended test_private_attributes")
209244

210245

211-
@pytest.mark.parametrize("fx_local_federated_workflow_prvt_attr", [("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np" )], indirect=True)
246+
@pytest.mark.parametrize(
247+
"fx_local_federated_workflow_prvt_attr",
248+
[("init_collaborate_pvt_attr_np", "int", "init_agg_pvt_attr_np")],
249+
indirect=True,
250+
)
212251
def test_private_attr_both(request, fx_local_federated_workflow_prvt_attr):
213252
"""
214253
Set private attribute through callable function and direct assignment
@@ -220,3 +259,25 @@ def test_private_attr_both(request, fx_local_federated_workflow_prvt_attr):
220259
log.info(f"Starting round {i}...")
221260
flflow.run()
222261
log.info("Successfully ended test_private_attr_both")
262+
263+
264+
@pytest.mark.parametrize(
265+
"fx_local_fed_wf_unserializable_pvt_attrs",
266+
[
267+
("callable_to_init_collab_unserializable_pvt_attrs",
268+
"int",
269+
"callable_to_init_agg_unserializable_pvt_attrs")
270+
],
271+
indirect=True,
272+
)
273+
def test_unserializable_private_attr(
274+
request, fx_local_fed_wf_unserializable_pvt_attrs
275+
):
276+
"""
277+
Validate unserializable objects are accessible as private attributes
278+
"""
279+
log.info("Starting Test for unserializable private attributes")
280+
flflow = TestFlowUnserializablePrivateAttributes(rounds=request.config.num_rounds, checkpoint=False)
281+
flflow.runtime = fx_local_fed_wf_unserializable_pvt_attrs.runtime
282+
flflow.run()
283+
log.info("Successfully ended Test for unserializable private attributes")

tests/end_to_end/utils/wf_common_fixtures.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,62 @@ def fx_local_federated_workflow_prvt_attr(request):
148148
collaborators=collaborators_list,
149149
runtime=local_runtime,
150150
)
151+
152+
153+
@pytest.fixture(scope="function")
154+
def fx_local_fed_wf_unserializable_pvt_attrs(request):
155+
"""
156+
Fixture to set up a local federated workflow for testing.
157+
This fixture initializes an `Aggregator` and sets up a list of collaborators
158+
based on the number specified in the test configuration. It also configures
159+
a `LocalRuntime` with the aggregator, collaborators, and an optional backend
160+
if specified in the test configuration.
161+
Args:
162+
request (FixtureRequest): The pytest request object that provides access
163+
to the test configuration.
164+
Yields:
165+
LocalRuntime: An instance of `LocalRuntime` configured with the aggregator,
166+
collaborators, and backend.
167+
"""
168+
# Inline import
169+
from tests.end_to_end.utils.wf_helper import (
170+
callable_to_init_agg_unserializable_pvt_attrs,
171+
callable_to_init_collab_unserializable_pvt_attrs
172+
)
173+
collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None
174+
collab_value = request.param[1] if hasattr(request, 'param') and request.param else None
175+
agg_callback_func = request.param[2] if hasattr(request, 'param') and request.param else None
176+
177+
# Get the callback functions from the locals using string
178+
collab_callback_func_name = locals()[collab_callback_func] if collab_callback_func else None
179+
agg_callback_func_name = locals()[agg_callback_func] if agg_callback_func else None
180+
collaborators_list = []
181+
182+
# Setup aggregator
183+
if agg_callback_func_name:
184+
aggregator = Aggregator(name="agg",
185+
private_attributes_callable=agg_callback_func_name)
186+
else:
187+
aggregator = Aggregator()
188+
189+
# Setup collaborators
190+
for i in range(request.config.num_collaborators):
191+
func_var = i if collab_value == "int" else f"collaborator{i}" if collab_value == "str" else None
192+
collab = Collaborator(
193+
name=f"collaborator{i}",
194+
private_attributes_callable=collab_callback_func_name
195+
)
196+
collaborators_list.append(collab)
197+
198+
workflow_backend = request.config.workflow_backend if hasattr(request.config, 'workflow_backend') else None
199+
if workflow_backend:
200+
local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators_list, backend=workflow_backend)
201+
else:
202+
local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators_list)
203+
204+
# Return the federation fixture
205+
return workflow_local_fixture(
206+
aggregator=aggregator,
207+
collaborators=collaborators_list,
208+
runtime=local_runtime,
209+
)

tests/end_to_end/utils/wf_helper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from metaflow import Flow
55
import logging
66
import numpy as np
7+
from openfl.databases import TensorDB
8+
from openfl.utilities import TensorKey
79

810
import tests.end_to_end.utils.exceptions as ex
911

@@ -116,6 +118,20 @@ def init_agg_pvt_attr_np():
116118
return {"test_loader": np.random.rand(10, 28, 28)}
117119

118120

121+
def callable_to_init_collab_unserializable_pvt_attrs():
122+
"""
123+
Create and return a TensorDB
124+
"""
125+
return {"col_tensor_db": TensorDB()}
126+
127+
128+
def callable_to_init_agg_unserializable_pvt_attrs():
129+
"""
130+
Create and return a TensorDB
131+
"""
132+
return {"agg_tensor_db": TensorDB()}
133+
134+
119135
def run_notebook(notebook_path, output_notebook_path):
120136
"""
121137
Function to run the notebook.

0 commit comments

Comments
 (0)