Skip to content

Commit 50800b7

Browse files
authored
Merge pull request #9345 from OpenMined/deposit_result_test
Added admin methods for get and set
2 parents e424573 + 142a6ea commit 50800b7

5 files changed

Lines changed: 106 additions & 2 deletions

File tree

packages/syft/src/syft/service/migration/migration_service.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# stdlib
22
from collections import defaultdict
33
import logging
4+
from typing import Any
45

56
# syft absolute
67
import syft
@@ -16,6 +17,7 @@
1617
from ...types.syft_object import SyftObject
1718
from ...types.syft_object_registry import SyftObjectRegistry
1819
from ...types.twin_object import TwinObject
20+
from ...types.uid import UID
1921
from ..action.action_object import Action
2022
from ..action.action_object import ActionObject
2123
from ..action.action_permissions import ActionObjectPermission
@@ -26,7 +28,10 @@
2628
from ..response import SyftSuccess
2729
from ..service import AbstractService
2830
from ..service import service_method
31+
from ..sync.sync_service import get_store
32+
from ..sync.sync_service import get_store_by_type
2933
from ..user.user_roles import ADMIN_ROLE_LEVEL
34+
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
3035
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
3136
from .object_migration_state import MigrationData
3237
from .object_migration_state import StoreMetadata
@@ -493,3 +498,29 @@ def reset_and_restore(
493498
)
494499

495500
return SyftSuccess(message="Database reset successfully.")
501+
502+
@service_method(
503+
path="migration._get_object",
504+
name="_get_object",
505+
roles=DATA_SCIENTIST_ROLE_LEVEL,
506+
)
507+
def _get_object(
508+
self, context: AuthedServiceContext, uid: UID, object_type: type
509+
) -> Any:
510+
return (
511+
get_store_by_type(context, object_type)
512+
.get_by_uid(credentials=context.credentials, uid=uid)
513+
.unwrap()
514+
)
515+
516+
@service_method(
517+
path="migration._update_object",
518+
name="_update_object",
519+
roles=ADMIN_ROLE_LEVEL,
520+
)
521+
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
522+
return (
523+
get_store(context, object)
524+
.update(credentials=context.credentials, obj=object)
525+
.unwrap()
526+
)

packages/syft/src/syft/service/request/request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat
585585
# which tries to send an email to the admin and ends up here
586586
pass # lets keep going
587587

588+
self.refresh()
588589
if len(self.history) == 0:
589590
return RequestStatus.PENDING
590591

packages/syft/src/syft/service/sync/sync_service.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@
3939

4040

4141
def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash:
42-
if isinstance(item, ActionObject):
42+
return get_store_by_type(context=context, obj_type=type(item))
43+
44+
45+
def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash:
46+
if issubclass(obj_type, ActionObject):
4347
service = context.server.services.action # type: ignore
4448
return service.stash # type: ignore
45-
service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore
49+
service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore
4650
return service.stash
4751

4852

packages/syft/src/syft/types/syft_object.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,17 @@ def make_id(cls, values: Any) -> Any:
429429
__table_coll_widths__: ClassVar[list[str] | None] = None
430430
__table_sort_attr__: ClassVar[str | None] = None
431431

432+
def refresh(self) -> None:
433+
try:
434+
api = self._get_api()
435+
new_object = api.services.migration._get_object(
436+
uid=self.id, object_type=type(self)
437+
)
438+
if type(new_object) == type(self):
439+
self.__dict__.update(new_object.__dict__)
440+
except Exception as _:
441+
return
442+
432443
def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
433444
funcs = print_type_cache[type(self)]
434445
if len(funcs) > 0:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# third party
2+
3+
# syft absolute
4+
import syft as sy
5+
from syft.client.datasite_client import DatasiteClient
6+
from syft.service.action.action_object import ActionObject
7+
from syft.service.dataset.dataset import Dataset
8+
9+
10+
def get_ds_client(client: DatasiteClient) -> DatasiteClient:
11+
client.register(
12+
name="a",
13+
email="a@a.com",
14+
password="asdf",
15+
password_verify="asdf",
16+
)
17+
return client.login(email="a@a.com", password="asdf")
18+
19+
20+
def test_get_set_object(high_worker):
21+
high_client: DatasiteClient = high_worker.root_client
22+
_ = get_ds_client(high_client)
23+
root_datasite_client = high_worker.root_client
24+
dataset = sy.Dataset(
25+
name="local_test",
26+
asset_list=[
27+
sy.Asset(
28+
name="local_test",
29+
data=[1, 2, 3],
30+
mock=[1, 1, 1],
31+
)
32+
],
33+
)
34+
root_datasite_client.upload_dataset(dataset)
35+
dataset = root_datasite_client.datasets[0]
36+
37+
other_dataset = high_client.api.services.migration._get_object(
38+
uid=dataset.id, object_type=Dataset
39+
)
40+
other_dataset.server_uid = dataset.server_uid
41+
assert dataset == other_dataset
42+
other_dataset.name = "new_name"
43+
updated_dataset = high_client.api.services.migration._update_object(
44+
object=other_dataset
45+
)
46+
assert updated_dataset.name == "new_name"
47+
48+
asset = root_datasite_client.datasets[0].assets[0]
49+
source_ao = high_client.api.services.action.get(uid=asset.action_id)
50+
ao = high_client.api.services.migration._get_object(
51+
uid=asset.action_id, object_type=ActionObject
52+
)
53+
ao._set_obj_location_(
54+
high_worker.id,
55+
root_datasite_client.credentials,
56+
)
57+
assert source_ao == ao

0 commit comments

Comments
 (0)