Skip to content

Commit 75146f9

Browse files
Use serializer create/update methods
1 parent 1bfb7f8 commit 75146f9

3 files changed

Lines changed: 116 additions & 101 deletions

File tree

backend/problem/serializers.py

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from rest_framework import serializers
22
from problem.services import FracasData, SNLIData, SickData
3-
from problem.models import Problem, KnowledgeBase
3+
from problem.models import Problem, KnowledgeBase, Sentence
44

55

66
class KnowledgeBaseSerializer(serializers.ModelSerializer):
@@ -21,6 +21,23 @@ def validate_id(self, value):
2121
)
2222
return value
2323

24+
def create_for_problem(
25+
self, validated_data: dict, problem: Problem
26+
) -> KnowledgeBase:
27+
"""Create a new KnowledgeBase item for a problem."""
28+
return KnowledgeBase.objects.create(
29+
**validated_data,
30+
problem=problem,
31+
)
32+
33+
def update(self, instance: KnowledgeBase, validated_data: dict) -> KnowledgeBase:
34+
"""Update an existing KnowledgeBase item."""
35+
instance.entity1 = validated_data["entity1"]
36+
instance.relationship = validated_data["relationship"]
37+
instance.entity2 = validated_data["entity2"]
38+
instance.save()
39+
return instance
40+
2441

2542
class ProblemSerializer(serializers.ModelSerializer):
2643
"""
@@ -71,6 +88,85 @@ def get_kbItems(self, problem):
7188
kb_items = problem.knowledge_bases.all()
7289
return KnowledgeBaseSerializer(kb_items, many=True).data
7390

91+
def create(self, validated_data: dict) -> Problem:
92+
"""
93+
Create a new Problem instance from validated input data.
94+
Handles creation of related Sentence and KnowledgeBase objects.
95+
"""
96+
premise_sentences = [
97+
Sentence.objects.get_or_create(text=premise)[0]
98+
for premise in validated_data["premises"]
99+
]
100+
101+
hypothesis_sentence = Sentence.objects.get_or_create(
102+
text=validated_data["hypothesis"]
103+
)[0]
104+
105+
problem = Problem.objects.create(
106+
hypothesis=hypothesis_sentence,
107+
dataset=Problem.Dataset.USER,
108+
# TODO: Determine entailment label based on LangPro parser output.
109+
entailment_label=Problem.EntailmentLabel.UNKNOWN,
110+
extra_data={},
111+
)
112+
113+
problem.premises.set(premise_sentences)
114+
115+
kb_items = validated_data.get("kbItems", [])
116+
if kb_items:
117+
self._update_or_create_kb_items(problem, kb_items)
118+
119+
return problem
120+
121+
def update(self, instance: Problem, validated_data: dict) -> Problem:
122+
"""
123+
Update an existing Problem instance from validated input data.
124+
Handles updating of related Sentence and KnowledgeBase objects.
125+
"""
126+
if instance.dataset != Problem.Dataset.USER:
127+
raise serializers.ValidationError(
128+
"Cannot update a problem that is not a user-created problem."
129+
)
130+
131+
instance.hypothesis = Sentence.objects.get_or_create(
132+
text=validated_data["hypothesis"],
133+
)[0]
134+
instance.save()
135+
136+
premise_sentences = [
137+
Sentence.objects.get_or_create(text=premise)[0]
138+
for premise in validated_data["premises"]
139+
]
140+
instance.premises.set(premise_sentences)
141+
142+
self._update_or_create_kb_items(instance, validated_data.get("kbItems", []))
143+
144+
return instance
145+
146+
def _update_or_create_kb_items(
147+
self, problem: Problem, kb_items: list[dict]
148+
) -> None:
149+
"""Create or update KnowledgeBase items for a problem."""
150+
kb_ids: list[int] = []
151+
kb_serializer = KnowledgeBaseSerializer()
152+
153+
for item in kb_items:
154+
kb_id = item.get("id", None)
155+
156+
if kb_id is None:
157+
kb = kb_serializer.create_for_problem(item, problem=problem) # type: ignore
158+
else:
159+
kb_instance = KnowledgeBase.objects.get(id=kb_id, problem_id=problem.pk)
160+
kb = kb_serializer.update(kb_instance, item)
161+
162+
kb_ids.append(kb.pk)
163+
164+
# Delete existing knowledge bases associated to this problem that are
165+
# not included in the input.
166+
KnowledgeBase.objects.filter(problem_id=problem.pk).exclude(
167+
id__in=kb_ids
168+
).delete()
169+
74170

75171
class ProblemInputSerializer(serializers.Serializer):
76172
"""
@@ -91,16 +187,3 @@ class ProblemInputSerializer(serializers.Serializer):
91187
many=True, allow_empty=True, help_text="List of knowledge base items"
92188
)
93189

94-
def validate_id(self, value):
95-
"""
96-
Validate that the problem ID exists and belongs to a user problem.
97-
Users are not allowed to modify non-user problems.
98-
"""
99-
if value is not None:
100-
if not Problem.objects.filter(
101-
id=value, dataset=Problem.Dataset.USER
102-
).exists():
103-
raise serializers.ValidationError(
104-
f"Problem with ID {value} does not exist or is not a user problem."
105-
)
106-
return value

backend/problem/views/problem.py

Lines changed: 16 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from rest_framework.viewsets import ModelViewSet
55
from rest_framework.status import HTTP_201_CREATED, HTTP_200_OK
66

7+
from django.shortcuts import get_object_or_404
8+
79
from problem.problem_details import (
810
get_filters,
911
get_related_problem_ids,
1012
)
11-
from problem.models import KnowledgeBase, Problem, Sentence
13+
from problem.models import Problem
1214
from problem.serializers import ProblemInputSerializer, ProblemSerializer
1315

1416

@@ -95,101 +97,30 @@ def partial_update(self, request: Request, pk: int) -> Response:
9597
"""
9698
Updates an existing user-created Problem with the provided input data.
9799
"""
98-
return self._handle_update_create_problem(request, pk)
100+
return self._handle_update_create_problem(request, problem_id=pk)
99101

100102
def _handle_update_create_problem(
101103
self, request: Request, problem_id: int | None
102104
) -> Response:
103105
input_data = request.data
104106

105-
serializer = ProblemInputSerializer(data=input_data)
106-
serializer.is_valid(raise_exception=True)
107-
validated_input: dict = serializer.validated_data # type: ignore
108-
validated_input["id"] = problem_id
107+
input_serializer = ProblemInputSerializer(data=input_data)
108+
input_serializer.is_valid(raise_exception=True)
109+
validated_input: dict = input_serializer.validated_data # type: ignore
110+
111+
problem_serializer = ProblemSerializer()
109112

110113
if problem_id is None:
111-
problem = create_problem_from_input(validated_input)
114+
problem = problem_serializer.create(validated_input)
112115
status = HTTP_201_CREATED
113116
else:
114-
problem = update_problem_from_input(validated_input)
117+
problem_instance = get_object_or_404(
118+
Problem, id=problem_id, dataset=Problem.Dataset.USER
119+
)
120+
problem: Problem = problem_serializer.update(
121+
problem_instance, validated_input
122+
)
115123
status = HTTP_200_OK
116124

117125
return Response({"id": problem.pk}, status=status)
118126

119-
120-
def create_problem_from_input(parse_input: dict) -> Problem:
121-
"""
122-
Save a new Problem instance from the given parse input data.
123-
"""
124-
125-
premise_sentences = [
126-
Sentence.objects.get_or_create(text=premise)[0]
127-
for premise in parse_input["premises"]
128-
]
129-
130-
hypothesis_sentence = Sentence.objects.get_or_create(
131-
text=parse_input["hypothesis"]
132-
)[0]
133-
134-
problem = Problem.objects.create(
135-
hypothesis=hypothesis_sentence,
136-
dataset=Problem.Dataset.USER,
137-
# TODO: Determine entailment label based on LangPro parser output.
138-
entailment_label=Problem.EntailmentLabel.UNKNOWN,
139-
extra_data={},
140-
)
141-
142-
problem.premises.set(premise_sentences)
143-
144-
update_or_create_kb_items(problem=problem, kb_items=parse_input["kbItems"])
145-
146-
return problem
147-
148-
149-
def update_or_create_kb_items(problem: Problem, kb_items: list[dict]) -> None:
150-
kb_ids: list[str] = []
151-
for item in kb_items:
152-
id = getattr(item, "id", None)
153-
entity1 = item["entity1"]
154-
relationship = item["relationship"]
155-
entity2 = item["entity2"]
156-
157-
if id is None:
158-
kb = KnowledgeBase.objects.create(
159-
entity1=entity1,
160-
relationship=relationship,
161-
entity2=entity2,
162-
problem=problem,
163-
)
164-
kb_ids.append(kb.pk)
165-
else:
166-
kb = KnowledgeBase.objects.get(id=id, problem_id=problem.pk)
167-
kb.entity1 = entity1
168-
kb.relationship = relationship
169-
kb.entity2 = entity2
170-
kb.save()
171-
kb_ids.append(kb.pk)
172-
173-
# Delete existing knowledge bases associated to this problem that are
174-
# not included in the input.
175-
KnowledgeBase.objects.filter(problem_id=problem.pk).exclude(id__in=kb_ids).delete()
176-
177-
178-
def update_problem_from_input(parse_input: dict) -> Problem:
179-
problem = Problem.objects.get(id=parse_input["id"], dataset=Problem.Dataset.USER)
180-
181-
problem.hypothesis = Sentence.objects.get_or_create(
182-
text=parse_input["hypothesis"],
183-
)[0]
184-
problem.save()
185-
186-
premises: list[Sentence] = []
187-
for input_premise in parse_input["premises"]:
188-
premise = Sentence.objects.get_or_create(text=input_premise)[0]
189-
premises.append(premise)
190-
191-
problem.premises.set(premises)
192-
193-
update_or_create_kb_items(problem, parse_input["kbItems"])
194-
195-
return problem

frontend/src/app/services/problem.service.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ export class ProblemService {
6363

6464
public saveProblem$ = this.submit$.pipe(
6565
exhaustMap((problem) => {
66-
const url = `/api/problem/${problem.id ?? ""}`;
67-
return this.http.post<SaveProblemResponse>(url, problem).pipe(
66+
const action = problem.id ? this.http.patch<SaveProblemResponse>(`/api/problem/${problem.id}/`, problem) :
67+
this.http.post<SaveProblemResponse>(`/api/problem/`, problem);
68+
return action.pipe(
6869
catchError((error) => {
6970
console.error('Error saving problem:', error);
7071
return of({ id: null, error: 'Failed to save problem' });

0 commit comments

Comments
 (0)