Skip to content

Commit befd348

Browse files
committed
feat(protocol): add ProtocolNegotiatingStateMachine with ALPN and preface detection
1 parent 45ba63e commit befd348

3 files changed

Lines changed: 353 additions & 0 deletions

File tree

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
using System.Net;
2+
using System.Net.Security;
3+
using System.Security.Authentication;
4+
using Akka.Actor;
5+
using Akka.Event;
6+
using Servus.Akka.Transport;
7+
using TurboHTTP.Protocol;
8+
using TurboHTTP.Protocol.Syntax.Http11.Server;
9+
using TurboHTTP.Protocol.Syntax.Http2.Server;
10+
using TurboHTTP.Server;
11+
using TurboHTTP.Streams.Stages.Server;
12+
13+
namespace TurboHTTP.Tests.Protocol;
14+
15+
public sealed class ProtocolNegotiatingStateMachineSpec
16+
{
17+
private sealed class FakeServerOps : IServerStageOperations
18+
{
19+
public List<HttpRequestMessage> EmittedRequests { get; } = [];
20+
public List<ITransportOutbound> EmittedOutbound { get; } = [];
21+
public List<string> ScheduledTimers { get; } = [];
22+
public ILoggingAdapter Log { get; } = NoLogger.Instance;
23+
public IActorRef StageActor { get; set; } = ActorRefs.Nobody;
24+
25+
public void OnRequest(HttpRequestMessage request) => EmittedRequests.Add(request);
26+
public void OnOutbound(ITransportOutbound item) => EmittedOutbound.Add(item);
27+
public void OnScheduleTimer(string name, TimeSpan delay) => ScheduledTimers.Add(name);
28+
public void OnCancelTimer(string name) { }
29+
}
30+
31+
private static TransportConnected MakeConnected(SslApplicationProtocol? alpn = null)
32+
{
33+
SecurityInfo? security = alpn is not null
34+
? new SecurityInfo(SslProtocols.Tls13, alpn.Value)
35+
: null;
36+
37+
var info = new ConnectionInfo(
38+
new IPEndPoint(IPAddress.Loopback, 443),
39+
new IPEndPoint(IPAddress.Loopback, 50000),
40+
alpn is not null ? TransportProtocol.Tls : TransportProtocol.Tcp,
41+
security);
42+
43+
return new TransportConnected(info);
44+
}
45+
46+
private static TransportData MakeData(byte[] data)
47+
{
48+
var buffer = TransportBuffer.Rent(data.Length);
49+
data.CopyTo(buffer.FullMemory.Span);
50+
buffer.Length = data.Length;
51+
return new TransportData(buffer);
52+
}
53+
54+
// Task 2: ALPN Detection Tests
55+
56+
[Fact(Timeout = 5000)]
57+
public void DecodeClientData_should_select_http2_for_alpn_h2()
58+
{
59+
var ops = new FakeServerOps();
60+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
61+
62+
sm.DecodeClientData(MakeConnected(SslApplicationProtocol.Http2));
63+
64+
Assert.True(sm.CanAcceptResponse || !sm.ShouldComplete);
65+
Assert.Contains("keep-alive-timeout", ops.ScheduledTimers);
66+
}
67+
68+
[Fact(Timeout = 5000)]
69+
public void DecodeClientData_should_select_http11_for_alpn_http11()
70+
{
71+
var ops = new FakeServerOps();
72+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
73+
74+
sm.DecodeClientData(MakeConnected(SslApplicationProtocol.Http11));
75+
76+
Assert.False(sm.CanAcceptResponse);
77+
Assert.False(sm.ShouldComplete);
78+
}
79+
80+
[Fact(Timeout = 5000)]
81+
public void DecodeClientData_should_select_http11_for_default_alpn()
82+
{
83+
var ops = new FakeServerOps();
84+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
85+
86+
sm.DecodeClientData(MakeConnected(default(SslApplicationProtocol)));
87+
88+
Assert.False(sm.CanAcceptResponse);
89+
Assert.False(sm.ShouldComplete);
90+
}
91+
92+
// Task 3: Preface Sniffing Tests
93+
94+
[Fact(Timeout = 5000)]
95+
public void DecodeClientData_should_select_http2_for_pri_preface()
96+
{
97+
var ops = new FakeServerOps();
98+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
99+
100+
sm.DecodeClientData(MakeConnected());
101+
102+
var preface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"u8.ToArray();
103+
sm.DecodeClientData(MakeData(preface));
104+
105+
Assert.Contains("keep-alive-timeout", ops.ScheduledTimers);
106+
}
107+
108+
[Fact(Timeout = 5000)]
109+
public void DecodeClientData_should_select_http11_for_get_request()
110+
{
111+
var ops = new FakeServerOps();
112+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
113+
114+
sm.DecodeClientData(MakeConnected());
115+
116+
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"u8.ToArray();
117+
sm.DecodeClientData(MakeData(request));
118+
119+
Assert.Single(ops.EmittedRequests);
120+
Assert.Equal("GET", ops.EmittedRequests[0].Method.Method);
121+
}
122+
123+
[Fact(Timeout = 5000)]
124+
public void DecodeClientData_should_select_http11_for_post_request()
125+
{
126+
var ops = new FakeServerOps();
127+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
128+
129+
sm.DecodeClientData(MakeConnected());
130+
131+
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"u8.ToArray();
132+
sm.DecodeClientData(MakeData(request));
133+
134+
Assert.Single(ops.EmittedRequests);
135+
Assert.Equal("POST", ops.EmittedRequests[0].Method.Method);
136+
}
137+
138+
[Fact(Timeout = 5000)]
139+
public void DecodeClientData_should_stay_sniffing_for_insufficient_data()
140+
{
141+
var ops = new FakeServerOps();
142+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
143+
144+
sm.DecodeClientData(MakeConnected());
145+
sm.DecodeClientData(MakeData("PR"u8.ToArray()));
146+
147+
Assert.False(sm.CanAcceptResponse);
148+
Assert.False(sm.ShouldComplete);
149+
Assert.Empty(ops.EmittedRequests);
150+
Assert.Empty(ops.ScheduledTimers);
151+
}
152+
153+
[Fact(Timeout = 5000)]
154+
public void Cleanup_should_dispose_buffered_data()
155+
{
156+
var ops = new FakeServerOps();
157+
var sm = new ProtocolNegotiatingStateMachine(new TurboServerOptions(), ops);
158+
159+
sm.DecodeClientData(MakeConnected());
160+
sm.Cleanup();
161+
162+
Assert.False(sm.ShouldComplete);
163+
}
164+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using TurboHTTP.Streams.Stages.Server;
2+
3+
namespace TurboHTTP.Protocol;
4+
5+
internal interface IProtocolSwitchCapable
6+
{
7+
void RequestProtocolSwitch(
8+
Func<IServerStageOperations, IServerStateMachine> newSmFactory);
9+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
using System.Net.Security;
2+
using Akka.Actor;
3+
using Akka.Event;
4+
using Servus.Akka.Transport;
5+
using TurboHTTP.Protocol.Syntax.Http11.Server;
6+
using TurboHTTP.Protocol.Syntax.Http2.Server;
7+
using TurboHTTP.Server;
8+
using TurboHTTP.Streams.Stages.Server;
9+
10+
namespace TurboHTTP.Protocol;
11+
12+
internal sealed class ProtocolNegotiatingStateMachine : IServerStateMachine
13+
{
14+
private enum Phase { WaitingForConnect, Sniffing, Running }
15+
16+
private readonly TurboServerOptions _options;
17+
private readonly UpgradeAwareOps _wrappedOps;
18+
19+
private Phase _phase = Phase.WaitingForConnect;
20+
private IServerStateMachine? _inner;
21+
private readonly List<ITransportInbound> _buffered = [];
22+
23+
public bool CanAcceptResponse => _phase == Phase.Running && _inner!.CanAcceptResponse;
24+
public bool ShouldComplete => _phase == Phase.Running && _inner!.ShouldComplete;
25+
26+
public ProtocolNegotiatingStateMachine(TurboServerOptions options, IServerStageOperations ops)
27+
{
28+
_options = options;
29+
_wrappedOps = new UpgradeAwareOps(ops, this);
30+
}
31+
32+
public void PreStart()
33+
{
34+
if (_phase == Phase.Running)
35+
{
36+
_inner!.PreStart();
37+
}
38+
}
39+
40+
public void DecodeClientData(ITransportInbound data)
41+
{
42+
switch (_phase)
43+
{
44+
case Phase.WaitingForConnect:
45+
OnWaitingForConnect(data);
46+
break;
47+
case Phase.Sniffing:
48+
OnSniffing(data);
49+
break;
50+
case Phase.Running:
51+
_inner!.DecodeClientData(data);
52+
break;
53+
}
54+
}
55+
56+
public void OnResponse(HttpResponseMessage response) => _inner!.OnResponse(response);
57+
public void OnDownstreamFinished() => _inner?.OnDownstreamFinished();
58+
public void OnTimerFired(string name) => _inner?.OnTimerFired(name);
59+
public void OnBodyMessage(object msg) => _inner?.OnBodyMessage(msg);
60+
61+
public void Cleanup()
62+
{
63+
_inner?.Cleanup();
64+
DisposeBuffered();
65+
}
66+
67+
private void OnWaitingForConnect(ITransportInbound data)
68+
{
69+
if (data is not TransportConnected { Info.Security: var security })
70+
{
71+
return;
72+
}
73+
74+
if (security?.ApplicationProtocol == SslApplicationProtocol.Http2)
75+
{
76+
Activate(ops => new Http2ServerStateMachine(_options, ops));
77+
_inner!.DecodeClientData(data);
78+
return;
79+
}
80+
81+
if (security is not null)
82+
{
83+
Activate(ops => new Http11ServerStateMachine(_options, ops));
84+
_inner!.DecodeClientData(data);
85+
return;
86+
}
87+
88+
_buffered.Add(data);
89+
_phase = Phase.Sniffing;
90+
}
91+
92+
private void OnSniffing(ITransportInbound data)
93+
{
94+
_buffered.Add(data);
95+
96+
if (data is not TransportData { Buffer: var buffer })
97+
{
98+
return;
99+
}
100+
101+
var span = buffer.Memory.Span;
102+
if (span.Length < 4)
103+
{
104+
return;
105+
}
106+
107+
if (span[0] == 'P' && span[1] == 'R' && span[2] == 'I' && span[3] == ' ')
108+
{
109+
Activate(ops => new Http2ServerStateMachine(_options, ops));
110+
}
111+
else
112+
{
113+
Activate(ops => new Http11ServerStateMachine(_options, ops));
114+
}
115+
116+
ReplayBuffered();
117+
}
118+
119+
private void Activate(Func<IServerStageOperations, IServerStateMachine> factory)
120+
{
121+
_inner = factory(_wrappedOps);
122+
_phase = Phase.Running;
123+
_inner.PreStart();
124+
}
125+
126+
private void ReplayBuffered()
127+
{
128+
var buffered = _buffered.ToArray();
129+
_buffered.Clear();
130+
131+
foreach (var item in buffered)
132+
{
133+
_inner!.DecodeClientData(item);
134+
}
135+
}
136+
137+
private void DisposeBuffered()
138+
{
139+
foreach (var item in _buffered)
140+
{
141+
if (item is TransportData { Buffer: var buf })
142+
{
143+
buf.Dispose();
144+
}
145+
}
146+
147+
_buffered.Clear();
148+
}
149+
150+
internal void HandleUpgrade(Func<IServerStageOperations, IServerStateMachine> newSmFactory)
151+
{
152+
_inner?.Cleanup();
153+
_inner = newSmFactory(_wrappedOps);
154+
_inner.PreStart();
155+
}
156+
157+
private sealed class UpgradeAwareOps : IServerStageOperations, IProtocolSwitchCapable
158+
{
159+
private readonly IServerStageOperations _real;
160+
private readonly ProtocolNegotiatingStateMachine _parent;
161+
162+
public UpgradeAwareOps(IServerStageOperations real, ProtocolNegotiatingStateMachine parent)
163+
{
164+
_real = real;
165+
_parent = parent;
166+
}
167+
168+
public void OnRequest(HttpRequestMessage request) => _real.OnRequest(request);
169+
public void OnOutbound(ITransportOutbound item) => _real.OnOutbound(item);
170+
public void OnScheduleTimer(string name, TimeSpan delay) => _real.OnScheduleTimer(name, delay);
171+
public void OnCancelTimer(string name) => _real.OnCancelTimer(name);
172+
public ILoggingAdapter Log => _real.Log;
173+
public IActorRef StageActor => _real.StageActor;
174+
175+
public void RequestProtocolSwitch(Func<IServerStageOperations, IServerStateMachine> newSmFactory)
176+
{
177+
_parent.HandleUpgrade(newSmFactory);
178+
}
179+
}
180+
}

0 commit comments

Comments
 (0)