Skip to content

Commit 966b7ba

Browse files
committed
Add community clustering workflow and covariate joiner
1 parent 84a921f commit 966b7ba

File tree

13 files changed

+621
-158
lines changed

13 files changed

+621
-158
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
using System.Collections.Immutable;
2+
using System.Globalization;
3+
using GraphRag.Config;
4+
using GraphRag.Entities;
5+
using GraphRag.Relationships;
6+
7+
namespace GraphRag.Community;
8+
9+
internal static class CommunityBuilder
10+
{
11+
public static IReadOnlyList<CommunityRecord> Build(
12+
IReadOnlyList<EntityRecord> entities,
13+
IReadOnlyList<RelationshipRecord> relationships,
14+
ClusterGraphConfig? config)
15+
{
16+
ArgumentNullException.ThrowIfNull(entities);
17+
ArgumentNullException.ThrowIfNull(relationships);
18+
19+
config ??= new ClusterGraphConfig();
20+
21+
if (entities.Count == 0)
22+
{
23+
return Array.Empty<CommunityRecord>();
24+
}
25+
26+
var adjacency = BuildAdjacency(entities, relationships);
27+
var titleLookup = entities.ToDictionary(entity => entity.Title, StringComparer.OrdinalIgnoreCase);
28+
var random = new Random(config.Seed);
29+
30+
var orderedTitles = titleLookup.Keys
31+
.OrderBy(_ => random.Next())
32+
.ToList();
33+
34+
var visited = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
35+
var components = new List<List<string>>();
36+
37+
foreach (var title in orderedTitles)
38+
{
39+
if (!visited.Add(title))
40+
{
41+
continue;
42+
}
43+
44+
var component = new List<string>();
45+
var queue = new Queue<string>();
46+
queue.Enqueue(title);
47+
48+
while (queue.Count > 0)
49+
{
50+
var current = queue.Dequeue();
51+
if (!component.Contains(current))
52+
{
53+
component.Add(current);
54+
}
55+
56+
if (!adjacency.TryGetValue(current, out var neighbors) || neighbors.Count == 0)
57+
{
58+
continue;
59+
}
60+
61+
var orderedNeighbors = neighbors
62+
.OrderBy(_ => random.Next())
63+
.ToList();
64+
65+
foreach (var neighbor in orderedNeighbors)
66+
{
67+
if (visited.Add(neighbor))
68+
{
69+
queue.Enqueue(neighbor);
70+
}
71+
}
72+
}
73+
74+
components.Add(component);
75+
}
76+
77+
if (config.UseLargestConnectedComponent && components.Count > 0)
78+
{
79+
var largestSize = components.Max(component => component.Count);
80+
components = components
81+
.Where(component => component.Count == largestSize)
82+
.Take(1)
83+
.ToList();
84+
}
85+
86+
var clusters = components
87+
.SelectMany(component => SplitComponent(component, config.MaxClusterSize))
88+
.ToList();
89+
90+
if (clusters.Count == 0)
91+
{
92+
return Array.Empty<CommunityRecord>();
93+
}
94+
95+
var period = DateTime.UtcNow.ToString("yyyy-MM-dd", CultureInfo.InvariantCulture);
96+
var communityRecords = new List<CommunityRecord>(clusters.Count);
97+
var relationshipLookup = relationships.ToList();
98+
99+
var communityIndex = 0;
100+
foreach (var cluster in clusters)
101+
{
102+
var memberTitles = cluster
103+
.Distinct(StringComparer.OrdinalIgnoreCase)
104+
.Where(titleLookup.ContainsKey)
105+
.ToList();
106+
107+
if (memberTitles.Count == 0)
108+
{
109+
continue;
110+
}
111+
112+
var members = memberTitles
113+
.Select(title => titleLookup[title])
114+
.OrderBy(entity => entity.HumanReadableId)
115+
.ToList();
116+
117+
if (members.Count == 0)
118+
{
119+
continue;
120+
}
121+
122+
communityIndex++;
123+
var communityId = communityIndex;
124+
125+
var entityIds = members
126+
.Select(member => member.Id)
127+
.ToImmutableArray();
128+
129+
var membership = new HashSet<string>(memberTitles, StringComparer.OrdinalIgnoreCase);
130+
var relationshipIds = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
131+
var textUnitIds = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
132+
133+
foreach (var relationship in relationshipLookup)
134+
{
135+
if (!membership.Contains(relationship.Source) || !membership.Contains(relationship.Target))
136+
{
137+
continue;
138+
}
139+
140+
relationshipIds.Add(relationship.Id);
141+
142+
foreach (var textUnitId in relationship.TextUnitIds)
143+
{
144+
if (!string.IsNullOrWhiteSpace(textUnitId))
145+
{
146+
textUnitIds.Add(textUnitId);
147+
}
148+
}
149+
}
150+
151+
if (textUnitIds.Count == 0)
152+
{
153+
foreach (var member in members)
154+
{
155+
foreach (var textUnitId in member.TextUnitIds)
156+
{
157+
if (!string.IsNullOrWhiteSpace(textUnitId))
158+
{
159+
textUnitIds.Add(textUnitId);
160+
}
161+
}
162+
}
163+
}
164+
165+
var record = new CommunityRecord(
166+
Id: Guid.NewGuid().ToString(),
167+
HumanReadableId: communityId,
168+
CommunityId: communityId,
169+
Level: 0,
170+
ParentId: -1,
171+
Children: ImmutableArray<int>.Empty,
172+
Title: $"Community {communityId}",
173+
EntityIds: entityIds,
174+
RelationshipIds: relationshipIds
175+
.OrderBy(id => id, StringComparer.Ordinal)
176+
.ToImmutableArray(),
177+
TextUnitIds: textUnitIds
178+
.OrderBy(id => id, StringComparer.Ordinal)
179+
.ToImmutableArray(),
180+
Period: period,
181+
Size: members.Count);
182+
183+
communityRecords.Add(record);
184+
}
185+
186+
return communityRecords;
187+
}
188+
189+
private static Dictionary<string, HashSet<string>> BuildAdjacency(
190+
IReadOnlyList<EntityRecord> entities,
191+
IReadOnlyList<RelationshipRecord> relationships)
192+
{
193+
var adjacency = new Dictionary<string, HashSet<string>>(StringComparer.OrdinalIgnoreCase);
194+
195+
foreach (var entity in entities)
196+
{
197+
adjacency.TryAdd(entity.Title, new HashSet<string>(StringComparer.OrdinalIgnoreCase));
198+
}
199+
200+
foreach (var relationship in relationships)
201+
{
202+
if (!adjacency.TryGetValue(relationship.Source, out var sourceNeighbors))
203+
{
204+
sourceNeighbors = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
205+
adjacency[relationship.Source] = sourceNeighbors;
206+
}
207+
208+
if (!adjacency.TryGetValue(relationship.Target, out var targetNeighbors))
209+
{
210+
targetNeighbors = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
211+
adjacency[relationship.Target] = targetNeighbors;
212+
}
213+
214+
sourceNeighbors.Add(relationship.Target);
215+
targetNeighbors.Add(relationship.Source);
216+
}
217+
218+
return adjacency;
219+
}
220+
221+
private static IEnumerable<List<string>> SplitComponent(List<string> component, int maxClusterSize)
222+
{
223+
if (component.Count == 0)
224+
{
225+
yield break;
226+
}
227+
228+
if (maxClusterSize <= 0 || component.Count <= maxClusterSize)
229+
{
230+
yield return component;
231+
yield break;
232+
}
233+
234+
for (var index = 0; index < component.Count; index += maxClusterSize)
235+
{
236+
var length = Math.Min(maxClusterSize, component.Count - index);
237+
yield return component.GetRange(index, length);
238+
}
239+
}
240+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
namespace GraphRag.Config;
2+
3+
/// <summary>
4+
/// Configuration settings for graph community clustering.
5+
/// </summary>
6+
public sealed class ClusterGraphConfig
7+
{
8+
/// <summary>
9+
/// Gets or sets the maximum number of entities allowed in a single community cluster.
10+
/// A value less than or equal to zero disables the limit.
11+
/// </summary>
12+
public int MaxClusterSize { get; set; } = 10;
13+
14+
/// <summary>
15+
/// Gets or sets a value indicating whether the largest connected component
16+
/// should be used when clustering.
17+
/// </summary>
18+
public bool UseLargestConnectedComponent { get; set; } = true;
19+
20+
/// <summary>
21+
/// Gets or sets the seed used when ordering traversal operations to keep
22+
/// results deterministic across runs.
23+
/// </summary>
24+
public int Seed { get; set; } = unchecked((int)0xDEADBEEF);
25+
}

src/ManagedCode.GraphRag/Config/GraphRagConfig.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ public sealed class GraphRagConfig
5050

5151
public SummarizeDescriptionsConfig SummarizeDescriptions { get; set; } = new();
5252

53+
public ClusterGraphConfig ClusterGraph { get; set; } = new();
54+
5355
public CommunityReportsConfig CommunityReports { get; set; } = new();
5456

5557
public SnapshotsConfig Snapshots { get; set; } = new();
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using GraphRag.Data;
2+
3+
namespace GraphRag.Covariates;
4+
5+
/// <summary>
6+
/// Provides helpers for attaching extracted covariates back onto text unit records.
7+
/// </summary>
8+
public static class TextUnitCovariateJoiner
9+
{
10+
public static IReadOnlyList<TextUnitRecord> Attach(
11+
IReadOnlyList<TextUnitRecord> textUnits,
12+
IReadOnlyList<CovariateRecord> covariates)
13+
{
14+
ArgumentNullException.ThrowIfNull(textUnits);
15+
ArgumentNullException.ThrowIfNull(covariates);
16+
17+
if (textUnits.Count == 0 || covariates.Count == 0)
18+
{
19+
return textUnits;
20+
}
21+
22+
var lookup = new Dictionary<string, HashSet<string>>(StringComparer.OrdinalIgnoreCase);
23+
foreach (var covariate in covariates)
24+
{
25+
if (string.IsNullOrWhiteSpace(covariate.TextUnitId))
26+
{
27+
continue;
28+
}
29+
30+
if (!lookup.TryGetValue(covariate.TextUnitId, out var ids))
31+
{
32+
ids = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
33+
lookup[covariate.TextUnitId] = ids;
34+
}
35+
36+
if (!string.IsNullOrWhiteSpace(covariate.Id))
37+
{
38+
ids.Add(covariate.Id);
39+
}
40+
}
41+
42+
if (lookup.Count == 0)
43+
{
44+
return textUnits;
45+
}
46+
47+
var results = new List<TextUnitRecord>(textUnits.Count);
48+
foreach (var unit in textUnits)
49+
{
50+
if (!lookup.TryGetValue(unit.Id, out var ids))
51+
{
52+
results.Add(unit);
53+
continue;
54+
}
55+
56+
var existing = unit.CovariateIds ?? Array.Empty<string>();
57+
var combined = new HashSet<string>(existing, StringComparer.OrdinalIgnoreCase);
58+
foreach (var id in ids)
59+
{
60+
combined.Add(id);
61+
}
62+
63+
var ordered = combined
64+
.OrderBy(value => value, StringComparer.Ordinal)
65+
.ToArray();
66+
67+
results.Add(unit with { CovariateIds = ordered });
68+
}
69+
70+
return results;
71+
}
72+
}

src/ManagedCode.GraphRag/Indexing/Runtime/IndexingPipelineDefinitions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public static class IndexingPipelineDefinitions
99
LoadInputDocumentsWorkflow.Name,
1010
CreateBaseTextUnitsWorkflow.Name,
1111
ExtractGraphWorkflow.Name,
12+
CreateCommunitiesWorkflow.Name,
1213
CommunitySummariesWorkflow.Name,
1314
CreateFinalDocumentsWorkflow.Name
1415
});

0 commit comments

Comments
 (0)