Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/annotation/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_removable(self, annotation: KnowledgeBaseAnnotation) -> bool:
if user is None or user.is_anonymous:
return False

return user.has_perm("annotation.delete_knowledgebaseannotation")
return user.has_perm("annotation.change_knowledgebaseannotation")

def validate_id(self, value):
"""Validate that the KnowledgeBaseAnnotation ID exists if provided."""
Expand Down
2 changes: 1 addition & 1 deletion backend/annotation/serializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_kb_annotation_update(kb_annotation):
def test_kb_annotation_get_removable_with_permission(kb_annotation, annotator):
"""Test that get_removable returns True when user has permission."""
permission = Permission.objects.get(
codename="delete_knowledgebaseannotation", content_type__app_label="annotation"
codename="change_knowledgebaseannotation", content_type__app_label="annotation"
)
annotator.user_permissions.add(permission)
annotator.refresh_from_db()
Expand Down
44 changes: 31 additions & 13 deletions backend/problem/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework import serializers
from django.utils import timezone

from annotation.serializers import KnowledgeBaseAnnotationSerializer
from annotation.models import (
Expand Down Expand Up @@ -77,11 +78,9 @@ class ProblemInputSerializer(serializers.Serializer):
base = serializers.IntegerField(required=False, allow_null=True)

def validate_id(self, value):
"""Validate that the Problem ID, if provided, exists and belongs to a user-created problem."""
"""Validate that the Problem ID, if provided, exists."""
if value is not None:
if not Problem.objects.filter(
id=value, dataset=Problem.Dataset.USER
).exists():
if not Problem.objects.filter(id=value).exists():
raise serializers.ValidationError(
f"Problem with ID {value} does not exist."
)
Expand Down Expand Up @@ -122,8 +121,7 @@ def create(self, validated_data: dict) -> Problem:
problem.premises.set(premise_sentences)

kb_items = validated_data.get("kbItems", [])
if kb_items:
self._create_update_kb_annotations(problem, kb_items)
self._handle_kb_annotations(problem, kb_items)

return problem

Expand Down Expand Up @@ -159,21 +157,42 @@ def _create_update_kb_annotation(
serializer.is_valid(raise_exception=True)
serializer.save(problem=problem, session=session, created_by=session.user)

def _create_update_kb_annotations(
def _mark_kb_not_in_input_as_removed(
self, problem: Problem, kb_items: list[dict], session: AnnotationSession
) -> None:
"""
Marks KnowledgeBase annotations for a problem that are not included in
the provided list of kb_items as removed.
"""
kb_item_ids = {kb_item.get("id") for kb_item in kb_items if kb_item.get("id") is not None}

annotations_to_delete = KnowledgeBaseAnnotation.objects.filter(
problem=problem,
removed_at__isnull=True
).exclude(id__in=kb_item_ids)

current_time = timezone.now()

for annotation in annotations_to_delete:
annotation.removed_at = current_time
annotation.removed_by = session.user
annotation.save()

def _handle_kb_annotations(
self, problem: Problem, kb_items: list[dict]
) -> None:
"""
Creates or update KnowledgeBase and Label annotations for a problem.
Creates, updates and deletes KnowledgeBase annotations for a problem.
Creates an annotation session if it does not exist.

TODO: handle deletions!
"""
request = self.context.get("request", None)
if not request or not request.user.is_authenticated:
if not request or not request.user.is_authenticated or not request.user.can_edit_kb:
return

session = AnnotationSession.objects.create(user=request.user)

self._mark_kb_not_in_input_as_removed(problem, kb_items, session)

for kb_item in kb_items:
self._create_update_kb_annotation(kb_item, problem, session)

Expand All @@ -185,8 +204,7 @@ def update(self, instance: Problem, validated_data: dict) -> Problem:

# KB annotations can be made for all problems.
kb_items = validated_data.get("kbItems", [])
if kb_items:
self._create_update_kb_annotations(instance, kb_items)
self._handle_kb_annotations(instance, kb_items)

# Other fields can only be updated for user-created problems.
if instance.dataset != Problem.Dataset.USER:
Expand Down
Loading