forked from modelcontextprotocol/csharp-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDelegatingChannelReader.cs
More file actions
149 lines (127 loc) · 4.57 KB
/
DelegatingChannelReader.cs
File metadata and controls
149 lines (127 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Channels;
namespace ModelContextProtocol.Client;
/// <summary>
/// A <see cref="ChannelReader{T}"/> implementation that delegates to another reader
/// after a connection has been established.
/// </summary>
/// <typeparam name="T">The type of data in the channel.</typeparam>
internal sealed class DelegatingChannelReader<T> : ChannelReader<T>
{
private readonly TaskCompletionSource<bool> _connectionEstablished;
private readonly AutoDetectingClientSessionTransport _parent;
public DelegatingChannelReader(AutoDetectingClientSessionTransport parent)
{
_parent = parent;
_connectionEstablished = new TaskCompletionSource<bool>();
}
/// <summary>
/// Signals that the transport has been established and operations can proceed.
/// </summary>
public void SetConnected()
{
_connectionEstablished.TrySetResult(true);
}
/// <summary>
/// Sets the error if connection couldn't be established.
/// </summary>
public void SetError(Exception exception)
{
_connectionEstablished.TrySetException(exception);
}
/// <summary>
/// Gets the channel reader to delegate to.
/// </summary>
private ChannelReader<T> GetReader()
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
throw new InvalidOperationException("Transport connection not yet established.");
}
return (_parent.ActiveTransport?.MessageReader as ChannelReader<T>)!;
}
#if !NETSTANDARD2_0
/// <inheritdoc/>
public override bool CanCount => GetReader().CanCount;
/// <inheritdoc/>
public override bool CanPeek => GetReader().CanPeek;
/// <inheritdoc/>
public override int Count => GetReader().Count;
#endif
/// <inheritdoc/>
public override bool TryPeek(out T item)
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
item = default!;
return false;
}
return GetReader().TryPeek(out item!);
}
/// <inheritdoc/>
public override bool TryRead(out T item)
{
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
item = default!;
return false;
}
return GetReader().TryRead(out item!);
}
/// <inheritdoc/>
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<bool>(WaitForConnectionAndThenReadAsync(cancellationToken));
}
// Then delegate to the active reader
return GetReader().WaitToReadAsync(cancellationToken);
}
private async Task<bool> WaitForConnectionAndThenReadAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().WaitToReadAsync(cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public override ValueTask<T> ReadAsync(CancellationToken cancellationToken = default)
{
// First wait for the connection to be established
if (_connectionEstablished.Task.Status != TaskStatus.RanToCompletion)
{
return new ValueTask<T>(WaitForConnectionAndThenGetItemAsync(cancellationToken));
}
// Then delegate to the active reader
return GetReader().ReadAsync(cancellationToken);
}
private async Task<T> WaitForConnectionAndThenGetItemAsync(CancellationToken cancellationToken)
{
await _connectionEstablished.Task.ConfigureAwait(false);
return await GetReader().ReadAsync(cancellationToken).ConfigureAwait(false);
}
#if NETSTANDARD2_0
public IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
// Create a simple async enumerable implementation
async IAsyncEnumerable<T> ReadAllAsyncImplementation()
{
while (await WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
while (TryRead(out var item))
{
yield return item;
}
}
}
return ReadAllAsyncImplementation();
}
#else
/// <inheritdoc/>
public override IAsyncEnumerable<T> ReadAllAsync(CancellationToken cancellationToken = default)
{
return base.ReadAllAsync(cancellationToken);
}
#endif
}