11from rest_framework import serializers
22from problem .services import FracasData , SNLIData , SickData
3- from problem .models import Problem , KnowledgeBase
3+ from problem .models import Problem , KnowledgeBase , Sentence
44
55
66class KnowledgeBaseSerializer (serializers .ModelSerializer ):
@@ -21,6 +21,23 @@ def validate_id(self, value):
2121 )
2222 return value
2323
24+ def create_for_problem (
25+ self , validated_data : dict , problem : Problem
26+ ) -> KnowledgeBase :
27+ """Create a new KnowledgeBase item for a problem."""
28+ return KnowledgeBase .objects .create (
29+ ** validated_data ,
30+ problem = problem ,
31+ )
32+
33+ def update (self , instance : KnowledgeBase , validated_data : dict ) -> KnowledgeBase :
34+ """Update an existing KnowledgeBase item."""
35+ instance .entity1 = validated_data ["entity1" ]
36+ instance .relationship = validated_data ["relationship" ]
37+ instance .entity2 = validated_data ["entity2" ]
38+ instance .save ()
39+ return instance
40+
2441
2542class ProblemSerializer (serializers .ModelSerializer ):
2643 """
@@ -71,6 +88,85 @@ def get_kbItems(self, problem):
7188 kb_items = problem .knowledge_bases .all ()
7289 return KnowledgeBaseSerializer (kb_items , many = True ).data
7390
91+ def create (self , validated_data : dict ) -> Problem :
92+ """
93+ Create a new Problem instance from validated input data.
94+ Handles creation of related Sentence and KnowledgeBase objects.
95+ """
96+ premise_sentences = [
97+ Sentence .objects .get_or_create (text = premise )[0 ]
98+ for premise in validated_data ["premises" ]
99+ ]
100+
101+ hypothesis_sentence = Sentence .objects .get_or_create (
102+ text = validated_data ["hypothesis" ]
103+ )[0 ]
104+
105+ problem = Problem .objects .create (
106+ hypothesis = hypothesis_sentence ,
107+ dataset = Problem .Dataset .USER ,
108+ # TODO: Determine entailment label based on LangPro parser output.
109+ entailment_label = Problem .EntailmentLabel .UNKNOWN ,
110+ extra_data = {},
111+ )
112+
113+ problem .premises .set (premise_sentences )
114+
115+ kb_items = validated_data .get ("kbItems" , [])
116+ if kb_items :
117+ self ._update_or_create_kb_items (problem , kb_items )
118+
119+ return problem
120+
121+ def update (self , instance : Problem , validated_data : dict ) -> Problem :
122+ """
123+ Update an existing Problem instance from validated input data.
124+ Handles updating of related Sentence and KnowledgeBase objects.
125+ """
126+ if instance .dataset != Problem .Dataset .USER :
127+ raise serializers .ValidationError (
128+ "Cannot update a problem that is not a user-created problem."
129+ )
130+
131+ instance .hypothesis = Sentence .objects .get_or_create (
132+ text = validated_data ["hypothesis" ],
133+ )[0 ]
134+ instance .save ()
135+
136+ premise_sentences = [
137+ Sentence .objects .get_or_create (text = premise )[0 ]
138+ for premise in validated_data ["premises" ]
139+ ]
140+ instance .premises .set (premise_sentences )
141+
142+ self ._update_or_create_kb_items (instance , validated_data .get ("kbItems" , []))
143+
144+ return instance
145+
146+ def _update_or_create_kb_items (
147+ self , problem : Problem , kb_items : list [dict ]
148+ ) -> None :
149+ """Create or update KnowledgeBase items for a problem."""
150+ kb_ids : list [int ] = []
151+ kb_serializer = KnowledgeBaseSerializer ()
152+
153+ for item in kb_items :
154+ kb_id = item .get ("id" , None )
155+
156+ if kb_id is None :
157+ kb = kb_serializer .create_for_problem (item , problem = problem ) # type: ignore
158+ else :
159+ kb_instance = KnowledgeBase .objects .get (id = kb_id , problem_id = problem .pk )
160+ kb = kb_serializer .update (kb_instance , item )
161+
162+ kb_ids .append (kb .pk )
163+
164+ # Delete existing knowledge bases associated to this problem that are
165+ # not included in the input.
166+ KnowledgeBase .objects .filter (problem_id = problem .pk ).exclude (
167+ id__in = kb_ids
168+ ).delete ()
169+
74170
75171class ProblemInputSerializer (serializers .Serializer ):
76172 """
@@ -91,16 +187,3 @@ class ProblemInputSerializer(serializers.Serializer):
91187 many = True , allow_empty = True , help_text = "List of knowledge base items"
92188 )
93189
94- def validate_id (self , value ):
95- """
96- Validate that the problem ID exists and belongs to a user problem.
97- Users are not allowed to modify non-user problems.
98- """
99- if value is not None :
100- if not Problem .objects .filter (
101- id = value , dataset = Problem .Dataset .USER
102- ).exists ():
103- raise serializers .ValidationError (
104- f"Problem with ID { value } does not exist or is not a user problem."
105- )
106- return value
0 commit comments