Skip to content

Commit 6f24522

Browse files
authored
Fix RESP3 client handshakes on non-RESP3 servers (#3037)
* failing tests for #2783 - `HELLO 3` with `3` response: pass - `HELLO 3` with `-ERR` response: fail - `HELLO 3` with `2` response: fail * more test permutations * don't run unnecessary RESP2-client permutations * fix RESP3 handshakes * words * words
1 parent f9af64f commit 6f24522

File tree

6 files changed

+150
-12
lines changed

6 files changed

+150
-12
lines changed

src/StackExchange.Redis/ResultProcessor.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2894,10 +2894,18 @@ public override bool SetResult(PhysicalConnection connection, Message message, i
28942894

28952895
if (connection.Protocol is null)
28962896
{
2897-
// if we didn't get a valid response from HELLO, then we have to assume RESP2 at some point
2897+
// If we didn't get a valid response from HELLO, then we have to assume RESP2 at some point.
2898+
// We need the protocol assigned before OnFullyEstablished so that the
2899+
// protocol is reliably known *before* we do next-steps.
28982900
connection.SetProtocol(RedisProtocol.Resp2);
28992901
}
29002902

2903+
if (final & establishConnection)
2904+
{
2905+
// This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection.
2906+
connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}");
2907+
}
2908+
29012909
return final;
29022910
}
29032911

@@ -2939,11 +2947,6 @@ protected override bool SetResultCore(PhysicalConnection connection, Message mes
29392947
}
29402948
if (happy)
29412949
{
2942-
if (establishConnection)
2943-
{
2944-
// This is what ultimately brings us to complete a connection, by advancing the state forward from a successful tracer after connection.
2945-
connection.BridgeCouldBeNull?.OnFullyEstablished(connection, $"From command: {message.Command}");
2946-
}
29472950
SetResult(message, happy);
29482951
return true;
29492952
}

src/StackExchange.Redis/ServerEndPoint.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,14 +696,20 @@ internal void OnFullyEstablished(PhysicalConnection connection, string source)
696696
// Clear the unselectable flag ASAP since we are open for business
697697
ClearUnselectable(UnselectableFlags.DidNotRespond);
698698

699-
bool isResp3 = KnowOrAssumeResp3();
699+
// is *this specific* connection using RESP3? (without reference to config preferences)
700+
bool isResp3 = connection?.Protocol is >= RedisProtocol.Resp3;
700701
if (bridge == subscription || isResp3)
701702
{
702703
// Note: this MUST be fire and forget, because we might be in the middle of a Sync processing
703704
// TracerProcessor which is executing this line inside a SetResultCore().
704705
// Since we're issuing commands inside a SetResult path in a message, we'd create a deadlock by waiting.
705706
Multiplexer.EnsureSubscriptions(CommandFlags.FireAndForget);
706707
}
708+
else if (SupportsSubscriptions && Multiplexer.RawConfig.Protocol > RedisProtocol.Resp2)
709+
{
710+
// interactive, and we wanted RESP3+, but we didn't get it; spin up pub/sub
711+
Activate(ConnectionType.Subscription, null);
712+
}
707713
if (IsConnected && (IsSubscriberConnected || !SupportsSubscriptions || isResp3))
708714
{
709715
// Only connect on the second leg - we can accomplish this by checking both

tests/StackExchange.Redis.Tests/InProcessTestServer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace StackExchange.Redis.Tests;
1717
public class InProcessTestServer : MemoryCacheRedisServer
1818
{
1919
private readonly ITestOutputHelper? _log;
20-
public InProcessTestServer(ITestOutputHelper? log = null)
20+
public InProcessTestServer(ITestOutputHelper? log = null, EndPoint? endpoint = null)
21+
: base(endpoint)
2122
{
2223
RedisVersion = RedisFeatures.v6_0_0; // for client to expect RESP3
2324
_log = log;
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Net;
5+
using System.Threading.Tasks;
6+
using StackExchange.Redis.Server;
7+
using Xunit;
8+
9+
namespace StackExchange.Redis.Tests;
10+
11+
public class Resp3HandshakeTests(ITestOutputHelper log)
12+
{
13+
public enum ServerResponse
14+
{
15+
Resp3, // up-level server style
16+
Resp2, // DMC hybrid style, i.e. we know about it, but: "no, you'll take RESP2"
17+
UnknownCommand, // down-level server style
18+
}
19+
20+
[Flags]
21+
public enum HandshakeFlags
22+
{
23+
None = 0,
24+
Authenticated = 1 << 0,
25+
TieBreaker = 1 << 1,
26+
ConfigChannel = 1 << 2,
27+
UsePubSub = 1 << 3,
28+
UseDatabase = 1 << 4,
29+
}
30+
31+
private static readonly int HandshakeFlagsCount = Enum.GetValues(typeof(HandshakeFlags)).Length - 1;
32+
public static IEnumerable<object[]> GetHandshakeParameters()
33+
{
34+
// all client protocols, all server-response modes; all flag permutations
35+
var clients = (RedisProtocol[])Enum.GetValues(typeof(RedisProtocol));
36+
var servers = (ServerResponse[])Enum.GetValues(typeof(ServerResponse));
37+
foreach (var client in clients)
38+
{
39+
foreach (var server in servers)
40+
{
41+
if (client is RedisProtocol.Resp2 & server is not ServerResponse.Resp2)
42+
{
43+
// we don't issue HELLO for this, nothing to test
44+
}
45+
else
46+
{
47+
int count = 1 << HandshakeFlagsCount;
48+
for (int i = 0; i < count; i++)
49+
{
50+
yield return [client, server, (HandshakeFlags)i];
51+
}
52+
}
53+
}
54+
}
55+
}
56+
57+
[Theory]
58+
[MemberData(nameof(GetHandshakeParameters))]
59+
public async Task Handshake(RedisProtocol client, ServerResponse server, HandshakeFlags flags)
60+
{
61+
using var serverObj = new HandshakeServer(server, log);
62+
serverObj.Password = (flags & HandshakeFlags.Authenticated) == 0 ? null : "mypassword";
63+
var config = serverObj.GetClientConfig();
64+
config.Protocol = client;
65+
config.TieBreaker = (flags & HandshakeFlags.TieBreaker) == 0 ? "" : "tiebreaker_key";
66+
config.ConfigurationChannel = (flags & HandshakeFlags.ConfigChannel) == 0 ? "" : "broadcast_channel";
67+
68+
using var clientObj = await ConnectionMultiplexer.ConnectAsync(config);
69+
70+
var sub = clientObj.GetSubscriber();
71+
var db = clientObj.GetDatabase();
72+
ConcurrentBag<string> received = [];
73+
RedisChannel channel = RedisChannel.Literal("mychannel");
74+
RedisKey key = "mykey";
75+
bool useDatabase = (flags & HandshakeFlags.UseDatabase) != 0;
76+
bool usePubSub = (flags & HandshakeFlags.UsePubSub) != 0;
77+
78+
if (usePubSub)
79+
{
80+
await sub.SubscribeAsync(channel, (x, y) => received.Add(y!));
81+
}
82+
if (useDatabase)
83+
{
84+
await db.StringSetAsync(key, "myvalue");
85+
}
86+
if (usePubSub)
87+
{
88+
await sub.PublishAsync(channel, "msg payload");
89+
for (int i = 0; i < 5 && received.IsEmpty; i++)
90+
{
91+
await Task.Delay(10, TestContext.Current.CancellationToken);
92+
await sub.PingAsync();
93+
}
94+
Assert.Equal("msg payload", Assert.Single(received));
95+
}
96+
97+
if (useDatabase)
98+
{
99+
Assert.Equal("myvalue", await db.StringGetAsync(key));
100+
}
101+
}
102+
103+
private static readonly EndPoint EP = new DnsEndPoint("home", 8000);
104+
private sealed class HandshakeServer(ServerResponse response, ITestOutputHelper log)
105+
: InProcessTestServer(log, EP)
106+
{
107+
protected override RedisProtocol MaxProtocol => response switch
108+
{
109+
ServerResponse.Resp3 => RedisProtocol.Resp3,
110+
_ => RedisProtocol.Resp2,
111+
};
112+
113+
protected override TypedRedisValue Hello(RedisClient client, in RedisRequest request)
114+
=> response is ServerResponse.UnknownCommand
115+
? request.CommandNotFound()
116+
: base.Hello(client, in request);
117+
}
118+
}

toys/StackExchange.Redis.Server/RedisClient.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ internal bool ShouldSkipResponse()
118118
public int Id { get; internal set; }
119119
public bool IsAuthenticated { get; internal set; }
120120
public RedisProtocol Protocol { get; internal set; } = RedisProtocol.Resp2;
121-
public long ProtocolVersion => Protocol is RedisProtocol.Resp2 ? 2 : 3;
122121

123122
private readonly CancellationTokenSource _lifetime = CancellationTokenSource.CreateLinkedTokenSource(node.Server.Lifetime);
124123

toys/StackExchange.Redis.Server/RedisServer.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Linq;
77
using System.Net;
8+
using System.Runtime.InteropServices;
89
using System.Text;
910
using System.Threading;
1011
using RESPite;
@@ -158,7 +159,7 @@ protected override void AppendStats(StringBuilder sb)
158159
public override TypedRedisValue Execute(RedisClient client, in RedisRequest request)
159160
{
160161
var pw = Password;
161-
if (pw.Length != 0 & !client.IsAuthenticated)
162+
if (!string.IsNullOrEmpty(pw) & !client.IsAuthenticated)
162163
{
163164
if (!IsAuthCommand(request.KnownCommand))
164165
return TypedRedisValue.Error("NOAUTH Authentication required.");
@@ -190,6 +191,8 @@ protected virtual TypedRedisValue Auth(RedisClient client, in RedisRequest reque
190191
return TypedRedisValue.Error("ERR invalid password");
191192
}
192193

194+
protected virtual RedisProtocol MaxProtocol => RedisProtocol.Resp3;
195+
193196
[RedisCommand(-1)]
194197
protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest request)
195198
{
@@ -204,12 +207,14 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
204207
case 2:
205208
protocol = RedisProtocol.Resp2;
206209
break;
207-
case 3: // this client does not currently support RESP3
210+
case 3:
208211
protocol = RedisProtocol.Resp3;
209212
break;
210213
default:
211214
return TypedRedisValue.Error("NOPROTO unsupported protocol version");
212215
}
216+
protocol = (RedisProtocol)Math.Min((int)protocol, (int)MaxProtocol);
217+
213218
static TypedRedisValue ArgFail(in RespReader reader) => TypedRedisValue.Error($"ERR Syntax error in HELLO option '{reader.ReadString()}'\"");
214219

215220
for (int i = 2; i < request.Count; i++)
@@ -246,6 +251,12 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
246251
}
247252

248253
// all good, update client
254+
long proto32 = protocol switch
255+
{
256+
>= RedisProtocol.Resp3 => 3,
257+
>= RedisProtocol.Resp2 => 2,
258+
_ => throw new InvalidOperationException($"Unexpected protocol: {protocol}"),
259+
};
249260
client.Protocol = protocol;
250261
client.IsAuthenticated = isAuthed;
251262
client.Name = name;
@@ -256,7 +267,7 @@ protected virtual TypedRedisValue Hello(RedisClient client, in RedisRequest requ
256267
span[2] = TypedRedisValue.BulkString("version");
257268
span[3] = TypedRedisValue.BulkString(VersionString);
258269
span[4] = TypedRedisValue.BulkString("proto");
259-
span[5] = TypedRedisValue.Integer(client.ProtocolVersion);
270+
span[5] = TypedRedisValue.Integer(proto32);
260271
span[6] = TypedRedisValue.BulkString("id");
261272
span[7] = TypedRedisValue.Integer(client.Id);
262273
span[8] = TypedRedisValue.BulkString("mode");

0 commit comments

Comments
 (0)