Skip to content

Commit 31c0e82

Browse files
committed
Refactor code and docs
Previously, the PR for this example (#1532) only had a check for building docs, which bypass the pre-commit code format check.
1 parent 20daaa2 commit 31c0e82

File tree

11 files changed

+64
-63
lines changed

11 files changed

+64
-63
lines changed

examples/LLM_Workflows/neo4j_graph_rag/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,4 @@ neo4j_graph_rag/
202202
│ └── rag_dag.png
203203
└── data/
204204
└── README.md Dataset download and conversion instructions
205-
```
205+
```

examples/LLM_Workflows/neo4j_graph_rag/data/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ with open("tmdb_5000_credits.json", "w") as f:
5454
json.dump(credits.to_dict(orient="records"), f)
5555
```
5656

57-
Run this script once from inside the `data/` folder, then proceed with `python run.py --mode ingest`.
57+
Run this script once from inside the `data/` folder, then proceed with `python run.py --mode ingest`.

examples/LLM_Workflows/neo4j_graph_rag/data/data_refine.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
# under the License.
1717

1818

19-
import pandas as pd, json
20-
19+
import json
20+
21+
import pandas as pd
22+
2123
movies = pd.read_csv("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_movies.csv")
2224
credits = pd.read_csv("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_credits.csv")
23-
25+
2426
with open("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_movies.json", "w") as f:
2527
json.dump(movies.to_dict(orient="records"), f)
26-
28+
2729
with open("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_credits.json", "w") as f:
28-
json.dump(credits.to_dict(orient="records"), f)
30+
json.dump(credits.to_dict(orient="records"), f)

examples/LLM_Workflows/neo4j_graph_rag/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ services:
3939

4040
volumes:
4141
neo4j_data:
42-
neo4j_logs:
42+
neo4j_logs:

examples/LLM_Workflows/neo4j_graph_rag/embed_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def movie_embeddings(
9696
model=EMBEDDING_MODEL,
9797
input=[item["text"] for item in batch],
9898
)
99-
for item, emb_obj in zip(batch, response.data):
99+
for item, emb_obj in zip(batch, response.data, strict=True):
100100
results.append({"id": item["id"], "embedding": emb_obj.embedding})
101101

102102
logger.info("Embedded batch %d-%d of %d", i, min(i + BATCH_SIZE, total), total)
@@ -165,4 +165,4 @@ def embedding_summary(
165165
"dimensions": EMBEDDING_DIMENSIONS,
166166
}
167167
logger.info("Embedding complete: %s", summary)
168-
return summary
168+
return summary

examples/LLM_Workflows/neo4j_graph_rag/generation_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ def answer(prompt_messages: list[dict], openai_api_key: str) -> str:
9090
)
9191
result = response.choices[0].message.content
9292
logger.info("Generated answer (%d chars)", len(result))
93-
return result
93+
return result

examples/LLM_Workflows/neo4j_graph_rag/graph_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,4 @@ def schema_to_prompt() -> str:
105105
prop_str = f" with properties: {', '.join(props)}" if props else ""
106106
lines.append(f" (:{src})-[:{rel}]->(:{dest}){prop_str}")
107107

108-
return "\n".join(lines)
108+
return "\n".join(lines)

examples/LLM_Workflows/neo4j_graph_rag/ingest_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,4 @@ def ingestion_summary(
323323
"person_edges": write_person_nodes_and_edges,
324324
}
325325
logger.info("Ingestion complete: %s", summary)
326-
return summary
326+
return summary

examples/LLM_Workflows/neo4j_graph_rag/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
sf-hamilton>=1.73.0
1918
neo4j>=5.18.0
2019
openai>=1.30.0
2120
pandas>=2.0.0
2221
python-dotenv>=1.0.0
22+
sf-hamilton>=1.73.0
2323
tqdm>=4.66.0

examples/LLM_Workflows/neo4j_graph_rag/retrieval_module.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
import logging
3535

3636
import openai
37-
from neo4j import Driver
38-
3937
from embed_module import EMBEDDING_MODEL, VECTOR_INDEX_NAME
4038
from graph_schema import schema_to_prompt
39+
from neo4j import Driver
4140

4241
logger = logging.getLogger(__name__)
4342

@@ -127,6 +126,7 @@
127126
# 1. Classify query intent
128127
# ---------------------------------------------------------------------------
129128

129+
130130
def query_intent(user_query: str, openai_api_key: str) -> str:
131131
"""
132132
Classify the user query into one of four retrieval strategies:
@@ -175,6 +175,7 @@ def query_intent(user_query: str, openai_api_key: str) -> str:
175175
# 2. Entity extraction
176176
# ---------------------------------------------------------------------------
177177

178+
178179
def entity_extraction(
179180
user_query: str,
180181
openai_api_key: str,
@@ -248,6 +249,7 @@ def entity_extraction(
248249
# 3. Entity resolution — look up canonical forms in Neo4j
249250
# ---------------------------------------------------------------------------
250251

252+
251253
def _resolve_persons(names: list[str], session) -> dict[str, str]:
252254
"""Fuzzy-match person names against the graph, return {input: canonical}."""
253255
resolved = {}
@@ -404,24 +406,16 @@ def entity_resolution(
404406

405407
with neo4j_driver.session() as session:
406408
if entity_extraction.get("persons"):
407-
resolved["persons"] = _resolve_persons(
408-
entity_extraction["persons"], session
409-
)
409+
resolved["persons"] = _resolve_persons(entity_extraction["persons"], session)
410410

411411
if entity_extraction.get("movies"):
412-
resolved["movies"] = _resolve_movies(
413-
entity_extraction["movies"], session
414-
)
412+
resolved["movies"] = _resolve_movies(entity_extraction["movies"], session)
415413

416414
if entity_extraction.get("genres"):
417-
resolved["genres"] = _resolve_genres(
418-
entity_extraction["genres"], session
419-
)
415+
resolved["genres"] = _resolve_genres(entity_extraction["genres"], session)
420416

421417
if entity_extraction.get("companies"):
422-
resolved["companies"] = _resolve_companies(
423-
entity_extraction["companies"], session
424-
)
418+
resolved["companies"] = _resolve_companies(entity_extraction["companies"], session)
425419

426420
# Pass through numeric/date filters unchanged
427421
for key in ("year_after", "year_before", "rating_above", "rating_below"):
@@ -436,6 +430,7 @@ def entity_resolution(
436430
# 4. Vector path
437431
# ---------------------------------------------------------------------------
438432

433+
439434
def query_embedding(
440435
user_query: str,
441436
openai_api_key: str,
@@ -499,6 +494,7 @@ def vector_results(
499494
# 5. Cypher generation using resolved entities
500495
# ---------------------------------------------------------------------------
501496

497+
502498
def _build_entity_context(resolved: dict) -> str:
503499
"""
504500
Build a plain-English summary of resolved entities for the Cypher
@@ -511,32 +507,32 @@ def _build_entity_context(resolved: dict) -> str:
511507

512508
persons = resolved.get("persons", {})
513509
if persons:
514-
for original, canonical in persons.items():
510+
for _original, canonical in persons.items():
515511
lines.append(f' Person: "{canonical}"')
516512

517513
movies = resolved.get("movies", {})
518514
if movies:
519-
for original, canonical in movies.items():
515+
for _original, canonical in movies.items():
520516
lines.append(f' Movie title: "{canonical}"')
521517

522518
genres = resolved.get("genres", {})
523519
if genres:
524-
for original, canonical in genres.items():
520+
for _original, canonical in genres.items():
525521
lines.append(f' Genre: "{canonical}"')
526522

527523
companies = resolved.get("companies", {})
528524
if companies:
529-
for original, canonical in companies.items():
525+
for _original, canonical in companies.items():
530526
lines.append(f' ProductionCompany: "{canonical}"')
531527

532528
if "year_after" in resolved:
533-
lines.append(f' Date filter: m.release_date > \'{resolved["year_after"]}-01-01\'')
529+
lines.append(f" Date filter: m.release_date > '{resolved['year_after']}-01-01'")
534530
if "year_before" in resolved:
535-
lines.append(f' Date filter: m.release_date < \'{resolved["year_before"]}-12-31\'')
531+
lines.append(f" Date filter: m.release_date < '{resolved['year_before']}-12-31'")
536532
if "rating_above" in resolved:
537-
lines.append(f' Rating filter: m.vote_average > {resolved["rating_above"]}')
533+
lines.append(f" Rating filter: m.vote_average > {resolved['rating_above']}")
538534
if "rating_below" in resolved:
539-
lines.append(f' Rating filter: m.vote_average < {resolved["rating_below"]}')
535+
lines.append(f" Rating filter: m.vote_average < {resolved['rating_below']}")
540536

541537
return "\n".join(lines)
542538

@@ -656,6 +652,7 @@ def cypher_results(
656652
# 6. Enrich vector results with graph traversal
657653
# ---------------------------------------------------------------------------
658654

655+
659656
def _enrich_movie(movie_id: int, driver: Driver) -> dict | None:
660657
"""Pull directors, cast, genres, companies for a movie node."""
661658
cypher = """
@@ -684,6 +681,7 @@ def _enrich_movie(movie_id: int, driver: Driver) -> dict | None:
684681
# 7. Merge results
685682
# ---------------------------------------------------------------------------
686683

684+
687685
def merged_results(
688686
vector_results: list[dict],
689687
cypher_results: list[dict],
@@ -722,6 +720,7 @@ def merged_results(
722720
# 8. Format context
723721
# ---------------------------------------------------------------------------
724722

723+
725724
def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
726725
"""
727726
Format merged results into plain-text context for the generation DAG.
@@ -734,26 +733,26 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
734733
return "No relevant information found in the knowledge graph for this query."
735734

736735
FIELD_LABELS = {
737-
"movie": "Movie",
738-
"director": "Director",
739-
"actor": "Actor",
740-
"genre": "Genre",
741-
"company": "Production company",
742-
"film_count": "Films",
743-
"movie_count": "Count",
736+
"movie": "Movie",
737+
"director": "Director",
738+
"actor": "Actor",
739+
"genre": "Genre",
740+
"company": "Production company",
741+
"film_count": "Films",
742+
"movie_count": "Count",
744743
"action_movie_count": "Action movies",
745-
"avg_rating": "Avg rating",
746-
"average_rating": "Avg rating",
747-
"vote_average": "Rating",
748-
"release_date": "Released",
744+
"avg_rating": "Avg rating",
745+
"average_rating": "Avg rating",
746+
"vote_average": "Rating",
747+
"release_date": "Released",
749748
}
750749

751750
lines = []
752751
i = 0
753752

754753
for row in merged_results:
755754
i += 1
756-
source = row.get("_source", "unknown")
755+
_source = row.get("_source", "unknown")
757756

758757
if "directors" in row:
759758
# Enriched movie record from vector path
@@ -786,4 +785,4 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
786785

787786
context = "\n".join(lines)
788787
logger.info("Formatted context: %d chars from %d results", len(context), len(merged_results))
789-
return context
788+
return context

0 commit comments

Comments
 (0)