Skip to content

Commit 67a96f0

Browse files
committed
Add integration coverage for heuristic maintenance
1 parent 0109b41 commit 67a96f0

File tree

2 files changed

+385
-0
lines changed

2 files changed

+385
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using Microsoft.Extensions.AI;
2+
3+
namespace ManagedCode.GraphRag.Tests.Infrastructure;
4+
5+
internal sealed class StubEmbeddingGenerator : IEmbeddingGenerator<string, Embedding<float>>
6+
{
7+
private readonly Dictionary<string, float[]> _vectors;
8+
private readonly float[] _fallback;
9+
10+
public StubEmbeddingGenerator(IDictionary<string, float[]>? vectors = null)
11+
{
12+
_vectors = vectors is null
13+
? new Dictionary<string, float[]>(StringComparer.OrdinalIgnoreCase)
14+
: new Dictionary<string, float[]>(vectors, StringComparer.OrdinalIgnoreCase);
15+
16+
_fallback = _vectors.Values.FirstOrDefault() ?? new[] { 0.5f, 0.5f, 0.5f };
17+
}
18+
19+
public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
20+
IEnumerable<string> values,
21+
EmbeddingGenerationOptions? options = null,
22+
CancellationToken cancellationToken = default)
23+
{
24+
ArgumentNullException.ThrowIfNull(values);
25+
26+
var embeddings = new List<Embedding<float>>();
27+
28+
foreach (var value in values)
29+
{
30+
cancellationToken.ThrowIfCancellationRequested();
31+
var vector = ResolveVector(value);
32+
embeddings.Add(new Embedding<float>(new ReadOnlyMemory<float>(vector)));
33+
}
34+
35+
return Task.FromResult(new GeneratedEmbeddings<Embedding<float>>(embeddings));
36+
}
37+
38+
public object? GetService(Type serviceType, object? serviceKey = null) => null;
39+
40+
public void Dispose()
41+
{
42+
}
43+
44+
private float[] ResolveVector(string? value)
45+
{
46+
if (!string.IsNullOrWhiteSpace(value) && _vectors.TryGetValue(value, out var vector))
47+
{
48+
return vector;
49+
}
50+
51+
return _fallback;
52+
}
53+
}
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
using System.Collections.Immutable;
2+
using GraphRag;
3+
using GraphRag.Callbacks;
4+
using GraphRag.Community;
5+
using GraphRag.Config;
6+
using GraphRag.Constants;
7+
using GraphRag.Data;
8+
using GraphRag.Entities;
9+
using GraphRag.Indexing.Runtime;
10+
using GraphRag.Indexing.Workflows;
11+
using GraphRag.Relationships;
12+
using GraphRag.Storage;
13+
using ManagedCode.GraphRag.Tests.Infrastructure;
14+
using Microsoft.Extensions.AI;
15+
using Microsoft.Extensions.DependencyInjection;
16+
17+
namespace ManagedCode.GraphRag.Tests.Integration;
18+
19+
public sealed class HeuristicMaintenanceIntegrationTests : IDisposable
20+
{
21+
private readonly string _rootDir;
22+
23+
public HeuristicMaintenanceIntegrationTests()
24+
{
25+
_rootDir = Path.Combine(Path.GetTempPath(), "GraphRag", Guid.NewGuid().ToString("N"));
26+
Directory.CreateDirectory(_rootDir);
27+
}
28+
29+
[Fact]
30+
public async Task HeuristicMaintenanceWorkflow_AppliesBudgetsAndSemanticDeduplication()
31+
{
32+
var outputDir = PrepareDirectory("output-maintenance");
33+
var inputDir = PrepareDirectory("input-maintenance");
34+
var previousDir = PrepareDirectory("previous-maintenance");
35+
36+
var textUnits = new[]
37+
{
38+
new TextUnitRecord
39+
{
40+
Id = "a",
41+
Text = "Alpha Beta",
42+
TokenCount = 40,
43+
DocumentIds = new[] { "doc-1" },
44+
EntityIds = Array.Empty<string>(),
45+
RelationshipIds = Array.Empty<string>(),
46+
CovariateIds = Array.Empty<string>()
47+
},
48+
new TextUnitRecord
49+
{
50+
Id = "b",
51+
Text = "Gamma Delta",
52+
TokenCount = 30,
53+
DocumentIds = new[] { "doc-1" },
54+
EntityIds = Array.Empty<string>(),
55+
RelationshipIds = Array.Empty<string>(),
56+
CovariateIds = Array.Empty<string>()
57+
},
58+
new TextUnitRecord
59+
{
60+
Id = "c",
61+
Text = "Trim me",
62+
TokenCount = 30,
63+
DocumentIds = new[] { "doc-1" },
64+
EntityIds = Array.Empty<string>(),
65+
RelationshipIds = Array.Empty<string>(),
66+
CovariateIds = Array.Empty<string>()
67+
},
68+
new TextUnitRecord
69+
{
70+
Id = "d",
71+
Text = "Alpha Beta",
72+
TokenCount = 35,
73+
DocumentIds = new[] { "doc-2" },
74+
EntityIds = Array.Empty<string>(),
75+
RelationshipIds = Array.Empty<string>(),
76+
CovariateIds = Array.Empty<string>()
77+
}
78+
};
79+
80+
var outputStorage = new FilePipelineStorage(outputDir);
81+
await outputStorage.WriteTableAsync(PipelineTableNames.TextUnits, textUnits);
82+
83+
var embeddingVectors = new Dictionary<string, float[]>
84+
{
85+
["Alpha Beta"] = new[] { 1f, 0f },
86+
["Gamma Delta"] = new[] { 0f, 1f }
87+
};
88+
89+
using var services = new ServiceCollection()
90+
.AddLogging()
91+
.AddSingleton<IChatClient>(new TestChatClientFactory().CreateClient())
92+
.AddSingleton<IEmbeddingGenerator<string, Embedding<float>>>(new StubEmbeddingGenerator(embeddingVectors))
93+
.AddKeyedSingleton<IEmbeddingGenerator<string, Embedding<float>>>("dedupe-model", (sp, _) => sp.GetRequiredService<IEmbeddingGenerator<string, Embedding<float>>>())
94+
.AddGraphRag()
95+
.BuildServiceProvider();
96+
97+
var config = new GraphRagConfig
98+
{
99+
Heuristics = new HeuristicMaintenanceConfig
100+
{
101+
MaxTokensPerTextUnit = 50,
102+
MaxDocumentTokenBudget = 80,
103+
EnableSemanticDeduplication = true,
104+
SemanticDeduplicationThreshold = 0.75,
105+
EmbeddingModelId = "dedupe-model"
106+
}
107+
};
108+
109+
var context = new PipelineRunContext(
110+
inputStorage: new FilePipelineStorage(inputDir),
111+
outputStorage: outputStorage,
112+
previousStorage: new FilePipelineStorage(previousDir),
113+
cache: new StubPipelineCache(),
114+
callbacks: NoopWorkflowCallbacks.Instance,
115+
stats: new PipelineRunStats(),
116+
state: new PipelineState(),
117+
services: services);
118+
119+
var workflow = HeuristicMaintenanceWorkflow.Create();
120+
await workflow(config, context, CancellationToken.None);
121+
122+
var processed = await outputStorage.LoadTableAsync<TextUnitRecord>(PipelineTableNames.TextUnits);
123+
Assert.Equal(2, processed.Count);
124+
125+
var merged = Assert.Single(processed, unit => unit.Id == "a");
126+
Assert.Equal(2, merged.DocumentIds.Count);
127+
Assert.Contains("doc-1", merged.DocumentIds, StringComparer.OrdinalIgnoreCase);
128+
Assert.Contains("doc-2", merged.DocumentIds, StringComparer.OrdinalIgnoreCase);
129+
Assert.Equal(35, merged.TokenCount);
130+
131+
var survivor = Assert.Single(processed, unit => unit.Id == "b");
132+
Assert.Single(survivor.DocumentIds);
133+
Assert.Equal("doc-1", survivor.DocumentIds[0]);
134+
Assert.DoesNotContain(processed, unit => unit.Id == "c");
135+
Assert.DoesNotContain(processed, unit => unit.Id == "d" && unit.DocumentIds.Count == 1);
136+
}
137+
138+
[Fact]
139+
public async Task ExtractGraphWorkflow_LinksOrphansAndEnforcesRelationshipFloors()
140+
{
141+
var outputDir = PrepareDirectory("output-graph");
142+
var inputDir = PrepareDirectory("input-graph");
143+
var previousDir = PrepareDirectory("previous-graph");
144+
145+
var outputStorage = new FilePipelineStorage(outputDir);
146+
await outputStorage.WriteTableAsync(PipelineTableNames.TextUnits, new[]
147+
{
148+
new TextUnitRecord
149+
{
150+
Id = "unit-1",
151+
Text = "Alice collaborates with Bob on research.",
152+
TokenCount = 12,
153+
DocumentIds = new[] { "doc-1" },
154+
EntityIds = Array.Empty<string>(),
155+
RelationshipIds = Array.Empty<string>(),
156+
CovariateIds = Array.Empty<string>()
157+
},
158+
new TextUnitRecord
159+
{
160+
Id = "unit-2",
161+
Text = "Charlie and Alice planned a workshop.",
162+
TokenCount = 18,
163+
DocumentIds = new[] { "doc-1" },
164+
EntityIds = Array.Empty<string>(),
165+
RelationshipIds = Array.Empty<string>(),
166+
CovariateIds = Array.Empty<string>()
167+
}
168+
});
169+
170+
var responses = new Queue<string>(new[]
171+
{
172+
"{\"entities\": [ { \"title\": \"Alice\", \"type\": \"person\", \"description\": \"Researcher\", \"confidence\": 0.9 }, { \"title\": \"Bob\", \"type\": \"person\", \"description\": \"Engineer\", \"confidence\": 0.6 } ], \"relationships\": [ { \"source\": \"Alice\", \"target\": \"Bob\", \"type\": \"collaborates\", \"description\": \"Works together\", \"weight\": 0.1, \"bidirectional\": false } ] }",
173+
"{\"entities\": [ { \"title\": \"Alice\", \"type\": \"person\", \"description\": \"Researcher\", \"confidence\": 0.8 }, { \"title\": \"Charlie\", \"type\": \"person\", \"description\": \"Analyst\", \"confidence\": 0.7 } ], \"relationships\": [] }"
174+
});
175+
176+
using var services = new ServiceCollection()
177+
.AddLogging()
178+
.AddSingleton<IChatClient>(new TestChatClientFactory(_ =>
179+
{
180+
if (responses.Count == 0)
181+
{
182+
throw new InvalidOperationException("No chat responses remaining.");
183+
}
184+
185+
var payload = responses.Dequeue();
186+
return new ChatResponse(new ChatMessage(ChatRole.Assistant, payload));
187+
}).CreateClient())
188+
.AddGraphRag()
189+
.BuildServiceProvider();
190+
191+
var config = new GraphRagConfig
192+
{
193+
Heuristics = new HeuristicMaintenanceConfig
194+
{
195+
LinkOrphanEntities = true,
196+
OrphanLinkWeight = 0.5,
197+
MaxTextUnitsPerRelationship = 1,
198+
RelationshipConfidenceFloor = 0.4
199+
}
200+
};
201+
202+
var context = new PipelineRunContext(
203+
inputStorage: new FilePipelineStorage(inputDir),
204+
outputStorage: outputStorage,
205+
previousStorage: new FilePipelineStorage(previousDir),
206+
cache: new StubPipelineCache(),
207+
callbacks: NoopWorkflowCallbacks.Instance,
208+
stats: new PipelineRunStats(),
209+
state: new PipelineState(),
210+
services: services);
211+
212+
var workflow = ExtractGraphWorkflow.Create();
213+
await workflow(config, context, CancellationToken.None);
214+
215+
var relationships = await outputStorage.LoadTableAsync<RelationshipRecord>(PipelineTableNames.Relationships);
216+
Assert.Equal(2, relationships.Count);
217+
218+
var direct = Assert.Single(relationships, rel => rel.Source == "Alice" && rel.Target == "Bob");
219+
Assert.Equal(0.4, direct.Weight, 3);
220+
Assert.Contains("unit-1", direct.TextUnitIds);
221+
Assert.False(direct.Bidirectional);
222+
223+
var synthetic = Assert.Single(relationships, rel => rel.Source == "Charlie" && rel.Target == "Alice");
224+
Assert.True(synthetic.Bidirectional);
225+
Assert.Equal(0.5, synthetic.Weight, 3);
226+
var orphanUnit = Assert.Single(synthetic.TextUnitIds);
227+
Assert.Equal("unit-2", orphanUnit);
228+
229+
var entities = await outputStorage.LoadTableAsync<EntityRecord>(PipelineTableNames.Entities);
230+
Assert.Equal(3, entities.Count);
231+
Assert.Contains(entities, entity => entity.Title == "Charlie");
232+
}
233+
234+
[Fact]
235+
public async Task CreateCommunitiesWorkflow_UsesFastLabelPropagationAssignments()
236+
{
237+
var outputDir = PrepareDirectory("output-communities");
238+
var inputDir = PrepareDirectory("input-communities");
239+
var previousDir = PrepareDirectory("previous-communities");
240+
241+
var outputStorage = new FilePipelineStorage(outputDir);
242+
243+
var entities = new[]
244+
{
245+
new EntityRecord("entity-1", 0, "Alice", "Person", "Researcher", ImmutableArray.Create("unit-1"), 2, 2, 0, 0),
246+
new EntityRecord("entity-2", 1, "Bob", "Person", "Engineer", ImmutableArray.Create("unit-1"), 2, 2, 0, 0),
247+
new EntityRecord("entity-3", 2, "Charlie", "Person", "Analyst", ImmutableArray.Create("unit-2"), 2, 1, 0, 0),
248+
new EntityRecord("entity-4", 3, "Diana", "Person", "Strategist", ImmutableArray.Create("unit-3"), 2, 1, 0, 0),
249+
new EntityRecord("entity-5", 4, "Eve", "Person", "Planner", ImmutableArray.Create("unit-3"), 2, 1, 0, 0)
250+
};
251+
252+
await outputStorage.WriteTableAsync(PipelineTableNames.Entities, entities);
253+
254+
var relationships = new[]
255+
{
256+
new RelationshipRecord("rel-1", 0, "Alice", "Bob", "collaborates", "", 0.9, 2, ImmutableArray.Create("unit-1"), true),
257+
new RelationshipRecord("rel-2", 1, "Bob", "Charlie", "supports", "", 0.85, 2, ImmutableArray.Create("unit-2"), true),
258+
new RelationshipRecord("rel-3", 2, "Diana", "Eve", "partners", "", 0.95, 2, ImmutableArray.Create("unit-3"), true)
259+
};
260+
261+
await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, relationships);
262+
263+
using var services = new ServiceCollection()
264+
.AddLogging()
265+
.AddSingleton<IChatClient>(new TestChatClientFactory().CreateClient())
266+
.AddGraphRag()
267+
.BuildServiceProvider();
268+
269+
var config = new GraphRagConfig
270+
{
271+
ClusterGraph = new ClusterGraphConfig
272+
{
273+
Algorithm = CommunityDetectionAlgorithm.FastLabelPropagation,
274+
MaxIterations = 8,
275+
MaxClusterSize = 10,
276+
Seed = 13,
277+
UseLargestConnectedComponent = false
278+
}
279+
};
280+
281+
var context = new PipelineRunContext(
282+
inputStorage: new FilePipelineStorage(inputDir),
283+
outputStorage: outputStorage,
284+
previousStorage: new FilePipelineStorage(previousDir),
285+
cache: new StubPipelineCache(),
286+
callbacks: NoopWorkflowCallbacks.Instance,
287+
stats: new PipelineRunStats(),
288+
state: new PipelineState(),
289+
services: services);
290+
291+
var workflow = CreateCommunitiesWorkflow.Create();
292+
await workflow(config, context, CancellationToken.None);
293+
294+
var communities = await outputStorage.LoadTableAsync<CommunityRecord>(PipelineTableNames.Communities);
295+
Assert.Equal(2, communities.Count);
296+
Assert.Equal(communities.Count, Assert.IsType<int>(context.Items["create_communities:count"]));
297+
298+
var titleLookup = entities.ToDictionary(entity => entity.Id, entity => entity.Title, StringComparer.OrdinalIgnoreCase);
299+
300+
var members = communities
301+
.Select(community => community.EntityIds
302+
.Select(id => titleLookup[id])
303+
.OrderBy(title => title, StringComparer.OrdinalIgnoreCase)
304+
.ToArray())
305+
.ToList();
306+
307+
Assert.Contains(members, group => group.SequenceEqual(new[] { "Alice", "Bob", "Charlie" }));
308+
Assert.Contains(members, group => group.SequenceEqual(new[] { "Diana", "Eve" }));
309+
}
310+
311+
public void Dispose()
312+
{
313+
try
314+
{
315+
if (Directory.Exists(_rootDir))
316+
{
317+
Directory.Delete(_rootDir, recursive: true);
318+
}
319+
}
320+
catch
321+
{
322+
// Ignore cleanup errors in tests.
323+
}
324+
}
325+
326+
private string PrepareDirectory(string name)
327+
{
328+
var path = Path.Combine(_rootDir, name);
329+
Directory.CreateDirectory(path);
330+
return path;
331+
}
332+
}

0 commit comments

Comments
 (0)