Skip to content

Commit ae0e7b3

Browse files
authored
Add log_dataset API (#197)
* Remove EXECUTION_RESULT from Run Detail page Signed-off-by: kerthcet <kerthcet@gmail.com> * Add log_dataset to APIs Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 757fc62 commit ae0e7b3

13 files changed

Lines changed: 417 additions & 212 deletions

File tree

alphatrion/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from alphatrion.log.log import log_artifact, log_metrics, log_params, log_result
1+
from alphatrion.log.log import log_artifact, log_dataset, log_metrics, log_params
22
from alphatrion.runtime.runtime import init
33

44
__all__ = [
55
"init",
66
"log_artifact",
77
"log_params",
88
"log_metrics",
9-
"log_result",
9+
"log_dataset",
1010
]

alphatrion/envs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@
2121

2222
# Runtime related envs
2323
ROOT_PATH = "ALPHATRION_ROOT_PATH"
24-
AUTO_CLEANUP = "ALPHATRION_AUTO_CLEANUP"

alphatrion/log/log.py

Lines changed: 42 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import asyncio
2+
import json
23
import os
4+
import tempfile
35
from collections.abc import Callable
46
from typing import Any
57

68
from alphatrion.runtime.contextvars import current_exp_id, current_run_id
79
from alphatrion.runtime.runtime import global_runtime
810
from alphatrion.snapshot.snapshot import (
9-
ExecutionKind,
10-
build_run_execution,
1111
checkpoint_path,
12-
snapshot_path,
1312
)
1413

1514
BEST_RESULT_PATH = "best_result_path"
@@ -151,79 +150,48 @@ async def log_metrics(metrics: dict[str, float]) -> bool:
151150
return is_best_metric
152151

153152

154-
# log_result is used to log the result of a run/experiment,
155-
# including both input and output, e.g. you want to save the code snippet.
156-
# It will be stored in the object storage as a JSON file if object storage
157-
# is enabled or locally otherwise.
158-
# NOTE: will be deprecated in the v0.3.0, use log_dataset instead.
159-
async def log_result(
160-
output: dict[str, Any],
161-
input: dict[str, Any] | None = None,
162-
phase: str = "success",
163-
kind: ExecutionKind = ExecutionKind.RUN,
164-
):
165-
result = None
166-
167-
if kind == ExecutionKind.RUN:
168-
result = build_run_execution(output=output, input=input, phase=phase)
169-
else:
170-
raise NotImplementedError(
171-
f"Logging record of kind {result.kind} is not implemented yet."
172-
)
173-
174-
# Can I get the file size to store in the database?
153+
# log_records is used to log a list of records, which is similar to log_metrics
154+
# but for tracing the execution of the code.
155+
# async def log_records():
175156

176-
path = snapshot_path()
177-
if os.path.exists(path) is False:
178-
os.makedirs(path, exist_ok=True)
179157

180-
# Will eventually be cleanup on Experiment done() if AUTO_CLEANUP is enabled.
181-
# Considering the record file is small, we just save it locally first.
182-
# If this changes in the future, we should delete them after uploading.
183-
with open(os.path.join(path, "result.json"), "w") as f:
184-
f.write(result.model_dump_json())
158+
async def log_dataset(
159+
name: str,
160+
data: dict[str, Any],
161+
):
162+
"""
163+
Log dataset to the database and artifact registry.
185164
186-
file_size = os.path.getsize(os.path.join(path, "result.json"))
165+
:param name: the name of the dataset.
166+
:param data: the data to be logged, currently support dict only,
167+
will support more types in the future.
168+
"""
187169
runtime = global_runtime()
188170

189-
# If not enabled, only save to local disk.
190-
if runtime.artifact_storage_enabled():
191-
path = await log_artifact(
192-
paths=os.path.join(path, "result.json"),
193-
repo_name="execution",
194-
)
195-
runtime.metadb.update_run(
196-
run_id=current_run_id.get(),
197-
meta={
198-
EXECUTION_RESULT: {
199-
"path": path,
200-
"size": file_size,
201-
"file_name": "result.json",
202-
}
203-
},
204-
)
205-
206-
207-
# log_records is used to log a list of records, which is similar to log_metrics
208-
# but for tracing the execution of the code.
209-
# async def log_records():
210-
211-
# log_dataset will store sometime in the artifacts als record in the database.
212-
# async def log_dataset(
213-
# name: str,
214-
# paths: str | list[str],
215-
# version: str | None = None,
216-
# ):
217-
# path = await log_artifact(
218-
# paths=paths,
219-
# repo_name="dataset",
220-
# version=version,
221-
# )
222-
223-
# runtime = global_runtime()
224-
# runtime.metadb.create_dataset(
225-
# name=name,
226-
# team_id=runtime._team_id,
227-
# path=path,
228-
# version=version,
229-
# )
171+
if isinstance(data, dict):
172+
with tempfile.TemporaryDirectory() as tmpdir:
173+
os.chdir(tmpdir)
174+
with open(name, "w") as f:
175+
f.write(json.dumps(data))
176+
177+
file_size = os.path.getsize(name)
178+
179+
path = await log_artifact(
180+
paths=name,
181+
repo_name="dataset",
182+
)
183+
184+
runtime.metadb.create_dataset(
185+
name=name,
186+
team_id=runtime.team_id,
187+
user_id=runtime.user_id,
188+
path=path,
189+
experiment_id=current_exp_id.get(),
190+
run_id=current_run_id.get(),
191+
meta={"size": file_size},
192+
)
193+
return
194+
195+
raise NotImplementedError(
196+
f"Logging dataset of type {type(data)} is not implemented yet."
197+
)

alphatrion/server/graphql/resolvers.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CreateTeamInput,
2222
CreateUserInput,
2323
DailyTokenUsage,
24+
Dataset,
2425
Experiment,
2526
GraphQLExperimentType,
2627
GraphQLExperimentTypeEnum,
@@ -631,6 +632,89 @@ def get_daily_token_usage(
631632
print(f"Failed to fetch daily token usage: {e}")
632633
return []
633634

635+
@staticmethod
636+
def list_datasets(
637+
team_id: strawberry.ID,
638+
page: int = 0,
639+
page_size: int = 20,
640+
order_by: str = "created_at",
641+
order_desc: bool = True,
642+
) -> list[Dataset]:
643+
metadb = runtime.storage_runtime().metadb
644+
datasets = metadb.list_datasets_by_team_id(
645+
team_id=uuid.UUID(team_id),
646+
page=page,
647+
page_size=page_size,
648+
order_by=order_by,
649+
order_desc=order_desc,
650+
)
651+
return [
652+
Dataset(
653+
id=d.uuid,
654+
name=d.name,
655+
description=d.description,
656+
meta=d.meta,
657+
team_id=d.team_id,
658+
experiment_id=d.experiment_id,
659+
run_id=d.run_id,
660+
user_id=d.user_id,
661+
created_at=d.created_at,
662+
updated_at=d.updated_at,
663+
)
664+
for d in datasets
665+
]
666+
667+
@staticmethod
668+
def get_dataset(id: strawberry.ID) -> Dataset | None:
669+
metadb = runtime.storage_runtime().metadb
670+
dataset = metadb.get_dataset(dataset_id=uuid.UUID(id))
671+
if dataset:
672+
return Dataset(
673+
id=dataset.uuid,
674+
name=dataset.name,
675+
description=dataset.description,
676+
meta=dataset.meta,
677+
team_id=dataset.team_id,
678+
experiment_id=dataset.experiment_id,
679+
run_id=dataset.run_id,
680+
user_id=dataset.user_id,
681+
created_at=dataset.created_at,
682+
updated_at=dataset.updated_at,
683+
)
684+
return None
685+
686+
@staticmethod
687+
def list_datasets_by_experiment(
688+
experiment_id: strawberry.ID,
689+
page: int = 0,
690+
page_size: int = 20,
691+
order_by: str = "created_at",
692+
order_desc: bool = True,
693+
) -> list[Dataset]:
694+
metadb = runtime.storage_runtime().metadb
695+
datasets = metadb.list_datasets_by_exp_id(
696+
exp_id=uuid.UUID(experiment_id),
697+
page=page,
698+
page_size=page_size,
699+
order_by=order_by,
700+
order_desc=order_desc,
701+
)
702+
return [
703+
Dataset(
704+
id=d.uuid,
705+
name=d.name,
706+
description=d.description,
707+
meta=d.meta,
708+
team_id=d.team_id,
709+
experiment_id=d.experiment_id,
710+
run_id=d.run_id,
711+
user_id=d.user_id,
712+
created_at=d.created_at,
713+
updated_at=d.updated_at,
714+
)
715+
for d in datasets
716+
]
717+
634718

635719
class GraphQLMutations:
636720
@staticmethod
@@ -758,3 +842,8 @@ def delete_experiments(experiment_ids: list[strawberry.ID]) -> int:
758842
uuids = [uuid.UUID(exp_id) for exp_id in experiment_ids]
759843
# Soft delete experiments by setting is_del flag
760844
return metadb.delete_experiments(experiment_ids=uuids)
845+
846+
@staticmethod
847+
def delete_dataset(dataset_id: strawberry.ID) -> bool:
848+
metadb = runtime.storage_runtime().metadb
849+
return metadb.delete_dataset(dataset_id=dataset_id)

alphatrion/server/graphql/schema.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
CreateTeamInput,
1010
CreateUserInput,
1111
DailyTokenUsage,
12+
Dataset,
1213
Experiment,
1314
RemoveUserFromTeamInput,
1415
Run,
@@ -103,6 +104,43 @@ async def artifact_content(
103104
) -> ArtifactContent:
104105
return await GraphQLResolvers.get_artifact_content(str(team_id), tag, repo_name)
105106

107+
# Dataset queries
108+
@strawberry.field
109+
def datasets(
110+
self,
111+
team_id: strawberry.ID,
112+
page: int = 0,
113+
page_size: int = 20,
114+
order_by: str = "created_at",
115+
order_desc: bool = True,
116+
) -> list[Dataset]:
117+
return GraphQLResolvers.list_datasets(
118+
team_id=team_id,
119+
page=page,
120+
page_size=page_size,
121+
order_by=order_by,
122+
order_desc=order_desc,
123+
)
124+
125+
dataset: Dataset | None = strawberry.field(resolver=GraphQLResolvers.get_dataset)
126+
127+
@strawberry.field
128+
def datasets_by_experiment(
129+
self,
130+
experiment_id: strawberry.ID,
131+
page: int = 0,
132+
page_size: int = 20,
133+
order_by: str = "created_at",
134+
order_desc: bool = True,
135+
) -> list[Dataset]:
136+
return GraphQLResolvers.list_datasets_by_experiment(
137+
experiment_id=experiment_id,
138+
page=page,
139+
page_size=page_size,
140+
order_by=order_by,
141+
order_desc=order_desc,
142+
)
143+
106144

107145
@strawberry.type
108146
class Mutation:
@@ -134,5 +172,9 @@ def delete_experiment(self, experiment_id: strawberry.ID) -> bool:
134172
def delete_experiments(self, experiment_ids: list[strawberry.ID]) -> int:
135173
return GraphQLMutations.delete_experiments(experiment_ids=experiment_ids)
136174

175+
@strawberry.mutation
176+
def delete_dataset(self, dataset_id: strawberry.ID) -> bool:
177+
return GraphQLMutations.delete_dataset(dataset_id=dataset_id)
178+
137179

138180
schema = strawberry.Schema(query=Query, mutation=Mutation)

alphatrion/server/graphql/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,20 @@ class Metric:
203203
created_at: datetime
204204

205205

206+
@strawberry.type
207+
class Dataset:
208+
id: strawberry.ID
209+
name: str
210+
description: str | None
211+
meta: JSON | None
212+
team_id: strawberry.ID
213+
experiment_id: strawberry.ID | None
214+
run_id: strawberry.ID | None
215+
user_id: strawberry.ID
216+
created_at: datetime
217+
updated_at: datetime
218+
219+
206220
# Input types for mutations
207221
@strawberry.input
208222
class CreateUserInput:

alphatrion/snapshot/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +0,0 @@
1-
from alphatrion.snapshot.snapshot import (
2-
ExecutionKind,
3-
ExecutionResult,
4-
Metadata,
5-
Spec,
6-
Status,
7-
)
8-
9-
__all__ = [
10-
"ExecutionKind",
11-
"ExecutionResult",
12-
"Metadata",
13-
"Spec",
14-
"Status",
15-
]

0 commit comments

Comments
 (0)