Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alphatrion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from alphatrion.log.log import log_artifact, log_metrics, log_params, log_result
from alphatrion.log.log import log_artifact, log_dataset, log_metrics, log_params
from alphatrion.runtime.runtime import init

__all__ = [
"init",
"log_artifact",
"log_params",
"log_metrics",
"log_result",
"log_dataset",
]
1 change: 0 additions & 1 deletion alphatrion/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@

# Runtime related envs
ROOT_PATH = "ALPHATRION_ROOT_PATH"
AUTO_CLEANUP = "ALPHATRION_AUTO_CLEANUP"
116 changes: 42 additions & 74 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import asyncio
import json
import os
import tempfile
from collections.abc import Callable
from typing import Any

from alphatrion.runtime.contextvars import current_exp_id, current_run_id
from alphatrion.runtime.runtime import global_runtime
from alphatrion.snapshot.snapshot import (
ExecutionKind,
build_run_execution,
checkpoint_path,
snapshot_path,
)

BEST_RESULT_PATH = "best_result_path"
Expand Down Expand Up @@ -151,79 +150,48 @@ async def log_metrics(metrics: dict[str, float]) -> bool:
return is_best_metric


# log_result is used to log the result of a run/experiment,
# including both input and output, e.g. you want to save the code snippet.
# It will be stored in the object storage as a JSON file if object storage
# is enabled or locally otherwise.
# NOTE: will be deprecated in the v0.3.0, use log_dataset instead.
async def log_result(
output: dict[str, Any],
input: dict[str, Any] | None = None,
phase: str = "success",
kind: ExecutionKind = ExecutionKind.RUN,
):
result = None

if kind == ExecutionKind.RUN:
result = build_run_execution(output=output, input=input, phase=phase)
else:
raise NotImplementedError(
f"Logging record of kind {result.kind} is not implemented yet."
)

# Can I get the file size to store in the database?
# log_records is used to log a list of records, which is similar to log_metrics
# but for tracing the execution of the code.
# async def log_records():

path = snapshot_path()
if os.path.exists(path) is False:
os.makedirs(path, exist_ok=True)

# Will eventually be cleanup on Experiment done() if AUTO_CLEANUP is enabled.
# Considering the record file is small, we just save it locally first.
# If this changes in the future, we should delete them after uploading.
with open(os.path.join(path, "result.json"), "w") as f:
f.write(result.model_dump_json())
async def log_dataset(
name: str,
data: dict[str, Any],
):
"""
Log dataset to the database and artifact registry.

file_size = os.path.getsize(os.path.join(path, "result.json"))
:param name: the name of the dataset.
:param data: the data to be logged, currently support dict only,
will support more types in the future.
"""
runtime = global_runtime()

# If not enabled, only save to local disk.
if runtime.artifact_storage_enabled():
path = await log_artifact(
paths=os.path.join(path, "result.json"),
repo_name="execution",
)
runtime.metadb.update_run(
run_id=current_run_id.get(),
meta={
EXECUTION_RESULT: {
"path": path,
"size": file_size,
"file_name": "result.json",
}
},
)


# log_records is used to log a list of records, which is similar to log_metrics
# but for tracing the execution of the code.
# async def log_records():

# log_dataset will store sometime in the artifacts als record in the database.
# async def log_dataset(
# name: str,
# paths: str | list[str],
# version: str | None = None,
# ):
# path = await log_artifact(
# paths=paths,
# repo_name="dataset",
# version=version,
# )

# runtime = global_runtime()
# runtime.metadb.create_dataset(
# name=name,
# team_id=runtime._team_id,
# path=path,
# version=version,
# )
if isinstance(data, dict):
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with open(name, "w") as f:
f.write(json.dumps(data))

file_size = os.path.getsize(name)

path = await log_artifact(
paths=name,
repo_name="dataset",
)

runtime.metadb.create_dataset(
name=name,
team_id=runtime.team_id,
user_id=runtime.user_id,
path=path,
experiment_id=current_exp_id.get(),
run_id=current_run_id.get(),
meta={"size": file_size},
)
return

raise NotImplementedError(
f"Logging dataset of type {type(data)} is not implemented yet."
)
89 changes: 89 additions & 0 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CreateTeamInput,
CreateUserInput,
DailyTokenUsage,
Dataset,
Experiment,
GraphQLExperimentType,
GraphQLExperimentTypeEnum,
Expand Down Expand Up @@ -631,6 +632,89 @@ def get_daily_token_usage(
print(f"Failed to fetch daily token usage: {e}")
return []

@staticmethod
def list_datasets(
team_id: strawberry.ID,
page: int = 0,
page_size: int = 20,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Dataset]:
metadb = runtime.storage_runtime().metadb
datasets = metadb.list_datasets_by_team_id(
team_id=uuid.UUID(team_id),
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)
return [
Dataset(
id=d.uuid,
name=d.name,
description=d.description,
meta=d.meta,
team_id=d.team_id,
experiment_id=d.experiment_id,
run_id=d.run_id,
user_id=d.user_id,
created_at=d.created_at,
updated_at=d.updated_at,
)
for d in datasets
]

@staticmethod
def get_dataset(id: strawberry.ID) -> Dataset | None:
metadb = runtime.storage_runtime().metadb
dataset = metadb.get_dataset(dataset_id=uuid.UUID(id))
if dataset:
return Dataset(
id=dataset.uuid,
name=dataset.name,
description=dataset.description,
meta=dataset.meta,
team_id=dataset.team_id,
experiment_id=dataset.experiment_id,
run_id=dataset.run_id,
user_id=dataset.user_id,
created_at=dataset.created_at,
updated_at=dataset.updated_at,
)
return None

@staticmethod
def list_datasets_by_experiment(
experiment_id: strawberry.ID,
page: int = 0,
page_size: int = 20,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Dataset]:
metadb = runtime.storage_runtime().metadb
datasets = metadb.list_datasets_by_exp_id(
exp_id=uuid.UUID(experiment_id),
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)
return [
Dataset(
id=d.uuid,
name=d.name,
description=d.description,
meta=d.meta,
team_id=d.team_id,
experiment_id=d.experiment_id,
run_id=d.run_id,
user_id=d.user_id,
created_at=d.created_at,
updated_at=d.updated_at,
)
for d in datasets
]


class GraphQLMutations:
@staticmethod
Expand Down Expand Up @@ -758,3 +842,8 @@ def delete_experiments(experiment_ids: list[strawberry.ID]) -> int:
uuids = [uuid.UUID(exp_id) for exp_id in experiment_ids]
# Soft delete experiments by setting is_del flag
return metadb.delete_experiments(experiment_ids=uuids)

@staticmethod
def delete_dataset(dataset_id: strawberry.ID) -> bool:
metadb = runtime.storage_runtime().metadb
return metadb.delete_dataset(dataset_id=dataset_id)
42 changes: 42 additions & 0 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CreateTeamInput,
CreateUserInput,
DailyTokenUsage,
Dataset,
Experiment,
RemoveUserFromTeamInput,
Run,
Expand Down Expand Up @@ -103,6 +104,43 @@ async def artifact_content(
) -> ArtifactContent:
return await GraphQLResolvers.get_artifact_content(str(team_id), tag, repo_name)

# Dataset queries
@strawberry.field
def datasets(
self,
team_id: strawberry.ID,
page: int = 0,
page_size: int = 20,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Dataset]:
return GraphQLResolvers.list_datasets(
team_id=team_id,
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)

dataset: Dataset | None = strawberry.field(resolver=GraphQLResolvers.get_dataset)

@strawberry.field
def datasets_by_experiment(
self,
experiment_id: strawberry.ID,
page: int = 0,
page_size: int = 20,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Dataset]:
return GraphQLResolvers.list_datasets_by_experiment(
experiment_id=experiment_id,
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)


@strawberry.type
class Mutation:
Expand Down Expand Up @@ -134,5 +172,9 @@ def delete_experiment(self, experiment_id: strawberry.ID) -> bool:
def delete_experiments(self, experiment_ids: list[strawberry.ID]) -> int:
return GraphQLMutations.delete_experiments(experiment_ids=experiment_ids)

@strawberry.mutation
def delete_dataset(self, dataset_id: strawberry.ID) -> bool:
return GraphQLMutations.delete_dataset(dataset_id=dataset_id)


schema = strawberry.Schema(query=Query, mutation=Mutation)
14 changes: 14 additions & 0 deletions alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ class Metric:
created_at: datetime


@strawberry.type
class Dataset:
id: strawberry.ID
name: str
description: str | None
meta: JSON | None
team_id: strawberry.ID
experiment_id: strawberry.ID | None
run_id: strawberry.ID | None
user_id: strawberry.ID
created_at: datetime
updated_at: datetime


# Input types for mutations
@strawberry.input
class CreateUserInput:
Expand Down
15 changes: 0 additions & 15 deletions alphatrion/snapshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
from alphatrion.snapshot.snapshot import (
ExecutionKind,
ExecutionResult,
Metadata,
Spec,
Status,
)

__all__ = [
"ExecutionKind",
"ExecutionResult",
"Metadata",
"Spec",
"Status",
]
Loading
Loading