Skip to content

Commit 71cbc22

Browse files
committed
[Milky] Support Bearer token authentication for WebSocket
1 parent d6ef45c commit 71cbc22

2 files changed

Lines changed: 105 additions & 102 deletions

File tree

Lagrange.Milky/Api/MilkyHttpApiService.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public Task StartAsync(CancellationToken token)
3333
_listener.Prefixes.Add($"http://{_host}:{_port}{_prefix}/");
3434
_listener.Start();
3535

36-
foreach (var prefix in _listener.Prefixes) _logger.LogServerRunning(prefix);
36+
foreach (string prefix in _listener.Prefixes) _logger.LogServerRunning(prefix);
3737

3838
_cts = CancellationTokenSource.CreateLinkedTokenSource(token);
3939
_task = GetHttpContextLoopAsync(_cts.Token);
@@ -65,8 +65,8 @@ private async Task HandleHttpContextAsync(HttpListenerContext context, Cancellat
6565
var request = context.Request;
6666
var identifier = request.RequestTraceIdentifier;
6767
var remote = request.RemoteEndPoint;
68-
var method = request.HttpMethod;
69-
var rawUrl = request.RawUrl;
68+
string method = request.HttpMethod;
69+
string? rawUrl = request.RawUrl;
7070

7171
try
7272
{
@@ -77,10 +77,10 @@ private async Task HandleHttpContextAsync(HttpListenerContext context, Cancellat
7777
var handler = await GetApiHandlerAsync(context, token);
7878
if (handler == null) return;
7979

80-
var parameter = await GetParameterAsync(context, handler.ParameterType, token);
80+
object? parameter = await GetParameterAsync(context, handler.ParameterType, token);
8181
if (parameter == null) return;
8282

83-
var result = await GetResultAsync(context, handler, parameter, token);
83+
object? result = await GetResultAsync(context, handler, parameter, token);
8484
if (result == null) return;
8585

8686
await SendWithLoggerAsync(context, result, token);

Lagrange.Milky/Event/MilkyWebSocketEventService.cs

Lines changed: 100 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,45 @@ public class MilkyWebSocketEventService(ILogger<MilkyWebSocketEventService> logg
2222

2323
private readonly HttpListener _listener = new();
2424
private readonly ConcurrentDictionary<ConnectionContext, object?> _connections = [];
25-
private CancellationTokenSource? _cts;
25+
2626
private Task? _task;
27+
private CancellationTokenSource? _cts;
2728

28-
public Task StartAsync(CancellationToken token)
29+
public Task StartAsync(CancellationToken ct)
2930
{
3031
_listener.Prefixes.Add($"http://{_host}:{_port}{_path}/");
3132
_listener.Start();
3233

33-
foreach (var prefix in _listener.Prefixes) _logger.LogServerRunning(prefix);
34+
foreach (string prefix in _listener.Prefixes) _logger.LogServerRunning(prefix);
3435

35-
_cts = CancellationTokenSource.CreateLinkedTokenSource(token);
36+
_cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
3637
_task = GetHttpContextLoopAsync(_cts.Token);
3738

3839
_event.Register(HandleEventAsync);
3940

4041
return Task.CompletedTask;
4142
}
4243

43-
private async Task GetHttpContextLoopAsync(CancellationToken token)
44+
public async Task StopAsync(CancellationToken ct)
45+
{
46+
_event.Unregister(HandleEventAsync);
47+
48+
_cts?.Cancel();
49+
if (_task != null) await _task.WaitAsync(ct).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
50+
await Task.WhenAll(_connections.Keys.Select(connection => connection.Tcs.Task));
51+
52+
_listener.Stop();
53+
}
54+
55+
private async Task? GetHttpContextLoopAsync(CancellationToken ct)
4456
{
4557
try
4658
{
4759
while (true)
4860
{
49-
_ = HandleHttpContextAsync(await _listener.GetContextAsync().WaitAsync(token), token);
61+
_ = HandleHttpContextAsync(await _listener.GetContextAsync().WaitAsync(ct), ct);
5062

51-
token.ThrowIfCancellationRequested();
63+
ct.ThrowIfCancellationRequested();
5264
}
5365
}
5466
catch (OperationCanceledException) { throw; }
@@ -59,9 +71,9 @@ private async Task GetHttpContextLoopAsync(CancellationToken token)
5971
}
6072
}
6173

62-
private async Task HandleHttpContextAsync(HttpListenerContext httpContext, CancellationToken token)
74+
private async Task HandleHttpContextAsync(HttpListenerContext context, CancellationToken ct)
6375
{
64-
var request = httpContext.Request;
76+
var request = context.Request;
6577
var identifier = request.RequestTraceIdentifier;
6678
var remote = request.RemoteEndPoint;
6779
string method = request.HttpMethod;
@@ -71,26 +83,26 @@ private async Task HandleHttpContextAsync(HttpListenerContext httpContext, Cance
7183
{
7284
_logger.LogHttpContext(identifier, remote, method, rawUrl);
7385

74-
if (!await ValidateHttpContextAsync(httpContext, token)) return;
86+
if (!await ValidateHttpContextAsync(context, ct)) return;
7587

76-
var connection = await GetConnectionContextAsync(httpContext, token);
88+
var connection = await GetConnectionContextAsync(context, ct);
7789
if (connection == null) return;
7890

7991
_ = WaitConnectionCloseLoopAsync(connection, connection.Cts.Token);
8092
}
8193
catch (OperationCanceledException)
8294
{
83-
await SendWithLoggerAsync(httpContext, HttpStatusCode.InternalServerError, token);
95+
await SendWithLoggerAsync(context, HttpStatusCode.InternalServerError, ct);
8496
throw;
8597
}
8698
catch (Exception e)
8799
{
88100
_logger.LogHandleHttpContextException(identifier, remote, e);
89-
await SendWithLoggerAsync(httpContext, HttpStatusCode.InternalServerError, token);
101+
await SendWithLoggerAsync(context, HttpStatusCode.InternalServerError, ct);
90102
}
91103
}
92104

93-
private async Task WaitConnectionCloseLoopAsync(ConnectionContext connection, CancellationToken token)
105+
private async Task WaitConnectionCloseLoopAsync(ConnectionContext connection, CancellationToken ct)
94106
{
95107
var identifier = connection.HttpContext.Request.RequestTraceIdentifier;
96108
var remote = connection.HttpContext.Request.RemoteEndPoint;
@@ -100,20 +112,20 @@ private async Task WaitConnectionCloseLoopAsync(ConnectionContext connection, Ca
100112
byte[] buffer = new byte[1024];
101113
while (true)
102114
{
103-
ValueTask<ValueWebSocketReceiveResult> resultTask = connection.WsContext.WebSocket
115+
var resultTask = connection.WsContext.WebSocket
104116
.ReceiveAsync(buffer.AsMemory(), default);
105117

106-
ValueWebSocketReceiveResult result = !resultTask.IsCompleted ?
107-
await resultTask.AsTask().WaitAsync(token) :
118+
var result = !resultTask.IsCompleted ?
119+
await resultTask.AsTask().WaitAsync(ct) :
108120
resultTask.Result;
109121

110122
if (result.MessageType == WebSocketMessageType.Close)
111123
{
112-
await CloseConnectionAsync(connection, WebSocketCloseStatus.NormalClosure, token);
124+
await CloseConnectionAsync(connection, WebSocketCloseStatus.NormalClosure, ct);
113125
return;
114126
}
115127

116-
token.ThrowIfCancellationRequested();
128+
ct.ThrowIfCancellationRequested();
117129
}
118130
}
119131
catch (OperationCanceledException)
@@ -124,11 +136,11 @@ await resultTask.AsTask().WaitAsync(token) :
124136
{
125137
_logger.LogWaitWebSocketCloseException(identifier, remote, e);
126138

127-
await CloseConnectionAsync(connection, WebSocketCloseStatus.InternalServerError, token);
139+
await CloseConnectionAsync(connection, WebSocketCloseStatus.InternalServerError, ct);
128140
}
129141
}
130142

131-
private async Task CloseConnectionAsync(ConnectionContext connection, WebSocketCloseStatus status, CancellationToken token)
143+
private async Task CloseConnectionAsync(ConnectionContext connection, WebSocketCloseStatus status, CancellationToken ct)
132144
{
133145
var identifier = connection.HttpContext.Request.RequestTraceIdentifier;
134146
var remote = connection.HttpContext.Request.RemoteEndPoint;
@@ -137,7 +149,7 @@ private async Task CloseConnectionAsync(ConnectionContext connection, WebSocketC
137149
{
138150
_connections.Remove(connection, out _);
139151

140-
await connection.WsContext.WebSocket.CloseAsync(status, null, token);
152+
await connection.WsContext.WebSocket.CloseAsync(status, null, ct);
141153
connection.HttpContext.Response.Close();
142154

143155
_logger.LogWebSocketClosed(identifier, remote);
@@ -152,117 +164,79 @@ private async Task CloseConnectionAsync(ConnectionContext connection, WebSocketC
152164
}
153165
}
154166

155-
private async void HandleEventAsync(Memory<byte> payload)
167+
private async Task<bool> ValidateHttpContextAsync(HttpListenerContext context, CancellationToken ct)
156168
{
157-
if (_connections.IsEmpty) return;
158-
159-
_logger.LogSend(payload.Span);
160-
foreach (var connection in _connections.Keys)
161-
{
162-
var identifier = connection.HttpContext.Request.RequestTraceIdentifier;
163-
var remote = connection.HttpContext.Request.RemoteEndPoint;
164-
var ws = connection.WsContext.WebSocket;
165-
166-
try
167-
{
168-
await connection.SendSemaphoreSlim.WaitAsync(connection.Cts.Token);
169-
try
170-
{
171-
await ws.SendAsync(payload, WebSocketMessageType.Text, true, connection.Cts.Token);
172-
}
173-
finally
174-
{
175-
connection.SendSemaphoreSlim.Release();
176-
}
177-
}
178-
catch (Exception e)
179-
{
180-
_logger.LogSendException(identifier, remote, e);
181-
182-
await CloseConnectionAsync(connection, WebSocketCloseStatus.InternalServerError, connection.Cts.Token);
183-
}
184-
}
185-
}
186-
187-
public async Task StopAsync(CancellationToken token)
188-
{
189-
_event.Unregister(HandleEventAsync);
190-
191-
_cts?.Cancel();
192-
if (_task != null) await _task.WaitAsync(token).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
193-
await Task.WhenAll(_connections.Keys.Select(connection => connection.Tcs.Task));
194-
195-
_listener.Stop();
196-
}
197169

198-
private async Task<bool> ValidateHttpContextAsync(HttpListenerContext httpContext, CancellationToken token)
199-
{
200-
var request = httpContext.Request;
170+
var request = context.Request;
201171
var identifier = request.RequestTraceIdentifier;
202172
var remote = request.RemoteEndPoint;
203173

204174
if (request.Url?.LocalPath != _path)
205175
{
206-
await SendWithLoggerAsync(httpContext, HttpStatusCode.NotFound, token);
176+
await SendWithLoggerAsync(context, HttpStatusCode.NotFound, ct);
207177
}
208178

209-
if (!httpContext.Request.HttpMethod.Equals("GET", StringComparison.OrdinalIgnoreCase))
179+
if (!context.Request.HttpMethod.Equals("GET", StringComparison.OrdinalIgnoreCase))
210180
{
211-
await SendWithLoggerAsync(httpContext, HttpStatusCode.MethodNotAllowed, token);
181+
await SendWithLoggerAsync(context, HttpStatusCode.MethodNotAllowed, ct);
212182
return false;
213183
}
214184

215-
if (!ValidateApiAccessToken(httpContext))
185+
if (!request.IsWebSocketRequest)
216186
{
217-
_logger.LogValidateAccessTokenFailed(identifier, remote);
218-
await SendWithLoggerAsync(httpContext, HttpStatusCode.Unauthorized, token);
187+
await SendWithLoggerAsync(context, HttpStatusCode.BadRequest, ct);
219188
return false;
220189
}
221190

222-
if (!request.IsWebSocketRequest)
191+
if (!ValidateAccessToken(context))
223192
{
224-
await SendWithLoggerAsync(httpContext, HttpStatusCode.BadRequest, token);
193+
_logger.LogValidateAccessTokenFailed(identifier, remote);
194+
await SendWithLoggerAsync(context, HttpStatusCode.Unauthorized, ct);
225195
return false;
226196
}
227197

228198
return true;
229199
}
230200

231-
private bool ValidateApiAccessToken(HttpListenerContext httpContext)
201+
private bool ValidateAccessToken(HttpListenerContext context)
232202
{
233-
if (_token == null) return true;
203+
if (string.IsNullOrEmpty(_token)) return true;
234204

235-
string? authorization = httpContext.Request.QueryString["access_token"];
236-
if (authorization == null) return false;
205+
string? authorization = context.Request.Headers["Authorization"];
206+
if (authorization != null)
207+
{
208+
if (!authorization.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) return false;
209+
return authorization.AsSpan(7..).Equals(_token);
210+
}
211+
212+
string? accessToken = context.Request.QueryString["access_token"];
213+
if (accessToken != null)
214+
{
215+
return accessToken.Equals(_token);
216+
}
237217

238-
return authorization == _token;
218+
return false;
239219
}
240220

241-
private async Task<ConnectionContext?> GetConnectionContextAsync(HttpListenerContext httpContext, CancellationToken token)
221+
private async Task<ConnectionContext?> GetConnectionContextAsync(HttpListenerContext context, CancellationToken ct)
242222
{
243-
var request = httpContext.Request;
244-
var identifier = request.RequestTraceIdentifier;
245-
var remote = request.RemoteEndPoint;
246-
247223
try
248224
{
249-
var wsContext = await httpContext.AcceptWebSocketAsync(null).WaitAsync(token);
250-
var cts = CancellationTokenSource.CreateLinkedTokenSource(token);
251-
var connection = new ConnectionContext(httpContext, wsContext, cts);
225+
var wsContext = await context.AcceptWebSocketAsync(null).WaitAsync(ct);
226+
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
227+
var connection = new ConnectionContext(context, wsContext, cts);
252228
_connections.TryAdd(connection, null);
253229
return connection;
254230
}
255231
catch (OperationCanceledException) { throw; }
256-
catch (Exception e)
232+
catch (Exception)
257233
{
258-
_logger.LogUpgradeWebSocketException(identifier, remote, e);
259-
await SendWithLoggerAsync(httpContext, HttpStatusCode.InternalServerError, token);
234+
await SendWithLoggerAsync(context, HttpStatusCode.InternalServerError, ct);
235+
throw;
260236
}
261-
262-
return null;
263237
}
264238

265-
private async Task SendWithLoggerAsync(HttpListenerContext context, HttpStatusCode status, CancellationToken token)
239+
private async Task SendWithLoggerAsync(HttpListenerContext context, HttpStatusCode status, CancellationToken ct)
266240
{
267241
var request = context.Request;
268242
var identifier = request.RequestTraceIdentifier;
@@ -276,7 +250,7 @@ private async Task SendWithLoggerAsync(HttpListenerContext context, HttpStatusCo
276250
int code = (int)status;
277251

278252
response.StatusCode = code;
279-
await output.WriteAsync(Encoding.UTF8.GetBytes($"{code} {status}"), token);
253+
await output.WriteAsync(Encoding.UTF8.GetBytes($"{code} {status}"), ct);
280254
response.Close();
281255

282256
_logger.LogSend(identifier, remote, status);
@@ -287,6 +261,38 @@ private async Task SendWithLoggerAsync(HttpListenerContext context, HttpStatusCo
287261
}
288262
}
289263

264+
private async void HandleEventAsync(Memory<byte> payload)
265+
{
266+
if (_connections.IsEmpty) return;
267+
268+
_logger.LogSend(payload.Span);
269+
foreach (var connection in _connections.Keys)
270+
{
271+
var identifier = connection.HttpContext.Request.RequestTraceIdentifier;
272+
var remote = connection.HttpContext.Request.RemoteEndPoint;
273+
var ws = connection.WsContext.WebSocket;
274+
275+
try
276+
{
277+
await connection.SendSemaphoreSlim.WaitAsync(connection.Cts.Token);
278+
try
279+
{
280+
await ws.SendAsync(payload, WebSocketMessageType.Text, true, connection.Cts.Token);
281+
}
282+
finally
283+
{
284+
connection.SendSemaphoreSlim.Release();
285+
}
286+
}
287+
catch (Exception e)
288+
{
289+
_logger.LogSendException(identifier, remote, e);
290+
291+
await CloseConnectionAsync(connection, WebSocketCloseStatus.InternalServerError, connection.Cts.Token);
292+
}
293+
}
294+
}
295+
290296
private class ConnectionContext(HttpListenerContext httpContext, WebSocketContext wsContext, CancellationTokenSource cts)
291297
{
292298
public HttpListenerContext HttpContext { get; } = httpContext;
@@ -333,9 +339,6 @@ public static void LogSend(this ILogger<MilkyWebSocketEventService> logger, Span
333339
[LoggerMessage(LogLevel.Error, "{identifier} {remote} <!!> Handle http context failed")]
334340
public static partial void LogHandleHttpContextException(this ILogger<MilkyWebSocketEventService> logger, Guid identifier, IPEndPoint remote, Exception e);
335341

336-
[LoggerMessage(LogLevel.Error, "{identifier} {remote} <!!> Upgrade websocket failed")]
337-
public static partial void LogUpgradeWebSocketException(this ILogger<MilkyWebSocketEventService> logger, Guid identifier, IPEndPoint remote, Exception e);
338-
339342
[LoggerMessage(LogLevel.Error, "{identifier} {remote} <!!> Validate access token failed")]
340343
public static partial void LogValidateAccessTokenFailed(this ILogger<MilkyWebSocketEventService> logger, Guid identifier, IPEndPoint remote);
341344

0 commit comments

Comments
 (0)