Skip to content

Commit 3bf4b3d

Browse files
committed
Select the fastest (Kdc/Kpasswd) instance instead of at random
The implementation that chooses at random is fine as long as all kdcs advertised are working. But in case some of them are not reachable then choosing at random will randomly fail. The new implementation pings all the kdcs and selects the server that replies first.
1 parent 77e10aa commit 3bf4b3d

2 files changed

Lines changed: 80 additions & 13 deletions

File tree

Kerberos.NET/Client/Transport/KerberosTransportBase.cs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System;
77
using System.Collections.Generic;
88
using System.Linq;
9+
using System.Net.NetworkInformation;
910
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Kerberos.NET.Asn1;
@@ -19,15 +20,15 @@ namespace Kerberos.NET.Transport
1920
{
2021
public abstract class KerberosTransportBase : IKerberosTransport2, IDisposable
2122
{
22-
private static readonly Random Random = new Random();
23-
2423
protected KerberosTransportBase(ILoggerFactory logger)
2524
{
2625
this.ClientRealmService = new ClientDomainService(logger);
2726
}
2827

2928
private bool disposedValue;
3029

30+
private DnsRecord fastest;
31+
3132
public virtual bool TransportFailed { get; set; }
3233

3334
public virtual KerberosTransportException LastError { get; set; }
@@ -165,34 +166,58 @@ public void Dispose()
165166
protected virtual async Task<DnsRecord> LocatePreferredKdc(string domain, string servicePrefix)
166167
{
167168
var results = await this.LocateKdc(domain, servicePrefix);
168-
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
169+
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
169170
}
170171

171172
protected virtual async Task<DnsRecord> LocatePreferredKpasswd(string domain, string servicePrefix)
172173
{
173174
var results = await this.LocateKpasswd(domain, servicePrefix);
174-
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
175+
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
175176
}
176177

177-
protected virtual DnsRecord SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
178+
protected virtual async Task<DnsRecord> SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
178179
{
179-
results = results.Where(r => r.Name.StartsWith(servicePrefix));
180+
if (results.Contains(fastest, DnsRecordComparer.Instance))
181+
{
182+
return fastest;
183+
}
180184

181-
var rand = Random.Next(0, results?.Count() ?? 0);
185+
fastest = await results.Where(r => r.Name.StartsWith(servicePrefix)).GetFastestAsync(PingAsync);
186+
return fastest ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
187+
}
182188

183-
var srv = results?.ElementAtOrDefault(rand);
189+
private async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
190+
{
191+
using var ping = new Ping();
192+
cancellationToken.Register(() => ping.SendAsyncCancel());
193+
var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(ConnectTimeout.TotalMilliseconds));
194+
return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
195+
}
184196

185-
if (srv == null)
197+
private class DnsRecordComparer : IEqualityComparer<DnsRecord>
198+
{
199+
public static readonly DnsRecordComparer Instance = new();
200+
201+
private DnsRecordComparer()
186202
{
187-
throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
188203
}
189204

190-
if (srv.Port <= 0)
205+
public bool Equals(DnsRecord x, DnsRecord y)
191206
{
192-
srv.Port = defaultPort;
207+
if (ReferenceEquals(x, y)) return true;
208+
if (x is null) return false;
209+
if (y is null) return false;
210+
if (x.GetType() != y.GetType()) return false;
211+
return x.Target == y.Target && x.Port == y.Port;
193212
}
194213

195-
return srv;
214+
public int GetHashCode(DnsRecord obj)
215+
{
216+
unchecked
217+
{
218+
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
219+
}
220+
}
196221
}
197222
}
198223
}

Kerberos.NET/TaskExtensions.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// -----------------------------------------------------------------------
2+
// Licensed to The .NET Foundation under one or more agreements.
3+
// The .NET Foundation licenses this file to you under the MIT license.
4+
// -----------------------------------------------------------------------
5+
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
12+
internal static class TaskExtensions
13+
{
14+
public static async Task<TResult> GetFastestAsync<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, CancellationToken, Task<TResult>> task, CancellationToken cancellationToken = default)
15+
{
16+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
17+
var tasks = new HashSet<Task<TResult>>(source.Select(e => task(e, cts.Token)));
18+
if (tasks.Count == 0)
19+
{
20+
return default;
21+
}
22+
23+
var exceptions = new List<Exception>();
24+
do
25+
{
26+
var completedTask = await Task.WhenAny(tasks);
27+
if (completedTask.Status == TaskStatus.RanToCompletion)
28+
{
29+
cts.Cancel();
30+
return completedTask.Result;
31+
}
32+
33+
if (completedTask.Exception != null)
34+
{
35+
exceptions.AddRange(completedTask.Exception.InnerExceptions);
36+
}
37+
tasks.Remove(completedTask);
38+
} while (tasks.Count > 0);
39+
40+
throw new AggregateException(exceptions);
41+
}
42+
}

0 commit comments

Comments
 (0)