forked from microsoft/semantic-kernel
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathChatHistorySummarizationReducer.cs
More file actions
186 lines (160 loc) · 7.59 KB
/
Copy pathChatHistorySummarizationReducer.cs
File metadata and controls
186 lines (160 loc) · 7.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.SemanticKernel.ChatCompletion;
/// <summary>
/// Reduce the chat history by summarizing message past the target message count.
/// </summary>
/// <remarks>
/// Summarization will always avoid orphaning function-content as the presence of
/// a function-call _must_ be followed by a function-result. When a threshold count is
/// is provided (recommended), reduction will scan within the threshold window in an attempt to
/// avoid orphaning a user message from an assistant response.
/// </remarks>
public class ChatHistorySummarizationReducer : IChatHistoryReducer
{
/// <summary>
/// Metadata key to indicate a summary message.
/// </summary>
public const string SummaryMetadataKey = "__summary__";
/// <summary>
/// The default summarization system instructions.
/// </summary>
public const string DefaultSummarizationPrompt =
"""
Provide a concise and complete summarization of the entire dialog that does not exceed 5 sentences
This summary must always:
- Consider both user and assistant interactions
- Maintain continuity for the purpose of further dialog
- Include details from any existing summary
- Focus on the most significant aspects of the dialog
This summary must never:
- Critique, correct, interpret, presume, or assume
- Identify faults, mistakes, misunderstanding, or correctness
- Analyze what has not occurred
- Exclude details from any existing summary
""";
/// <summary>
/// System instructions for summarization. Defaults to <see cref="DefaultSummarizationPrompt"/>.
/// </summary>
public string SummarizationInstructions { get; init; } = DefaultSummarizationPrompt;
/// <summary>
/// Flag to indicate if an exception should be thrown if summarization fails.
/// </summary>
public bool FailOnError { get; init; } = true;
/// <summary>
/// Flag to indicate summarization is maintained in a single message, or if a series of
/// summations are generated over time.
/// </summary>
/// <remarks>
/// Not using a single summary may ultimately result in a chat history that exceeds the token limit.
/// </remarks>
public bool UseSingleSummary { get; init; } = true;
/// <summary>
/// Initializes a new instance of the <see cref="ChatHistorySummarizationReducer"/> class.
/// </summary>
/// <param name="service">A <see cref="IChatCompletionService"/> instance to be used for summarization.</param>
/// <param name="targetCount">The desired number of target messages after reduction.</param>
/// <param name="thresholdCount">An optional number of messages beyond the 'targetCount' that must be present in order to trigger reduction/</param>
/// <remarks>
/// While the 'thresholdCount' is optional, it is recommended to provided so that reduction is not triggered
/// for every incremental addition to the chat history beyond the 'targetCount'.
/// </remarks>>
public ChatHistorySummarizationReducer(IChatCompletionService service, int targetCount, int? thresholdCount = null)
{
Verify.NotNull(service, nameof(service));
Verify.True(targetCount > 0, "Target message count must be greater than zero.");
Verify.True(!thresholdCount.HasValue || thresholdCount > 0, "The reduction threshold length must be greater than zero.");
this._service = service;
this._targetCount = targetCount;
this._thresholdCount = thresholdCount ?? 0;
}
/// <inheritdoc/>
public async Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> chatHistory, CancellationToken cancellationToken = default)
{
var systemMessage = chatHistory.FirstOrDefault(l => l.Role == AuthorRole.System);
// Identify where summary messages end and regular history begins
int insertionPoint = chatHistory.LocateSummarizationBoundary(SummaryMetadataKey);
// First pass to determine the truncation index
int truncationIndex = chatHistory.LocateSafeReductionIndex(
this._targetCount,
this._thresholdCount,
insertionPoint,
hasSystemMessage: systemMessage is not null);
IEnumerable<ChatMessageContent>? truncatedHistory = null;
if (truncationIndex >= 0)
{
// Second pass to extract history for summarization
IEnumerable<ChatMessageContent> summarizedHistory =
chatHistory.Extract(
this.UseSingleSummary ? 0 : insertionPoint,
truncationIndex,
filter: (m) => m.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent));
try
{
// Summarize
ChatHistory summarizationRequest = [.. summarizedHistory, new ChatMessageContent(AuthorRole.System, this.SummarizationInstructions)];
ChatMessageContent summaryMessage = await this._service.GetChatMessageContentAsync(summarizationRequest, cancellationToken: cancellationToken).ConfigureAwait(false);
summaryMessage.Metadata = new Dictionary<string, object?> { { SummaryMetadataKey, true } };
// Assembly the summarized history
truncatedHistory = AssemblySummarizedHistory(summaryMessage, systemMessage);
}
catch
{
if (this.FailOnError)
{
throw;
}
}
}
return truncatedHistory;
// Inner function to assemble the summarized history
IEnumerable<ChatMessageContent> AssemblySummarizedHistory(ChatMessageContent? summaryMessage, ChatMessageContent? systemMessage)
{
if (systemMessage is not null)
{
yield return systemMessage;
}
if (insertionPoint > 0 && !this.UseSingleSummary)
{
for (int index = 0; index <= insertionPoint - 1; ++index)
{
yield return chatHistory[index];
}
}
if (summaryMessage is not null)
{
yield return summaryMessage;
}
for (int index = truncationIndex; index < chatHistory.Count; ++index)
{
yield return chatHistory[index];
}
}
}
/// <inheritdoc/>
public override bool Equals(object? obj)
{
ChatHistorySummarizationReducer? other = obj as ChatHistorySummarizationReducer;
return other != null &&
this._thresholdCount == other._thresholdCount &&
this._targetCount == other._targetCount &&
this.UseSingleSummary == other.UseSingleSummary &&
string.Equals(this.SummarizationInstructions, other.SummarizationInstructions, StringComparison.Ordinal);
}
/// <inheritdoc/>
public override int GetHashCode()
{
#if UNITY
return HashCodeSlim.Combine(nameof(ChatHistorySummarizationReducer), this._thresholdCount, this._targetCount, this.SummarizationInstructions, this.UseSingleSummary);
#else
return HashCode.Combine(nameof(ChatHistorySummarizationReducer), this._thresholdCount, this._targetCount, this.SummarizationInstructions, this.UseSingleSummary);
#endif
}
private readonly IChatCompletionService _service;
private readonly int _thresholdCount;
private readonly int _targetCount;
}