88from bertopic .vectorizers import ClassTfidfTransformer
99from sklearn .feature_extraction .text import CountVectorizer
1010
11+ from persistence .AssesmentInfoDbConnector import AssessmentInfoDbConnector
1112from persistence .MediaRecordInfoDbConnector import MediaRecordInfoDbConnector
1213from persistence .SegmentDbConnector import SegmentDbConnector
13- from persistence .entities import DocumentSegmentEntity , VideoSegmentEntity
14+ from persistence .entities import DocumentSegmentEntity , VideoSegmentEntity , AssessmentSegmentEntity
1415
1516_logger = logging .getLogger (__name__ )
1617
1718
1819class TopicModel :
1920 model = BERTopic ()
2021
21- def __init__ (self , record_segments : list [DocumentSegmentEntity | VideoSegmentEntity ], media_records ):
22+ def __init__ (self , record_segments : list [VideoSegmentEntity | DocumentSegmentEntity | AssessmentSegmentEntity ] ):
2223 self .record_segments = []
23- self .media_records = {}
2424 self .docs = []
2525 self .record_segments = record_segments
26- self .media_records = media_records
2726 self .docs = []
2827
2928 def create_topic_model (self ):
@@ -36,6 +35,9 @@ def create_topic_model(self):
3635 if isinstance (entity , VideoSegmentEntity ):
3736 self .docs .append (entity .transcript )
3837 embeddings .append (entity .embedding )
38+ if isinstance (entity , AssessmentSegmentEntity ):
39+ self .docs .append (entity .textual_representation )
40+ embeddings .append (entity .embedding )
3941
4042 if len (self .docs ) < 11 :
4143 _logger .info ("More documents needed to create topic model." )
@@ -46,7 +48,6 @@ def create_topic_model(self):
4648 ctfidf_model = ClassTfidfTransformer (reduce_frequent_words = True , bm25_weighting = True )
4749 mmr = MaximalMarginalRelevance (diversity = 0.3 )
4850
49-
5051 representation_models = mmr
5152
5253 self .model = BERTopic (
@@ -58,41 +59,74 @@ def create_topic_model(self):
5859
5960 self .model .fit_transform (self .docs , embeddings )
6061
61- def add_tags_to_media_records (self , record_segments , media_records ):
62+ def add_tags_to_media_records (self , segments ):
6263 if len (self .docs ) < 11 :
6364 _logger .info ("Topic model wasn't created. More documents needed." )
6465 return
6566 document_info = self .model .get_document_info (self .docs )
6667 mediarecords_with_tags = {}
6768
6869 i = 0
69- for record in media_records :
70- mediarecords_with_tags .update ({record .get (id ): set ()})
70+ while i < len (segments ):
71+ if isinstance (segments [i ], AssessmentSegmentEntity ):
72+ i += 1
73+ continue
7174
72- while i < len (record_segments ):
73- mediarecord_id = record_segments [i ].media_record_id
75+ mediarecord_id = segments [i ].media_record_id
7476
75- if isinstance (record_segments [i ], DocumentSegmentEntity ):
76- if record_segments [i ].text != document_info ['Document' ].iat [i ]:
77+ if isinstance (segments [i ], DocumentSegmentEntity ):
78+ if segments [i ].text != document_info ['Document' ].iat [i ]:
79+ i += 1
7780 continue
78-
79- elif isinstance ( record_segments [i ], VideoSegmentEntity ) :
80- if record_segments [ i ]. transcript != document_info [ 'Document' ]. iat [ i ]:
81+ elif isinstance ( segments [ i ], VideoSegmentEntity ):
82+ if segments [i ]. transcript != document_info [ 'Document' ]. iat [ i ] :
83+ i += 1
8184 continue
8285
8386 tags = set ()
8487 if mediarecords_with_tags .get (mediarecord_id ) is not None :
85- tags = mediarecords_with_tags .get (mediarecord_id )
88+ tags = mediarecords_with_tags .get (mediarecord_id )
8689 tags .update (set (document_info ['Representation' ].iat [i ]))
8790
8891 mediarecords_with_tags .update ({mediarecord_id : tags })
8992 i += 1
9093
9194 return mediarecords_with_tags
9295
93- if __name__ == "__main__" :
96+ def add_tags_to_assessments (self , segments ):
97+ if len (self .docs ) < 11 :
98+ _logger .info ("Topic model wasn't created. More documents needed." )
99+ return
100+ document_info = self .model .get_document_info (self .docs )
101+ assesments_with_tags = {}
102+
103+ i = 0
104+
105+ while i < len (segments ):
106+ if isinstance (segments [i ], DocumentSegmentEntity ) or isinstance (segments [i ], VideoSegmentEntity ):
107+ i += 1
108+ continue
94109
95- star = time .time ()
110+ assessment_id = segments [i ].assessment_id
111+
112+ if isinstance (segments [i ], AssessmentSegmentEntity ):
113+ if segments [i ].textual_representation != document_info ['Document' ].iat [i ]:
114+ i += 1
115+ continue
116+
117+ tags = set ()
118+ if assesments_with_tags .get (assessment_id ) is not None :
119+ tags = assesments_with_tags .get (assessment_id )
120+ tags .update (set (document_info ['Representation' ].iat [i ]))
121+
122+ assesments_with_tags .update ({assessment_id : tags })
123+ i += 1
124+
125+ return assesments_with_tags
126+
127+
128+ if __name__ == "__main__" :
129+ start = time .time ()
96130
97131 print ("Connecting to DB" )
98132 database_connection = psycopg .connect (
@@ -103,24 +137,29 @@ def add_tags_to_media_records(self, record_segments, media_records):
103137
104138 segment_database = SegmentDbConnector (database_connection )
105139 media_record_info_database = MediaRecordInfoDbConnector (database_connection )
140+ assessment_database = AssessmentInfoDbConnector (database_connection )
106141
107142 print ("Loading segments and media records" )
108143
109- record_segments = segment_database .get_all_media_record_segments ()
144+ segments = segment_database .get_all_entity_segments ()
110145 media_records = media_record_info_database .get_all_media_records ()
146+ assessments = assessment_database .get_all_assessments ()
111147
112- topic_model = TopicModel (record_segments , media_records )
148+ topic_model = TopicModel (segments )
113149
114150 print ("Running Topic model" )
115151 topic_model .create_topic_model ()
152+ print ("Topic model created" )
116153
117- media_records_with_tags = topic_model .add_tags_to_media_records (record_segments , media_records )
154+ print ("Adding tags" )
155+ media_records_with_tags = topic_model .add_tags_to_media_records (segments )
156+ assessments_with_tags = topic_model .add_tags_to_assessments (segments )
118157 if media_records_with_tags is not None :
119158 for mrid , tags in media_records_with_tags .items ():
120159 media_record_info_database .update_media_record_tags (mrid , list (tags ))
121- end = time .time ()
122- print ("Done in " + str (end - star ) + " seconds" )
123-
124-
125-
126160
161+ if assessments_with_tags is not None :
162+ for aid , tags in assessments_with_tags .items ():
163+ assessment_database .update_assessment_tags (aid , list (tags ))
164+ end = time .time ()
165+ print ("Done in " + str (end - start ) + " seconds" )
0 commit comments