diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java index 70c21f8b610b1..38e0664914498 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java @@ -17,7 +17,6 @@ package org.apache.kafka.streams.state.internals; import org.apache.kafka.common.TopicPartition; -import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.metrics.Sensor; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.common.utils.Bytes; @@ -59,8 +58,6 @@ import static org.apache.kafka.common.utils.Utils.mkMap; import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; -// TODO: replace with new method in follow-up PR of KIP-1271 -@SuppressWarnings("deprecation") public class MeteredSessionStore extends WrappedStateStore, Windowed, V> implements SessionStore, MeteredStateStore { @@ -93,11 +90,13 @@ public class MeteredSessionStore ); - MeteredSessionStore(final SessionStore inner, - final String metricsScope, - final Serde keySerde, - final Serde valueSerde, - final Time time) { + MeteredSessionStore( + final SessionStore inner, + final String metricsScope, + final Serde keySerde, + final Serde valueSerde, + final Time time + ) { super(inner); this.metricsScope = metricsScope; this.keySerde = keySerde; @@ -106,8 +105,10 @@ public class MeteredSessionStore } @Override - public void init(final StateStoreContext stateStoreContext, - final StateStore root) { + public void init( + final StateStoreContext stateStoreContext, + final StateStore root + ) { internalContext = stateStoreContext instanceof InternalProcessorContext ? (InternalProcessorContext) stateStoreContext : null; taskId = stateStoreContext.taskId(); initStoreSerde(stateStoreContext); @@ -180,27 +181,35 @@ private void initStoreSerde(final StateStoreContext context) { @SuppressWarnings("unchecked") @Override - public boolean setFlushListener(final CacheFlushListener, V> listener, - final boolean sendOldValues) { + public boolean setFlushListener( + final CacheFlushListener, V> listener, + final boolean sendOldValues + ) { final SessionStore wrapped = wrapped(); if (wrapped instanceof CachedStateStore) { return ((CachedStateStore) wrapped).setFlushListener( - record -> listener.apply( - record.withKey(SessionKeySchema.from(record.key(), serdes.keyDeserializer(), record.headers(), serdes.topic())) - .withValue(new Change<>( - record.value().newValue != null ? serdes.valueFrom(record.value().newValue, record.headers()) : null, - record.value().oldValue != null ? serdes.valueFrom(record.value().oldValue, record.headers()) : null, - record.value().isLatest - )) - ), + record -> { + final Change change = record.value(); + listener.apply( + record + .withKey(SessionKeySchema.from(record.key(), serdes.keyDeserializer(), record.headers(), serdes.topic())) + .withValue(new Change<>( + change.newValue != null ? serdes.valueFrom(change.newValue, record.headers()) : null, + change.oldValue != null ? serdes.valueFrom(change.oldValue, record.headers()) : null, + change.isLatest + )) + ); + }, sendOldValues); } return false; } @Override - public void put(final Windowed sessionKey, - final V aggregate) { + public void put( + final Windowed sessionKey, + final V aggregate + ) { Objects.requireNonNull(sessionKey, "sessionKey can't be null"); Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null"); Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't be null"); @@ -208,8 +217,8 @@ public void put(final Windowed sessionKey, try { maybeMeasureLatency( () -> { - final Bytes key = keyBytes(sessionKey.key()); - wrapped().put(new Windowed<>(key, sessionKey.window()), serdes.rawValue(aggregate)); + final Bytes key = serializeKey(sessionKey.key()); + wrapped().put(new Windowed<>(key, sessionKey.window()), serializeValue(aggregate)); }, time, putSensor @@ -230,7 +239,7 @@ public void remove(final Windowed sessionKey) { try { maybeMeasureLatency( () -> { - final Bytes key = keyBytes(sessionKey.key()); + final Bytes key = serializeKey(sessionKey.key()); wrapped().remove(new Windowed<>(key, sessionKey.window())); }, time, @@ -246,18 +255,7 @@ public void remove(final Windowed sessionKey) { public V fetchSession(final K key, final long earliestSessionEndTime, final long latestSessionStartTime) { Objects.requireNonNull(key, "key cannot be null"); return maybeMeasureLatency( - () -> { - final Bytes bytesKey = keyBytes(key); - final byte[] result = wrapped().fetchSession( - bytesKey, - earliestSessionEndTime, - latestSessionStartTime - ); - if (result == null) { - return null; - } - return serdes.valueFrom(result); - }, + () -> deserializeValue(wrapped().fetchSession(serializeKey(key), earliestSessionEndTime, latestSessionStartTime)), time, fetchSensor ); @@ -267,25 +265,26 @@ public V fetchSession(final K key, final long earliestSessionEndTime, final long public KeyValueIterator, V> fetch(final K key) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredWindowedKeyValueIterator<>( - wrapped().fetch(keyBytes(key)), + wrapped().fetch(serializeKey(key)), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, - openIterators); + openIterators + ); } @Override public KeyValueIterator, V> backwardFetch(final K key) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredWindowedKeyValueIterator<>( - wrapped().backwardFetch(keyBytes(key)), + wrapped().backwardFetch(serializeKey(key)), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, openIterators @@ -293,28 +292,33 @@ public KeyValueIterator, V> backwardFetch(final K key) { } @Override - public KeyValueIterator, V> fetch(final K keyFrom, - final K keyTo) { + public KeyValueIterator, V> fetch( + final K keyFrom, + final K keyTo + ) { return new MeteredWindowedKeyValueIterator<>( - wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo)), + wrapped().fetch(serializeKey(keyFrom), serializeKey(keyTo)), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, - openIterators); + openIterators + ); } @Override - public KeyValueIterator, V> backwardFetch(final K keyFrom, - final K keyTo) { + public KeyValueIterator, V> backwardFetch( + final K keyFrom, + final K keyTo + ) { return new MeteredWindowedKeyValueIterator<>( - wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo)), + wrapped().backwardFetch(serializeKey(keyFrom), serializeKey(keyTo)), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, openIterators @@ -322,11 +326,13 @@ public KeyValueIterator, V> backwardFetch(final K keyFrom, } @Override - public KeyValueIterator, V> findSessions(final K key, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, V> findSessions( + final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { Objects.requireNonNull(key, "key cannot be null"); - final Bytes bytesKey = keyBytes(key); + final Bytes bytesKey = serializeKey(key); return new MeteredWindowedKeyValueIterator<>( wrapped().findSessions( bytesKey, @@ -334,19 +340,22 @@ public KeyValueIterator, V> findSessions(final K key, latestSessionStartTime), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, - openIterators); + openIterators + ); } @Override - public KeyValueIterator, V> backwardFindSessions(final K key, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, V> backwardFindSessions( + final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { Objects.requireNonNull(key, "key cannot be null"); - final Bytes bytesKey = keyBytes(key); + final Bytes bytesKey = serializeKey(key); return new MeteredWindowedKeyValueIterator<>( wrapped().backwardFindSessions( bytesKey, @@ -355,8 +364,8 @@ public KeyValueIterator, V> backwardFindSessions(final K key, ), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, openIterators @@ -364,12 +373,14 @@ public KeyValueIterator, V> backwardFindSessions(final K key, } @Override - public KeyValueIterator, V> findSessions(final K keyFrom, - final K keyTo, - final long earliestSessionEndTime, - final long latestSessionStartTime) { - final Bytes bytesKeyFrom = keyBytes(keyFrom); - final Bytes bytesKeyTo = keyBytes(keyTo); + public KeyValueIterator, V> findSessions( + final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { + final Bytes bytesKeyFrom = serializeKey(keyFrom); + final Bytes bytesKeyTo = serializeKey(keyTo); return new MeteredWindowedKeyValueIterator<>( wrapped().findSessions( bytesKeyFrom, @@ -378,34 +389,40 @@ public KeyValueIterator, V> findSessions(final K keyFrom, latestSessionStartTime), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, - openIterators); + openIterators + ); } @Override - public KeyValueIterator, V> findSessions(final long earliestSessionEndTime, - final long latestSessionEndTime) { + public KeyValueIterator, V> findSessions( + final long earliestSessionEndTime, + final long latestSessionEndTime + ) { return new MeteredWindowedKeyValueIterator<>( - wrapped().findSessions(earliestSessionEndTime, latestSessionEndTime), - fetchSensor, - iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, - time, - numOpenIterators, - openIterators); + wrapped().findSessions(earliestSessionEndTime, latestSessionEndTime), + fetchSensor, + iteratorDurationSensor, + this::deserializeKey, + this::deserializeValue, + time, + numOpenIterators, + openIterators + ); } @Override - public KeyValueIterator, V> backwardFindSessions(final K keyFrom, - final K keyTo, - final long earliestSessionEndTime, - final long latestSessionStartTime) { - final Bytes bytesKeyFrom = keyBytes(keyFrom); - final Bytes bytesKeyTo = keyBytes(keyTo); + public KeyValueIterator, V> backwardFindSessions( + final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { + final Bytes bytesKeyFrom = serializeKey(keyFrom); + final Bytes bytesKeyTo = serializeKey(keyTo); return new MeteredWindowedKeyValueIterator<>( wrapped().backwardFindSessions( bytesKeyFrom, @@ -415,8 +432,8 @@ public KeyValueIterator, V> backwardFindSessions(final K keyFrom, ), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - serdes::valueFrom, + this::deserializeKey, + this::deserializeValue, time, numOpenIterators, openIterators @@ -477,9 +494,7 @@ private QueryResult runRangeQuery( final WindowRangeQuery typedQuery = (WindowRangeQuery) query; if (typedQuery.getKey().isPresent()) { final WindowRangeQuery rawKeyQuery = - WindowRangeQuery.withKey( - Bytes.wrap(serdes.rawKey(typedQuery.getKey().get())) - ); + WindowRangeQuery.withKey(serializeKey(typedQuery.getKey().get())); final QueryResult, byte[]>> rawResult = wrapped().query(rawKeyQuery, positionBound, config); if (rawResult.isSuccess()) { @@ -488,7 +503,7 @@ private QueryResult runRangeQuery( rawResult.getResult(), fetchSensor, iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), + this::deserializeKey, StoreQueryUtils.deserializeValue(serdes, wrapped()), time, numOpenIterators, @@ -502,7 +517,6 @@ private QueryResult runRangeQuery( result = (QueryResult) rawResult; } } else { - result = QueryResult.forFailure( FailureReason.UNKNOWN_QUERY_TYPE, "This store (" + getClass() + ") doesn't know how to" @@ -515,8 +529,20 @@ private QueryResult runRangeQuery( return result; } - private Bytes keyBytes(final K key) { - return key == null ? null : Bytes.wrap(serdes.rawKey(key, new RecordHeaders())); + private Bytes serializeKey(final K key) { + return Bytes.wrap(serdes.rawKey(key, internalContext.headers())); + } + + private K deserializeKey(final byte[] rawKey) { + return serdes.keyFrom(rawKey, internalContext.headers()); + } + + protected byte[] serializeValue(final V value) { + return value != null ? serdes.rawValue(value, internalContext.headers()) : null; + } + + protected V deserializeValue(final byte[] rawValue) { + return rawValue != null ? serdes.valueFrom(rawValue, internalContext.headers()) : null; } void maybeRecordE2ELatency() { diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java index 45a1b17ecc82f..cbddac333cc79 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java @@ -41,17 +41,18 @@ import java.util.Objects; import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; -import static org.apache.kafka.streams.state.internals.Utils.keyBytes; public class MeteredSessionStoreWithHeaders extends MeteredSessionStore> implements SessionStoreWithHeaders { - MeteredSessionStoreWithHeaders(final SessionStore inner, - final String metricsScope, - final Serde keySerde, - final Serde> aggSerde, - final Time time) { + MeteredSessionStoreWithHeaders( + final SessionStore inner, + final String metricsScope, + final Serde keySerde, + final Serde> aggSerde, + final Time time + ) { super(inner, metricsScope, keySerde, aggSerde, time); } @@ -59,13 +60,22 @@ public class MeteredSessionStoreWithHeaders @Override protected Serde> prepareValueSerdeForStore( final Serde> valueSerde, - final SerdeGetter getter) { + final SerdeGetter getter + ) { if (valueSerde == null) { return new AggregationWithHeadersSerde<>((Serde) getter.valueSerde()); } return super.prepareValueSerdeForStore(valueSerde, getter); } + private Bytes serializeKey(final K key, final Headers headers) { + return Bytes.wrap(serdes.rawKey(key, headers)); + } + + private K deserializeKey(final byte[] rawKey, final Headers headers) { + return serdes.keyFrom(rawKey, headers); + } + @Override public void put(final Windowed sessionKey, final AggregationWithHeaders aggregate) { Objects.requireNonNull(sessionKey, "sessionKey can't be null"); @@ -89,16 +99,27 @@ public void put(final Windowed sessionKey, final AggregationWithHeaders try { internalContext.setRecordContext(temporaryContext); - final Bytes key = keyBytes(sessionKey, deleteHeaders, serdes); - wrapped().put(new Windowed<>(key, sessionKey.window()), serdes.rawValue(null, deleteHeaders)); + wrapped().put( + new Windowed<>( + serializeKey(sessionKey.key(), deleteHeaders), + sessionKey.window() + ), + null + ); } finally { // Restore original context internalContext.setRecordContext(currentContext); } } else { - final Headers headers = aggregate.headers(); - final Bytes key = keyBytes(sessionKey, headers, serdes); - wrapped().put(new Windowed<>(key, sessionKey.window()), serdes.rawValue(aggregate, headers)); + // it's ok to only pass headers into `serializeKey`, because for the value case passed-in headers are + // getting ignored anyway, because the value (of type `AggregationWithHeaders`) itself carries the headers + wrapped().put( + new Windowed<>( + serializeKey(sessionKey.key(), aggregate.headers()), + sessionKey.window() + ), + serializeValue(aggregate) + ); } }, time, @@ -137,8 +158,9 @@ public void remove(final Windowed sessionKey) { try { internalContext.setRecordContext(temporaryContext); - final Bytes key = keyBytes(sessionKey, deleteHeaders, serdes); - wrapped().remove(new Windowed<>(key, sessionKey.window())); + wrapped().remove( + new Windowed<>(serializeKey(sessionKey.key(), deleteHeaders), sessionKey.window()) + ); } finally { // Restore original context internalContext.setRecordContext(currentContext); @@ -154,30 +176,19 @@ public void remove(final Windowed sessionKey) { @SuppressWarnings("unchecked") @Override - public QueryResult query(final Query query, - final PositionBound positionBound, - final QueryConfig config) { - final long start = time.nanoseconds(); + public QueryResult query( + final Query query, + final PositionBound positionBound, + final QueryConfig config + ) { + final long start = config.isCollectExecutionInfo() ? System.nanoTime() : -1L; final QueryResult result; if (query instanceof WindowRangeQuery) { - final WindowRangeQuery windowRangeQuery = (WindowRangeQuery) query; - if (windowRangeQuery.getKey().isPresent()) { - result = runRangeQuery(query, positionBound, config); - } else { - result = QueryResult.forFailure( - FailureReason.UNKNOWN_QUERY_TYPE, - "This store (" + getClass() + ") doesn't know how to" - + " execute the given query (" + query + ") because" - + " SessionStores only support WindowRangeQuery.withKey." - + " Contact the store maintainer if you need support" - + " for a new query type." - ); - } + result = runRangeQuery((WindowRangeQuery) query, positionBound, config); if (config.isCollectExecutionInfo()) { result.addExecutionInfo( - "Handled in " + getClass() + " with serdes " - + serdes + " in " + (time.nanoseconds() - start) + "ns"); + "Handled in " + getClass() + " with serdes " + serdes + " in " + (time.nanoseconds() - start) + "ns"); } } else { result = wrapped().query(query, positionBound, config); @@ -193,7 +204,7 @@ public QueryResult query(final Query query, public KeyValueIterator, AggregationWithHeaders> fetch(final K key) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredSessionStoreWithHeadersIterator( - wrapped().fetch(keyBytes(key, new RecordHeaders(), serdes)) + wrapped().fetch(serializeKey(key, internalContext.headers())) ); } @@ -201,123 +212,152 @@ public KeyValueIterator, AggregationWithHeaders> fetch(final K public KeyValueIterator, AggregationWithHeaders> backwardFetch(final K key) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredSessionStoreWithHeadersIterator( - wrapped().backwardFetch(keyBytes(key, new RecordHeaders(), serdes)) + wrapped().backwardFetch(serializeKey(key, internalContext.headers())) ); } @Override - public KeyValueIterator, AggregationWithHeaders> fetch(final K keyFrom, - final K keyTo) { + public KeyValueIterator, AggregationWithHeaders> fetch( + final K keyFrom, + final K keyTo + ) { return new MeteredSessionStoreWithHeadersIterator( wrapped().fetch( - keyBytes(keyFrom, new RecordHeaders(), serdes), - keyBytes(keyTo, new RecordHeaders(), serdes)) + serializeKey(keyFrom, internalContext.headers()), + serializeKey(keyTo, internalContext.headers()) + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> backwardFetch(final K keyFrom, - final K keyTo) { + public KeyValueIterator, AggregationWithHeaders> backwardFetch( + final K keyFrom, + final K keyTo + ) { return new MeteredSessionStoreWithHeadersIterator( wrapped().backwardFetch( - keyBytes(keyFrom, new RecordHeaders(), serdes), - keyBytes(keyTo, new RecordHeaders(), serdes)) + serializeKey(keyFrom, internalContext.headers()), + serializeKey(keyTo, internalContext.headers()) + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> findSessions(final K key, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, AggregationWithHeaders> findSessions( + final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredSessionStoreWithHeadersIterator( wrapped().findSessions( - keyBytes(key, new RecordHeaders(), serdes), + serializeKey(key, internalContext.headers()), earliestSessionEndTime, - latestSessionStartTime) + latestSessionStartTime + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> backwardFindSessions(final K key, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, AggregationWithHeaders> backwardFindSessions( + final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { Objects.requireNonNull(key, "key cannot be null"); return new MeteredSessionStoreWithHeadersIterator( wrapped().backwardFindSessions( - keyBytes(key, new RecordHeaders(), serdes), + serializeKey(key, internalContext.headers()), earliestSessionEndTime, - latestSessionStartTime) + latestSessionStartTime + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> findSessions(final K keyFrom, - final K keyTo, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, AggregationWithHeaders> findSessions( + final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { return new MeteredSessionStoreWithHeadersIterator( wrapped().findSessions( - keyBytes(keyFrom, new RecordHeaders(), serdes), - keyBytes(keyTo, new RecordHeaders(), serdes), + serializeKey(keyFrom, internalContext.headers()), + serializeKey(keyTo, internalContext.headers()), earliestSessionEndTime, - latestSessionStartTime) + latestSessionStartTime + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> backwardFindSessions(final K keyFrom, - final K keyTo, - final long earliestSessionEndTime, - final long latestSessionStartTime) { + public KeyValueIterator, AggregationWithHeaders> backwardFindSessions( + final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime + ) { return new MeteredSessionStoreWithHeadersIterator( wrapped().backwardFindSessions( - keyBytes(keyFrom, new RecordHeaders(), serdes), - keyBytes(keyTo, new RecordHeaders(), serdes), + serializeKey(keyFrom, internalContext.headers()), + serializeKey(keyTo, internalContext.headers()), earliestSessionEndTime, - latestSessionStartTime) + latestSessionStartTime + ) ); } @Override - public KeyValueIterator, AggregationWithHeaders> findSessions(final long earliestSessionEndTime, - final long latestSessionEndTime) { - return new MeteredSessionStoreWithHeadersIterator( - wrapped().findSessions(earliestSessionEndTime, latestSessionEndTime) - ); + public KeyValueIterator, AggregationWithHeaders> findSessions( + final long earliestSessionEndTime, + final long latestSessionEndTime + ) { + return new MeteredSessionStoreWithHeadersIterator(wrapped().findSessions(earliestSessionEndTime, latestSessionEndTime)); } @SuppressWarnings("unchecked") - private QueryResult runRangeQuery(final Query query, - final PositionBound positionBound, - final QueryConfig config) { - final WindowRangeQuery typedQuery = (WindowRangeQuery) query; - final WindowRangeQuery rawKeyQuery = - WindowRangeQuery.withKey( - Bytes.wrap(serdes.rawKey(typedQuery.getKey().get(), new RecordHeaders())) - ); - final QueryResult, byte[]>> rawResult = - wrapped().query(rawKeyQuery, positionBound, config); - if (rawResult.isSuccess()) { - final MeteredWindowedKeyValueIterator typedResult = - new MeteredWindowedKeyValueIterator<>( - rawResult.getResult(), - fetchSensor, - iteratorDurationSensor, - bytes -> serdes.keyFrom(bytes, new RecordHeaders()), - byteArray -> { - final AggregationWithHeaders awh = - serdes.valueDeserializer().deserialize(serdes.topic(), byteArray); - return awh == null ? null : awh.aggregation(); - }, - time, - numOpenIterators, - openIterators - ); - return (QueryResult) InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); + private QueryResult runRangeQuery( + final WindowRangeQuery query, + final PositionBound positionBound, + final QueryConfig config + ) { + final QueryResult queryResult; + + if (query.getKey().isPresent()) { + final WindowRangeQuery rawKeyQuery = + WindowRangeQuery.withKey(serializeKey(query.getKey().get(), internalContext.headers())); + final QueryResult, byte[]>> rawResult = + wrapped().query(rawKeyQuery, positionBound, config); + if (rawResult.isSuccess()) { + final MeteredWindowedKeyValueIterator typedResult = + new MeteredWindowedKeyValueWithHeadersIterator<>( + rawResult.getResult(), + fetchSensor, + iteratorDurationSensor, + this::deserializeValue, + this::deserializeKey, + AggregationWithHeaders::headers, + aggregationWithHeaders -> aggregationWithHeaders == null ? null : aggregationWithHeaders.aggregation(), + time, + numOpenIterators, + openIterators + ); + queryResult = (QueryResult) InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); + } else { + queryResult = (QueryResult) rawResult; + } } else { - return (QueryResult) rawResult; + queryResult = QueryResult.forFailure( + FailureReason.UNKNOWN_QUERY_TYPE, + "This store (" + getClass() + ") doesn't know how to" + + " execute the given query (" + query + ") because" + + " SessionStores only support WindowRangeQuery.withKey." + + " Contact the store maintainer if you need support" + + " for a new query type." + ); } + return queryResult; } private class MeteredSessionStoreWithHeadersIterator @@ -356,9 +396,9 @@ public KeyValue, AggregationWithHeaders> next() { final KeyValue, byte[]> next = iter.next(); - final AggregationWithHeaders value = serdes.valueFrom(next.value, new RecordHeaders()); + final AggregationWithHeaders value = deserializeValue(next.value); final Headers headers = value != null ? value.headers() : new RecordHeaders(); - final K key = serdes.keyFrom(next.key.key().get(), headers); + final K key = deserializeKey(next.key.key().get(), headers); final Windowed windowedKey = new Windowed<>(key, next.key.window()); return KeyValue.pair(windowedKey, value); } diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java index 6d4cf93bebad6..c1ab307001baa 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java @@ -18,7 +18,6 @@ import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.header.internals.RecordHeaders; -import org.apache.kafka.common.metrics.Sensor; import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.common.utils.Bytes; import org.apache.kafka.common.utils.Time; @@ -45,9 +44,6 @@ import org.apache.kafka.streams.state.WindowStoreIterator; import java.util.Objects; -import java.util.Set; -import java.util.concurrent.atomic.LongAdder; -import java.util.function.BiFunction; import java.util.function.Function; import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; @@ -190,10 +186,7 @@ private QueryResult runWindowKeyQuery( return vth == null ? null : ValueAndTimestamp.make(vth.value(), vth.timestamp()); } ); - - final QueryResult>> typedQueryResult = - InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); - queryResult = (QueryResult) typedQueryResult; + queryResult = (QueryResult) InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); } else { // For non-timestamped stores, return plain V final MeteredWindowStoreIterator typedResult = meteredIterator( @@ -203,10 +196,7 @@ private QueryResult runWindowKeyQuery( return vth == null ? null : vth.value(); } ); - - final QueryResult> typedQueryResult = - InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); - queryResult = (QueryResult) typedQueryResult; + queryResult = (QueryResult) InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, typedResult); } } else { queryResult = (QueryResult) rawResult; @@ -430,7 +420,9 @@ private MeteredWindowedKeyValueIterator meteredWindowe rawResult.getResult(), fetchSensor, iteratorDurationSensor, + this::deserializeValue, this::deserializeKey, + ValueTimestampHeaders::headers, valueConverter, time, numOpenIterators, @@ -438,51 +430,6 @@ private MeteredWindowedKeyValueIterator meteredWindowe ); } - private final class MeteredWindowedKeyValueWithHeadersIterator extends MeteredWindowedKeyValueIterator { - private final BiFunction deserializeKey; - private final Function, ValueType> valueConverter; - - MeteredWindowedKeyValueWithHeadersIterator( - final KeyValueIterator, byte[]> iter, - final Sensor operationSensor, - final Sensor iteratorSensor, - final BiFunction deserializeKey, - final Function, ValueType> valueConverter, - final Time time, - final LongAdder numOpenIterators, - final Set openIterators - ) { - super( - iter, - operationSensor, - iteratorSensor, - null, // should not be used in super-class - null, // should not be used in super-class - time, - numOpenIterators, - openIterators - ); - - this.deserializeKey = deserializeKey; - this.valueConverter = valueConverter; - } - - @Override - public KeyValue, ValueType> next() { - final KeyValue, byte[]> next = iter.next(); - final ValueTimestampHeaders valueTimestampHeaders = deserializeValue(next.value); - return KeyValue.pair( - windowedKey(next.key, valueTimestampHeaders.headers()), - valueConverter.apply(valueTimestampHeaders) - ); - } - - private Windowed windowedKey(final Windowed bytesKey, final Headers headers) { - final K key = deserializeKey.apply(bytesKey.key().get(), headers); - return new Windowed<>(key, bytesKey.window()); - } - } - private boolean isUnderlyingStoreTimestamped() { StateStore store = wrapped(); do { @@ -507,16 +454,11 @@ private boolean isUnderlyingStoreTimestamped() { return false; } - protected Bytes serializeKey(final K key, final Headers headers) { + private Bytes serializeKey(final K key, final Headers headers) { return Bytes.wrap(serdes.rawKey(key, headers)); } - @Override - protected K deserializeKey(final byte[] rawKey) { - throw new UnsupportedOperationException("MeteredTimestampedWindowStoreWithHeaders required to pass in Headers when deserializing a key."); - } - - protected K deserializeKey(final byte[] rawKey, final Headers headers) { + private K deserializeKey(final byte[] rawKey, final Headers headers) { return serdes.keyFrom(rawKey, headers); } diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java index 2320d48af8f23..44e7d1b4510bc 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java @@ -541,7 +541,7 @@ private Bytes serializeKey(final K key) { return Bytes.wrap(serdes.rawKey(key, internalContext.headers())); } - protected K deserializeKey(final byte[] rawKey) { + private K deserializeKey(final byte[] rawKey) { return serdes.keyFrom(rawKey, internalContext.headers()); } diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java new file mode 100644 index 0000000000000..d83a6ea99527d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.Set; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.BiFunction; +import java.util.function.Function; + +final class MeteredWindowedKeyValueWithHeadersIterator extends MeteredWindowedKeyValueIterator { + private final Function deserializeValue; + private final BiFunction deserializeKey; + private final Function headersExtractor; + private final Function valueConverter; + + MeteredWindowedKeyValueWithHeadersIterator( + final KeyValueIterator, byte[]> iter, + final Sensor operationSensor, + final Sensor iteratorSensor, + final Function deserializeValue, + final BiFunction deserializeKey, + final Function headersExtractor, + final Function valueConverter, + final Time time, + final LongAdder numOpenIterators, + final Set openIterators + ) { + super( + iter, + operationSensor, + iteratorSensor, + null, // should not be used in super-class + null, // should not be used in super-class + time, + numOpenIterators, + openIterators + ); + + this.deserializeValue = deserializeValue; + this.deserializeKey = deserializeKey; + this.headersExtractor = headersExtractor; + this.valueConverter = valueConverter; + } + + @Override + public KeyValue, VOuter> next() { + final KeyValue, byte[]> next = iter.next(); + final VInner valueTimestampHeaders = deserializeValue.apply(next.value); + return KeyValue.pair( + windowedKey(next.key, headersExtractor.apply(valueTimestampHeaders)), + valueConverter.apply(valueTimestampHeaders) + ); + } + + private Windowed windowedKey(final Windowed bytesKey, final Headers headers) { + final K key = deserializeKey.apply(bytesKey.key().get(), headers); + return new Windowed<>(key, bytesKey.window()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java index 6e9ad91f8107d..6b9d382758775 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java @@ -20,9 +20,7 @@ import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.serialization.LongDeserializer; -import org.apache.kafka.common.utils.Bytes; import org.apache.kafka.common.utils.internals.ByteUtils; -import org.apache.kafka.streams.kstream.Windowed; import org.apache.kafka.streams.state.StateSerdes; import java.nio.ByteBuffer; @@ -54,38 +52,6 @@ public static Headers headers(final byte[] valueWithHeaders) { return readHeaders(buffer); } - /** - * Serialize the key with headers into bytes - * @param key the key to serialize - * @param headers the Headers as context - * @param serdes the StateSerdes as serializer - * @return the Bytes of the key - */ - public static Bytes keyBytes(final K key, final Headers headers, final StateSerdes serdes) { - return Bytes.wrap(serdes.rawKey(key, headers)); - } - - /** - * Serialize the key into bytes - * @param key the key to serialize - * @param serdes the StateSerdes as serializer - * @return the Bytes of the key - */ - static Bytes keyBytes(final K key, final StateSerdes serdes) { - return keyBytes(key, new RecordHeaders(), serdes); - } - - /** - * Serialize the session key with headers into bytes - * @param sessionKey the Windowed session key to serialize - * @param headers the Headers as context - * @param serdes the StateSerdes as serializer - * @return the Bytes of the key - */ - static Bytes keyBytes(final Windowed sessionKey, final Headers headers, final StateSerdes serdes) { - return keyBytes(sessionKey.key(), headers, serdes); - } - /** * Extract the raw aggregation bytes from serialized AggregationWithHeaders, * stripping the headers prefix. diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java index 78b49786f75fe..21ae08edd8985 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java @@ -103,6 +103,7 @@ public class MeteredSessionStoreTest { private static final Windowed WINDOWED_KEY_BYTES = new Windowed<>(KEY_BYTES, new SessionWindow(0, 0)); private static final String VALUE = "value"; private static final byte[] VALUE_BYTES = VALUE.getBytes(); + private static final Headers HEADERS = new RecordHeaders(); private static final long START_TIMESTAMP = 24L; private static final long END_TIMESTAMP = 42L; private static final int RETENTION_PERIOD = 100; @@ -139,8 +140,7 @@ public void setUp() { setUpWithoutContext(); metrics.config().recordLevel(Sensor.RecordingLevel.DEBUG); when(context.applicationId()).thenReturn(APPLICATION_ID); - when(context.metrics()) - .thenReturn(new StreamsMetricsImpl(metrics, "test", mockTime)); + when(context.metrics()).thenReturn(new StreamsMetricsImpl(metrics, "test", mockTime)); when(context.taskId()).thenReturn(taskId); when(context.changelogFor(STORE_NAME)).thenReturn(CHANGELOG_TOPIC); when(innerStore.name()).thenReturn(STORE_NAME); @@ -187,12 +187,13 @@ private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) final Deserializer valueDeserializer = mock(Deserializer.class); final Serializer valueSerializer = mock(Serializer.class); when(keySerde.serializer()).thenReturn(keySerializer); - when(keySerializer.serialize(topic, new RecordHeaders(), KEY)).thenReturn(KEY.getBytes()); + when(keySerializer.serialize(topic, HEADERS, KEY)).thenReturn(KEY.getBytes()); when(valueSerde.deserializer()).thenReturn(valueDeserializer); - when(valueDeserializer.deserialize(topic, new RecordHeaders(), VALUE_BYTES)).thenReturn(VALUE); + when(valueDeserializer.deserialize(topic, HEADERS, VALUE_BYTES)).thenReturn(VALUE); when(valueSerde.serializer()).thenReturn(valueSerializer); - when(valueSerializer.serialize(topic, new RecordHeaders(), VALUE)).thenReturn(VALUE_BYTES); + when(valueSerializer.serialize(topic, HEADERS, VALUE)).thenReturn(VALUE_BYTES); when(innerStore.fetchSession(KEY_BYTES, START_TIMESTAMP, END_TIMESTAMP)).thenReturn(VALUE_BYTES); + when(context.headers()).thenReturn(HEADERS); store = new MeteredSessionStore<>( innerStore, STORE_TYPE, diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java index f95254237d6bb..aafa28d00b27b 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java @@ -841,6 +841,8 @@ private MeteredSessionStoreWithHeaders createStoreWithMockSerdes lenient().when(keyDeserializer.deserialize(any(), eq(HEADERS), eq(KEY.getBytes()))) .thenReturn(KEY); + when(context.headers()).thenReturn(new RecordHeaders()); + final MeteredSessionStoreWithHeaders mockStore = new MeteredSessionStoreWithHeaders<>( innerStore, STORE_TYPE,