-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathClassifyDataflowCommand.cs
More file actions
148 lines (119 loc) · 5.56 KB
/
ClassifyDataflowCommand.cs
File metadata and controls
148 lines (119 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
using System.CommandLine;
using System.Threading.Tasks.Dataflow;
using Microsoft.Extensions.DependencyInjection;
using Resolver.Athena.Client.HighLevelClient.Images;
using Resolver.Athena.Client.HighLevelClient.Models;
using Resolver.Athena.Client.TPL.DependencyInjection;
using Resolver.Athena.Client.TPL.Interfaces;
namespace Resolver.Athena.CliClient;
public static class ClassifyDataflowCommand
{
public static void RegisterCommand(RootCommand rootCommand)
{
var cmd = rootCommand.AddAthenaCommand("classify-dataflow", "Classify images using the TPL Dataflow Client", DoClassifyDataflowCommand);
cmd.Options.Add(CliUtilities.RepeatOption);
cmd.Arguments.Add(CliUtilities.DeploymentIdArgument);
cmd.Arguments.Add(CliUtilities.ImagePathArgument);
}
public static async Task<int> DoClassifyDataflowCommand(ParseResult parseResult, CancellationToken cancellationToken)
{
CliUtilities.LoadDotEnv(parseResult);
var svcs = new ServiceCollection()
.AddAthenaDataflowClient(o =>
{
CliUtilities.ConfigureAthenaClientFromEnv(o);
o.UnsafeAllowInsecure = parseResult.GetValue(CliUtilities.UnsafeAllowInsecure);
},
CliUtilities.ConfigureOAuthTokenManagerFromEnv)
.BuildServiceProvider();
var athenaClient = svcs.GetRequiredService<IAthenaDataflowClient>();
var pipeline = await athenaClient.CreatePipelineAsync(cancellationToken);
var deploymentId = CliUtilities.GetDeploymentId(parseResult);
var path = parseResult.GetValue(CliUtilities.ImagePathArgument) ?? throw new InvalidOperationException("Image path argument is required.");
var repeatSeconds = parseResult.GetValue(CliUtilities.RepeatOption);
try
{
// gather files: if a directory is provided, enumerate common image extensions; otherwise treat as single file
var blockLinkOptions = new DataflowLinkOptions
{
PropagateCompletion = true,
};
var requestsToSendCount = new TaskCompletionSource<int>();
var filepathGathererBlock = new TransformManyBlock<string, string>(dirPath =>
{
var files = CliUtilities.GatherImagePaths(path);
requestsToSendCount.TrySetResult(files.Count);
return files;
}, new ExecutionDataflowBlockOptions
{
CancellationToken = cancellationToken,
});
var loaderBlock = new TransformBlock<string, ClassificationRequest>(async f =>
{
var data = await File.ReadAllBytesAsync(f, cancellationToken).ConfigureAwait(false);
var image = new AthenaImageEncoded(data);
var rawCorrelation = Path.GetFileNameWithoutExtension(f) ?? string.Empty;
var sanitized = CliUtilities.SanitizeCorrelationId(rawCorrelation);
Console.WriteLine($"Sending file: {f} (correlation: {sanitized})");
return new ClassificationRequest(deploymentId, image, sanitized);
}, new ExecutionDataflowBlockOptions
{
CancellationToken = cancellationToken,
MaxDegreeOfParallelism = Environment.ProcessorCount,
});
filepathGathererBlock.LinkTo(loaderBlock, blockLinkOptions);
loaderBlock.LinkTo(pipeline.Input, blockLinkOptions);
var consumedResponses = 0;
var errorResponses = 0;
var pipelineTCS = new TaskCompletionSource();
var loggerBlock = new ActionBlock<ClassificationResult>(async result =>
{
Console.WriteLine("[consumer] received classification result");
consumedResponses++;
Console.WriteLine(result.ToPrettyString());
if (result.ErrorDetails != null)
{
errorResponses++;
}
if (repeatSeconds == 0 && consumedResponses >= await requestsToSendCount.Task)
{
pipelineTCS.TrySetResult();
}
}, new ExecutionDataflowBlockOptions
{
CancellationToken = cancellationToken,
});
pipeline.Output.LinkTo(loggerBlock, blockLinkOptions);
if (repeatSeconds > 0)
{
Console.WriteLine($"Repeating every {repeatSeconds} seconds. Press Ctrl+C to stop.");
while (!cancellationToken.IsCancellationRequested)
{
filepathGathererBlock.Post(path);
Console.WriteLine($"Waiting {repeatSeconds} seconds before next cycle...");
try
{
await Task.Delay(TimeSpan.FromSeconds(repeatSeconds), cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
break;
}
}
}
else
{
filepathGathererBlock.Post(path);
}
Console.WriteLine("[consumer] starting dataflow consume");
await pipelineTCS.Task;
Console.WriteLine(CliUtilities.GenerateStreamSummary(await requestsToSendCount.Task, consumedResponses, errorResponses));
return 0;
}
catch (Exception ex)
{
Console.WriteLine($"Failed to classify images (stream): {ex.Message}");
return 1;
}
}
}