Skip to content

Commit 0846c38

Browse files
committed
fix: address review findings -- security, correctness, and code quality
- Sanitize connection string in constructor exceptions to avoid leaking credentials (strip userinfo, show only scheme/host/port/path) - Guard DeliveryDelay overflow: fail with clear ArgumentOutOfRangeException instead of unhelpful OverflowException when delay > Int32.MaxValue ms - Handle IPv6 bracket notation in ParseHostEndpoint ([::1]:port) - Extract duplicated publisher channel creation into CreatePublisherChannelAsync - Remove stale empty XML doc tags on PublishImplAsync and CreateConnectionAsync - Add PERF comments at ToArray() call sites and publish lock documenting allocation/serialization trade-offs (tracked in FoundatioFx/Foundatio#512)
1 parent e90c803 commit 0846c38

1 file changed

Lines changed: 65 additions & 39 deletions

File tree

src/Foundatio.RabbitMQ/Messaging/RabbitMQMessageBus.cs

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ public RabbitMQMessageBus(RabbitMQMessageBusOptions options) : base(options)
4040
ArgumentException.ThrowIfNullOrWhiteSpace(options?.ConnectionString, nameof(options.ConnectionString));
4141

4242
if (!Uri.TryCreate(options.ConnectionString, UriKind.Absolute, out var primaryUri))
43-
throw new ArgumentException($"ConnectionString is not a valid URI: {options.ConnectionString}");
43+
throw new ArgumentException($"ConnectionString is not a valid URI: {SanitizeUri(options.ConnectionString)}");
4444

4545
if (!primaryUri.Scheme.Equals("amqp", StringComparison.OrdinalIgnoreCase) &&
4646
!primaryUri.Scheme.Equals("amqps", StringComparison.OrdinalIgnoreCase))
47-
throw new ArgumentException($"ConnectionString must use amqp:// or amqps:// scheme: {options.ConnectionString}");
47+
throw new ArgumentException($"ConnectionString must use amqp:// or amqps:// scheme: {SanitizeUri(primaryUri)}");
4848

4949
_isQuorumQueue = options.Arguments is not null && options.Arguments.TryGetValue("x-queue-type", out object? queueType) && queueType is string type && String.Equals(type, "quorum", StringComparison.OrdinalIgnoreCase);
5050

@@ -357,6 +357,7 @@ private async Task RepublishMessageWithIncrementedDeliveryCountAsync(BasicDelive
357357
try
358358
{
359359
await EnsureTopicCreatedAsync(envelope.CancellationToken).AnyContext();
360+
// PERF: ToArray() allocates; PublishMessageAsync takes byte[] until Foundatio base supports ReadOnlyMemory<byte>
360361
await PublishMessageAsync(envelope.Exchange, envelope.RoutingKey, envelope.Body.ToArray(), properties, envelope.CancellationToken).AnyContext();
361362
await subscriberChannel.BasicAckAsync(envelope.DeliveryTag, false).AnyContext();
362363

@@ -424,6 +425,8 @@ await _publisherReady.WaitAsync(cancellationToken)
424425
}
425426
}
426427

428+
// PERF: Single lock serializes all publishes; with publisher confirms this limits throughput to 1 RTT.
429+
// Consider channel pooling or batch publishing for high-throughput scenarios.
427430
using (await _lock.LockAsync(cancellationToken).AnyContext())
428431
{
429432
if (_publisherChannel is not { IsOpen: true } channel)
@@ -442,6 +445,7 @@ await _publisherReady.WaitAsync(cancellationToken)
442445

443446
protected virtual IMessage ConvertToMessage(BasicDeliverEventArgs envelope)
444447
{
448+
// PERF: ToArray() allocates a copy; Message ctor requires byte[] until Foundatio supports ReadOnlyMemory<byte>
445449
var message = new Message(envelope.Body.ToArray(), DeserializeMessageBody)
446450
{
447451
Type = envelope.BasicProperties.Type,
@@ -483,17 +487,7 @@ protected override async Task EnsureTopicCreatedAsync(CancellationToken cancella
483487
_isPublisherBlocked = false;
484488
_publisherBlockedReason = null;
485489

486-
if (_options.PublisherConfirmsEnabled)
487-
{
488-
var channelOptions = new CreateChannelOptions(
489-
publisherConfirmationsEnabled: true,
490-
publisherConfirmationTrackingEnabled: true);
491-
_publisherChannel = await _publisherConnection.CreateChannelAsync(channelOptions, cancellationToken).AnyContext();
492-
}
493-
else
494-
{
495-
_publisherChannel = await _publisherConnection.CreateChannelAsync(cancellationToken: cancellationToken).AnyContext();
496-
}
490+
_publisherChannel = await CreatePublisherChannelAsync(cancellationToken).AnyContext();
497491

498492
// We first attempt to create "x-delayed-type". For this the rabbitmq_delayed_message_exchange plugin should be installed.
499493
// However, if the plugin is not installed this will throw an exception. In that case
@@ -521,17 +515,7 @@ protected override async Task EnsureTopicCreatedAsync(CancellationToken cancella
521515
_isPublisherBlocked = false;
522516
_publisherBlockedReason = null;
523517

524-
if (_options.PublisherConfirmsEnabled)
525-
{
526-
var channelOptions = new CreateChannelOptions(
527-
publisherConfirmationsEnabled: true,
528-
publisherConfirmationTrackingEnabled: true);
529-
_publisherChannel = await _publisherConnection.CreateChannelAsync(channelOptions, cancellationToken).AnyContext();
530-
}
531-
else
532-
{
533-
_publisherChannel = await _publisherConnection.CreateChannelAsync(cancellationToken: cancellationToken).AnyContext();
534-
}
518+
_publisherChannel = await CreatePublisherChannelAsync(cancellationToken).AnyContext();
535519
await CreateRegularExchangeAsync(_publisherChannel).AnyContext();
536520
}
537521

@@ -599,13 +583,6 @@ private Task OnPublisherConnectionOnRecoverySucceededAsync(object sender, AsyncE
599583
return Task.CompletedTask;
600584
}
601585

602-
/// <summary>
603-
/// Publish the message
604-
/// </summary>
605-
/// <param name="messageType"></param>
606-
/// <param name="message"></param>
607-
/// <param name="options">Message options</param>
608-
/// <param name="cancellationToken"></param>
609586
protected override async Task PublishImplAsync(string messageType, object message, MessageOptions options, CancellationToken cancellationToken)
610587
{
611588
byte[] data = SerializeMessageBody(messageType, message);
@@ -648,7 +625,10 @@ protected override async Task PublishImplAsync(string messageType, object messag
648625
// data back as signed (using BinaryReader#ReadInt64). You will see the value to be negative
649626
// and the data will be delivered immediately.
650627
basicProperties.Headers ??= new Dictionary<string, object?>();
651-
basicProperties.Headers["x-delay"] = Convert.ToInt32(options.DeliveryDelay.Value.TotalMilliseconds);
628+
double delayMs = options.DeliveryDelay.Value.TotalMilliseconds;
629+
if (delayMs > Int32.MaxValue)
630+
throw new ArgumentOutOfRangeException(nameof(options), $"DeliveryDelay ({options.DeliveryDelay.Value}) exceeds the maximum supported by RabbitMQ delayed exchange plugin ({Int32.MaxValue}ms).");
631+
basicProperties.Headers["x-delay"] = (int)delayMs;
652632
_logger.LogTrace("Schedule delayed message: {MessageType} ({Delay}ms)", messageType, options.DeliveryDelay.Value.TotalMilliseconds);
653633
}
654634
else
@@ -660,16 +640,27 @@ protected override async Task PublishImplAsync(string messageType, object messag
660640
_logger.LogDebug("Done publishing type {MessageType} {MessageId}", messageType, basicProperties.MessageId);
661641
}
662642

663-
/// <summary>
664-
/// Connect to a broker - RabbitMQ
665-
/// </summary>
666-
/// <returns></returns>
667643
private Task<IConnection> CreateConnectionAsync()
668644
{
669-
// Use multiple endpoints for failover support
670645
return _factory.CreateConnectionAsync(_endpoints);
671646
}
672647

648+
private async Task<IChannel> CreatePublisherChannelAsync(CancellationToken cancellationToken)
649+
{
650+
if (_publisherConnection is null)
651+
throw new MessageBusException("Publisher connection must be initialized before creating a channel.");
652+
653+
if (_options.PublisherConfirmsEnabled)
654+
{
655+
var channelOptions = new CreateChannelOptions(
656+
publisherConfirmationsEnabled: true,
657+
publisherConfirmationTrackingEnabled: true);
658+
return await _publisherConnection.CreateChannelAsync(channelOptions, cancellationToken).AnyContext();
659+
}
660+
661+
return await _publisherConnection.CreateChannelAsync(cancellationToken: cancellationToken).AnyContext();
662+
}
663+
673664
private void DetectServerVersion(IConnection connection)
674665
{
675666
if (_serverVersion is not null)
@@ -830,6 +821,22 @@ private async Task CloseSubscriberConnectionAsync(CancellationToken cancellation
830821
}
831822
}
832823

824+
private static string SanitizeUri(Uri uri)
825+
{
826+
if (String.IsNullOrEmpty(uri.UserInfo))
827+
return uri.ToString();
828+
829+
string portSuffix = uri.IsDefaultPort ? "" : $":{uri.Port}";
830+
return $"{uri.Scheme}://***@{uri.Host}{portSuffix}{uri.AbsolutePath}";
831+
}
832+
833+
private static string SanitizeUri(string connectionString)
834+
{
835+
return Uri.TryCreate(connectionString, UriKind.Absolute, out var uri)
836+
? SanitizeUri(uri)
837+
: "***";
838+
}
839+
833840
/// <summary>
834841
/// Parses a host string in format "hostname" or "hostname:port" into an AmqpTcpEndpoint.
835842
/// </summary>
@@ -839,13 +846,32 @@ private async Task CloseSubscriberConnectionAsync(CancellationToken cancellation
839846
return null;
840847

841848
string trimmed = host.Trim();
849+
850+
// Handle IPv6 bracket notation: [::1] or [::1]:5672
851+
if (trimmed.StartsWith('['))
852+
{
853+
int closeBracket = trimmed.IndexOf(']');
854+
if (closeBracket < 0)
855+
return new AmqpTcpEndpoint(trimmed, defaultPort);
856+
857+
string ipv6Host = trimmed[1..closeBracket];
858+
if (closeBracket + 1 < trimmed.Length && trimmed[closeBracket + 1] == ':')
859+
{
860+
return Int32.TryParse(trimmed[(closeBracket + 2)..], out int port)
861+
? new AmqpTcpEndpoint(ipv6Host, port)
862+
: new AmqpTcpEndpoint(ipv6Host, defaultPort);
863+
}
864+
865+
return new AmqpTcpEndpoint(ipv6Host, defaultPort);
866+
}
867+
842868
int colonIndex = trimmed.LastIndexOf(':');
843869
if (colonIndex < 0)
844870
return new AmqpTcpEndpoint(trimmed, defaultPort);
845871

846872
string hostname = trimmed[..colonIndex];
847-
return Int32.TryParse(trimmed[(colonIndex + 1)..], out int port)
848-
? new AmqpTcpEndpoint(hostname, port)
873+
return Int32.TryParse(trimmed[(colonIndex + 1)..], out int parsedPort)
874+
? new AmqpTcpEndpoint(hostname, parsedPort)
849875
: new AmqpTcpEndpoint(trimmed, defaultPort);
850876
}
851877

0 commit comments

Comments
 (0)