-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Expand file tree
/
Copy pathFrameworkShims.Sockets.cs
More file actions
161 lines (136 loc) · 6.09 KB
/
FrameworkShims.Sockets.cs
File metadata and controls
161 lines (136 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#if !NET
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// ReSharper disable once CheckNamespace
namespace System.Net.Sockets;
internal static class SocketExtensions
{
internal static async ValueTask ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken = default)
{
// this API is only used during handshake, *not* core IO, so: we're not concerned about alloc overhead
using var args = new SocketAwaitableEventArgs(SocketFlags.None, cancellationToken);
args.RemoteEndPoint = remoteEP;
if (!socket.ConnectAsync(args))
{
args.Complete();
}
await args;
}
internal static async ValueTask<int> SendAsync(this Socket socket, ReadOnlyMemory<byte> buffer, SocketFlags socketFlags, CancellationToken cancellationToken = default)
{
// this API is only used during handshake, *not* core IO, so: we're not concerned about alloc overhead
using var args = new SocketAwaitableEventArgs(socketFlags, cancellationToken);
args.SetBuffer(buffer);
if (!socket.SendAsync(args))
{
args.Complete();
}
return await args;
}
internal static async ValueTask<int> ReceiveAsync(this Socket socket, Memory<byte> buffer, SocketFlags socketFlags, CancellationToken cancellationToken = default)
{
// this API is only used during handshake, *not* core IO, so: we're not concerned about alloc overhead
using var args = new SocketAwaitableEventArgs(socketFlags, cancellationToken);
args.SetBuffer(buffer);
if (!socket.ReceiveAsync(args))
{
args.Complete();
}
return await args;
}
/// <summary>
/// Awaitable SocketAsyncEventArgs, where awaiting the args yields either the BytesTransferred or throws the relevant socket exception,
/// plus support for cancellation via <see cref="SocketError.TimedOut"/>.
/// </summary>
private sealed class SocketAwaitableEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion, IDisposable
{
public new void Dispose()
{
cancelRegistration.Dispose();
base.Dispose();
}
private CancellationTokenRegistration cancelRegistration;
public SocketAwaitableEventArgs(SocketFlags socketFlags, CancellationToken cancellationToken)
{
SocketFlags = socketFlags;
if (cancellationToken.CanBeCanceled)
{
cancellationToken.ThrowIfCancellationRequested();
cancelRegistration = cancellationToken.Register(Timeout);
}
}
public void SetBuffer(ReadOnlyMemory<byte> buffer)
{
if (!MemoryMarshal.TryGetArray(buffer, out var segment)) ThrowNotSupported();
SetBuffer(segment.Array ?? [], segment.Offset, segment.Count);
[DoesNotReturn]
static void ThrowNotSupported() => throw new NotSupportedException("Only array-backed buffers are supported");
}
public void Timeout() => Abort(SocketError.TimedOut);
public void Abort(SocketError error)
{
_forcedError = error;
OnCompleted(this);
}
private volatile SocketError _forcedError; // Success = 0, no field init required
// ReSharper disable once InconsistentNaming
private static readonly Action _callbackCompleted = () => { };
private Action? _callback;
public SocketAwaitableEventArgs GetAwaiter() => this;
/// <summary>
/// Indicates whether the current operation is complete; used as part of "await".
/// </summary>
public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted);
/// <summary>
/// Gets the result of the async operation is complete; used as part of "await".
/// </summary>
public int GetResult()
{
Debug.Assert(ReferenceEquals(_callback, _callbackCompleted));
_callback = null;
var error = _forcedError;
if (error is SocketError.Success) error = SocketError;
if (error is not SocketError.Success) ThrowSocketException(error);
return BytesTransferred;
static void ThrowSocketException(SocketError e) => throw new SocketException((int)e);
}
/// <summary>
/// Schedules a continuation for this operation; used as part of "await".
/// </summary>
public void OnCompleted(Action continuation)
{
if (ReferenceEquals(Volatile.Read(ref _callback), _callbackCompleted)
|| ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted))
{
// this is the rare "kinda already complete" case; push to worker to prevent possible stack dive,
// but prefer the custom scheduler when possible
RunOnThreadPool(continuation);
}
}
/// <summary>
/// Schedules a continuation for this operation; used as part of "await".
/// </summary>
public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation);
/// <summary>
/// Marks the operation as complete - this should be invoked whenever a SocketAsyncEventArgs operation returns false.
/// </summary>
public void Complete() => OnCompleted(this);
private static void RunOnThreadPool(Action action)
=> ThreadPool.QueueUserWorkItem(static state => ((Action)state).Invoke(), action);
/// <summary>
/// Invoked automatically when an operation completes asynchronously.
/// </summary>
protected override void OnCompleted(SocketAsyncEventArgs e)
{
var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted);
if (continuation is not null)
{
// continue on the thread-pool
RunOnThreadPool(continuation);
}
}
}
}
#endif