Skip to content

Commit 13531b8

Browse files
committed
feat(record-encryption): decrypt share fetch responses
Apache Kafka is about to promote Queues for Kafka to stable. We now decrypt records fetched by Share Group Consumers from encrypted topics. If we cannot discover the name for a topic id, we: * mark all share fetch response partitions for that topic id with the mapping error code and message. * empty the memory records for all share fetch response partitions for that topic id. This means that if we can obtain topic name for some of the response topics we will try to decode and return those topic partitions, and if we cannot obtain topic name those partitions will be marked as errored. Signed-off-by: Robert Young <robertyoungnz@gmail.com>
1 parent 9dfac8a commit 13531b8

6 files changed

Lines changed: 512 additions & 38 deletions

File tree

kroxylicious-filters/kroxylicious-record-encryption/etc/module-layering.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
<allow pkg="org.apache.kafka.common.message" local-only="true"/>
2525
<allow pkg="org.apache.kafka.common.record" local-only="true"/>
2626
<allow pkg="org.apache.kafka.common.errors" local-only="true"/>
27+
<allow pkg="org.apache.kafka.common" local-only="true"/>
28+
<allow pkg="org.apache.kafka.common.utils" local-only="true"/>
2729
<allow pkg="org.apache.kafka.common.protocol" local-only="true"/>
2830
<allow pkg="io.kroxylicious.proxy.filter" local-only="true"/> <!-- the filter api -->
2931
<allow pkg="io.kroxylicious.proxy.plugin" local-only="true"/> <!-- the plugin api -->

kroxylicious-filters/kroxylicious-record-encryption/src/main/java/io/kroxylicious/filter/encryption/RecordEncryptionFilter.java

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
package io.kroxylicious.filter.encryption;
88

99
import java.util.ArrayList;
10+
import java.util.Collection;
1011
import java.util.EnumSet;
1112
import java.util.List;
1213
import java.util.Map;
1314
import java.util.Set;
1415
import java.util.concurrent.CompletableFuture;
1516
import java.util.concurrent.CompletionStage;
17+
import java.util.function.BiFunction;
1618
import java.util.function.Function;
19+
import java.util.function.ToIntFunction;
1720
import java.util.stream.Collectors;
1821

22+
import org.apache.kafka.common.Uuid;
1923
import org.apache.kafka.common.errors.ApiException;
2024
import org.apache.kafka.common.message.ApiVersionsResponseData;
2125
import org.apache.kafka.common.message.FetchResponseData;
@@ -25,9 +29,13 @@
2529
import org.apache.kafka.common.message.ProduceRequestData.TopicProduceData;
2630
import org.apache.kafka.common.message.RequestHeaderData;
2731
import org.apache.kafka.common.message.ResponseHeaderData;
32+
import org.apache.kafka.common.message.ShareFetchResponseData;
33+
import org.apache.kafka.common.message.ShareFetchResponseData.ShareFetchableTopicResponse;
2834
import org.apache.kafka.common.protocol.ApiKeys;
2935
import org.apache.kafka.common.protocol.Errors;
36+
import org.apache.kafka.common.record.BaseRecords;
3037
import org.apache.kafka.common.record.MemoryRecords;
38+
import org.apache.kafka.common.utils.ImplicitLinkedHashCollection;
3139
import org.slf4j.Logger;
3240

3341
import io.micrometer.core.instrument.Counter;
@@ -52,6 +60,9 @@
5260
import io.kroxylicious.proxy.filter.ProduceRequestFilter;
5361
import io.kroxylicious.proxy.filter.RequestFilterResult;
5462
import io.kroxylicious.proxy.filter.ResponseFilterResult;
63+
import io.kroxylicious.proxy.filter.ShareFetchResponseFilter;
64+
import io.kroxylicious.proxy.filter.metadata.TopicNameMapping;
65+
import io.kroxylicious.proxy.filter.metadata.TopicNameMappingException;
5566

5667
import edu.umd.cs.findbugs.annotations.NonNull;
5768

@@ -62,7 +73,7 @@
6273
* @param <K> The type of KEK reference
6374
*/
6475
public class RecordEncryptionFilter<K>
65-
implements ProduceRequestFilter, FetchResponseFilter, ApiVersionsResponseFilter {
76+
implements ProduceRequestFilter, FetchResponseFilter, ShareFetchResponseFilter, ApiVersionsResponseFilter {
6677
private static final Logger log = getLogger(RecordEncryptionFilter.class);
6778
private final TopicNameBasedKekSelector<K> kekSelector;
6879

@@ -123,7 +134,7 @@ else if (isEncryptionException(throwable.getCause())) {
123134

124135
private CompletionStage<ProduceRequestData> maybeEncodeProduce(ProduceRequestData request, FilterContext context) {
125136
var plainRecordsTotal = RecordEncryptionMetrics.plainRecordsCounter(context.getVirtualClusterName());
126-
var encyptedRecordsTotal = RecordEncryptionMetrics.encryptedRecordsCounter(context.getVirtualClusterName());
137+
var encryptedRecordsTotal = RecordEncryptionMetrics.encryptedRecordsCounter(context.getVirtualClusterName());
127138
var topicNameToData = request.topicData().stream().collect(Collectors.toMap(TopicProduceData::name, Function.identity()));
128139
CompletionStage<TopicNameKekSelection<K>> keks = filterThreadExecutor.completingOnFilterThread(kekSelector.selectKek(topicNameToData.keySet()));
129140
return keks // figure out what keks we need
@@ -149,7 +160,7 @@ private CompletionStage<ProduceRequestData> maybeEncodeProduce(ProduceRequestDat
149160
context::createByteBufferOutputStream)
150161
.thenApply(ppd::setRecords)
151162
.thenApply(produceData -> {
152-
encyptedRecordsTotal
163+
encryptedRecordsTotal
153164
.withTags(RecordEncryptionMetrics.TOPIC_NAME, topicName)
154165
.increment(RecordEncryptionUtil.totalRecordsInBatches((MemoryRecords) produceData.records()));
155166
return null;
@@ -176,6 +187,40 @@ private void generatePlainRecordsMetrics(Meter.MeterProvider<Counter> plainRecor
176187
.increment(RecordEncryptionUtil.totalRecordsInBatches((MemoryRecords) produceData.records()))));
177188
}
178189

190+
@Override
191+
public CompletionStage<ResponseFilterResult> onShareFetchResponse(short apiVersion, ResponseHeaderData header, ShareFetchResponseData response,
192+
FilterContext context) {
193+
Set<Uuid> topicIds = response.responses().stream().map(ShareFetchableTopicResponse::topicId).collect(Collectors.toSet());
194+
return context.topicNames(topicIds).thenCompose(topicNameMapping -> maybeDecodeShareFetch(topicNameMapping, response.responses(), context)
195+
.thenCompose(topicResponses -> {
196+
ShareFetchResponseData.ShareFetchableTopicResponseCollection collection = new ShareFetchResponseData.ShareFetchableTopicResponseCollection();
197+
// danger, the share fetch response uses a collection where the `add` method can silently do nothing
198+
// if the element being added already has element link fields populated it will not be added.
199+
// we reset next and prev so the existing elements can be added without calling `duplicate()`
200+
// which would copy the memory records again.
201+
// this may leave the original collection in an inconsistent state, but we are finished using it
202+
topicResponses.forEach(topicResponse -> {
203+
topicResponse.setNext(ImplicitLinkedHashCollection.INVALID_INDEX);
204+
topicResponse.setPrev(ImplicitLinkedHashCollection.INVALID_INDEX);
205+
collection.mustAdd(topicResponse);
206+
});
207+
return context.forwardResponse(header, response.setResponses(collection));
208+
})
209+
.exceptionallyCompose(throwable -> {
210+
if (throwable.getCause() instanceof UnknownKeyException) {
211+
// #maybeDecodePartitions will have set the RESOURCE_NOT_FOUND error code on the partition(s) that failed to decrypt
212+
// and will have logged the affected topic-partitions.
213+
// Remove all the records from the whole fetch to avoid the possibility that the client processes an incomplete response.
214+
response.responses().forEach(topicResponse -> topicResponse.partitions().forEach(p -> p.setRecords(MemoryRecords.EMPTY)));
215+
return context.forwardResponse(header, response);
216+
}
217+
else {
218+
// returning a failed stage is effectively asking the runtime to kill the connection.
219+
return logAndCreateFailedStage(throwable);
220+
}
221+
}));
222+
}
223+
179224
@Override
180225
public CompletionStage<ResponseFilterResult> onFetchResponse(short apiVersion, ResponseHeaderData header, FetchResponseData response, FilterContext context) {
181226
return maybeDecodeFetch(response.responses(), context)
@@ -189,36 +234,83 @@ public CompletionStage<ResponseFilterResult> onFetchResponse(short apiVersion, R
189234
return context.forwardResponse(header, response);
190235
}
191236
else {
192-
log.atWarn().setMessage("Failed to process records, connection will be closed, cause message: {}")
193-
.addArgument(throwable.getMessage())
194-
.setCause(log.isDebugEnabled() ? throwable : null)
195-
.log();
196237
// returning a failed stage is effectively asking the runtime to kill the connection.
197-
return CompletableFuture.failedStage(throwable);
238+
return logAndCreateFailedStage(throwable);
198239
}
199240
});
200241
}
201242

243+
private static CompletionStage<ResponseFilterResult> logAndCreateFailedStage(Throwable throwable) {
244+
log.atWarn().setMessage("Failed to process records, connection will be closed, cause message: {}. Raise log level to DEBUG to see the stack.")
245+
.addArgument(throwable.getMessage())
246+
.setCause(log.isDebugEnabled() ? throwable : null)
247+
.log();
248+
return CompletableFuture.failedStage(throwable);
249+
}
250+
251+
private CompletionStage<List<ShareFetchableTopicResponse>> maybeDecodeShareFetch(TopicNameMapping mapping, Collection<ShareFetchableTopicResponse> topics,
252+
FilterContext context) {
253+
List<CompletionStage<ShareFetchableTopicResponse>> result = new ArrayList<>(topics.size());
254+
for (ShareFetchableTopicResponse topicData : topics) {
255+
TopicNameMappingException failure = mapping.failures().get(topicData.topicId());
256+
String topicName = mapping.topicNames().get(topicData.topicId());
257+
List<ShareFetchResponseData.PartitionData> partitions = topicData.partitions();
258+
if (topicName != null) {
259+
result.add(maybeDecodePartitions(topicName, partitions, context, ShareFetchResponseData.PartitionData::records,
260+
ShareFetchResponseData.PartitionData::partitionIndex,
261+
partitionData -> partitionData::setRecords, ShareFetchResponseData.PartitionData::setErrorCode).thenApply(kk -> {
262+
topicData.setPartitions(kk);
263+
return topicData;
264+
}));
265+
}
266+
else {
267+
result.add(setAllPartitionsToError(topicData, failure));
268+
}
269+
}
270+
return RecordEncryptionUtil.join(result);
271+
}
272+
273+
private static CompletableFuture<ShareFetchableTopicResponse> setAllPartitionsToError(ShareFetchableTopicResponse topicData,
274+
TopicNameMappingException failure) {
275+
for (ShareFetchResponseData.PartitionData partition : topicData.partitions()) {
276+
partition.setRecords(MemoryRecords.EMPTY);
277+
Errors error = failure != null ? failure.getError() : Errors.UNKNOWN_SERVER_ERROR;
278+
partition.setErrorCode(error.code());
279+
partition.setErrorMessage(error.message());
280+
}
281+
return CompletableFuture.completedFuture(topicData);
282+
}
283+
202284
private CompletionStage<List<FetchableTopicResponse>> maybeDecodeFetch(List<FetchableTopicResponse> topics, FilterContext context) {
203285
List<CompletionStage<FetchableTopicResponse>> result = new ArrayList<>(topics.size());
204286
for (FetchableTopicResponse topicData : topics) {
205-
result.add(maybeDecodePartitions(topicData.topic(), topicData.partitions(), context).thenApply(kk -> {
206-
topicData.setPartitions(kk);
207-
return topicData;
208-
}));
287+
String topicName = topicData.topic();
288+
List<PartitionData> partitions = topicData.partitions();
289+
result.add(maybeDecodePartitions(topicName, partitions, context, PartitionData::records, PartitionData::partitionIndex,
290+
partitionData -> partitionData::setRecords,
291+
PartitionData::setErrorCode).thenApply(kk -> {
292+
topicData.setPartitions(kk);
293+
return topicData;
294+
}));
209295
}
210296
return RecordEncryptionUtil.join(result);
211297
}
212298

213-
private CompletionStage<List<PartitionData>> maybeDecodePartitions(String topicName,
214-
List<PartitionData> partitions,
215-
FilterContext context) {
216-
List<CompletionStage<PartitionData>> result = new ArrayList<>(partitions.size());
217-
for (PartitionData partitionData : partitions) {
218-
if (!(partitionData.records() instanceof MemoryRecords)) {
299+
private <T> CompletionStage<List<T>> maybeDecodePartitions(String topicName,
300+
List<T> partitions,
301+
FilterContext context,
302+
Function<T, BaseRecords> recordExtractor,
303+
ToIntFunction<T> partitionIndexExtractor,
304+
Function<T, Function<MemoryRecords, T>> setRecords,
305+
BiFunction<T, Short, T> errorsConsumer) {
306+
List<CompletionStage<T>> result = new ArrayList<>(partitions.size());
307+
for (T partitionData : partitions) {
308+
BaseRecords records = recordExtractor.apply(partitionData);
309+
if (!(records instanceof MemoryRecords)) {
219310
throw new IllegalStateException();
220311
}
221-
var stage = maybeDecodeRecords(topicName, partitionData, (MemoryRecords) partitionData.records(), context)
312+
var stage = maybeDecodeRecords(topicName, (MemoryRecords) records, context, partitionIndexExtractor.applyAsInt(partitionData),
313+
setRecords.apply(partitionData))
222314
.exceptionallyCompose(t -> {
223315
var cause = t.getCause();
224316
if (cause instanceof UnknownKeyException) {
@@ -229,11 +321,11 @@ private CompletionStage<List<PartitionData>> maybeDecodePartitions(String topicN
229321
+ "Cause message: {}. "
230322
+ "Raise log level to DEBUG to see the stack.")
231323
.addArgument(topicName)
232-
.addArgument(partitionData.partitionIndex())
324+
.addArgument(() -> partitionIndexExtractor.applyAsInt(partitionData))
233325
.addArgument(cause.getMessage())
234326
.setCause(log.isDebugEnabled() ? cause : null)
235327
.log();
236-
partitionData.setErrorCode(Errors.RESOURCE_NOT_FOUND.code());
328+
errorsConsumer.apply(partitionData, Errors.RESOURCE_NOT_FOUND.code());
237329
}
238330
return CompletableFuture.failedFuture(t);
239331
});
@@ -242,16 +334,16 @@ private CompletionStage<List<PartitionData>> maybeDecodePartitions(String topicN
242334
return RecordEncryptionUtil.join(result);
243335
}
244336

245-
private CompletionStage<PartitionData> maybeDecodeRecords(String topicName,
246-
PartitionData fpr,
247-
MemoryRecords memoryRecords,
248-
FilterContext context) {
337+
private <T> CompletionStage<T> maybeDecodeRecords(String topicName,
338+
MemoryRecords memoryRecords,
339+
FilterContext context, int partition,
340+
Function<MemoryRecords, T> setRecords) {
249341
return decryptionManager.decrypt(
250342
topicName,
251-
fpr.partitionIndex(),
343+
partition,
252344
memoryRecords,
253345
context::createByteBufferOutputStream)
254-
.thenApply(fpr::setRecords);
346+
.thenApply(setRecords);
255347
}
256348

257349
private static boolean isEncryptionException(Throwable throwable) {

0 commit comments

Comments
 (0)