-
Notifications
You must be signed in to change notification settings - Fork 2k
Expand file tree
/
Copy pathAIContextProviderWorkflowTests.cs
More file actions
185 lines (160 loc) · 7.9 KB
/
Copy pathAIContextProviderWorkflowTests.cs
File metadata and controls
185 lines (160 loc) · 7.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI.Workflows.UnitTests;
/// <summary>
/// Validates that messages injected by <see cref="AIContextProvider"/> into an inner agent
/// are correctly persisted into the workflow's chat history, without leaking to downstream agents.
/// </summary>
public class AIContextProviderWorkflowTests
{
private const string UserText = "Where is Taggia?";
private const string ContextText = "Taggia is a city in Liguria.";
private const string FirstAgentResponseText = "Taggia is in Liguria.";
/// <summary>
/// Ensures that AIContextProvider-injected messages appear in the workflow session's
/// chat history and survive serialization (regression test for the bug where such
/// messages were lost because WorkflowHostAgent only persisted model outputs).
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_SerializesAIContextProviderRequestMessagesAsync()
{
// Arrange
ChatClientAgent innerAgent = CreateContextAwareAgent();
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent).AsAIAgent();
AgentSession session = await workflowAgent.CreateSessionAsync();
// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), session);
JsonElement serializedSession = await workflowAgent.SerializeSessionAsync(session);
// Assert
WorkflowSession workflowSession = session.Should().BeOfType<WorkflowSession>().Subject;
string[] historyTexts =
[
.. workflowSession.ChatHistoryProvider
.GetAllMessages(workflowSession)
.Select(message => message.Text)
];
historyTexts.Should().Contain(UserText);
historyTexts.Should().Contain(ContextText);
historyTexts.Should().Contain(FirstAgentResponseText);
serializedSession.GetRawText().Should().Contain(ContextText);
}
/// <summary>
/// Ensures that AIContextProvider-injected messages are still persisted when inner chat history is pruned.
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_SerializesAIContextProviderRequestMessagesWhenInnerHistoryIsPrunedAsync()
{
// Arrange
RetainingChatHistoryProvider chatHistoryProvider = new(maxStoredMessages: 2);
chatHistoryProvider.Add(new ChatMessage(ChatRole.User, "Previous question") { MessageId = "previous-user" });
chatHistoryProvider.Add(new ChatMessage(ChatRole.Assistant, "Previous answer") { MessageId = "previous-assistant" });
ChatClientAgent innerAgent = CreateContextAwareAgent(chatHistoryProvider);
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent).AsAIAgent();
AgentSession session = await workflowAgent.CreateSessionAsync();
// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), session);
// Assert
WorkflowSession workflowSession = session.Should().BeOfType<WorkflowSession>().Subject;
workflowSession.ChatHistoryProvider
.GetAllMessages(workflowSession)
.Select(message => message.Text)
.Should()
.Contain(ContextText);
}
/// <summary>
/// Ensures that AIContextProvider-injected messages are saved to workflow history
/// but are NOT forwarded as part of the input to subsequent agents in the workflow.
/// </summary>
[Fact]
public async Task Test_WorkflowAsAgent_DoesNotForwardAIContextProviderRequestMessagesToDownstreamAgentAsync()
{
// Arrange
ChatClientAgent innerAgent = CreateContextAwareAgent();
RecordingEchoAgent downstreamAgent = new(id: "downstream", name: "downstream", prefix: "downstream:");
AIAgent workflowAgent = AgentWorkflowBuilder.BuildSequential(innerAgent, downstreamAgent).AsAIAgent();
// Act
await workflowAgent.RunAsync(new ChatMessage(ChatRole.User, UserText), await workflowAgent.CreateSessionAsync());
// Assert
downstreamAgent.RecordedInputs.Should().ContainSingle();
string[] downstreamTexts = [.. downstreamAgent.RecordedInputs[0].Select(message => message.Text)];
downstreamTexts.Should().Contain(FirstAgentResponseText);
downstreamTexts.Should().NotContain(ContextText);
}
/// <summary>Builds an agent whose IChatClient always replies with <see cref="FirstAgentResponseText"/>, prepopulated with a <see cref="StaticAIContextProvider"/>.</summary>
private static ChatClientAgent CreateContextAwareAgent(ChatHistoryProvider? chatHistoryProvider = null)
{
return new ChatClientAgent(
new StubChatClient(_ => new ChatResponse([new ChatMessage(ChatRole.Assistant, FirstAgentResponseText)])),
new ChatClientAgentOptions
{
Name = "inner",
ChatHistoryProvider = chatHistoryProvider,
AIContextProviders = [new StaticAIContextProvider(ContextText)]
});
}
/// <summary>Always injects a single System message containing the configured text.</summary>
private sealed class StaticAIContextProvider(string text) : AIContextProvider
{
protected override ValueTask<AIContext> ProvideAIContextAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new(new AIContext
{
Messages = [new ChatMessage(ChatRole.System, text)]
});
}
}
private sealed class RetainingChatHistoryProvider(int maxStoredMessages) : ChatHistoryProvider
{
private readonly List<ChatMessage> _messages = [];
public void Add(ChatMessage message)
{
this._messages.Add(message);
}
protected override ValueTask<IEnumerable<ChatMessage>> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default)
{
return new(this._messages.Concat(context.RequestMessages));
}
protected override ValueTask StoreChatHistoryAsync(InvokedContext context, CancellationToken cancellationToken = default)
{
this._messages.AddRange(context.RequestMessages);
if (context.ResponseMessages is not null)
{
this._messages.AddRange(context.ResponseMessages);
}
if (this._messages.Count > maxStoredMessages)
{
this._messages.RemoveRange(0, this._messages.Count - maxStoredMessages);
}
return default;
}
}
/// <summary>Test double for <see cref="IChatClient"/> that returns deterministic responses via the supplied factory.</summary>
private sealed class StubChatClient(Func<IEnumerable<ChatMessage>, ChatResponse> responseFactory) : IChatClient
{
public Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> Task.FromResult(responseFactory(messages));
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
IEnumerable<ChatMessage> messages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ChatResponse response = await this.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
foreach (ChatResponseUpdate update in response.ToChatResponseUpdates())
{
yield return update;
}
}
public object? GetService(Type serviceType, object? serviceKey = null) => null;
public void Dispose()
{
}
}
}