77package io .kroxylicious .filter .encryption ;
88
99import java .util .ArrayList ;
10+ import java .util .Collection ;
1011import java .util .EnumSet ;
1112import java .util .List ;
1213import java .util .Map ;
1314import java .util .Set ;
1415import java .util .concurrent .CompletableFuture ;
1516import java .util .concurrent .CompletionStage ;
17+ import java .util .function .BiFunction ;
1618import java .util .function .Function ;
19+ import java .util .function .ToIntFunction ;
1720import java .util .stream .Collectors ;
1821
22+ import org .apache .kafka .common .Uuid ;
1923import org .apache .kafka .common .errors .ApiException ;
2024import org .apache .kafka .common .message .ApiVersionsResponseData ;
2125import org .apache .kafka .common .message .FetchResponseData ;
2529import org .apache .kafka .common .message .ProduceRequestData .TopicProduceData ;
2630import org .apache .kafka .common .message .RequestHeaderData ;
2731import org .apache .kafka .common .message .ResponseHeaderData ;
32+ import org .apache .kafka .common .message .ShareFetchResponseData ;
33+ import org .apache .kafka .common .message .ShareFetchResponseData .ShareFetchableTopicResponse ;
2834import org .apache .kafka .common .protocol .ApiKeys ;
2935import org .apache .kafka .common .protocol .Errors ;
36+ import org .apache .kafka .common .record .BaseRecords ;
3037import org .apache .kafka .common .record .MemoryRecords ;
38+ import org .apache .kafka .common .utils .ImplicitLinkedHashCollection ;
3139import org .slf4j .Logger ;
3240
3341import io .micrometer .core .instrument .Counter ;
5260import io .kroxylicious .proxy .filter .ProduceRequestFilter ;
5361import io .kroxylicious .proxy .filter .RequestFilterResult ;
5462import io .kroxylicious .proxy .filter .ResponseFilterResult ;
63+ import io .kroxylicious .proxy .filter .ShareFetchResponseFilter ;
64+ import io .kroxylicious .proxy .filter .metadata .TopicNameMapping ;
65+ import io .kroxylicious .proxy .filter .metadata .TopicNameMappingException ;
5566
5667import edu .umd .cs .findbugs .annotations .NonNull ;
5768
6273 * @param <K> The type of KEK reference
6374 */
6475public class RecordEncryptionFilter <K >
65- implements ProduceRequestFilter , FetchResponseFilter , ApiVersionsResponseFilter {
76+ implements ProduceRequestFilter , FetchResponseFilter , ShareFetchResponseFilter , ApiVersionsResponseFilter {
6677 private static final Logger log = getLogger (RecordEncryptionFilter .class );
6778 private final TopicNameBasedKekSelector <K > kekSelector ;
6879
@@ -123,7 +134,7 @@ else if (isEncryptionException(throwable.getCause())) {
123134
124135 private CompletionStage <ProduceRequestData > maybeEncodeProduce (ProduceRequestData request , FilterContext context ) {
125136 var plainRecordsTotal = RecordEncryptionMetrics .plainRecordsCounter (context .getVirtualClusterName ());
126- var encyptedRecordsTotal = RecordEncryptionMetrics .encryptedRecordsCounter (context .getVirtualClusterName ());
137+ var encryptedRecordsTotal = RecordEncryptionMetrics .encryptedRecordsCounter (context .getVirtualClusterName ());
127138 var topicNameToData = request .topicData ().stream ().collect (Collectors .toMap (TopicProduceData ::name , Function .identity ()));
128139 CompletionStage <TopicNameKekSelection <K >> keks = filterThreadExecutor .completingOnFilterThread (kekSelector .selectKek (topicNameToData .keySet ()));
129140 return keks // figure out what keks we need
@@ -149,7 +160,7 @@ private CompletionStage<ProduceRequestData> maybeEncodeProduce(ProduceRequestDat
149160 context ::createByteBufferOutputStream )
150161 .thenApply (ppd ::setRecords )
151162 .thenApply (produceData -> {
152- encyptedRecordsTotal
163+ encryptedRecordsTotal
153164 .withTags (RecordEncryptionMetrics .TOPIC_NAME , topicName )
154165 .increment (RecordEncryptionUtil .totalRecordsInBatches ((MemoryRecords ) produceData .records ()));
155166 return null ;
@@ -176,6 +187,40 @@ private void generatePlainRecordsMetrics(Meter.MeterProvider<Counter> plainRecor
176187 .increment (RecordEncryptionUtil .totalRecordsInBatches ((MemoryRecords ) produceData .records ()))));
177188 }
178189
190+ @ Override
191+ public CompletionStage <ResponseFilterResult > onShareFetchResponse (short apiVersion , ResponseHeaderData header , ShareFetchResponseData response ,
192+ FilterContext context ) {
193+ Set <Uuid > topicIds = response .responses ().stream ().map (ShareFetchableTopicResponse ::topicId ).collect (Collectors .toSet ());
194+ return context .topicNames (topicIds ).thenCompose (topicNameMapping -> maybeDecodeShareFetch (topicNameMapping , response .responses (), context )
195+ .thenCompose (topicResponses -> {
196+ ShareFetchResponseData .ShareFetchableTopicResponseCollection collection = new ShareFetchResponseData .ShareFetchableTopicResponseCollection ();
197+ // danger, the share fetch response uses a collection where the `add` method can silently do nothing
198+ // if the element being added already has element link fields populated it will not be added.
199+ // we reset next and prev so the existing elements can be added without calling `duplicate()`
200+ // which would copy the memory records again.
201+ // this may leave the original collection in an inconsistent state, but we are finished using it
202+ topicResponses .forEach (topicResponse -> {
203+ topicResponse .setNext (ImplicitLinkedHashCollection .INVALID_INDEX );
204+ topicResponse .setPrev (ImplicitLinkedHashCollection .INVALID_INDEX );
205+ collection .mustAdd (topicResponse );
206+ });
207+ return context .forwardResponse (header , response .setResponses (collection ));
208+ })
209+ .exceptionallyCompose (throwable -> {
210+ if (throwable .getCause () instanceof UnknownKeyException ) {
211+ // #maybeDecodePartitions will have set the RESOURCE_NOT_FOUND error code on the partition(s) that failed to decrypt
212+ // and will have logged the affected topic-partitions.
213+ // Remove all the records from the whole fetch to avoid the possibility that the client processes an incomplete response.
214+ response .responses ().forEach (topicResponse -> topicResponse .partitions ().forEach (p -> p .setRecords (MemoryRecords .EMPTY )));
215+ return context .forwardResponse (header , response );
216+ }
217+ else {
218+ // returning a failed stage is effectively asking the runtime to kill the connection.
219+ return logAndCreateFailedStage (throwable );
220+ }
221+ }));
222+ }
223+
179224 @ Override
180225 public CompletionStage <ResponseFilterResult > onFetchResponse (short apiVersion , ResponseHeaderData header , FetchResponseData response , FilterContext context ) {
181226 return maybeDecodeFetch (response .responses (), context )
@@ -189,36 +234,83 @@ public CompletionStage<ResponseFilterResult> onFetchResponse(short apiVersion, R
189234 return context .forwardResponse (header , response );
190235 }
191236 else {
192- log .atWarn ().setMessage ("Failed to process records, connection will be closed, cause message: {}" )
193- .addArgument (throwable .getMessage ())
194- .setCause (log .isDebugEnabled () ? throwable : null )
195- .log ();
196237 // returning a failed stage is effectively asking the runtime to kill the connection.
197- return CompletableFuture . failedStage (throwable );
238+ return logAndCreateFailedStage (throwable );
198239 }
199240 });
200241 }
201242
243+ private static CompletionStage <ResponseFilterResult > logAndCreateFailedStage (Throwable throwable ) {
244+ log .atWarn ().setMessage ("Failed to process records, connection will be closed, cause message: {}. Raise log level to DEBUG to see the stack." )
245+ .addArgument (throwable .getMessage ())
246+ .setCause (log .isDebugEnabled () ? throwable : null )
247+ .log ();
248+ return CompletableFuture .failedStage (throwable );
249+ }
250+
251+ private CompletionStage <List <ShareFetchableTopicResponse >> maybeDecodeShareFetch (TopicNameMapping mapping , Collection <ShareFetchableTopicResponse > topics ,
252+ FilterContext context ) {
253+ List <CompletionStage <ShareFetchableTopicResponse >> result = new ArrayList <>(topics .size ());
254+ for (ShareFetchableTopicResponse topicData : topics ) {
255+ TopicNameMappingException failure = mapping .failures ().get (topicData .topicId ());
256+ String topicName = mapping .topicNames ().get (topicData .topicId ());
257+ List <ShareFetchResponseData .PartitionData > partitions = topicData .partitions ();
258+ if (topicName != null ) {
259+ result .add (maybeDecodePartitions (topicName , partitions , context , ShareFetchResponseData .PartitionData ::records ,
260+ ShareFetchResponseData .PartitionData ::partitionIndex ,
261+ partitionData -> partitionData ::setRecords , ShareFetchResponseData .PartitionData ::setErrorCode ).thenApply (kk -> {
262+ topicData .setPartitions (kk );
263+ return topicData ;
264+ }));
265+ }
266+ else {
267+ result .add (setAllPartitionsToError (topicData , failure ));
268+ }
269+ }
270+ return RecordEncryptionUtil .join (result );
271+ }
272+
273+ private static CompletableFuture <ShareFetchableTopicResponse > setAllPartitionsToError (ShareFetchableTopicResponse topicData ,
274+ TopicNameMappingException failure ) {
275+ for (ShareFetchResponseData .PartitionData partition : topicData .partitions ()) {
276+ partition .setRecords (MemoryRecords .EMPTY );
277+ Errors error = failure != null ? failure .getError () : Errors .UNKNOWN_SERVER_ERROR ;
278+ partition .setErrorCode (error .code ());
279+ partition .setErrorMessage (error .message ());
280+ }
281+ return CompletableFuture .completedFuture (topicData );
282+ }
283+
202284 private CompletionStage <List <FetchableTopicResponse >> maybeDecodeFetch (List <FetchableTopicResponse > topics , FilterContext context ) {
203285 List <CompletionStage <FetchableTopicResponse >> result = new ArrayList <>(topics .size ());
204286 for (FetchableTopicResponse topicData : topics ) {
205- result .add (maybeDecodePartitions (topicData .topic (), topicData .partitions (), context ).thenApply (kk -> {
206- topicData .setPartitions (kk );
207- return topicData ;
208- }));
287+ String topicName = topicData .topic ();
288+ List <PartitionData > partitions = topicData .partitions ();
289+ result .add (maybeDecodePartitions (topicName , partitions , context , PartitionData ::records , PartitionData ::partitionIndex ,
290+ partitionData -> partitionData ::setRecords ,
291+ PartitionData ::setErrorCode ).thenApply (kk -> {
292+ topicData .setPartitions (kk );
293+ return topicData ;
294+ }));
209295 }
210296 return RecordEncryptionUtil .join (result );
211297 }
212298
213- private CompletionStage <List <PartitionData >> maybeDecodePartitions (String topicName ,
214- List <PartitionData > partitions ,
215- FilterContext context ) {
216- List <CompletionStage <PartitionData >> result = new ArrayList <>(partitions .size ());
217- for (PartitionData partitionData : partitions ) {
218- if (!(partitionData .records () instanceof MemoryRecords )) {
299+ private <T > CompletionStage <List <T >> maybeDecodePartitions (String topicName ,
300+ List <T > partitions ,
301+ FilterContext context ,
302+ Function <T , BaseRecords > recordExtractor ,
303+ ToIntFunction <T > partitionIndexExtractor ,
304+ Function <T , Function <MemoryRecords , T >> setRecords ,
305+ BiFunction <T , Short , T > errorsConsumer ) {
306+ List <CompletionStage <T >> result = new ArrayList <>(partitions .size ());
307+ for (T partitionData : partitions ) {
308+ BaseRecords records = recordExtractor .apply (partitionData );
309+ if (!(records instanceof MemoryRecords )) {
219310 throw new IllegalStateException ();
220311 }
221- var stage = maybeDecodeRecords (topicName , partitionData , (MemoryRecords ) partitionData .records (), context )
312+ var stage = maybeDecodeRecords (topicName , (MemoryRecords ) records , context , partitionIndexExtractor .applyAsInt (partitionData ),
313+ setRecords .apply (partitionData ))
222314 .exceptionallyCompose (t -> {
223315 var cause = t .getCause ();
224316 if (cause instanceof UnknownKeyException ) {
@@ -229,11 +321,11 @@ private CompletionStage<List<PartitionData>> maybeDecodePartitions(String topicN
229321 + "Cause message: {}. "
230322 + "Raise log level to DEBUG to see the stack." )
231323 .addArgument (topicName )
232- .addArgument (partitionData . partitionIndex ( ))
324+ .addArgument (() -> partitionIndexExtractor . applyAsInt ( partitionData ))
233325 .addArgument (cause .getMessage ())
234326 .setCause (log .isDebugEnabled () ? cause : null )
235327 .log ();
236- partitionData . setErrorCode ( Errors .RESOURCE_NOT_FOUND .code ());
328+ errorsConsumer . apply ( partitionData , Errors .RESOURCE_NOT_FOUND .code ());
237329 }
238330 return CompletableFuture .failedFuture (t );
239331 });
@@ -242,16 +334,16 @@ private CompletionStage<List<PartitionData>> maybeDecodePartitions(String topicN
242334 return RecordEncryptionUtil .join (result );
243335 }
244336
245- private CompletionStage <PartitionData > maybeDecodeRecords (String topicName ,
246- PartitionData fpr ,
247- MemoryRecords memoryRecords ,
248- FilterContext context ) {
337+ private < T > CompletionStage <T > maybeDecodeRecords (String topicName ,
338+ MemoryRecords memoryRecords ,
339+ FilterContext context , int partition ,
340+ Function < MemoryRecords , T > setRecords ) {
249341 return decryptionManager .decrypt (
250342 topicName ,
251- fpr . partitionIndex () ,
343+ partition ,
252344 memoryRecords ,
253345 context ::createByteBufferOutputStream )
254- .thenApply (fpr :: setRecords );
346+ .thenApply (setRecords );
255347 }
256348
257349 private static boolean isEncryptionException (Throwable throwable ) {
0 commit comments