Skip to content

Commit e1b2aa2

Browse files
authored
Merge pull request #30 from iberi22/feature/hirag-sync-reranking
Implement HiRAG, Sync, and Re-ranking Features
2 parents 37cf6ae + a97fffa commit e1b2aa2

22 files changed

Lines changed: 1140 additions & 5 deletions

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,9 @@ print(explanation);
447447
- [x] Hybrid Retrieval (Dense + Isar Filter).
448448
- [x] Sync & Privacy (Encryption with AES-256-GCM, LWW conflict resolution).
449449
- [x] HiRAG Phase 1 (Layer-based organization, summary nodes, relationship types).
450-
- [ ] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval).
451-
- [ ] Cross-device sync backend (Firebase/WebSocket integration).
452-
- [ ] Re-ranking and advanced retrieval strategies.
450+
- [x] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval).
451+
- [x] Cross-device sync backend (Firebase/WebSocket integration).
452+
- [x] Re-ranking and advanced retrieval strategies.
453453

454454
---
455455

lib/src/hierarchical_graph.dart

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,65 @@
11
import 'package:isar/isar.dart';
22
import 'package:isar_agent_memory/isar_agent_memory.dart';
3+
import 'llm_adapter.dart';
34

45
/// Extension for HiRAG (Hierarchical RAG) capabilities.
56
///
67
/// This extension adds methods to manage hierarchical layers of knowledge.
78
extension HierarchicalMemoryGraph on MemoryGraph {
8-
99
/// Relation type for summary edges (Child -> Summary).
1010
static const String relationSummaryOf = 'summary_of';
1111

1212
/// Relation type for part-of edges (Part -> Whole).
1313
static const String relationPartOf = 'part_of';
1414

15+
/// Automatically summarizes a given layer by grouping nodes and using an LLM.
16+
///
17+
/// [layerIndex]: The layer to summarize.
18+
/// [llmAdapter]: The adapter to the Large Language Model for summary generation.
19+
/// [promptTemplate]: A function to format the content into a prompt for the LLM.
20+
/// If null, a default template is used.
21+
///
22+
/// Returns the ID of the newly created summary node.
23+
Future<int> autoSummarizeLayer({
24+
required int layerIndex,
25+
required LLMAdapter llmAdapter,
26+
String Function(String content)? promptTemplate,
27+
}) async {
28+
// 1. Get all nodes in the target layer
29+
final nodes = await getNodesByLayer(layerIndex);
30+
if (nodes.isEmpty) {
31+
throw Exception('No nodes found in layer $layerIndex to summarize.');
32+
}
33+
34+
// 2. Combine content for the LLM prompt
35+
final contentToSummarize = nodes.map((n) => n.content).join('\n---\n');
36+
final prompt = promptTemplate != null
37+
? promptTemplate(contentToSummarize)
38+
: 'Summarize the following content into a coherent paragraph: \n\n$contentToSummarize';
39+
40+
// 3. Call LLM to generate summary
41+
final summaryContent = await llmAdapter.generate(prompt);
42+
43+
// 4. Create the summary node in the next layer
44+
final childNodeIds = nodes.map((n) => n.id).toList();
45+
final summaryNodeId = await createSummaryNode(
46+
summaryContent: summaryContent,
47+
childNodeIds: childNodeIds,
48+
layer: layerIndex + 1,
49+
);
50+
51+
// 5. Create 'part_of' relationships from the summary to its parts
52+
for (final childId in childNodeIds) {
53+
await storeEdge(MemoryEdge(
54+
fromNodeId: summaryNodeId,
55+
toNodeId: childId,
56+
relation: relationPartOf,
57+
));
58+
}
59+
60+
return summaryNodeId;
61+
}
62+
1563
/// Creates a summary node for a list of [childNodeIds].
1664
///
1765
/// [summaryContent]: The summarized text.
@@ -56,4 +104,57 @@ extension HierarchicalMemoryGraph on MemoryGraph {
56104
Future<List<MemoryNode>> getNodesByLayer(int layer) async {
57105
return await isar.memoryNodes.filter().layerEqualTo(layer).findAll();
58106
}
107+
108+
/// Performs a multi-hop search, enriching results with hierarchical context.
109+
///
110+
/// [queryEmbedding]: The embedding of the search query.
111+
/// [maxHops]: The maximum number of upward traversals (default is 2).
112+
/// [topK]: The number of initial results to fetch from the base layer.
113+
///
114+
/// Returns a list of enriched results, where each result includes the base node
115+
/// and a list of parent (summary) nodes.
116+
Future<List<({MemoryNode node, List<MemoryNode> context})>> multiHopSearch({
117+
required List<double> queryEmbedding,
118+
int maxHops = 2,
119+
int topK = 5,
120+
}) async {
121+
// 1. Semantic search on the base layer (layer 0)
122+
final initialResults =
123+
await semanticSearch(queryEmbedding, topK: topK, layer: 0);
124+
125+
final enrichedResults =
126+
<({MemoryNode node, List<MemoryNode> context})>[];
127+
128+
// 2. Traverse upwards for each result
129+
for (final result in initialResults) {
130+
final baseNode = result.node;
131+
final context = <MemoryNode>[];
132+
var currentNode = baseNode;
133+
var hops = 0;
134+
135+
while (hops < maxHops) {
136+
// Find edges where the current node is the 'from' node and relation is 'summary_of'
137+
final edges = await isar.memoryEdges
138+
.filter()
139+
.fromNodeIdEqualTo(currentNode.id)
140+
.relationEqualTo(relationSummaryOf)
141+
.findAll();
142+
143+
if (edges.isEmpty) break;
144+
145+
// Follow the first summary edge to the parent
146+
final parentNodeId = edges.first.toNodeId;
147+
final parentNode = await getNode(parentNodeId);
148+
149+
if (parentNode == null) break;
150+
151+
context.add(parentNode);
152+
currentNode = parentNode;
153+
hops++;
154+
}
155+
enrichedResults.add((node: baseNode, context: context));
156+
}
157+
158+
return enrichedResults;
159+
}
59160
}

lib/src/llm_adapter.dart

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/// Abstract interface for a generic Large Language Model (LLM).
2+
///
3+
/// This class defines the contract for generating text content based on a given prompt.
4+
abstract class LLMAdapter {
5+
/// Generates content based on the given [prompt].
6+
///
7+
/// Returns a [Future] that completes with the generated [String].
8+
Future<String> generate(String prompt);
9+
}

lib/src/memory_graph.dart

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import 'models/memory_edge.dart';
66
import 'models/memory_embedding.dart';
77
import 'vector_index.dart';
88
import 'vector_index_objectbox.dart';
9+
import 'reranking_strategy.dart';
910

1011
/// Main API for interacting with the universal agent memory graph.
1112
///
@@ -181,6 +182,7 @@ class MemoryGraph {
181182
semanticSearch(
182183
List<double> queryEmbedding, {
183184
int topK = 5,
185+
int? layer,
184186
}) async {
185187
// Gracefully return empty list if dimensions mismatch, as tests expect.
186188
if (queryEmbedding.length != embeddingsAdapter.dimension) {
@@ -216,7 +218,11 @@ class MemoryGraph {
216218
}
217219

218220
// Fallback to linear scan if the index returns no results or fails.
219-
final allNodes = await isar.memoryNodes.where().findAll();
221+
var query = isar.memoryNodes.where();
222+
if (layer != null) {
223+
query = query.filter().layerEqualTo(layer);
224+
}
225+
final allNodes = await query.findAll();
220226

221227
final distances = allNodes
222228
.map((n) => (n.embedding != null)
@@ -419,4 +425,37 @@ class MemoryGraph {
419425
Future<void> clearVectorCollection() async {
420426
await _index.clear();
421427
}
428+
/// Performs a semantic search with a re-ranking strategy.
429+
///
430+
/// [queryEmbedding] is the embedding of the search query.
431+
/// [reranker] is the re-ranking strategy to apply.
432+
/// [topK] is the number of results to return.
433+
Future<List<({MemoryNode node, double score})>> semanticSearchWithReRanking(
434+
List<double> queryEmbedding, {
435+
required ReRankingStrategy reranker,
436+
int topK = 5,
437+
}) async {
438+
final searchResults = await semanticSearch(queryEmbedding, topK: topK * 2);
439+
final resultsWithScore = searchResults
440+
.map((r) => (node: r.node, score: 1.0 - r.distance))
441+
.toList();
442+
return reranker.reRank(resultsWithScore).take(topK).toList();
443+
}
444+
445+
/// Performs a hybrid search with a re-ranking strategy.
446+
///
447+
/// [query] is the text to search for.
448+
/// [reranker] is the re-ranking strategy to apply.
449+
/// [alpha] controls the weight of the vector search vs. text search.
450+
/// [topK] is the number of results to return.
451+
Future<List<({MemoryNode node, double score})>> hybridSearchWithReRanking(
452+
String query, {
453+
required ReRankingStrategy reranker,
454+
int topK = 5,
455+
double alpha = 0.5,
456+
}) async {
457+
final searchResults =
458+
await hybridSearch(query, topK: topK * 2, alpha: alpha);
459+
return reranker.reRank(searchResults, query: query).take(topK).toList();
460+
}
422461
}

lib/src/models/memory_node.dart

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class MemoryNode {
3131
this.modifiedAt,
3232
this.layer = 0,
3333
this.uuid,
34+
this.accessCount = 0,
3435
}) : createdAt = DateTime.now() {
3536
this.degree = degree ?? Degree();
3637
if (modifiedAt == null) {
@@ -71,6 +72,9 @@ class MemoryNode {
7172
/// Can be used to track recency and relevance.
7273
DateTime? updatedAt;
7374

75+
/// The number of times this node has been accessed.
76+
int accessCount;
77+
7478
/// The timestamp when this record was last modified (system-level sync).
7579
///
7680
/// Used for Last-Write-Wins (LWW) conflict resolution.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import 'dart:math';
2+
import 'package:isar_agent_memory/isar_agent_memory.dart';
3+
4+
/// A re-ranking strategy based on the BM25 algorithm.
5+
///
6+
/// This class re-ranks search results based on term frequency.
7+
class BM25ReRanker implements ReRankingStrategy {
8+
final double k1;
9+
final double b;
10+
11+
BM25ReRanker({this.k1 = 1.5, this.b = 0.75});
12+
13+
@override
14+
List<({MemoryNode node, double score})> reRank(
15+
List<({MemoryNode node, double score})> results, {
16+
String? query,
17+
}) {
18+
if (query == null || results.isEmpty) {
19+
return results;
20+
}
21+
22+
final queryTerms = _tokenize(query);
23+
final documents =
24+
results.map((r) => _tokenize(r.node.content)).toList();
25+
final idf = _calculateIdf(queryTerms, documents);
26+
final avgdl =
27+
documents.map((d) => d.length).reduce((a, b) => a + b) / documents.length;
28+
29+
results.sort((a, b) {
30+
final scoreA =
31+
_calculateBm25(queryTerms, _tokenize(a.node.content), documents, idf, avgdl);
32+
final scoreB =
33+
_calculateBm25(queryTerms, _tokenize(b.node.content), documents, idf, avgdl);
34+
return scoreB.compareTo(scoreA);
35+
});
36+
37+
return results;
38+
}
39+
40+
List<String> _tokenize(String text) {
41+
return text.toLowerCase().split(RegExp(r'\W+'));
42+
}
43+
44+
Map<String, double> _calculateIdf(
45+
List<String> queryTerms, List<List<String>> documents) {
46+
final idf = <String, double>{};
47+
for (final term in queryTerms) {
48+
final docCount = documents.where((d) => d.contains(term)).length;
49+
idf[term] = log((documents.length - docCount + 0.5) / (docCount + 0.5) + 1);
50+
}
51+
return idf;
52+
}
53+
54+
double _calculateBm25(List<String> queryTerms, List<String> doc,
55+
List<List<String>> documents, Map<String, double> idf, double avgdl) {
56+
double score = 0.0;
57+
for (final term in queryTerms) {
58+
if (!doc.contains(term)) {
59+
continue;
60+
}
61+
final tf = doc.where((t) => t == term).length;
62+
score += idf[term]! *
63+
(tf * (k1 + 1)) /
64+
(tf + k1 * (1 - b + b * (doc.length / avgdl)));
65+
}
66+
return score;
67+
}
68+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import 'dart:math';
2+
import 'package:isar_agent_memory/isar_agent_memory.dart';
3+
4+
/// A re-ranking strategy that maximizes the diversity of the results.
5+
///
6+
/// This class re-ranks search results to avoid showing similar results.
7+
class DiversityReRanker implements ReRankingStrategy {
8+
@override
9+
List<({MemoryNode node, double score})> reRank(
10+
List<({MemoryNode node, double score})> results, {
11+
String? query,
12+
}) {
13+
if (results.isEmpty) {
14+
return [];
15+
}
16+
17+
final reranked = <({MemoryNode node, double score})>[];
18+
final remaining = List.of(results);
19+
20+
reranked.add(remaining.removeAt(0));
21+
22+
while (remaining.isNotEmpty) {
23+
var bestCandidate = remaining.first;
24+
var bestScore = double.infinity;
25+
26+
for (final candidate in remaining) {
27+
final similarity = reranked
28+
.map((r) => _cosineSimilarity(
29+
r.node.embedding?.vector, candidate.node.embedding?.vector))
30+
.reduce(max);
31+
32+
if (similarity < bestScore) {
33+
bestScore = similarity;
34+
bestCandidate = candidate;
35+
}
36+
}
37+
38+
reranked.add(bestCandidate);
39+
remaining.remove(bestCandidate);
40+
}
41+
42+
return reranked;
43+
}
44+
45+
double _cosineSimilarity(List<double>? a, List<double>? b) {
46+
if (a == null || b == null || a.length != b.length) {
47+
return 0.0;
48+
}
49+
50+
double dotProduct = 0.0;
51+
double normA = 0.0;
52+
double normB = 0.0;
53+
54+
for (var i = 0; i < a.length; i++) {
55+
dotProduct += a[i] * b[i];
56+
normA += a[i] * a[i];
57+
normB += b[i] * b[i];
58+
}
59+
60+
if (normA == 0.0 || normB == 0.0) {
61+
return 0.0;
62+
}
63+
64+
return dotProduct / (sqrt(normA) * sqrt(normB));
65+
}
66+
}

0 commit comments

Comments
 (0)