diff --git a/examples/Resolver.Athena.CliClient/ClassifyCommand.cs b/examples/Resolver.Athena.CliClient/ClassifyCommand.cs index 5928287..e8a8874 100644 --- a/examples/Resolver.Athena.CliClient/ClassifyCommand.cs +++ b/examples/Resolver.Athena.CliClient/ClassifyCommand.cs @@ -23,7 +23,12 @@ public static async Task DoClassifyCommand(ParseResult parseResult, Cancell CliUtilities.LoadDotEnv(parseResult); var svcs = new ServiceCollection() - .AddAthenaClient(CliUtilities.ConfigureAthenaClientFromEnv, CliUtilities.ConfigureOAuthTokenManagerFromEnv) + .AddAthenaClient(o => + { + CliUtilities.ConfigureAthenaClientFromEnv(o); + o.UnsafeAllowInsecure = parseResult.GetValue(CliUtilities.UnsafeAllowInsecure); + }, + CliUtilities.ConfigureOAuthTokenManagerFromEnv) .BuildServiceProvider(); var athenaClient = svcs.GetRequiredService(); diff --git a/examples/Resolver.Athena.CliClient/ClassifyDataflowCommand.cs b/examples/Resolver.Athena.CliClient/ClassifyDataflowCommand.cs index 0dbc8da..0d71b09 100644 --- a/examples/Resolver.Athena.CliClient/ClassifyDataflowCommand.cs +++ b/examples/Resolver.Athena.CliClient/ClassifyDataflowCommand.cs @@ -23,7 +23,12 @@ public static async Task DoClassifyDataflowCommand(ParseResult parseResult, CliUtilities.LoadDotEnv(parseResult); var svcs = new ServiceCollection() - .AddAthenaDataflowClient(CliUtilities.ConfigureAthenaClientFromEnv, CliUtilities.ConfigureOAuthTokenManagerFromEnv) + .AddAthenaDataflowClient(o => + { + CliUtilities.ConfigureAthenaClientFromEnv(o); + o.UnsafeAllowInsecure = parseResult.GetValue(CliUtilities.UnsafeAllowInsecure); + }, + CliUtilities.ConfigureOAuthTokenManagerFromEnv) .BuildServiceProvider(); var athenaClient = svcs.GetRequiredService(); diff --git a/examples/Resolver.Athena.CliClient/ClassifySingleCommand.cs b/examples/Resolver.Athena.CliClient/ClassifySingleCommand.cs index 4ada2d5..3c9a6e9 100644 --- a/examples/Resolver.Athena.CliClient/ClassifySingleCommand.cs +++ b/examples/Resolver.Athena.CliClient/ClassifySingleCommand.cs @@ -18,7 +18,12 @@ public static async Task DoClassifySingleCommand(ParseResult parseResult, C { CliUtilities.LoadDotEnv(parseResult); var svcs = new ServiceCollection() - .AddAthenaClient(CliUtilities.ConfigureAthenaClientFromEnv, CliUtilities.ConfigureOAuthTokenManagerFromEnv) + .AddAthenaClient(o => + { + CliUtilities.ConfigureAthenaClientFromEnv(o); + o.UnsafeAllowInsecure = parseResult.GetValue(CliUtilities.UnsafeAllowInsecure); + }, + CliUtilities.ConfigureOAuthTokenManagerFromEnv) .BuildServiceProvider(); var athenaClient = svcs.GetRequiredService(); diff --git a/examples/Resolver.Athena.CliClient/CliUtilities.cs b/examples/Resolver.Athena.CliClient/CliUtilities.cs index 8b23a61..cef2c57 100644 --- a/examples/Resolver.Athena.CliClient/CliUtilities.cs +++ b/examples/Resolver.Athena.CliClient/CliUtilities.cs @@ -33,6 +33,13 @@ public static partial class CliUtilities DefaultValueFactory = _ => 0, }; + public static readonly Option UnsafeAllowInsecure = new("--unsafe-insecure") + { + Description = "If set, allows insecure connections to the Athena endpoint. For development use only.", + DefaultValueFactory = _ => false, + Recursive = true, + }; + [GeneratedRegex("[^A-Za-z0-9_-]")] private static partial Regex CorrelationIdRegex(); diff --git a/examples/Resolver.Athena.CliClient/ListDeploymentsCommand.cs b/examples/Resolver.Athena.CliClient/ListDeploymentsCommand.cs index c7c3bbb..3edc368 100644 --- a/examples/Resolver.Athena.CliClient/ListDeploymentsCommand.cs +++ b/examples/Resolver.Athena.CliClient/ListDeploymentsCommand.cs @@ -17,7 +17,12 @@ public static async Task DoListDeploymentsCommand(ParseResult parseResult, CliUtilities.LoadDotEnv(parseResult); var svcs = new ServiceCollection() - .AddAthenaClient(CliUtilities.ConfigureAthenaClientFromEnv, CliUtilities.ConfigureOAuthTokenManagerFromEnv) + .AddAthenaClient(o => + { + CliUtilities.ConfigureAthenaClientFromEnv(o); + o.UnsafeAllowInsecure = parseResult.GetValue(CliUtilities.UnsafeAllowInsecure); + }, + CliUtilities.ConfigureOAuthTokenManagerFromEnv) .BuildServiceProvider(); var athenaClient = svcs.GetRequiredService(); diff --git a/examples/Resolver.Athena.CliClient/Program.cs b/examples/Resolver.Athena.CliClient/Program.cs index 792b4f2..e261be8 100644 --- a/examples/Resolver.Athena.CliClient/Program.cs +++ b/examples/Resolver.Athena.CliClient/Program.cs @@ -27,6 +27,7 @@ public static async Task Main(string[] args) // reuse the pre-defined static option so LoadDotEnv can access the same option rootCommand.Options.Add(CliUtilities.DotenvPathOption); + rootCommand.Options.Add(CliUtilities.UnsafeAllowInsecure); TokenTestCommand.RegisterCommand(rootCommand); ListDeploymentsCommand.RegisterCommand(rootCommand); diff --git a/src/Resolver.Athena.Client/ApiClient/AthenaApiClient.cs b/src/Resolver.Athena.Client/ApiClient/AthenaApiClient.cs index c6ba297..f1d213c 100644 --- a/src/Resolver.Athena.Client/ApiClient/AthenaApiClient.cs +++ b/src/Resolver.Athena.Client/ApiClient/AthenaApiClient.cs @@ -29,15 +29,20 @@ public AthenaApiClient(ITokenManager tokenManager, IOptions { var token = await tokenManager.GetTokenAsync(context.CancellationToken).ConfigureAwait(false); metadata.Add("Authorization", $"Bearer {token}"); - })) + })), + UnsafeUseInsecureChannelCallCredentials = options.Value.UnsafeAllowInsecure }; _client = clientFactory.Create(options.Value.Endpoint, channelOptions); diff --git a/src/Resolver.Athena.Client/ApiClient/AthenaApiClientConfiguration.cs b/src/Resolver.Athena.Client/ApiClient/AthenaApiClientConfiguration.cs index ad4d68e..2e9140d 100644 --- a/src/Resolver.Athena.Client/ApiClient/AthenaApiClientConfiguration.cs +++ b/src/Resolver.Athena.Client/ApiClient/AthenaApiClientConfiguration.cs @@ -24,4 +24,13 @@ public class AthenaApiClientConfiguration /// Indicates whether to send SHA1 hashes of images. /// public bool SendSha1Hash { get; set; } + + /// + /// Indicates whether to allow insecure connections. + /// + /// + /// This is obviously unsuitable for production use, and is intended for + /// local development and testing only. + /// + public bool UnsafeAllowInsecure { get; set; } } diff --git a/test/Resolver.Athena.Tests/Client/AthenaApiClientTests.cs b/test/Resolver.Athena.Tests/Client/AthenaApiClientTests.cs index ca14dad..fa23be5 100644 --- a/test/Resolver.Athena.Tests/Client/AthenaApiClientTests.cs +++ b/test/Resolver.Athena.Tests/Client/AthenaApiClientTests.cs @@ -14,6 +14,7 @@ namespace Resolver.Athena.Tests.Client; public class AthenaApiClientTests() { private readonly List _sentRequests = []; + private GrpcChannelOptions? _capturedOptions = null; [Fact] public async Task SingleSendAndReceiveAsync() @@ -161,6 +162,62 @@ public async Task MultipleImagesInMultipleRequestsSendAndReceiveAsync() } } + [Fact] + public void ChannelIsSecure_ByDefault() + { + var opts = new AthenaApiClientConfiguration + { + Endpoint = "https://mock-endpoint", + Affiliate = "test-affiliate", + SendMd5Hash = true, + SendSha1Hash = true, + }; + + CreateTestApiClient([], opts); + + Assert.NotNull(_capturedOptions); + Assert.NotNull(_capturedOptions.Credentials); + Assert.False(_capturedOptions.UnsafeUseInsecureChannelCallCredentials); + } + + [Fact] + public void ChannelIsSecure_WhenExplicitlySet() + { + var opts = new AthenaApiClientConfiguration + { + Endpoint = "https://mock-endpoint", + Affiliate = "test-affiliate", + SendMd5Hash = true, + SendSha1Hash = true, + UnsafeAllowInsecure = false + }; + + CreateTestApiClient([], opts); + + Assert.NotNull(_capturedOptions); + Assert.NotNull(_capturedOptions.Credentials); + Assert.False(_capturedOptions.UnsafeUseInsecureChannelCallCredentials); + } + + [Fact] + public void ChannelIsInsecure_WhenExplicitlySet() + { + var opts = new AthenaApiClientConfiguration + { + Endpoint = "https://mock-endpoint", + Affiliate = "test-affiliate", + SendMd5Hash = true, + SendSha1Hash = true, + UnsafeAllowInsecure = true + }; + + CreateTestApiClient([], opts); + + Assert.NotNull(_capturedOptions); + Assert.NotNull(_capturedOptions.Credentials); + Assert.True(_capturedOptions.UnsafeUseInsecureChannelCallCredentials); + } + private static async Task> CreateChannelWithDataAsync(IEnumerable requests) { var channel = Channel.CreateBounded(new BoundedChannelOptions(1024) @@ -179,6 +236,17 @@ private static async Task> CreateChannelWithDataAsync(I } private AthenaApiClient CreateTestApiClient(IEnumerable responses) + { + return CreateTestApiClient(responses, new AthenaApiClientConfiguration + { + Endpoint = "https://mock-endpoint", + Affiliate = "test-affiliate", + SendMd5Hash = true, + SendSha1Hash = true, + }); + } + + private AthenaApiClient CreateTestApiClient(IEnumerable responses, AthenaApiClientConfiguration config) { var fakeRequestStream = new Mock>(); fakeRequestStream.Setup(s => s.WriteAsync(It.IsAny(), It.IsAny())) @@ -198,19 +266,14 @@ private AthenaApiClient CreateTestApiClient(IEnumerable respon mockClient.Setup(c => c.Classify(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(duplex); - var opts = new OptionsWrapper(new AthenaApiClientConfiguration - { - Endpoint = "https://mock-endpoint", - Affiliate = "test-affiliate", - SendMd5Hash = true, - SendSha1Hash = true, - }); + var opts = new OptionsWrapper(config); var tokenManagerMock = new Mock(); tokenManagerMock.Setup(tm => tm.GetTokenAsync(It.IsAny())) .ReturnsAsync("mock-token"); var factory = new Mock(); factory.Setup(f => f.Create(It.IsAny(), It.IsAny())) + .Callback((string _, GrpcChannelOptions o) => _capturedOptions = o) .Returns(mockClient.Object); return new AthenaApiClient(tokenManagerMock.Object, opts, factory.Object);