Skip to content

Commit 77adbf6

Browse files
authored
Merge pull request #1982 from transformerlab/add/experiments-ui
Revamp the experiment list UI
2 parents e33d5b5 + 49c10d8 commit 77adbf6

12 files changed

Lines changed: 965 additions & 82 deletions

File tree

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""add_user_experiment_access_table
2+
3+
Revision ID: 46378c10f132
4+
Revises: 6ccd4a4d9ca1
5+
Create Date: 2026-05-04 13:23:27.122716
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
import sqlalchemy as sa
13+
from transformerlab.db.migration_utils import table_exists
14+
15+
16+
# revision identifiers, used by Alembic.
17+
revision: str = "46378c10f132"
18+
down_revision: Union[str, Sequence[str], None] = "6ccd4a4d9ca1"
19+
branch_labels: Union[str, Sequence[str], None] = None
20+
depends_on: Union[str, Sequence[str], None] = None
21+
22+
23+
def upgrade() -> None:
24+
connection = op.get_bind()
25+
if not table_exists(connection, "user_experiment_access"):
26+
op.create_table(
27+
"user_experiment_access",
28+
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
29+
sa.Column("user_id", sa.String(), nullable=False),
30+
sa.Column("team_id", sa.String(), nullable=False),
31+
sa.Column("experiment_id", sa.String(), nullable=False),
32+
sa.Column(
33+
"last_opened_at",
34+
sa.DateTime(),
35+
server_default=sa.text("(CURRENT_TIMESTAMP)"),
36+
nullable=False,
37+
),
38+
sa.PrimaryKeyConstraint("id"),
39+
sa.UniqueConstraint(
40+
"user_id",
41+
"team_id",
42+
"experiment_id",
43+
name="uq_user_experiment_access",
44+
),
45+
)
46+
op.create_index(
47+
"idx_user_experiment_access_user_team",
48+
"user_experiment_access",
49+
["user_id", "team_id"],
50+
)
51+
52+
53+
def downgrade() -> None:
54+
op.drop_index("idx_user_experiment_access_user_team", table_name="user_experiment_access", if_exists=True)
55+
op.drop_table("user_experiment_access", if_exists=True)

api/test/api/test_experiment_service.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ async def test_missing_experiment_returns_none(tmp_experiments_dir):
4646
assert await experiment_service.experiment_get("no_such_experiment") is None
4747

4848

49+
@pytest.mark.asyncio
50+
async def test_duplicate_experiment_create_raises_file_exists(tmp_experiments_dir):
51+
_ = tmp_experiments_dir
52+
name = f"duplicate_exp_{uuid.uuid4().hex[:8]}"
53+
await experiment_service.experiment_create(name, {"a": 1})
54+
55+
with pytest.raises(FileExistsError):
56+
await experiment_service.experiment_create(name, {"a": 2})
57+
58+
4959
# Added test to hit the new FileNotFoundError except-clauses in experiment_service
5060
@pytest.mark.asyncio
5161
async def test_missing_experiment_operations_handle_FileNotFound(tmp_experiments_dir):
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
3+
from sqlalchemy.exc import IntegrityError
4+
5+
import transformerlab.services.experiment_access_service as svc
6+
7+
8+
async def test_touch_experiment_upserts_record():
9+
mock_session = AsyncMock()
10+
mock_session.add = MagicMock()
11+
mock_result = MagicMock(rowcount=0)
12+
mock_session.execute.return_value = mock_result
13+
14+
await svc.touch_experiment(mock_session, "user1", "team1", "exp1")
15+
16+
mock_session.add.assert_called_once()
17+
mock_session.commit.assert_called_once()
18+
19+
20+
async def test_touch_experiment_updates_existing_record():
21+
mock_session = AsyncMock()
22+
mock_session.add = MagicMock()
23+
mock_result = MagicMock(rowcount=1)
24+
mock_session.execute.return_value = mock_result
25+
26+
await svc.touch_experiment(mock_session, "user1", "team1", "exp1")
27+
28+
mock_session.add.assert_not_called()
29+
mock_session.commit.assert_called_once()
30+
31+
32+
async def test_touch_experiment_handles_insert_race_integrity_error():
33+
mock_session = AsyncMock()
34+
mock_session.add = MagicMock()
35+
mock_result = MagicMock(rowcount=0)
36+
mock_session.execute.return_value = mock_result
37+
mock_session.commit.side_effect = [
38+
IntegrityError("stmt", "params", Exception("duplicate key")),
39+
]
40+
41+
await svc.touch_experiment(mock_session, "user1", "team1", "exp1")
42+
43+
mock_session.rollback.assert_called_once()
44+
45+
46+
async def test_get_recent_experiment_ids_returns_ordered_list():
47+
mock_session = AsyncMock()
48+
record1 = MagicMock()
49+
record1.experiment_id = "exp_b"
50+
record2 = MagicMock()
51+
record2.experiment_id = "exp_a"
52+
mock_result = MagicMock()
53+
mock_result.scalars.return_value.all.return_value = [record1, record2]
54+
mock_session.execute.return_value = mock_result
55+
56+
result = await svc.get_recent_experiment_ids(mock_session, "user1", "team1", limit=3)
57+
58+
assert result == ["exp_b", "exp_a"]

api/transformerlab/routers/experiment/experiment.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from typing import Annotated
44

5-
from fastapi import APIRouter, Body, Depends
5+
from fastapi import APIRouter, Body, Depends, HTTPException
66
from sqlalchemy.ext.asyncio import AsyncSession
77

88
import transformerlab.services.experiment_service as experiment_service
9+
import transformerlab.services.experiment_access_service as access_service
910
from lab import Experiment, storage
1011
from transformerlab.shared import shared
1112
from transformerlab.routers.experiment import (
@@ -16,7 +17,8 @@
1617
)
1718
from transformerlab.routers.auth import get_user_and_team
1819
from transformerlab.services.permission_service import check_permission, get_user_team, require_permission
19-
from transformerlab.shared.models.models import TeamRole
20+
from sqlalchemy import select
21+
from transformerlab.shared.models.models import TeamRole, UserExperimentAccess
2022
from transformerlab.shared.models.user_model import get_async_session
2123

2224
from werkzeug.utils import secure_filename
@@ -54,7 +56,7 @@ async def experiments_get_all(
5456
session: AsyncSession = Depends(get_async_session),
5557
user_and_team: dict = Depends(get_user_and_team),
5658
):
57-
"""Get a list of all experiments"""
59+
"""Get a list of all experiments, filtered by role, with per-user last_opened_at."""
5860
experiments = await experiment_service.experiment_get_all()
5961
user = user_and_team["user"]
6062
team_id = user_and_team["team_id"]
@@ -63,34 +65,99 @@ async def experiments_get_all(
6365
user_team = await get_user_team(session, user_id, team_id)
6466
if user_team is None:
6567
return []
68+
69+
# Role-based filtering (existing logic)
6670
if user_team.role == TeamRole.OWNER.value:
67-
return experiments
68-
69-
filtered_experiments = []
70-
for experiment in experiments:
71-
experiment_id = str(experiment.get("id"))
72-
if not experiment_id:
73-
continue
74-
allowed = await check_permission(
75-
session=session,
76-
user_id=user_id,
77-
team_id=team_id,
78-
resource_type="experiment",
79-
resource_id=experiment_id,
80-
action="read",
81-
user_team=user_team,
71+
filtered = experiments
72+
else:
73+
filtered = []
74+
for experiment in experiments:
75+
experiment_id = str(experiment.get("id"))
76+
if not experiment_id:
77+
continue
78+
allowed = await check_permission(
79+
session=session,
80+
user_id=user_id,
81+
team_id=team_id,
82+
resource_type="experiment",
83+
resource_id=experiment_id,
84+
action="read",
85+
user_team=user_team,
86+
)
87+
if allowed:
88+
filtered.append(experiment)
89+
90+
# Attach per-user last_opened_at
91+
access_records = await session.execute(
92+
select(UserExperimentAccess).where(
93+
UserExperimentAccess.user_id == user_id,
94+
UserExperimentAccess.team_id == team_id,
8295
)
83-
if allowed:
84-
filtered_experiments.append(experiment)
85-
return filtered_experiments
96+
)
97+
access_map = {row.experiment_id: row.last_opened_at.isoformat() for row in access_records.scalars().all()}
98+
99+
for exp in filtered:
100+
exp_id = str(exp.get("id", ""))
101+
exp["last_opened_at"] = access_map.get(exp_id)
102+
103+
return filtered
104+
105+
106+
@router.post("/{id}/touch", summary="Record experiment opened", tags=["experiment"])
107+
async def experiment_touch(
108+
id: str,
109+
session: AsyncSession = Depends(get_async_session),
110+
user_and_team: dict = Depends(get_user_and_team),
111+
_: None = Depends(require_permission("experiment", "read")),
112+
):
113+
user_id = str(user_and_team["user"].id)
114+
team_id = str(user_and_team["team_id"])
115+
await access_service.touch_experiment(session, user_id, team_id, id)
116+
return {"status": "ok"}
117+
118+
119+
@router.get("/recent", summary="Get recently opened experiments", tags=["experiment"])
120+
async def experiments_get_recent(
121+
session: AsyncSession = Depends(get_async_session),
122+
user_and_team: dict = Depends(get_user_and_team),
123+
):
124+
"""Return last 3 experiments opened by the current user that the user still has access to.
125+
Falls back to 3 permitted experiments if no access records exist."""
126+
user = user_and_team["user"]
127+
team_id = str(user_and_team["team_id"])
128+
user_id = str(user.id)
129+
130+
user_team = await get_user_team(session, user_id, team_id)
131+
if user_team is None:
132+
return []
133+
134+
recent_ids = await access_service.get_recent_experiment_ids(session, user_id, team_id, limit=3)
135+
permitted_experiments = await experiments_get_all(session=session, user_and_team=user_and_team)
136+
if not recent_ids:
137+
return permitted_experiments[:3]
138+
139+
permitted_by_id = {
140+
str(exp.get("id")): exp for exp in permitted_experiments if isinstance(exp, dict) and exp.get("id")
141+
}
142+
ordered_recent = [permitted_by_id[exp_id] for exp_id in recent_ids if exp_id in permitted_by_id]
143+
return ordered_recent[:3]
86144

87145

88146
@router.get("/create", summary="Create Experiment", tags=["experiment"])
89-
async def experiments_create(name: str):
147+
async def experiments_create(
148+
name: str,
149+
user_and_team: dict = Depends(get_user_and_team),
150+
):
90151
# Apply secure filename validation to the experiment name
91152
secure_name = secure_filename(name)
153+
if not secure_name:
154+
raise HTTPException(status_code=422, detail="Invalid experiment name")
155+
user_id = str(user_and_team["user"].id)
92156

93-
newid = await experiment_service.experiment_create(secure_name, {})
157+
try:
158+
newid = await experiment_service.experiment_create(secure_name, {}, created_by=user_id)
159+
except FileExistsError as e:
160+
raise HTTPException(status_code=409, detail=f"Experiment '{secure_name}' already exists") from e
94161
return newid
95162

96163

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import logging
2+
from datetime import datetime, timezone
3+
4+
from sqlalchemy import select, update
5+
from sqlalchemy.exc import IntegrityError
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
8+
from transformerlab.shared.models.models import UserExperimentAccess
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
async def touch_experiment(session: AsyncSession, user_id: str, team_id: str, experiment_id: str) -> None:
14+
"""Upsert last_opened_at for a user-experiment pair."""
15+
now = datetime.now(timezone.utc)
16+
result = await session.execute(
17+
update(UserExperimentAccess)
18+
.where(
19+
UserExperimentAccess.user_id == user_id,
20+
UserExperimentAccess.team_id == team_id,
21+
UserExperimentAccess.experiment_id == experiment_id,
22+
)
23+
.values(last_opened_at=now)
24+
)
25+
if result.rowcount == 0:
26+
try:
27+
session.add(
28+
UserExperimentAccess(
29+
user_id=user_id,
30+
team_id=team_id,
31+
experiment_id=experiment_id,
32+
last_opened_at=now,
33+
)
34+
)
35+
await session.commit()
36+
except IntegrityError:
37+
# Another concurrent request inserted first; treat as success.
38+
await session.rollback()
39+
else:
40+
await session.commit()
41+
42+
43+
async def get_recent_experiment_ids(session: AsyncSession, user_id: str, team_id: str, limit: int = 3) -> list[str]:
44+
"""Return experiment IDs ordered by last_opened_at DESC for a user."""
45+
result = await session.execute(
46+
select(UserExperimentAccess)
47+
.where(
48+
UserExperimentAccess.user_id == user_id,
49+
UserExperimentAccess.team_id == team_id,
50+
)
51+
.order_by(UserExperimentAccess.last_opened_at.desc())
52+
.limit(limit)
53+
)
54+
return [row.experiment_id for row in result.scalars().all()]

api/transformerlab/services/experiment_service.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import json
44
import os
55

6+
from sqlalchemy import delete
67
from lab import Experiment
78
from lab import dirs as lab_dirs
89
from lab import storage
910

11+
from transformerlab.db.session import async_session
1012
from transformerlab.services.cache_service import cache, cached
13+
from transformerlab.shared.models.models import UserExperimentAccess
1114

1215
logger = logging.getLogger(__name__)
1316
EXPERIMENT_LIST_CONCURRENCY = max(1, int(os.getenv("TLAB_EXPERIMENT_LIST_CONCURRENCY", "24")))
@@ -100,7 +103,9 @@ async def _read_with_limit(exp_path: str) -> dict | None:
100103
return experiments
101104

102105

103-
async def experiment_create(name: str, config: dict) -> str:
106+
async def experiment_create(name: str, config: dict, created_by: str | None = None) -> str:
107+
if created_by:
108+
config = {**config, "created_by": created_by}
104109
await Experiment.create_with_config(name, config)
105110
# Ensure the experiment dropdown refreshes immediately after creation.
106111
await cache.invalidate("experiments")
@@ -131,6 +136,9 @@ async def experiment_delete(id):
131136
try:
132137
exp = await Experiment.get(id)
133138
await exp.delete()
139+
async with async_session() as session:
140+
await session.execute(delete(UserExperimentAccess).where(UserExperimentAccess.experiment_id == str(id)))
141+
await session.commit()
134142
await cache.invalidate("experiments")
135143
except FileNotFoundError:
136144
print(f"Experiment with id '{id}' not found")

api/transformerlab/shared/models/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,22 @@ class JobQueue(Base):
346346
Index("idx_job_queue_status_type", "status", "queue_type"),
347347
Index("idx_job_queue_job_id", "job_id"),
348348
)
349+
350+
351+
class UserExperimentAccess(Base):
352+
"""Tracks when each user last opened each experiment."""
353+
354+
__tablename__ = "user_experiment_access"
355+
356+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
357+
user_id: Mapped[str] = mapped_column(String, nullable=False)
358+
team_id: Mapped[str] = mapped_column(String, nullable=False)
359+
experiment_id: Mapped[str] = mapped_column(String, nullable=False)
360+
last_opened_at: Mapped[DateTime] = mapped_column(
361+
DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
362+
)
363+
364+
__table_args__ = (
365+
UniqueConstraint("user_id", "team_id", "experiment_id", name="uq_user_experiment_access"),
366+
Index("idx_user_experiment_access_user_team", "user_id", "team_id"),
367+
)

0 commit comments

Comments
 (0)