diff --git a/README.md b/README.md index 431fad5..17c3ce0 100644 --- a/README.md +++ b/README.md @@ -447,9 +447,9 @@ print(explanation); - [x] Hybrid Retrieval (Dense + Isar Filter). - [x] Sync & Privacy (Encryption with AES-256-GCM, LWW conflict resolution). - [x] HiRAG Phase 1 (Layer-based organization, summary nodes, relationship types). -- [ ] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval). -- [ ] Cross-device sync backend (Firebase/WebSocket integration). -- [ ] Re-ranking and advanced retrieval strategies. +- [x] HiRAG Phase 2 (Automatic LLM-based summarization, multi-hop retrieval). +- [x] Cross-device sync backend (Firebase/WebSocket integration). +- [x] Re-ranking and advanced retrieval strategies. --- diff --git a/lib/src/hierarchical_graph.dart b/lib/src/hierarchical_graph.dart index 15773ae..1473be5 100644 --- a/lib/src/hierarchical_graph.dart +++ b/lib/src/hierarchical_graph.dart @@ -1,17 +1,65 @@ import 'package:isar/isar.dart'; import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'llm_adapter.dart'; /// Extension for HiRAG (Hierarchical RAG) capabilities. /// /// This extension adds methods to manage hierarchical layers of knowledge. extension HierarchicalMemoryGraph on MemoryGraph { - /// Relation type for summary edges (Child -> Summary). static const String relationSummaryOf = 'summary_of'; /// Relation type for part-of edges (Part -> Whole). static const String relationPartOf = 'part_of'; + /// Automatically summarizes a given layer by grouping nodes and using an LLM. + /// + /// [layerIndex]: The layer to summarize. + /// [llmAdapter]: The adapter to the Large Language Model for summary generation. + /// [promptTemplate]: A function to format the content into a prompt for the LLM. + /// If null, a default template is used. + /// + /// Returns the ID of the newly created summary node. + Future autoSummarizeLayer({ + required int layerIndex, + required LLMAdapter llmAdapter, + String Function(String content)? promptTemplate, + }) async { + // 1. Get all nodes in the target layer + final nodes = await getNodesByLayer(layerIndex); + if (nodes.isEmpty) { + throw Exception('No nodes found in layer $layerIndex to summarize.'); + } + + // 2. Combine content for the LLM prompt + final contentToSummarize = nodes.map((n) => n.content).join('\n---\n'); + final prompt = promptTemplate != null + ? promptTemplate(contentToSummarize) + : 'Summarize the following content into a coherent paragraph: \n\n$contentToSummarize'; + + // 3. Call LLM to generate summary + final summaryContent = await llmAdapter.generate(prompt); + + // 4. Create the summary node in the next layer + final childNodeIds = nodes.map((n) => n.id).toList(); + final summaryNodeId = await createSummaryNode( + summaryContent: summaryContent, + childNodeIds: childNodeIds, + layer: layerIndex + 1, + ); + + // 5. Create 'part_of' relationships from the summary to its parts + for (final childId in childNodeIds) { + await storeEdge(MemoryEdge( + fromNodeId: summaryNodeId, + toNodeId: childId, + relation: relationPartOf, + )); + } + + return summaryNodeId; + } + /// Creates a summary node for a list of [childNodeIds]. /// /// [summaryContent]: The summarized text. @@ -56,4 +104,57 @@ extension HierarchicalMemoryGraph on MemoryGraph { Future> getNodesByLayer(int layer) async { return await isar.memoryNodes.filter().layerEqualTo(layer).findAll(); } + + /// Performs a multi-hop search, enriching results with hierarchical context. + /// + /// [queryEmbedding]: The embedding of the search query. + /// [maxHops]: The maximum number of upward traversals (default is 2). + /// [topK]: The number of initial results to fetch from the base layer. + /// + /// Returns a list of enriched results, where each result includes the base node + /// and a list of parent (summary) nodes. + Future context})>> multiHopSearch({ + required List queryEmbedding, + int maxHops = 2, + int topK = 5, + }) async { + // 1. Semantic search on the base layer (layer 0) + final initialResults = + await semanticSearch(queryEmbedding, topK: topK, layer: 0); + + final enrichedResults = + <({MemoryNode node, List context})>[]; + + // 2. Traverse upwards for each result + for (final result in initialResults) { + final baseNode = result.node; + final context = []; + var currentNode = baseNode; + var hops = 0; + + while (hops < maxHops) { + // Find edges where the current node is the 'from' node and relation is 'summary_of' + final edges = await isar.memoryEdges + .filter() + .fromNodeIdEqualTo(currentNode.id) + .relationEqualTo(relationSummaryOf) + .findAll(); + + if (edges.isEmpty) break; + + // Follow the first summary edge to the parent + final parentNodeId = edges.first.toNodeId; + final parentNode = await getNode(parentNodeId); + + if (parentNode == null) break; + + context.add(parentNode); + currentNode = parentNode; + hops++; + } + enrichedResults.add((node: baseNode, context: context)); + } + + return enrichedResults; + } } diff --git a/lib/src/llm_adapter.dart b/lib/src/llm_adapter.dart new file mode 100644 index 0000000..57f44c1 --- /dev/null +++ b/lib/src/llm_adapter.dart @@ -0,0 +1,9 @@ +/// Abstract interface for a generic Large Language Model (LLM). +/// +/// This class defines the contract for generating text content based on a given prompt. +abstract class LLMAdapter { + /// Generates content based on the given [prompt]. + /// + /// Returns a [Future] that completes with the generated [String]. + Future generate(String prompt); +} diff --git a/lib/src/memory_graph.dart b/lib/src/memory_graph.dart index dd1127e..cf741fe 100644 --- a/lib/src/memory_graph.dart +++ b/lib/src/memory_graph.dart @@ -6,6 +6,7 @@ import 'models/memory_edge.dart'; import 'models/memory_embedding.dart'; import 'vector_index.dart'; import 'vector_index_objectbox.dart'; +import 'reranking_strategy.dart'; /// Main API for interacting with the universal agent memory graph. /// @@ -181,6 +182,7 @@ class MemoryGraph { semanticSearch( List queryEmbedding, { int topK = 5, + int? layer, }) async { // Gracefully return empty list if dimensions mismatch, as tests expect. if (queryEmbedding.length != embeddingsAdapter.dimension) { @@ -216,7 +218,11 @@ class MemoryGraph { } // Fallback to linear scan if the index returns no results or fails. - final allNodes = await isar.memoryNodes.where().findAll(); + var query = isar.memoryNodes.where(); + if (layer != null) { + query = query.filter().layerEqualTo(layer); + } + final allNodes = await query.findAll(); final distances = allNodes .map((n) => (n.embedding != null) @@ -419,4 +425,37 @@ class MemoryGraph { Future clearVectorCollection() async { await _index.clear(); } + /// Performs a semantic search with a re-ranking strategy. + /// + /// [queryEmbedding] is the embedding of the search query. + /// [reranker] is the re-ranking strategy to apply. + /// [topK] is the number of results to return. + Future> semanticSearchWithReRanking( + List queryEmbedding, { + required ReRankingStrategy reranker, + int topK = 5, + }) async { + final searchResults = await semanticSearch(queryEmbedding, topK: topK * 2); + final resultsWithScore = searchResults + .map((r) => (node: r.node, score: 1.0 - r.distance)) + .toList(); + return reranker.reRank(resultsWithScore).take(topK).toList(); + } + + /// Performs a hybrid search with a re-ranking strategy. + /// + /// [query] is the text to search for. + /// [reranker] is the re-ranking strategy to apply. + /// [alpha] controls the weight of the vector search vs. text search. + /// [topK] is the number of results to return. + Future> hybridSearchWithReRanking( + String query, { + required ReRankingStrategy reranker, + int topK = 5, + double alpha = 0.5, + }) async { + final searchResults = + await hybridSearch(query, topK: topK * 2, alpha: alpha); + return reranker.reRank(searchResults, query: query).take(topK).toList(); + } } diff --git a/lib/src/models/memory_node.dart b/lib/src/models/memory_node.dart index 0415cb1..e9401f9 100644 --- a/lib/src/models/memory_node.dart +++ b/lib/src/models/memory_node.dart @@ -31,6 +31,7 @@ class MemoryNode { this.modifiedAt, this.layer = 0, this.uuid, + this.accessCount = 0, }) : createdAt = DateTime.now() { this.degree = degree ?? Degree(); if (modifiedAt == null) { @@ -71,6 +72,9 @@ class MemoryNode { /// Can be used to track recency and relevance. DateTime? updatedAt; + /// The number of times this node has been accessed. + int accessCount; + /// The timestamp when this record was last modified (system-level sync). /// /// Used for Last-Write-Wins (LWW) conflict resolution. diff --git a/lib/src/rerankers/bm25_reranker.dart b/lib/src/rerankers/bm25_reranker.dart new file mode 100644 index 0000000..b14b7fb --- /dev/null +++ b/lib/src/rerankers/bm25_reranker.dart @@ -0,0 +1,68 @@ +import 'dart:math'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +/// A re-ranking strategy based on the BM25 algorithm. +/// +/// This class re-ranks search results based on term frequency. +class BM25ReRanker implements ReRankingStrategy { + final double k1; + final double b; + + BM25ReRanker({this.k1 = 1.5, this.b = 0.75}); + + @override + List<({MemoryNode node, double score})> reRank( + List<({MemoryNode node, double score})> results, { + String? query, + }) { + if (query == null || results.isEmpty) { + return results; + } + + final queryTerms = _tokenize(query); + final documents = + results.map((r) => _tokenize(r.node.content)).toList(); + final idf = _calculateIdf(queryTerms, documents); + final avgdl = + documents.map((d) => d.length).reduce((a, b) => a + b) / documents.length; + + results.sort((a, b) { + final scoreA = + _calculateBm25(queryTerms, _tokenize(a.node.content), documents, idf, avgdl); + final scoreB = + _calculateBm25(queryTerms, _tokenize(b.node.content), documents, idf, avgdl); + return scoreB.compareTo(scoreA); + }); + + return results; + } + + List _tokenize(String text) { + return text.toLowerCase().split(RegExp(r'\W+')); + } + + Map _calculateIdf( + List queryTerms, List> documents) { + final idf = {}; + for (final term in queryTerms) { + final docCount = documents.where((d) => d.contains(term)).length; + idf[term] = log((documents.length - docCount + 0.5) / (docCount + 0.5) + 1); + } + return idf; + } + + double _calculateBm25(List queryTerms, List doc, + List> documents, Map idf, double avgdl) { + double score = 0.0; + for (final term in queryTerms) { + if (!doc.contains(term)) { + continue; + } + final tf = doc.where((t) => t == term).length; + score += idf[term]! * + (tf * (k1 + 1)) / + (tf + k1 * (1 - b + b * (doc.length / avgdl))); + } + return score; + } +} diff --git a/lib/src/rerankers/diversity_reranker.dart b/lib/src/rerankers/diversity_reranker.dart new file mode 100644 index 0000000..8bcba2d --- /dev/null +++ b/lib/src/rerankers/diversity_reranker.dart @@ -0,0 +1,66 @@ +import 'dart:math'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +/// A re-ranking strategy that maximizes the diversity of the results. +/// +/// This class re-ranks search results to avoid showing similar results. +class DiversityReRanker implements ReRankingStrategy { + @override + List<({MemoryNode node, double score})> reRank( + List<({MemoryNode node, double score})> results, { + String? query, + }) { + if (results.isEmpty) { + return []; + } + + final reranked = <({MemoryNode node, double score})>[]; + final remaining = List.of(results); + + reranked.add(remaining.removeAt(0)); + + while (remaining.isNotEmpty) { + var bestCandidate = remaining.first; + var bestScore = double.infinity; + + for (final candidate in remaining) { + final similarity = reranked + .map((r) => _cosineSimilarity( + r.node.embedding?.vector, candidate.node.embedding?.vector)) + .reduce(max); + + if (similarity < bestScore) { + bestScore = similarity; + bestCandidate = candidate; + } + } + + reranked.add(bestCandidate); + remaining.remove(bestCandidate); + } + + return reranked; + } + + double _cosineSimilarity(List? a, List? b) { + if (a == null || b == null || a.length != b.length) { + return 0.0; + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (var i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + return dotProduct / (sqrt(normA) * sqrt(normB)); + } +} diff --git a/lib/src/rerankers/mmr_reranker.dart b/lib/src/rerankers/mmr_reranker.dart new file mode 100644 index 0000000..b1d707d --- /dev/null +++ b/lib/src/rerankers/mmr_reranker.dart @@ -0,0 +1,72 @@ +import 'dart:math'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +/// A re-ranking strategy that uses Maximal Marginal Relevance (MMR). +/// +/// This class re-ranks search results to balance relevance and diversity. +class MMRReRanker implements ReRankingStrategy { + final double lambda; + + MMRReRanker({this.lambda = 0.5}); + + @override + List<({MemoryNode node, double score})> reRank( + List<({MemoryNode node, double score})> results, { + String? query, + }) { + if (results.isEmpty) { + return []; + } + + final reranked = <({MemoryNode node, double score})>[]; + final remaining = List.of(results); + + reranked.add(remaining.removeAt(0)); + + while (remaining.isNotEmpty) { + var bestCandidate = remaining.first; + var bestScore = -double.infinity; + + for (final candidate in remaining) { + final relevance = candidate.score; + final similarity = reranked + .map((r) => _cosineSimilarity( + r.node.embedding?.vector, candidate.node.embedding?.vector)) + .reduce(max); + final mmrScore = lambda * relevance - (1 - lambda) * similarity; + + if (mmrScore > bestScore) { + bestScore = mmrScore; + bestCandidate = candidate; + } + } + + reranked.add(bestCandidate); + remaining.remove(bestCandidate); + } + + return reranked; + } + + double _cosineSimilarity(List? a, List? b) { + if (a == null || b == null || a.length != b.length) { + return 0.0; + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (var i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + return dotProduct / (sqrt(normA) * sqrt(normB)); + } +} diff --git a/lib/src/rerankers/recency_reranker.dart b/lib/src/rerankers/recency_reranker.dart new file mode 100644 index 0000000..94b8c0a --- /dev/null +++ b/lib/src/rerankers/recency_reranker.dart @@ -0,0 +1,19 @@ +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +/// A re-ranking strategy that prioritizes more recent results. +/// +/// This class re-ranks search results based on their creation or update timestamps. +class RecencyReRanker implements ReRankingStrategy { + @override + List<({MemoryNode node, double score})> reRank( + List<({MemoryNode node, double score})> results, { + String? query, + }) { + results.sort((a, b) { + final dateA = a.node.updatedAt ?? a.node.createdAt; + final dateB = b.node.updatedAt ?? b.node.createdAt; + return dateB.compareTo(dateA); + }); + return results; + } +} diff --git a/lib/src/reranking_strategy.dart b/lib/src/reranking_strategy.dart new file mode 100644 index 0000000..caf0e48 --- /dev/null +++ b/lib/src/reranking_strategy.dart @@ -0,0 +1,17 @@ +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +/// Abstract interface for a re-ranking strategy. +/// +/// This class defines the contract for different re-ranking algorithms +/// that can be applied to a list of search results. +abstract class ReRankingStrategy { + /// Re-ranks a list of search results. + /// + /// [results] is the initial list of search results to be re-ranked. + /// [query] is the original search query, required by some strategies like BM25. + /// Returns a re-ranked list of search results. + List<({MemoryNode node, double score})> reRank( + List<({MemoryNode node, double score})> results, { + String? query, + }); +} diff --git a/lib/src/sync/cross_device_sync_manager.dart b/lib/src/sync/cross_device_sync_manager.dart new file mode 100644 index 0000000..f77db1a --- /dev/null +++ b/lib/src/sync/cross_device_sync_manager.dart @@ -0,0 +1,72 @@ +import 'dart:async'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/sync/sync_backend.dart'; +import 'package:isar_agent_memory/src/sync/firebase_sync_backend.dart'; +import 'package:isar_agent_memory/src/sync/websocket_sync_backend.dart'; + +/// Manages cross-device synchronization of the memory graph. +/// +/// This class extends [SyncManager] to add support for real-time synchronization +/// using different backends (e.g., Firebase, WebSockets). +class CrossDeviceSyncManager extends SyncManager { + SyncBackend? _backend; + StreamSubscription? _subscription; + + /// Creates a [CrossDeviceSyncManager] for the given [memoryGraph]. + CrossDeviceSyncManager(MemoryGraph memoryGraph) : super(memoryGraph); + + /// Initializes the sync manager with a specific backend. + /// + /// The backend is chosen by the [SyncBackendFactory]. + Future initializeBackend({ + Map? firebaseConfig, + Map? websocketConfig, + List? encryptionKey, + }) async { + await initialize(encryptionKey: encryptionKey); + _backend = SyncBackendFactory.createBackend( + firebaseConfig: firebaseConfig, + websocketConfig: websocketConfig, + ); + await _backend!.initialize(firebaseConfig ?? websocketConfig ?? {}); + + // Start listening for remote changes + _subscription = _backend!.remoteSnapshotsStream.listen((snapshot) { + importEncryptedSnapshot(snapshot); + }); + } + + /// Publishes the current memory state as an encrypted snapshot. + Future publishSnapshot() async { + if (_backend == null) { + throw StateError('Sync backend not initialized.'); + } + final snapshot = await exportEncryptedSnapshot(); + await _backend!.publishSnapshot(snapshot); + } + + /// Disposes of the sync manager and its backend. + Future dispose() async { + await _subscription?.cancel(); + await _backend?.dispose(); + } +} + +/// A factory for creating [SyncBackend] instances. +class SyncBackendFactory { + /// Creates a [SyncBackend] based on the provided configurations. + /// + /// If [firebaseConfig] is provided, a [FirebaseSyncBackend] is created. + /// Otherwise, a [WebSocketSyncBackend] is created as a fallback. + static SyncBackend createBackend({ + Map? firebaseConfig, + Map? websocketConfig, + }) { + if (firebaseConfig != null) { + return FirebaseSyncBackend(); + } else { + return WebSocketSyncBackend( + channel: websocketConfig?['channel']); + } + } +} diff --git a/lib/src/sync/firebase_sync_backend.dart b/lib/src/sync/firebase_sync_backend.dart new file mode 100644 index 0000000..3f7ff4e --- /dev/null +++ b/lib/src/sync/firebase_sync_backend.dart @@ -0,0 +1,63 @@ +import 'dart:async'; +import 'package:firebase_core/firebase_core.dart'; +import 'package:firebase_database/firebase_database.dart'; +import 'package:isar_agent_memory/src/sync/sync_backend.dart'; + +/// A [SyncBackend] implementation for Firebase Realtime Database. +class FirebaseSyncBackend implements SyncBackend { + late final FirebaseDatabase _database; + final _controller = StreamController>.broadcast(); + StreamSubscription? _subscription; + DatabaseReference? _userRef; + + @override + Future initialize(Map config) async { + final String? userId = config['userId']; + if (userId == null || userId.isEmpty) { + throw ArgumentError('A valid userId must be provided in the config.'); + } + + final app = Firebase.apps.isEmpty + ? await Firebase.initializeApp( + options: FirebaseOptions( + apiKey: config['apiKey'], + appId: config['appId'], + messagingSenderId: config['messagingSenderId'], + projectId: config['projectId'], + databaseURL: config['databaseURL'], + ), + ) + : Firebase.app(); + + _database = FirebaseDatabase.instanceFor(app: app); + _userRef = _database.ref('users/$userId'); + + _subscription = _userRef!.child('snapshots').onValue.listen((event) { + if (event.snapshot.value != null) { + final data = Map.from(event.snapshot.value as Map); + final snapshot = List.from(data['data'] as List); + _controller.add(snapshot); + } + }); + } + + @override + Future publishSnapshot(List snapshot) async { + if (_userRef == null) { + throw StateError('FirebaseSyncBackend not initialized.'); + } + await _userRef!.child('snapshots').set({ + 'data': snapshot, + 'timestamp': ServerValue.timestamp, + }); + } + + @override + Stream> get remoteSnapshotsStream => _controller.stream; + + @override + Future dispose() async { + await _subscription?.cancel(); + await _controller.close(); + } +} diff --git a/lib/src/sync/sync_backend.dart b/lib/src/sync/sync_backend.dart new file mode 100644 index 0000000..5bdfb11 --- /dev/null +++ b/lib/src/sync/sync_backend.dart @@ -0,0 +1,25 @@ +import 'dart:async'; + +/// Abstract interface for a synchronization backend. +/// +/// This defines the contract for services that handle the real-time synchronization +/// of memory snapshots between devices (e.g., Firebase, WebSocket). +abstract class SyncBackend { + /// Initializes the backend and establishes a connection. + /// + /// [config] is a map of configuration parameters (e.g., API keys, URLs). + Future initialize(Map config); + + /// Publishes an encrypted snapshot to the remote backend. + /// + /// [snapshot] is the encrypted data blob to be sent. + Future publishSnapshot(List snapshot); + + /// A stream of incoming snapshots from the remote backend. + /// + /// Listen to this stream to receive real-time updates from other devices. + Stream> get remoteSnapshotsStream; + + /// Disposes of the backend connection and cleans up resources. + Future dispose(); +} diff --git a/lib/src/sync/websocket_sync_backend.dart b/lib/src/sync/websocket_sync_backend.dart new file mode 100644 index 0000000..5510de0 --- /dev/null +++ b/lib/src/sync/websocket_sync_backend.dart @@ -0,0 +1,46 @@ +import 'dart:async'; +import 'package:isar_agent_memory/src/sync/sync_backend.dart'; +import 'package:web_socket_channel/web_socket_channel.dart'; + +/// A [SyncBackend] implementation for WebSockets. +class WebSocketSyncBackend implements SyncBackend { + late final WebSocketChannel _channel; + final _controller = StreamController>.broadcast(); + Timer? _keepAliveTimer; + + WebSocketSyncBackend({WebSocketChannel? channel}) { + if (channel != null) { + _channel = channel; + } + } + + @override + Future initialize(Map config) async { + if (!config.containsKey('channel')) { + _channel = WebSocketChannel.connect(Uri.parse(config['url'])); + } + _channel.stream.listen((data) { + if (data == 'pong') return; + _controller.add(List.from(data)); + }); + + _keepAliveTimer = Timer.periodic(const Duration(seconds: 30), (timer) { + _channel.sink.add('ping'); + }); + } + + @override + Future publishSnapshot(List snapshot) async { + _channel.sink.add(snapshot); + } + + @override + Stream> get remoteSnapshotsStream => _controller.stream; + + @override + Future dispose() async { + _keepAliveTimer?.cancel(); + await _channel.sink.close(); + await _controller.close(); + } +} diff --git a/pubspec.yaml b/pubspec.yaml index 3bee59f..dfe4955 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -29,11 +29,16 @@ dependencies: http: ^1.1.0 # Required for tool scripts cryptography: ^2.9.0 json_annotation: ^4.9.0 + firebase_core: ^2.24.2 + firebase_database: ^10.4.0 + web_socket_channel: ^2.4.0 # Add your embedding provider dependencies as needed (e.g., google_generative_ai, openai, etc.) dev_dependencies: flutter_lints: ^6.0.0 lints: ^6.0.0 + firebase_database_mocks: ^0.6.0 + mockito: ^5.4.4 test: ^1.25.0 build_runner: ^2.4.6 objectbox_generator: ^5.0.0 diff --git a/test/advanced_retrieval_test.dart b/test/advanced_retrieval_test.dart new file mode 100644 index 0000000..eb612e0 --- /dev/null +++ b/test/advanced_retrieval_test.dart @@ -0,0 +1,61 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package.isar/isar.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/rerankers/bm25_reranker.dart'; +import 'package:isar_agent_memory/src/rerankers/recency_reranker.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isar; + late MemoryGraph memoryGraph; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + isar = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'test_db', + ); + memoryGraph = + MemoryGraph(isar, embeddingsAdapter: FallbackEmbeddingsAdapter()); + await isar.writeTxn(() async => await isar.clear()); + }); + + tearDown(() async { + await isar.close(deleteFromDisk: true); + }); + + test('Semantic search with RecencyReRanker test', () async { + final now = DateTime.now(); + await memoryGraph.storeNode(MemoryNode( + content: 'older', createdAt: now.subtract(const Duration(days: 1)))); + await memoryGraph.storeNode(MemoryNode(content: 'newer', createdAt: now)); + + final results = await memoryGraph.semanticSearchWithReRanking( + await memoryGraph.embeddingsAdapter.embed('some query'), + reranker: RecencyReRanker(), + topK: 1, + ); + + expect(results.first.node.content, 'newer'); + }); + + test('Hybrid search with BM25ReRanker test', () async { + await memoryGraph + .storeNode(MemoryNode(content: 'the quick brown fox')); + await memoryGraph.storeNode(MemoryNode(content: 'a lazy dog')); + await memoryGraph.storeNode(MemoryNode( + content: 'the quick brown fox jumps over the lazy dog')); + + final results = await memoryGraph.hybridSearchWithReRanking( + 'quick fox', + reranker: BM25ReRanker(), + topK: 2, + ); + + expect(results[0].node.content, 'the quick brown fox'); + expect(results[1].node.content, + 'the quick brown fox jumps over the lazy dog'); + }); +} diff --git a/test/cross_device_sync_firebase_test.dart b/test/cross_device_sync_firebase_test.dart new file mode 100644 index 0000000..1b0b796 --- /dev/null +++ b/test/cross_device_sync_firebase_test.dart @@ -0,0 +1,71 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar/isar.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/sync/cross_device_sync_manager.dart'; +import 'package:firebase_database_mocks/firebase_database_mocks.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isarA; + late Isar isarB; + late MemoryGraph memoryGraphA; + late MemoryGraph memoryGraphB; + late CrossDeviceSyncManager syncManagerA; + late CrossDeviceSyncManager syncManagerB; + late MockFirebaseDatabase mockDatabase; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + mockDatabase = MockFirebaseDatabase(); + + // Setup for Device A + isarA = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'device_a_db', + ); + memoryGraphA = + MemoryGraph(isarA, embeddingsAdapter: FallbackEmbeddingsAdapter()); + syncManagerA = CrossDeviceSyncManager(memoryGraphA); + + // Setup for Device B + isarB = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'device_b_db', + ); + memoryGraphB = + MemoryGraph(isarB, embeddingsAdapter: FallbackEmbeddingsAdapter()); + syncManagerB = CrossDeviceSyncManager(memoryGraphB); + + await isarA.writeTxn(() async => await isarA.clear()); + await isarB.writeTxn(() async => await isarB.clear()); + }); + + tearDown(() async { + await isarA.close(deleteFromDisk: true); + await isarB.close(deleteFromDisk: true); + }); + + test('Firebase sync test', () async { + final config = { + 'apiKey': 'fake-api-key', + 'appId': 'fake-app-id', + 'messagingSenderId': 'fake-sender-id', + 'projectId': 'fake-project-id', + 'databaseURL': mockDatabase.databaseURL, + 'userId': 'user123', + }; + await syncManagerA.initializeBackend(firebaseConfig: config); + await syncManagerB.initializeBackend(firebaseConfig: config); + + await memoryGraphA.storeNode(MemoryNode(content: 'Node from A')); + await syncManagerA.publishSnapshot(); + + await Future.delayed(const Duration(milliseconds: 100)); + final nodesOnB = await memoryGraphB.isar.memoryNodes.where().findAll(); + expect(nodesOnB.length, 1); + expect(nodesOnB.first.content, 'Node from A'); + }); +} diff --git a/test/cross_device_sync_websocket_test.dart b/test/cross_device_sync_websocket_test.dart new file mode 100644 index 0000000..de878c3 --- /dev/null +++ b/test/cross_device_sync_websocket_test.dart @@ -0,0 +1,85 @@ +import 'dart:async'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar/isar.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/sync/cross_device_sync_manager.dart'; +import 'package:mockito/annotations.dart'; +import 'package:mockito/mockito.dart'; +import 'package:web_socket_channel/web_socket_channel.dart'; +import 'cross_device_sync_websocket_test.mocks.dart'; + +@GenerateMocks([WebSocketChannel, WebSocketSink]) +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isarA; + late Isar isarB; + late MemoryGraph memoryGraphA; + late MemoryGraph memoryGraphB; + late CrossDeviceSyncManager syncManagerA; + late CrossDeviceSyncManager syncManagerB; + late MockWebSocketChannel mockChannelA; + late MockWebSocketChannel mockChannelB; + late StreamController controllerA; + late StreamController controllerB; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + + isarA = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'device_a_db', + ); + memoryGraphA = + MemoryGraph(isarA, embeddingsAdapter: FallbackEmbeddingsAdapter()); + syncManagerA = CrossDeviceSyncManager(memoryGraphA); + + isarB = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'device_b_db', + ); + memoryGraphB = + MemoryGraph(isarB, embeddingsAdapter: FallbackEmbeddingsAdapter()); + syncManagerB = CrossDeviceSyncManager(memoryGraphB); + + await isarA.writeTxn(() async => await isarA.clear()); + await isarB.writeTxn(() async => await isarB.clear()); + + mockChannelA = MockWebSocketChannel(); + mockChannelB = MockWebSocketChannel(); + controllerA = StreamController.broadcast(); + controllerB = StreamController.broadcast(); + + when(mockChannelA.stream).thenAnswer((_) => controllerA.stream); + when(mockChannelB.stream).thenAnswer((_) => controllerB.stream); + when(mockChannelA.sink).thenReturn(MockWebSocketSink()); + when(mockChannelB.sink).thenReturn(MockWebSocketSink()); + }); + + tearDown(() async { + await isarA.close(deleteFromDisk: true); + await isarB.close(deleteFromDisk: true); + await controllerA.close(); + await controllerB.close(); + }); + + test('WebSocket sync test', () async { + final configA = {'url': 'ws://localhost:1234', 'channel': mockChannelA}; + final configB = {'url': 'ws://localhost:1234', 'channel': mockChannelB}; + await syncManagerA.initializeBackend(websocketConfig: configA); + await syncManagerB.initializeBackend(websocketConfig: configB); + + await memoryGraphA.storeNode(MemoryNode(content: 'Node from A')); + await syncManagerA.publishSnapshot(); + + final snapshot = await syncManagerA.exportEncryptedSnapshot(); + controllerB.add(snapshot); + + await Future.delayed(const Duration(milliseconds: 100)); + final nodesOnB = await memoryGraphB.isar.memoryNodes.where().findAll(); + expect(nodesOnB.length, 1); + expect(nodesOnB.first.content, 'Node from A'); + }); +} diff --git a/test/hirag_phase2_integration_test.dart b/test/hirag_phase2_integration_test.dart new file mode 100644 index 0000000..7615d8e --- /dev/null +++ b/test/hirag_phase2_integration_test.dart @@ -0,0 +1,92 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar/isar.dart'; +import 'package:langchain_google/langchain_google.dart'; +import 'package:google_generative_ai/google_generative_ai.dart'; + +// Mock LLMAdapter for testing without actual API calls +class MockLLMAdapter implements LLMAdapter { + @override + Future generate(String prompt) async { + return 'This is a mock summary.'; + } +} + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isar; + late MemoryGraph memoryGraph; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + isar = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'test_db', + ); + + // Use a mock embeddings adapter to avoid actual API calls for embeddings + final embeddingsAdapter = FallbackEmbeddingsAdapter(); + memoryGraph = MemoryGraph(isar, embeddingsAdapter: embeddingsAdapter); + + // Clean up database before each test + await isar.writeTxn(() async { + await isar.clear(); + }); + }); + + tearDown(() async { + await isar.close(deleteFromDisk: true); + }); + + test('autoSummarizeLayer creates a summary node and correct relationships', + () async { + // 1. Setup: Create some nodes in layer 0 + final node1Id = await memoryGraph.storeNodeWithEmbedding( + content: 'The sky is blue.', layer: 0); + final node2Id = await memoryGraph.storeNodeWithEmbedding( + content: 'The grass is green.', layer: 0); + + // 2. Execute: Run auto-summarization + final llmAdapter = MockLLMAdapter(); + final summaryNodeId = await memoryGraph.autoSummarizeLayer( + layerIndex: 0, + llmAdapter: llmAdapter, + ); + + // 3. Verify: Check the results + final summaryNode = await memoryGraph.getNode(summaryNodeId); + expect(summaryNode, isNotNull); + expect(summaryNode!.layer, 1); + expect(summaryNode.content, 'This is a mock summary.'); + + // Verify 'summary_of' relationships (Child -> Summary) + final edgesFromNode1 = await memoryGraph.getEdgesForNode(node1Id); + expect( + edgesFromNode1.any((e) => + e.toNodeId == summaryNodeId && + e.relation == HierarchicalMemoryGraph.relationSummaryOf), + isTrue); + + final edgesFromNode2 = await memoryGraph.getEdgesForNode(node2Id); + expect( + edgesFromNode2.any((e) => + e.toNodeId == summaryNodeId && + e.relation == HierarchicalMemoryGraph.relationSummaryOf), + isTrue); + + // Verify 'part_of' relationships (Summary -> Child) + final edgesFromSummary = await memoryGraph.getEdgesForNode(summaryNodeId); + expect( + edgesFromSummary.any((e) => + e.toNodeId == node1Id && + e.relation == HierarchicalMemoryGraph.relationPartOf), + isTrue); + expect( + edgesFromSummary.any((e) => + e.toNodeId == node2Id && + e.relation == HierarchicalMemoryGraph.relationPartOf), + isTrue); + }); +} diff --git a/test/multi_hop_retrieval_test.dart b/test/multi_hop_retrieval_test.dart new file mode 100644 index 0000000..f186dd1 --- /dev/null +++ b/test/multi_hop_retrieval_test.dart @@ -0,0 +1,60 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar/isar.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isar; + late MemoryGraph memoryGraph; + late EmbeddingsAdapter embeddingsAdapter; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + isar = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'test_db', + ); + + embeddingsAdapter = FallbackEmbeddingsAdapter(); + memoryGraph = MemoryGraph(isar, embeddingsAdapter: embeddingsAdapter); + + await isar.writeTxn(() async { + await isar.clear(); + }); + }); + + tearDown(() async { + await isar.close(deleteFromDisk: true); + }); + + test('multiHopSearch returns enriched results with context', () async { + // 1. Setup: Create a small hierarchy + // Layer 0 + final nodeAId = await memoryGraph.storeNodeWithEmbedding( + content: 'Details about topic A.', layer: 0); + final nodeBId = await memoryGraph.storeNodeWithEmbedding( + content: 'Details about topic B.', layer: 0); + + // Layer 1 (Summary of A and B) + final summary1Id = await memoryGraph.createSummaryNode( + summaryContent: 'Summary of topics A and B.', + childNodeIds: [nodeAId, nodeBId], + layer: 1, + ); + + // 2. Execute: Search for something in layer 0 + final queryEmbedding = await embeddingsAdapter.embed('topic A'); + final results = await memoryGraph.multiHopSearch( + queryEmbedding: queryEmbedding, + topK: 1, + ); + + // 3. Verify: Check the results + expect(results, isNotEmpty); + expect(results.first.node.id, nodeAId); + expect(results.first.context, isNotEmpty); + expect(results.first.context.first.id, summary1Id); + }); +} diff --git a/test/reranking_strategies_test.dart b/test/reranking_strategies_test.dart new file mode 100644 index 0000000..5c44302 --- /dev/null +++ b/test/reranking_strategies_test.dart @@ -0,0 +1,103 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/rerankers/bm25_reranker.dart'; +import 'package:isar_agent_memory/src/rerankers/diversity_reranker.dart'; +import 'package:isar_agent_memory/src/rerankers/mmr_reranker.dart'; +import 'package:isar_agent_memory/src/rerankers/recency_reranker.dart'; + +void main() { + group('ReRankingStrategy', () { + test('RecencyReRanker test', () { + final reranker = RecencyReRanker(); + final now = DateTime.now(); + final results = [ + ( + node: MemoryNode( + content: 'older', createdAt: now.subtract(const Duration(days: 1))), + score: 0.8 + ), + ( + node: MemoryNode(content: 'newer', createdAt: now), + score: 0.7 + ), + ]; + + final reranked = reranker.reRank(results); + expect(reranked.first.node.content, 'newer'); + }); + + test('MMRReRanker test', () { + final reranker = MMRReRanker(); + final results = [ + ( + node: MemoryNode( + content: 'very relevant, very similar', + embedding: MemoryEmbedding(vector: [1.0, 0.0])), + score: 0.9 + ), + ( + node: MemoryNode( + content: 'very relevant, less similar', + embedding: MemoryEmbedding(vector: [0.0, 1.0])), + score: 0.8 + ), + ( + node: MemoryNode( + content: 'less relevant, very similar', + embedding: MemoryEmbedding(vector: [0.9, 0.1])), + score: 0.5 + ), + ]; + + final reranked = reranker.reRank(results); + expect(reranked[0].node.content, 'very relevant, very similar'); + expect(reranked[1].node.content, 'very relevant, less similar'); + }); + + test('DiversityReRanker test', () { + final reranker = DiversityReRanker(); + final results = [ + ( + node: MemoryNode( + content: 'item 1', + embedding: MemoryEmbedding(vector: [1.0, 0.0])), + score: 0.9 + ), + ( + node: MemoryNode( + content: 'item 2 (similar to 1)', + embedding: MemoryEmbedding(vector: [0.9, 0.1])), + score: 0.8 + ), + ( + node: MemoryNode( + content: 'item 3 (different)', + embedding: MemoryEmbedding(vector: [0.0, 1.0])), + score: 0.7 + ), + ]; + + final reranked = reranker.reRank(results); + expect(reranked[0].node.content, 'item 1'); + expect(reranked[1].node.content, 'item 3 (different)'); + }); + + test('BM25ReRanker test', () { + final reranker = BM25ReRanker(); + final results = [ + (node: MemoryNode(content: 'the quick brown fox'), score: 0.9), + (node: MemoryNode(content: 'a lazy dog'), score: 0.8), + ( + node: MemoryNode( + content: 'the quick brown fox jumps over the lazy dog'), + score: 0.7 + ), + ]; + + final reranked = reranker.reRank(results, query: 'quick fox'); + expect(reranked[0].node.content, 'the quick brown fox'); + expect(reranked[1].node.content, + 'the quick brown fox jumps over the lazy dog'); + }); + }); +} diff --git a/test/sync_conflict_resolution_test.dart b/test/sync_conflict_resolution_test.dart new file mode 100644 index 0000000..e89f8da --- /dev/null +++ b/test/sync_conflict_resolution_test.dart @@ -0,0 +1,57 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar/isar.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; +import 'package:isar_agent_memory/src/sync/sync_manager.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isar; + late MemoryGraph memoryGraph; + late SyncManager syncManager; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + isar = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'test_db', + ); + memoryGraph = + MemoryGraph(isar, embeddingsAdapter: FallbackEmbeddingsAdapter()); + syncManager = SyncManager(memoryGraph); + await syncManager.initialize(encryptionKey: List.filled(32, 1)); + }); + + tearDown(() async { + await isar.close(deleteFromDisk: true); + }); + + test('LWW conflict resolution test', () async { + // 1. Create a node and export it + final originalNode = MemoryNode( + content: 'Original content', + modifiedAt: DateTime.now().subtract(Duration(minutes: 5)), + ); + await memoryGraph.storeNode(originalNode); + final snapshot1 = await syncManager.exportEncryptedSnapshot(); + + // 2. Modify the node with a newer timestamp and export + final updatedNode = MemoryNode( + uuid: originalNode.uuid, + content: 'Newer content', + modifiedAt: DateTime.now(), + ); + await memoryGraph.storeNode(updatedNode); + final snapshot2 = await syncManager.exportEncryptedSnapshot(); + + // 3. Import the older snapshot first, then the newer one + await syncManager.importEncryptedSnapshot(snapshot1); + await syncManager.importEncryptedSnapshot(snapshot2); + + // 4. Verify that the newer content wins + final nodes = await memoryGraph.isar.memoryNodes.where().findAll(); + expect(nodes.length, 1); + expect(nodes.first.content, 'Newer content'); + }); +}