diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index ec11e5f7..f7fe68bc 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -206,6 +206,7 @@ def _start( name: str, description: str | None = None, labels: str | None = None, + tags: list[str] | None = None, meta: dict | None = None, params: dict | None = None, ): @@ -244,6 +245,7 @@ def _start( user_id=self._runtime._user_id, description=description, labels=labels, + tags=tags, meta=meta, params=params, status=Status.RUNNING, diff --git a/alphatrion/experiment/craft_experiment.py b/alphatrion/experiment/craft_experiment.py index 84c9fd7e..d84aca71 100644 --- a/alphatrion/experiment/craft_experiment.py +++ b/alphatrion/experiment/craft_experiment.py @@ -15,6 +15,7 @@ def start( name: str, description: str | None = None, labels: str | None = None, + tags: list[str] | None = None, meta: dict | None = None, params: dict | None = None, config: base.ExperimentConfig | None = None, @@ -29,6 +30,7 @@ def start( name=name, description=description, labels=labels, + tags=tags, meta=meta, params=params, ) diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index 1850a939..f7c8d8cd 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -101,6 +101,12 @@ def list_labels_by_exp_id(experiment_id: strawberry.ID) -> list[Label]: for label in labels ] + @staticmethod + def list_tags_by_exp_id(experiment_id: strawberry.ID) -> list[str]: + metadb = runtime.storage_runtime().metadb + tags = metadb.list_tags_by_exp_id(experiment_id=experiment_id) + return [t.tag for t in tags] + @staticmethod def list_experiments( team_id: strawberry.ID, @@ -110,26 +116,19 @@ def list_experiments( order_desc: bool = True, label_name: str | None = None, label_value: str | None = None, + tag: str | None = None, ) -> list[Experiment]: metadb = runtime.storage_runtime().metadb - if label_name: - exps = metadb.list_exps_by_label( - team_id=uuid.UUID(team_id), - label_name=label_name, - label_value=label_value, - page=page, - page_size=page_size, - order_by=order_by, - order_desc=order_desc, - ) - else: - exps = metadb.list_exps_by_team_id( - team_id=uuid.UUID(team_id), - page=page, - page_size=page_size, - order_by=order_by, - order_desc=order_desc, - ) + exps = metadb.list_experiments( + team_id=uuid.UUID(team_id), + label_name=label_name, + label_value=label_value, + tag=tag, + page=page, + page_size=page_size, + order_by=order_by, + order_desc=order_desc, + ) return [ Experiment( diff --git a/alphatrion/server/graphql/schema.py b/alphatrion/server/graphql/schema.py index 8c1cad47..d7c25d49 100644 --- a/alphatrion/server/graphql/schema.py +++ b/alphatrion/server/graphql/schema.py @@ -38,6 +38,7 @@ def experiments( order_desc: bool = True, label_name: str | None = None, label_value: str | None = None, + tag: str | None = None, ) -> list[Experiment]: return GraphQLResolvers.list_experiments( team_id=team_id, @@ -47,6 +48,7 @@ def experiments( order_desc=order_desc, label_name=label_name, label_value=label_value, + tag=tag, ) experiment: Experiment | None = strawberry.field( diff --git a/alphatrion/server/graphql/types.py b/alphatrion/server/graphql/types.py index d65c41cb..fb9b5521 100644 --- a/alphatrion/server/graphql/types.py +++ b/alphatrion/server/graphql/types.py @@ -150,6 +150,12 @@ def labels(self) -> list[Label]: return GraphQLResolvers.list_labels_by_exp_id(experiment_id=self.id) + @strawberry.field + def tags(self) -> list[str]: + from .resolvers import GraphQLResolvers + + return GraphQLResolvers.list_tags_by_exp_id(experiment_id=self.id) + @strawberry.field def metrics(self) -> list["Metric"]: from .resolvers import GraphQLResolvers diff --git a/alphatrion/storage/metastore.py b/alphatrion/storage/metastore.py index f0c9d751..785adf32 100644 --- a/alphatrion/storage/metastore.py +++ b/alphatrion/storage/metastore.py @@ -73,6 +73,7 @@ def create_experiment( name: str, description: str | None = None, labels: str | None = None, + tags: list[str] | None = None, meta: dict | None = None, params: dict | None = None, ) -> int: diff --git a/alphatrion/storage/sql_models.py b/alphatrion/storage/sql_models.py index 9ede15fe..5c0d8597 100644 --- a/alphatrion/storage/sql_models.py +++ b/alphatrion/storage/sql_models.py @@ -303,6 +303,27 @@ class ExperimentLabel(Base): ) +class ExperimentTag(Base): + __tablename__ = "experiment_tags" + + uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + team_id = Column(UUID(as_uuid=True), nullable=False) + experiment_id = Column(UUID(as_uuid=True), nullable=False) + tag = Column(String, nullable=False) + + created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + ) + + __table_args__ = ( + Index("idx_experiment_tag_lookup", "experiment_id", "tag"), + Index("idx_experiment_tag_team", "team_id", "tag"), + ) + + class Dataset(Base): __tablename__ = "datasets" diff --git a/alphatrion/storage/sqlstore.py b/alphatrion/storage/sqlstore.py index 1db83359..2c9648ce 100644 --- a/alphatrion/storage/sqlstore.py +++ b/alphatrion/storage/sqlstore.py @@ -10,6 +10,7 @@ Dataset, Experiment, ExperimentLabel, + ExperimentTag, Metric, Run, Status, @@ -340,6 +341,7 @@ def create_experiment( user_id: uuid.UUID, description: str | None = None, labels: str | None = None, + tags: list[str] | None = None, meta: dict | None = None, params: dict | None = None, status: Status = Status.PENDING, @@ -392,6 +394,16 @@ def create_experiment( ) session.add(exp_label) + if tags: + for tag in [t.strip() for t in tags]: + if tag: + exp_tag = ExperimentTag( + team_id=team_id, + experiment_id=uid, + tag=tag, + ) + session.add(exp_tag) + session.commit() exp_id = new_exp.uuid @@ -430,19 +442,37 @@ def get_exp_by_name( session.close() return trial - def list_exps_by_team_id( + def list_experiments( self, team_id: uuid.UUID, + label_name: str | None = None, + label_value: str | None = None, + tag: str | None = None, page: int = 0, page_size: int = 10, order_by: str = "created_at", order_desc: bool = True, ) -> list[Experiment]: session = self._session() + query = session.query(Experiment).filter( + Experiment.team_id == team_id, + Experiment.is_del == 0, + ) + + if label_name: + query = query.join( + ExperimentLabel, ExperimentLabel.experiment_id == Experiment.uuid + ).filter(ExperimentLabel.label_name == label_name) + if label_value is not None: + query = query.filter(ExperimentLabel.label_value == label_value) + + if tag: + query = query.join( + ExperimentTag, ExperimentTag.experiment_id == Experiment.uuid + ).filter(ExperimentTag.tag == tag) + exps = ( - session.query(Experiment) - .filter(Experiment.team_id == team_id, Experiment.is_del == 0) - .order_by( + query.order_by( getattr(Experiment, order_by).desc() if order_desc else getattr(Experiment, order_by) @@ -465,41 +495,16 @@ def list_labels_by_exp_id(self, experiment_id: uuid.UUID) -> list[ExperimentLabe session.close() return labels - def list_exps_by_label( - self, - team_id: uuid.UUID, - label_name: str, - label_value: str | None = None, - page: int = 0, - page_size: int = 10, - order_by: str = "created_at", - order_desc: bool = True, - ) -> list[Experiment]: + def list_tags_by_exp_id(self, experiment_id: uuid.UUID) -> list[ExperimentTag]: session = self._session() - query = ( - session.query(Experiment) - .join(ExperimentLabel, ExperimentLabel.experiment_id == Experiment.uuid) - .filter( - Experiment.team_id == team_id, - Experiment.is_del == 0, - ExperimentLabel.label_name == label_name, - ) - ) - if label_value is not None: - query = query.filter(ExperimentLabel.label_value == label_value) - - exps = ( - query.order_by( - getattr(Experiment, order_by).desc() - if order_desc - else getattr(Experiment, order_by) - ) - .offset(page * page_size) - .limit(page_size) + tags = ( + session.query(ExperimentTag) + .filter(ExperimentTag.experiment_id == experiment_id) + .order_by(ExperimentTag.created_at.asc()) .all() ) session.close() - return exps + return tags def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None: session = self._session() diff --git a/dashboard/src/lib/graphql-client.ts b/dashboard/src/lib/graphql-client.ts index 3a892395..6e1fe4b4 100644 --- a/dashboard/src/lib/graphql-client.ts +++ b/dashboard/src/lib/graphql-client.ts @@ -148,8 +148,8 @@ export const queries = { `, listExperiments: ` - query ListExperiments($teamId: ID!, $labelName: String, $labelValue: String, $page: Int, $pageSize: Int) { - experiments(teamId: $teamId, labelName: $labelName, labelValue: $labelValue, page: $page, pageSize: $pageSize) { + query ListExperiments($teamId: ID!, $labelName: String, $labelValue: String, $tag: String, $page: Int, $pageSize: Int) { + experiments(teamId: $teamId, labelName: $labelName, labelValue: $labelValue, tag: $tag, page: $page, pageSize: $pageSize) { id teamId userId @@ -162,6 +162,7 @@ export const queries = { name value } + tags duration status createdAt @@ -185,6 +186,7 @@ export const queries = { name value } + tags duration status createdAt diff --git a/dashboard/src/pages/experiments/index.tsx b/dashboard/src/pages/experiments/index.tsx index e5adc945..3232f8cf 100644 --- a/dashboard/src/pages/experiments/index.tsx +++ b/dashboard/src/pages/experiments/index.tsx @@ -78,12 +78,15 @@ const LABEL_COLORS = [ { bg: 'bg-stone-100', text: 'text-stone-700', border: 'border-stone-300' }, ]; +const TAG_COLOR = { bg: 'bg-stone-100', text: 'text-stone-700', arrow: 'border-r-stone-100' }; + const PAGE_SIZE = 10; export function ExperimentsPage() { const { selectedTeamId } = useTeamContext(); const [statusFilter, setStatusFilter] = useState('ALL'); const [labelFilters, setLabelFilters] = useState([]); + const [tagFilters, setTagFilters] = useState([]); const [searchQuery, setSearchQuery] = useState(''); const [currentPage, setCurrentPage] = useState(0); const [selectedExperiments, setSelectedExperiments] = useState>(new Set()); @@ -172,6 +175,22 @@ export function ExperimentsPage() { return options; }, [experiments]); + // Build tag options from all experiments + const tagOptions = useMemo(() => { + if (!experiments || experiments.length === 0) return []; + + const uniqueTags = new Set(); + experiments.forEach(exp => { + exp.tags?.forEach(tag => uniqueTags.add(tag)); + }); + + return Array.from(uniqueTags).sort().map(tag => ({ + value: tag, + label: tag, + group: 'Tags', + })); + }, [experiments]); + // Filter and sort experiments const filteredExperiments = useMemo(() => { if (!experiments) return []; @@ -189,7 +208,8 @@ export function ExperimentsPage() { exp.labels?.some(label => label.name.toLowerCase().includes(query) || label.value.toLowerCase().includes(query) - ) + ) || + exp.tags?.some(tag => tag.toLowerCase().includes(query)) ); } @@ -218,11 +238,18 @@ export function ExperimentsPage() { }); } + // Apply tag filters (AND logic - experiment must have ALL selected tags) + if (tagFilters.length > 0) { + filtered = filtered.filter(exp => + tagFilters.every(tag => exp.tags?.includes(tag)) + ); + } + // Sort by creation time descending (newest first) filtered.sort((a, b) => new Date(b.createdAt).getTime() - new Date(a.createdAt).getTime()); return filtered; - }, [experiments, statusFilter, labelFilters, searchQuery]); + }, [experiments, statusFilter, labelFilters, tagFilters, searchQuery]); // Check if all filtered experiments are selected const allSelected = filteredExperiments.length > 0 && @@ -301,10 +328,21 @@ export function ExperimentsPage() { values={labelFilters} onChange={(values) => setLabelFilters(values)} options={labelOptions} - className="w-64" + className="w-48" placeholder="Filter by labels..." /> + {/* Tag Filter */} + {tagOptions.length > 0 && ( + setTagFilters(values)} + options={tagOptions} + className="w-48" + placeholder="Filter by tags..." + /> + )} + {/* Status Filter */}
- {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 ? ( + {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 || tagFilters.length > 0 ? ( ) : ( )}

- {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 + {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 || tagFilters.length > 0 ? 'No experiments match your filters' : 'No experiments found'}

- {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 + {searchQuery.trim() || statusFilter !== 'ALL' || labelFilters.length > 0 || tagFilters.length > 0 ? 'Try adjusting your filters or search query' : 'Experiments will appear here once created'}

@@ -369,6 +407,7 @@ export function ExperimentsPage() { Name Labels + Tags Status Created @@ -418,6 +457,25 @@ export function ExperimentsPage() { - )} + + {experiment.tags && experiment.tags.length > 0 ? ( +
+ {experiment.tags.map((tag, idx) => ( + + + + {tag} + + + ))} +
+ ) : ( + - + )} +
{experiment.status} diff --git a/dashboard/src/types/index.ts b/dashboard/src/types/index.ts index 1605574d..20b15783 100644 --- a/dashboard/src/types/index.ts +++ b/dashboard/src/types/index.ts @@ -64,6 +64,7 @@ export interface Experiment { meta: Record | null; params: Record | null; labels: Label[]; + tags: string[]; duration: number; status: Status; createdAt: string; diff --git a/migrations/versions/faa471d8accb_add_new_table_tags.py b/migrations/versions/faa471d8accb_add_new_table_tags.py new file mode 100644 index 00000000..e08329e6 --- /dev/null +++ b/migrations/versions/faa471d8accb_add_new_table_tags.py @@ -0,0 +1,34 @@ +"""add new table tags + +Revision ID: faa471d8accb +Revises: 467107424ef6 +Create Date: 2026-03-12 23:51:17.473606 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'faa471d8accb' +down_revision: Union[str, Sequence[str], None] = '467107424ef6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('unique_experiment_tag'), 'experiment_tags', type_='unique') + op.create_index('idx_experiment_tag_lookup', 'experiment_tags', ['experiment_id', 'tag'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('idx_experiment_tag_lookup', table_name='experiment_tags') + op.create_unique_constraint(op.f('unique_experiment_tag'), 'experiment_tags', ['experiment_id', 'tag'], postgresql_nulls_not_distinct=False) + # ### end Alembic commands ### diff --git a/tests/unit/experiment/test_experimant.py b/tests/unit/experiment/test_experiment.py similarity index 93% rename from tests/unit/experiment/test_experimant.py rename to tests/unit/experiment/test_experiment.py index fe075d2d..8047f887 100644 --- a/tests/unit/experiment/test_experimant.py +++ b/tests/unit/experiment/test_experiment.py @@ -363,10 +363,39 @@ async def test_experiment_with_labels(): exp_obj = exp._get_obj() assert exp_obj is not None - exp_labels = exp._runtime.metadb.list_exps_by_label( + exp_labels = exp._runtime.metadb.list_experiments( team_id=team_id, label_name="foo", label_value="bar", ) assert len(exp_labels) == 1 + + +@pytest.mark.asyncio +async def test_experiment_with_tags(): + team_id = uuid.uuid4() + user_id = uuid.uuid4() + init( + team_id=team_id, + user_id=user_id, + ) + + async with CraftExperiment.start( + name="first-experiment", + tags=["foo", "bar"], + ) as exp: + exp_obj = exp._get_obj() + assert exp_obj is not None + + exp_tags = exp._runtime.metadb.list_experiments( + team_id=team_id, + tag="foo", + ) + + assert len(exp_tags) == 1 + + all_tags = exp._runtime.metadb.list_tags_by_exp_id( + experiment_id=exp.id, + ) + assert len(all_tags) == 2