Skip to content

Commit 3a6aa24

Browse files
Merge pull request #53 from CentreForDigitalHumanities/feature/user-problems
Feature/user problems
2 parents eb8bafc + 2ba0c98 commit 3a6aa24

32 files changed

Lines changed: 1482 additions & 310 deletions

backend/langpro_annotator/urls.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121

2222
from rest_framework import routers
2323

24+
from problem.views.problem import ProblemView
25+
2426
from .index import index
2527
from .proxy_frontend import proxy_frontend
2628
from .i18n import i18n
2729

2830
api_router = routers.DefaultRouter() # register viewsets with this router
31+
api_router.register(r"problem", ProblemView, basename="problem")
2932

3033

3134
if settings.PROXY_FRONTEND:

backend/problem/models.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from django.db import models
2-
from django.contrib.postgres.fields import ArrayField
32
from django.db.models import QuerySet
43

54
from problem.services import FracasData, SNLIData, SickData
@@ -58,30 +57,6 @@ def get_index(self, qs: QuerySet) -> int | None:
5857
logger.exception(f"Error getting index for problem {self.pk}: {e}")
5958
return None
6059

61-
def serialize(self) -> dict:
62-
"""
63-
Serialize the Problem instance to a dictionary.
64-
"""
65-
66-
match self.dataset:
67-
case self.Dataset.SICK:
68-
serialized_extra_data = SickData.serialize(self.extra_data)
69-
case self.Dataset.FRACAS:
70-
serialized_extra_data = FracasData.serialize(self.extra_data)
71-
case self.Dataset.SNLI:
72-
serialized_extra_data = SNLIData.serialize(self.extra_data)
73-
case _:
74-
serialized_extra_data = {}
75-
76-
return {
77-
"id": self.pk,
78-
"dataset": self.dataset,
79-
"premises": [premise.text for premise in self.premises.all()],
80-
"hypothesis": self.hypothesis.text,
81-
"entailmentLabel": self.entailment_label,
82-
"extraData": serialized_extra_data,
83-
}
84-
8560

8661
class KnowledgeBase(models.Model):
8762
class Relationship(models.TextChoices):
@@ -105,3 +80,11 @@ class Relationship(models.TextChoices):
10580
on_delete=models.CASCADE,
10681
related_name="knowledge_bases",
10782
)
83+
84+
def serialize(self) -> dict:
85+
return {
86+
"id": self.pk,
87+
"entity1": self.entity1,
88+
"entity2": self.entity2,
89+
"relationship": self.relationship,
90+
}

backend/problem/problem_details.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import Optional
21
from dataclasses import dataclass
32

43
from django.http import QueryDict
@@ -10,16 +9,16 @@
109

1110
@dataclass
1211
class RelatedProblemIds:
13-
first: Optional[str] = None
14-
previous: Optional[str] = None
15-
next: Optional[str] = None
16-
last: Optional[str] = None
17-
total: Optional[int] = None
12+
first: int | None = None
13+
previous: int | None = None
14+
next: int | None = None
15+
last: int | None = None
16+
total: int | None = None
1817

1918

2019
def get_related_problem_ids(
2120
problem_qs: QuerySet[Problem],
22-
problem_id: Optional[int],
21+
problem_id: int | None = None,
2322
) -> RelatedProblemIds:
2423
"""
2524
Retrieves the IDs of surrounding problem objects
@@ -44,10 +43,10 @@ def get_related_problem_ids(
4443
problem = None
4544

4645
return RelatedProblemIds(
47-
first=str(first_problem.pk) if first_problem else None,
48-
previous=str(previous_problem.pk) if previous_problem else None,
49-
next=str(next_problem.pk) if next_problem else None,
50-
last=str(last_problem.pk) if last_problem else None,
46+
first=first_problem.pk if first_problem else None,
47+
previous=previous_problem.pk if previous_problem else None,
48+
next=next_problem.pk if next_problem else None,
49+
last=last_problem.pk if last_problem else None,
5150
total=total,
5251
)
5352

@@ -70,8 +69,9 @@ def get_filters(query_params: QueryDict) -> Q | None:
7069
filters &= Q(dataset=dataset)
7170
if entailment_label:
7271
filters &= Q(entailment_label=entailment_label)
73-
if gold is not None:
74-
raise NotImplementedError()
72+
if gold:
73+
logger.warning(f"Filtering by gold is not implemented yet.")
74+
pass
7575
if text:
7676
filters &= Q(
7777
Q(hypothesis__text__icontains=text) | Q(premises__text__icontains=text)

backend/problem/serializers.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from rest_framework import serializers
2+
from problem.services import FracasData, SNLIData, SickData
3+
from problem.models import Problem, KnowledgeBase, Sentence
4+
5+
6+
class KnowledgeBaseSerializer(serializers.ModelSerializer):
7+
8+
class Meta:
9+
model = KnowledgeBase
10+
fields = ["id", "entity1", "entity2", "relationship"]
11+
extra_kwargs = {
12+
# Without this, the relationship field is not required during validation.
13+
"relationship": {"required": True},
14+
}
15+
16+
def validate_id(self, value):
17+
"""Validate that the KnowledgeBase ID exists if provided."""
18+
if value is not None:
19+
if not KnowledgeBase.objects.filter(id=value).exists():
20+
raise serializers.ValidationError(
21+
f"KnowledgeBase item with ID {value} does not exist."
22+
)
23+
return value
24+
25+
def create_for_problem(
26+
self, validated_data: dict, problem: Problem
27+
) -> KnowledgeBase:
28+
"""Create a new KnowledgeBase item for a problem."""
29+
return KnowledgeBase.objects.create(
30+
**validated_data,
31+
problem=problem,
32+
)
33+
34+
def update(self, instance: KnowledgeBase, validated_data: dict) -> KnowledgeBase:
35+
"""Update an existing KnowledgeBase item."""
36+
instance.entity1 = validated_data["entity1"]
37+
instance.relationship = validated_data["relationship"]
38+
instance.entity2 = validated_data["entity2"]
39+
instance.save()
40+
return instance
41+
42+
43+
class ProblemSerializer(serializers.ModelSerializer):
44+
"""
45+
Serializer for Problem model output.
46+
Handles serialization of problems with all related data including labels.
47+
"""
48+
49+
premises = serializers.SerializerMethodField()
50+
hypothesis = serializers.SerializerMethodField()
51+
entailmentLabel = serializers.CharField(source="entailment_label")
52+
extraData = serializers.SerializerMethodField()
53+
kbItems = serializers.SerializerMethodField()
54+
55+
class Meta:
56+
model = Problem
57+
fields = [
58+
"id",
59+
"dataset",
60+
"premises",
61+
"hypothesis",
62+
"entailmentLabel",
63+
"extraData",
64+
"kbItems",
65+
]
66+
67+
def get_premises(self, problem):
68+
"""Get list of premise texts."""
69+
return [premise.text for premise in problem.premises.all()]
70+
71+
def get_hypothesis(self, problem):
72+
"""Get hypothesis text."""
73+
return problem.hypothesis.text
74+
75+
def get_extraData(self, problem):
76+
"""Get dataset-specific extra data."""
77+
match problem.dataset:
78+
case Problem.Dataset.SICK:
79+
return SickData.serialize(problem.extra_data)
80+
case Problem.Dataset.FRACAS:
81+
return FracasData.serialize(problem.extra_data)
82+
case Problem.Dataset.SNLI:
83+
return SNLIData.serialize(problem.extra_data)
84+
case _:
85+
return {}
86+
87+
def get_kbItems(self, problem):
88+
"""Get knowledge base items."""
89+
kb_items = problem.knowledge_bases.all()
90+
return KnowledgeBaseSerializer(kb_items, many=True).data
91+
92+
def create(self, validated_data: dict) -> Problem:
93+
"""
94+
Create a new Problem instance from validated input data.
95+
Handles creation of related Sentence and KnowledgeBase objects.
96+
"""
97+
premise_sentences = [
98+
Sentence.objects.get_or_create(text=premise)[0]
99+
for premise in validated_data["premises"]
100+
]
101+
102+
hypothesis_sentence = Sentence.objects.get_or_create(
103+
text=validated_data["hypothesis"]
104+
)[0]
105+
106+
problem = Problem.objects.create(
107+
hypothesis=hypothesis_sentence,
108+
dataset=Problem.Dataset.USER,
109+
# TODO: Determine entailment label based on LangPro parser output.
110+
entailment_label=Problem.EntailmentLabel.UNKNOWN,
111+
extra_data={},
112+
)
113+
114+
problem.premises.set(premise_sentences)
115+
116+
kb_items = validated_data.get("kbItems", [])
117+
if kb_items:
118+
self._update_or_create_kb_items(problem, kb_items)
119+
120+
return problem
121+
122+
def update(self, instance: Problem, validated_data: dict) -> Problem:
123+
"""
124+
Update an existing Problem instance from validated input data.
125+
Handles updating of related Sentence and KnowledgeBase objects.
126+
"""
127+
if instance.dataset != Problem.Dataset.USER:
128+
raise serializers.ValidationError(
129+
"Cannot update a problem that is not a user-created problem."
130+
)
131+
132+
instance.hypothesis = Sentence.objects.get_or_create(
133+
text=validated_data["hypothesis"],
134+
)[0]
135+
instance.save()
136+
137+
premise_sentences = [
138+
Sentence.objects.get_or_create(text=premise)[0]
139+
for premise in validated_data["premises"]
140+
]
141+
instance.premises.set(premise_sentences)
142+
143+
self._update_or_create_kb_items(instance, validated_data.get("kbItems", []))
144+
145+
return instance
146+
147+
def _update_or_create_kb_items(
148+
self, problem: Problem, kb_items: list[dict]
149+
) -> None:
150+
"""Create or update KnowledgeBase items for a problem."""
151+
kb_ids: list[int] = []
152+
kb_serializer = KnowledgeBaseSerializer()
153+
154+
for item in kb_items:
155+
kb_id = item.get("id", None)
156+
157+
if kb_id is None:
158+
kb = kb_serializer.create_for_problem(item, problem=problem) # type: ignore
159+
else:
160+
kb_instance = KnowledgeBase.objects.get(id=kb_id, problem_id=problem.pk)
161+
kb = kb_serializer.update(kb_instance, item)
162+
163+
kb_ids.append(kb.pk)
164+
165+
# Delete existing knowledge bases associated to this problem that are
166+
# not included in the input.
167+
KnowledgeBase.objects.filter(problem_id=problem.pk).exclude(
168+
id__in=kb_ids
169+
).delete()
170+
171+
172+
class ProblemInputSerializer(serializers.Serializer):
173+
"""
174+
Serializer for validating problem input data.
175+
This is used for both creating and updating user-created problems.
176+
"""
177+
178+
id = serializers.IntegerField(required=False, allow_null=True)
179+
premises = serializers.ListField(
180+
child=serializers.CharField(allow_blank=False),
181+
allow_empty=False,
182+
help_text="List of premise sentence texts",
183+
)
184+
hypothesis = serializers.CharField(
185+
allow_blank=False, help_text="Hypothesis sentence text"
186+
)
187+
kbItems = KnowledgeBaseSerializer(
188+
many=True, allow_empty=True, help_text="List of knowledge base items"
189+
)
190+
191+
def validate_id(self, value):
192+
"""Validate that the Problem ID, if provided, exists and belongs to a user-created problem."""
193+
if value is not None:
194+
if not Problem.objects.filter(
195+
id=value, dataset=Problem.Dataset.USER
196+
).exists():
197+
raise serializers.ValidationError(
198+
f"Problem with ID {value} does not exist."
199+
)
200+
return value

0 commit comments

Comments
 (0)