diff --git a/backend/annotation/migrations/0003_labelannotation_and_more.py b/backend/annotation/migrations/0003_labelannotation_and_more.py new file mode 100644 index 0000000..a3398d0 --- /dev/null +++ b/backend/annotation/migrations/0003_labelannotation_and_more.py @@ -0,0 +1,183 @@ +# Generated by Django 4.2.27 on 2026-02-01 13:27 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("problem", "0007_alter_problem_options"), + ("annotation", "0002_label_labeling"), + ] + + operations = [ + migrations.CreateModel( + name="LabelAnnotation", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ( + "removed_at", + models.DateTimeField( + blank=True, + help_text="When this annotation was removed from the problem.", + null=True, + ), + ), + ( + "notes", + models.TextField( + blank=True, + help_text="Optional notes explaining why this annotation was added or removed.", + ), + ), + ( + "created_by", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss_created", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "label", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="label_annotations", + to="annotation.label", + ), + ), + ( + "problem", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss", + to="problem.problem", + ), + ), + ( + "removed_by", + models.ForeignKey( + blank=True, + help_text="User who removed this annotation.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss_removed", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "session", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss", + to="annotation.annotationsession", + ), + ), + ], + options={ + "ordering": ["-created_at"], + "permissions": [ + ( + "delete_own_labelannotation", + "Can remove own label annotation from problems", + ), + ( + "delete_any_labelannotation", + "Can remove any label annotation from problems", + ), + ], + "abstract": False, + }, + ), + migrations.RemoveField( + model_name="knowledgebaseannotation", + name="knowledge_base", + ), + migrations.AddField( + model_name="knowledgebaseannotation", + name="created_by", + field=models.ForeignKey( + default=0, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss_created", + to=settings.AUTH_USER_MODEL, + ), + preserve_default=False, + ), + migrations.AddField( + model_name="knowledgebaseannotation", + name="notes", + field=models.TextField( + blank=True, + help_text="Optional notes explaining why this annotation was added or removed.", + ), + ), + migrations.AddField( + model_name="knowledgebaseannotation", + name="problem", + field=models.ForeignKey( + default=0, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss", + to="problem.problem", + ), + preserve_default=False, + ), + migrations.AddField( + model_name="knowledgebaseannotation", + name="removed_at", + field=models.DateTimeField( + blank=True, + help_text="When this annotation was removed from the problem.", + null=True, + ), + ), + migrations.AddField( + model_name="knowledgebaseannotation", + name="removed_by", + field=models.ForeignKey( + blank=True, + help_text="User who removed this annotation.", + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss_removed", + to=settings.AUTH_USER_MODEL, + ), + ), + migrations.AlterField( + model_name="knowledgebaseannotation", + name="session", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="%(class)ss", + to="annotation.annotationsession", + ), + ), + migrations.DeleteModel( + name="Labeling", + ), + migrations.AddIndex( + model_name="labelannotation", + index=models.Index( + fields=["problem", "removed_at"], name="annotation__problem_7b6d4f_idx" + ), + ), + migrations.AddIndex( + model_name="labelannotation", + index=models.Index( + fields=["label", "removed_at"], name="annotation__label_i_243aa0_idx" + ), + ), + ] diff --git a/backend/annotation/models.py b/backend/annotation/models.py index 0286386..b3b9dc5 100644 --- a/backend/annotation/models.py +++ b/backend/annotation/models.py @@ -1,7 +1,7 @@ from django.conf import settings from django.db import models -from problem.models import KnowledgeBase, Problem, Sentence +from problem.models import Problem, Sentence class AnnotationSession(models.Model): @@ -57,30 +57,73 @@ class ProblemAnnotation(models.Model): created_at = models.DateTimeField(auto_now_add=True) -class KnowledgeBaseAnnotation(models.Model): +class BaseAnnotation(models.Model): session = models.ForeignKey( AnnotationSession, on_delete=models.CASCADE, - related_name="kb_annotations", + related_name="%(class)ss", ) - knowledge_base = models.ForeignKey( - KnowledgeBase, + problem = models.ForeignKey( + Problem, + on_delete=models.CASCADE, + related_name="%(class)ss", + ) + + created_at = models.DateTimeField(auto_now_add=True) + created_by = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name="%(class)ss_created", + ) + + removed_at = models.DateTimeField( + null=True, + blank=True, + help_text="When this annotation was removed from the problem.", + ) + removed_by = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="kb_annotations", + related_name="%(class)ss_removed", + null=True, + blank=True, + help_text="User who removed this annotation.", + ) + + notes = models.TextField( + blank=True, + help_text="Optional notes explaining why this annotation was added or removed.", ) + class Meta: + abstract = True + + def is_active(self) -> bool: + """Check if this annotation is currently active (not removed).""" + return self.removed_at is None + + +class KnowledgeBaseAnnotation(BaseAnnotation): + class Relationship(models.TextChoices): + EQUAL = "equal", "Equal" + NOT_EQUAL = "not_equal", "Not Equal" + SUBSET = "subset", "Subset" + SUPERSET = "superset", "Superset" + entity1 = models.CharField(max_length=255) entity2 = models.CharField(max_length=255) relationship = models.CharField( max_length=255, - choices=KnowledgeBase.Relationship.choices, - default=KnowledgeBase.Relationship.EQUAL, + choices=Relationship.choices, + default=Relationship.EQUAL, ) - created_at = models.DateTimeField(auto_now_add=True) + def __str__(self): + status = "active" if self.is_active() else f"removed at {self.removed_at}" + return f"KB Annotation ({self.entity1} {self.relationship} {self.entity2}) on Problem {self.problem.pk} ({status})" class Label(models.Model): @@ -96,68 +139,41 @@ def __str__(self): return self.text -class Labeling(models.Model): +class LabelAnnotation(BaseAnnotation): """ The attachment of a label to a problem. - Each time a label is attached to a problem, a new Labeling record is created. - When removed, the record is marked as removed (not deleted), so the history of labelings is preserved. + Each time a label is attached to a problem, a new LabelAnnotation record + is created. When removed, the record is marked as removed (not deleted), + so the history of labelings is preserved. """ - - problem = models.ForeignKey( - Problem, - on_delete=models.CASCADE, - related_name="labelings", - ) - label = models.ForeignKey( Label, on_delete=models.CASCADE, - related_name="labelings", - ) - - attached_at = models.DateTimeField(auto_now_add=True) - attached_by = models.ForeignKey( - settings.AUTH_USER_MODEL, - on_delete=models.CASCADE, - related_name="labelings_attached", + related_name="label_annotations", ) - removed_at = models.DateTimeField( - null=True, - blank=True, - help_text="When this label was removed from the problem.", - ) - removed_by = models.ForeignKey( - settings.AUTH_USER_MODEL, - on_delete=models.CASCADE, - related_name="labelings_removed", - null=True, - blank=True, - help_text="User who removed this label.", - ) - - notes = models.TextField( - blank=True, - help_text="Optional notes explaining why this label was added or removed.", - ) - - class Meta: - ordering = ["-attached_at"] + class Meta(BaseAnnotation.Meta): + ordering = ["-created_at"] indexes = [ models.Index(fields=["problem", "removed_at"]), models.Index(fields=["label", "removed_at"]), ] permissions = [ - ("delete_own_labeling", "Can remove own labeling from problems"), - ("delete_any_labeling", "Can remove any labeling from problems"), + ( + "delete_own_labelannotation", + "Can remove own label annotation from problems", + ), + ( + "delete_any_labelannotation", + "Can remove any label annotation from problems", + ), ] - - def is_active(self) -> bool: - """Check if this labeling is currently active (not removed).""" - return self.removed_at is None - def __str__(self): status = "active" if self.is_active() else f"removed at {self.removed_at}" return f"Label '{self.label.text}' on Problem {self.problem.pk} ({status})" + + def is_attached_by_user(self, user) -> bool: + """Check if this label annotation was created by the given user.""" + return self.created_by == user diff --git a/backend/annotation/serializers.py b/backend/annotation/serializers.py index 02f8f93..34f66a3 100644 --- a/backend/annotation/serializers.py +++ b/backend/annotation/serializers.py @@ -2,98 +2,159 @@ from django.contrib.auth.models import AnonymousUser -from annotation.models import Label, Labeling +from annotation.models import ( + KnowledgeBaseAnnotation, + Label, + LabelAnnotation, +) from problem.models import Problem from user.models import User -class LabelSerializer(serializers.ModelSerializer): +class AnnotationBaseSerializer(serializers.ModelSerializer): """ - Serializer for Label model. + Base serializer for AnnotationBase model. """ + createdAt = serializers.DateTimeField(source="created_at", read_only=True) + createdBy = serializers.PrimaryKeyRelatedField(source="created_by", read_only=True) + removedAt = serializers.DateTimeField( + source="removed_at", allow_null=True, read_only=True + ) + removedBy = serializers.PrimaryKeyRelatedField( + source="removed_by", allow_null=True, read_only=True + ) + removable = serializers.SerializerMethodField(read_only=True) + class Meta: - model = Label - fields = ["id", "text", "description"] + model = None # To be set in subclasses + fields = [ + "id", + "session", + "problem", + "createdAt", + "createdBy", + "removedAt", + "removedBy", + "notes", + "removable", + ] + abstract = True + def get_removable(self, annotation) -> bool: + """This should be overridden in subclasses.""" + raise NotImplementedError("Subclasses must implement get_removable method.") -class ActiveLabelSerializer(serializers.Serializer): - """ - Serializer for active labels attached to a problem. - Includes attachedInfo and removable status based on current user. - """ - id = serializers.IntegerField(source="label.id") - text = serializers.CharField(source="label.text") - description = serializers.CharField(source="label.description") - attachedInfo = serializers.SerializerMethodField() - removable = serializers.SerializerMethodField() +class KnowledgeBaseAnnotationSerializer(AnnotationBaseSerializer): + """ + Serializer for the KnowledgeBaseAnnotation model. - def get_attachedInfo(self, labeling: Labeling) -> dict: - """Get attachment information for the label.""" - request = self.context.get("request") - user: User | AnonymousUser | None = request.user if request else None + Requires context to be set with the current user for determining + removability, e.g. KnowledgeBaseAnnotationSerializer(annotation, context={"user": request.user}) + """ - if user and user.is_anonymous is False: - attached_by_current_user = labeling.attached_by.pk == user.pk - else: - attached_by_current_user = False + id = serializers.IntegerField(required=False, allow_null=True) + # Mark relationship as required. DRF thinks it is optional because it has a + # default value in the model. + relationship = serializers.ChoiceField( + choices=KnowledgeBaseAnnotation.Relationship.choices, + required=True, + ) + session = serializers.PrimaryKeyRelatedField(read_only=True) + problem = serializers.PrimaryKeyRelatedField(read_only=True) - return { - "userName": labeling.attached_by.username, - "date": labeling.attached_at.isoformat(), - "attachedByCurrentUser": attached_by_current_user, - } + class Meta(AnnotationBaseSerializer.Meta): + model = KnowledgeBaseAnnotation + fields = [ + "id", + "entity1", + "entity2", + "relationship", + ] + AnnotationBaseSerializer.Meta.fields - def get_removable(self, labeling: Labeling) -> bool: - """Determine if the label is removable by the current user.""" - request = self.context.get("request") - user: User | AnonymousUser | None = request.user if request else None + def get_removable(self, annotation: KnowledgeBaseAnnotation) -> bool: + """Determine if the KB annotation is removable by the current user.""" + user: User | AnonymousUser | None = self.context.get("user", None) if user is None or user.is_anonymous: return False - if user.is_superuser or user.has_perm("annotation.delete_any_labeling"): - return True + return user.has_perm("annotation.delete_knowledgebaseannotation") - if user.has_perm("annotation.delete_own_labeling"): - return labeling.attached_by.pk == user.pk + def validate_id(self, value): + """Validate that the KnowledgeBaseAnnotation ID exists if provided.""" + if value is None or KnowledgeBaseAnnotation.objects.filter(id=value).exists(): + return value + raise serializers.ValidationError( + f"KnowledgeBaseAnnotation item with ID {value} does not exist." + ) - return False +class LabelSerializer(serializers.ModelSerializer): + """ + Serializer for Label model. + """ + + class Meta: + model = Label + fields = ["id", "text", "description"] -class LabelingSerializer(serializers.ModelSerializer): +class LabelAnnotationSerializer(AnnotationBaseSerializer): """ - Serializer for Labeling model, including the full label details. + Serializer for LabelAnnotation model. + + Requires context to be set with the current user for determining + removability, e.g. LabelAnnotationSerializer(annotation, context={"user": request.user}) """ label = LabelSerializer(read_only=True) - attachedAt = serializers.DateTimeField(source="attached_at") - attachedBy = serializers.PrimaryKeyRelatedField( - source="attached_by", read_only=True - ) - removedAt = serializers.DateTimeField(source="removed_at", allow_null=True) - removedBy = serializers.PrimaryKeyRelatedField( - source="removed_by", allow_null=True, read_only=True + label_id = serializers.PrimaryKeyRelatedField( + queryset=Label.objects.all(), + source="label", + write_only=True, + required=False, ) + attachedByCurrentUser = serializers.SerializerMethodField(read_only=True) + removable = serializers.SerializerMethodField(read_only=True) - class Meta: - model = Labeling + class Meta(AnnotationBaseSerializer.Meta): + model = LabelAnnotation fields = [ "id", "label", - "attached_at", - "attached_by", - "removed_at", - "removed_by", - "notes", - ] + "label_id", + "attachedByCurrentUser", + "removable", + ] + AnnotationBaseSerializer.Meta.fields + def get_attachedByCurrentUser(self, annotation: LabelAnnotation) -> bool: + """Determine if the label was attached by the current user.""" + user: User | AnonymousUser | None = self.context.get("user", None) + + if user and user.is_anonymous is False: + return annotation.is_attached_by_user(user) + return False + + def get_removable(self, annotation: LabelAnnotation) -> bool: + """Determine if the label annotation is removable by the current user.""" + user: User | AnonymousUser | None = self.context.get("user", None) + + if user is None or user.is_anonymous: + return False + + if user.is_superuser or user.has_perm("annotation.delete_any_labelannotation"): + return True + + if user.has_perm("annotation.delete_own_labelannotation"): + return annotation.is_attached_by_user(user) + + return False class SelectedLabelSerializer(serializers.Serializer): """Serializer for a selected label in the save labels input.""" - id = serializers.IntegerField() + id = serializers.IntegerField(required=False) def validate_id(self, value): """Validate that the label exists.""" @@ -110,7 +171,6 @@ class SaveLabelsInputSerializer(serializers.Serializer): problemId = serializers.IntegerField() selectedLabels = SelectedLabelSerializer(many=True, allow_empty=True) - remarks = serializers.CharField(required=False, allow_blank=True, default="") def validate_problemId(self, value): """Validate that the problem exists.""" diff --git a/backend/annotation/serializers_test.py b/backend/annotation/serializers_test.py new file mode 100644 index 0000000..c688f08 --- /dev/null +++ b/backend/annotation/serializers_test.py @@ -0,0 +1,212 @@ +import pytest +from typing import Any + +from django.utils import timezone +from django.contrib.auth.models import Permission +from rest_framework.test import APIRequestFactory + +from annotation.serializers import ( + KnowledgeBaseAnnotationSerializer, + LabelSerializer, + LabelAnnotationSerializer, + SaveLabelsInputSerializer, +) +from problem.serializers import ProblemSerializer + + +@pytest.mark.django_db +def test_invalid_kb_item_id(annotator_session, sample_problem): + """Test that a non-existent kbItem ID is invalid.""" + data = { + "id": 9999, + "relationship": "equal", + "entity1": "e1", + "entity2": "e2", + "session": annotator_session.pk, + "problem": sample_problem.pk, + } + serializer = KnowledgeBaseAnnotationSerializer(data=data) + assert not serializer.is_valid() + assert "id" in serializer.errors + + +@pytest.mark.django_db +def test_valid_kb_annotation_data(annotator_session, sample_problem): + """Test that valid KB annotation data is accepted.""" + data = { + "relationship": "equal", + "entity1": "cat", + "entity2": "feline", + "session": annotator_session.pk, + "problem": sample_problem.pk, + } + serializer = KnowledgeBaseAnnotationSerializer(data=data) + assert serializer.is_valid(), serializer.errors + + +@pytest.mark.django_db +def test_kb_annotation_update(kb_annotation): + """Test updating an existing KB annotation.""" + serializer = KnowledgeBaseAnnotationSerializer(kb_annotation) + updated_data = { + "entity1": "updated_e1", + "entity2": "updated_e2", + "relationship": "subset", + } + + updated = serializer.update(kb_annotation, updated_data) + + assert updated.entity1 == updated_data["entity1"] + assert updated.entity2 == updated_data["entity2"] + assert updated.relationship == updated_data["relationship"] + + +@pytest.mark.django_db +def test_kb_annotation_get_removable_with_permission(kb_annotation, annotator): + """Test that get_removable returns True when user has permission.""" + permission = Permission.objects.get( + codename="delete_knowledgebaseannotation", content_type__app_label="annotation" + ) + annotator.user_permissions.add(permission) + annotator.refresh_from_db() + + factory = APIRequestFactory() + request = factory.get("/") + request.user = annotator + + serializer = KnowledgeBaseAnnotationSerializer( + kb_annotation, context={"user": request.user} + ) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is True + + +@pytest.mark.django_db +def test_kb_annotation_get_removable_without_permission(kb_annotation, visitor): + """Test that get_removable returns False when user lacks permission.""" + factory = APIRequestFactory() + request = factory.get("/") + request.user = visitor + + serializer = KnowledgeBaseAnnotationSerializer( + kb_annotation, context={"user": request.user} + ) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is False + + +@pytest.mark.django_db +def test_label_serialization(sample_label): + """Test serializing a label.""" + serializer = LabelSerializer(sample_label) + data: dict[str, Any] = serializer.data # type: ignore + + assert data["id"] == sample_label.pk + assert data["text"] == sample_label.text + assert data["description"] == sample_label.description + + +@pytest.mark.django_db +def test_label_annotation_serialization(label_annotation): + """Test serializing a label annotation.""" + serializer = LabelAnnotationSerializer( + label_annotation, context={"user": label_annotation.created_by} + ) + data: dict[str, Any] = serializer.data # type: ignore + + assert data["id"] == label_annotation.pk + assert data["label"]["text"] == label_annotation.label.text + assert data["attachedByCurrentUser"] is True + + +@pytest.mark.django_db +def test_label_annotation_attached_by_different_user(label_annotation, visitor): + """Test attachedByCurrentUser is False for different user.""" + serializer = LabelAnnotationSerializer(label_annotation, context={"user": visitor}) + data: dict[str, Any] = serializer.data # type: ignore + + assert data["attachedByCurrentUser"] is False + + +@pytest.mark.django_db +def test_label_annotation_removable_by_creator(label_annotation, annotator): + """Test that creator can remove their own annotation with proper permission.""" + perm = Permission.objects.get(codename="delete_own_labelannotation") + annotator.user_permissions.add(perm) + + serializer = LabelAnnotationSerializer( + label_annotation, context={"user": annotator} + ) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is True + + +@pytest.mark.django_db +def test_label_annotation_not_removable_by_non_creator( + label_annotation, annotator, master_annotator +): + """Test that a regular Annotator user cannot remove annotation they did not create.""" + perm = Permission.objects.get(codename="delete_own_labelannotation") + annotator.user_permissions.add(perm) + + # Change the creator to a different user + label_annotation.created_by = master_annotator + label_annotation.save() + + serializer = LabelAnnotationSerializer( + label_annotation, context={"user": annotator} + ) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is False + + +@pytest.mark.django_db +def test_label_annotation_removable_by_master(label_annotation, master_annotator): + """Test that master annotator can remove any annotation.""" + serializer = LabelAnnotationSerializer( + label_annotation, context={"user": master_annotator} + ) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is True + + +@pytest.mark.django_db +def test_label_annotation_not_removable_without_permission(label_annotation, visitor): + """Test that user without permission cannot remove annotation.""" + serializer = LabelAnnotationSerializer(label_annotation, context={"user": visitor}) + data: dict[str, Any] = serializer.data # type: ignore + assert data["removable"] is False + + +@pytest.mark.django_db +def test_save_labels_input_valid(sample_problem, sample_label): + """Test valid save labels input data.""" + data = { + "problemId": sample_problem.pk, + "selectedLabels": [{"id": sample_label.pk}], + } + serializer = SaveLabelsInputSerializer(data=data) + assert serializer.is_valid(), serializer.errors + + +@pytest.mark.django_db +def test_save_labels_input_empty_labels(sample_problem): + """Test save labels input with empty label list.""" + data = { + "problemId": sample_problem.pk, + "selectedLabels": [], + } + serializer = SaveLabelsInputSerializer(data=data) + assert serializer.is_valid() + + +@pytest.mark.django_db +def test_save_labels_input_invalid_problem(): + """Test that non-existent problem ID is invalid.""" + data = { + "problemId": 9999, + "selectedLabels": [], + } + serializer = SaveLabelsInputSerializer(data=data) + assert not serializer.is_valid() + assert "problemId" in serializer.errors diff --git a/backend/annotation/views.py b/backend/annotation/views.py index aefa18e..013608a 100644 --- a/backend/annotation/views.py +++ b/backend/annotation/views.py @@ -10,35 +10,44 @@ SAFE_METHODS, ) -from annotation.models import Label, Labeling +from annotation.models import AnnotationSession, Label, LabelAnnotation from annotation.serializers import ( + LabelAnnotationSerializer, LabelSerializer, SaveLabelsInputSerializer, ) +from django.contrib.auth.models import AnonymousUser from problem.models import Problem from user.models import User from langpro_annotator.logger import logger -class SaveLabelingsPermission(IsAuthenticated): - """Permission class for saving labelings.""" +class SaveLabelAnnotationPermission(IsAuthenticated): + """Permission class for saving label annotations.""" def has_permission(self, request, view): if not super().has_permission(request, view): return False - if request.user.is_superuser: + user: User | AnonymousUser | None = request.user + + if user is None or user.is_anonymous: + return False + + if user.is_superuser: return True - return request.user.has_perm("annotation.add_labeling") + return user.has_perm("annotation.add_labelannotation") or user.has_perm( + "annotation.change_labelannotation" + ) -class LabelView(ModelViewSet): +class LabelAnnotationView(ModelViewSet): """ - ViewSet for Label model. + ViewSet for the Label and LabelAnnotation models. GET: All users can list and retrieve labels. - POST: Only selected users can save labelings (attach/remove labels from problems). + POST: Only selected users can save label annotations (attach/remove labels to/from problems). """ queryset = Label.objects.all().order_by("text") @@ -48,16 +57,15 @@ class LabelView(ModelViewSet): def get_permissions(self): if self.request.method in SAFE_METHODS: return [AllowAny()] - return [SaveLabelingsPermission()] + return [SaveLabelAnnotationPermission()] def create(self, request: Request) -> Response: """ - Save labelings for a problem (attach/remove labels). + Create annotations by attaching/removing labels to/from a problem. Expects a payload with: - problemId: ID of the problem - selectedLabels: List of labels to be attached (with at least 'id' field) - - remarks: Optional notes """ serializer = SaveLabelsInputSerializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -65,74 +73,65 @@ def create(self, request: Request) -> Response: problem_id = validated_data["problemId"] selected_labels = validated_data["selectedLabels"] - remarks = validated_data.get("remarks", "") problem = Problem.objects.get(id=problem_id) user: User = request.user # type: ignore selected_label_ids = {label["id"] for label in selected_labels} - self._update_labelings(problem, user, selected_label_ids, remarks) + self._update_label_annotations(problem, user, selected_label_ids) return Response({"ok": True}) - def _update_labelings( + def _update_label_annotations( self, problem: Problem, user: User, selected_label_ids: set[int], - remarks: str, ) -> None: - """Update labelings for a problem based on selected labels.""" + """Update label annotations for a problem based on selected labels.""" with transaction.atomic(): - active_labelings = Labeling.objects.filter( + session = AnnotationSession.objects.create(user=user) + + active_annotations = LabelAnnotation.objects.filter( problem=problem, removed_at__isnull=True - ).select_related("label", "attached_by") + ).select_related("label", "created_by") - current_label_ids = {labeling.label.pk for labeling in active_labelings} + current_label_ids = { + annotation.label.pk for annotation in active_annotations + } labels_to_remove = current_label_ids - selected_label_ids labels_to_add = selected_label_ids - current_label_ids - for labeling in active_labelings: - if labeling.label.pk in labels_to_remove: - self._remove_labeling( - labeling=labeling, + for annotation in active_annotations: + if annotation.label.pk in labels_to_remove: + self._mark_as_removed( + label_annotation=annotation, user=user, - remarks=remarks, ) for label_id in labels_to_add: - self._create_labeling( - label_id=label_id, - problem=problem, - user=user, - remarks=remarks, + serializer = LabelAnnotationSerializer( + context={"user": user}, + data={ + "problem": problem.pk, + "label_id": label_id, + "session": session.pk, + }, ) + serializer.is_valid(raise_exception=True) + serializer.save(created_by=user) - def _create_labeling( - self, label_id: int, problem: Problem, user: User, remarks: str - ) -> None: - """Create a new labeling.""" - - Labeling.objects.create( - problem=problem, - label_id=label_id, - attached_by=user, - notes=remarks, - ) - - def _remove_labeling(self, labeling: Labeling, user: User, remarks: str) -> None: - """Mark a labeling as removed.""" + def _mark_as_removed(self, label_annotation: LabelAnnotation, user: User) -> None: + """Mark a label annotation as removed.""" - if not user.can_remove_labeling(labeling): + if not user.can_remove_label(label_annotation): logger.warning( - f"User {user.username} attempted to remove label {labeling.label.pk} " - f"attached by {labeling.attached_by.username}" + f"User {user.username} attempted to remove label {label_annotation.label.pk} " + f"attached by {label_annotation.created_by.username}" ) raise PermissionDenied("You can only remove labels you attached yourself.") - labeling.removed_at = timezone.now() - labeling.removed_by = user - if remarks: - labeling.notes = remarks - labeling.save() + label_annotation.removed_at = timezone.now() + label_annotation.removed_by = user + label_annotation.save() diff --git a/backend/annotation/views_test.py b/backend/annotation/views_test.py index 4daeea8..b64d645 100644 --- a/backend/annotation/views_test.py +++ b/backend/annotation/views_test.py @@ -1,7 +1,7 @@ import pytest from rest_framework import status -from annotation.models import Label, Labeling +from annotation.models import AnnotationSession, Label, LabelAnnotation @pytest.fixture @@ -23,22 +23,28 @@ def another_label(db): @pytest.fixture -def labeling_by_annotator(db, sample_problem, sample_label, annotator): - """Creates a labeling attached by an annotator.""" - return Labeling.objects.create( +def label_annotation_by_annotator( + db, sample_problem, sample_label, annotator, annotator_session +): + """Creates a label_annotation attached by an annotator.""" + return LabelAnnotation.objects.create( problem=sample_problem, label=sample_label, - attached_by=annotator, + session=annotator_session, + created_by=annotator, ) @pytest.fixture -def labeling_by_master(db, sample_problem, another_label, master_annotator): - """Creates a labeling attached by a master annotator.""" - return Labeling.objects.create( +def label_annotation_by_master( + db, sample_problem, another_label, master_annotator, master_annotator_session +): + """Creates a label_annotation attached by a master annotator.""" + return LabelAnnotation.objects.create( problem=sample_problem, label=another_label, - attached_by=master_annotator, + session=master_annotator_session, + created_by=master_annotator, ) @@ -82,160 +88,157 @@ def test_visitor_can_retrieve_label(self, api_client, visitor, sample_label): assert response.status_code == status.HTTP_200_OK -class TestSaveLabelingsPermissions: - """Tests for POST permissions (saving labelings) on the Label endpoint.""" +class TestSavelabel_annotationsPermissions: + """Tests for POST permissions (saving label_annotations) on the Label endpoint.""" - def test_unauthenticated_user_cannot_save_labelings( + def test_unauthenticated_user_cannot_save_label_annotations( self, api_client, sample_problem, sample_label ): - """Unauthenticated users should not be able to save labelings.""" + """Unauthenticated users should not be able to save label_annotations.""" data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_401_UNAUTHORIZED - def test_visitor_cannot_save_labelings( + def test_visitor_cannot_save_label_annotations( self, api_client, visitor, sample_problem, sample_label ): - """Visitors should not be able to save labelings.""" + """Visitors should not be able to save label_annotations.""" api_client.force_authenticate(user=visitor) data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_403_FORBIDDEN - def test_annotator_can_save_labelings( + def test_annotator_can_save_label_annotations( self, api_client, annotator, sample_problem, sample_label ): - """Annotators should be able to save labelings.""" + """Annotators should be able to save label_annotations.""" api_client.force_authenticate(user=annotator) data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}], - "remarks": "Test remark", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK assert response.data["ok"] is True - # Verify labeling was created - labeling = Labeling.objects.get(problem=sample_problem, label=sample_label) - assert labeling.attached_by == annotator - assert labeling.notes == "Test remark" + # Verify label_annotation was created + label_annotation = LabelAnnotation.objects.get( + problem=sample_problem, label=sample_label + ) + assert label_annotation.created_by == annotator - def test_master_annotator_can_save_labelings( + def test_master_annotator_can_save_label_annotations( self, api_client, master_annotator, sample_problem, sample_label ): - """Master annotators should be able to save labelings.""" + """Master annotators should be able to save label_annotations.""" api_client.force_authenticate(user=master_annotator) data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK assert response.data["ok"] is True -class TestRemoveLabelingPermissions: - """Tests for labeling removal permissions.""" +class TestRemoveLabelAnnotationPermissions: + """Tests for label_annotation removal permissions.""" - def test_annotator_can_remove_own_labeling( - self, api_client, annotator, sample_problem, labeling_by_annotator + def test_annotator_can_remove_own_label_annotation( + self, api_client, annotator, sample_problem, label_annotation_by_annotator ): """Annotators should be able to remove labels they attached.""" api_client.force_authenticate(user=annotator) data = { "problemId": sample_problem.id, "selectedLabels": [], # Empty = remove the existing label - "remarks": "Removing my own label", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - # Verify labeling was marked as removed - labeling_by_annotator.refresh_from_db() - assert labeling_by_annotator.removed_at is not None - assert labeling_by_annotator.removed_by == annotator + # Verify label_annotation was marked as removed + label_annotation_by_annotator.refresh_from_db() + assert label_annotation_by_annotator.removed_at is not None + assert label_annotation_by_annotator.removed_by == annotator - def test_annotator_cannot_remove_others_labeling( - self, api_client, annotator, sample_problem, labeling_by_master + def test_annotator_cannot_remove_others_label_annotation( + self, api_client, annotator, sample_problem, label_annotation_by_master ): """Annotators should not be able to remove labels attached by others.""" api_client.force_authenticate(user=annotator) data = { "problemId": sample_problem.id, "selectedLabels": [], # Empty = try to remove the existing label - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_403_FORBIDDEN assert ( - "You can only remove labels you attached yourself" in response.data["detail"] + "You can only remove labels you attached yourself" + in response.data["detail"] ) - # Verify labeling was NOT removed - labeling_by_master.refresh_from_db() - assert labeling_by_master.removed_at is None + # Verify label_annotation was NOT removed + label_annotation_by_master.refresh_from_db() + assert label_annotation_by_master.removed_at is None - def test_master_annotator_can_remove_own_labeling( - self, api_client, master_annotator, sample_problem, labeling_by_master + def test_master_annotator_can_remove_own_label_annotation( + self, api_client, master_annotator, sample_problem, label_annotation_by_master ): """Master annotators should be able to remove labels they attached.""" api_client.force_authenticate(user=master_annotator) data = { "problemId": sample_problem.id, "selectedLabels": [], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - labeling_by_master.refresh_from_db() - assert labeling_by_master.removed_at is not None + label_annotation_by_master.refresh_from_db() + assert label_annotation_by_master.removed_at is not None - def test_master_annotator_can_remove_others_labeling( - self, api_client, master_annotator, sample_problem, labeling_by_annotator + def test_master_annotator_can_remove_others_label_annotation( + self, + api_client, + master_annotator, + sample_problem, + label_annotation_by_annotator, ): """Master annotators should be able to remove labels attached by others.""" api_client.force_authenticate(user=master_annotator) data = { "problemId": sample_problem.id, "selectedLabels": [], - "remarks": "Removing annotator's label", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - labeling_by_annotator.refresh_from_db() - assert labeling_by_annotator.removed_at is not None - assert labeling_by_annotator.removed_by == master_annotator + label_annotation_by_annotator.refresh_from_db() + assert label_annotation_by_annotator.removed_at is not None + assert label_annotation_by_annotator.removed_by == master_annotator -class TestLabelingAddAndRemove: - """Tests for adding / removing labelings.""" +class TestLabelAnnotationAddAndRemove: + """Tests for adding / removing label_annotations.""" - def test_adding_label_creates_labeling( + def test_adding_label_creates_label_annotation( self, api_client, annotator, sample_problem, sample_label ): - """Adding a label should create a new Labeling record.""" + """Adding a label should create a new label_annotation record.""" api_client.force_authenticate(user=annotator) data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - assert Labeling.objects.filter( + assert LabelAnnotation.objects.filter( problem=sample_problem, label=sample_label, removed_at__isnull=True ).exists() @@ -247,44 +250,42 @@ def test_adding_multiple_labels( data = { "problemId": sample_problem.id, "selectedLabels": [{"id": sample_label.id}, {"id": another_label.id}], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - active_labelings = Labeling.objects.filter( + active_label_annotations = LabelAnnotation.objects.filter( problem=sample_problem, removed_at__isnull=True ) - assert active_labelings.count() == 2 + assert active_label_annotations.count() == 2 def test_keeping_existing_labels_unchanged( self, api_client, annotator, sample_problem, - labeling_by_annotator, + label_annotation_by_annotator, another_label, ): """Labels already attached should remain if still in selectedLabels.""" - original_labeling_id = labeling_by_annotator.id + original_label_annotation_id = label_annotation_by_annotator.id api_client.force_authenticate(user=annotator) data = { "problemId": sample_problem.id, "selectedLabels": [ - {"id": labeling_by_annotator.label.id}, + {"id": label_annotation_by_annotator.label.id}, {"id": another_label.id}, ], - "remarks": "", } response = api_client.post("/api/label/", data, format="json") assert response.status_code == status.HTTP_200_OK - # Original labeling should still be active - labeling_by_annotator.refresh_from_db() - assert labeling_by_annotator.id == original_labeling_id - assert labeling_by_annotator.removed_at is None + # Original label_annotation should still be active + label_annotation_by_annotator.refresh_from_db() + assert label_annotation_by_annotator.id == original_label_annotation_id + assert label_annotation_by_annotator.removed_at is None - # New labeling should be created - assert Labeling.objects.filter( + # New label_annotation should be created + assert LabelAnnotation.objects.filter( problem=sample_problem, label=another_label, removed_at__isnull=True ).exists() diff --git a/backend/conftest.py b/backend/conftest.py index 94f7c5b..176bf9f 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -5,6 +5,12 @@ from django.contrib.auth.models import Group, Permission from rest_framework.test import APIClient as DRFAPIClient +from annotation.models import ( + AnnotationSession, + KnowledgeBaseAnnotation, + Label, + LabelAnnotation, +) from user.models import User, GroupName from user.permissions import ANNOTATOR_PERMISSIONS, MASTER_ANNOTATOR_PERMISSIONS from problem.models import Problem, Sentence @@ -118,3 +124,55 @@ def sample_problem(db): ) problem.premises.add(premise) return problem + + +@pytest.fixture +def annotator_session(db, annotator): + """Annotation session for a user with the 'Annotator' role.""" + return AnnotationSession.objects.create( + user=annotator, + created_at="2024-01-01T00:00:00Z", + updated_at="2024-01-01T00:00:00Z", + ) + + +@pytest.fixture +def master_annotator_session(db, master_annotator): + """Creates an annotation session for the master annotator.""" + return AnnotationSession.objects.create( + user=master_annotator, + created_at="2024-01-01T00:00:00Z", + updated_at="2024-01-01T00:00:00Z", + ) + + +@pytest.fixture +def kb_annotation(db, sample_problem, annotator_session): + return KnowledgeBaseAnnotation.objects.create( + problem=sample_problem, + entity1="e1", + entity2="e2", + relationship=KnowledgeBaseAnnotation.Relationship.EQUAL, + session=annotator_session, + created_at="2024-01-01T00:00:00Z", + created_by=annotator_session.user, + ) + + +@pytest.fixture +def sample_label(db): + """Create a sample label for testing.""" + return Label.objects.create( + text="Ambiguous", description="This problem contains ambiguous language." + ) + + +@pytest.fixture +def label_annotation(db, sample_problem, annotator_session, sample_label): + """Create a label annotation for testing.""" + return LabelAnnotation.objects.create( + problem=sample_problem, + label=sample_label, + session=annotator_session, + created_by=annotator_session.user, + ) diff --git a/backend/langpro_annotator/urls.py b/backend/langpro_annotator/urls.py index 97e77ca..e4808da 100644 --- a/backend/langpro_annotator/urls.py +++ b/backend/langpro_annotator/urls.py @@ -21,7 +21,7 @@ from rest_framework import routers -from annotation.views import LabelView +from annotation.views import LabelAnnotationView from problem.views.problem import ProblemView from .index import index @@ -30,7 +30,7 @@ api_router = routers.DefaultRouter() # register viewsets with this router api_router.register(r"problem", ProblemView, basename="problem") -api_router.register(r"label", LabelView, basename="labels") +api_router.register(r"label", LabelAnnotationView, basename="labels") if settings.PROXY_FRONTEND: diff --git a/backend/problem/migrations/0008_delete_knowledgebase.py b/backend/problem/migrations/0008_delete_knowledgebase.py new file mode 100644 index 0000000..ab7dc87 --- /dev/null +++ b/backend/problem/migrations/0008_delete_knowledgebase.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.27 on 2026-02-01 13:27 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("annotation", "0003_labelannotation_and_more"), + ("problem", "0007_alter_problem_options"), + ] + + operations = [ + migrations.DeleteModel( + name="KnowledgeBase", + ), + ] diff --git a/backend/problem/models.py b/backend/problem/models.py index 22b9a7b..862e86d 100644 --- a/backend/problem/models.py +++ b/backend/problem/models.py @@ -74,27 +74,3 @@ def get_index(self, qs: QuerySet) -> int | None: except Exception as e: logger.exception(f"Error getting index for problem {self.pk}: {e}") return None - - -class KnowledgeBase(models.Model): - class Relationship(models.TextChoices): - EQUAL = "equal", "Equal" - NOT_EQUAL = "not_equal", "Not Equal" - SUBSET = "subset", "Subset" - SUPERSET = "superset", "Superset" - - entity1 = models.CharField(max_length=255) - - entity2 = models.CharField(max_length=255) - - relationship = models.CharField( - max_length=255, - choices=Relationship.choices, - default=Relationship.EQUAL, - ) - - problem = models.ForeignKey( - Problem, - on_delete=models.CASCADE, - related_name="knowledge_bases", - ) diff --git a/backend/problem/serializers.py b/backend/problem/serializers.py index 3f1ce4c..778bfe5 100644 --- a/backend/problem/serializers.py +++ b/backend/problem/serializers.py @@ -1,62 +1,24 @@ from rest_framework import serializers -from django.contrib.auth.models import AnonymousUser -from annotation.serializers import ActiveLabelSerializer -from user.models import User -from annotation.models import Label, Labeling +from annotation.serializers import KnowledgeBaseAnnotationSerializer +from annotation.models import ( + AnnotationSession, + KnowledgeBaseAnnotation, +) from problem.services import FracasData, SNLIData, SickData -from problem.models import Problem, KnowledgeBase, Sentence - - -class KnowledgeBaseSerializer(serializers.ModelSerializer): - - class Meta: - model = KnowledgeBase - fields = ["id", "entity1", "entity2", "relationship"] - extra_kwargs = { - # Without this, the relationship field is not required during validation. - "relationship": {"required": True}, - } - - def validate_id(self, value): - """Validate that the KnowledgeBase ID exists if provided.""" - if value is not None: - if not KnowledgeBase.objects.filter(id=value).exists(): - raise serializers.ValidationError( - f"KnowledgeBase item with ID {value} does not exist." - ) - return value - - def create_for_problem( - self, validated_data: dict, problem: Problem - ) -> KnowledgeBase: - """Create a new KnowledgeBase item for a problem.""" - return KnowledgeBase.objects.create( - **validated_data, - problem=problem, - ) - - def update(self, instance: KnowledgeBase, validated_data: dict) -> KnowledgeBase: - """Update an existing KnowledgeBase item.""" - instance.entity1 = validated_data["entity1"] - instance.relationship = validated_data["relationship"] - instance.entity2 = validated_data["entity2"] - instance.save() - return instance +from problem.models import Problem, Sentence class ProblemSerializer(serializers.ModelSerializer): """ Serializer for Problem model output. - Handles serialization of problems with all related data including labels. """ + id = serializers.IntegerField(read_only=True) premises = serializers.SerializerMethodField() hypothesis = serializers.SerializerMethodField() entailmentLabel = serializers.CharField(source="entailment_label") extraData = serializers.SerializerMethodField() - kbItems = serializers.SerializerMethodField() - labels = serializers.SerializerMethodField() class Meta: model = Problem @@ -67,20 +29,18 @@ class Meta: "hypothesis", "entailmentLabel", "extraData", - "kbItems", "base", - "labels", ] - def get_premises(self, problem): + def get_premises(self, problem: Problem): """Get list of premise texts.""" return [premise.text for premise in problem.premises.all()] - def get_hypothesis(self, problem): + def get_hypothesis(self, problem: Problem): """Get hypothesis text.""" return problem.hypothesis.text - def get_extraData(self, problem): + def get_extraData(self, problem: Problem): """Get dataset-specific extra data.""" match problem.dataset: case Problem.Dataset.SICK: @@ -92,17 +52,49 @@ def get_extraData(self, problem): case _: return {} - def get_kbItems(self, problem): - """Get knowledge base items.""" - kb_items = problem.knowledge_bases.all() - return KnowledgeBaseSerializer(kb_items, many=True).data - def get_labels(self, problem): - """Get active labels with attachment info and removability.""" - active_labelings = problem.labelings.filter(removed_at__isnull=True) - return ActiveLabelSerializer( - active_labelings, many=True, context=self.context - ).data +class ProblemInputSerializer(serializers.Serializer): + """ + Serializer for validating problem input data. + This is used for both creating and updating user-created problems. + """ + + id = serializers.IntegerField(required=False, allow_null=True) + premises = serializers.ListField( + child=serializers.CharField(allow_blank=False), + allow_empty=False, + help_text="List of premise sentence texts", + ) + hypothesis = serializers.CharField( + allow_blank=False, help_text="Hypothesis sentence text" + ) + kbItems = KnowledgeBaseAnnotationSerializer( + many=True, + help_text="List of knowledge base annotations", + required=False, + ) + + base = serializers.IntegerField(required=False, allow_null=True) + + def validate_id(self, value): + """Validate that the Problem ID, if provided, exists and belongs to a user-created problem.""" + if value is not None: + if not Problem.objects.filter( + id=value, dataset=Problem.Dataset.USER + ).exists(): + raise serializers.ValidationError( + f"Problem with ID {value} does not exist." + ) + return value + + def validate_base(self, value): + """Validate that the base problem ID exists if provided.""" + if value is not None: + if not Problem.objects.filter(id=value).exists(): + raise serializers.ValidationError( + f"Base problem with ID {value} does not exist." + ) + return value def create(self, validated_data: dict) -> Problem: """ @@ -131,19 +123,74 @@ def create(self, validated_data: dict) -> Problem: kb_items = validated_data.get("kbItems", []) if kb_items: - self._update_or_create_kb_items(problem, kb_items) + self._create_update_kb_annotations(problem, kb_items) return problem + def _create_update_kb_annotation( + self, kb_item: dict, problem: Problem, session: AnnotationSession + ) -> None: + kb_id = kb_item.get("id", None) + + update_data = { + **kb_item, + "problem_id": problem.pk, + "session_id": session.pk, + } + + if kb_id is None: + # Create new KnowledgeBaseAnnotation + serializer = KnowledgeBaseAnnotationSerializer(data=update_data) + else: + # Update existing KnowledgeBaseAnnotation + try: + kb_instance = KnowledgeBaseAnnotation.objects.get( + id=kb_id, problem_id=problem.pk + ) + serializer = KnowledgeBaseAnnotationSerializer( + kb_instance, data=update_data + ) + except KnowledgeBaseAnnotation.DoesNotExist: + raise serializers.ValidationError( + f"KnowledgeBaseAnnotation with ID {kb_id} does not exist " + f"for this problem and session." + ) + + serializer.is_valid(raise_exception=True) + serializer.save(problem=problem, session=session, created_by=session.user) + + def _create_update_kb_annotations( + self, problem: Problem, kb_items: list[dict] + ) -> None: + """ + Creates or update KnowledgeBase and Label annotations for a problem. + Creates an annotation session if it does not exist. + + TODO: handle deletions! + """ + request = self.context.get("request", None) + if not request or not request.user.is_authenticated: + return + + session = AnnotationSession.objects.create(user=request.user) + + for kb_item in kb_items: + self._create_update_kb_annotation(kb_item, problem, session) + def update(self, instance: Problem, validated_data: dict) -> Problem: """ Update an existing Problem instance from validated input data. Handles updating of related Sentence and KnowledgeBase objects. """ + + # KB annotations can be made for all problems. + kb_items = validated_data.get("kbItems", []) + if kb_items: + self._create_update_kb_annotations(instance, kb_items) + + # Other fields can only be updated for user-created problems. if instance.dataset != Problem.Dataset.USER: - raise serializers.ValidationError( - "Cannot update a problem that is not a user-created problem." - ) + return instance instance.hypothesis = Sentence.objects.get_or_create( text=validated_data["hypothesis"], @@ -169,72 +216,4 @@ def update(self, instance: Problem, validated_data: dict) -> Problem: ] instance.premises.set(premise_sentences) - self._update_or_create_kb_items(instance, validated_data.get("kbItems", [])) - return instance - - def _update_or_create_kb_items( - self, problem: Problem, kb_items: list[dict] - ) -> None: - """Create or update KnowledgeBase items for a problem.""" - kb_ids: list[int] = [] - kb_serializer = KnowledgeBaseSerializer() - - for item in kb_items: - kb_id = item.get("id", None) - - if kb_id is None: - kb = kb_serializer.create_for_problem(item, problem=problem) # type: ignore - else: - kb_instance = KnowledgeBase.objects.get(id=kb_id, problem_id=problem.pk) - kb = kb_serializer.update(kb_instance, item) - - kb_ids.append(kb.pk) - - # Delete existing knowledge bases associated to this problem that are - # not included in the input. - KnowledgeBase.objects.filter(problem_id=problem.pk).exclude( - id__in=kb_ids - ).delete() - - -class ProblemInputSerializer(serializers.Serializer): - """ - Serializer for validating problem input data. - This is used for both creating and updating user-created problems. - """ - - id = serializers.IntegerField(required=False, allow_null=True) - premises = serializers.ListField( - child=serializers.CharField(allow_blank=False), - allow_empty=False, - help_text="List of premise sentence texts", - ) - hypothesis = serializers.CharField( - allow_blank=False, help_text="Hypothesis sentence text" - ) - kbItems = KnowledgeBaseSerializer( - many=True, allow_empty=True, help_text="List of knowledge base items" - ) - - base = serializers.IntegerField(required=False, allow_null=True) - - def validate_id(self, value): - """Validate that the Problem ID, if provided, exists and belongs to a user-created problem.""" - if value is not None: - if not Problem.objects.filter( - id=value, dataset=Problem.Dataset.USER - ).exists(): - raise serializers.ValidationError( - f"Problem with ID {value} does not exist." - ) - return value - - def validate_base(self, value): - """Validate that the base problem ID exists if provided.""" - if value is not None: - if not Problem.objects.filter(id=value).exists(): - raise serializers.ValidationError( - f"Base problem with ID {value} does not exist." - ) - return value diff --git a/backend/problem/serializers_test.py b/backend/problem/serializers_test.py index ad7e489..6699d08 100644 --- a/backend/problem/serializers_test.py +++ b/backend/problem/serializers_test.py @@ -1,8 +1,10 @@ import pytest from rest_framework.exceptions import ValidationError -from .serializers import ProblemInputSerializer -from .models import Problem, Sentence, KnowledgeBase +from annotation.models import AnnotationSession, KnowledgeBaseAnnotation + +from .serializers import ProblemInputSerializer, ProblemSerializer +from .models import Problem, Sentence @pytest.fixture @@ -37,16 +39,6 @@ def non_user_problem(db, hypothesis_sentence, premise_sentence): return problem -@pytest.fixture -def kb_item(db, user_problem): - return KnowledgeBase.objects.create( - problem=user_problem, - entity1="e1", - entity2="e2", - relationship=KnowledgeBase.Relationship.EQUAL, - ) - - @pytest.mark.django_db def test_valid_create_data(): """Test valid data for creating a problem.""" @@ -60,21 +52,12 @@ def test_valid_create_data(): @pytest.mark.django_db -def test_valid_update_data(user_problem, kb_item): +def test_valid_update_data(user_problem): """Test valid data for updating a user problem.""" data = { "id": user_problem.pk, "premises": ["The cat is on the mat."], "hypothesis": "A cat is on a mat.", - "kbItems": [ - { - "id": kb_item.pk, - "entity1": "e1", - "entity2": "e2", - "relationship": "equal", - }, - {"entity1": "new_e1", "entity2": "new_e2", "relationship": "subset"}, - ], } serializer = ProblemInputSerializer(data=data) assert serializer.is_valid(raise_exception=True) @@ -148,28 +131,3 @@ def test_blank_hypothesis_invalid(): assert not serializer.is_valid() assert "hypothesis" in serializer.errors - -@pytest.mark.django_db -def test_invalid_kb_item_id(): - """Test that a non-existent kbItem ID is invalid.""" - data = { - "premises": ["premise"], - "hypothesis": "hypothesis", - "kbItems": [{"id": 9999, "relationship": "equal"}], - } - serializer = ProblemInputSerializer(data=data) - assert not serializer.is_valid() - assert "kbItems" in serializer.errors - - -@pytest.mark.django_db -def test_kb_item_missing_relationship(): - """Test that a kbItem missing a relationship is invalid.""" - data = { - "premises": ["premise"], - "hypothesis": "hypothesis", - "kbItems": [{"entity1": "e1", "entity2": "e2"}], - } - serializer = ProblemInputSerializer(data=data) - assert not serializer.is_valid() - assert "kbItems" in serializer.errors diff --git a/backend/problem/views/problem.py b/backend/problem/views/problem.py index 3b37b3d..39f9a0e 100644 --- a/backend/problem/views/problem.py +++ b/backend/problem/views/problem.py @@ -3,6 +3,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from rest_framework.status import HTTP_201_CREATED, HTTP_200_OK +from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly from django.shortcuts import get_object_or_404 @@ -12,7 +13,9 @@ ) from problem.models import Problem from problem.serializers import ProblemInputSerializer, ProblemSerializer -from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly + +from annotation.models import KnowledgeBaseAnnotation, LabelAnnotation +from annotation.serializers import KnowledgeBaseAnnotationSerializer, LabelAnnotationSerializer class CreateProblemPermission(IsAuthenticated): @@ -26,10 +29,7 @@ def has_permission(self, request, view): class ProblemView(ModelViewSet): - queryset = Problem.objects.prefetch_related( - "labelings__label", - "labelings__attached_by", - ) + queryset = Problem.objects.all() serializer_class = ProblemSerializer def get_permissions(self): @@ -95,9 +95,28 @@ def _get_problem_response(self, request: Request, pk: int | None) -> Response: serializer = self.get_serializer(problem) + kb_annotations = KnowledgeBaseAnnotation.objects.filter( + problem=problem, removed_at__isnull=True + ) + label_annotations = LabelAnnotation.objects.filter( + problem=problem, removed_at__isnull=True + ) + + # kbAnnotations and labelAnnotations are not included in the + # ProblemSerializer because they require additional context for + # determining removability, so we serialize them separately here with + # the proper context. return Response( { - "problem": serializer.data, + "problem": { + **serializer.data, + "kbAnnotations": KnowledgeBaseAnnotationSerializer( + kb_annotations, context={"user": request.user}, many=True + ).data, + "labelAnnotations": LabelAnnotationSerializer( + label_annotations, context={"user": request.user}, many=True + ).data, + }, "index": problem_index, "first": related_problem_ids.first, "previous": related_problem_ids.previous, @@ -125,22 +144,20 @@ def _handle_update_create_problem( ) -> Response: input_data = request.data - input_serializer = ProblemInputSerializer(data=input_data) - input_serializer.is_valid(raise_exception=True) - validated_input: dict = input_serializer.validated_data # type: ignore - - problem_serializer = ProblemSerializer() + serializer = ProblemInputSerializer( + data=input_data, context={"request": request} + ) + serializer.is_valid(raise_exception=True) + validated_input: dict = serializer.validated_data # type: ignore if problem_id is None: - problem = problem_serializer.create(validated_input) # type: ignore + problem = serializer.create(validated_input) # type: ignore status = HTTP_201_CREATED else: problem_instance = get_object_or_404( Problem, id=problem_id, dataset=Problem.Dataset.USER ) - problem: Problem = problem_serializer.update( - problem_instance, validated_input - ) + problem: Problem = serializer.update(problem_instance, validated_input) status = HTTP_200_OK return Response({"id": problem.pk}, status=status) diff --git a/backend/user/models.py b/backend/user/models.py index d9a4627..d85489b 100644 --- a/backend/user/models.py +++ b/backend/user/models.py @@ -1,7 +1,7 @@ from enum import StrEnum import django.contrib.auth.models as django_auth_models -from annotation.models import Labeling +from annotation.models import LabelAnnotation class GroupName(StrEnum): @@ -58,14 +58,14 @@ def can_create_problem(self) -> bool: """ return self.has_perm("problem.add_problem") - def can_remove_labeling(self, labeling: Labeling) -> bool: + def can_remove_label(self, label_annotation: LabelAnnotation) -> bool: """ - Determines whether the user can remove a specific labeling. + Determines whether the user can remove a specific label (as part of an annotation). """ - if self.is_superuser or self.has_perm("annotation.delete_any_labeling"): + if self.is_superuser or self.has_perm("annotation.delete_any_labelannotation"): return True - if self.has_perm("annotation.delete_own_labeling"): - return labeling.attached_by.pk == self.pk + if self.has_perm("annotation.delete_own_labelannotation"): + return label_annotation.created_by.pk == self.pk return False diff --git a/backend/user/permissions.py b/backend/user/permissions.py index 0b6e93c..1e4a03d 100644 --- a/backend/user/permissions.py +++ b/backend/user/permissions.py @@ -1,12 +1,12 @@ # Django permissions are uniquely identified by their combination of a `content_type__app_label` and a `codename`. ANNOTATOR_PERMISSIONS = [ ("problem", "view_silver_problems"), - ("problem", "add_knowledgebase"), - ("problem", "change_knowledgebase"), - ("problem", "delete_knowledgebase"), - ("problem", "view_knowledgebase"), - ("annotation", "add_labeling"), - ("annotation", "delete_own_labeling"), + ("annotation", "add_knowledgebaseannotation"), + ("annotation", "add_labelannotation"), + ("annotation", "change_knowledgebaseannotation"), + ("annotation", "change_labelannotation"), + ("annotation", "delete_knowledgebaseannotation"), + ("annotation", "delete_own_labelannotation"), ] MASTER_ANNOTATOR_PERMISSIONS = ANNOTATOR_PERMISSIONS + [ @@ -21,5 +21,5 @@ ("annotation", "add_label"), ("annotation", "change_label"), ("annotation", "delete_label"), - ("annotation", "delete_any_labeling"), + ("annotation", "delete_any_labelannotation"), ] diff --git a/frontend/src/app/annotate/annotation-input/annotation-input.component.spec.ts b/frontend/src/app/annotate/annotation-input/annotation-input.component.spec.ts index 7ca60aa..38db77f 100644 --- a/frontend/src/app/annotate/annotation-input/annotation-input.component.spec.ts +++ b/frontend/src/app/annotate/annotation-input/annotation-input.component.spec.ts @@ -53,21 +53,35 @@ describe("AnnotationInputComponent", () => { premises: ["First premise", "Second premise"], hypothesis: "Test hypothesis", entailmentLabel: EntailmentLabel.ENTAILMENT, - kbItems: [ + kbAnnotations: [ { id: 456, entity1: "cat", entity2: "animal", - relationship: KnowledgeBaseRelationship.SUBSET + relationship: KnowledgeBaseRelationship.SUBSET, + createdAt: "", + createdBy: "", + removedAt: null, + removedBy: null, + notes: "", + session: null, + removable: true }, { id: 789, entity1: "dog", entity2: "pet", - relationship: KnowledgeBaseRelationship.EQUAL + relationship: KnowledgeBaseRelationship.EQUAL, + createdAt: "", + createdBy: "", + removedAt: null, + removedBy: null, + notes: "", + session: null, + removable: true } ], - labels: [], + labelAnnotations: [], dataset: Dataset.USER, extraData: null }; @@ -110,8 +124,8 @@ describe("AnnotationInputComponent", () => { premises: [], hypothesis: "Empty test hypothesis", entailmentLabel: EntailmentLabel.NEUTRAL, - kbItems: [], - labels: [], + kbAnnotations: [], + labelAnnotations: [], dataset: Dataset.USER, extraData: null }; @@ -123,9 +137,6 @@ describe("AnnotationInputComponent", () => { const premisesArray = form.get('premises') as FormArray; expect(premisesArray.length).toBe(0); - - const kbItemsArray = form.get('kbItems') as FormArray; - expect(kbItemsArray.length).toBe(0); }); it('should create form controls with required validators', () => { @@ -135,10 +146,10 @@ describe("AnnotationInputComponent", () => { premises: ["Test premise"], hypothesis: "Test hypothesis", entailmentLabel: EntailmentLabel.CONTRADICTION, - kbItems: [], - labels: [], dataset: Dataset.USER, - extraData: null + extraData: null, + kbAnnotations: [], + labelAnnotations: [], }; const form = component['buildForm'](mockProblem); @@ -158,10 +169,10 @@ describe("AnnotationInputComponent", () => { premises: [], hypothesis: "", entailmentLabel: EntailmentLabel.UNKNOWN, - kbItems: [], - labels: [], dataset: Dataset.USER, - extraData: null + extraData: null, + kbAnnotations: [], + labelAnnotations: [], }; component['navigateToNewProblem'](mockProblem); @@ -179,10 +190,10 @@ describe("AnnotationInputComponent", () => { premises: [], hypothesis: "", entailmentLabel: EntailmentLabel.UNKNOWN, - kbItems: [], - labels: [], dataset: Dataset.USER, - extraData: null + extraData: null, + kbAnnotations: [], + labelAnnotations: [], }; component['navigateToNewProblem'](mockProblem); diff --git a/frontend/src/app/annotate/annotation-input/annotation-input.component.ts b/frontend/src/app/annotate/annotation-input/annotation-input.component.ts index 3c3a690..bb4d828 100644 --- a/frontend/src/app/annotate/annotation-input/annotation-input.component.ts +++ b/frontend/src/app/annotate/annotation-input/annotation-input.component.ts @@ -11,7 +11,7 @@ import { import { PremisesFormComponent } from "./premises-form/premises-form.component"; import { KnowledgeBaseFormComponent } from "./knowledge-base-form/knowledge-base-form.component"; import { takeUntilDestroyed } from "@angular/core/rxjs-interop"; -import { Dataset, KnowledgeBaseRelationship, Problem } from "../../types"; +import { Dataset, KnowledgeBaseAnnotation, KnowledgeBaseRelationship, Problem } from "../../types"; import { faCheck, faCopy, faExclamationCircle, faFloppyDisk, faTrash, faTree, faWrench } from "@fortawesome/free-solid-svg-icons"; import { ProblemDetailsComponent } from "./problem-details/problem-details.component"; import { map, Subject } from "rxjs"; @@ -28,10 +28,11 @@ export type ParseInputForm = FormGroup<{ base: FormControl; premises: FormArray>; hypothesis: FormControl; - kbItems: FormArray; + kbItems: FormArray; }>; -type KnowledgeBaseItemsForm = FormGroup<{ + +type KBItemForm = FormGroup<{ id: FormControl; entity1: FormControl; relationship: FormControl; @@ -161,7 +162,7 @@ export class AnnotationInputComponent implements OnInit { } private buildForm(problem: Problem): ParseInputForm { - const kbItems = this.buildKbForms(problem.kbItems); + const kbItems = this.buildKbForms(problem.kbAnnotations ?? []); return new FormGroup({ id: new FormControl(problem.id, { @@ -183,10 +184,10 @@ export class AnnotationInputComponent implements OnInit { validators: [Validators.required], nonNullable: true, }), - kbItems: new FormArray(kbItems), + kbItems: new FormArray(kbItems), }); } - private buildKbForms(inputKbItems: Problem['kbItems']): KnowledgeBaseItemsForm[] { + private buildKbForms(inputKbItems: KnowledgeBaseAnnotation[]): KBItemForm[] { return inputKbItems.map(item => new FormGroup({ id: new FormControl(item.id, { nonNullable: true diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html index f0dd55d..b5f98c2 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.html @@ -61,7 +61,7 @@ diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts index f00a200..436fc1a 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.spec.ts @@ -16,9 +16,9 @@ const createMockProblem = ( entailmentLabel, premises: ["premise"], hypothesis: "hypothesis", - kbItems: [], - labels: [], extraData, + kbAnnotations: [], + labelAnnotations: [], }); describe("ProblemDetailsComponent", () => { @@ -66,7 +66,7 @@ describe("ProblemDetailsComponent", () => { section: null, subsection: null, comment: null, - labels: [], + labelAnnotations: [], }); }); @@ -100,7 +100,7 @@ describe("ProblemDetailsComponent", () => { section: "Quantifiers", subsection: "Some", comment: "A test note", - labels: [], + labelAnnotations: [], }); }); diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts index 9f1110a..78c4785 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-details.component.ts @@ -1,4 +1,4 @@ -import { Dataset, EntailmentLabel, Problem, ProblemLabel } from "../../../types"; +import { Dataset, EntailmentLabel, LabelAnnotation, Problem } from "../../../types"; import { Component, computed, inject, input } from "@angular/core"; import { EntailmentLabelBadgeComponent } from "./entailment-label-badge/entailment-label-badge.component"; import { faArrowUpRightFromSquare, faQuestionCircle } from "@fortawesome/free-solid-svg-icons"; @@ -18,7 +18,7 @@ export interface ProblemDetails { section: string | null; subsection: string | null; comment: string | null; - labels: ProblemLabel[]; + labelAnnotations: LabelAnnotation[]; } @Component({ @@ -73,13 +73,13 @@ export class ProblemDetailsComponent { private extractDetails(problem: Problem): ProblemDetails | null { const shared: Pick< ProblemDetails, - "problemId" | "dataset" | "entailmentLabel" | "labels" | "baseProblemId" + "problemId" | "dataset" | "entailmentLabel" | "labelAnnotations" | "baseProblemId" > = { problemId: problem.id?.toString() ?? $localize`new`, baseProblemId: problem.base?.toString() ?? null, dataset: problem.dataset, entailmentLabel: problem.entailmentLabel, - labels: problem.labels + labelAnnotations: problem.labelAnnotations ?? [], }; switch (problem.dataset) { diff --git a/frontend/src/app/annotate/annotation-input/problem-details/problem-labels/manage-labels-modal/manage-labels-modal.component.html b/frontend/src/app/annotate/annotation-input/problem-details/problem-labels/manage-labels-modal/manage-labels-modal.component.html index 5858d28..1b60f39 100644 --- a/frontend/src/app/annotate/annotation-input/problem-details/problem-labels/manage-labels-modal/manage-labels-modal.component.html +++ b/frontend/src/app/annotate/annotation-input/problem-details/problem-labels/manage-labels-modal/manage-labels-modal.component.html @@ -19,26 +19,33 @@
Attached labels
- @for (label of selectedLabels$ | async; track label.id) { + @for (annotation of shownLabels$ | async; track annotation.label.id) + {
- {{ label.text }} + {{ + annotation.label.text + }}
- {{ label.description }} - @if (getAttachedByText(label)) { + {{ annotation.label.description }} + @if (getAttachedByText(annotation)) {
- {{ getAttachedByText(label) }} + {{ getAttachedByText(annotation) }}
} - @if (label.removable) { + @if (annotation.removable) { Click to remove } @@ -66,14 +73,22 @@
Available labels
@if (loadingLabels$ | async) { -
- Loading labels... +
+ Loading labels...
} @else { @for (label of availableLabels$ | async; track label.id) { - +
{{ label.text }} @@ -96,19 +111,6 @@
Available labels
} - -
-
- - -
-