-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAthenaApiClient.cs
More file actions
132 lines (115 loc) · 5.1 KB
/
AthenaApiClient.cs
File metadata and controls
132 lines (115 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
using System.Threading.Channels;
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.Options;
using Resolver.Athena.Client.ApiClient.Interfaces;
using Resolver.Athena.Grpc;
namespace Resolver.Athena.Client.ApiClient;
/// <summary>
/// Default implementation of <see cref="IAthenaApiClient"/>.
/// </summary>
public sealed class AthenaApiClient : IAthenaApiClient
{
private readonly ClassifierService.ClassifierServiceClient _client;
private Task? _senderTask = null;
private Task? _receiverTask = null;
/// <summary>
/// Initializes a new instance of the <see cref="AthenaApiClient"/> class.
/// </summary>
/// <param name="tokenManager">The token manager.</param>
/// <param name="options">The client configuration options.</param>
/// <param name="clientFactory">Factory used to create the underlying gRPC client.</param>
public AthenaApiClient(ITokenManager tokenManager, IOptions<AthenaApiClientConfiguration> options, IAthenaClassifierServiceClientFactory clientFactory)
{
ArgumentNullException.ThrowIfNull(tokenManager);
ArgumentNullException.ThrowIfNull(options);
ArgumentNullException.ThrowIfNull(clientFactory);
var credentialType = options.Value.UnsafeAllowInsecure
? ChannelCredentials.Insecure
: new SslCredentials();
var channelOptions = new GrpcChannelOptions
{
Credentials = ChannelCredentials.Create(
credentialType,
CallCredentials.FromInterceptor(async (context, metadata) =>
{
var token = await tokenManager.GetTokenAsync(context.CancellationToken).ConfigureAwait(false);
metadata.Add("Authorization", $"Bearer {token}");
})),
UnsafeUseInsecureChannelCallCredentials = options.Value.UnsafeAllowInsecure
};
_client = clientFactory.Create(options.Value.Endpoint, channelOptions);
}
/// <inheritdoc />
public async Task<ClassificationOutput> ClassifySingleAsync(ClassificationInput input, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(input);
return await _client.ClassifySingleAsync(input, cancellationToken: cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc />
public async Task<ListDeploymentsResponse> ListDeploymentsAsync(CancellationToken cancellationToken)
{
return await _client.ListDeploymentsAsync(new Google.Protobuf.WellKnownTypes.Empty(), cancellationToken: cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc />
public Task<Channel<ClassifyResponse>> ClassifyAsync(ChannelReader<ClassifyRequest> requestChannel, int responseChannelCapacity, CancellationToken cancellationToken)
{
if (_senderTask != null || _receiverTask != null)
{
throw new InvalidOperationException("ClassifyAsync can only be called once per AthenaApiClient instance.");
}
var call = _client.Classify(cancellationToken: cancellationToken);
var responseChannel = Channel.CreateBounded<ClassifyResponse>(new BoundedChannelOptions(responseChannelCapacity)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});
_senderTask = RequestLoopAsync(requestChannel, call.RequestStream, cancellationToken);
_receiverTask = ResponseLoopAsync(responseChannel, call.ResponseStream, cancellationToken);
return Task.FromResult(responseChannel);
}
private static async Task RequestLoopAsync(ChannelReader<ClassifyRequest> requestChannel, IClientStreamWriter<ClassifyRequest> requestStream, CancellationToken cancellationToken)
{
try
{
await foreach (var req in requestChannel.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
await requestStream.WriteAsync(req, cancellationToken).ConfigureAwait(false);
}
await requestStream.CompleteAsync().ConfigureAwait(false);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// best-effort attempt to complete gRPC stream
try
{
await requestStream.CompleteAsync().ConfigureAwait(false);
}
catch
{
}
}
catch (Exception)
{
await requestStream.CompleteAsync();
throw;
}
}
private static async Task ResponseLoopAsync(Channel<ClassifyResponse> responseChannel, IAsyncStreamReader<ClassifyResponse> responseStream, CancellationToken cancellationToken)
{
try
{
await foreach (var resp in responseStream.ReadAllAsync(cancellationToken).ConfigureAwait(false))
{
await responseChannel.Writer.WriteAsync(resp, cancellationToken).ConfigureAwait(false);
}
responseChannel.Writer.TryComplete();
}
catch (Exception ex)
{
responseChannel.Writer.TryComplete(ex);
throw;
}
}
}