Skip to content

Commit b2245c2

Browse files
committed
2 parents f5fd43f + 58ae94f commit b2245c2

11 files changed

Lines changed: 273 additions & 34 deletions

controller/dapr_controller.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from enum import Enum, auto
2+
13
import dapr
24
from dapr.ext.fastapi.app import DaprApp
35
from fastapi import FastAPI
46
import uuid
57

68
from dto import TaskInformationDto
9+
from controller.events import ContentChangeEvent, CrudOperation
710
from service.DocProcAiService import DocProcAiService
811

912

@@ -34,4 +37,13 @@ def assessment_content_mutated_handler(data: dict):
3437
assessment_id = uuid.UUID(data["data"]["assessmentId"])
3538
task_information: list[TaskInformationDto] = data["data"]["taskInformationList"]
3639

37-
ai_service.enqueue_generate_assessment_segments(assessment_id, task_information)
40+
ai_service.enqueue_generate_assessment_segments(assessment_id, task_information)
41+
42+
@dapr_app.subscribe(pubsub="meitrex", topic="content-changed")
43+
def assessment_content_deleted_handler(data: dict):
44+
content_change_event = ContentChangeEvent(data["data"]["contentIds"], data["data"]["operation"])
45+
46+
if content_change_event.crudOperation == "DELETE":
47+
ai_service.delete_entries_of_assessments(content_change_event)
48+
49+

controller/events.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import uuid
2+
from dataclasses import dataclass
3+
from enum import Enum, auto
4+
5+
6+
class CrudOperation(Enum):
7+
CREATE = auto()
8+
UPDATE = auto()
9+
DELETE = auto()
10+
11+
@dataclass
12+
class ContentChangeEvent:
13+
contentIds: list[uuid]
14+
crudOperation: CrudOperation

controller/graphql_controller.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def get_media_record_summary(parent, info, mediaRecordId: UUID) -> list[str]:
8585
def get_media_record_suggested_tags(parent, info, mediaRecordId: UUID) -> list[str]:
8686
return ai_service.get_media_record_tags(mediaRecordId)
8787

88+
@query.field("_internal_noauth_getAssessmentSuggestedTags")
89+
def get_media_record_suggested_tags(parent, info, assessmentId: UUID) -> list[str]:
90+
return ai_service.get_assessment_tags(assessmentId)
91+
8892
@query.field("_internal_noauth_getMediaRecordsAiProcessingProgress")
8993
def get_media_records_ai_processing_state(parent, info, mediaRecordIds: list[UUID])\
9094
-> list[AiEntityProcessingProgressDto]:

fileextractlib/TopicModel.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,21 @@
88
from bertopic.vectorizers import ClassTfidfTransformer
99
from sklearn.feature_extraction.text import CountVectorizer
1010

11+
from persistence.AssesmentInfoDbConnector import AssessmentInfoDbConnector
1112
from persistence.MediaRecordInfoDbConnector import MediaRecordInfoDbConnector
1213
from persistence.SegmentDbConnector import SegmentDbConnector
13-
from persistence.entities import DocumentSegmentEntity, VideoSegmentEntity
14+
from persistence.entities import DocumentSegmentEntity, VideoSegmentEntity, AssessmentSegmentEntity
1415

1516
_logger = logging.getLogger(__name__)
1617

1718

1819
class TopicModel:
1920
model = BERTopic()
2021

21-
def __init__(self, record_segments: list[DocumentSegmentEntity | VideoSegmentEntity], media_records):
22+
def __init__(self, record_segments: list[VideoSegmentEntity | DocumentSegmentEntity | AssessmentSegmentEntity]):
2223
self.record_segments = []
23-
self.media_records = {}
2424
self.docs = []
2525
self.record_segments = record_segments
26-
self.media_records = media_records
2726
self.docs = []
2827

2928
def create_topic_model(self):
@@ -36,6 +35,9 @@ def create_topic_model(self):
3635
if isinstance(entity, VideoSegmentEntity):
3736
self.docs.append(entity.transcript)
3837
embeddings.append(entity.embedding)
38+
if isinstance(entity, AssessmentSegmentEntity):
39+
self.docs.append(entity.textual_representation)
40+
embeddings.append(entity.embedding)
3941

4042
if len(self.docs) < 11:
4143
_logger.info("More documents needed to create topic model.")
@@ -46,7 +48,6 @@ def create_topic_model(self):
4648
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True, bm25_weighting=True)
4749
mmr = MaximalMarginalRelevance(diversity=0.3)
4850

49-
5051
representation_models = mmr
5152

5253
self.model = BERTopic(
@@ -58,41 +59,74 @@ def create_topic_model(self):
5859

5960
self.model.fit_transform(self.docs, embeddings)
6061

61-
def add_tags_to_media_records(self, record_segments, media_records):
62+
def add_tags_to_media_records(self, segments):
6263
if len(self.docs) < 11:
6364
_logger.info("Topic model wasn't created. More documents needed.")
6465
return
6566
document_info = self.model.get_document_info(self.docs)
6667
mediarecords_with_tags = {}
6768

6869
i = 0
69-
for record in media_records:
70-
mediarecords_with_tags.update({record.get(id): set()})
70+
while i < len(segments):
71+
if isinstance(segments[i], AssessmentSegmentEntity):
72+
i += 1
73+
continue
7174

72-
while i < len(record_segments):
73-
mediarecord_id = record_segments[i].media_record_id
75+
mediarecord_id = segments[i].media_record_id
7476

75-
if isinstance(record_segments[i], DocumentSegmentEntity):
76-
if record_segments[i].text != document_info['Document'].iat[i]:
77+
if isinstance(segments[i], DocumentSegmentEntity):
78+
if segments[i].text != document_info['Document'].iat[i]:
79+
i += 1
7780
continue
78-
79-
elif isinstance(record_segments[i], VideoSegmentEntity):
80-
if record_segments[i].transcript != document_info['Document'].iat[i]:
81+
elif isinstance(segments[i], VideoSegmentEntity):
82+
if segments[i].transcript != document_info['Document'].iat[i]:
83+
i += 1
8184
continue
8285

8386
tags = set()
8487
if mediarecords_with_tags.get(mediarecord_id) is not None:
85-
tags = mediarecords_with_tags.get(mediarecord_id)
88+
tags = mediarecords_with_tags.get(mediarecord_id)
8689
tags.update(set(document_info['Representation'].iat[i]))
8790

8891
mediarecords_with_tags.update({mediarecord_id: tags})
8992
i += 1
9093

9194
return mediarecords_with_tags
9295

93-
if __name__ == "__main__":
96+
def add_tags_to_assessments(self, segments):
97+
if len(self.docs) < 11:
98+
_logger.info("Topic model wasn't created. More documents needed.")
99+
return
100+
document_info = self.model.get_document_info(self.docs)
101+
assesments_with_tags = {}
102+
103+
i = 0
104+
105+
while i < len(segments):
106+
if isinstance(segments[i], DocumentSegmentEntity) or isinstance(segments[i], VideoSegmentEntity):
107+
i += 1
108+
continue
94109

95-
star = time.time()
110+
assessment_id = segments[i].assessment_id
111+
112+
if isinstance(segments[i], AssessmentSegmentEntity):
113+
if segments[i].textual_representation != document_info['Document'].iat[i]:
114+
i += 1
115+
continue
116+
117+
tags = set()
118+
if assesments_with_tags.get(assessment_id) is not None:
119+
tags = assesments_with_tags.get(assessment_id)
120+
tags.update(set(document_info['Representation'].iat[i]))
121+
122+
assesments_with_tags.update({assessment_id: tags})
123+
i += 1
124+
125+
return assesments_with_tags
126+
127+
128+
if __name__ == "__main__":
129+
start = time.time()
96130

97131
print("Connecting to DB")
98132
database_connection = psycopg.connect(
@@ -103,24 +137,29 @@ def add_tags_to_media_records(self, record_segments, media_records):
103137

104138
segment_database = SegmentDbConnector(database_connection)
105139
media_record_info_database = MediaRecordInfoDbConnector(database_connection)
140+
assessment_database = AssessmentInfoDbConnector(database_connection)
106141

107142
print("Loading segments and media records")
108143

109-
record_segments = segment_database.get_all_media_record_segments()
144+
segments = segment_database.get_all_entity_segments()
110145
media_records = media_record_info_database.get_all_media_records()
146+
assessments = assessment_database.get_all_assessments()
111147

112-
topic_model = TopicModel(record_segments, media_records)
148+
topic_model = TopicModel(segments)
113149

114150
print("Running Topic model")
115151
topic_model.create_topic_model()
152+
print("Topic model created")
116153

117-
media_records_with_tags = topic_model.add_tags_to_media_records(record_segments, media_records)
154+
print("Adding tags")
155+
media_records_with_tags = topic_model.add_tags_to_media_records(segments)
156+
assessments_with_tags = topic_model.add_tags_to_assessments(segments)
118157
if media_records_with_tags is not None:
119158
for mrid, tags in media_records_with_tags.items():
120159
media_record_info_database.update_media_record_tags(mrid, list(tags))
121-
end = time.time()
122-
print("Done in " + str(end - star) + " seconds")
123-
124-
125-
126160

161+
if assessments_with_tags is not None:
162+
for aid, tags in assessments_with_tags.items():
163+
assessment_database.update_assessment_tags(aid, list(tags))
164+
end = time.time()
165+
print("Done in " + str(end - start) + " seconds")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from uuid import UUID
2+
3+
import psycopg
4+
from pgvector.psycopg import register_vector
5+
6+
7+
class AssessmentInfoDbConnector:
8+
def __init__(self, db_connection: psycopg.Connection):
9+
self.db_connection = db_connection
10+
11+
# ensure pgvector extension is installed, we need it to store text embeddings
12+
self.db_connection.execute("CREATE EXTENSION IF NOT EXISTS vector")
13+
register_vector(self.db_connection)
14+
15+
self.db_connection.execute(
16+
"""
17+
CREATE TABLE IF NOT EXISTS assessments (
18+
id uuid PRIMARY KEY,
19+
tags text[]
20+
);
21+
""")
22+
23+
def upsert_assessment_info(self, id: UUID):
24+
self.db_connection.execute(
25+
query="""
26+
INSERT INTO assessments (id, tags)
27+
VALUES (%s, %s)
28+
ON CONFLICT (id)
29+
DO UPDATE SET
30+
tags = EXCLUDED.tags
31+
""",
32+
params=(id, [])
33+
)
34+
35+
def get_assessment_tags_by_id(self, assesment_id) -> list[str]:
36+
query_result = self.db_connection.execute(
37+
"SELECT tags FROM assessments WHERE id = %s",
38+
(assesment_id,)).fetchone()
39+
40+
if query_result is None:
41+
return []
42+
43+
return query_result["tags"]
44+
45+
def get_all_assessments(self):
46+
cursor = self.db_connection.cursor()
47+
cursor.execute(
48+
"SELECT * FROM assessments"
49+
)
50+
return cursor.fetchall()
51+
52+
def update_assessment_tags(self, id: UUID, tags: list[str]):
53+
self.db_connection.execute(
54+
"""
55+
UPDATE assessments
56+
SET tags = (%(tags)s)
57+
WHERE id = (%(id)s)
58+
""",
59+
{'tags': tags, 'id': id})
60+
61+
def delete_assessment_by_id(self, id: UUID):
62+
self.db_connection.execute(
63+
"""
64+
DELETE FROM assessments WHERE id = (%(id)s)
65+
""",
66+
{'id': id}
67+
)

persistence/IngestionStateDbConnector.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,12 @@ def get_enqueued_or_processing_ingestion_entities(self) \
101101
FROM media_record_ingestion_states
102102
WHERE state IN ('ENQUEUED', 'PROCESSING');
103103
""").fetchall()
104-
return [(x["id"], x["entity_type"], x["state"]) for x in query_results]
104+
return [(x["id"], x["entity_type"], x["state"]) for x in query_results]
105+
106+
def delete_ingestion_state(self, id: UUID) -> None:
107+
self.db_connection.execute(
108+
"""
109+
DELETE FROM media_record_ingestion_states WHERE id = (%(id)s);
110+
""",
111+
{'id': id}
112+
)

persistence/MediaRecordInfoDbConnector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,11 @@ def update_media_record_tags(self, id: UUID, tags: list[str]):
8181
WHERE id = (%(id)s)
8282
""",
8383
{'tags': tags, 'id': id})
84+
85+
def delete_media_record_by_id(self, id: UUID):
86+
self.db_connection.execute(
87+
"""
88+
DELETE FROM media_records WHERE id = (%(id)s)
89+
""",
90+
{'id': id}
91+
)

persistence/SegmentDbConnector.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,19 @@ def get_all_media_record_segments(self) -> list[EntitySegmentEntity]:
290290
"""
291291
return self.__get_record_segments_with_query(query, {})
292292

293+
def get_all_entity_segments(self) -> list[EntitySegmentEntity]:
294+
query = """
295+
SELECT * FROM (
296+
(SELECT *, 'document' AS source FROM document_segments) AS t1
297+
NATURAL FULL JOIN
298+
(SELECT *, 'video' AS source FROM video_segments) AS t2
299+
NATURAL FULL JOIN
300+
(SELECT *, 'assessment' AS source FROM assessment_segments) AS t3
301+
);
302+
"""
303+
return self.__get_record_segments_with_query(query, {})
304+
305+
293306
def get_entity_segments_by_ids(self, segment_ids: list[UUID]) -> list[EntitySegmentEntity]:
294307
query = """
295308
WITH document_results AS (

persistence/entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ def __init__(self, id: UUID, summary: list[str], tags: set):
77
self.summary = summary
88
self.tags = tags
99

10+
class AssessmentEntity:
11+
def __init__(self, id: UUID, tags: set):
12+
self.tags = tags
13+
14+
1015

1116
class DocumentSegmentEntity:
1217
def __init__(self, id: UUID, media_record_id: UUID, page_index: int, text: str, thumbnail: bytes, title: str,

schema/query.graphqls

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ type Query {
8282
"""
8383
_internal_noauth_getMediaRecordSuggestedTags(mediaRecordId: UUID!): [String!]!
8484

85+
"""
86+
Gets the suggested tags of the specified assessment. Returns a list of strings
87+
where each string is a tag.
88+
89+
⚠️ This query is only accessible internally in the system and allows the caller to fetch contents without
90+
any permissions check and should not be called without any validation of the caller's permissions. ⚠️
91+
"""
92+
_internal_noauth_getAssessmentSuggestedTags(assessmentId: UUID!): [String!]!
93+
8594
"""
8695
Gets the DocProcAI ingestion processing state of the specified media records. "UNKNOWN" is returned if the specified
8796
ID is unknown to the service (either because a media record with the given ID does not exist or because the media

0 commit comments

Comments
 (0)