Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ print(explanation);
- [x] Hybrid Retrieval (Dense + Isar Filter).
- [x] Sync & Privacy (Encryption with AES-256-GCM, LWW conflict resolution).
- [x] HiRAG Phase 1 (Layer-based organization, summary nodes, relationship types).
- [ ] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval).
- [ ] Cross-device sync backend (Firebase/WebSocket integration).
- [ ] Re-ranking and advanced retrieval strategies.
- [x] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval).
- [x] Cross-device sync backend (Firebase/WebSocket integration).
- [x] Re-ranking and advanced retrieval strategies.

---

Expand Down
103 changes: 102 additions & 1 deletion lib/src/hierarchical_graph.dart
Original file line number Diff line number Diff line change
@@ -1,17 +1,65 @@
import 'package:isar/isar.dart';
import 'package:isar_agent_memory/isar_agent_memory.dart';
import 'llm_adapter.dart';

/// Extension for HiRAG (Hierarchical RAG) capabilities.
///
/// This extension adds methods to manage hierarchical layers of knowledge.
extension HierarchicalMemoryGraph on MemoryGraph {

/// Relation type for summary edges (Child -> Summary).
static const String relationSummaryOf = 'summary_of';

/// Relation type for part-of edges (Part -> Whole).
static const String relationPartOf = 'part_of';

/// Automatically summarizes a given layer by grouping nodes and using an LLM.
///
/// [layerIndex]: The layer to summarize.
/// [llmAdapter]: The adapter to the Large Language Model for summary generation.
/// [promptTemplate]: A function to format the content into a prompt for the LLM.
/// If null, a default template is used.
///
/// Returns the ID of the newly created summary node.
Future<int> autoSummarizeLayer({
required int layerIndex,
required LLMAdapter llmAdapter,
String Function(String content)? promptTemplate,
}) async {
// 1. Get all nodes in the target layer
final nodes = await getNodesByLayer(layerIndex);
if (nodes.isEmpty) {
throw Exception('No nodes found in layer $layerIndex to summarize.');
}

// 2. Combine content for the LLM prompt
final contentToSummarize = nodes.map((n) => n.content).join('\n---\n');
final prompt = promptTemplate != null
? promptTemplate(contentToSummarize)
: 'Summarize the following content into a coherent paragraph: \n\n$contentToSummarize';

// 3. Call LLM to generate summary
final summaryContent = await llmAdapter.generate(prompt);

// 4. Create the summary node in the next layer
final childNodeIds = nodes.map((n) => n.id).toList();
final summaryNodeId = await createSummaryNode(
summaryContent: summaryContent,
childNodeIds: childNodeIds,
layer: layerIndex + 1,
);

// 5. Create 'part_of' relationships from the summary to its parts
for (final childId in childNodeIds) {
await storeEdge(MemoryEdge(
fromNodeId: summaryNodeId,
toNodeId: childId,
relation: relationPartOf,
));
}

return summaryNodeId;
}

/// Creates a summary node for a list of [childNodeIds].
///
/// [summaryContent]: The summarized text.
Expand Down Expand Up @@ -56,4 +104,57 @@ extension HierarchicalMemoryGraph on MemoryGraph {
Future<List<MemoryNode>> getNodesByLayer(int layer) async {
return await isar.memoryNodes.filter().layerEqualTo(layer).findAll();
}

/// Performs a multi-hop search, enriching results with hierarchical context.
///
/// [queryEmbedding]: The embedding of the search query.
/// [maxHops]: The maximum number of upward traversals (default is 2).
/// [topK]: The number of initial results to fetch from the base layer.
///
/// Returns a list of enriched results, where each result includes the base node
/// and a list of parent (summary) nodes.
Future<List<({MemoryNode node, List<MemoryNode> context})>> multiHopSearch({
required List<double> queryEmbedding,
int maxHops = 2,
int topK = 5,
}) async {
// 1. Semantic search on the base layer (layer 0)
final initialResults =
await semanticSearch(queryEmbedding, topK: topK, layer: 0);

final enrichedResults =
<({MemoryNode node, List<MemoryNode> context})>[];

// 2. Traverse upwards for each result
for (final result in initialResults) {
final baseNode = result.node;
final context = <MemoryNode>[];
var currentNode = baseNode;
var hops = 0;

while (hops < maxHops) {
// Find edges where the current node is the 'from' node and relation is 'summary_of'
final edges = await isar.memoryEdges
.filter()
.fromNodeIdEqualTo(currentNode.id)
.relationEqualTo(relationSummaryOf)
.findAll();

if (edges.isEmpty) break;

// Follow the first summary edge to the parent
final parentNodeId = edges.first.toNodeId;
final parentNode = await getNode(parentNodeId);

if (parentNode == null) break;

context.add(parentNode);
currentNode = parentNode;
hops++;
}
enrichedResults.add((node: baseNode, context: context));
}

return enrichedResults;
}
}
9 changes: 9 additions & 0 deletions lib/src/llm_adapter.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/// Abstract interface for a generic Large Language Model (LLM).
///
/// This class defines the contract for generating text content based on a given prompt.
abstract class LLMAdapter {
/// Generates content based on the given [prompt].
///
/// Returns a [Future] that completes with the generated [String].
Future<String> generate(String prompt);
}
41 changes: 40 additions & 1 deletion lib/src/memory_graph.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'models/memory_edge.dart';
import 'models/memory_embedding.dart';
import 'vector_index.dart';
import 'vector_index_objectbox.dart';
import 'reranking_strategy.dart';

/// Main API for interacting with the universal agent memory graph.
///
Expand Down Expand Up @@ -181,6 +182,7 @@ class MemoryGraph {
semanticSearch(
List<double> queryEmbedding, {
int topK = 5,
int? layer,
}) async {
// Gracefully return empty list if dimensions mismatch, as tests expect.
if (queryEmbedding.length != embeddingsAdapter.dimension) {
Expand Down Expand Up @@ -216,7 +218,11 @@ class MemoryGraph {
}

// Fallback to linear scan if the index returns no results or fails.
final allNodes = await isar.memoryNodes.where().findAll();
var query = isar.memoryNodes.where();
if (layer != null) {
query = query.filter().layerEqualTo(layer);
}
final allNodes = await query.findAll();

final distances = allNodes
.map((n) => (n.embedding != null)
Expand Down Expand Up @@ -419,4 +425,37 @@ class MemoryGraph {
Future<void> clearVectorCollection() async {
await _index.clear();
}
/// Performs a semantic search with a re-ranking strategy.
///
/// [queryEmbedding] is the embedding of the search query.
/// [reranker] is the re-ranking strategy to apply.
/// [topK] is the number of results to return.
Future<List<({MemoryNode node, double score})>> semanticSearchWithReRanking(
List<double> queryEmbedding, {
required ReRankingStrategy reranker,
int topK = 5,
}) async {
final searchResults = await semanticSearch(queryEmbedding, topK: topK * 2);
final resultsWithScore = searchResults
.map((r) => (node: r.node, score: 1.0 - r.distance))
.toList();
return reranker.reRank(resultsWithScore).take(topK).toList();
}

/// Performs a hybrid search with a re-ranking strategy.
///
/// [query] is the text to search for.
/// [reranker] is the re-ranking strategy to apply.
/// [alpha] controls the weight of the vector search vs. text search.
/// [topK] is the number of results to return.
Future<List<({MemoryNode node, double score})>> hybridSearchWithReRanking(
String query, {
required ReRankingStrategy reranker,
int topK = 5,
double alpha = 0.5,
}) async {
final searchResults =
await hybridSearch(query, topK: topK * 2, alpha: alpha);
return reranker.reRank(searchResults, query: query).take(topK).toList();
}
}
4 changes: 4 additions & 0 deletions lib/src/models/memory_node.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MemoryNode {
this.modifiedAt,
this.layer = 0,
this.uuid,
this.accessCount = 0,
}) : createdAt = DateTime.now() {
this.degree = degree ?? Degree();
if (modifiedAt == null) {
Expand Down Expand Up @@ -71,6 +72,9 @@ class MemoryNode {
/// Can be used to track recency and relevance.
DateTime? updatedAt;

/// The number of times this node has been accessed.
int accessCount;

/// The timestamp when this record was last modified (system-level sync).
///
/// Used for Last-Write-Wins (LWW) conflict resolution.
Expand Down
68 changes: 68 additions & 0 deletions lib/src/rerankers/bm25_reranker.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import 'dart:math';
import 'package:isar_agent_memory/isar_agent_memory.dart';

/// A re-ranking strategy based on the BM25 algorithm.
///
/// This class re-ranks search results based on term frequency.
class BM25ReRanker implements ReRankingStrategy {
final double k1;
final double b;

BM25ReRanker({this.k1 = 1.5, this.b = 0.75});

@override
List<({MemoryNode node, double score})> reRank(
List<({MemoryNode node, double score})> results, {
String? query,
}) {
if (query == null || results.isEmpty) {
return results;
}

final queryTerms = _tokenize(query);
final documents =
results.map((r) => _tokenize(r.node.content)).toList();
final idf = _calculateIdf(queryTerms, documents);
final avgdl =
documents.map((d) => d.length).reduce((a, b) => a + b) / documents.length;

results.sort((a, b) {
final scoreA =
_calculateBm25(queryTerms, _tokenize(a.node.content), documents, idf, avgdl);
final scoreB =
_calculateBm25(queryTerms, _tokenize(b.node.content), documents, idf, avgdl);
return scoreB.compareTo(scoreA);
});

return results;
}

List<String> _tokenize(String text) {
return text.toLowerCase().split(RegExp(r'\W+'));
}

Map<String, double> _calculateIdf(
List<String> queryTerms, List<List<String>> documents) {
final idf = <String, double>{};
for (final term in queryTerms) {
final docCount = documents.where((d) => d.contains(term)).length;
idf[term] = log((documents.length - docCount + 0.5) / (docCount + 0.5) + 1);
}
return idf;
}

double _calculateBm25(List<String> queryTerms, List<String> doc,
List<List<String>> documents, Map<String, double> idf, double avgdl) {
double score = 0.0;
for (final term in queryTerms) {
if (!doc.contains(term)) {
continue;
}
final tf = doc.where((t) => t == term).length;
score += idf[term]! *
(tf * (k1 + 1)) /
(tf + k1 * (1 - b + b * (doc.length / avgdl)));
}
return score;
}
}
66 changes: 66 additions & 0 deletions lib/src/rerankers/diversity_reranker.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import 'dart:math';
import 'package:isar_agent_memory/isar_agent_memory.dart';

/// A re-ranking strategy that maximizes the diversity of the results.
///
/// This class re-ranks search results to avoid showing similar results.
class DiversityReRanker implements ReRankingStrategy {
@override
List<({MemoryNode node, double score})> reRank(
List<({MemoryNode node, double score})> results, {
String? query,
}) {
if (results.isEmpty) {
return [];
}

final reranked = <({MemoryNode node, double score})>[];
final remaining = List.of(results);

reranked.add(remaining.removeAt(0));

while (remaining.isNotEmpty) {
var bestCandidate = remaining.first;
var bestScore = double.infinity;

for (final candidate in remaining) {
final similarity = reranked
.map((r) => _cosineSimilarity(
r.node.embedding?.vector, candidate.node.embedding?.vector))
.reduce(max);

if (similarity < bestScore) {
bestScore = similarity;
bestCandidate = candidate;
}
}

reranked.add(bestCandidate);
remaining.remove(bestCandidate);
}

return reranked;
}

double _cosineSimilarity(List<double>? a, List<double>? b) {
if (a == null || b == null || a.length != b.length) {
return 0.0;
}

double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;

for (var i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}

if (normA == 0.0 || normB == 0.0) {
return 0.0;
}

return dotProduct / (sqrt(normA) * sqrt(normB));
}
}
Loading
Loading