Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def _start(
name: str,
description: str | None = None,
labels: str | None = None,
tags: list[str] | None = None,
meta: dict | None = None,
params: dict | None = None,
):
Expand Down Expand Up @@ -244,6 +245,7 @@ def _start(
user_id=self._runtime._user_id,
description=description,
labels=labels,
tags=tags,
meta=meta,
params=params,
status=Status.RUNNING,
Expand Down
2 changes: 2 additions & 0 deletions alphatrion/experiment/craft_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def start(
name: str,
description: str | None = None,
labels: str | None = None,
tags: list[str] | None = None,
meta: dict | None = None,
params: dict | None = None,
config: base.ExperimentConfig | None = None,
Expand All @@ -29,6 +30,7 @@ def start(
name=name,
description=description,
labels=labels,
tags=tags,
meta=meta,
params=params,
)
Expand Down
35 changes: 17 additions & 18 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def list_labels_by_exp_id(experiment_id: strawberry.ID) -> list[Label]:
for label in labels
]

@staticmethod
def list_tags_by_exp_id(experiment_id: strawberry.ID) -> list[str]:
metadb = runtime.storage_runtime().metadb
tags = metadb.list_tags_by_exp_id(experiment_id=experiment_id)
return [t.tag for t in tags]

@staticmethod
def list_experiments(
team_id: strawberry.ID,
Expand All @@ -110,26 +116,19 @@ def list_experiments(
order_desc: bool = True,
label_name: str | None = None,
label_value: str | None = None,
tag: str | None = None,
) -> list[Experiment]:
metadb = runtime.storage_runtime().metadb
if label_name:
exps = metadb.list_exps_by_label(
team_id=uuid.UUID(team_id),
label_name=label_name,
label_value=label_value,
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)
else:
exps = metadb.list_exps_by_team_id(
team_id=uuid.UUID(team_id),
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)
exps = metadb.list_experiments(
team_id=uuid.UUID(team_id),
label_name=label_name,
label_value=label_value,
tag=tag,
page=page,
page_size=page_size,
order_by=order_by,
order_desc=order_desc,
)

return [
Experiment(
Expand Down
2 changes: 2 additions & 0 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def experiments(
order_desc: bool = True,
label_name: str | None = None,
label_value: str | None = None,
tag: str | None = None,
) -> list[Experiment]:
return GraphQLResolvers.list_experiments(
team_id=team_id,
Expand All @@ -47,6 +48,7 @@ def experiments(
order_desc=order_desc,
label_name=label_name,
label_value=label_value,
tag=tag,
)

experiment: Experiment | None = strawberry.field(
Expand Down
6 changes: 6 additions & 0 deletions alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def labels(self) -> list[Label]:

return GraphQLResolvers.list_labels_by_exp_id(experiment_id=self.id)

@strawberry.field
def tags(self) -> list[str]:
from .resolvers import GraphQLResolvers

return GraphQLResolvers.list_tags_by_exp_id(experiment_id=self.id)

@strawberry.field
def metrics(self) -> list["Metric"]:
from .resolvers import GraphQLResolvers
Expand Down
1 change: 1 addition & 0 deletions alphatrion/storage/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_experiment(
name: str,
description: str | None = None,
labels: str | None = None,
tags: list[str] | None = None,
meta: dict | None = None,
params: dict | None = None,
) -> int:
Expand Down
21 changes: 21 additions & 0 deletions alphatrion/storage/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,27 @@ class ExperimentLabel(Base):
)


class ExperimentTag(Base):
__tablename__ = "experiment_tags"

uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
team_id = Column(UUID(as_uuid=True), nullable=False)
experiment_id = Column(UUID(as_uuid=True), nullable=False)
tag = Column(String, nullable=False)

created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)

__table_args__ = (
Index("idx_experiment_tag_lookup", "experiment_id", "tag"),
Index("idx_experiment_tag_team", "team_id", "tag"),
)


class Dataset(Base):
__tablename__ = "datasets"

Expand Down
75 changes: 40 additions & 35 deletions alphatrion/storage/sqlstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dataset,
Experiment,
ExperimentLabel,
ExperimentTag,
Metric,
Run,
Status,
Expand Down Expand Up @@ -340,6 +341,7 @@ def create_experiment(
user_id: uuid.UUID,
description: str | None = None,
labels: str | None = None,
tags: list[str] | None = None,
meta: dict | None = None,
params: dict | None = None,
status: Status = Status.PENDING,
Expand Down Expand Up @@ -392,6 +394,16 @@ def create_experiment(
)
session.add(exp_label)

if tags:
for tag in [t.strip() for t in tags]:
if tag:
exp_tag = ExperimentTag(
team_id=team_id,
experiment_id=uid,
tag=tag,
)
session.add(exp_tag)

session.commit()

exp_id = new_exp.uuid
Expand Down Expand Up @@ -430,19 +442,37 @@ def get_exp_by_name(
session.close()
return trial

def list_exps_by_team_id(
def list_experiments(
self,
team_id: uuid.UUID,
label_name: str | None = None,
label_value: str | None = None,
tag: str | None = None,
page: int = 0,
page_size: int = 10,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Experiment]:
session = self._session()
query = session.query(Experiment).filter(
Experiment.team_id == team_id,
Experiment.is_del == 0,
)

if label_name:
query = query.join(
ExperimentLabel, ExperimentLabel.experiment_id == Experiment.uuid
).filter(ExperimentLabel.label_name == label_name)
if label_value is not None:
query = query.filter(ExperimentLabel.label_value == label_value)

if tag:
query = query.join(
ExperimentTag, ExperimentTag.experiment_id == Experiment.uuid
).filter(ExperimentTag.tag == tag)

exps = (
session.query(Experiment)
.filter(Experiment.team_id == team_id, Experiment.is_del == 0)
.order_by(
query.order_by(
getattr(Experiment, order_by).desc()
if order_desc
else getattr(Experiment, order_by)
Expand All @@ -465,41 +495,16 @@ def list_labels_by_exp_id(self, experiment_id: uuid.UUID) -> list[ExperimentLabe
session.close()
return labels

def list_exps_by_label(
self,
team_id: uuid.UUID,
label_name: str,
label_value: str | None = None,
page: int = 0,
page_size: int = 10,
order_by: str = "created_at",
order_desc: bool = True,
) -> list[Experiment]:
def list_tags_by_exp_id(self, experiment_id: uuid.UUID) -> list[ExperimentTag]:
session = self._session()
query = (
session.query(Experiment)
.join(ExperimentLabel, ExperimentLabel.experiment_id == Experiment.uuid)
.filter(
Experiment.team_id == team_id,
Experiment.is_del == 0,
ExperimentLabel.label_name == label_name,
)
)
if label_value is not None:
query = query.filter(ExperimentLabel.label_value == label_value)

exps = (
query.order_by(
getattr(Experiment, order_by).desc()
if order_desc
else getattr(Experiment, order_by)
)
.offset(page * page_size)
.limit(page_size)
tags = (
session.query(ExperimentTag)
.filter(ExperimentTag.experiment_id == experiment_id)
.order_by(ExperimentTag.created_at.asc())
.all()
)
session.close()
return exps
return tags

def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None:
session = self._session()
Expand Down
6 changes: 4 additions & 2 deletions dashboard/src/lib/graphql-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ export const queries = {
`,

listExperiments: `
query ListExperiments($teamId: ID!, $labelName: String, $labelValue: String, $page: Int, $pageSize: Int) {
experiments(teamId: $teamId, labelName: $labelName, labelValue: $labelValue, page: $page, pageSize: $pageSize) {
query ListExperiments($teamId: ID!, $labelName: String, $labelValue: String, $tag: String, $page: Int, $pageSize: Int) {
experiments(teamId: $teamId, labelName: $labelName, labelValue: $labelValue, tag: $tag, page: $page, pageSize: $pageSize) {
id
teamId
userId
Expand All @@ -162,6 +162,7 @@ export const queries = {
name
value
}
tags
duration
status
createdAt
Expand All @@ -185,6 +186,7 @@ export const queries = {
name
value
}
tags
duration
status
createdAt
Expand Down
Loading
Loading