diff --git a/backend/recotem/tests/test_assign_owners_command.py b/backend/recotem/tests/test_assign_owners_command.py new file mode 100644 index 00000000..f433039b --- /dev/null +++ b/backend/recotem/tests/test_assign_owners_command.py @@ -0,0 +1,80 @@ +"""Tests for the assign_owners management command.""" + +from io import StringIO + +import pytest +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError + +from recotem.api.models import EvaluationConfig, Project, SplitConfig + +User = get_user_model() + + +@pytest.fixture +def target_user(db): + return User.objects.create_user(username="assign_target", password="pass") + + +@pytest.mark.django_db +class TestAssignOwnersCommand: + def test_assigns_owner_to_unowned_projects(self, target_user): + p = Project.objects.create( + name="unowned", user_column="u", item_column="i", owner=None + ) + out = StringIO() + call_command("assign_owners", "--user", target_user.username, stdout=out) + + p.refresh_from_db() + assert p.owner == target_user + assert "assigned" in out.getvalue() + + def test_assigns_created_by_to_configs(self, target_user): + sc = SplitConfig.objects.create(created_by=None) + ec = EvaluationConfig.objects.create(created_by=None) + + out = StringIO() + call_command("assign_owners", "--user", target_user.username, stdout=out) + + sc.refresh_from_db() + ec.refresh_from_db() + assert sc.created_by == target_user + assert ec.created_by == target_user + + def test_dry_run_no_changes(self, target_user): + p = Project.objects.create( + name="dry_run_proj", user_column="u", item_column="i", owner=None + ) + sc = SplitConfig.objects.create(created_by=None) + + out = StringIO() + call_command( + "assign_owners", "--user", target_user.username, "--dry-run", stdout=out + ) + + p.refresh_from_db() + sc.refresh_from_db() + assert p.owner is None + assert sc.created_by is None + assert "Dry run complete" in out.getvalue() + + def test_invalid_user_raises(self): + with pytest.raises(CommandError, match="does not exist"): + call_command("assign_owners", "--user", "ghost_user") + + def test_no_unowned_records(self, target_user): + # Create records that already have owners + Project.objects.create( + name="owned_proj", + user_column="u", + item_column="i", + owner=target_user, + ) + SplitConfig.objects.create(created_by=target_user) + EvaluationConfig.objects.create(created_by=target_user) + + out = StringIO() + call_command("assign_owners", "--user", target_user.username, stdout=out) + + assert "no unowned records" in out.getvalue() diff --git a/backend/recotem/tests/test_authentication.py b/backend/recotem/tests/test_authentication.py index ed96aa3c..bd42b6b8 100644 --- a/backend/recotem/tests/test_authentication.py +++ b/backend/recotem/tests/test_authentication.py @@ -2,7 +2,9 @@ import pytest from django.contrib.auth import get_user_model +from django.contrib.auth.hashers import make_password from django.test import RequestFactory +from rest_framework.exceptions import AuthenticationFailed from recotem.api.authentication import ( API_KEY_PREFIX, @@ -143,3 +145,76 @@ def test_authenticate_header(self): factory = RequestFactory() request = factory.get("/") assert auth.authenticate_header(request) == "X-API-Key" + + +@pytest.mark.django_db +class TestRequireManagementScope: + """Test RequireManagementScope permission class.""" + + def test_read_scope_for_get(self, api_key_data, user): + """GET requires 'read' scope.""" + from recotem.api.authentication import RequireManagementScope + + full_key, key_obj = api_key_data + # key_obj has scopes=["read", "predict"] + factory = RequestFactory() + request = factory.get("/") + request.api_key = key_obj + request.user = user + perm = RequireManagementScope() + assert perm.has_permission(request, None) is True + + def test_write_scope_for_post(self, api_key_data, user): + """POST requires 'write' scope -- key only has read+predict, should fail.""" + from recotem.api.authentication import RequireManagementScope + + full_key, key_obj = api_key_data + factory = RequestFactory() + request = factory.post("/") + request.api_key = key_obj + request.user = user + perm = RequireManagementScope() + assert perm.has_permission(request, None) is False + + def test_jwt_always_allowed(self, user): + """JWT user (no api_key attr) passes all scope checks.""" + from recotem.api.authentication import RequireManagementScope + + factory = RequestFactory() + request = factory.post("/") + request.user = user + # No api_key attribute -> JWT + perm = RequireManagementScope() + assert perm.has_permission(request, None) is True + + +@pytest.mark.django_db +class TestAmbiguousApiKeyPrefix: + def test_ambiguous_prefix(self, user, project): + """Two keys with same prefix -> 'Ambiguous API key prefix'.""" + # Create two keys with the same prefix + prefix = "SAMEPRFX" + ApiKey.objects.create( + project=project, + owner=user, + name="key1", + key_prefix=prefix, + hashed_key=make_password("dummy1"), + scopes=["read"], + ) + ApiKey.objects.create( + project=project, + owner=user, + name="key2", + key_prefix=prefix, + hashed_key=make_password("dummy2"), + scopes=["read"], + ) + + factory = RequestFactory() + request = factory.get( + "/", HTTP_X_API_KEY=f"{API_KEY_PREFIX}{prefix}longenoughkey" + ) + auth = ApiKeyAuthentication() + with pytest.raises(AuthenticationFailed, match="Ambiguous"): + auth.authenticate(request) diff --git a/backend/recotem/tests/test_create_api_key_command.py b/backend/recotem/tests/test_create_api_key_command.py new file mode 100644 index 00000000..5e259966 --- /dev/null +++ b/backend/recotem/tests/test_create_api_key_command.py @@ -0,0 +1,141 @@ +"""Tests for the create_api_key management command.""" + +from io import StringIO + +import pytest +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError +from django.utils import timezone + +from recotem.api.models import ApiKey, Project + +User = get_user_model() + + +@pytest.fixture +def admin_user(db): + return User.objects.create_user(username="admin", password="pass") + + +@pytest.fixture +def project(admin_user): + return Project.objects.create( + name="test_project", user_column="user", item_column="item", owner=admin_user + ) + + +@pytest.mark.django_db +class TestCreateApiKeyCommand: + def test_creates_key_prints_to_stdout(self, project, admin_user): + out = StringIO() + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "my-key", + "--owner", + admin_user.username, + stdout=out, + ) + raw_key = out.getvalue().strip() + assert raw_key.startswith("rctm_") + assert ApiKey.objects.filter(project=project, name="my-key").exists() + + def test_invalid_project_raises(self, admin_user): + with pytest.raises(CommandError, match="not found"): + call_command( + "create_api_key", + "--project-id", + "99999", + "--name", + "bad-key", + "--owner", + admin_user.username, + ) + + def test_invalid_owner_raises(self, project): + with pytest.raises(CommandError, match="not found"): + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "bad-key", + "--owner", + "nonexistent_user", + ) + + def test_invalid_scope_raises(self, project, admin_user): + with pytest.raises(CommandError, match="Invalid scope"): + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "bad-scope-key", + "--scopes", + "predict,badscope", + "--owner", + admin_user.username, + ) + + def test_owner_project_mismatch_raises(self, project): + other_user = User.objects.create_user(username="other", password="pass") + with pytest.raises(CommandError, match="does not match project owner"): + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "mismatch-key", + "--owner", + other_user.username, + ) + + def test_duplicate_name_raises(self, project, admin_user): + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "dup-key", + "--owner", + admin_user.username, + stdout=StringIO(), + ) + with pytest.raises(CommandError, match="already exists"): + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "dup-key", + "--owner", + admin_user.username, + ) + + def test_expires_in_days_sets_expiry(self, project, admin_user): + before = timezone.now() + call_command( + "create_api_key", + "--project-id", + str(project.id), + "--name", + "expiring-key", + "--expires-in-days", + "30", + "--owner", + admin_user.username, + stdout=StringIO(), + ) + after = timezone.now() + + key = ApiKey.objects.get(project=project, name="expiring-key") + assert key.expires_at is not None + from datetime import timedelta + + assert ( + before + timedelta(days=30) <= key.expires_at <= after + timedelta(days=30) + ) diff --git a/backend/recotem/tests/test_create_test_users_command.py b/backend/recotem/tests/test_create_test_users_command.py new file mode 100644 index 00000000..56351f01 --- /dev/null +++ b/backend/recotem/tests/test_create_test_users_command.py @@ -0,0 +1,33 @@ +"""Tests for the create_test_users management command.""" + +import pytest +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.core.management.base import CommandError + +User = get_user_model() + + +@pytest.mark.django_db +class TestCreateTestUsersCommand: + def test_creates_new_user(self): + call_command("create_test_users", "--user", "newuser:secret123") + + user = User.objects.get(username="newuser") + assert user.check_password("secret123") + + def test_updates_existing_user(self): + User.objects.create_user(username="existing", password="oldpass") + call_command("create_test_users", "--user", "existing:newpass") + + user = User.objects.get(username="existing") + assert user.check_password("newpass") + assert not user.check_password("oldpass") + + def test_invalid_format_raises(self): + with pytest.raises(CommandError, match="Expected format"): + call_command("create_test_users", "--user", "nocolonhere") + + def test_empty_username_raises(self): + with pytest.raises(CommandError, match="Username and password are required"): + call_command("create_test_users", "--user", ":password") diff --git a/backend/recotem/tests/test_data_post.py b/backend/recotem/tests/test_data_post.py index 5ba74fdb..b19a09f1 100644 --- a/backend/recotem/tests/test_data_post.py +++ b/backend/recotem/tests/test_data_post.py @@ -11,6 +11,17 @@ from recotem.api.models import TrainingData + +@pytest.fixture(autouse=True) +def _use_locmem_cache(settings): + """Use in-memory cache so tests work without Redis.""" + settings.CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + } + } + + I_O_functions: list[ tuple[ str, @@ -334,3 +345,91 @@ def test_metadata_post( assert "URL" in columns assert "movieId" not in columns json_file.close() + + +@pytest.mark.django_db +class TestTrainingDataPreview: + # These tests need a project and training data with an actual file + + def test_preview_returns_data(self, client, ml100k): + """Preview returns columns, rows, total_rows.""" + login_client(client) + project_url = reverse("project-list") + resp = client.post( + project_url, + dict(name="preview_project", user_column="userId", item_column="movieId"), + ) + project_id = resp.json()["id"] + + data_url = reverse("training_data-list") + csv_file = NamedTemporaryFile(suffix=".csv") + ml100k.to_csv(csv_file, index=False) + csv_file.seek(0) + resp = client.post(data_url, dict(project=project_id, file=csv_file)) + assert resp.status_code == 201 + data_id = resp.json()["id"] + + preview_url = reverse("training_data-preview", args=[data_id]) + resp = client.get(preview_url) + assert resp.status_code == 200 + data = resp.json() + assert "columns" in data + assert "rows" in data + assert "total_rows" in data + assert len(data["columns"]) > 0 + assert data["total_rows"] > 0 + + def test_preview_n_rows_param(self, client, ml100k): + """n_rows limits returned rows.""" + login_client(client) + project_url = reverse("project-list") + resp = client.post( + project_url, + dict(name="preview_nrows", user_column="userId", item_column="movieId"), + ) + project_id = resp.json()["id"] + + data_url = reverse("training_data-list") + csv_file = NamedTemporaryFile(suffix=".csv") + ml100k.to_csv(csv_file, index=False) + csv_file.seek(0) + resp = client.post(data_url, dict(project=project_id, file=csv_file)) + data_id = resp.json()["id"] + + preview_url = reverse("training_data-preview", args=[data_id]) + resp = client.get(preview_url, {"n_rows": 5}) + assert resp.status_code == 200 + assert resp.json()["total_rows"] <= 5 + + def test_preview_unauthenticated_401(self, client): + """Unauthenticated -> 401.""" + # Try accessing any preview URL without login + # Use a fake ID since we want auth to fail first + resp = client.get("/api/v1/training_data/99999/preview/") + assert resp.status_code in (401, 403) + + def test_preview_cross_owner_denied(self, client, ml100k): + """User B can't preview User A's data.""" + login_client(client) # logs in as "admin" + project_url = reverse("project-list") + resp = client.post( + project_url, + dict(name="preview_cross", user_column="userId", item_column="movieId"), + ) + project_id = resp.json()["id"] + + data_url = reverse("training_data-list") + csv_file = NamedTemporaryFile(suffix=".csv") + ml100k.to_csv(csv_file, index=False) + csv_file.seek(0) + resp = client.post(data_url, dict(project=project_id, file=csv_file)) + data_id = resp.json()["id"] + + # Login as different user + User = get_user_model() + other = User.objects.create_user(username="other_preview_user", password="pass") + client.force_login(other) + + preview_url = reverse("training_data-preview", args=[data_id]) + resp = client.get(preview_url) + assert resp.status_code == 404 # OwnedResourceMixin filters it out diff --git a/backend/recotem/tests/test_event_views.py b/backend/recotem/tests/test_event_views.py index 35c23b86..6ef58350 100644 --- a/backend/recotem/tests/test_event_views.py +++ b/backend/recotem/tests/test_event_views.py @@ -6,7 +6,9 @@ from django.test import Client from django.urls import reverse +from recotem.api.authentication import generate_api_key from recotem.api.models import ( + ApiKey, ConversionEvent, DeploymentSlot, ModelConfiguration, @@ -18,6 +20,16 @@ User = get_user_model() +@pytest.fixture(autouse=True) +def _use_locmem_cache(settings): + """Use in-memory cache so tests work without Redis.""" + settings.CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + } + } + + @pytest.fixture def user(db): return User.objects.create_user(username="event_user", password="pass") @@ -144,3 +156,146 @@ def test_unauthenticated_cannot_list(self, client: Client): url = reverse("conversion_event-list") resp = client.get(url) assert resp.status_code in (401, 403) + + +@pytest.fixture +def other_user(db): + return User.objects.create_user(username="other_event_user", password="pass") + + +@pytest.fixture +def other_project(other_user): + return Project.objects.create( + name="other_project", user_column="u", item_column="i", owner=other_user + ) + + +@pytest.fixture +def api_key_data(user, project): + full_key, prefix, hashed = generate_api_key() + key_obj = ApiKey.objects.create( + project=project, + owner=user, + name="event-key", + key_prefix=prefix, + hashed_key=hashed, + scopes=["predict"], + ) + return full_key, key_obj + + +@pytest.fixture +def other_slot(other_project): + mc = ModelConfiguration.objects.create( + name="other_cfg", + project=other_project, + recommender_class_name="IALSRecommender", + parameters_json={}, + ) + td = TrainingData.objects.create(project=other_project) + td.file.save("other_data.csv", ContentFile(b"u,i\n1,2\n")) + tm = TrainedModel.objects.create(configuration=mc, data_loc=td) + return DeploymentSlot.objects.create( + project=other_project, name="other_slot", trained_model=tm, weight=100 + ) + + +@pytest.mark.django_db +class TestConversionEventAccessControl: + def test_api_key_wrong_project_rejected( + self, client: Client, api_key_data, other_project, slot + ): + """API key for project A can't create event for project B.""" + full_key, key_obj = api_key_data + url = reverse("conversion_event-list") + resp = client.post( + url, + { + "project": other_project.id, + "deployment_slot": slot.id, + "user_id": "user-1", + "event_type": "impression", + }, + content_type="application/json", + HTTP_X_API_KEY=full_key, + ) + assert resp.status_code == 403 + + def test_jwt_non_owner_rejected(self, client: Client, other_user, project, slot): + """JWT user who doesn't own project gets 403.""" + client.force_login(other_user) + url = reverse("conversion_event-list") + resp = client.post( + url, + { + "project": project.id, + "deployment_slot": slot.id, + "user_id": "user-1", + "event_type": "impression", + }, + content_type="application/json", + ) + assert resp.status_code == 403 + + def test_slot_project_mismatch_rejected(self, auth_client, project, other_slot): + """Slot from other_project used with project should be rejected.""" + url = reverse("conversion_event-list") + resp = auth_client.post( + url, + { + "project": project.id, + "deployment_slot": other_slot.id, + "user_id": "user-1", + "event_type": "impression", + }, + content_type="application/json", + ) + assert resp.status_code == 403 + + def test_api_key_predict_scope_allowed( + self, client: Client, api_key_data, project, slot + ): + """API key with predict scope can create event for its own project.""" + full_key, key_obj = api_key_data + url = reverse("conversion_event-list") + resp = client.post( + url, + { + "project": project.id, + "deployment_slot": slot.id, + "user_id": "user-1", + "event_type": "impression", + }, + content_type="application/json", + HTTP_X_API_KEY=full_key, + ) + assert resp.status_code == 201 + + def test_batch_validates_each_event( + self, client: Client, api_key_data, project, slot, other_project + ): + """Batch rejects if any event references the wrong project.""" + full_key, key_obj = api_key_data + url = reverse("conversion_event-batch") + resp = client.post( + url, + { + "events": [ + { + "project": project.id, + "deployment_slot": slot.id, + "user_id": "u1", + "event_type": "impression", + }, + { + "project": other_project.id, + "deployment_slot": slot.id, + "user_id": "u2", + "event_type": "click", + }, + ] + }, + content_type="application/json", + HTTP_X_API_KEY=full_key, + ) + assert resp.status_code == 403 diff --git a/backend/recotem/tests/test_project_service.py b/backend/recotem/tests/test_project_service.py new file mode 100644 index 00000000..4cb870cd --- /dev/null +++ b/backend/recotem/tests/test_project_service.py @@ -0,0 +1,147 @@ +"""Unit tests for project_service.py — project lookup and summary aggregation.""" + +import pytest +from django.contrib.auth import get_user_model +from django.core.files.base import ContentFile + +from recotem.api.exceptions import ResourceNotFoundError +from recotem.api.models import ( + EvaluationConfig, + ModelConfiguration, + ParameterTuningJob, + Project, + SplitConfig, + TrainedModel, + TrainingData, +) +from recotem.api.services.project_service import get_project_or_404, get_project_summary + +User = get_user_model() + + +@pytest.fixture(autouse=True) +def _use_locmem_cache(settings): + """Use in-memory cache so tests work without Redis.""" + settings.CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + } + } + + +@pytest.fixture +def user(db): + return User.objects.create_user(username="owner", password="OwnerPass123!") + + +@pytest.fixture +def other_user(db): + return User.objects.create_user(username="other", password="OtherPass123!") + + +@pytest.fixture +def staff_user(db): + return User.objects.create_user( + username="staff", password="StaffPass123!", is_staff=True + ) + + +@pytest.fixture +def project(user): + return Project.objects.create( + name="TestProject", + owner=user, + user_column="user_id", + item_column="item_id", + ) + + +@pytest.fixture +def unowned_project(db): + return Project.objects.create( + name="LegacyProject", + owner=None, + user_column="user_id", + item_column="item_id", + ) + + +@pytest.mark.django_db +class TestGetProjectOr404: + def test_existing_project_returned(self, project, user): + """Returns project when user is owner.""" + result = get_project_or_404(project.pk, user=user) + assert result.pk == project.pk + assert result.name == "TestProject" + + def test_nonexistent_project_raises(self, user): + """Raises ResourceNotFoundError for missing pk.""" + with pytest.raises(ResourceNotFoundError, match="not found"): + get_project_or_404(999999, user=user) + + def test_staff_bypasses_ownership(self, project, staff_user): + """Staff user can access any project regardless of ownership.""" + result = get_project_or_404(project.pk, user=staff_user) + assert result.pk == project.pk + + def test_non_owner_denied(self, project, other_user): + """Non-owner non-staff gets ResourceNotFoundError.""" + with pytest.raises(ResourceNotFoundError, match="not found"): + get_project_or_404(project.pk, user=other_user) + + def test_unowned_project_visible(self, unowned_project, other_user): + """Project with owner=None accessible by any authenticated user.""" + result = get_project_or_404(unowned_project.pk, user=other_user) + assert result.pk == unowned_project.pk + + def test_no_user_returns_project(self, project): + """user=None skips ownership check entirely.""" + result = get_project_or_404(project.pk, user=None) + assert result.pk == project.pk + + +@pytest.mark.django_db +class TestGetProjectSummary: + def test_counts_correct(self, project, user): + """Correct n_data, n_complete_jobs, n_models for populated project.""" + # Create TrainingData with a file and set filesize manually + # (post_save signal only fires on create, but file is saved after) + td = TrainingData.objects.create(project=project) + td.file.save("data.csv", ContentFile(b"user_id,item_id\n1,2\n")) + td.filesize = td.file.size + td.save(update_fields=["filesize"]) + + # Create ModelConfiguration linked to project + config = ModelConfiguration.objects.create( + name="cfg", + project=project, + recommender_class_name="TopPopRecommender", + parameters_json={}, + ) + + # Create a ParameterTuningJob with best_config pointing to config + sc = SplitConfig.objects.create(created_by=user) + ec = EvaluationConfig.objects.create(created_by=user) + ParameterTuningJob.objects.create( + data=td, split=sc, evaluation=ec, best_config=config + ) + + # Create TrainedModel with a file and set filesize manually + tm = TrainedModel.objects.create(configuration=config, data_loc=td) + tm.file.save("model.pkl", ContentFile(b"fake model data")) + tm.filesize = tm.file.size + tm.save(update_fields=["filesize"]) + + summary = get_project_summary(project) + assert summary["n_data"] == 1 + assert summary["n_complete_jobs"] == 1 + assert summary["n_models"] == 1 + assert summary["ins_datetime"] == project.ins_datetime + + def test_empty_project_zeros(self, project): + """All zeros for project with no data, configs, or models.""" + summary = get_project_summary(project) + assert summary["n_data"] == 0 + assert summary["n_complete_jobs"] == 0 + assert summary["n_models"] == 0 + assert summary["ins_datetime"] == project.ins_datetime diff --git a/backend/recotem/tests/test_schedule_service.py b/backend/recotem/tests/test_schedule_service.py index 9d2aac1a..39848b77 100644 --- a/backend/recotem/tests/test_schedule_service.py +++ b/backend/recotem/tests/test_schedule_service.py @@ -4,7 +4,11 @@ import pytest -from recotem.api.services.schedule_service import _parse_cron, sync_schedule_to_beat +from recotem.api.services.schedule_service import ( + _parse_cron, + delete_beat_task, + sync_schedule_to_beat, +) class TestParseCron: @@ -69,3 +73,24 @@ def test_enabled_schedule_creates_task(self, mock_crontab, mock_periodic): mock_periodic.objects.update_or_create.assert_called_once() call_kwargs = mock_periodic.objects.update_or_create.call_args assert "recotem_retrain_schedule_42" in str(call_kwargs) + + +@pytest.mark.django_db +class TestDeleteBeatTask: + @patch("recotem.api.services.schedule_service.PeriodicTask") + def test_deletes_existing_task(self, mock_periodic): + schedule = MagicMock() + schedule.id = 5 + delete_beat_task(schedule) + mock_periodic.objects.filter.assert_called_once_with( + name="recotem_retrain_schedule_5" + ) + mock_periodic.objects.filter.return_value.delete.assert_called_once() + + @patch("recotem.api.services.schedule_service.PeriodicTask") + def test_no_task_no_error(self, mock_periodic): + """Missing task -> no error (filter returns empty qs, delete does nothing).""" + schedule = MagicMock() + schedule.id = 999 + mock_periodic.objects.filter.return_value.delete.return_value = 0 + delete_beat_task(schedule) # Should not raise diff --git a/backend/recotem/tests/test_serializers.py b/backend/recotem/tests/test_serializers.py index 844047ec..825a520b 100644 --- a/backend/recotem/tests/test_serializers.py +++ b/backend/recotem/tests/test_serializers.py @@ -2,6 +2,7 @@ import pytest from django.contrib.auth import get_user_model +from rest_framework import serializers from rest_framework.test import APIRequestFactory from recotem.api.models import ( @@ -9,6 +10,9 @@ Project, SplitConfig, ) +from recotem.api.serializers.ab_test import ABTestSerializer +from recotem.api.serializers.deployment import DeploymentSlotSerializer +from recotem.api.serializers.events import ConversionEventSerializer from recotem.api.serializers.project import ProjectSerializer User = get_user_model() @@ -519,3 +523,64 @@ def test_missing_item_column_rejected(self, user, factory): serializer = ProjectSerializer(data=data, context={"request": request}) assert not serializer.is_valid() assert "item_column" in serializer.errors + + +@pytest.mark.django_db +class TestDeploymentSlotSerializer: + def test_weight_below_zero_rejected(self, user, factory): + p = Project.objects.create( + name="DeployWeight", user_column="u", item_column="i", owner=user + ) + request = factory.post("/") + request.user = user + s = DeploymentSlotSerializer( + data={"project": p.id, "name": "s", "weight": -1, "trained_model": 1}, + context={"request": request}, + ) + s.is_valid() + assert "weight" in s.errors + + def test_weight_above_100_rejected(self, user, factory): + p = Project.objects.create( + name="DeployOver", user_column="u", item_column="i", owner=user + ) + request = factory.post("/") + request.user = user + s = DeploymentSlotSerializer( + data={"project": p.id, "name": "s", "weight": 101, "trained_model": 1}, + context={"request": request}, + ) + s.is_valid() + assert "weight" in s.errors + + def test_weight_boundary_values(self, user, factory): + """0 and 100 should be accepted (weight validation only).""" + s0 = DeploymentSlotSerializer() + assert s0.validate_weight(0) == 0 + assert s0.validate_weight(100) == 100 + + +@pytest.mark.django_db +class TestABTestSerializerValidation: + def test_confidence_below_0_5_rejected(self): + s = ABTestSerializer() + with pytest.raises(serializers.ValidationError): + s.validate_confidence_level(0.4) + + def test_confidence_above_0_99_rejected(self): + s = ABTestSerializer() + with pytest.raises(serializers.ValidationError): + s.validate_confidence_level(1.0) + + def test_invalid_target_metric_rejected(self): + s = ABTestSerializer() + with pytest.raises(serializers.ValidationError): + s.validate_target_metric_name("invalid_metric") + + +@pytest.mark.django_db +class TestConversionEventSerializerValidation: + def test_invalid_event_type_rejected(self): + s = ConversionEventSerializer() + with pytest.raises(serializers.ValidationError): + s.validate_event_type("invalid_type") diff --git a/backend/recotem/tests/test_tasks.py b/backend/recotem/tests/test_tasks.py index 30ea4893..468e403c 100644 --- a/backend/recotem/tests/test_tasks.py +++ b/backend/recotem/tests/test_tasks.py @@ -9,19 +9,42 @@ from django_celery_results.models import TaskResult from recotem.api.models import ( + DeploymentSlot, EvaluationConfig, ModelConfiguration, ParameterTuningJob, Project, + RetrainingRun, + RetrainingSchedule, SplitConfig, TrainedModel, TrainingData, ) -from recotem.api.tasks import start_tuning_job, train_recommender_func +from recotem.api.tasks import ( + DEFAULT_SEARCH_RECOMMENDERS, + _auto_deploy_model, + _fail_retraining_run_for_job, + _finalize_retraining_run, + _get_search_recommender_classes, + _resolve_recommender_class_name, + start_tuning_job, + task_scheduled_retrain, + train_recommender_func, +) User = get_user_model() +@pytest.fixture(autouse=True) +def _use_locmem_cache(settings): + """Use in-memory cache so tests work without Redis.""" + settings.CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + } + } + + @pytest.fixture def user(db): return User.objects.create_user(username="task_tester", password="pass") @@ -460,3 +483,426 @@ def test_start_tuning_job_error_handling( job.refresh_from_db() assert job.status == ParameterTuningJob.Status.FAILED + + +# --------------------------------------------------------------------------- +# Tests for _resolve_recommender_class_name +# --------------------------------------------------------------------------- + + +class TestResolveRecommenderClassName: + def test_exact_name(self): + """A valid recommender class name resolves to itself.""" + result = _resolve_recommender_class_name("TopPopRecommender") + assert result == "TopPopRecommender" + + def test_unknown_returns_none(self): + """A completely unknown algorithm name returns None.""" + result = _resolve_recommender_class_name("CompletelyFakeAlgorithm") + assert result is None + + +# --------------------------------------------------------------------------- +# Tests for _get_search_recommender_classes +# --------------------------------------------------------------------------- + + +class TestGetSearchRecommenderClasses: + def test_none_returns_defaults(self): + """Passing None returns a copy of DEFAULT_SEARCH_RECOMMENDERS.""" + result = _get_search_recommender_classes(None) + assert result == DEFAULT_SEARCH_RECOMMENDERS + # Must be a copy, not the same list object + assert result is not DEFAULT_SEARCH_RECOMMENDERS + + def test_all_invalid_returns_defaults(self): + """When all algorithm names are invalid, fall back to defaults.""" + result = _get_search_recommender_classes(["FakeAlgo1", "FakeAlgo2"]) + assert result == DEFAULT_SEARCH_RECOMMENDERS + + +# --------------------------------------------------------------------------- +# Tests for _auto_deploy_model +# --------------------------------------------------------------------------- + + +@pytest.fixture +def training_data(project): + from django.core.files.uploadedfile import SimpleUploadedFile + + return TrainingData.objects.create( + project=project, + file=SimpleUploadedFile( + "dummy.csv", b"col1,col2\n1,2\n", content_type="text/csv" + ), + ) + + +@pytest.fixture +def model_config(project): + return ModelConfiguration.objects.create( + name="auto_deploy_config", + project=project, + recommender_class_name="IALSRecommender", + parameters_json={}, + ) + + +@pytest.fixture +def trained_model(model_config, training_data): + return TrainedModel.objects.create( + configuration=model_config, + data_loc=training_data, + ) + + +@pytest.fixture +def schedule(project): + return RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + ) + + +@pytest.mark.django_db +class TestAutoDeployModel: + def test_creates_new_slot(self, schedule, trained_model): + """Creates a new auto-deploy slot with weight=100 and is_active=True.""" + _auto_deploy_model(schedule, trained_model) + + slot = DeploymentSlot.objects.get( + project=schedule.project, + name=f"auto-deploy-{schedule.project.name}", + ) + assert slot.trained_model == trained_model + assert slot.weight == 100 + assert slot.is_active is True + + def test_updates_existing_slot( + self, schedule, trained_model, model_config, training_data + ): + """Updates an existing auto-deploy slot with a new model.""" + # Create first slot + _auto_deploy_model(schedule, trained_model) + + # Create a second trained model + new_model = TrainedModel.objects.create( + configuration=model_config, + data_loc=training_data, + ) + _auto_deploy_model(schedule, new_model) + + slots = DeploymentSlot.objects.filter( + project=schedule.project, + name=f"auto-deploy-{schedule.project.name}", + ) + assert slots.count() == 1 + slot = slots.first() + assert slot.trained_model == new_model + + +# --------------------------------------------------------------------------- +# Tests for _finalize_retraining_run +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db +class TestFinalizeRetrainingRun: + def test_no_linked_run( + self, user, project, split_config, eval_config, training_data + ): + """No matching RetrainingRun for the job — returns silently.""" + job = ParameterTuningJob.objects.create( + data=training_data, + split=split_config, + evaluation=eval_config, + status=ParameterTuningJob.Status.RUNNING, + ) + config = ModelConfiguration.objects.create( + name="finalize_config", + project=project, + recommender_class_name="IALSRecommender", + parameters_json={}, + ) + model = TrainedModel.objects.create( + configuration=config, + data_loc=training_data, + ) + # Should not raise + _finalize_retraining_run(job, model) + + def test_marks_completed( + self, + user, + project, + split_config, + eval_config, + training_data, + schedule, + ): + """Sets run status to COMPLETED, assigns model, sets completed_at.""" + job = ParameterTuningJob.objects.create( + data=training_data, + split=split_config, + evaluation=eval_config, + status=ParameterTuningJob.Status.RUNNING, + ) + config = ModelConfiguration.objects.create( + name="finalize_completed_config", + project=project, + recommender_class_name="IALSRecommender", + parameters_json={}, + ) + model = TrainedModel.objects.create( + configuration=config, + data_loc=training_data, + ) + run = RetrainingRun.objects.create( + schedule=schedule, + tuning_job=job, + status=RetrainingRun.Status.RUNNING, + ) + + _finalize_retraining_run(job, model) + + run.refresh_from_db() + assert run.status == RetrainingRun.Status.COMPLETED + assert run.trained_model == model + assert run.completed_at is not None + + schedule.refresh_from_db() + assert schedule.last_run_status == RetrainingRun.Status.COMPLETED + + def test_auto_deploy_triggered( + self, user, project, split_config, eval_config, training_data + ): + """When auto_deploy=True, a deployment slot is created.""" + auto_schedule = RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + auto_deploy=True, + ) + job = ParameterTuningJob.objects.create( + data=training_data, + split=split_config, + evaluation=eval_config, + status=ParameterTuningJob.Status.RUNNING, + ) + config = ModelConfiguration.objects.create( + name="finalize_auto_deploy_config", + project=project, + recommender_class_name="IALSRecommender", + parameters_json={}, + ) + model = TrainedModel.objects.create( + configuration=config, + data_loc=training_data, + ) + RetrainingRun.objects.create( + schedule=auto_schedule, + tuning_job=job, + status=RetrainingRun.Status.RUNNING, + ) + + _finalize_retraining_run(job, model) + + slot = DeploymentSlot.objects.get( + project=project, + name=f"auto-deploy-{project.name}", + ) + assert slot.trained_model == model + assert slot.weight == 100 + assert slot.is_active is True + + +# --------------------------------------------------------------------------- +# Tests for _fail_retraining_run_for_job +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db +class TestFailRetrainingRunForJob: + def test_marks_failed( + self, + user, + project, + split_config, + eval_config, + training_data, + schedule, + ): + """Sets RUNNING run to FAILED and updates schedule.""" + job = ParameterTuningJob.objects.create( + data=training_data, + split=split_config, + evaluation=eval_config, + status=ParameterTuningJob.Status.RUNNING, + ) + run = RetrainingRun.objects.create( + schedule=schedule, + tuning_job=job, + status=RetrainingRun.Status.RUNNING, + ) + + _fail_retraining_run_for_job(job.id) + + run.refresh_from_db() + assert run.status == RetrainingRun.Status.FAILED + assert run.completed_at is not None + assert run.error_message != "" + + schedule.refresh_from_db() + assert schedule.last_run_status == "FAILED" + + def test_non_running_not_changed( + self, + user, + project, + split_config, + eval_config, + training_data, + schedule, + ): + """A COMPLETED run is not modified.""" + job = ParameterTuningJob.objects.create( + data=training_data, + split=split_config, + evaluation=eval_config, + status=ParameterTuningJob.Status.COMPLETED, + ) + run = RetrainingRun.objects.create( + schedule=schedule, + tuning_job=job, + status=RetrainingRun.Status.COMPLETED, + ) + + _fail_retraining_run_for_job(job.id) + + run.refresh_from_db() + assert run.status == RetrainingRun.Status.COMPLETED + + def test_no_matching_run(self): + """No linked run for a nonexistent job ID — no crash.""" + # Use a job ID that does not exist + _fail_retraining_run_for_job(999999) + + +# --------------------------------------------------------------------------- +# Tests for task_scheduled_retrain +# --------------------------------------------------------------------------- + + +@pytest.mark.django_db +class TestTaskScheduledRetrain: + def test_schedule_not_found(self): + """Missing schedule ID returns early without crashing.""" + task_scheduled_retrain._orig_run(999999) + + def test_disabled_skipped(self, project): + """Disabled schedule is skipped.""" + disabled_schedule = RetrainingSchedule.objects.create( + project=project, + is_enabled=False, + cron_expression="0 2 * * 0", + ) + # Should return early without creating a RetrainingRun + task_scheduled_retrain._orig_run(disabled_schedule.id) + assert not RetrainingRun.objects.filter(schedule=disabled_schedule).exists() + + def test_no_training_data(self, user): + """No training data available — returns early.""" + empty_project = Project.objects.create( + name="empty_project", + owner=user, + user_column="userId", + item_column="movieId", + ) + sched = RetrainingSchedule.objects.create( + project=empty_project, + is_enabled=True, + cron_expression="0 2 * * 0", + ) + task_scheduled_retrain._orig_run(sched.id) + assert not RetrainingRun.objects.filter(schedule=sched).exists() + + @patch("recotem.api.tasks.train_and_save_model") + def test_train_with_config(self, mock_train, project, training_data, model_config): + """Schedule with model_configuration trains directly.""" + sched = RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + model_configuration=model_config, + training_data=training_data, + ) + + task_scheduled_retrain._orig_run(sched.id) + + mock_train.assert_called_once() + run = RetrainingRun.objects.filter(schedule=sched).first() + assert run is not None + assert run.status == RetrainingRun.Status.COMPLETED + assert run.trained_model is not None + + def test_no_config_no_retune_skipped(self, project, training_data): + """No model config and retune=False results in SKIPPED.""" + sched = RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + training_data=training_data, + retune=False, + ) + + task_scheduled_retrain._orig_run(sched.id) + + run = RetrainingRun.objects.filter(schedule=sched).first() + assert run is not None + assert run.status == RetrainingRun.Status.SKIPPED + assert "No model configuration" in run.error_message + + @patch( + "recotem.api.tasks.train_and_save_model", + side_effect=RuntimeError("training exploded"), + ) + def test_exception_marks_failed( + self, mock_train, project, training_data, model_config + ): + """Exception during training marks run as FAILED with error_message.""" + sched = RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + model_configuration=model_config, + training_data=training_data, + ) + + task_scheduled_retrain._orig_run(sched.id) + + run = RetrainingRun.objects.filter(schedule=sched).first() + assert run is not None + assert run.status == RetrainingRun.Status.FAILED + assert "training exploded" in run.error_message + + @patch("recotem.api.tasks.train_and_save_model") + def test_schedule_metadata_updated( + self, mock_train, project, training_data, model_config + ): + """last_run_at and last_run_status are updated after a run.""" + sched = RetrainingSchedule.objects.create( + project=project, + is_enabled=True, + cron_expression="0 2 * * 0", + model_configuration=model_config, + training_data=training_data, + ) + assert sched.last_run_at is None + assert sched.last_run_status is None + + task_scheduled_retrain._orig_run(sched.id) + + sched.refresh_from_db() + assert sched.last_run_at is not None + assert sched.last_run_status == RetrainingRun.Status.COMPLETED diff --git a/backend/recotem/tests/test_user_service.py b/backend/recotem/tests/test_user_service.py new file mode 100644 index 00000000..684cf770 --- /dev/null +++ b/backend/recotem/tests/test_user_service.py @@ -0,0 +1,99 @@ +"""Unit tests for user_service.py — user creation, activation, +and password management.""" + +import pytest +from django.contrib.auth import get_user_model +from django.core.exceptions import ValidationError + +from recotem.api.services.user_service import ( + activate_user, + admin_reset_password, + create_user, + deactivate_user, +) + +User = get_user_model() + + +@pytest.mark.django_db +class TestCreateUser: + def test_creates_user(self): + """User created with hashed password (not stored as plaintext).""" + user = create_user(username="newuser", password="SecurePass123!") + assert user.pk is not None + assert user.username == "newuser" + assert user.check_password("SecurePass123!") + # Password should be hashed, not stored as plaintext + assert user.password != "SecurePass123!" + + def test_weak_password_raises(self): + """Django validators reject weak password.""" + with pytest.raises(ValidationError): + create_user(username="weakuser", password="123") + + def test_staff_flag(self): + """is_staff=True propagated to created user.""" + user = create_user( + username="staffuser", password="StaffPass123!", is_staff=True + ) + assert user.is_staff is True + + +@pytest.mark.django_db +class TestDeactivateUser: + def test_sets_inactive(self, db): + """is_active set to False after deactivation.""" + user = User.objects.create_user(username="active", password="ActivePass123!") + assert user.is_active is True + + result = deactivate_user(user) + + assert result.is_active is False + # Verify persisted to database + user.refresh_from_db() + assert user.is_active is False + + +@pytest.mark.django_db +class TestActivateUser: + def test_sets_active(self, db): + """is_active set to True after activation.""" + user = User.objects.create_user( + username="inactive", password="InactivePass123!" + ) + user.is_active = False + user.save(update_fields=["is_active"]) + + result = activate_user(user) + + assert result.is_active is True + # Verify persisted to database + user.refresh_from_db() + assert user.is_active is True + + +@pytest.mark.django_db +class TestAdminResetPassword: + def test_password_changed(self, db): + """New password verifiable after reset.""" + user = User.objects.create_user( + username="resetuser", password="OldPassword123!" + ) + assert user.check_password("OldPassword123!") + + admin_reset_password(user, "NewPassword456!") + + # New password works + user.refresh_from_db() + assert user.check_password("NewPassword456!") + # Old password no longer works + assert not user.check_password("OldPassword123!") + + def test_weak_password_rejected(self, db): + """Raises ValidationError for weak new password.""" + user = User.objects.create_user( + username="resetuser2", password="StrongPassword123!" + ) + + with pytest.raises(ValidationError): + admin_reset_password(user, "123") diff --git a/frontend/src/__tests__/pages/DataUploadPage.test.ts b/frontend/src/__tests__/pages/DataUploadPage.test.ts index adf476df..79360e48 100644 --- a/frontend/src/__tests__/pages/DataUploadPage.test.ts +++ b/frontend/src/__tests__/pages/DataUploadPage.test.ts @@ -160,7 +160,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); // Should show progress bar while uploading expect(wrapper.find(".progress-bar").exists()).toBe(true); @@ -187,7 +187,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerUploadProgress(50, 100); await nextTick(); @@ -205,7 +205,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(200); await flushPromises(); @@ -223,7 +223,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); expect(wrapper.find(".progress-bar").exists()).toBe(true); xhrInstances[0].triggerLoad(200); @@ -239,7 +239,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(400, JSON.stringify({ detail: "Invalid CSV format" })); await flushPromises(); @@ -253,7 +253,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(400, JSON.stringify(["File too large"])); await flushPromises(); @@ -266,7 +266,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(400, JSON.stringify({ file: ["This field is required."] })); await flushPromises(); @@ -279,7 +279,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(500, "Internal Server Error"); await flushPromises(); @@ -292,7 +292,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerError(); await flushPromises(); @@ -305,7 +305,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerAbort(); await flushPromises(); @@ -318,7 +318,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); expect(wrapper.find(".progress-bar").exists()).toBe(true); xhrInstances[0].triggerLoad(400, JSON.stringify({ detail: "Bad request" })); @@ -335,7 +335,7 @@ describe("DataUploadPage", () => { // Trigger first upload that fails emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); xhrInstances[0].triggerLoad(500, "error"); await flushPromises(); @@ -345,7 +345,7 @@ describe("DataUploadPage", () => { // Click retry await retryBtn!.trigger("click"); - await nextTick(); + await flushPromises(); // A new XHR should be created expect(xhrInstances.length).toBe(2); @@ -377,7 +377,7 @@ describe("DataUploadPage", () => { const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); emitUploader(wrapper, mockFile); - await nextTick(); + await flushPromises(); expect(xhrInstances[0].setRequestHeader).toHaveBeenCalledWith( "Authorization", @@ -385,4 +385,303 @@ describe("DataUploadPage", () => { ); }); }); + + describe("pre-upload token refresh", () => { + it("calls ensureFreshToken before creating XHR", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + const spy = vi.spyOn(authStore, "ensureFreshToken").mockResolvedValue(true); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + expect(spy).toHaveBeenCalled(); + // XHR should still be created after ensureFreshToken + expect(xhrInstances.length).toBe(1); + expect(xhrInstances[0].send).toHaveBeenCalled(); + spy.mockRestore(); + }); + + it("proceeds with upload even when ensureFreshToken returns false", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + // ensureFreshToken returns false (no tokens) but upload should still attempt + const spy = vi.spyOn(authStore, "ensureFreshToken").mockResolvedValue(false); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // XHR is still created — the server will reject it with 401 + expect(xhrInstances.length).toBe(1); + expect(xhrInstances[0].send).toHaveBeenCalled(); + spy.mockRestore(); + }); + }); + + describe("401 retry logic", () => { + it("retries upload after refreshing token on 401 response", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + const refreshSpy = vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = "new-token"; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // First XHR returns 401 + expect(xhrInstances.length).toBe(1); + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Given token not valid for any token type" })); + await flushPromises(); + + // Should have called refreshAccessToken and created a second XHR + expect(refreshSpy).toHaveBeenCalled(); + expect(xhrInstances.length).toBe(2); + expect(xhrInstances[1].send).toHaveBeenCalled(); + + // Complete the retry successfully + xhrInstances[1].triggerLoad(200); + await flushPromises(); + + expect(notifyMock.success).toHaveBeenCalled(); + expect(wrapper.find(".message").exists()).toBe(false); + refreshSpy.mockRestore(); + }); + + it("uses refreshed token in retry XHR", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = "fresh-new-token"; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // First XHR used old token + expect(xhrInstances[0].setRequestHeader).toHaveBeenCalledWith( + "Authorization", + "Bearer old-token", + ); + + // Trigger 401 + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Token expired" })); + await flushPromises(); + + // Retry XHR should use the new token + expect(xhrInstances[1].setRequestHeader).toHaveBeenCalledWith( + "Authorization", + "Bearer fresh-new-token", + ); + }); + + it("shows error when retry also fails with non-401", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = "new-token"; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // First XHR returns 401 + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Token expired" })); + await flushPromises(); + + // Retry XHR returns 500 + xhrInstances[1].triggerLoad(500, JSON.stringify({ detail: "Server error" })); + await flushPromises(); + + // Should show error + expect(wrapper.find(".message").exists()).toBe(true); + }); + + it("shows error when token refresh fails during 401 retry", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + // refreshAccessToken fails → sets accessToken to null + vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = null; + authStore.refreshToken = null; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // First XHR returns 401 + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Token expired" })); + await flushPromises(); + + // Refresh failed — should show original error, no retry XHR created + expect(xhrInstances.length).toBe(1); + expect(wrapper.find(".message").exists()).toBe(true); + }); + + it("does not retry on non-401 errors", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + const refreshSpy = vi.spyOn(authStore, "refreshAccessToken"); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // 400 Bad Request — should NOT trigger retry + xhrInstances[0].triggerLoad(400, JSON.stringify({ detail: "Invalid CSV" })); + await flushPromises(); + + expect(refreshSpy).not.toHaveBeenCalled(); + expect(xhrInstances.length).toBe(1); + expect(wrapper.find(".message").exists()).toBe(true); + expect(wrapper.text()).toContain("Invalid CSV"); + refreshSpy.mockRestore(); + }); + + it("does not retry on 500 errors", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + const refreshSpy = vi.spyOn(authStore, "refreshAccessToken"); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + xhrInstances[0].triggerLoad(500, "Internal Server Error"); + await flushPromises(); + + expect(refreshSpy).not.toHaveBeenCalled(); + expect(xhrInstances.length).toBe(1); + refreshSpy.mockRestore(); + }); + + it("resets progress to 0 before retry upload", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = "new-token"; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // Simulate progress reaching 80% + xhrInstances[0].triggerUploadProgress(80, 100); + await nextTick(); + expect(wrapper.find(".progress-bar").attributes("data-value")).toBe("80"); + + // 401 triggers retry + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Token expired" })); + await flushPromises(); + + // Progress should be reset to 0 for the retry + expect(wrapper.find(".progress-bar").attributes("data-value")).toBe("0"); + }); + + it("hides progress bar after successful retry", async () => { + vi.useFakeTimers(); + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + authStore.accessToken = "old-token"; + authStore.refreshToken = "valid-refresh"; + + vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = "new-token"; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // 401 → retry + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Token expired" })); + await flushPromises(); + + // Retry succeeds + xhrInstances[1].triggerLoad(200); + await flushPromises(); + + expect(wrapper.find(".progress-bar").exists()).toBe(false); + expect(notifyMock.success).toHaveBeenCalled(); + vi.advanceTimersByTime(500); + expect(mockPush).toHaveBeenCalledWith("/projects/1/data"); + + vi.useRealTimers(); + }); + }); + + describe("error status property", () => { + it("attaches HTTP status to error on non-2xx response", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + // Spy on refreshAccessToken to verify 401 detection works via status + const refreshSpy = vi.spyOn(authStore, "refreshAccessToken").mockImplementation(async () => { + authStore.accessToken = null; + }); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // 401 should be detected by the retry logic (proves .status is set) + xhrInstances[0].triggerLoad(401, JSON.stringify({ detail: "Unauthorized" })); + await flushPromises(); + + expect(refreshSpy).toHaveBeenCalled(); + refreshSpy.mockRestore(); + }); + + it("does not trigger 401 retry for 403 forbidden", async () => { + const wrapper = mountPage(); + const { useAuthStore } = await import("@/stores/auth"); + const authStore = useAuthStore(); + const refreshSpy = vi.spyOn(authStore, "refreshAccessToken"); + + const mockFile = new File(["test"], "test.csv", { type: "text/csv" }); + emitUploader(wrapper, mockFile); + await flushPromises(); + + // 403 is different from 401 + xhrInstances[0].triggerLoad(403, JSON.stringify({ detail: "Forbidden" })); + await flushPromises(); + + expect(refreshSpy).not.toHaveBeenCalled(); + expect(xhrInstances.length).toBe(1); + expect(wrapper.text()).toContain("Forbidden"); + refreshSpy.mockRestore(); + }); + }); }); diff --git a/frontend/src/__tests__/stores/auth.test.ts b/frontend/src/__tests__/stores/auth.test.ts index fcabccde..cfe0a546 100644 --- a/frontend/src/__tests__/stores/auth.test.ts +++ b/frontend/src/__tests__/stores/auth.test.ts @@ -274,6 +274,177 @@ describe("useAuthStore", () => { }); }); + describe("ensureFreshToken", () => { + function makeToken(expMs: number): string { + const payload = { exp: expMs / 1000 }; + return `header.${btoa(JSON.stringify(payload))}.signature`; + } + + it("should_return_true_without_refreshing_when_token_has_more_than_60s_remaining", async () => { + // Token expires in 5 minutes — well above the 60s threshold + store.accessToken = makeToken(Date.now() + 300_000); + store.refreshToken = "valid-refresh"; + + const result = await store.ensureFreshToken(); + + expect(result).toBe(true); + // Should NOT have called the refresh endpoint + expect(mockOfetch).not.toHaveBeenCalled(); + }); + + it("should_refresh_and_return_true_when_token_expires_within_60s", async () => { + vi.useFakeTimers(); + // Token expires in 45s: within ensureFreshToken's 60s threshold, + // but proactive refresh sets a timer (delay > 0) instead of firing immediately + store.refreshToken = "valid-refresh"; + const newToken = makeToken(Date.now() + 300_000); + mockOfetch.mockResolvedValueOnce({ access: newToken }); + store.accessToken = makeToken(Date.now() + 45_000); + + const result = await store.ensureFreshToken(); + + expect(result).toBe(true); + expect(mockOfetch).toHaveBeenCalledWith( + "/api/v1/auth/token/refresh/", + expect.objectContaining({ method: "POST" }), + ); + expect(store.accessToken).toBe(newToken); + vi.useRealTimers(); + }); + + it("should_refresh_and_return_true_when_token_is_already_expired", async () => { + // Expired token triggers both proactive refresh and ensureFreshToken. + // Provide two mock responses so both can succeed. + store.refreshToken = "valid-refresh"; + const newToken = makeToken(Date.now() + 300_000); + mockOfetch + .mockResolvedValueOnce({ access: newToken }) // proactive refresh (via watch) + .mockResolvedValueOnce({ access: newToken }); // ensureFreshToken (if still needed) + store.accessToken = makeToken(Date.now() - 10_000); + + const result = await store.ensureFreshToken(); + + expect(result).toBe(true); + expect(store.accessToken).toBe(newToken); + }); + + it("should_return_false_when_no_access_token", async () => { + store.accessToken = null; + store.refreshToken = "valid-refresh"; + + const result = await store.ensureFreshToken(); + + expect(result).toBe(false); + expect(mockOfetch).not.toHaveBeenCalled(); + }); + + it("should_return_false_when_no_refresh_token", async () => { + store.accessToken = makeToken(Date.now() + 30_000); + store.refreshToken = null; + + const result = await store.ensureFreshToken(); + + expect(result).toBe(false); + expect(mockOfetch).not.toHaveBeenCalled(); + }); + + it("should_return_false_when_refresh_fails", async () => { + vi.useFakeTimers(); + // Token expires in 45s: within ensureFreshToken's threshold, + // proactive refresh uses a timer (delay > 0) + store.refreshToken = "expired-refresh"; + mockOfetch.mockRejectedValueOnce(new Error("Refresh failed")); + store.accessToken = makeToken(Date.now() + 45_000); + + const result = await store.ensureFreshToken(); + + // refreshAccessToken failed → logout → accessToken becomes null + expect(result).toBe(false); + expect(store.accessToken).toBeNull(); + vi.useRealTimers(); + }); + + it("should_return_true_for_token_with_exactly_60s_remaining", async () => { + vi.useFakeTimers(); + // Token expires in ~59s — within ensureFreshToken's 60s threshold, + // but proactive refresh calculates delay > 0 so uses a timer + store.refreshToken = "valid-refresh"; + const newToken = makeToken(Date.now() + 300_000); + mockOfetch.mockResolvedValueOnce({ access: newToken }); + store.accessToken = makeToken(Date.now() + 59_000); + + const result = await store.ensureFreshToken(); + + expect(result).toBe(true); + expect(mockOfetch).toHaveBeenCalled(); + vi.useRealTimers(); + }); + + it("should_not_refresh_when_token_has_no_parseable_exp", async () => { + // Token without valid exp field — getTokenExpiry returns null + store.accessToken = "not-a-real-jwt"; + store.refreshToken = "valid-refresh"; + + const result = await store.ensureFreshToken(); + + // exp is null → condition (exp !== null && ...) is false → returns true + expect(result).toBe(true); + expect(mockOfetch).not.toHaveBeenCalled(); + }); + }); + + describe("visibilitychange", () => { + function makeToken(expMs: number): string { + const payload = { exp: expMs / 1000 }; + return `header.${btoa(JSON.stringify(payload))}.signature`; + } + + it("should_trigger_refresh_when_tab_becomes_visible_with_expired_token", async () => { + // Setup: authenticated user with an expired token + store.accessToken = makeToken(Date.now() - 5_000); // expired 5s ago + store.refreshToken = "valid-refresh"; + const newToken = makeToken(Date.now() + 300_000); + mockOfetch.mockResolvedValueOnce({ access: newToken }); + + // Simulate tab becoming visible + document.dispatchEvent(new Event("visibilitychange")); + // scheduleProactiveRefresh sees delay<=0 and calls refreshAccessToken synchronously + // but refreshAccessToken is async, so flush + await vi.waitFor(() => { + expect(mockOfetch).toHaveBeenCalledWith( + "/api/v1/auth/token/refresh/", + expect.objectContaining({ method: "POST" }), + ); + }); + }); + + it("should_not_trigger_refresh_when_tab_becomes_visible_without_token", () => { + store.accessToken = null; + store.refreshToken = null; + mockOfetch.mockClear(); + + // Simulate tab becoming visible + document.dispatchEvent(new Event("visibilitychange")); + + expect(mockOfetch).not.toHaveBeenCalled(); + }); + + it("should_schedule_timer_when_tab_becomes_visible_with_fresh_token", () => { + vi.useFakeTimers(); + // Token with plenty of time remaining + store.accessToken = makeToken(Date.now() + 300_000); + store.refreshToken = "valid-refresh"; + mockOfetch.mockClear(); + + document.dispatchEvent(new Event("visibilitychange")); + + // Should NOT immediately call refresh (token is still fresh) + expect(mockOfetch).not.toHaveBeenCalled(); + // But a timer should be scheduled (we can verify by advancing time) + vi.useRealTimers(); + }); + }); + describe("initialization", () => { it("should_load_tokens_from_sessionStorage_on_creation", () => { // Arrange diff --git a/frontend/src/pages/DataUploadPage.vue b/frontend/src/pages/DataUploadPage.vue index a1638908..a6ed7a75 100644 --- a/frontend/src/pages/DataUploadPage.vue +++ b/frontend/src/pages/DataUploadPage.vue @@ -115,7 +115,9 @@ function uploadWithProgress(file: File): Promise { detail = body.detail ?? body.file?.[0] ?? body.project?.[0] ?? detail; } } catch { /* ignore parse errors */ } - reject(new Error(detail)); + const err = new Error(detail) as Error & { status: number }; + err.status = xhr.status; + reject(err); } }); @@ -135,10 +137,25 @@ async function doUpload(file: File) { lastFile.value = file; try { + await authStore.ensureFreshToken(); await uploadWithProgress(file); notify.success(t("data.uploadSuccess")); setTimeout(() => router.push(`/projects/${projectId}/data`), 500); } catch (e) { + if ((e as any).status === 401) { + // Token expired during upload — refresh and retry once + try { + await authStore.refreshAccessToken(); + if (!authStore.accessToken) throw e; + progress.value = 0; + await uploadWithProgress(file); + notify.success(t("data.uploadSuccess")); + setTimeout(() => router.push(`/projects/${projectId}/data`), 500); + return; + } catch { + // Retry also failed + } + } uploadError.value = (e as Error).message || t("data.uploadFailed"); } finally { uploading.value = false; diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index f4f2bfc1..8ca7e61a 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -137,6 +137,23 @@ export const useAuthStore = defineStore("auth", () => { } } + async function ensureFreshToken(): Promise { + if (!accessToken.value || !refreshToken.value) return false; + const exp = getTokenExpiry(accessToken.value); + if (exp !== null && exp - Date.now() < 60_000) { + await refreshAccessToken(); + return !!accessToken.value; + } + return true; + } + + function handleVisibilityChange() { + if (document.visibilityState === "visible" && accessToken.value) { + scheduleProactiveRefresh(); + } + } + document.addEventListener("visibilitychange", handleVisibilityChange); + // Schedule initial proactive refresh if already authenticated scheduleProactiveRefresh(); @@ -149,5 +166,6 @@ export const useAuthStore = defineStore("auth", () => { logout, fetchUser, refreshAccessToken, + ensureFreshToken, }; });