Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""add_user_experiment_access_table

Revision ID: 46378c10f132
Revises: 6ccd4a4d9ca1
Create Date: 2026-05-04 13:23:27.122716

"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from transformerlab.db.migration_utils import table_exists


# revision identifiers, used by Alembic.
revision: str = "46378c10f132"
down_revision: Union[str, Sequence[str], None] = "6ccd4a4d9ca1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
connection = op.get_bind()
if not table_exists(connection, "user_experiment_access"):
op.create_table(
"user_experiment_access",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("team_id", sa.String(), nullable=False),
sa.Column("experiment_id", sa.String(), nullable=False),
sa.Column(
"last_opened_at",
sa.DateTime(),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"user_id",
"team_id",
"experiment_id",
name="uq_user_experiment_access",
),
)
op.create_index(
"idx_user_experiment_access_user_team",
"user_experiment_access",
["user_id", "team_id"],
)


def downgrade() -> None:
op.drop_index("idx_user_experiment_access_user_team", table_name="user_experiment_access", if_exists=True)
op.drop_table("user_experiment_access", if_exists=True)
10 changes: 10 additions & 0 deletions api/test/api/test_experiment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ async def test_missing_experiment_returns_none(tmp_experiments_dir):
assert await experiment_service.experiment_get("no_such_experiment") is None


@pytest.mark.asyncio
async def test_duplicate_experiment_create_raises_file_exists(tmp_experiments_dir):
_ = tmp_experiments_dir
name = f"duplicate_exp_{uuid.uuid4().hex[:8]}"
await experiment_service.experiment_create(name, {"a": 1})

with pytest.raises(FileExistsError):
await experiment_service.experiment_create(name, {"a": 2})


# Added test to hit the new FileNotFoundError except-clauses in experiment_service
@pytest.mark.asyncio
async def test_missing_experiment_operations_handle_FileNotFound(tmp_experiments_dir):
Expand Down
58 changes: 58 additions & 0 deletions api/test/test_experiment_access_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from unittest.mock import AsyncMock, MagicMock

from sqlalchemy.exc import IntegrityError

import transformerlab.services.experiment_access_service as svc


async def test_touch_experiment_upserts_record():
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_result = MagicMock(rowcount=0)
mock_session.execute.return_value = mock_result

await svc.touch_experiment(mock_session, "user1", "team1", "exp1")

mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()


async def test_touch_experiment_updates_existing_record():
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_result = MagicMock(rowcount=1)
mock_session.execute.return_value = mock_result

await svc.touch_experiment(mock_session, "user1", "team1", "exp1")

mock_session.add.assert_not_called()
mock_session.commit.assert_called_once()


async def test_touch_experiment_handles_insert_race_integrity_error():
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_result = MagicMock(rowcount=0)
mock_session.execute.return_value = mock_result
mock_session.commit.side_effect = [
IntegrityError("stmt", "params", Exception("duplicate key")),
]

await svc.touch_experiment(mock_session, "user1", "team1", "exp1")

mock_session.rollback.assert_called_once()


async def test_get_recent_experiment_ids_returns_ordered_list():
mock_session = AsyncMock()
record1 = MagicMock()
record1.experiment_id = "exp_b"
record2 = MagicMock()
record2.experiment_id = "exp_a"
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [record1, record2]
mock_session.execute.return_value = mock_result

result = await svc.get_recent_experiment_ids(mock_session, "user1", "team1", limit=3)

assert result == ["exp_b", "exp_a"]
113 changes: 90 additions & 23 deletions api/transformerlab/routers/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import Annotated

from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession

import transformerlab.services.experiment_service as experiment_service
import transformerlab.services.experiment_access_service as access_service
from lab import Experiment, storage
from transformerlab.shared import shared
from transformerlab.routers.experiment import (
Expand All @@ -16,7 +17,8 @@
)
from transformerlab.routers.auth import get_user_and_team
from transformerlab.services.permission_service import check_permission, get_user_team, require_permission
from transformerlab.shared.models.models import TeamRole
from sqlalchemy import select
from transformerlab.shared.models.models import TeamRole, UserExperimentAccess
from transformerlab.shared.models.user_model import get_async_session

from werkzeug.utils import secure_filename
Expand Down Expand Up @@ -54,7 +56,7 @@ async def experiments_get_all(
session: AsyncSession = Depends(get_async_session),
user_and_team: dict = Depends(get_user_and_team),
):
"""Get a list of all experiments"""
"""Get a list of all experiments, filtered by role, with per-user last_opened_at."""
experiments = await experiment_service.experiment_get_all()
user = user_and_team["user"]
team_id = user_and_team["team_id"]
Expand All @@ -63,34 +65,99 @@ async def experiments_get_all(
user_team = await get_user_team(session, user_id, team_id)
if user_team is None:
return []

# Role-based filtering (existing logic)
if user_team.role == TeamRole.OWNER.value:
return experiments

filtered_experiments = []
for experiment in experiments:
experiment_id = str(experiment.get("id"))
if not experiment_id:
continue
allowed = await check_permission(
session=session,
user_id=user_id,
team_id=team_id,
resource_type="experiment",
resource_id=experiment_id,
action="read",
user_team=user_team,
filtered = experiments
else:
filtered = []
for experiment in experiments:
experiment_id = str(experiment.get("id"))
if not experiment_id:
continue
allowed = await check_permission(
session=session,
user_id=user_id,
team_id=team_id,
resource_type="experiment",
resource_id=experiment_id,
action="read",
user_team=user_team,
)
if allowed:
filtered.append(experiment)

# Attach per-user last_opened_at
access_records = await session.execute(
select(UserExperimentAccess).where(
UserExperimentAccess.user_id == user_id,
UserExperimentAccess.team_id == team_id,
)
if allowed:
filtered_experiments.append(experiment)
return filtered_experiments
)
access_map = {row.experiment_id: row.last_opened_at.isoformat() for row in access_records.scalars().all()}

for exp in filtered:
exp_id = str(exp.get("id", ""))
exp["last_opened_at"] = access_map.get(exp_id)

return filtered


@router.post("/{id}/touch", summary="Record experiment opened", tags=["experiment"])
async def experiment_touch(
id: str,
session: AsyncSession = Depends(get_async_session),
user_and_team: dict = Depends(get_user_and_team),
_: None = Depends(require_permission("experiment", "read")),
):
user_id = str(user_and_team["user"].id)
team_id = str(user_and_team["team_id"])
await access_service.touch_experiment(session, user_id, team_id, id)
return {"status": "ok"}


@router.get("/recent", summary="Get recently opened experiments", tags=["experiment"])
async def experiments_get_recent(
session: AsyncSession = Depends(get_async_session),
user_and_team: dict = Depends(get_user_and_team),
):
"""Return last 3 experiments opened by the current user that the user still has access to.
Falls back to 3 permitted experiments if no access records exist."""
user = user_and_team["user"]
team_id = str(user_and_team["team_id"])
user_id = str(user.id)

user_team = await get_user_team(session, user_id, team_id)
if user_team is None:
return []

recent_ids = await access_service.get_recent_experiment_ids(session, user_id, team_id, limit=3)
permitted_experiments = await experiments_get_all(session=session, user_and_team=user_and_team)
if not recent_ids:
return permitted_experiments[:3]

permitted_by_id = {
str(exp.get("id")): exp for exp in permitted_experiments if isinstance(exp, dict) and exp.get("id")
}
ordered_recent = [permitted_by_id[exp_id] for exp_id in recent_ids if exp_id in permitted_by_id]
return ordered_recent[:3]


@router.get("/create", summary="Create Experiment", tags=["experiment"])
async def experiments_create(name: str):
async def experiments_create(
name: str,
user_and_team: dict = Depends(get_user_and_team),
):
# Apply secure filename validation to the experiment name
secure_name = secure_filename(name)
if not secure_name:
raise HTTPException(status_code=422, detail="Invalid experiment name")
user_id = str(user_and_team["user"].id)

newid = await experiment_service.experiment_create(secure_name, {})
try:
newid = await experiment_service.experiment_create(secure_name, {}, created_by=user_id)
except FileExistsError as e:
raise HTTPException(status_code=409, detail=f"Experiment '{secure_name}' already exists") from e
return newid


Expand Down
54 changes: 54 additions & 0 deletions api/transformerlab/services/experiment_access_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
from datetime import datetime, timezone

from sqlalchemy import select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from transformerlab.shared.models.models import UserExperimentAccess

logger = logging.getLogger(__name__)


async def touch_experiment(session: AsyncSession, user_id: str, team_id: str, experiment_id: str) -> None:
"""Upsert last_opened_at for a user-experiment pair."""
now = datetime.now(timezone.utc)
result = await session.execute(
update(UserExperimentAccess)
.where(
UserExperimentAccess.user_id == user_id,
UserExperimentAccess.team_id == team_id,
UserExperimentAccess.experiment_id == experiment_id,
)
.values(last_opened_at=now)
)
if result.rowcount == 0:
try:
session.add(
UserExperimentAccess(
user_id=user_id,
team_id=team_id,
experiment_id=experiment_id,
last_opened_at=now,
)
)
await session.commit()
except IntegrityError:
# Another concurrent request inserted first; treat as success.
await session.rollback()
else:
await session.commit()


async def get_recent_experiment_ids(session: AsyncSession, user_id: str, team_id: str, limit: int = 3) -> list[str]:
"""Return experiment IDs ordered by last_opened_at DESC for a user."""
result = await session.execute(
select(UserExperimentAccess)
.where(
UserExperimentAccess.user_id == user_id,
UserExperimentAccess.team_id == team_id,
)
.order_by(UserExperimentAccess.last_opened_at.desc())
.limit(limit)
)
return [row.experiment_id for row in result.scalars().all()]
10 changes: 9 additions & 1 deletion api/transformerlab/services/experiment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import json
import os

from sqlalchemy import delete
from lab import Experiment
from lab import dirs as lab_dirs
from lab import storage

from transformerlab.db.session import async_session
from transformerlab.services.cache_service import cache, cached
from transformerlab.shared.models.models import UserExperimentAccess

logger = logging.getLogger(__name__)
EXPERIMENT_LIST_CONCURRENCY = max(1, int(os.getenv("TLAB_EXPERIMENT_LIST_CONCURRENCY", "24")))
Expand Down Expand Up @@ -100,7 +103,9 @@ async def _read_with_limit(exp_path: str) -> dict | None:
return experiments


async def experiment_create(name: str, config: dict) -> str:
async def experiment_create(name: str, config: dict, created_by: str | None = None) -> str:
if created_by:
config = {**config, "created_by": created_by}
await Experiment.create_with_config(name, config)
# Ensure the experiment dropdown refreshes immediately after creation.
await cache.invalidate("experiments")
Expand Down Expand Up @@ -131,6 +136,9 @@ async def experiment_delete(id):
try:
exp = await Experiment.get(id)
await exp.delete()
async with async_session() as session:
await session.execute(delete(UserExperimentAccess).where(UserExperimentAccess.experiment_id == str(id)))
await session.commit()
await cache.invalidate("experiments")
except FileNotFoundError:
print(f"Experiment with id '{id}' not found")
Expand Down
19 changes: 19 additions & 0 deletions api/transformerlab/shared/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,22 @@ class JobQueue(Base):
Index("idx_job_queue_status_type", "status", "queue_type"),
Index("idx_job_queue_job_id", "job_id"),
)


class UserExperimentAccess(Base):
"""Tracks when each user last opened each experiment."""

__tablename__ = "user_experiment_access"

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(String, nullable=False)
team_id: Mapped[str] = mapped_column(String, nullable=False)
experiment_id: Mapped[str] = mapped_column(String, nullable=False)
last_opened_at: Mapped[DateTime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
)

__table_args__ = (
UniqueConstraint("user_id", "team_id", "experiment_id", name="uq_user_experiment_access"),
Index("idx_user_experiment_access_user_team", "user_id", "team_id"),
)
Loading
Loading