Skip to content

Commit 259f67d

Browse files
authored
Merge pull request #9178 from OpenMined/eelco/l2-test-flow-fixes
L2 flow fixes
2 parents 909b61d + e151321 commit 259f67d

5 files changed

Lines changed: 59 additions & 10 deletions

File tree

packages/syft/src/syft/service/code/user_code.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,8 @@ def assets(self) -> DictTuple[str, Asset] | SyftError:
724724
all_inputs = {}
725725
inputs = self.input_policy_init_kwargs or {}
726726
for vals in inputs.values():
727-
all_inputs.update(vals)
727+
# Only keep UIDs, filter out Constants
728+
all_inputs.update({k: v for k, v in vals.items() if isinstance(v, UID)})
728729

729730
# map the action_id to the asset
730731
used_assets: list[Asset] = []
@@ -753,20 +754,47 @@ def action_objects(self) -> dict:
753754
action_objects = {
754755
arg_name: str(uid)
755756
for arg_name, uid in all_inputs.items()
756-
if arg_name not in self.assets.keys()
757+
if arg_name not in self.assets.keys() and isinstance(uid, UID)
757758
}
758759

759760
return action_objects
760761

762+
@property
763+
def constants(self) -> dict[str, Constant]:
764+
if not self.input_policy_init_kwargs:
765+
return {}
766+
767+
all_inputs = {}
768+
for vals in self.input_policy_init_kwargs.values():
769+
all_inputs.update(vals)
770+
771+
# filter out the assets
772+
constants = {
773+
arg_name: item
774+
for arg_name, item in all_inputs.items()
775+
if isinstance(item, Constant)
776+
}
777+
778+
return constants
779+
761780
@property
762781
def inputs(self) -> dict:
763782
inputs = {}
764-
if self.action_objects:
765-
inputs["action_objects"] = self.action_objects
766-
if self.assets:
783+
784+
assets = self.assets
785+
action_objects = self.action_objects
786+
constants = self.constants
787+
if action_objects:
788+
inputs["action_objects"] = action_objects
789+
if assets:
767790
inputs["assets"] = {
768791
argument: asset._get_dict_for_user_code_repr()
769-
for argument, asset in self.assets.items()
792+
for argument, asset in assets.items()
793+
}
794+
if self.constants:
795+
inputs["constants"] = {
796+
argument: constant._get_dict_for_user_code_repr()
797+
for argument, constant in constants.items()
770798
}
771799
return inputs
772800

packages/syft/src/syft/service/job/job_stash.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def check_user_code_id(self) -> Self:
156156

157157
@property
158158
def result_id(self) -> UID | None:
159-
if self.result is None:
160-
return None
161-
return self.result.id.id
159+
if isinstance(self.result, ActionObject):
160+
return self.result.id.id
161+
return None
162162

163163
@property
164164
def action_display_name(self) -> str:

packages/syft/src/syft/service/output/output_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def from_ids(
115115
)
116116
else:
117117
job_link = None
118+
119+
if input_ids is not None:
120+
input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
118121
return cls(
119122
output_ids=output_ids,
120123
user_code_link=user_code_link,

packages/syft/src/syft/service/policy/policy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ def transform_kwarg(
275275
return Ok(obj.syft_action_data)
276276
return Ok(self.val)
277277

278+
def _get_dict_for_user_code_repr(self) -> dict[str, Any]:
279+
return self._coll_repr_()
280+
281+
def _coll_repr_(self) -> dict[str, Any]:
282+
return {
283+
"klass": self.klass.__qualname__,
284+
"val": str(self.val),
285+
}
286+
278287

279288
@serializable()
280289
class UserOwned(PolicyRule):

packages/syft/src/syft/service/request/request.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,8 @@ def _create_output_history_for_deposited_result(
792792
if input_policy is not None:
793793
for input_ in input_policy.inputs.values():
794794
input_ids.update(input_)
795+
796+
input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
795797
res = api.services.code.store_execution_output(
796798
user_code_id=code.id,
797799
outputs=result,
@@ -1088,6 +1090,7 @@ def _deposit_result_l2(
10881090
for inps in code.input_policy.inputs.values():
10891091
input_ids.update(inps)
10901092

1093+
input_ids = {k: v for k, v in input_ids.items() if isinstance(v, UID)}
10911094
res = api.services.code.store_execution_output(
10921095
user_code_id=code.id,
10931096
outputs=result,
@@ -1104,7 +1107,13 @@ def _deposit_result_l2(
11041107
else JobStatus.COMPLETED
11051108
)
11061109

1107-
existing_result = job.result.id if job.result is not None else None
1110+
existing_result = None
1111+
if isinstance(job.result, ActionObject):
1112+
existing_result = job.result.id
1113+
elif isinstance(job.result, Err):
1114+
existing_result = job.result
1115+
else:
1116+
existing_result = job.result
11081117
print(
11091118
f"Job({job.id}) Setting new result {existing_result} -> {job_info.result.id}"
11101119
)

0 commit comments

Comments
 (0)