3434import logging
3535
3636import openai
37- from neo4j import Driver
38-
3937from embed_module import EMBEDDING_MODEL , VECTOR_INDEX_NAME
4038from graph_schema import schema_to_prompt
39+ from neo4j import Driver
4140
4241logger = logging .getLogger (__name__ )
4342
127126# 1. Classify query intent
128127# ---------------------------------------------------------------------------
129128
129+
130130def query_intent (user_query : str , openai_api_key : str ) -> str :
131131 """
132132 Classify the user query into one of four retrieval strategies:
@@ -175,6 +175,7 @@ def query_intent(user_query: str, openai_api_key: str) -> str:
175175# 2. Entity extraction
176176# ---------------------------------------------------------------------------
177177
178+
178179def entity_extraction (
179180 user_query : str ,
180181 openai_api_key : str ,
@@ -248,6 +249,7 @@ def entity_extraction(
248249# 3. Entity resolution — look up canonical forms in Neo4j
249250# ---------------------------------------------------------------------------
250251
252+
251253def _resolve_persons (names : list [str ], session ) -> dict [str , str ]:
252254 """Fuzzy-match person names against the graph, return {input: canonical}."""
253255 resolved = {}
@@ -404,24 +406,16 @@ def entity_resolution(
404406
405407 with neo4j_driver .session () as session :
406408 if entity_extraction .get ("persons" ):
407- resolved ["persons" ] = _resolve_persons (
408- entity_extraction ["persons" ], session
409- )
409+ resolved ["persons" ] = _resolve_persons (entity_extraction ["persons" ], session )
410410
411411 if entity_extraction .get ("movies" ):
412- resolved ["movies" ] = _resolve_movies (
413- entity_extraction ["movies" ], session
414- )
412+ resolved ["movies" ] = _resolve_movies (entity_extraction ["movies" ], session )
415413
416414 if entity_extraction .get ("genres" ):
417- resolved ["genres" ] = _resolve_genres (
418- entity_extraction ["genres" ], session
419- )
415+ resolved ["genres" ] = _resolve_genres (entity_extraction ["genres" ], session )
420416
421417 if entity_extraction .get ("companies" ):
422- resolved ["companies" ] = _resolve_companies (
423- entity_extraction ["companies" ], session
424- )
418+ resolved ["companies" ] = _resolve_companies (entity_extraction ["companies" ], session )
425419
426420 # Pass through numeric/date filters unchanged
427421 for key in ("year_after" , "year_before" , "rating_above" , "rating_below" ):
@@ -436,6 +430,7 @@ def entity_resolution(
436430# 4. Vector path
437431# ---------------------------------------------------------------------------
438432
433+
439434def query_embedding (
440435 user_query : str ,
441436 openai_api_key : str ,
@@ -499,6 +494,7 @@ def vector_results(
499494# 5. Cypher generation using resolved entities
500495# ---------------------------------------------------------------------------
501496
497+
502498def _build_entity_context (resolved : dict ) -> str :
503499 """
504500 Build a plain-English summary of resolved entities for the Cypher
@@ -511,32 +507,32 @@ def _build_entity_context(resolved: dict) -> str:
511507
512508 persons = resolved .get ("persons" , {})
513509 if persons :
514- for original , canonical in persons .items ():
510+ for _original , canonical in persons .items ():
515511 lines .append (f' Person: "{ canonical } "' )
516512
517513 movies = resolved .get ("movies" , {})
518514 if movies :
519- for original , canonical in movies .items ():
515+ for _original , canonical in movies .items ():
520516 lines .append (f' Movie title: "{ canonical } "' )
521517
522518 genres = resolved .get ("genres" , {})
523519 if genres :
524- for original , canonical in genres .items ():
520+ for _original , canonical in genres .items ():
525521 lines .append (f' Genre: "{ canonical } "' )
526522
527523 companies = resolved .get ("companies" , {})
528524 if companies :
529- for original , canonical in companies .items ():
525+ for _original , canonical in companies .items ():
530526 lines .append (f' ProductionCompany: "{ canonical } "' )
531527
532528 if "year_after" in resolved :
533- lines .append (f' Date filter: m.release_date > \ '{ resolved [" year_after" ]} -01-01\' ' )
529+ lines .append (f" Date filter: m.release_date > '{ resolved [' year_after' ]} -01-01'" )
534530 if "year_before" in resolved :
535- lines .append (f' Date filter: m.release_date < \ '{ resolved [" year_before" ]} -12-31\' ' )
531+ lines .append (f" Date filter: m.release_date < '{ resolved [' year_before' ]} -12-31'" )
536532 if "rating_above" in resolved :
537- lines .append (f' Rating filter: m.vote_average > { resolved [" rating_above" ] } ' )
533+ lines .append (f" Rating filter: m.vote_average > { resolved [' rating_above' ] } " )
538534 if "rating_below" in resolved :
539- lines .append (f' Rating filter: m.vote_average < { resolved [" rating_below" ] } ' )
535+ lines .append (f" Rating filter: m.vote_average < { resolved [' rating_below' ] } " )
540536
541537 return "\n " .join (lines )
542538
@@ -656,6 +652,7 @@ def cypher_results(
656652# 6. Enrich vector results with graph traversal
657653# ---------------------------------------------------------------------------
658654
655+
659656def _enrich_movie (movie_id : int , driver : Driver ) -> dict | None :
660657 """Pull directors, cast, genres, companies for a movie node."""
661658 cypher = """
@@ -684,6 +681,7 @@ def _enrich_movie(movie_id: int, driver: Driver) -> dict | None:
684681# 7. Merge results
685682# ---------------------------------------------------------------------------
686683
684+
687685def merged_results (
688686 vector_results : list [dict ],
689687 cypher_results : list [dict ],
@@ -722,6 +720,7 @@ def merged_results(
722720# 8. Format context
723721# ---------------------------------------------------------------------------
724722
723+
725724def retrieved_context (merged_results : list [dict ], query_intent : str ) -> str :
726725 """
727726 Format merged results into plain-text context for the generation DAG.
@@ -734,26 +733,26 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
734733 return "No relevant information found in the knowledge graph for this query."
735734
736735 FIELD_LABELS = {
737- "movie" : "Movie" ,
738- "director" : "Director" ,
739- "actor" : "Actor" ,
740- "genre" : "Genre" ,
741- "company" : "Production company" ,
742- "film_count" : "Films" ,
743- "movie_count" : "Count" ,
736+ "movie" : "Movie" ,
737+ "director" : "Director" ,
738+ "actor" : "Actor" ,
739+ "genre" : "Genre" ,
740+ "company" : "Production company" ,
741+ "film_count" : "Films" ,
742+ "movie_count" : "Count" ,
744743 "action_movie_count" : "Action movies" ,
745- "avg_rating" : "Avg rating" ,
746- "average_rating" : "Avg rating" ,
747- "vote_average" : "Rating" ,
748- "release_date" : "Released" ,
744+ "avg_rating" : "Avg rating" ,
745+ "average_rating" : "Avg rating" ,
746+ "vote_average" : "Rating" ,
747+ "release_date" : "Released" ,
749748 }
750749
751750 lines = []
752751 i = 0
753752
754753 for row in merged_results :
755754 i += 1
756- source = row .get ("_source" , "unknown" )
755+ _source = row .get ("_source" , "unknown" )
757756
758757 if "directors" in row :
759758 # Enriched movie record from vector path
@@ -786,4 +785,4 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
786785
787786 context = "\n " .join (lines )
788787 logger .info ("Formatted context: %d chars from %d results" , len (context ), len (merged_results ))
789- return context
788+ return context
0 commit comments