From 9770833eb593e21165051daa0cbb2ed07fb57470 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Mon, 13 Apr 2026 22:49:34 +0000 Subject: [PATCH 01/25] Adding readManyByPartitionKey API --- .../azure/cosmos/spark/CosmosConstants.scala | 1 + .../cosmos/spark/CosmosItemsDataSource.scala | 109 ++++- .../spark/CosmosPartitionKeyHelper.scala | 46 +++ .../CosmosReadManyByPartitionKeyReader.scala | 153 +++++++ ...tionReaderWithReadManyByPartitionKey.scala | 233 +++++++++++ .../udf/GetCosmosPartitionKeyValue.scala | 25 ++ .../spark/CosmosPartitionKeyHelperSpec.scala | 81 ++++ ...eaderWithReadManyByPartitionKeyITest.scala | 158 ++++++++ .../cosmos/ReadManyByPartitionKeyTest.java | 381 ++++++++++++++++++ ...ReadManyByPartitionKeyQueryHelperTest.java | 349 ++++++++++++++++ .../azure/cosmos/CosmosAsyncContainer.java | 111 +++++ .../com/azure/cosmos/CosmosContainer.java | 67 +++ .../implementation/AsyncDocumentClient.java | 21 + .../azure/cosmos/implementation/Configs.java | 19 + .../ReadManyByPartitionKeyQueryHelper.java | 199 +++++++++ .../implementation/RxDocumentClientImpl.java | 227 +++++++++++ .../DocumentQueryExecutionContextFactory.java | 11 + .../docs/readManyByPartitionKey-design.md | 133 ++++++ 18 files changed, 2323 insertions(+), 1 deletion(-) create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java create mode 100644 sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java create mode 100644 sdk/cosmos/docs/readManyByPartitionKey-design.md diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala index 9ece47416526..00761f23d399 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala @@ -45,6 +45,7 @@ private[cosmos] object CosmosConstants { val Id = "id" val ETag = "_etag" val ItemIdentity = "_itemIdentity" + val PartitionKeyIdentity = "_partitionKeyIdentity" } object StatusCodes { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index a35cff27af68..2fbc036724e7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -2,9 +2,10 @@ // Licensed under the MIT License. package com.azure.cosmos.spark -import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey} +import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey, PartitionKeyBuilder} import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.azure.cosmos.{SparkBridgeInternal} import org.apache.spark.sql.{DataFrame, Row, SparkSession} import java.util @@ -112,4 +113,110 @@ object CosmosItemsDataSource { readManyReader.readMany(df.rdd, readManyFilterExtraction) } + + def readManyByPartitionKey(df: DataFrame, userConfig: java.util.Map[String, String]): DataFrame = { + readManyByPartitionKey(df, userConfig, null) + } + + def readManyByPartitionKey( + df: DataFrame, + userConfig: java.util.Map[String, String], + userProvidedSchema: StructType): DataFrame = { + + val readManyReader = new CosmosReadManyByPartitionKeyReader( + userProvidedSchema, + userConfig.asScala.toMap) + + // Option 1: Look for the _partitionKeyIdentity column (produced by GetCosmosPartitionKeyValue UDF) + val pkIdentityFieldExtraction = df + .schema + .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) + .map(field => (row: Row) => + CosmosPartitionKeyHelper.tryParsePartitionKey(row.getString(row.fieldIndex(field.name))).get) + + // Option 2: Detect PK columns by matching the container's partition key paths against the DataFrame schema + val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { + None // no need to resolve PK paths - _partitionKeyIdentity column takes precedence + } else { + val effectiveConfig = CosmosConfig.getEffectiveConfig( + databaseName = None, + containerName = None, + userConfig.asScala.toMap) + val readConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveConfig) + val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) + val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) + val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" + + val pkPathsOpt = Loan( + List[Option[CosmosClientCacheItem]]( + Some( + CosmosClientCache( + CosmosClientConfiguration( + effectiveConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom)), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => { + val container = + ThroughputControlHelper.getContainer( + effectiveConfig, + containerConfig, + clientCacheItems(0).get, + clientCacheItems(1)) + + val pkDefinition = SparkBridgeInternal + .getContainerPropertiesFromCollectionCache(container) + .getPartitionKeyDefinition + + pkDefinition.getPaths.asScala.map(_.stripPrefix("/")).toList + }) + + // Check if ALL PK path columns exist in the DataFrame schema + val dfFieldNames = df.schema.fieldNames.toSet + val allPkColumnsPresent = pkPathsOpt.forall(path => dfFieldNames.contains(path)) + + if (allPkColumnsPresent && pkPathsOpt.nonEmpty) { + val pkPaths = pkPathsOpt + Some((row: Row) => { + if (pkPaths.size == 1) { + // Single partition key + new PartitionKey(row.getAs[Any](pkPaths.head)) + } else { + // Hierarchical partition key — build level by level + val builder = new PartitionKeyBuilder() + for (path <- pkPaths) { + val value = row.getAs[Any](path) + value match { + case s: String => builder.add(s) + case n: Number => builder.add(n.doubleValue()) + case b: Boolean => builder.add(b) + case null => builder.addNoneValue() + case other => builder.add(other.toString) + } + } + builder.build() + } + }) + } else { + None + } + } + + val pkExtraction = pkIdentityFieldExtraction + .orElse(pkColumnExtraction) + .getOrElse( + throw new IllegalArgumentException( + "Cannot determine partition key extraction from the input DataFrame. " + + "Either add a '_partitionKeyIdentity' column (using the GetCosmosPartitionKeyValue UDF) " + + "or ensure the DataFrame contains columns matching the container's partition key paths.")) + + readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala new file mode 100644 index 000000000000..616e1893b343 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.routing.PartitionKeyInternal +import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, Utils} +import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait + +import java.util + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { + // pattern will be recognized + // pk(partitionKeyValue) + // + // (?i) : The whole matching is case-insensitive + // pk[(](.*)[)]: partitionKey Value + private val cosmosPartitionKeyStringRegx = """(?i)pk[(](.*)[)]""".r + private val objectMapper = Utils.getSimpleObjectMapper + + def getCosmosPartitionKeyValueString(partitionKeyValue: List[Object]): String = { + s"pk(${objectMapper.writeValueAsString(partitionKeyValue.asJava)})" + } + + def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = { + cosmosPartitionKeyString match { + case cosmosPartitionKeyStringRegx(pkValue) => + val partitionKeyValue = Utils.parse(pkValue, classOf[Object]) + partitionKeyValue match { + case arrayList: util.ArrayList[Object] => + Some( + ImplementationBridgeHelpers + .PartitionKeyHelper + .getPartitionKeyAccessor + .toPartitionKey(PartitionKeyInternal.fromObjectArray(arrayList.toArray, false))) + case _ => Some(new PartitionKey(partitionKeyValue)) + } + case _ => None + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala new file mode 100644 index 000000000000..1d324d4855ff --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosException, ReadConsistencyStrategy} +import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, UUIDs} +import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver +import com.azure.cosmos.spark.diagnostics.{BasicLoggingTrait, DiagnosticsContext} +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.StructType + +import java.util.UUID + +private[spark] class CosmosReadManyByPartitionKeyReader( + val userProvidedSchema: StructType, + val userConfig: Map[String, String] + ) extends BasicLoggingTrait with Serializable { + val effectiveUserConfig: Map[String, String] = CosmosConfig.getEffectiveConfig( + databaseName = None, + containerName = None, + userConfig) + + val clientConfig: CosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(effectiveUserConfig) + val readConfig: CosmosReadConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveUserConfig) + val cosmosContainerConfig: CosmosContainerConfig = + CosmosContainerConfig.parseCosmosContainerConfig(effectiveUserConfig) + //scalastyle:off multiple.string.literals + val tableName: String = s"com.azure.cosmos.spark.items.${clientConfig.accountName}." + + s"${cosmosContainerConfig.database}.${cosmosContainerConfig.container}" + private lazy val sparkSession = { + assertOnSparkDriver() + SparkSession.active + } + val sparkEnvironmentInfo: String = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(sparkSession)) + logTrace(s"Instantiated ${this.getClass.getSimpleName} for $tableName") + + private[spark] def initializeAndBroadcastCosmosClientStatesForContainer(): Broadcast[CosmosClientMetadataCachesSnapshots] = { + val calledFrom = s"CosmosReadManyByPartitionKeyReader($tableName).initializeAndBroadcastCosmosClientStateForContainer" + Loan( + List[Option[CosmosClientCacheItem]]( + Some( + CosmosClientCache( + CosmosClientConfiguration( + effectiveUserConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom)), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveUserConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => { + val container = + ThroughputControlHelper.getContainer( + effectiveUserConfig, + cosmosContainerConfig, + clientCacheItems(0).get, + clientCacheItems(1)) + try { + container.readItem( + UUIDs.nonBlockingRandomUUID().toString, + new PartitionKey(UUIDs.nonBlockingRandomUUID().toString), + classOf[ObjectNode]) + .block() + } catch { + case _: CosmosException => None + } + + val state = new CosmosClientMetadataCachesSnapshot() + state.serialize(clientCacheItems(0).get.cosmosClient) + + var throughputControlState: Option[CosmosClientMetadataCachesSnapshot] = None + if (clientCacheItems(1).isDefined) { + throughputControlState = Some(new CosmosClientMetadataCachesSnapshot()) + throughputControlState.get.serialize(clientCacheItems(1).get.cosmosClient) + } + + val metadataSnapshots = CosmosClientMetadataCachesSnapshots(state, throughputControlState) + sparkSession.sparkContext.broadcast(metadataSnapshots) + }) + } + + def readManyByPartitionKey(inputRdd: RDD[Row], pkExtraction: Row => PartitionKey): DataFrame = { + val correlationActivityId = UUIDs.nonBlockingRandomUUID() + val calledFrom = s"CosmosReadManyByPartitionKeyReader.readManyByPartitionKey($correlationActivityId)" + val schema = Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration( + effectiveUserConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom + )), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveUserConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => Option.apply(userProvidedSchema).getOrElse( + CosmosTableSchemaInferrer.inferSchema( + clientCacheItems(0).get, + clientCacheItems(1), + effectiveUserConfig, + ItemsTable.defaultSchemaForInferenceDisabled))) + + val clientStates = initializeAndBroadcastCosmosClientStatesForContainer + + sparkSession.sqlContext.createDataFrame( + inputRdd.mapPartitionsWithIndex( + (partitionIndex: Int, rowIterator: Iterator[Row]) => { + val pkIterator: Iterator[PartitionKey] = rowIterator + .map(row => pkExtraction.apply(row)) + + logInfo(s"Creating an ItemsPartitionReaderWithReadManyByPartitionKey for Activity $correlationActivityId to read for " + + s"input partition [$partitionIndex] ${tableName}") + + val reader = new ItemsPartitionReaderWithReadManyByPartitionKey( + effectiveUserConfig, + CosmosReadManyHelper.FullRangeFeedRange, + schema, + DiagnosticsContext(correlationActivityId, partitionIndex.toString), + clientStates, + DiagnosticsConfig.parseDiagnosticsConfig(effectiveUserConfig), + sparkEnvironmentInfo, + TaskContext.get, + pkIterator) + + new Iterator[Row] { + override def hasNext: Boolean = reader.next() + + override def next(): Row = reader.getCurrentRow() + } + }, + preservesPartitioning = true + ), + schema) + } +} + +private object CosmosReadManyByPartitionKeyHelper { + val FullRangeFeedRange: NormalizedRange = NormalizedRange("", "FF") +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala new file mode 100644 index 000000000000..68e41cad3ec5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -0,0 +1,233 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} +import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple +import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition} +import com.azure.cosmos.spark.BulkWriter.getThreadInfo +import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName +import com.azure.cosmos.spark.diagnostics.{DetailedFeedDiagnosticsProvider, DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType + +import java.util + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey +( + config: Map[String, String], + feedRange: NormalizedRange, + readSchema: StructType, + diagnosticsContext: DiagnosticsContext, + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + sparkEnvironmentInfo: String, + taskContext: TaskContext, + readManyPartitionKeys: Iterator[PartitionKey] +) + extends PartitionReader[InternalRow] { + + private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) + + private val readManyOptions = new CosmosReadManyRequestOptions() + private val readManyOptionsImpl = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor + .getImpl(readManyOptions) + + private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + ThroughputControlHelper.populateThroughputControlGroupName(readManyOptionsImpl, readConfig.throughputControlConfig) + + private val operationContext = { + assert(taskContext != null) + + SparkTaskContext(diagnosticsContext.correlationActivityId, + taskContext.stageId(), + taskContext.partitionId(), + taskContext.taskAttemptId(), + feedRange.toString) + } + + private val operationContextAndListenerTuple: Option[OperationContextAndListenerTuple] = { + if (diagnosticsConfig.mode.isDefined) { + val listener = + DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass) + + val ctxAndListener = new OperationContextAndListenerTuple(operationContext, listener) + + readManyOptionsImpl + .setOperationContextAndListenerTuple(ctxAndListener) + + Some(ctxAndListener) + } else { + None + } + } + + log.logTrace(s"Instantiated ${this.getClass.getSimpleName}, Context: ${operationContext.toString} $getThreadInfo") + + private val containerTargetConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) + + log.logInfo(s"Using ReadManyByPartitionKey from feed range $feedRange of " + + s"container ${containerTargetConfig.database}.${containerTargetConfig.container} - " + + s"correlationActivityId ${diagnosticsContext.correlationActivityId}, " + + s"Context: ${operationContext.toString} $getThreadInfo") + + private val clientCacheItem = CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), + s"ItemsPartitionReaderWithReadManyByPartitionKey($feedRange, ${containerTargetConfig.database}.${containerTargetConfig.container})" + ) + + private val throughputControlClientCacheItemOpt = + ThroughputControlHelper.getThroughputControlClientCacheItem( + config, + clientCacheItem.context, + Some(cosmosClientStateHandles), + sparkEnvironmentInfo) + + private val cosmosAsyncContainer = + ThroughputControlHelper.getContainer( + config, + containerTargetConfig, + clientCacheItem, + throughputControlClientCacheItemOpt) + + private val partitionKeyDefinition: PartitionKeyDefinition = { + TransientErrorsRetryPolicy.executeWithRetry(() => { + SparkBridgeInternal + .getContainerPropertiesFromCollectionCache(cosmosAsyncContainer).getPartitionKeyDefinition + }) + } + + private val cosmosSerializationConfig = CosmosSerializationConfig.parseSerializationConfig(config) + private val cosmosRowConverter = CosmosRowConverter.get(cosmosSerializationConfig) + + readManyOptionsImpl + .setCustomItemSerializer( + new CosmosItemSerializerNoExceptionWrapping { + override def serialize[T](item: T): util.Map[String, AnyRef] = ??? + + override def deserialize[T](jsonNodeMap: util.Map[String, AnyRef], classType: Class[T]): T = { + if (jsonNodeMap == null) { + throw new IllegalStateException("The 'jsonNodeMap' should never be null here.") + } + + if (classType != classOf[SparkRowItem]) { + throw new IllegalStateException("The 'classType' must be 'classOf[SparkRowItem])' here.") + } + + val objectNode: ObjectNode = jsonNodeMap match { + case map: ObjectNodeMap => + map.getObjectNode + case _ => + Utils.getSimpleObjectMapper.convertValue(jsonNodeMap, classOf[ObjectNode]) + } + + val partitionKey = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition) + + val row = cosmosRowConverter.fromObjectNodeToRow(readSchema, + objectNode, + readConfig.schemaConversionMode) + + SparkRowItem(row, getPartitionKeyForFeedDiagnostics(partitionKey)).asInstanceOf[T] + } + } + ) + + // Collect all PK values upfront — readManyByPartitionKey needs the full list to + // group by physical partition and issue parallel queries. + // Deduplicate by PK string representation — safe because the list size is bounded + // by the per-call limit of the readManyByPartitionKey API. + private lazy val pkList = { + val seen = new java.util.LinkedHashMap[String, PartitionKey]() + readManyPartitionKeys.foreach(pk => seen.putIfAbsent(pk.toString, pk)) + new java.util.ArrayList[PartitionKey](seen.values()) + } + + private val endToEndTimeoutPolicy = + new CosmosEndToEndOperationLatencyPolicyConfigBuilder( + java.time.Duration.ofSeconds(CosmosConstants.readOperationEndToEndTimeoutInSeconds)) + .enable(true) + .build + + // Single iterator over all PKs — the SDK handles per-physical-partition batching + // internally to avoid oversized SQL queries. + private lazy val iterator = new TransientIOErrorsRetryingIterator[SparkRowItem]( + continuationToken => { + val options = new CosmosReadManyRequestOptions() + val optionsImpl = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor + .getImpl(options) + + ThroughputControlHelper.populateThroughputControlGroupName(optionsImpl, readConfig.throughputControlConfig) + + if (operationContextAndListenerTuple.isDefined) { + optionsImpl.setOperationContextAndListenerTuple(operationContextAndListenerTuple.get) + } + + optionsImpl.setCustomItemSerializer(readManyOptionsImpl.getCustomItemSerializer) + + if (pkList.isEmpty) { + cosmosAsyncContainer.readManyByPartitionKey( + new java.util.ArrayList[PartitionKey](), options, classOf[SparkRowItem]) + } else { + readConfig.customQuery match { + case Some(query) => + cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, options, classOf[SparkRowItem]) + case None => + cosmosAsyncContainer.readManyByPartitionKey(pkList, options, classOf[SparkRowItem]) + } + } + }, + readConfig.maxItemCount, + readConfig.prefetchBufferSize, + operationContextAndListenerTuple, + None + ) + + private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) + + private def shouldLogDetailedFeedDiagnostics(): Boolean = { + diagnosticsConfig.mode.isDefined && + diagnosticsConfig.mode.get.equalsIgnoreCase(classOf[DetailedFeedDiagnosticsProvider].getName) + } + + private def getPartitionKeyForFeedDiagnostics(pkValue: PartitionKey): Option[PartitionKey] = { + if (shouldLogDetailedFeedDiagnostics()) { + Some(pkValue) + } else { + None + } + } + + override def next(): Boolean = iterator.hasNext + + override def get(): InternalRow = { + cosmosRowConverter.fromRowToInternalRow(iterator.next().row, rowSerializer) + } + + def getCurrentRow(): Row = iterator.next().row + + override def close(): Unit = { + this.iterator.close() + RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) + clientCacheItem.close() + if (throughputControlClientCacheItemOpt.isDefined) { + throughputControlClientCacheItemOpt.get.close() + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala new file mode 100644 index 000000000000..a58d5b723b8b --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark.udf + +import com.azure.cosmos.spark.CosmosPartitionKeyHelper +import com.azure.cosmos.spark.CosmosPredicates.requireNotNull +import org.apache.spark.sql.api.java.UDF1 + +@SerialVersionUID(1L) +class GetCosmosPartitionKeyValue extends UDF1[Object, String] { + override def call + ( + partitionKeyValue: Object + ): String = { + requireNotNull(partitionKeyValue, "partitionKeyValue") + + partitionKeyValue match { + // for subpartitions case - Seq covers both WrappedArray (Scala 2.12) and ArraySeq (Scala 2.13) + case seq: Seq[Any] => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(seq.map(_.asInstanceOf[Object]).toList) + case _ => CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala new file mode 100644 index 000000000000..d127710da287 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} + +class CosmosPartitionKeyHelperSpec extends UnitSpec { + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + it should "return the correct partition key value string for single PK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("pk1")) + pkString shouldEqual "pk([\"pk1\"])" + } + + it should "return the correct partition key value string for HPK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city1", "zip1")) + pkString shouldEqual "pk([\"city1\",\"zip1\"])" + } + + it should "return the correct partition key value string for 3-level HPK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("a", "b", "c")) + pkString shouldEqual "pk([\"a\",\"b\",\"c\"])" + } + + it should "parse valid single PK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"myPkValue\"])") + pk.isDefined shouldBe true + pk.get shouldEqual new PartitionKey("myPkValue") + } + + it should "parse valid HPK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"city1\",\"zip1\"])") + pk.isDefined shouldBe true + val expected = new PartitionKeyBuilder().add("city1").add("zip1").build() + pk.get shouldEqual expected + } + + it should "parse valid 3-level HPK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"a\",\"b\",\"c\"])") + pk.isDefined shouldBe true + val expected = new PartitionKeyBuilder().add("a").add("b").add("c").build() + pk.get shouldEqual expected + } + + it should "roundtrip single PK" in { + val original = "pk([\"roundtrip\"])" + val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) + parsed.isDefined shouldBe true + val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("roundtrip")) + serialized shouldEqual original + } + + it should "roundtrip HPK" in { + val original = "pk([\"city\",\"zip\"])" + val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) + parsed.isDefined shouldBe true + val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city", "zip")) + serialized shouldEqual original + } + + it should "return None for malformed string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("invalid_format") + pk.isDefined shouldBe false + } + + it should "return None for missing pk prefix" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("[\"value\"]") + pk.isDefined shouldBe false + } + + it should "be case-insensitive for parsing" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("PK([\"value\"])") + pk.isDefined shouldBe true + pk.get shouldEqual new PartitionKey("value") + } + + //scalastyle:on multiple.string.literals + //scalastyle:on magic.number +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala new file mode 100644 index 000000000000..5c2d7b59836d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, TestConfigurations, Utils} +import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.spark.diagnostics.DiagnosticsContext +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.MockTaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +import java.util.UUID +import scala.collection.mutable.ListBuffer + +class ItemsPartitionReaderWithReadManyByPartitionKeyITest + extends IntegrationSpec + with Spark + with AutoCleanableCosmosContainersWithPkAsPartitionKey { + private val idProperty = "id" + private val pkProperty = "pk" + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + it should "be able to retrieve all items for given partition keys" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + + // Create items with known PK values + val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition + val allItemsByPk = scala.collection.mutable.Map[String, ListBuffer[ObjectNode]]() + val pkValues = List("pkA", "pkB", "pkC") + + for (pk <- pkValues) { + allItemsByPk(pk) = ListBuffer[ObjectNode]() + for (_ <- 1 to 5) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put(idProperty, UUID.randomUUID().toString) + objectNode.put(pkProperty, pk) + container.createItem(objectNode).block() + allItemsByPk(pk) += objectNode + } + } + + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ReadManyByPKTest" + ) + + val readSchema = StructType(Seq( + StructField(idProperty, StringType, false), + StructField(pkProperty, StringType, false) + )) + + val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + // Read items for pkA and pkB (not pkC) + val targetPks = List("pkA", "pkB") + val pkIterator = targetPks.map(pk => new PartitionKey(pk)).iterator + + val reader = ItemsPartitionReaderWithReadManyByPartitionKey( + config, + NormalizedRange("", "FF"), + readSchema, + diagnosticsContext, + cosmosClientMetadataCachesSnapshots, + diagnosticsConfig, + "", + MockTaskContext.mockTaskContext(), + pkIterator + ) + + val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) + val itemsReadFromReader = ListBuffer[ObjectNode]() + while (reader.next()) { + itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) + } + + // Should have 10 items (5 for pkA + 5 for pkB) + itemsReadFromReader.size shouldEqual 10 + + // All items should be from pkA or pkB + itemsReadFromReader.foreach(item => { + val pk = item.get(pkProperty).asText() + targetPks should contain(pk) + }) + + // Validate all expected IDs are present + val expectedIds = (allItemsByPk("pkA") ++ allItemsByPk("pkB")).map(_.get(idProperty).asText()).toSet + val actualIds = itemsReadFromReader.map(_.get(idProperty).asText()).toSet + actualIds shouldEqual expectedIds + + reader.close() + } + + it should "return empty results for non-existent partition keys" in { + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ReadManyByPKEmptyTest" + ) + + val readSchema = StructType(Seq( + StructField(idProperty, StringType, false), + StructField(pkProperty, StringType, false) + )) + + val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + val pkIterator = List(new PartitionKey("nonExistentPk")).iterator + + val reader = ItemsPartitionReaderWithReadManyByPartitionKey( + config, + NormalizedRange("", "FF"), + readSchema, + diagnosticsContext, + cosmosClientMetadataCachesSnapshots, + diagnosticsConfig, + "", + MockTaskContext.mockTaskContext(), + pkIterator + ) + + val itemsReadFromReader = ListBuffer[ObjectNode]() + val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) + while (reader.next()) { + itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) + } + + itemsReadFromReader.size shouldEqual 0 + reader.close() + } + + private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = { + val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot() + cosmosClientMetadataCachesSnapshot.serialize(cosmosClient) + + spark.sparkContext.broadcast( + CosmosClientMetadataCachesSnapshots( + cosmosClientMetadataCachesSnapshot, + Option.empty[CosmosClientMetadataCachesSnapshot])) + } + + //scalastyle:on multiple.string.literals + //scalastyle:on magic.number +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java new file mode 100644 index 000000000000..8fa657b62105 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -0,0 +1,381 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos; + +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.FeedResponse; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyBuilder; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKeyDefinitionVersion; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; +import com.azure.cosmos.rx.TestSuiteBase; +import com.azure.cosmos.util.CosmosPagedIterable; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class ReadManyByPartitionKeyTest extends TestSuiteBase { + + private String preExistingDatabaseId = CosmosDatabaseForTest.generateId(); + private CosmosClient client; + private CosmosDatabase createdDatabase; + + // Single PK container (/mypk) + private CosmosContainer singlePkContainer; + + // HPK container (/city, /zipcode, /areaCode) + private CosmosContainer multiHashContainer; + + @Factory(dataProvider = "clientBuilders") + public ReadManyByPartitionKeyTest(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = {"emulator"}, timeOut = SETUP_TIMEOUT) + public void before_ReadManyByPartitionKeyTest() { + client = getClientBuilder().buildClient(); + createdDatabase = createSyncDatabase(client, preExistingDatabaseId); + + // Single PK container + String singlePkContainerName = UUID.randomUUID().toString(); + CosmosContainerProperties singlePkProps = new CosmosContainerProperties(singlePkContainerName, "/mypk"); + createdDatabase.createContainer(singlePkProps); + singlePkContainer = createdDatabase.getContainer(singlePkContainerName); + + // HPK container + String multiHashContainerName = UUID.randomUUID().toString(); + PartitionKeyDefinition hpkDef = new PartitionKeyDefinition(); + hpkDef.setKind(PartitionKind.MULTI_HASH); + hpkDef.setVersion(PartitionKeyDefinitionVersion.V2); + ArrayList paths = new ArrayList<>(); + paths.add("/city"); + paths.add("/zipcode"); + paths.add("/areaCode"); + hpkDef.setPaths(paths); + + CosmosContainerProperties hpkProps = new CosmosContainerProperties(multiHashContainerName, hpkDef); + createdDatabase.createContainer(hpkProps); + multiHashContainer = createdDatabase.getContainer(multiHashContainerName); + } + + @AfterClass(groups = {"emulator"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteSyncDatabase(createdDatabase); + safeCloseSyncClient(client); + } + + //region Single PK tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_basic() { + // Create items with different PKs + List items = createSinglePkItems("pk1", 3); + items.addAll(createSinglePkItems("pk2", 2)); + items.addAll(createSinglePkItems("pk3", 4)); + + // Read by 2 partition keys + List pkValues = Arrays.asList( + new PartitionKey("pk1"), + new PartitionKey("pk2")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(5); // 3 + 2 + resultList.forEach(item -> { + String pk = item.get("mypk").asText(); + assertThat(pk).isIn("pk1", "pk2"); + }); + + // Cleanup + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withProjection() { + List items = createSinglePkItems("pkProj", 2); + + List pkValues = Collections.singletonList(new PartitionKey("pkProj")); + SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.mypk FROM c"); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(2); + // Should only have id and mypk fields (plus system properties) + resultList.forEach(item -> { + assertThat(item.has("id")).isTrue(); + assertThat(item.has("mypk")).isTrue(); + }); + + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withAdditionalFilter() { + // Create items with different "status" values + createSinglePkItemsWithStatus("pkFilter", "active", 3); + createSinglePkItemsWithStatus("pkFilter", "inactive", 2); + + List pkValues = Collections.singletonList(new PartitionKey("pkFilter")); + SqlQuerySpec customQuery = new SqlQuerySpec( + "SELECT * FROM c WHERE c.status = @status", + Arrays.asList(new SqlParameter("@status", "active"))); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(3); + resultList.forEach(item -> { + assertThat(item.get("status").asText()).isEqualTo("active"); + }); + + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_emptyResults() { + List pkValues = Collections.singletonList(new PartitionKey("nonExistent")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).isEmpty(); + } + + //endregion + + //region HPK tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_fullPk() { + createHpkItems(); + + // Read by full PKs + List pkValues = Arrays.asList( + new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build(), + new PartitionKeyBuilder().add("Pittsburgh").add("15232").add(2).build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond/98053/1 has 2 items, Pittsburgh/15232/2 has 1 item + assertThat(resultList).hasSize(3); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_partialPk_singleLevel() { + createHpkItems(); + + // Read by partial PK (only city) + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond has 3 items total (2 with 98053/1 and 1 with 12345/1) + assertThat(resultList).hasSize(3); + resultList.forEach(item -> { + assertThat(item.get("city").asText()).isEqualTo("Redmond"); + }); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_partialPk_twoLevels() { + createHpkItems(); + + // Read by partial PK (city + zipcode) + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").add("98053").build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond/98053 has 2 items + assertThat(resultList).hasSize(2); + resultList.forEach(item -> { + assertThat(item.get("city").asText()).isEqualTo("Redmond"); + assertThat(item.get("zipcode").asText()).isEqualTo("98053"); + }); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_withProjection() { + createHpkItems(); + + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build()); + + SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.city FROM c"); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(2); + + cleanupContainer(multiHashContainer); + } + + //endregion + + //region Negative/validation tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsAggregateQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec aggregateQuery = new SqlQuerySpec("SELECT COUNT(1) FROM c"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, aggregateQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for aggregate query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("aggregates"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsOrderByQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec orderByQuery = new SqlQuerySpec("SELECT * FROM c ORDER BY c.id"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, orderByQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for ORDER BY query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("ORDER BY"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsDistinctQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec distinctQuery = new SqlQuerySpec("SELECT DISTINCT c.mypk FROM c"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, distinctQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for DISTINCT query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("DISTINCT"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsGroupByQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, groupByQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for GROUP BY query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("GROUP BY"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void rejectsNullPartitionKeyList() { + singlePkContainer.readManyByPartitionKey((List) null, ObjectNode.class); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void rejectsEmptyPartitionKeyList() { + singlePkContainer.readManyByPartitionKey(new ArrayList<>(), ObjectNode.class) + .stream().collect(Collectors.toList()); + } + + //endregion + + //region helper methods + + private List createSinglePkItems(String pkValue, int count) { + List items = new ArrayList<>(); + for (int i = 0; i < count; i++) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("mypk", pkValue); + singlePkContainer.createItem(item); + items.add(item); + } + return items; + } + + private List createSinglePkItemsWithStatus(String pkValue, String status, int count) { + List items = new ArrayList<>(); + for (int i = 0; i < count; i++) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("mypk", pkValue); + item.put("status", status); + singlePkContainer.createItem(item); + items.add(item); + } + return items; + } + + private void createHpkItems() { + // Same data as CosmosMultiHashTest.createItems() + createHpkItem("Redmond", "98053", 1); + createHpkItem("Redmond", "98053", 1); + createHpkItem("Pittsburgh", "15232", 2); + createHpkItem("Stonybrook", "11790", 3); + createHpkItem("Stonybrook", "11794", 3); + createHpkItem("Stonybrook", "11791", 3); + createHpkItem("Redmond", "12345", 1); + } + + private void createHpkItem(String city, String zipcode, int areaCode) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("city", city); + item.put("zipcode", zipcode); + item.put("areaCode", areaCode); + multiHashContainer.createItem(item); + } + + private void cleanupContainer(CosmosContainer container) { + CosmosPagedIterable allItems = container.queryItems( + "SELECT * FROM c", new com.azure.cosmos.models.CosmosQueryRequestOptions(), ObjectNode.class); + allItems.forEach(item -> { + try { + container.deleteItem(item, new CosmosItemRequestOptions()); + } catch (CosmosException e) { + // ignore cleanup failures + } + }); + } + + //endregion +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java new file mode 100644 index 000000000000..9c68b3b126e4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java @@ -0,0 +1,349 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos.implementation; + +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyBuilder; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKeyDefinitionVersion; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ReadManyByPartitionKeyQueryHelperTest { + + //region Single PK (HASH) tests + + @Test(groups = { "unit" }) + public void singlePk_defaultQuery_singleValue() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getQueryText()).contains("@pkParam0"); + assertThat(result.getParameters()).hasSize(1); + assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("pk1"); + } + + @Test(groups = { "unit" }) + public void singlePk_defaultQuery_multipleValues() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList( + new PartitionKey("pk1"), + new PartitionKey("pk2"), + new PartitionKey("pk3")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getQueryText()).contains("@pkParam0"); + assertThat(result.getQueryText()).contains("@pkParam1"); + assertThat(result.getQueryText()).contains("@pkParam2"); + assertThat(result.getParameters()).hasSize(3); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_noWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT c.name, c.age FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).startsWith("SELECT c.name, c.age FROM c WHERE"); + assertThat(result.getQueryText()).contains("IN ("); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_withExistingWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@minAge", 18)); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c WHERE c.age > @minAge", baseParams, pkValues, selectors, pkDef); + + // Should AND the PK filter to the existing WHERE clause + assertThat(result.getQueryText()).contains("WHERE (c.age > @minAge) AND ("); + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getParameters()).hasSize(2); // @minAge + @pkParam1 + assertThat(result.getParameters().get(0).getName()).isEqualTo("@minAge"); + } + + //endregion + + //region HPK (MULTI_HASH) tests + + @Test(groups = { "unit" }) + public void hpk_fullPk_defaultQuery() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); + // Should use OR/AND pattern, not IN + assertThat(result.getQueryText()).doesNotContain("IN ("); + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("AND"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); + assertThat(result.getParameters()).hasSize(2); + assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("Redmond"); + assertThat(result.getParameters().get(1).getValue(Object.class)).isEqualTo("98052"); + } + + @Test(groups = { "unit" }) + public void hpk_fullPk_multipleValues() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + List pkValues = Arrays.asList( + new PartitionKeyBuilder().add("Redmond").add("98052").build(), + new PartitionKeyBuilder().add("Seattle").add("98101").build()); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("OR"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam2"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam3"); + assertThat(result.getParameters()).hasSize(4); + } + + @Test(groups = { "unit" }) + public void hpk_partialPk_singleLevel() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); + List selectors = createSelectors(pkDef); + + // Partial PK — only first level + PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").build(); + List pkValues = Collections.singletonList(partialPk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + // Should NOT include zipcode or areaCode since it's partial + assertThat(result.getQueryText()).doesNotContain("zipcode"); + assertThat(result.getQueryText()).doesNotContain("areaCode"); + assertThat(result.getParameters()).hasSize(1); + } + + @Test(groups = { "unit" }) + public void hpk_partialPk_twoLevels() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); + List selectors = createSelectors(pkDef); + + // Partial PK — first two levels + PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(partialPk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).doesNotContain("areaCode"); + assertThat(result.getParameters()).hasSize(2); + } + + @Test(groups = { "unit" }) + public void hpk_customQuery_withWhere() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@status", "active")); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT c.name FROM c WHERE c.status = @status", baseParams, pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("WHERE (c.status = @status) AND ("); + assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam1"); + assertThat(result.getParameters()).hasSize(3); // @status + 2 pk params + } + + //endregion + + //region findTopLevelWhereIndex tests + + @Test(groups = { "unit" }) + public void findWhere_simpleQuery() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.ID = 1"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_noWhere() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C"); + assertThat(idx).isEqualTo(-1); + } + + @Test(groups = { "unit" }) + public void findWhere_whereInSubquery() { + // WHERE inside parentheses (subquery) should be ignored + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM C WHERE EXISTS(SELECT VALUE T FROM T IN C.TAGS WHERE T = 'FOO')"); + // Should find the outer WHERE, not the inner one + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_caseInsensitive() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.X = 1"); + assertThat(idx).isGreaterThan(0); + } + + @Test(groups = { "unit" }) + public void findWhere_whereNotKeyword() { + // "ELSEWHERE" should not match + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM ELSEWHERE"); + assertThat(idx).isEqualTo(-1); + } + + //endregion + + //region Custom alias tests + + @Test(groups = { "unit" }) + public void singlePk_customAlias() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT x.id, x.mypk FROM x", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).startsWith("SELECT x.id, x.mypk FROM x WHERE"); + assertThat(result.getQueryText()).contains("x[\"mypk\"] IN ("); + assertThat(result.getQueryText()).doesNotContain("c[\"mypk\"]"); + } + + @Test(groups = { "unit" }) + public void singlePk_customAlias_withWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@cat", "HelloWorld")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT x.id, x.mypk FROM x WHERE x.category = @cat", baseParams, pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("WHERE (x.category = @cat) AND (x[\"mypk\"] IN ("); + } + + @Test(groups = { "unit" }) + public void hpk_customAlias() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT r.name FROM root r", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("r[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).doesNotContain("c[\""); + } + + //endregion + + //region extractTableAlias tests + + @Test(groups = { "unit" }) + public void extractAlias_defaultC() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM c")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_customX() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT x.id FROM x WHERE x.age > 5")).isEqualTo("x"); + } + + @Test(groups = { "unit" }) + public void extractAlias_rootWithAlias() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT r.name FROM root r")).isEqualTo("r"); + } + + @Test(groups = { "unit" }) + public void extractAlias_rootNoAlias() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM root")).isEqualTo("root"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithWhere() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM items WHERE items.status = 'active'")).isEqualTo("items"); + } + + @Test(groups = { "unit" }) + public void extractAlias_caseInsensitive() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("select * from MyContainer where MyContainer.id = '1'")).isEqualTo("MyContainer"); + } + + //endregion + + //region helpers + + private PartitionKeyDefinition createSinglePkDefinition(String path) { + PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); + pkDef.setKind(PartitionKind.HASH); + pkDef.setVersion(PartitionKeyDefinitionVersion.V2); + pkDef.setPaths(Collections.singletonList(path)); + return pkDef; + } + + private PartitionKeyDefinition createMultiHashPkDefinition(String... paths) { + PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); + pkDef.setKind(PartitionKind.MULTI_HASH); + pkDef.setVersion(PartitionKeyDefinitionVersion.V2); + pkDef.setPaths(Arrays.asList(paths)); + return pkDef; + } + + private List createSelectors(PartitionKeyDefinition pkDef) { + return pkDef.getPaths() + .stream() + .map(pathPart -> pathPart.substring(1)) // skip starting / + .map(pathPart -> pathPart.replace("\"", "\\")) + .map(part -> "[\"" + part + "\"]") + .collect(Collectors.toList()); + } + + //endregion +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index b3888f1bad3a..06d6b7d167bd 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -1577,6 +1577,117 @@ private Mono> readManyInternal( context); } + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, null, null, classType); + } + + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, null, requestOptions, classType); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query — aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return UtilBridgeInternal.createCosmosPagedFlux( + readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); + } + + private Function>> readManyByPartitionKeyInternalFunc( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + CosmosAsyncClient client = this.getDatabase().getClient(); + + return (pagedFluxOptions -> { + CosmosQueryRequestOptions queryRequestOptions = requestOptions == null + ? new CosmosQueryRequestOptions() + : queryOptionsAccessor.clone(readManyOptionsAccessor.getImpl(requestOptions)); + queryRequestOptions.setMaxDegreeOfParallelism(-1); + queryRequestOptions.setQueryName("readManyByPartitionKey"); + CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor.getImpl(queryRequestOptions); + applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyItemsSpanName); + + QueryFeedOperationState state = new QueryFeedOperationState( + client, + this.readManyItemsSpanName, + database.getId(), + this.getId(), + ResourceType.Document, + OperationType.Query, + queryOptionsAccessor.getQueryNameOrDefault(queryRequestOptions, this.readManyItemsSpanName), + queryRequestOptions, + pagedFluxOptions + ); + + pagedFluxOptions.setFeedOperationState(state); + + return CosmosBridgeInternal + .getAsyncDocumentClient(this.getDatabase()) + .readManyByPartitionKey( + partitionKeys, + customQuery, + BridgeInternal.getLink(this), + state, + classType) + .map(response -> prepareFeedResponse(response, false)); + }); + } + /** * Reads all the items of a logical partition * diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java index 04a6060c1927..0bd8be5850c0 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java @@ -540,6 +540,73 @@ public FeedResponse readMany( classType)); } + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, classType)); + } + + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, requestOptions, classType)); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query — aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, requestOptions, classType)); + } + /** * Reads all the items of a logical partition returning the results as {@link CosmosPagedIterable}. * diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 49e1fdf57f64..26f1ff64ea5d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -1584,6 +1584,27 @@ Mono> readMany( QueryFeedOperationState state, Class klass); + /** + * Reads many documents by partition key values. + * Unlike {@link #readMany(List, String, QueryFeedOperationState, Class)} this method does not require + * item ids - it queries all documents matching the provided partition key values. + * Partial hierarchical partition keys are supported and will fan out to multiple physical partitions. + * + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query (for projections/additional filters) - null means SELECT * FROM c + * @param collectionLink link for the documentcollection/container to be queried + * @param state the query operation state + * @param klass class type + * @param the type parameter + * @return a Flux with feed response pages of documents + */ + Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + Class klass); + /** * Read all documents of a certain logical partition. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java index 337055c6947f..162b0740f408 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java @@ -248,6 +248,11 @@ public class Configs { public static final String MIN_TARGET_BULK_MICRO_BATCH_SIZE_VARIABLE = "COSMOS_MIN_TARGET_BULK_MICRO_BATCH_SIZE"; public static final int DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE = 1; + // readManyByPartitionKey: max number of PK values per query per physical partition + private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; + private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; + private static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 1000; + public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY = "COSMOS.MAX_BULK_MICRO_BATCH_CONCURRENCY"; public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY_VARIABLE = "COSMOS_MAX_BULK_MICRO_BATCH_CONCURRENCY"; public static final int DEFAULT_MAX_BULK_MICRO_BATCH_CONCURRENCY = 1; @@ -816,6 +821,20 @@ public static int getMinTargetBulkMicroBatchSize() { return DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE; } + public static int getReadManyByPkMaxBatchSize() { + String valueFromSystemProperty = System.getProperty(READ_MANY_BY_PK_MAX_BATCH_SIZE); + if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { + return Math.max(1, Integer.parseInt(valueFromSystemProperty)); + } + + String valueFromEnvVariable = System.getenv(READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE); + if (valueFromEnvVariable != null && !valueFromEnvVariable.isEmpty()) { + return Math.max(1, Integer.parseInt(valueFromEnvVariable)); + } + + return DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE; + } + public static int getMaxBulkMicroBatchConcurrency() { String valueFromSystemProperty = System.getProperty(MAX_BULK_MICRO_BATCH_CONCURRENCY); if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java new file mode 100644 index 000000000000..fbe329bff844 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.implementation; + +import com.azure.cosmos.BridgeInternal; +import com.azure.cosmos.implementation.routing.PartitionKeyInternal; +import com.azure.cosmos.models.ModelBridgeInternal; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; + +import java.util.ArrayList; +import java.util.List; + +/** + * Helper for constructing SqlQuerySpec instances for readManyByPartitionKey operations. + * This class is not intended to be used directly by end-users. + */ +public class ReadManyByPartitionKeyQueryHelper { + + private static final String DEFAULT_TABLE_ALIAS = "c"; + + public static SqlQuerySpec createReadManyByPkQuerySpec( + String baseQueryText, + List baseParameters, + List pkValues, + List partitionKeySelectors, + PartitionKeyDefinition pkDefinition) { + + // Extract the table alias from the FROM clause (e.g. "FROM x" → "x", "FROM c" → "c") + String tableAlias = extractTableAlias(baseQueryText); + + StringBuilder pkFilter = new StringBuilder(); + List parameters = new ArrayList<>(baseParameters); + int paramCount = baseParameters.size(); + + boolean isSinglePathPk = partitionKeySelectors.size() == 1; + + if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { + // Single PK path — use IN clause: alias["pkPath"] IN (@pk0, @pk1, ...) + pkFilter.append(" "); + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(0)); + pkFilter.append(" IN ( "); + for (int i = 0; i < pkValues.size(); i++) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); + Object[] pkComponents = pkInternal.toObjectArray(); + String pkParamName = "@pkParam" + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[0])); + paramCount++; + + pkFilter.append(pkParamName); + if (i < pkValues.size() - 1) { + pkFilter.append(", "); + } + } + pkFilter.append(" )"); + } else { + // Multiple PK paths (HPK) or MULTI_HASH — use OR of AND clauses + pkFilter.append(" "); + for (int i = 0; i < pkValues.size(); i++) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); + Object[] pkComponents = pkInternal.toObjectArray(); + + pkFilter.append("("); + for (int j = 0; j < pkComponents.length; j++) { + String pkParamName = "@pkParam" + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[j])); + paramCount++; + + if (j > 0) { + pkFilter.append(" AND "); + } + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(j)); + pkFilter.append(" = "); + pkFilter.append(pkParamName); + } + pkFilter.append(")"); + + if (i < pkValues.size() - 1) { + pkFilter.append(" OR "); + } + } + } + + // Compose final query: handle existing WHERE clause in base query + String finalQuery; + int whereIndex = findTopLevelWhereIndex(baseQueryText); + if (whereIndex >= 0) { + // Base query has WHERE — AND our PK filter + String beforeWhere = baseQueryText.substring(0, whereIndex); + String afterWhere = baseQueryText.substring(whereIndex + 5); // skip "WHERE" + finalQuery = beforeWhere + "WHERE (" + afterWhere.trim() + ") AND (" + pkFilter.toString().trim() + ")"; + } else { + // No WHERE — add one + finalQuery = baseQueryText + " WHERE" + pkFilter.toString(); + } + + return new SqlQuerySpec(finalQuery, parameters); + } + + /** + * Extracts the table/collection alias from a SQL query's FROM clause. + * Handles: "SELECT * FROM c", "SELECT x.id FROM x WHERE ...", "SELECT * FROM root r", etc. + * Returns the alias used after FROM (last token before WHERE or end of FROM clause). + */ + static String extractTableAlias(String queryText) { + String upper = queryText.toUpperCase(); + int fromIndex = findTopLevelKeywordIndex(upper, "FROM"); + if (fromIndex < 0) { + return DEFAULT_TABLE_ALIAS; + } + + // Start scanning after "FROM" + int afterFrom = fromIndex + 4; + // Skip whitespace + while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; + } + + // Collect the container name token (could be "root", "c", etc.) + int tokenStart = afterFrom; + while (afterFrom < queryText.length() + && !Character.isWhitespace(queryText.charAt(afterFrom)) + && queryText.charAt(afterFrom) != '(' + && queryText.charAt(afterFrom) != ')') { + afterFrom++; + } + String containerName = queryText.substring(tokenStart, afterFrom); + + // Skip whitespace after container name + while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; + } + + // Check if there's an alias after the container name (before WHERE or end) + if (afterFrom < queryText.length()) { + char nextChar = Character.toUpperCase(queryText.charAt(afterFrom)); + // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN) or end, containerName IS the alias + if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J') { + // Check if it's actually a keyword + String remaining = upper.substring(afterFrom); + if (remaining.startsWith("WHERE") || remaining.startsWith("ORDER") + || remaining.startsWith("GROUP") || remaining.startsWith("JOIN")) { + return containerName; + } + } + // Otherwise the next token is the alias ("FROM root r" → alias is "r") + int aliasStart = afterFrom; + while (afterFrom < queryText.length() + && !Character.isWhitespace(queryText.charAt(afterFrom)) + && queryText.charAt(afterFrom) != '(' + && queryText.charAt(afterFrom) != ')') { + afterFrom++; + } + if (afterFrom > aliasStart) { + return queryText.substring(aliasStart, afterFrom); + } + } + + return containerName; + } + + /** + * Finds the index of a top-level SQL keyword in the query text (case-insensitive), + * ignoring occurrences inside parentheses. + */ + static int findTopLevelKeywordIndex(String queryText, String keyword) { + String queryTextUpper = queryText.toUpperCase(); + String keywordUpper = keyword.toUpperCase(); + int depth = 0; + int keyLen = keywordUpper.length(); + for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { + char ch = queryTextUpper.charAt(i); + if (ch == '(') { + depth++; + } else if (ch == ')') { + depth--; + } else if (depth == 0 && ch == keywordUpper.charAt(0) + && queryTextUpper.startsWith(keywordUpper, i) + && (i == 0 || !Character.isLetterOrDigit(queryTextUpper.charAt(i - 1))) + && (i + keyLen >= queryTextUpper.length() || !Character.isLetterOrDigit(queryTextUpper.charAt(i + keyLen)))) { + return i; + } + } + return -1; + } + + /** + * Finds the index of the top-level WHERE keyword in the query text, + * ignoring WHERE that appears inside parentheses (subqueries). + */ + public static int findTopLevelWhereIndex(String queryTextUpper) { + return findTopLevelKeywordIndex(queryTextUpper, "WHERE"); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 5555b3da671c..17461fa56531 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4367,6 +4367,233 @@ private Mono>> readMany( ); } + @Override + public Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + Class klass) { + + checkNotNull(partitionKeys, "Argument 'partitionKeys' must not be null."); + checkArgument(!partitionKeys.isEmpty(), "Argument 'partitionKeys' must not be empty."); + + final ScopedDiagnosticsFactory diagnosticsFactory = new ScopedDiagnosticsFactory(this, true); + state.registerDiagnosticsFactory( + () -> {}, // we never want to reset in readManyByPartitionKey + (ctx) -> diagnosticsFactory.merge(ctx) + ); + + String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); + RxDocumentServiceRequest request = RxDocumentServiceRequest.create(diagnosticsFactory, + OperationType.Query, + ResourceType.Document, + collectionLink, null + ); + + Mono> collectionObs = + collectionCache.resolveCollectionAsync(null, request); + + return collectionObs + .flatMapMany(documentCollectionResourceResponse -> { + final DocumentCollection collection = documentCollectionResourceResponse.v; + if (collection == null) { + return Flux.error(new IllegalStateException("Collection cannot be null")); + } + + final PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); + + Mono> valueHolderMono = partitionKeyRangeCache + .tryLookupAsync( + BridgeInternal.getMetaDataDiagnosticContext(request.requestContext.cosmosDiagnostics), + collection.getResourceId(), + null, + null); + + // Validate custom query if provided + Mono queryValidationMono; + if (customQuery != null) { + queryValidationMono = validateCustomQueryForReadManyByPartitionKey( + customQuery, resourceLink, state.getQueryOptions()); + } else { + queryValidationMono = Mono.empty(); + } + + return Mono.zip(valueHolderMono, queryValidationMono.then(Mono.just(true))) + .flatMapMany(tuple -> { + CollectionRoutingMap routingMap = tuple.getT1().v; + if (routingMap == null) { + return Flux.error(new IllegalStateException("Failed to get routing map.")); + } + + Map> partitionRangePkMap = + groupPartitionKeysByPhysicalPartition(partitionKeys, pkDefinition, routingMap); + + List partitionKeySelectors = createPkSelectors(pkDefinition); + + String baseQueryText; + List baseParameters; + if (customQuery != null) { + baseQueryText = customQuery.getQueryText(); + baseParameters = customQuery.getParameters() != null + ? new ArrayList<>(customQuery.getParameters()) + : new ArrayList<>(); + } else { + baseQueryText = "SELECT * FROM c"; + baseParameters = new ArrayList<>(); + } + + // Build per-physical-partition batched queries. + // Each physical partition may have many PKs — split into batches + // to avoid oversized SQL queries. Batch size is configurable via + // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 1000). + int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); + + // Build batches per partition as a list of lists (one inner list per partition). + // Then interleave in round-robin order so that concurrent execution + // prefers different physical partitions over multiple batches of the same partition. + List>> batchesPerPartition = new ArrayList<>(); + int maxBatchesPerPartition = 0; + + for (Map.Entry> entry : partitionRangePkMap.entrySet()) { + List allPks = entry.getValue(); + if (allPks.isEmpty()) { + continue; + } + List> partitionBatches = new ArrayList<>(); + for (int i = 0; i < allPks.size(); i += maxPksPerPartitionQuery) { + List batch = allPks.subList( + i, Math.min(i + maxPksPerPartitionQuery, allPks.size())); + SqlQuerySpec querySpec = ReadManyByPartitionKeyQueryHelper + .createReadManyByPkQuerySpec( + baseQueryText, baseParameters, batch, + partitionKeySelectors, pkDefinition); + partitionBatches.add( + Collections.singletonMap(entry.getKey(), querySpec)); + } + batchesPerPartition.add(partitionBatches); + maxBatchesPerPartition = Math.max(maxBatchesPerPartition, partitionBatches.size()); + } + + if (batchesPerPartition.isEmpty()) { + return Flux.empty(); + } + + // Round-robin interleave: [batch0-p1, batch0-p2, ..., batch0-pN, batch1-p1, batch1-p2, ...] + // This ensures that with bounded concurrency, different partitions are + // preferred over sequential batches of the same partition. + List> interleavedBatches = new ArrayList<>(); + for (int batchIdx = 0; batchIdx < maxBatchesPerPartition; batchIdx++) { + for (List> partitionBatches : batchesPerPartition) { + if (batchIdx < partitionBatches.size()) { + interleavedBatches.add(partitionBatches.get(batchIdx)); + } + } + } + + // Execute all batches with bounded concurrency. + List>> queryFluxes = interleavedBatches + .stream() + .map(batchMap -> queryForReadMany( + diagnosticsFactory, + resourceLink, + new SqlQuerySpec(DUMMY_SQL_QUERY), + state.getQueryOptions(), + klass, + ResourceType.Document, + collection, + Collections.unmodifiableMap(batchMap))) + .collect(Collectors.toList()); + + int fluxConcurrency = Math.min(queryFluxes.size(), + Math.max(Configs.getCPUCnt(), 1)); + + return Flux.mergeSequential(queryFluxes, fluxConcurrency, 1); + }); + }); + } + + private Mono validateCustomQueryForReadManyByPartitionKey( + SqlQuerySpec customQuery, + String resourceLink, + CosmosQueryRequestOptions queryRequestOptions) { + + IDocumentQueryClient queryClient = documentQueryClientImpl( + RxDocumentClientImpl.this, getOperationContextAndListenerTuple(queryRequestOptions)); + + return DocumentQueryExecutionContextFactory + .fetchQueryPlanForValidation(this, queryClient, customQuery, resourceLink, queryRequestOptions) + .flatMap(queryPlan -> { + QueryInfo queryInfo = queryPlan.getQueryInfo(); + + if (queryInfo.hasAggregates()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain aggregates.")); + } + if (queryInfo.hasOrderBy()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain ORDER BY.")); + } + if (queryInfo.hasDistinct()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DISTINCT.")); + } + if (queryInfo.hasGroupBy()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain GROUP BY.")); + } + if (queryInfo.hasDCount()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DCOUNT.")); + } + if (queryInfo.hasNonStreamingOrderBy()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain non-streaming ORDER BY.")); + } + if (queryPlan.hasHybridSearchQueryInfo()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain hybrid/vector/full-text search.")); + } + + return Mono.empty(); + }); + } + + private Map> groupPartitionKeysByPhysicalPartition( + List partitionKeys, + PartitionKeyDefinition pkDefinition, + CollectionRoutingMap routingMap) { + + Map> partitionRangePkMap = new HashMap<>(); + + for (PartitionKey pk : partitionKeys) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); + int componentCount = pkInternal.getComponents().size(); + int definedPathCount = pkDefinition.getPaths().size(); + + List targetRanges; + + if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { + // Partial HPK — compute EPK prefix range and find all overlapping physical partitions + Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( + pkInternal, pkDefinition); + targetRanges = routingMap.getOverlappingRanges(epkRange); + } else { + // Full PK — maps to exactly one physical partition + String effectivePartitionKeyString = PartitionKeyInternalHelper + .getEffectivePartitionKeyString(pkInternal, pkDefinition); + PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); + targetRanges = Collections.singletonList(range); + } + + for (PartitionKeyRange range : targetRanges) { + partitionRangePkMap.computeIfAbsent(range, k -> new ArrayList<>()).add(pk); + } + } + + return partitionRangePkMap; + } + private Map getRangeQueryMap( Map> partitionRangeItemKeyMap, PartitionKeyDefinition partitionKeyDefinition) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java index e142f35339dd..0e8aa78ec190 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java @@ -324,6 +324,17 @@ private static List getFeedRangeEpks(List> range return feedRanges; } + public static Mono fetchQueryPlanForValidation( + DiagnosticsClientContext diagnosticsClientContext, + IDocumentQueryClient queryClient, + SqlQuerySpec sqlQuerySpec, + String resourceLink, + CosmosQueryRequestOptions queryRequestOptions) { + + return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( + diagnosticsClientContext, queryClient, sqlQuerySpec, resourceLink, queryRequestOptions); + } + public static Flux> createDocumentQueryExecutionContextAsync( DiagnosticsClientContext diagnosticsClientContext, IDocumentQueryClient client, diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md new file mode 100644 index 000000000000..f53cde1db37a --- /dev/null +++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md @@ -0,0 +1,133 @@ +# readMany by Partition Key — Design & Implementation Plan + +## Overview + +New `readMany` overloads on `CosmosAsyncContainer` / `CosmosContainer` that accept a +`List` (without item-id). The SDK splits the PK values by physical +partition, generates a streaming query per physical partition, and returns results as +`CosmosPagedFlux` / `CosmosPagedIterable`. + +An optional `SqlQuerySpec` parameter lets callers supply a custom query for projections +and additional filters. The SDK appends the auto-generated PK WHERE clause to it. + +## Decisions + +| Topic | Decision | +|---|---| +| API name | `readMany` — new overload distinguished by `List` parameter | +| Return type | `CosmosPagedFlux` (async) / `CosmosPagedIterable` (sync) | +| Custom query format | `SqlQuerySpec` — full query with parameters; SDK ANDs the PK filter | +| Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | +| PK deduplication | Done at Spark layer only, not in the SDK | +| Spark UDF | New `GetCosmosPartitionKeyValue` UDF | +| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/vector/fulltext | +| Max PK list size | Enforced per invocation (same effective cap as existing readMany) | + +## Phase 1 — SDK Core (`azure-cosmos`) + +### Step 1: New public overloads in CosmosAsyncContainer + +```java + CosmosPagedFlux readMany(List partitionKeys, Class classType) + CosmosPagedFlux readMany(List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) + CosmosPagedFlux readMany(List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) +``` + +All delegate to a private `readManyByPartitionKeyInternal(...)`. + +### Step 2: Sync wrappers in CosmosContainer + +Same signatures returning `CosmosPagedIterable`, delegating to the async container. + +### Step 3: Internal orchestration (RxDocumentClientImpl) + +1. Resolve collection metadata + PK definition from cache. +2. Fetch routing map from `partitionKeyRangeCache`. +3. For each `PartitionKey`: + - Compute effective partition key (EPK). + - Full PK → `getRangeByEffectivePartitionKey()` (single range). + - Partial HPK → compute EPK prefix range → `getOverlappingRanges()` (multiple ranges). + **Note:** partial HPK intentionally fans out to multiple physical partitions. +4. Group PK values by `PartitionKeyRange`. +5. If custom `SqlQuerySpec` provided → validate via query plan (Step 4). +6. Per physical partition → build `SqlQuerySpec` with PK WHERE clause (Step 5). +7. Execute queries via `createReadManyQueryAsync()`. +8. Return results as `CosmosPagedFlux`. + +### Step 4: Custom query validation + +One-time call per invocation (existing query plan caching applies): + +- `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. +- Reject (`IllegalArgumentException`) if: + - `queryInfo.hasAggregates()` + - `queryInfo.hasOrderBy()` + - `queryInfo.hasDistinct()` + - `queryInfo.hasGroupBy()` + - `queryInfo.hasDCount()` + - `queryInfo.hasNonStreamingOrderBy()` + - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` + +### Step 5: Query construction + +**Single PK (HASH):** +```sql +{baseQuery} WHERE c["{pkPath}"] IN (@pk0, @pk1, @pk2) +``` + +**Full HPK (MULTI_HASH):** +```sql +{baseQuery} WHERE (c["{path1}"] = @p0l1 AND c["{path2}"] = @p0l2) + OR (c["{path1}"] = @p1l1 AND c["{path2}"] = @p1l2) +``` + +**Partial HPK (prefix-only):** +```sql +{baseQuery} WHERE (c["{path1}"] = @p0l1) + OR (c["{path1}"] = @p1l1) +``` + +If the base query already has a WHERE clause: +```sql +{selectAndFrom} WHERE ({existingWhere}) AND ({pkFilter}) +``` + +### Step 6: Bridge / accessor wiring + +Expose internal method through `ImplementationBridgeHelpers`. + +## Phase 2 — Spark Connector (`azure-cosmos-spark_3`) + +### Step 7: New UDF — `GetCosmosPartitionKeyValue` + +- Input: partition key column(s) as array. +- Output: serialized PK string. + +### Step 8: PK-only serialization helper + +`CosmosPartitionKeyHelper`: +- `getCosmosPartitionKeyValueString(pkValues)` — serialize. +- `tryParsePartitionKey(serialized)` — deserialize. + +### Step 9: `CosmosItemsDataSource.readManyByPartitionKey` + +Static entry points, deduplicates PKs at Spark level, delegates to reader. + +### Step 10: `CosmosReadManyByPartitionKeyReader` + +Per-Spark-partition execution, analogous to `CosmosReadManyReader`. + +### Step 11: `ItemsPartitionReaderWithReadManyByPartitionKey` + +Calls new SDK API with `Iterator[PartitionKey]`, iterates `CosmosPagedFlux` pages. + +## Phase 3 — Testing + +- Unit tests: query construction (single PK, HPK, partial HPK, custom query composition). +- Unit tests: query plan rejection (aggregates, ORDER BY, DISTINCT, etc.). +- Integration tests: end-to-end SDK + Spark UDF. From 9a5b3e96e7e6d1ad83990bfea9244a8284ccbb81 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 14 Apr 2026 16:49:23 +0200 Subject: [PATCH 02/25] Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../ItemsPartitionReaderWithReadManyByPartitionKey.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index 68e41cad3ec5..9df4b79fb238 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -118,7 +118,14 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey readManyOptionsImpl .setCustomItemSerializer( new CosmosItemSerializerNoExceptionWrapping { - override def serialize[T](item: T): util.Map[String, AnyRef] = ??? + override def serialize[T](item: T): util.Map[String, AnyRef] = { + throw new UnsupportedOperationException( + s"Serialization is not supported by the custom item serializer in " + + s"ItemsPartitionReaderWithReadManyByPartitionKey; this serializer is intended " + + s"for deserializing read-many responses into SparkRowItem only. " + + s"Unexpected item type: ${if (item == null) "null" else item.getClass.getName}" + ) + } override def deserialize[T](jsonNodeMap: util.Map[String, AnyRef], classType: Class[T]): T = { if (jsonNodeMap == null) { From a8720c3c9f27a4f1861ea2e59d5016541d90904b Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 14 Apr 2026 16:50:35 +0200 Subject: [PATCH 03/25] Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ...tionReaderWithReadManyByPartitionKey.scala | 89 ++++++++++++------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index 9df4b79fb238..dddbb23a0050 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -170,41 +170,66 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey .enable(true) .build + private trait CloseableSparkRowItemIterator { + def hasNext: Boolean + def next(): SparkRowItem + def close(): Unit + } + + private object EmptySparkRowItemIterator extends CloseableSparkRowItemIterator { + override def hasNext: Boolean = false + + override def next(): SparkRowItem = { + throw new java.util.NoSuchElementException("No items available for empty partition-key list.") + } + + override def close(): Unit = {} + } + // Single iterator over all PKs — the SDK handles per-physical-partition batching // internally to avoid oversized SQL queries. - private lazy val iterator = new TransientIOErrorsRetryingIterator[SparkRowItem]( - continuationToken => { - val options = new CosmosReadManyRequestOptions() - val optionsImpl = ImplementationBridgeHelpers - .CosmosReadManyRequestOptionsHelper - .getCosmosReadManyRequestOptionsAccessor - .getImpl(options) - - ThroughputControlHelper.populateThroughputControlGroupName(optionsImpl, readConfig.throughputControlConfig) - - if (operationContextAndListenerTuple.isDefined) { - optionsImpl.setOperationContextAndListenerTuple(operationContextAndListenerTuple.get) - } - - optionsImpl.setCustomItemSerializer(readManyOptionsImpl.getCustomItemSerializer) - - if (pkList.isEmpty) { - cosmosAsyncContainer.readManyByPartitionKey( - new java.util.ArrayList[PartitionKey](), options, classOf[SparkRowItem]) - } else { - readConfig.customQuery match { - case Some(query) => - cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, options, classOf[SparkRowItem]) - case None => - cosmosAsyncContainer.readManyByPartitionKey(pkList, options, classOf[SparkRowItem]) - } + // Short-circuit empty PK lists locally because the SDK rejects empty partition-key lists. + private lazy val iterator: CloseableSparkRowItemIterator = + if (pkList.isEmpty) { + EmptySparkRowItemIterator + } else { + new CloseableSparkRowItemIterator { + private val delegate = new TransientIOErrorsRetryingIterator[SparkRowItem]( + continuationToken => { + val options = new CosmosReadManyRequestOptions() + val optionsImpl = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor + .getImpl(options) + + ThroughputControlHelper.populateThroughputControlGroupName(optionsImpl, readConfig.throughputControlConfig) + + if (operationContextAndListenerTuple.isDefined) { + optionsImpl.setOperationContextAndListenerTuple(operationContextAndListenerTuple.get) + } + + optionsImpl.setCustomItemSerializer(readManyOptionsImpl.getCustomItemSerializer) + + readConfig.customQuery match { + case Some(query) => + cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, options, classOf[SparkRowItem]) + case None => + cosmosAsyncContainer.readManyByPartitionKey(pkList, options, classOf[SparkRowItem]) + } + }, + readConfig.maxItemCount, + readConfig.prefetchBufferSize, + operationContextAndListenerTuple, + None + ) + + override def hasNext: Boolean = delegate.hasNext + + override def next(): SparkRowItem = delegate.next() + + override def close(): Unit = delegate.close() } - }, - readConfig.maxItemCount, - readConfig.prefetchBufferSize, - operationContextAndListenerTuple, - None - ) + } private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) From d499da76fb44640c0405a1ac71195aa18781c27c Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 14 Apr 2026 16:51:24 +0200 Subject: [PATCH 04/25] Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../com/azure/cosmos/implementation/RxDocumentClientImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 17461fa56531..db667ddeb420 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4508,7 +4508,7 @@ public Flux> readManyByPartitionKey( int fluxConcurrency = Math.min(queryFluxes.size(), Math.max(Configs.getCPUCnt(), 1)); - return Flux.mergeSequential(queryFluxes, fluxConcurrency, 1); + return Flux.merge(queryFluxes, fluxConcurrency, 1); }); }); } From c3c542a33a79c189caff739f4412969c93c25b38 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 14 Apr 2026 16:52:35 +0200 Subject: [PATCH 05/25] Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index fbe329bff844..b8f9fa4ecc81 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -4,7 +4,6 @@ import com.azure.cosmos.BridgeInternal; import com.azure.cosmos.implementation.routing.PartitionKeyInternal; -import com.azure.cosmos.models.ModelBridgeInternal; import com.azure.cosmos.models.PartitionKey; import com.azure.cosmos.models.PartitionKeyDefinition; import com.azure.cosmos.models.PartitionKind; From 4416354e03e8973ae6a9195b05d2859301bd1561 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 14 Apr 2026 20:50:32 +0000 Subject: [PATCH 06/25] =?UTF-8?q?=C2=B4Fixing=20code=20review=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...ReadManyByPartitionKeyQueryHelperTest.java | 34 +++++++++---------- .../ReadManyByPartitionKeyQueryHelper.java | 10 +++--- .../implementation/RxDocumentClientImpl.java | 2 +- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java index 9c68b3b126e4..4f82db50bb60 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java @@ -37,7 +37,7 @@ public void singlePk_defaultQuery_singleValue() { assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); assertThat(result.getQueryText()).contains("IN ("); - assertThat(result.getQueryText()).contains("@pkParam0"); + assertThat(result.getQueryText()).contains("@__rmPk_0"); assertThat(result.getParameters()).hasSize(1); assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("pk1"); } @@ -55,9 +55,9 @@ public void singlePk_defaultQuery_multipleValues() { "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); assertThat(result.getQueryText()).contains("IN ("); - assertThat(result.getQueryText()).contains("@pkParam0"); - assertThat(result.getQueryText()).contains("@pkParam1"); - assertThat(result.getQueryText()).contains("@pkParam2"); + assertThat(result.getQueryText()).contains("@__rmPk_0"); + assertThat(result.getQueryText()).contains("@__rmPk_1"); + assertThat(result.getQueryText()).contains("@__rmPk_2"); assertThat(result.getParameters()).hasSize(3); } @@ -89,7 +89,7 @@ public void singlePk_customQuery_withExistingWhere() { // Should AND the PK filter to the existing WHERE clause assertThat(result.getQueryText()).contains("WHERE (c.age > @minAge) AND ("); assertThat(result.getQueryText()).contains("IN ("); - assertThat(result.getParameters()).hasSize(2); // @minAge + @pkParam1 + assertThat(result.getParameters()).hasSize(2); // @minAge + @__rmPk_0 assertThat(result.getParameters().get(0).getName()).isEqualTo("@minAge"); } @@ -111,9 +111,9 @@ public void hpk_fullPk_defaultQuery() { assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); // Should use OR/AND pattern, not IN assertThat(result.getQueryText()).doesNotContain("IN ("); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); assertThat(result.getQueryText()).contains("AND"); - assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); assertThat(result.getParameters()).hasSize(2); assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("Redmond"); assertThat(result.getParameters().get(1).getValue(Object.class)).isEqualTo("98052"); @@ -132,10 +132,10 @@ public void hpk_fullPk_multipleValues() { "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); assertThat(result.getQueryText()).contains("OR"); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); - assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam2"); - assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam3"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_2"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_3"); assertThat(result.getParameters()).hasSize(4); } @@ -151,7 +151,7 @@ public void hpk_partialPk_singleLevel() { SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); // Should NOT include zipcode or areaCode since it's partial assertThat(result.getQueryText()).doesNotContain("zipcode"); assertThat(result.getQueryText()).doesNotContain("areaCode"); @@ -170,8 +170,8 @@ public void hpk_partialPk_twoLevels() { SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam0"); - assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); assertThat(result.getQueryText()).doesNotContain("areaCode"); assertThat(result.getParameters()).hasSize(2); } @@ -191,7 +191,7 @@ public void hpk_customQuery_withWhere() { "SELECT c.name FROM c WHERE c.status = @status", baseParams, pkValues, selectors, pkDef); assertThat(result.getQueryText()).contains("WHERE (c.status = @status) AND ("); - assertThat(result.getQueryText()).contains("c[\"city\"] = @pkParam1"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); assertThat(result.getParameters()).hasSize(3); // @status + 2 pk params } @@ -277,8 +277,8 @@ public void hpk_customAlias() { SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( "SELECT r.name FROM root r", new ArrayList<>(), pkValues, selectors, pkDef); - assertThat(result.getQueryText()).contains("r[\"city\"] = @pkParam0"); - assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @pkParam1"); + assertThat(result.getQueryText()).contains("r[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @__rmPk_1"); assertThat(result.getQueryText()).doesNotContain("c[\""); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index b8f9fa4ecc81..36880eb99b59 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -20,6 +20,8 @@ public class ReadManyByPartitionKeyQueryHelper { private static final String DEFAULT_TABLE_ALIAS = "c"; + // Internal parameter prefix — uses double-underscore to avoid collisions with user-provided parameters + private static final String PK_PARAM_PREFIX = "@__rmPk_"; public static SqlQuerySpec createReadManyByPkQuerySpec( String baseQueryText, @@ -33,12 +35,12 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( StringBuilder pkFilter = new StringBuilder(); List parameters = new ArrayList<>(baseParameters); - int paramCount = baseParameters.size(); + int paramCount = 0; boolean isSinglePathPk = partitionKeySelectors.size() == 1; if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { - // Single PK path — use IN clause: alias["pkPath"] IN (@pk0, @pk1, ...) + // Single PK path — use IN clause: alias["pkPath"] IN (@__rmPk_0, @__rmPk_1, ...) pkFilter.append(" "); pkFilter.append(tableAlias); pkFilter.append(partitionKeySelectors.get(0)); @@ -46,7 +48,7 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( for (int i = 0; i < pkValues.size(); i++) { PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); Object[] pkComponents = pkInternal.toObjectArray(); - String pkParamName = "@pkParam" + paramCount; + String pkParamName = PK_PARAM_PREFIX + paramCount; parameters.add(new SqlParameter(pkParamName, pkComponents[0])); paramCount++; @@ -65,7 +67,7 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( pkFilter.append("("); for (int j = 0; j < pkComponents.length; j++) { - String pkParamName = "@pkParam" + paramCount; + String pkParamName = PK_PARAM_PREFIX + paramCount; parameters.add(new SqlParameter(pkParamName, pkComponents[j])); paramCount++; diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index db667ddeb420..5c276fb18e7e 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4508,7 +4508,7 @@ public Flux> readManyByPartitionKey( int fluxConcurrency = Math.min(queryFluxes.size(), Math.max(Configs.getCPUCnt(), 1)); - return Flux.merge(queryFluxes, fluxConcurrency, 1); + return Flux.merge(Flux.fromIterable(queryFluxes), fluxConcurrency, 1); }); }); } From 588a7550c545717f35bdb76e5e0fead1d38ac3b8 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 15 Apr 2026 09:49:34 +0000 Subject: [PATCH 07/25] Update CosmosAsyncContainer.java --- .../main/java/com/azure/cosmos/CosmosAsyncContainer.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index b0258ec1b561..42e5dfeb0f17 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -1680,10 +1680,10 @@ private Function>> readManyByPa return (pagedFluxOptions -> { CosmosQueryRequestOptions queryRequestOptions = requestOptions == null ? new CosmosQueryRequestOptions() - : queryOptionsAccessor.clone(readManyOptionsAccessor.getImpl(requestOptions)); + : queryOptionsAccessor().clone(readManyOptionsAccessor().getImpl(requestOptions)); queryRequestOptions.setMaxDegreeOfParallelism(-1); queryRequestOptions.setQueryName("readManyByPartitionKey"); - CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor.getImpl(queryRequestOptions); + CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyItemsSpanName); QueryFeedOperationState state = new QueryFeedOperationState( @@ -1693,7 +1693,7 @@ private Function>> readManyByPa this.getId(), ResourceType.Document, OperationType.Query, - queryOptionsAccessor.getQueryNameOrDefault(queryRequestOptions, this.readManyItemsSpanName), + queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyItemsSpanName), queryRequestOptions, pagedFluxOptions ); From f5485527a9c45892c58945cee66d02f5bab463b4 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 15 Apr 2026 12:04:49 +0000 Subject: [PATCH 08/25] Update ReadManyByPartitionKeyTest.java --- .../azure/cosmos/ReadManyByPartitionKeyTest.java | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java index 8fa657b62105..a5cb16dc2715 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -294,7 +294,7 @@ public void rejectsDistinctQuery() { @Test(groups = {"emulator"}, timeOut = TIMEOUT) public void rejectsGroupByQuery() { List pkValues = Collections.singletonList(new PartitionKey("pk1")); - SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); + SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk FROM c GROUP BY c.mypk"); try { singlePkContainer.readManyByPartitionKey(pkValues, groupByQuery, null, ObjectNode.class) @@ -305,6 +305,20 @@ public void rejectsGroupByQuery() { } } + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsGroupByWithAggregateQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec groupByWithAggregateQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, groupByWithAggregateQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for GROUP BY with aggregate query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("aggregates"); + } + } + @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) public void rejectsNullPartitionKeyList() { singlePkContainer.readManyByPartitionKey((List) null, ObjectNode.class); From f68cf02ff71cb784ccf6a5668589d1b64d4831a9 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 15 Apr 2026 20:08:33 +0000 Subject: [PATCH 09/25] Fixing test issues --- .../java/com/azure/cosmos/ReadManyByPartitionKeyTest.java | 2 +- .../azure/cosmos/implementation/RxDocumentClientImpl.java | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java index a5cb16dc2715..da76d795f981 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -315,7 +315,7 @@ public void rejectsGroupByWithAggregateQuery() { .stream().collect(Collectors.toList()); fail("Should have thrown IllegalArgumentException for GROUP BY with aggregate query"); } catch (IllegalArgumentException e) { - assertThat(e.getMessage()).contains("aggregates"); + assertThat(e.getMessage()).contains("GROUP BY"); } } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index dbd0d5553278..26986fb1dd8c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4524,6 +4524,10 @@ private Mono validateCustomQueryForReadManyByPartitionKey( .flatMap(queryPlan -> { QueryInfo queryInfo = queryPlan.getQueryInfo(); + if (queryInfo.hasGroupBy()) { + return Mono.error(new IllegalArgumentException( + "Custom query for readMany by partition key must not contain GROUP BY.")); + } if (queryInfo.hasAggregates()) { return Mono.error(new IllegalArgumentException( "Custom query for readMany by partition key must not contain aggregates.")); @@ -4536,10 +4540,6 @@ private Mono validateCustomQueryForReadManyByPartitionKey( return Mono.error(new IllegalArgumentException( "Custom query for readMany by partition key must not contain DISTINCT.")); } - if (queryInfo.hasGroupBy()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain GROUP BY.")); - } if (queryInfo.hasDCount()) { return Mono.error(new IllegalArgumentException( "Custom query for readMany by partition key must not contain DCOUNT.")); From 8b6c4b168ea2c4b5e538524a9b8419513dce9693 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 15 Apr 2026 22:36:29 +0000 Subject: [PATCH 10/25] Update CosmosAsyncContainer.java --- .../main/java/com/azure/cosmos/CosmosAsyncContainer.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index 42e5dfeb0f17..79579aaf1715 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -1665,6 +1665,13 @@ public CosmosPagedFlux readManyByPartitionKey( CosmosReadManyRequestOptions requestOptions, Class classType) { + if (partitionKeys == null) { + throw new IllegalArgumentException("Argument 'partitionKeys' must not be null."); + } + if (partitionKeys.isEmpty()) { + throw new IllegalArgumentException("Argument 'partitionKeys' must not be empty."); + } + return UtilBridgeInternal.createCosmosPagedFlux( readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); } From 56b067a93391c061108a3b8ba8bf1d5db3de4a91 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 16 Apr 2026 00:49:47 +0000 Subject: [PATCH 11/25] Reacted to code review feedback --- .../cosmos/spark/CosmosItemsDataSource.scala | 8 +- .../spark/CosmosPartitionKeyHelper.scala | 7 +- .../CosmosReadManyByPartitionKeyReader.scala | 3 - ...tionReaderWithReadManyByPartitionKey.scala | 18 +-- .../spark/CosmosPartitionKeyHelperSpec.scala | 12 ++ .../cosmos/ReadManyByPartitionKeyTest.java | 67 ++++++++++ ...ReadManyByPartitionKeyQueryHelperTest.java | 77 +++++++++++ .../azure/cosmos/CosmosAsyncContainer.java | 8 +- .../ReadManyByPartitionKeyQueryHelper.java | 19 ++- .../implementation/RxDocumentClientImpl.java | 7 +- .../docs/readManyByPartitionKey-design.md | 124 +++++++++++------- 11 files changed, 269 insertions(+), 81 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index 2fbc036724e7..6257e96e81e5 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -147,7 +147,7 @@ object CosmosItemsDataSource { val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" - val pkPathsOpt = Loan( + val pkPaths = Loan( List[Option[CosmosClientCacheItem]]( Some( CosmosClientCache( @@ -180,10 +180,10 @@ object CosmosItemsDataSource { // Check if ALL PK path columns exist in the DataFrame schema val dfFieldNames = df.schema.fieldNames.toSet - val allPkColumnsPresent = pkPathsOpt.forall(path => dfFieldNames.contains(path)) + val allPkColumnsPresent = pkPaths.forall(path => dfFieldNames.contains(path)) - if (allPkColumnsPresent && pkPathsOpt.nonEmpty) { - val pkPaths = pkPathsOpt + if (allPkColumnsPresent && pkPaths.nonEmpty) { + // pkPaths already defined above Some((row: Row) => { if (pkPaths.size == 1) { // Single partition key diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala index 616e1893b343..27776f5c3de6 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -30,15 +30,14 @@ private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = { cosmosPartitionKeyString match { case cosmosPartitionKeyStringRegx(pkValue) => - val partitionKeyValue = Utils.parse(pkValue, classOf[Object]) - partitionKeyValue match { - case arrayList: util.ArrayList[Object] => + scala.util.Try(Utils.parse(pkValue, classOf[Object])).toOption.flatMap { + case arrayList: util.ArrayList[Object @unchecked] => Some( ImplementationBridgeHelpers .PartitionKeyHelper .getPartitionKeyAccessor .toPartitionKey(PartitionKeyInternal.fromObjectArray(arrayList.toArray, false))) - case _ => Some(new PartitionKey(partitionKeyValue)) + case other => Some(new PartitionKey(other)) } case _ => None } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala index 1d324d4855ff..91f3a56bc664 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -148,6 +148,3 @@ private[spark] class CosmosReadManyByPartitionKeyReader( } } -private object CosmosReadManyByPartitionKeyHelper { - val FullRangeFeedRange: NormalizedRange = NormalizedRange("", "FF") -} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index dddbb23a0050..da3b81d951ae 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -196,25 +196,11 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey new CloseableSparkRowItemIterator { private val delegate = new TransientIOErrorsRetryingIterator[SparkRowItem]( continuationToken => { - val options = new CosmosReadManyRequestOptions() - val optionsImpl = ImplementationBridgeHelpers - .CosmosReadManyRequestOptionsHelper - .getCosmosReadManyRequestOptionsAccessor - .getImpl(options) - - ThroughputControlHelper.populateThroughputControlGroupName(optionsImpl, readConfig.throughputControlConfig) - - if (operationContextAndListenerTuple.isDefined) { - optionsImpl.setOperationContextAndListenerTuple(operationContextAndListenerTuple.get) - } - - optionsImpl.setCustomItemSerializer(readManyOptionsImpl.getCustomItemSerializer) - readConfig.customQuery match { case Some(query) => - cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, options, classOf[SparkRowItem]) + cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, readManyOptions, classOf[SparkRowItem]) case None => - cosmosAsyncContainer.readManyByPartitionKey(pkList, options, classOf[SparkRowItem]) + cosmosAsyncContainer.readManyByPartitionKey(pkList, readManyOptions, classOf[SparkRowItem]) } }, readConfig.maxItemCount, diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala index d127710da287..182f0c3cc3d3 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -76,6 +76,18 @@ class CosmosPartitionKeyHelperSpec extends UnitSpec { pk.get shouldEqual new PartitionKey("value") } + + it should "return None for malformed JSON inside pk() wrapper" in { + // Invalid JSON that would cause JsonProcessingException + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk({invalid json})") + pk.isDefined shouldBe false + } + + it should "return None for truncated JSON inside pk() wrapper" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk(["unterminated)") + pk.isDefined shouldBe false + } + //scalastyle:on multiple.string.literals //scalastyle:on magic.number } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java index da76d795f981..2c26d564ed24 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -332,6 +332,73 @@ public void rejectsEmptyPartitionKeyList() { //endregion + + //region Batch size tests (#10) + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withSmallBatchSize() { + // Temporarily set batch size to 2 to exercise the batching/interleaving logic + String originalValue = System.getProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); + try { + System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", "2"); + + // Create items across 4 PKs (more than the batch size of 2) + List items = createSinglePkItems("batchPk1", 2); + items.addAll(createSinglePkItems("batchPk2", 2)); + items.addAll(createSinglePkItems("batchPk3", 2)); + items.addAll(createSinglePkItems("batchPk4", 2)); + + // Read all 4 PKs — should be split into batches of 2 + List pkValues = Arrays.asList( + new PartitionKey("batchPk1"), + new PartitionKey("batchPk2"), + new PartitionKey("batchPk3"), + new PartitionKey("batchPk4")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(8); // 2 items per PK * 4 PKs + resultList.forEach(item -> { + String pk = item.get("mypk").asText(); + assertThat(pk).isIn("batchPk1", "batchPk2", "batchPk3", "batchPk4"); + }); + + cleanupContainer(singlePkContainer); + } finally { + if (originalValue != null) { + System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", originalValue); + } else { + System.clearProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); + } + } + } + + //endregion + + //region Custom serializer regression tests (#5) + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withRequestOptions() { + // This test ensures that request options (like throughput control settings) + // are properly propagated through the readManyByPartitionKey path. + // It acts as a regression test for the redundant options construction bug. + List items = createSinglePkItems("pkOpts", 3); + + List pkValues = Collections.singletonList(new PartitionKey("pkOpts")); + com.azure.cosmos.models.CosmosReadManyRequestOptions options = new com.azure.cosmos.models.CosmosReadManyRequestOptions(); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, options, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(3); + + cleanupContainer(singlePkContainer); + } + + //endregion + //region helper methods private List createSinglePkItems(String pkValue, int count) { diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java index 4f82db50bb60..95c109ba025f 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java @@ -318,6 +318,83 @@ public void extractAlias_caseInsensitive() { //endregion + + //region String literal handling tests (#1) + + @Test(groups = { "unit" }) + public void findWhere_ignoresWhereInsideStringLiteral() { + // WHERE inside a string literal should be ignored + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.msg = 'use WHERE clause here'"); + // Should find the outer WHERE at position 16, not the one inside the string + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_ignoresParenthesesInsideStringLiteral() { + // Parentheses inside string literal should not affect depth tracking + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.name = 'foo(bar)' AND c.x = 1"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_handlesUnbalancedParenInStringLiteral() { + // Unbalanced paren inside string literal must not corrupt depth + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.val = 'open(' AND c.active = true"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_handlesStringLiteralBeforeWhere() { + // String literal in SELECT before WHERE + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT 'WHERE' as label FROM c WHERE c.id = '1'"); + // The WHERE inside quotes should be ignored; the real WHERE is further along + assertThat(idx).isGreaterThan(30); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_withStringLiteralContainingParens() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@msg", "hello")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c WHERE c.msg = 'test(value)WHERE'", baseParams, pkValues, selectors, pkDef); + + // Should correctly AND the PK filter to the real WHERE clause + assertThat(result.getQueryText()).contains("WHERE (c.msg = 'test(value)WHERE') AND ("); + } + + //endregion + + //region OFFSET/LIMIT/HAVING alias detection tests (#9) + + @Test(groups = { "unit" }) + public void extractAlias_containerWithOffset() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT * FROM c OFFSET 10 LIMIT 5")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithLimit() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT * FROM c LIMIT 10")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithHaving() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT c.status, COUNT(1) FROM c GROUP BY c.status HAVING COUNT(1) > 1")).isEqualTo("c"); + } + + //endregion + //region helpers private PartitionKeyDefinition createSinglePkDefinition(String path) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index 79579aaf1715..025a957a4606 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -165,6 +165,7 @@ private static ImplementationBridgeHelpers.CosmosBatchRequestOptionsHelper.Cosmo private final String createItemSpanName; private final String readAllItemsSpanName; private final String readManyItemsSpanName; + private final String readManyByPartitionKeyItemsSpanName; private final String readAllItemsOfLogicalPartitionSpanName; private final String queryItemsSpanName; private final String queryChangeFeedSpanName; @@ -198,6 +199,7 @@ protected CosmosAsyncContainer(CosmosAsyncContainer toBeWrappedContainer) { this.createItemSpanName = "createItem." + this.id; this.readAllItemsSpanName = "readAllItems." + this.id; this.readManyItemsSpanName = "readManyItems." + this.id; + this.readManyByPartitionKeyItemsSpanName = "readManyByPartitionKeyItems." + this.id; this.readAllItemsOfLogicalPartitionSpanName = "readAllItemsOfLogicalPartition." + this.id; this.queryItemsSpanName = "queryItems." + this.id; this.queryChangeFeedSpanName = "queryChangeFeed." + this.id; @@ -1691,16 +1693,16 @@ private Function>> readManyByPa queryRequestOptions.setMaxDegreeOfParallelism(-1); queryRequestOptions.setQueryName("readManyByPartitionKey"); CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); - applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyItemsSpanName); + applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeyItemsSpanName); QueryFeedOperationState state = new QueryFeedOperationState( client, - this.readManyItemsSpanName, + this.readManyByPartitionKeyItemsSpanName, database.getId(), this.getId(), ResourceType.Document, OperationType.Query, - queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyItemsSpanName), + queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeyItemsSpanName), queryRequestOptions, pagedFluxOptions ); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index 36880eb99b59..d538542df1ba 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -140,12 +140,15 @@ static String extractTableAlias(String queryText) { // Check if there's an alias after the container name (before WHERE or end) if (afterFrom < queryText.length()) { char nextChar = Character.toUpperCase(queryText.charAt(afterFrom)); - // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN) or end, containerName IS the alias - if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J') { + // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING) or end, containerName IS the alias + if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J' + || nextChar == 'L' || nextChar == 'H') { // Check if it's actually a keyword String remaining = upper.substring(afterFrom); if (remaining.startsWith("WHERE") || remaining.startsWith("ORDER") - || remaining.startsWith("GROUP") || remaining.startsWith("JOIN")) { + || remaining.startsWith("GROUP") || remaining.startsWith("JOIN") + || remaining.startsWith("OFFSET") || remaining.startsWith("LIMIT") + || remaining.startsWith("HAVING")) { return containerName; } } @@ -167,7 +170,7 @@ static String extractTableAlias(String queryText) { /** * Finds the index of a top-level SQL keyword in the query text (case-insensitive), - * ignoring occurrences inside parentheses. + * ignoring occurrences inside parentheses or string literals. */ static int findTopLevelKeywordIndex(String queryText, String keyword) { String queryTextUpper = queryText.toUpperCase(); @@ -176,6 +179,14 @@ static int findTopLevelKeywordIndex(String queryText, String keyword) { int keyLen = keywordUpper.length(); for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { char ch = queryTextUpper.charAt(i); + // Skip string literals enclosed in single quotes + if (queryText.charAt(i) == '\'') { + i++; + while (i < queryText.length() && queryText.charAt(i) != '\'') { + i++; + } + continue; + } if (ch == '(') { depth++; } else if (ch == ')') { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 26986fb1dd8c..b4c72532280c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4417,9 +4417,10 @@ public Flux> readManyByPartitionKey( queryValidationMono = Mono.empty(); } - return Mono.zip(valueHolderMono, queryValidationMono.then(Mono.just(true))) - .flatMapMany(tuple -> { - CollectionRoutingMap routingMap = tuple.getT1().v; + return valueHolderMono + .delayUntil(ignored -> queryValidationMono) + .flatMapMany(routingMapHolder -> { + CollectionRoutingMap routingMap = routingMapHolder.v; if (routingMap == null) { return Flux.error(new IllegalStateException("Failed to get routing map.")); } diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md index f53cde1db37a..95d7624f0c8b 100644 --- a/sdk/cosmos/docs/readManyByPartitionKey-design.md +++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md @@ -1,10 +1,10 @@ -# readMany by Partition Key — Design & Implementation Plan +# readManyByPartitionKey — Design & Implementation ## Overview -New `readMany` overloads on `CosmosAsyncContainer` / `CosmosContainer` that accept a +New `readManyByPartitionKey` methods on `CosmosAsyncContainer` / `CosmosContainer` that accept a `List` (without item-id). The SDK splits the PK values by physical -partition, generates a streaming query per physical partition, and returns results as +partition, generates batched streaming queries per physical partition, and returns results as `CosmosPagedFlux` / `CosmosPagedIterable`. An optional `SqlQuerySpec` parameter lets callers supply a custom query for projections @@ -14,31 +14,36 @@ and additional filters. The SDK appends the auto-generated PK WHERE clause to it | Topic | Decision | |---|---| -| API name | `readMany` — new overload distinguished by `List` parameter | +| API name | `readManyByPartitionKey` — distinct name to avoid ambiguity with existing `readMany(List)` | | Return type | `CosmosPagedFlux` (async) / `CosmosPagedIterable` (sync) | | Custom query format | `SqlQuerySpec` — full query with parameters; SDK ANDs the PK filter | | Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | | PK deduplication | Done at Spark layer only, not in the SDK | | Spark UDF | New `GetCosmosPartitionKeyValue` UDF | -| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/vector/fulltext | -| Max PK list size | Enforced per invocation (same effective cap as existing readMany) | +| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/non-streaming ORDER BY/vector/fulltext | +| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 1000 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | +| Eager validation | Null and empty PK list rejected eagerly (not lazily in reactive chain) | +| Telemetry | Separate span name `readManyByPartitionKeyItems.` (distinct from existing `readManyItems`) | +| Query construction | Table alias auto-detected from FROM clause; string literals and subqueries handled correctly | ## Phase 1 — SDK Core (`azure-cosmos`) ### Step 1: New public overloads in CosmosAsyncContainer ```java - CosmosPagedFlux readMany(List partitionKeys, Class classType) - CosmosPagedFlux readMany(List partitionKeys, - CosmosReadManyRequestOptions requestOptions, - Class classType) - CosmosPagedFlux readMany(List partitionKeys, - SqlQuerySpec customQuery, - CosmosReadManyRequestOptions requestOptions, - Class classType) + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, Class classType) + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) ``` -All delegate to a private `readManyByPartitionKeyInternal(...)`. +All delegate to a private `readManyByPartitionKeyInternalFunc(...)`. + +**Eager validation:** The 4-arg method validates `partitionKeys` is non-null and non-empty before constructing the reactive pipeline, throwing `IllegalArgumentException` synchronously. ### Step 2: Sync wrappers in CosmosContainer @@ -47,49 +52,56 @@ Same signatures returning `CosmosPagedIterable`, delegating to the async cont ### Step 3: Internal orchestration (RxDocumentClientImpl) 1. Resolve collection metadata + PK definition from cache. -2. Fetch routing map from `partitionKeyRangeCache`. +2. Fetch routing map from `partitionKeyRangeCache` **in parallel with** custom query validation (Step 4). 3. For each `PartitionKey`: - Compute effective partition key (EPK). - Full PK → `getRangeByEffectivePartitionKey()` (single range). - Partial HPK → compute EPK prefix range → `getOverlappingRanges()` (multiple ranges). **Note:** partial HPK intentionally fans out to multiple physical partitions. 4. Group PK values by `PartitionKeyRange`. -5. If custom `SqlQuerySpec` provided → validate via query plan (Step 4). -6. Per physical partition → build `SqlQuerySpec` with PK WHERE clause (Step 5). -7. Execute queries via `createReadManyQueryAsync()`. -8. Return results as `CosmosPagedFlux`. +5. Per physical partition → split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 1000). +6. Per batch → build `SqlQuerySpec` with PK WHERE clause (Step 5). +7. Interleave batches across physical partitions in round-robin order so that bounded concurrency prefers different physical partitions over sequential batches of the same partition. +8. Execute queries via `queryForReadMany()` with bounded concurrency (`Math.min(batchCount, cpuCount)`). +9. Return results as `CosmosPagedFlux`. ### Step 4: Custom query validation -One-time call per invocation (existing query plan caching applies): +One-time call per invocation (existing query plan caching applies). Runs **in parallel** with routing map lookup to minimize latency: - `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. - Reject (`IllegalArgumentException`) if: + - `queryInfo.hasGroupBy()` — checked first (takes precedence over aggregates since `hasAggregates()` also returns true for GROUP BY queries) - `queryInfo.hasAggregates()` - `queryInfo.hasOrderBy()` - `queryInfo.hasDistinct()` - - `queryInfo.hasGroupBy()` - `queryInfo.hasDCount()` - `queryInfo.hasNonStreamingOrderBy()` - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` ### Step 5: Query construction +Query construction is implemented in `ReadManyByPartitionKeyQueryHelper`. The helper: +- Extracts the table alias from the FROM clause (handles `FROM c`, `FROM root r`, `FROM x WHERE ...`) +- Handles string literals in queries (parens/keywords inside `'...'` are correctly skipped) +- Recognizes SQL keywords: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING +- Uses parameterized queries (`@__rmPk_` prefix) to prevent SQL injection + **Single PK (HASH):** ```sql -{baseQuery} WHERE c["{pkPath}"] IN (@pk0, @pk1, @pk2) +{baseQuery} WHERE {alias}["{pkPath}"] IN (@__rmPk_0, @__rmPk_1, @__rmPk_2) ``` **Full HPK (MULTI_HASH):** ```sql -{baseQuery} WHERE (c["{path1}"] = @p0l1 AND c["{path2}"] = @p0l2) - OR (c["{path1}"] = @p1l1 AND c["{path2}"] = @p1l2) +{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0 AND {alias}["{path2}"] = @__rmPk_1) + OR ({alias}["{path1}"] = @__rmPk_2 AND {alias}["{path2}"] = @__rmPk_3) ``` **Partial HPK (prefix-only):** ```sql -{baseQuery} WHERE (c["{path1}"] = @p0l1) - OR (c["{path1}"] = @p1l1) +{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0) + OR ({alias}["{path1}"] = @__rmPk_1) ``` If the base query already has a WHERE clause: @@ -97,37 +109,61 @@ If the base query already has a WHERE clause: {selectAndFrom} WHERE ({existingWhere}) AND ({pkFilter}) ``` -### Step 6: Bridge / accessor wiring +### Step 6: Interface wiring + +New method `readManyByPartitionKey` added directly to `AsyncDocumentClient` interface, implemented in `RxDocumentClientImpl`. New `fetchQueryPlanForValidation` static method added to `DocumentQueryExecutionContextFactory` for custom query validation. -Expose internal method through `ImplementationBridgeHelpers`. +### Step 7: Configuration + +New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 1000, minimum: 1). Follows existing `Configs` patterns. ## Phase 2 — Spark Connector (`azure-cosmos-spark_3`) -### Step 7: New UDF — `GetCosmosPartitionKeyValue` +### Step 8: New UDF — `GetCosmosPartitionKeyValue` -- Input: partition key column(s) as array. -- Output: serialized PK string. +- Input: partition key value (single value or Seq for hierarchical PKs). +- Output: serialized PK string in format `pk([...json...])`. +- **Null handling:** Throws on null input (Scala convention; callers should filter nulls upstream). -### Step 8: PK-only serialization helper +### Step 9: PK-only serialization helper `CosmosPartitionKeyHelper`: -- `getCosmosPartitionKeyValueString(pkValues)` — serialize. -- `tryParsePartitionKey(serialized)` — deserialize. +- `getCosmosPartitionKeyValueString(pkValues: List[Object]): String` — serialize to `pk([...])` format. +- `tryParsePartitionKey(serialized: String): Option[PartitionKey]` — deserialize; returns `None` for malformed input including invalid JSON (wrapped in `scala.util.Try`). + +### Step 10: `CosmosItemsDataSource.readManyByPartitionKey` -### Step 9: `CosmosItemsDataSource.readManyByPartitionKey` +Static entry points that accept a DataFrame and Cosmos config. PK extraction supports two modes: +1. **UDF-produced column**: DataFrame contains `_partitionKeyIdentity` column (from `GetCosmosPartitionKeyValue` UDF). +2. **Schema-matched columns**: DataFrame columns match the container's PK paths. -Static entry points, deduplicates PKs at Spark level, delegates to reader. +Falls back with `IllegalArgumentException` if neither mode is possible. -### Step 10: `CosmosReadManyByPartitionKeyReader` +### Step 11: `CosmosReadManyByPartitionKeyReader` -Per-Spark-partition execution, analogous to `CosmosReadManyReader`. +Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. -### Step 11: `ItemsPartitionReaderWithReadManyByPartitionKey` +### Step 12: `ItemsPartitionReaderWithReadManyByPartitionKey` -Calls new SDK API with `Iterator[PartitionKey]`, iterates `CosmosPagedFlux` pages. +Spark `PartitionReader[InternalRow]` that: +- Deduplicates PKs via `LinkedHashMap` (by PK string representation). +- Passes the pre-built `CosmosReadManyRequestOptions` (with throughput control, diagnostics, custom serializer) to the SDK. +- Uses `TransientIOErrorsRetryingIterator` for retry handling. +- Short-circuits empty PK lists to avoid SDK rejection. ## Phase 3 — Testing -- Unit tests: query construction (single PK, HPK, partial HPK, custom query composition). -- Unit tests: query plan rejection (aggregates, ORDER BY, DISTINCT, etc.). -- Integration tests: end-to-end SDK + Spark UDF. +### Unit tests +- Query construction: single PK, HPK full/partial, custom query composition, table alias detection. +- Query plan rejection: aggregates, ORDER BY, DISTINCT, GROUP BY (with and without aggregates), DCOUNT. +- String literal handling: WHERE/parentheses inside string constants. +- Keyword detection: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING. +- PK serialization/deserialization roundtrip (including malformed JSON handling). +- `findTopLevelWhereIndex` edge cases: subqueries, string literals, case insensitivity. + +### Integration tests +- End-to-end SDK: single PK basic, projections, filters, empty results, HPK full/partial, request options propagation. +- Batch size validation: temporarily lowered batch size to exercise batching/interleaving logic. +- Null/empty PK list rejection (eager validation). +- Spark connector: `ItemsPartitionReaderWithReadManyByPartitionKey` with known PK values and non-existent PKs. +- `CosmosPartitionKeyHelper`: single/HPK roundtrip, case insensitivity, malformed input. From d9504c91f343d6b6c2e970586ce3c5f994c44dd8 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 16 Apr 2026 10:56:51 +0000 Subject: [PATCH 12/25] Fix build issues --- .../azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala | 2 +- sdk/cosmos/cspell.yaml | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 sdk/cosmos/cspell.yaml diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala index 182f0c3cc3d3..6528d44f5ebe 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -84,7 +84,7 @@ class CosmosPartitionKeyHelperSpec extends UnitSpec { } it should "return None for truncated JSON inside pk() wrapper" in { - val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk(["unterminated)") + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"unterminated)") pk.isDefined shouldBe false } diff --git a/sdk/cosmos/cspell.yaml b/sdk/cosmos/cspell.yaml new file mode 100644 index 000000000000..94a4002c2c9c --- /dev/null +++ b/sdk/cosmos/cspell.yaml @@ -0,0 +1,6 @@ +import: + - ../../.vscode/cspell.json +overrides: + - filename: "**/sdk/cosmos/*" + words: + - DCOUNT From 681830e2d4a134c1f72976cf6cdea22929f6a69e Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 16 Apr 2026 21:48:17 +0000 Subject: [PATCH 13/25] Fixing changelog --- sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + 6 files changed, 6 insertions(+) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index cbf97c610f9f..fe114462019c 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index c9097e749f03..3b2c7ce36db1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index f5eac38bdb71..2240a48b1654 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index 919d7fbfa325..20a3e3a61bd7 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index 3972ae6aeb98..d8368be6a0da 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index e8ea564fab7b..faf661ddd80f 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.80.0-beta.1 (Unreleased) #### Features Added +* Added new `readManyByPartitioNKey` to bulk query by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes From 0b8905dbb011e5eae5720098ca0ddd2de9d4757a Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 16 Apr 2026 23:46:18 +0000 Subject: [PATCH 14/25] Addressing code review comments --- .../com/azure/cosmos/spark/CosmosConfig.scala | 22 +++++++++- .../cosmos/spark/CosmosItemsDataSource.scala | 5 ++- ...tionReaderWithReadManyByPartitionKey.scala | 2 + .../azure/cosmos/spark/CosmosConfigSpec.scala | 42 +++++++++++++++++++ .../spark/CosmosPartitionKeyHelperSpec.scala | 11 +++++ 5 files changed, 79 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..e1b8f0b51f8a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -92,6 +92,7 @@ private[spark] object CosmosConfigNames { val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter" val ReadRuntimeFilteringEnabled = "spark.cosmos.read.runtimeFiltering.enabled" val ReadManyFilteringEnabled = "spark.cosmos.read.readManyFiltering.enabled" + val ReadManyByPkNullHandling = "spark.cosmos.read.readManyByPk.nullHandling" val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath" val DiagnosticsMode = "spark.cosmos.diagnostics" val DiagnosticsSamplingMaxCount = "spark.cosmos.diagnostics.sampling.maxCount" @@ -226,6 +227,7 @@ private[spark] object CosmosConfigNames { ReadPartitioningFeedRangeFilter, ReadRuntimeFilteringEnabled, ReadManyFilteringEnabled, + ReadManyByPkNullHandling, ViewsRepositoryPath, DiagnosticsMode, DiagnosticsSamplingIntervalInSeconds, @@ -1042,7 +1044,8 @@ private case class CosmosReadConfig(readConsistencyStrategy: ReadConsistencyStra throughputControlConfig: Option[CosmosThroughputControlConfig] = None, runtimeFilteringEnabled: Boolean, readManyFilteringConfig: CosmosReadManyFilteringConfig, - responseContinuationTokenLimitInKb: Option[Int] = None) + responseContinuationTokenLimitInKb: Option[Int] = None, + readManyByPkTreatNullAsNone: Boolean = false) private object SchemaConversionModes extends Enumeration { type SchemaConversionMode = Value @@ -1136,6 +1139,18 @@ private object CosmosReadConfig { helpMessage = " Indicates whether dynamic partition pruning filters will be pushed down when applicable." ) + private val ReadManyByPkNullHandling = CosmosConfigEntry[String]( + key = CosmosConfigNames.ReadManyByPkNullHandling, + mandatory = false, + defaultValue = Some("Null"), + parseFromStringFunction = value => value, + helpMessage = "Determines how null values in hierarchical partition key components are treated " + + "for readManyByPartitionKey. 'Null' (default) maps null to a JSON null value via addNullValue(), " + + "which is appropriate when the document field exists with an explicit null value. " + + "'None' maps null to PartitionKey.NONE via addNoneValue(), which should only be used when the " + + "partition key path does not exist at all in the document." + ) + def parseCosmosReadConfig(cfg: Map[String, String]): CosmosReadConfig = { val forceEventualConsistency = CosmosConfigEntry.parse(cfg, ForceEventualConsistency) val readConsistencyStrategyOverride = CosmosConfigEntry.parse(cfg, ReadConsistencyStrategyOverride) @@ -1158,6 +1173,8 @@ private object CosmosReadConfig { val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg) val runtimeFilteringEnabled = CosmosConfigEntry.parse(cfg, ReadRuntimeFilteringEnabled) val readManyFilteringConfig = CosmosReadManyFilteringConfig.parseCosmosReadManyFilterConfig(cfg) + val readManyByPkNullHandling = CosmosConfigEntry.parse(cfg, ReadManyByPkNullHandling) + val readManyByPkTreatNullAsNone = readManyByPkNullHandling.getOrElse("Null").equalsIgnoreCase("None") val effectiveReadConsistencyStrategy = if (readConsistencyStrategyOverride.getOrElse(ReadConsistencyStrategy.DEFAULT) != ReadConsistencyStrategy.DEFAULT) { readConsistencyStrategyOverride.get @@ -1189,7 +1206,8 @@ private object CosmosReadConfig { throughputControlConfigOpt, runtimeFilteringEnabled.get, readManyFilteringConfig, - responseContinuationTokenLimitInKb) + responseContinuationTokenLimitInKb, + readManyByPkTreatNullAsNone) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index 6257e96e81e5..ac2299929f6a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -146,6 +146,7 @@ object CosmosItemsDataSource { val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" + val treatNullAsNone = readConfig.readManyByPkTreatNullAsNone val pkPaths = Loan( List[Option[CosmosClientCacheItem]]( @@ -197,7 +198,9 @@ object CosmosItemsDataSource { case s: String => builder.add(s) case n: Number => builder.add(n.doubleValue()) case b: Boolean => builder.add(b) - case null => builder.addNoneValue() + case null => + if (treatNullAsNone) builder.addNoneValue() + else builder.addNullValue() case other => builder.add(other.toString) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index da3b81d951ae..73477b3a488d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -170,6 +170,8 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey .enable(true) .build + readManyOptionsImpl.setCosmosEndToEndOperationLatencyPolicyConfig(endToEndTimeoutPolicy) + private trait CloseableSparkRowItemIterator { def hasNext: Boolean def next(): SparkRowItem diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala index 17f75e45a746..17a298d62131 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala @@ -457,6 +457,7 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { config.runtimeFilteringEnabled shouldBe true config.readManyFilteringConfig.readManyFilteringEnabled shouldBe false config.readManyFilteringConfig.readManyFilterProperty shouldEqual "_itemIdentity" + config.readManyByPkTreatNullAsNone shouldBe false userConfig = Map( "spark.cosmos.read.forceEventualConsistency" -> "false", @@ -630,6 +631,47 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { config.customQuery.get.queryText shouldBe queryText } + it should "parse readManyByPk nullHandling configuration" in { + // Default (not specified) should treat null as JSON null (addNullValue) + var userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false" + ) + var config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // Explicit "Null" should treat null as JSON null (addNullValue) + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "Null" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // Case-insensitive "null" + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "null" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // "None" should treat null as PartitionKey.NONE (addNoneValue) + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "None" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe true + + // Case-insensitive "none" + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "none" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe true + } + it should "throw on invalid read configuration" in { val userConfig = Map( "spark.cosmos.read.schemaConversionMode" -> "not a valid value" diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala index 6528d44f5ebe..1ac40e395847 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -88,6 +88,17 @@ class CosmosPartitionKeyHelperSpec extends UnitSpec { pk.isDefined shouldBe false } + it should "produce different partition keys for addNullValue vs addNoneValue in HPK" in { + // addNullValue represents an explicit JSON null for a field that exists with value null + val pkWithNull = new PartitionKeyBuilder().add("Redmond").addNullValue().build() + + // addNoneValue represents PartitionKey.NONE, meaning the field is absent/undefined + val pkWithNone = new PartitionKeyBuilder().add("Redmond").addNoneValue().build() + + // These MUST produce different partition key hashes and route to different physical partitions + pkWithNull should not equal pkWithNone + } + //scalastyle:on multiple.string.literals //scalastyle:on magic.number } From 22abc780ed8b81da8665297b59d25abad896d97f Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 16 Apr 2026 23:54:19 +0000 Subject: [PATCH 15/25] Addressing code review feedback --- ...tionReaderWithReadManyByPartitionKey.scala | 25 +-- ...tryingReadManyByPartitionKeyIterator.scala | 175 ++++++++++++++++++ sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- 3 files changed, 186 insertions(+), 16 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index 73477b3a488d..8d4952b2144c 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -6,7 +6,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} -import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition} +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition, SqlQuerySpec} import com.azure.cosmos.spark.BulkWriter.getThreadInfo import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName import com.azure.cosmos.spark.diagnostics.{DetailedFeedDiagnosticsProvider, DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} @@ -188,27 +188,22 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey override def close(): Unit = {} } - // Single iterator over all PKs — the SDK handles per-physical-partition batching - // internally to avoid oversized SQL queries. - // Short-circuit empty PK lists locally because the SDK rejects empty partition-key lists. + // Batch partition keys and retry each batch independently on transient I/O errors. + // This avoids the continuation-token problem with TransientIOErrorsRetryingIterator + // where a retry would re-read all data from scratch, causing silent data duplication. private lazy val iterator: CloseableSparkRowItemIterator = if (pkList.isEmpty) { EmptySparkRowItemIterator } else { new CloseableSparkRowItemIterator { - private val delegate = new TransientIOErrorsRetryingIterator[SparkRowItem]( - continuationToken => { - readConfig.customQuery match { - case Some(query) => - cosmosAsyncContainer.readManyByPartitionKey(pkList, query.toSqlQuerySpec, readManyOptions, classOf[SparkRowItem]) - case None => - cosmosAsyncContainer.readManyByPartitionKey(pkList, readManyOptions, classOf[SparkRowItem]) - } - }, + private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( + cosmosAsyncContainer, + pkList, + readConfig.customQuery.map(_.toSqlQuerySpec), + readManyOptions, readConfig.maxItemCount, - readConfig.prefetchBufferSize, operationContextAndListenerTuple, - None + classOf[SparkRowItem] ) override def hasNext: Boolean = delegate.hasNext diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala new file mode 100644 index 000000000000..dfdc380b09ce --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.CosmosAsyncContainer +import com.azure.cosmos.implementation.OperationCancelledException +import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, PartitionKey, SqlQuerySpec} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait + +import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} +import scala.concurrent.{Await, ExecutionContext, Future} + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +/** + * Retry-safe iterator for readManyByPartitionKey that batches partition keys and retries + * each batch independently on transient I/O errors. This avoids the continuation-token problem + * where TransientIOErrorsRetryingIterator would re-read all data from scratch on retry, + * causing silent data duplication. + */ +private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] +( + val container: CosmosAsyncContainer, + val partitionKeys: java.util.List[PartitionKey], + val customQuery: Option[SqlQuerySpec], + val queryOptions: CosmosReadManyRequestOptions, + val pageSize: Int, + val operationContextAndListener: Option[OperationContextAndListenerTuple], + val classType: Class[TSparkRow] +) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { + + private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration( + 5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds, + scala.concurrent.duration.SECONDS) + + private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None + private val pkBatchIterator = partitionKeys.asScala.iterator.grouped(pageSize) + + override def hasNext: Boolean = { + if (hasBufferedNext) { + true + } else { + hasNextInternal + } + } + + private def hasNextInternal: Boolean = { + var returnValue: Option[Boolean] = None + + while (returnValue.isEmpty) { + if (pkBatchIterator.hasNext) { + val pkBatch = pkBatchIterator.next().toList + returnValue = + TransientErrorsRetryPolicy.executeWithRetry( + () => hasNextInternalCore(pkBatch), + statusResetFuncBetweenRetry = Some(() => { currentItemIterator = None }) + ) + } else { + returnValue = Some(false) + } + } + + returnValue.get + } + + private def hasNextInternalCore(pkBatch: List[PartitionKey]): Option[Boolean] = { + val pkJavaList = new java.util.ArrayList[PartitionKey](pkBatch.asJava) + val results = try { + Await.result( + Future { + val flux = customQuery match { + case Some(query) => + container.readManyByPartitionKey(pkJavaList, query, queryOptions, classType) + case None => + container.readManyByPartitionKey(pkJavaList, queryOptions, classType) + } + + // Collect all pages for this batch into a single list + flux.collectList().block() + }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), + maxPageRetrievalTimeout) + } catch { + case endToEndTimeoutException: OperationCancelledException => + val operationContextString = operationContextAndListener match { + case Some(o) => if (o.getOperationContext != null) { + o.getOperationContext.toString + } else { + "n/a" + } + case None => "n/a" + } + + val message = s"End-to-end timeout hit when trying to retrieve readManyByPartitionKey batch. " + + s"Batch size: ${pkBatch.size}, Context: $operationContextString" + + logError(message, throwable = endToEndTimeoutException) + + throw endToEndTimeoutException + case timeoutException: TimeoutException => + val operationContextString = operationContextAndListener match { + case Some(o) => if (o.getOperationContext != null) { + o.getOperationContext.toString + } else { + "n/a" + } + case None => "n/a" + } + + val message = s"Attempting to retrieve readManyByPartitionKey batch timed out. " + + s"Batch size: ${pkBatch.size}, Context: $operationContextString" + + logError(message, timeoutException) + + val exception = new OperationCancelledException( + message, + null + ) + exception.setStackTrace(timeoutException.getStackTrace) + throw exception + + case other: Throwable => throw other + } + + val iteratorCandidate = results.iterator().asScala.buffered + + if (iteratorCandidate.hasNext) { + currentItemIterator = Some(iteratorCandidate) + Some(true) + } else { + None + } + } + + private def hasBufferedNext: Boolean = { + currentItemIterator match { + case Some(iterator) => if (iterator.hasNext) { + true + } else { + currentItemIterator = None + false + } + case None => false + } + } + + override def next(): TSparkRow = { + currentItemIterator.get.next() + } + + override def head(): TSparkRow = { + currentItemIterator.get.head + } + + override def close(): Unit = {} +} + +private object TransientIOErrorsRetryingReadManyByPartitionKeyIterator extends BasicLoggingTrait { + private val maxConcurrency = SparkUtils.getNumberOfHostCPUCores + + val executorService: ExecutorService = new ThreadPoolExecutor( + maxConcurrency, + maxConcurrency, + 0L, + TimeUnit.MILLISECONDS, + new SynchronousQueue(), + SparkUtils.daemonThreadFactory(), + new ThreadPoolExecutor.CallerRunsPolicy() + ) + + val executionContext: ExecutionContext = ExecutionContext.fromExecutorService(executorService) +} diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index faf661ddd80f..904c01c3238f 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.80.0-beta.1 (Unreleased) #### Features Added -* Added new `readManyByPartitioNKey` to bulk query by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `readManyByPartitionKey` to bulk query by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes From 662b1a4b90ee6954d7467c4b707b340a2d6b446d Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 00:01:35 +0000 Subject: [PATCH 16/25] Update CosmosItemsDataSource.scala --- .../cosmos/spark/CosmosItemsDataSource.scala | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index ac2299929f6a..aac5a53d8a8d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -188,21 +188,12 @@ object CosmosItemsDataSource { Some((row: Row) => { if (pkPaths.size == 1) { // Single partition key - new PartitionKey(row.getAs[Any](pkPaths.head)) + buildPartitionKey(row.getAs[Any](pkPaths.head), treatNullAsNone) } else { // Hierarchical partition key — build level by level val builder = new PartitionKeyBuilder() for (path <- pkPaths) { - val value = row.getAs[Any](path) - value match { - case s: String => builder.add(s) - case n: Number => builder.add(n.doubleValue()) - case b: Boolean => builder.add(b) - case null => - if (treatNullAsNone) builder.addNoneValue() - else builder.addNullValue() - case other => builder.add(other.toString) - } + addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone) } builder.build() } @@ -222,4 +213,22 @@ object CosmosItemsDataSource { readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) } + + private def addPartitionKeyComponent(builder: PartitionKeyBuilder, value: Any, treatNullAsNone: Boolean): Unit = { + value match { + case s: String => builder.add(s) + case n: Number => builder.add(n.doubleValue()) + case b: Boolean => builder.add(b) + case null => + if (treatNullAsNone) builder.addNoneValue() + else builder.addNullValue() + case other => builder.add(other.toString) + } + } + + private def buildPartitionKey(value: Any, treatNullAsNone: Boolean): PartitionKey = { + val builder = new PartitionKeyBuilder() + addPartitionKeyComponent(builder, value, treatNullAsNone) + builder.build() + } } From c764de9de02caa44d307a31b5cb5f8a6755ca6d4 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 00:03:16 +0000 Subject: [PATCH 17/25] Update CosmosItemsDataSource.scala --- .../com/azure/cosmos/spark/CosmosItemsDataSource.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index aac5a53d8a8d..86ef865bcb83 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -131,8 +131,13 @@ object CosmosItemsDataSource { val pkIdentityFieldExtraction = df .schema .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) - .map(field => (row: Row) => - CosmosPartitionKeyHelper.tryParsePartitionKey(row.getString(row.fieldIndex(field.name))).get) + .map(field => (row: Row) => { + val rawValue = row.getString(row.fieldIndex(field.name)) + CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue) + .getOrElse(throw new IllegalArgumentException( + s"Invalid _partitionKeyIdentity value in row: '$rawValue'. " + + "Expected format: pk([...json...])")) + }) // Option 2: Detect PK columns by matching the container's partition key paths against the DataFrame schema val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { From 080ce4a22931755e02810d48cf6d2fd1d38da719 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 11:20:49 +0000 Subject: [PATCH 18/25] Update RxDocumentClientImpl.java --- .../implementation/RxDocumentClientImpl.java | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index b4c72532280c..e5d8248c1578 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4382,6 +4382,62 @@ public Flux> readManyByPartitionKey( (ctx) -> diagnosticsFactory.merge(ctx) ); + StaleResourceRetryPolicy staleResourceRetryPolicy = new StaleResourceRetryPolicy( + this.collectionCache, + null, + collectionLink, + queryOptionsAccessor().getProperties(state.getQueryOptions()), + queryOptionsAccessor().getHeaders(state.getQueryOptions()), + this.sessionContainer, + diagnosticsFactory, + ResourceType.Document + ); + + return ObservableHelper + .fluxInlineIfPossibleAsObs( + () -> readManyByPartitionKey( + partitionKeys, customQuery, collectionLink, state, diagnosticsFactory, klass), + staleResourceRetryPolicy + ) + .onErrorMap(throwable -> { + if (throwable instanceof CosmosException) { + CosmosException cosmosException = (CosmosException) throwable; + CosmosDiagnostics diagnostics = cosmosException.getDiagnostics(); + if (diagnostics != null) { + state.mergeDiagnosticsContext(); + CosmosDiagnosticsContext ctx = state.getDiagnosticsContextSnapshot(); + if (ctx != null) { + ctxAccessor().recordOperation( + ctx, + cosmosException.getStatusCode(), + cosmosException.getSubStatusCode(), + 0, + cosmosException.getRequestCharge(), + diagnostics, + throwable + ); + diagAccessor() + .setDiagnosticsContext( + diagnostics, + state.getDiagnosticsContextSnapshot()); + } + } + + return cosmosException; + } + + return throwable; + }); + } + + private Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + ScopedDiagnosticsFactory diagnosticsFactory, + Class klass) { + String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); RxDocumentServiceRequest request = RxDocumentServiceRequest.create(diagnosticsFactory, OperationType.Query, From b01f8758eea8c870df2a746ddc0b350fe4e9cfd8 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 11:37:51 +0000 Subject: [PATCH 19/25] Fix readManyByPartitionKey retries --- ...tionReaderWithReadManyByPartitionKey.scala | 1 + ...tryingReadManyByPartitionKeyIterator.scala | 236 ++++++++++++------ 2 files changed, 161 insertions(+), 76 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index 8d4952b2144c..c67cc9c10be1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -202,6 +202,7 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey readConfig.customQuery.map(_.toSqlQuerySpec), readManyOptions, readConfig.maxItemCount, + readConfig.prefetchBufferSize, operationContextAndListenerTuple, classOf[SparkRowItem] ) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala index dfdc380b09ce..dcfdf4f93536 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala @@ -3,24 +3,29 @@ package com.azure.cosmos.spark -import com.azure.cosmos.CosmosAsyncContainer +import com.azure.cosmos.{CosmosAsyncContainer, CosmosException} import com.azure.cosmos.implementation.OperationCancelledException import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple -import com.azure.cosmos.models.{CosmosReadManyRequestOptions, PartitionKey, SqlQuerySpec} +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, FeedResponse, PartitionKey, SqlQuerySpec} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.azure.cosmos.util.CosmosPagedIterable import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} +import java.util.concurrent.atomic.AtomicLong import scala.concurrent.{Await, ExecutionContext, Future} +import scala.util.Random +import scala.util.control.Breaks // scalastyle:off underscore.import import scala.collection.JavaConverters._ // scalastyle:on underscore.import /** - * Retry-safe iterator for readManyByPartitionKey that batches partition keys and retries - * each batch independently on transient I/O errors. This avoids the continuation-token problem - * where TransientIOErrorsRetryingIterator would re-read all data from scratch on retry, - * causing silent data duplication. + * Retry-safe iterator for readManyByPartitionKey that batches partition keys and lazily + * iterates pages within each batch via CosmosPagedIterable — consistent with how + * TransientIOErrorsRetryingIterator handles normal queries. On transient I/O errors the + * current batch's flux is recreated and pages already consumed are replayed, avoiding + * the memory overhead of collectList and matching the query iterator's structure. */ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] ( @@ -29,109 +34,139 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp val customQuery: Option[SqlQuerySpec], val queryOptions: CosmosReadManyRequestOptions, val pageSize: Int, + val pagePrefetchBufferSize: Int, val operationContextAndListener: Option[OperationContextAndListenerTuple], val classType: Class[TSparkRow] ) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { + private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs + private[spark] var maxRetryCount = CosmosConstants.maxRetryCountForTransientFailures + private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration( 5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds, scala.concurrent.duration.SECONDS) + private val rnd = Random + private val retryCount = new AtomicLong(0) + private lazy val operationContextString = operationContextAndListener match { + case Some(o) => if (o.getOperationContext != null) { + o.getOperationContext.toString + } else { + "n/a" + } + case None => "n/a" + } + + private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None + private val pkBatchIterator = partitionKeys.asScala.iterator.grouped(pageSize) + // Track the current batch so we can replay it on retry + private var currentBatch: Option[java.util.List[PartitionKey]] = None override def hasNext: Boolean = { - if (hasBufferedNext) { - true - } else { - hasNextInternal - } + executeWithRetry("hasNextInternal", () => hasNextInternal) } private def hasNextInternal: Boolean = { var returnValue: Option[Boolean] = None while (returnValue.isEmpty) { - if (pkBatchIterator.hasNext) { - val pkBatch = pkBatchIterator.next().toList - returnValue = - TransientErrorsRetryPolicy.executeWithRetry( - () => hasNextInternalCore(pkBatch), - statusResetFuncBetweenRetry = Some(() => { currentItemIterator = None }) - ) - } else { - returnValue = Some(false) - } + returnValue = hasNextInternalCore } returnValue.get } - private def hasNextInternalCore(pkBatch: List[PartitionKey]): Option[Boolean] = { - val pkJavaList = new java.util.ArrayList[PartitionKey](pkBatch.asJava) - val results = try { - Await.result( - Future { - val flux = customQuery match { - case Some(query) => - container.readManyByPartitionKey(pkJavaList, query, queryOptions, classType) + private def hasNextInternalCore: Option[Boolean] = { + if (hasBufferedNext) { + Some(true) + } else { + val feedResponseIterator = currentFeedResponseIterator match { + case Some(existing) => existing + case None => + // Need a new feed response iterator — either for the current batch (on retry) + // or for the next batch + val batch = currentBatch match { + case Some(b) => b // retry of current batch case None => - container.readManyByPartitionKey(pkJavaList, queryOptions, classType) + if (pkBatchIterator.hasNext) { + val nextBatch = new java.util.ArrayList[PartitionKey](pkBatchIterator.next().toList.asJava) + currentBatch = Some(nextBatch) + nextBatch + } else { + return Some(false) // no more batches + } } - // Collect all pages for this batch into a single list - flux.collectList().block() - }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), - maxPageRetrievalTimeout) - } catch { - case endToEndTimeoutException: OperationCancelledException => - val operationContextString = operationContextAndListener match { - case Some(o) => if (o.getOperationContext != null) { - o.getOperationContext.toString - } else { - "n/a" + val pagedFlux = customQuery match { + case Some(query) => + container.readManyByPartitionKey(batch, query, queryOptions, classType) + case None => + container.readManyByPartitionKey(batch, queryOptions, classType) } - case None => "n/a" - } - - val message = s"End-to-end timeout hit when trying to retrieve readManyByPartitionKey batch. " + - s"Batch size: ${pkBatch.size}, Context: $operationContextString" - - logError(message, throwable = endToEndTimeoutException) - throw endToEndTimeoutException - case timeoutException: TimeoutException => - val operationContextString = operationContextAndListener match { - case Some(o) => if (o.getOperationContext != null) { - o.getOperationContext.toString - } else { - "n/a" - } - case None => "n/a" - } + currentFeedResponseIterator = Some( + new CosmosPagedIterable[TSparkRow]( + pagedFlux, + pageSize, + pagePrefetchBufferSize + ) + .iterableByPage() + .iterator + .asScala + .buffered + ) - val message = s"Attempting to retrieve readManyByPartitionKey batch timed out. " + - s"Batch size: ${pkBatch.size}, Context: $operationContextString" + currentFeedResponseIterator.get + } - logError(message, timeoutException) + val hasNext: Boolean = try { + Await.result( + Future { + feedResponseIterator.hasNext + }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), + maxPageRetrievalTimeout) + } catch { + case endToEndTimeoutException: OperationCancelledException => + val message = s"End-to-end timeout hit when trying to retrieve the next page. " + + s"Context: $operationContextString" + logError(message, throwable = endToEndTimeoutException) + throw endToEndTimeoutException - val exception = new OperationCancelledException( - message, - null - ) - exception.setStackTrace(timeoutException.getStackTrace) - throw exception + case timeoutException: TimeoutException => + val message = s"Attempting to retrieve the next page timed out. " + + s"Context: $operationContextString" + logError(message, timeoutException) + val exception = new OperationCancelledException(message, null) + exception.setStackTrace(timeoutException.getStackTrace) + throw exception - case other: Throwable => throw other - } + case other: Throwable => throw other + } - val iteratorCandidate = results.iterator().asScala.buffered + if (hasNext) { + val feedResponse = feedResponseIterator.next() + if (operationContextAndListener.isDefined) { + operationContextAndListener.get.getOperationListener.feedResponseProcessedListener( + operationContextAndListener.get.getOperationContext, + feedResponse) + } + val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered - if (iteratorCandidate.hasNext) { - currentItemIterator = Some(iteratorCandidate) - Some(true) - } else { - None + if (iteratorCandidate.hasNext) { + currentItemIterator = Some(iteratorCandidate) + Some(true) + } else { + // empty page interleaved — try again + None + } + } else { + // Current batch's flux is exhausted — move to next batch + currentBatch = None + currentFeedResponseIterator = None + None + } } } @@ -155,7 +190,56 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp currentItemIterator.get.head } - override def close(): Unit = {} + private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = { + val loop = new Breaks() + var returnValue: Option[T] = None + + loop.breakable { + while (true) { + val retryIntervalInMs = rnd.nextInt(maxRetryIntervalInMs) + + try { + returnValue = Some(func()) + retryCount.set(0) + loop.break + } + catch { + case cosmosException: CosmosException => + if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) { + val retryCountSnapshot = retryCount.incrementAndGet() + if (retryCountSnapshot > maxRetryCount) { + logError( + s"Too many transient failure retry attempts in " + + s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName", + cosmosException) + throw cosmosException + } else { + logWarning( + s"Transient failure handled in " + + s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName -" + + s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms", + cosmosException) + } + } else { + throw cosmosException + } + case other: Throwable => throw other + } + + // Reset iterators but keep currentBatch so the batch is replayed + currentItemIterator = None + currentFeedResponseIterator = None + Thread.sleep(retryIntervalInMs) + } + } + + returnValue.get + } + + override def close(): Unit = { + currentItemIterator = None + currentFeedResponseIterator = None + } } private object TransientIOErrorsRetryingReadManyByPartitionKeyIterator extends BasicLoggingTrait { From 7130d4aa35a228b116df9220394ab6a1ca569a9f Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 11:48:54 +0000 Subject: [PATCH 20/25] Fix PK.None --- .../azure/cosmos/CosmosAsyncContainer.java | 6 ++ .../ReadManyByPartitionKeyQueryHelper.java | 97 ++++++++++++++----- .../implementation/RxDocumentClientImpl.java | 14 ++- 3 files changed, 88 insertions(+), 29 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index 025a957a4606..4e234667c1c0 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -1673,6 +1673,12 @@ public CosmosPagedFlux readManyByPartitionKey( if (partitionKeys.isEmpty()) { throw new IllegalArgumentException("Argument 'partitionKeys' must not be empty."); } + for (PartitionKey pk : partitionKeys) { + if (pk == null) { + throw new IllegalArgumentException( + "Argument 'partitionKeys' must not contain null elements."); + } + } return UtilBridgeInternal.createCosmosPagedFlux( readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index d538542df1ba..4a6d2efdeee5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -40,24 +40,54 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( boolean isSinglePathPk = partitionKeySelectors.size() == 1; if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { - // Single PK path — use IN clause: alias["pkPath"] IN (@__rmPk_0, @__rmPk_1, ...) + // Single PK path — use IN clause for normal values, OR NOT IS_DEFINED for NONE + // First, separate NONE PKs from normal PKs + boolean hasNone = false; + List normalPkValues = new ArrayList<>(); + for (PartitionKey pk : pkValues) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); + if (pkInternal.getComponents() == null) { + hasNone = true; + } else { + normalPkValues.add(pk); + } + } + pkFilter.append(" "); - pkFilter.append(tableAlias); - pkFilter.append(partitionKeySelectors.get(0)); - pkFilter.append(" IN ( "); - for (int i = 0; i < pkValues.size(); i++) { - PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); - Object[] pkComponents = pkInternal.toObjectArray(); - String pkParamName = PK_PARAM_PREFIX + paramCount; - parameters.add(new SqlParameter(pkParamName, pkComponents[0])); - paramCount++; + boolean hasNormalValues = !normalPkValues.isEmpty(); + if (hasNormalValues && hasNone) { + pkFilter.append("("); + } + if (hasNormalValues) { + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(0)); + pkFilter.append(" IN ( "); + for (int i = 0; i < normalPkValues.size(); i++) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(normalPkValues.get(i)); + Object[] pkComponents = pkInternal.toObjectArray(); + String pkParamName = PK_PARAM_PREFIX + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[0])); + paramCount++; - pkFilter.append(pkParamName); - if (i < pkValues.size() - 1) { - pkFilter.append(", "); + pkFilter.append(pkParamName); + if (i < normalPkValues.size() - 1) { + pkFilter.append(", "); + } } + pkFilter.append(" )"); + } + if (hasNone) { + if (hasNormalValues) { + pkFilter.append(" OR "); + } + pkFilter.append("NOT IS_DEFINED("); + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(0)); + pkFilter.append(")"); + } + if (hasNormalValues && hasNone) { + pkFilter.append(")"); } - pkFilter.append(" )"); } else { // Multiple PK paths (HPK) or MULTI_HASH — use OR of AND clauses pkFilter.append(" "); @@ -65,21 +95,36 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); Object[] pkComponents = pkInternal.toObjectArray(); - pkFilter.append("("); - for (int j = 0; j < pkComponents.length; j++) { - String pkParamName = PK_PARAM_PREFIX + paramCount; - parameters.add(new SqlParameter(pkParamName, pkComponents[j])); - paramCount++; + // PartitionKey.NONE — generate NOT IS_DEFINED for all PK paths + if (pkComponents == null) { + pkFilter.append("("); + for (int j = 0; j < partitionKeySelectors.size(); j++) { + if (j > 0) { + pkFilter.append(" AND "); + } + pkFilter.append("NOT IS_DEFINED("); + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(j)); + pkFilter.append(")"); + } + pkFilter.append(")"); + } else { + pkFilter.append("("); + for (int j = 0; j < pkComponents.length; j++) { + String pkParamName = PK_PARAM_PREFIX + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[j])); + paramCount++; - if (j > 0) { - pkFilter.append(" AND "); + if (j > 0) { + pkFilter.append(" AND "); + } + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(j)); + pkFilter.append(" = "); + pkFilter.append(pkParamName); } - pkFilter.append(tableAlias); - pkFilter.append(partitionKeySelectors.get(j)); - pkFilter.append(" = "); - pkFilter.append(pkParamName); + pkFilter.append(")"); } - pkFilter.append(")"); if (i < pkValues.size() - 1) { pkFilter.append(" OR "); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index e5d8248c1578..c70dedaa1f2c 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4623,7 +4623,15 @@ private Map> groupPartitionKeysByPhysicalP for (PartitionKey pk : partitionKeys) { PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); - int componentCount = pkInternal.getComponents().size(); + + // PartitionKey.NONE wraps NonePartitionKey which has components = null. + // For routing purposes, treat NONE as UndefinedPartitionKey — documents ingested + // without a partition key path are stored with the undefined EPK. + PartitionKeyInternal effectivePkInternal = pkInternal.getComponents() == null + ? PartitionKeyInternal.UndefinedPartitionKey + : pkInternal; + + int componentCount = effectivePkInternal.getComponents().size(); int definedPathCount = pkDefinition.getPaths().size(); List targetRanges; @@ -4631,12 +4639,12 @@ private Map> groupPartitionKeysByPhysicalP if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { // Partial HPK — compute EPK prefix range and find all overlapping physical partitions Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( - pkInternal, pkDefinition); + effectivePkInternal, pkDefinition); targetRanges = routingMap.getOverlappingRanges(epkRange); } else { // Full PK — maps to exactly one physical partition String effectivePartitionKeyString = PartitionKeyInternalHelper - .getEffectivePartitionKeyString(pkInternal, pkDefinition); + .getEffectivePartitionKeyString(effectivePkInternal, pkDefinition); PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); targetRanges = Collections.singletonList(range); } From 93957f3a8442d730fe67fbc379ef5399f46f5665 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 11:56:12 +0000 Subject: [PATCH 21/25] Update ReadManyByPartitionKeyQueryHelper.java --- .../ReadManyByPartitionKeyQueryHelper.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index 4a6d2efdeee5..6d6cd084e01a 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -224,10 +224,17 @@ static int findTopLevelKeywordIndex(String queryText, String keyword) { int keyLen = keywordUpper.length(); for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { char ch = queryTextUpper.charAt(i); - // Skip string literals enclosed in single quotes + // Skip string literals enclosed in single quotes (handle '' escape) if (queryText.charAt(i) == '\'') { i++; - while (i < queryText.length() && queryText.charAt(i) != '\'') { + while (i < queryText.length()) { + if (queryText.charAt(i) == '\'') { + if (i + 1 < queryText.length() && queryText.charAt(i + 1) == '\'') { + i += 2; // escaped quote — skip both + continue; + } + break; // end of string literal + } i++; } continue; From 9200f8fe53b95c51f4bec29e6afbc3b8480ec093 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 12:50:47 +0000 Subject: [PATCH 22/25] Fix code review feedback --- ...bianm_readManyByPK-vs-origin_main-full.txt | 3444 +++++++++++++++++ ...bianm_readManyByPK-vs-origin_main-stat.txt | 76 + .../azure-cosmos-spark_3-3_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-4_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-5_2-12/CHANGELOG.md | 2 +- .../azure-cosmos-spark_3-5_2-13/CHANGELOG.md | 2 +- .../com/azure/cosmos/spark/CosmosConfig.scala | 11 +- .../cosmos/spark/CosmosItemsDataSource.scala | 23 +- .../CosmosReadManyByPartitionKeyReader.scala | 7 +- ...tionReaderWithReadManyByPartitionKey.scala | 22 +- ...tryingReadManyByPartitionKeyIterator.scala | 127 +- .../udf/GetCosmosPartitionKeyValue.scala | 21 +- .../azure-cosmos-spark_4-0_2-13/CHANGELOG.md | 2 +- .../azure/cosmos/implementation/Configs.java | 10 +- .../ReadManyByPartitionKeyQueryHelper.java | 30 +- .../implementation/RxDocumentClientImpl.java | 16 +- 16 files changed, 3697 insertions(+), 100 deletions(-) create mode 100644 sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt create mode 100644 sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt diff --git a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt new file mode 100644 index 000000000000..5eee65a15a7b --- /dev/null +++ b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt @@ -0,0 +1,3444 @@ +===== PR #LOCAL-users_fabianm_readManyByPK ===== +Title: Branch comparison users/fabianm/readManyByPK vs origin/main +Author: Fabian Meiswinkel +Status: DIVERGED (ahead 30, behind 4) +Branch: users/fabianm/readManyByPK -> origin/main +Head SHA: 93957f3a8442d730fe67fbc379ef5399f46f5665 +Merge Base: 20313f79ba8dd0dfa97862d0c31dd4b2e44ee671 +URL: N/A (local branch comparison) + +--- Description --- +Adds readManyByPartitionKey API (sync+async) and Spark connector support for PK-only reads, with query-plan-based validation +--- End Description --- + +===== Commits in PR ===== +9770833eb59 Adding readManyByPartitionKey API +ac287bcdf00 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +9a5b3e96e7e Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +a8720c3c9f2 Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +d499da76fb4 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +c3c542a33a7 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +4416354e03e ┬┤Fixing code review comments +3ab3f0d64f5 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +588a7550c54 Update CosmosAsyncContainer.java +8c5cdb47b31 Merge branch 'main' into users/fabianm/readManyByPK +f5485527a9c Update ReadManyByPartitionKeyTest.java +f68cf02ff71 Fixing test issues +8b6c4b168ea Update CosmosAsyncContainer.java +8ba7f4db2da Merge branch 'main' into users/fabianm/readManyByPK +56b067a9339 Reacted to code review feedback +fa430e918fa Merge branch 'main' into users/fabianm/readManyByPK +d9504c91f34 Fix build issues +73151f09e5f Merge branch 'main' into users/fabianm/readManyByPK +681830e2d4a Fixing changelog +7f745e60641 Merge branch 'main' into users/fabianm/readManyByPK +0b8905dbb01 Addressing code review comments +22abc780ed8 Addressing code review feedback +662b1a4b90e Update CosmosItemsDataSource.scala +c764de9de02 Update CosmosItemsDataSource.scala +e1e6f5a6f73 Merge branch 'main' into users/fabianm/readManyByPK +080ce4a2293 Update RxDocumentClientImpl.java +516bbf3a95a Merge branch 'users/fabianm/readManyByPK' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +b01f8758eea Fix readManyByPartitionKey retries +7130d4aa35a Fix PK.None +93957f3a844 Update ReadManyByPartitionKeyQueryHelper.java + +===== Files Changed ===== + sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala (+20 -2) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala (+1 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala (+125 -1) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala (+45 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala (+150 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala (+249 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala (+259 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala (+25 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala (+42 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala (+104 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala (+158 -0) + sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java (+462 -0) + sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java (+426 -0) + sdk/cosmos/azure-cosmos/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java (+126 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java (+67 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java (+21 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java (+19 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java (+263 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java (+292 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java (+11 -0) + sdk/cosmos/cspell.yaml (+6 -0) + sdk/cosmos/docs/readManyByPartitionKey-design.md (+169 -0) + +===== Full Diff ===== +diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +index cbf97c610f9..fe114462019 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.47.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +index c9097e749f0..3b2c7ce36db 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.47.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +index f5eac38bdb7..2240a48b165 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.47.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +index 919d7fbfa32..20a3e3a61bd 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.47.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +index 951f4735444..e1b8f0b51f8 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +@@ -92,6 +92,7 @@ private[spark] object CosmosConfigNames { + val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter" + val ReadRuntimeFilteringEnabled = "spark.cosmos.read.runtimeFiltering.enabled" + val ReadManyFilteringEnabled = "spark.cosmos.read.readManyFiltering.enabled" ++ val ReadManyByPkNullHandling = "spark.cosmos.read.readManyByPk.nullHandling" + val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath" + val DiagnosticsMode = "spark.cosmos.diagnostics" + val DiagnosticsSamplingMaxCount = "spark.cosmos.diagnostics.sampling.maxCount" +@@ -226,6 +227,7 @@ private[spark] object CosmosConfigNames { + ReadPartitioningFeedRangeFilter, + ReadRuntimeFilteringEnabled, + ReadManyFilteringEnabled, ++ ReadManyByPkNullHandling, + ViewsRepositoryPath, + DiagnosticsMode, + DiagnosticsSamplingIntervalInSeconds, +@@ -1042,7 +1044,8 @@ private case class CosmosReadConfig(readConsistencyStrategy: ReadConsistencyStra + throughputControlConfig: Option[CosmosThroughputControlConfig] = None, + runtimeFilteringEnabled: Boolean, + readManyFilteringConfig: CosmosReadManyFilteringConfig, +- responseContinuationTokenLimitInKb: Option[Int] = None) ++ responseContinuationTokenLimitInKb: Option[Int] = None, ++ readManyByPkTreatNullAsNone: Boolean = false) + + private object SchemaConversionModes extends Enumeration { + type SchemaConversionMode = Value +@@ -1136,6 +1139,18 @@ private object CosmosReadConfig { + helpMessage = " Indicates whether dynamic partition pruning filters will be pushed down when applicable." + ) + ++ private val ReadManyByPkNullHandling = CosmosConfigEntry[String]( ++ key = CosmosConfigNames.ReadManyByPkNullHandling, ++ mandatory = false, ++ defaultValue = Some("Null"), ++ parseFromStringFunction = value => value, ++ helpMessage = "Determines how null values in hierarchical partition key components are treated " + ++ "for readManyByPartitionKey. 'Null' (default) maps null to a JSON null value via addNullValue(), " + ++ "which is appropriate when the document field exists with an explicit null value. " + ++ "'None' maps null to PartitionKey.NONE via addNoneValue(), which should only be used when the " + ++ "partition key path does not exist at all in the document." ++ ) ++ + def parseCosmosReadConfig(cfg: Map[String, String]): CosmosReadConfig = { + val forceEventualConsistency = CosmosConfigEntry.parse(cfg, ForceEventualConsistency) + val readConsistencyStrategyOverride = CosmosConfigEntry.parse(cfg, ReadConsistencyStrategyOverride) +@@ -1158,6 +1173,8 @@ private object CosmosReadConfig { + val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg) + val runtimeFilteringEnabled = CosmosConfigEntry.parse(cfg, ReadRuntimeFilteringEnabled) + val readManyFilteringConfig = CosmosReadManyFilteringConfig.parseCosmosReadManyFilterConfig(cfg) ++ val readManyByPkNullHandling = CosmosConfigEntry.parse(cfg, ReadManyByPkNullHandling) ++ val readManyByPkTreatNullAsNone = readManyByPkNullHandling.getOrElse("Null").equalsIgnoreCase("None") + + val effectiveReadConsistencyStrategy = if (readConsistencyStrategyOverride.getOrElse(ReadConsistencyStrategy.DEFAULT) != ReadConsistencyStrategy.DEFAULT) { + readConsistencyStrategyOverride.get +@@ -1189,7 +1206,8 @@ private object CosmosReadConfig { + throughputControlConfigOpt, + runtimeFilteringEnabled.get, + readManyFilteringConfig, +- responseContinuationTokenLimitInKb) ++ responseContinuationTokenLimitInKb, ++ readManyByPkTreatNullAsNone) + } + } + +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala +index 9ece4741652..00761f23d39 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala +@@ -45,6 +45,7 @@ private[cosmos] object CosmosConstants { + val Id = "id" + val ETag = "_etag" + val ItemIdentity = "_itemIdentity" ++ val PartitionKeyIdentity = "_partitionKeyIdentity" + } + + object StatusCodes { +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +index a35cff27af6..86ef865bcb8 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +@@ -2,9 +2,10 @@ + // Licensed under the MIT License. + package com.azure.cosmos.spark + +-import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey} ++import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey, PartitionKeyBuilder} + import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver + import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait ++import com.azure.cosmos.{SparkBridgeInternal} + import org.apache.spark.sql.{DataFrame, Row, SparkSession} + + import java.util +@@ -112,4 +113,127 @@ object CosmosItemsDataSource { + + readManyReader.readMany(df.rdd, readManyFilterExtraction) + } ++ ++ def readManyByPartitionKey(df: DataFrame, userConfig: java.util.Map[String, String]): DataFrame = { ++ readManyByPartitionKey(df, userConfig, null) ++ } ++ ++ def readManyByPartitionKey( ++ df: DataFrame, ++ userConfig: java.util.Map[String, String], ++ userProvidedSchema: StructType): DataFrame = { ++ ++ val readManyReader = new CosmosReadManyByPartitionKeyReader( ++ userProvidedSchema, ++ userConfig.asScala.toMap) ++ ++ // Option 1: Look for the _partitionKeyIdentity column (produced by GetCosmosPartitionKeyValue UDF) ++ val pkIdentityFieldExtraction = df ++ .schema ++ .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) ++ .map(field => (row: Row) => { ++ val rawValue = row.getString(row.fieldIndex(field.name)) ++ CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue) ++ .getOrElse(throw new IllegalArgumentException( ++ s"Invalid _partitionKeyIdentity value in row: '$rawValue'. " + ++ "Expected format: pk([...json...])")) ++ }) ++ ++ // Option 2: Detect PK columns by matching the container's partition key paths against the DataFrame schema ++ val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { ++ None // no need to resolve PK paths - _partitionKeyIdentity column takes precedence ++ } else { ++ val effectiveConfig = CosmosConfig.getEffectiveConfig( ++ databaseName = None, ++ containerName = None, ++ userConfig.asScala.toMap) ++ val readConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveConfig) ++ val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) ++ val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) ++ val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" ++ val treatNullAsNone = readConfig.readManyByPkTreatNullAsNone ++ ++ val pkPaths = Loan( ++ List[Option[CosmosClientCacheItem]]( ++ Some( ++ CosmosClientCache( ++ CosmosClientConfiguration( ++ effectiveConfig, ++ readConsistencyStrategy = readConfig.readConsistencyStrategy, ++ sparkEnvironmentInfo), ++ None, ++ calledFrom)), ++ ThroughputControlHelper.getThroughputControlClientCacheItem( ++ effectiveConfig, ++ calledFrom, ++ None, ++ sparkEnvironmentInfo) ++ )) ++ .to(clientCacheItems => { ++ val container = ++ ThroughputControlHelper.getContainer( ++ effectiveConfig, ++ containerConfig, ++ clientCacheItems(0).get, ++ clientCacheItems(1)) ++ ++ val pkDefinition = SparkBridgeInternal ++ .getContainerPropertiesFromCollectionCache(container) ++ .getPartitionKeyDefinition ++ ++ pkDefinition.getPaths.asScala.map(_.stripPrefix("/")).toList ++ }) ++ ++ // Check if ALL PK path columns exist in the DataFrame schema ++ val dfFieldNames = df.schema.fieldNames.toSet ++ val allPkColumnsPresent = pkPaths.forall(path => dfFieldNames.contains(path)) ++ ++ if (allPkColumnsPresent && pkPaths.nonEmpty) { ++ // pkPaths already defined above ++ Some((row: Row) => { ++ if (pkPaths.size == 1) { ++ // Single partition key ++ buildPartitionKey(row.getAs[Any](pkPaths.head), treatNullAsNone) ++ } else { ++ // Hierarchical partition key ΓÇö build level by level ++ val builder = new PartitionKeyBuilder() ++ for (path <- pkPaths) { ++ addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone) ++ } ++ builder.build() ++ } ++ }) ++ } else { ++ None ++ } ++ } ++ ++ val pkExtraction = pkIdentityFieldExtraction ++ .orElse(pkColumnExtraction) ++ .getOrElse( ++ throw new IllegalArgumentException( ++ "Cannot determine partition key extraction from the input DataFrame. " + ++ "Either add a '_partitionKeyIdentity' column (using the GetCosmosPartitionKeyValue UDF) " + ++ "or ensure the DataFrame contains columns matching the container's partition key paths.")) ++ ++ readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) ++ } ++ ++ private def addPartitionKeyComponent(builder: PartitionKeyBuilder, value: Any, treatNullAsNone: Boolean): Unit = { ++ value match { ++ case s: String => builder.add(s) ++ case n: Number => builder.add(n.doubleValue()) ++ case b: Boolean => builder.add(b) ++ case null => ++ if (treatNullAsNone) builder.addNoneValue() ++ else builder.addNullValue() ++ case other => builder.add(other.toString) ++ } ++ } ++ ++ private def buildPartitionKey(value: Any, treatNullAsNone: Boolean): PartitionKey = { ++ val builder = new PartitionKeyBuilder() ++ addPartitionKeyComponent(builder, value, treatNullAsNone) ++ builder.build() ++ } + } +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +new file mode 100644 +index 00000000000..27776f5c3de +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +@@ -0,0 +1,45 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.implementation.routing.PartitionKeyInternal ++import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, Utils} ++import com.azure.cosmos.models.PartitionKey ++import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait ++ ++import java.util ++ ++// scalastyle:off underscore.import ++import scala.collection.JavaConverters._ ++// scalastyle:on underscore.import ++ ++private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { ++ // pattern will be recognized ++ // pk(partitionKeyValue) ++ // ++ // (?i) : The whole matching is case-insensitive ++ // pk[(](.*)[)]: partitionKey Value ++ private val cosmosPartitionKeyStringRegx = """(?i)pk[(](.*)[)]""".r ++ private val objectMapper = Utils.getSimpleObjectMapper ++ ++ def getCosmosPartitionKeyValueString(partitionKeyValue: List[Object]): String = { ++ s"pk(${objectMapper.writeValueAsString(partitionKeyValue.asJava)})" ++ } ++ ++ def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = { ++ cosmosPartitionKeyString match { ++ case cosmosPartitionKeyStringRegx(pkValue) => ++ scala.util.Try(Utils.parse(pkValue, classOf[Object])).toOption.flatMap { ++ case arrayList: util.ArrayList[Object @unchecked] => ++ Some( ++ ImplementationBridgeHelpers ++ .PartitionKeyHelper ++ .getPartitionKeyAccessor ++ .toPartitionKey(PartitionKeyInternal.fromObjectArray(arrayList.toArray, false))) ++ case other => Some(new PartitionKey(other)) ++ } ++ case _ => None ++ } ++ } ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +new file mode 100644 +index 00000000000..91f3a56bc66 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +@@ -0,0 +1,150 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.{CosmosException, ReadConsistencyStrategy} ++import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, UUIDs} ++import com.azure.cosmos.models.PartitionKey ++import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver ++import com.azure.cosmos.spark.diagnostics.{BasicLoggingTrait, DiagnosticsContext} ++import com.fasterxml.jackson.databind.node.ObjectNode ++import org.apache.spark.TaskContext ++import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.rdd.RDD ++import org.apache.spark.sql.{DataFrame, Row, SparkSession} ++import org.apache.spark.sql.types.StructType ++ ++import java.util.UUID ++ ++private[spark] class CosmosReadManyByPartitionKeyReader( ++ val userProvidedSchema: StructType, ++ val userConfig: Map[String, String] ++ ) extends BasicLoggingTrait with Serializable { ++ val effectiveUserConfig: Map[String, String] = CosmosConfig.getEffectiveConfig( ++ databaseName = None, ++ containerName = None, ++ userConfig) ++ ++ val clientConfig: CosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(effectiveUserConfig) ++ val readConfig: CosmosReadConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveUserConfig) ++ val cosmosContainerConfig: CosmosContainerConfig = ++ CosmosContainerConfig.parseCosmosContainerConfig(effectiveUserConfig) ++ //scalastyle:off multiple.string.literals ++ val tableName: String = s"com.azure.cosmos.spark.items.${clientConfig.accountName}." + ++ s"${cosmosContainerConfig.database}.${cosmosContainerConfig.container}" ++ private lazy val sparkSession = { ++ assertOnSparkDriver() ++ SparkSession.active ++ } ++ val sparkEnvironmentInfo: String = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(sparkSession)) ++ logTrace(s"Instantiated ${this.getClass.getSimpleName} for $tableName") ++ ++ private[spark] def initializeAndBroadcastCosmosClientStatesForContainer(): Broadcast[CosmosClientMetadataCachesSnapshots] = { ++ val calledFrom = s"CosmosReadManyByPartitionKeyReader($tableName).initializeAndBroadcastCosmosClientStateForContainer" ++ Loan( ++ List[Option[CosmosClientCacheItem]]( ++ Some( ++ CosmosClientCache( ++ CosmosClientConfiguration( ++ effectiveUserConfig, ++ readConsistencyStrategy = readConfig.readConsistencyStrategy, ++ sparkEnvironmentInfo), ++ None, ++ calledFrom)), ++ ThroughputControlHelper.getThroughputControlClientCacheItem( ++ effectiveUserConfig, ++ calledFrom, ++ None, ++ sparkEnvironmentInfo) ++ )) ++ .to(clientCacheItems => { ++ val container = ++ ThroughputControlHelper.getContainer( ++ effectiveUserConfig, ++ cosmosContainerConfig, ++ clientCacheItems(0).get, ++ clientCacheItems(1)) ++ try { ++ container.readItem( ++ UUIDs.nonBlockingRandomUUID().toString, ++ new PartitionKey(UUIDs.nonBlockingRandomUUID().toString), ++ classOf[ObjectNode]) ++ .block() ++ } catch { ++ case _: CosmosException => None ++ } ++ ++ val state = new CosmosClientMetadataCachesSnapshot() ++ state.serialize(clientCacheItems(0).get.cosmosClient) ++ ++ var throughputControlState: Option[CosmosClientMetadataCachesSnapshot] = None ++ if (clientCacheItems(1).isDefined) { ++ throughputControlState = Some(new CosmosClientMetadataCachesSnapshot()) ++ throughputControlState.get.serialize(clientCacheItems(1).get.cosmosClient) ++ } ++ ++ val metadataSnapshots = CosmosClientMetadataCachesSnapshots(state, throughputControlState) ++ sparkSession.sparkContext.broadcast(metadataSnapshots) ++ }) ++ } ++ ++ def readManyByPartitionKey(inputRdd: RDD[Row], pkExtraction: Row => PartitionKey): DataFrame = { ++ val correlationActivityId = UUIDs.nonBlockingRandomUUID() ++ val calledFrom = s"CosmosReadManyByPartitionKeyReader.readManyByPartitionKey($correlationActivityId)" ++ val schema = Loan( ++ List[Option[CosmosClientCacheItem]]( ++ Some(CosmosClientCache( ++ CosmosClientConfiguration( ++ effectiveUserConfig, ++ readConsistencyStrategy = readConfig.readConsistencyStrategy, ++ sparkEnvironmentInfo), ++ None, ++ calledFrom ++ )), ++ ThroughputControlHelper.getThroughputControlClientCacheItem( ++ effectiveUserConfig, ++ calledFrom, ++ None, ++ sparkEnvironmentInfo) ++ )) ++ .to(clientCacheItems => Option.apply(userProvidedSchema).getOrElse( ++ CosmosTableSchemaInferrer.inferSchema( ++ clientCacheItems(0).get, ++ clientCacheItems(1), ++ effectiveUserConfig, ++ ItemsTable.defaultSchemaForInferenceDisabled))) ++ ++ val clientStates = initializeAndBroadcastCosmosClientStatesForContainer ++ ++ sparkSession.sqlContext.createDataFrame( ++ inputRdd.mapPartitionsWithIndex( ++ (partitionIndex: Int, rowIterator: Iterator[Row]) => { ++ val pkIterator: Iterator[PartitionKey] = rowIterator ++ .map(row => pkExtraction.apply(row)) ++ ++ logInfo(s"Creating an ItemsPartitionReaderWithReadManyByPartitionKey for Activity $correlationActivityId to read for " ++ + s"input partition [$partitionIndex] ${tableName}") ++ ++ val reader = new ItemsPartitionReaderWithReadManyByPartitionKey( ++ effectiveUserConfig, ++ CosmosReadManyHelper.FullRangeFeedRange, ++ schema, ++ DiagnosticsContext(correlationActivityId, partitionIndex.toString), ++ clientStates, ++ DiagnosticsConfig.parseDiagnosticsConfig(effectiveUserConfig), ++ sparkEnvironmentInfo, ++ TaskContext.get, ++ pkIterator) ++ ++ new Iterator[Row] { ++ override def hasNext: Boolean = reader.next() ++ ++ override def next(): Row = reader.getCurrentRow() ++ } ++ }, ++ preservesPartitioning = true ++ ), ++ schema) ++ } ++} ++ +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +new file mode 100644 +index 00000000000..c67cc9c10be +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +@@ -0,0 +1,249 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} ++import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple ++import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} ++import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition, SqlQuerySpec} ++import com.azure.cosmos.spark.BulkWriter.getThreadInfo ++import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName ++import com.azure.cosmos.spark.diagnostics.{DetailedFeedDiagnosticsProvider, DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} ++import com.fasterxml.jackson.databind.node.ObjectNode ++import org.apache.spark.TaskContext ++import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.sql.Row ++import org.apache.spark.sql.catalyst.InternalRow ++import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder ++import org.apache.spark.sql.connector.read.PartitionReader ++import org.apache.spark.sql.types.StructType ++ ++import java.util ++ ++// scalastyle:off underscore.import ++import scala.collection.JavaConverters._ ++// scalastyle:on underscore.import ++ ++private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey ++( ++ config: Map[String, String], ++ feedRange: NormalizedRange, ++ readSchema: StructType, ++ diagnosticsContext: DiagnosticsContext, ++ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], ++ diagnosticsConfig: DiagnosticsConfig, ++ sparkEnvironmentInfo: String, ++ taskContext: TaskContext, ++ readManyPartitionKeys: Iterator[PartitionKey] ++) ++ extends PartitionReader[InternalRow] { ++ ++ private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) ++ ++ private val readManyOptions = new CosmosReadManyRequestOptions() ++ private val readManyOptionsImpl = ImplementationBridgeHelpers ++ .CosmosReadManyRequestOptionsHelper ++ .getCosmosReadManyRequestOptionsAccessor ++ .getImpl(readManyOptions) ++ ++ private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) ++ ThroughputControlHelper.populateThroughputControlGroupName(readManyOptionsImpl, readConfig.throughputControlConfig) ++ ++ private val operationContext = { ++ assert(taskContext != null) ++ ++ SparkTaskContext(diagnosticsContext.correlationActivityId, ++ taskContext.stageId(), ++ taskContext.partitionId(), ++ taskContext.taskAttemptId(), ++ feedRange.toString) ++ } ++ ++ private val operationContextAndListenerTuple: Option[OperationContextAndListenerTuple] = { ++ if (diagnosticsConfig.mode.isDefined) { ++ val listener = ++ DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass) ++ ++ val ctxAndListener = new OperationContextAndListenerTuple(operationContext, listener) ++ ++ readManyOptionsImpl ++ .setOperationContextAndListenerTuple(ctxAndListener) ++ ++ Some(ctxAndListener) ++ } else { ++ None ++ } ++ } ++ ++ log.logTrace(s"Instantiated ${this.getClass.getSimpleName}, Context: ${operationContext.toString} $getThreadInfo") ++ ++ private val containerTargetConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) ++ ++ log.logInfo(s"Using ReadManyByPartitionKey from feed range $feedRange of " + ++ s"container ${containerTargetConfig.database}.${containerTargetConfig.container} - " + ++ s"correlationActivityId ${diagnosticsContext.correlationActivityId}, " + ++ s"Context: ${operationContext.toString} $getThreadInfo") ++ ++ private val clientCacheItem = CosmosClientCache( ++ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), ++ Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), ++ s"ItemsPartitionReaderWithReadManyByPartitionKey($feedRange, ${containerTargetConfig.database}.${containerTargetConfig.container})" ++ ) ++ ++ private val throughputControlClientCacheItemOpt = ++ ThroughputControlHelper.getThroughputControlClientCacheItem( ++ config, ++ clientCacheItem.context, ++ Some(cosmosClientStateHandles), ++ sparkEnvironmentInfo) ++ ++ private val cosmosAsyncContainer = ++ ThroughputControlHelper.getContainer( ++ config, ++ containerTargetConfig, ++ clientCacheItem, ++ throughputControlClientCacheItemOpt) ++ ++ private val partitionKeyDefinition: PartitionKeyDefinition = { ++ TransientErrorsRetryPolicy.executeWithRetry(() => { ++ SparkBridgeInternal ++ .getContainerPropertiesFromCollectionCache(cosmosAsyncContainer).getPartitionKeyDefinition ++ }) ++ } ++ ++ private val cosmosSerializationConfig = CosmosSerializationConfig.parseSerializationConfig(config) ++ private val cosmosRowConverter = CosmosRowConverter.get(cosmosSerializationConfig) ++ ++ readManyOptionsImpl ++ .setCustomItemSerializer( ++ new CosmosItemSerializerNoExceptionWrapping { ++ override def serialize[T](item: T): util.Map[String, AnyRef] = { ++ throw new UnsupportedOperationException( ++ s"Serialization is not supported by the custom item serializer in " + ++ s"ItemsPartitionReaderWithReadManyByPartitionKey; this serializer is intended " + ++ s"for deserializing read-many responses into SparkRowItem only. " + ++ s"Unexpected item type: ${if (item == null) "null" else item.getClass.getName}" ++ ) ++ } ++ ++ override def deserialize[T](jsonNodeMap: util.Map[String, AnyRef], classType: Class[T]): T = { ++ if (jsonNodeMap == null) { ++ throw new IllegalStateException("The 'jsonNodeMap' should never be null here.") ++ } ++ ++ if (classType != classOf[SparkRowItem]) { ++ throw new IllegalStateException("The 'classType' must be 'classOf[SparkRowItem])' here.") ++ } ++ ++ val objectNode: ObjectNode = jsonNodeMap match { ++ case map: ObjectNodeMap => ++ map.getObjectNode ++ case _ => ++ Utils.getSimpleObjectMapper.convertValue(jsonNodeMap, classOf[ObjectNode]) ++ } ++ ++ val partitionKey = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition) ++ ++ val row = cosmosRowConverter.fromObjectNodeToRow(readSchema, ++ objectNode, ++ readConfig.schemaConversionMode) ++ ++ SparkRowItem(row, getPartitionKeyForFeedDiagnostics(partitionKey)).asInstanceOf[T] ++ } ++ } ++ ) ++ ++ // Collect all PK values upfront ΓÇö readManyByPartitionKey needs the full list to ++ // group by physical partition and issue parallel queries. ++ // Deduplicate by PK string representation ΓÇö safe because the list size is bounded ++ // by the per-call limit of the readManyByPartitionKey API. ++ private lazy val pkList = { ++ val seen = new java.util.LinkedHashMap[String, PartitionKey]() ++ readManyPartitionKeys.foreach(pk => seen.putIfAbsent(pk.toString, pk)) ++ new java.util.ArrayList[PartitionKey](seen.values()) ++ } ++ ++ private val endToEndTimeoutPolicy = ++ new CosmosEndToEndOperationLatencyPolicyConfigBuilder( ++ java.time.Duration.ofSeconds(CosmosConstants.readOperationEndToEndTimeoutInSeconds)) ++ .enable(true) ++ .build ++ ++ readManyOptionsImpl.setCosmosEndToEndOperationLatencyPolicyConfig(endToEndTimeoutPolicy) ++ ++ private trait CloseableSparkRowItemIterator { ++ def hasNext: Boolean ++ def next(): SparkRowItem ++ def close(): Unit ++ } ++ ++ private object EmptySparkRowItemIterator extends CloseableSparkRowItemIterator { ++ override def hasNext: Boolean = false ++ ++ override def next(): SparkRowItem = { ++ throw new java.util.NoSuchElementException("No items available for empty partition-key list.") ++ } ++ ++ override def close(): Unit = {} ++ } ++ ++ // Batch partition keys and retry each batch independently on transient I/O errors. ++ // This avoids the continuation-token problem with TransientIOErrorsRetryingIterator ++ // where a retry would re-read all data from scratch, causing silent data duplication. ++ private lazy val iterator: CloseableSparkRowItemIterator = ++ if (pkList.isEmpty) { ++ EmptySparkRowItemIterator ++ } else { ++ new CloseableSparkRowItemIterator { ++ private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( ++ cosmosAsyncContainer, ++ pkList, ++ readConfig.customQuery.map(_.toSqlQuerySpec), ++ readManyOptions, ++ readConfig.maxItemCount, ++ readConfig.prefetchBufferSize, ++ operationContextAndListenerTuple, ++ classOf[SparkRowItem] ++ ) ++ ++ override def hasNext: Boolean = delegate.hasNext ++ ++ override def next(): SparkRowItem = delegate.next() ++ ++ override def close(): Unit = delegate.close() ++ } ++ } ++ ++ private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) ++ ++ private def shouldLogDetailedFeedDiagnostics(): Boolean = { ++ diagnosticsConfig.mode.isDefined && ++ diagnosticsConfig.mode.get.equalsIgnoreCase(classOf[DetailedFeedDiagnosticsProvider].getName) ++ } ++ ++ private def getPartitionKeyForFeedDiagnostics(pkValue: PartitionKey): Option[PartitionKey] = { ++ if (shouldLogDetailedFeedDiagnostics()) { ++ Some(pkValue) ++ } else { ++ None ++ } ++ } ++ ++ override def next(): Boolean = iterator.hasNext ++ ++ override def get(): InternalRow = { ++ cosmosRowConverter.fromRowToInternalRow(iterator.next().row, rowSerializer) ++ } ++ ++ def getCurrentRow(): Row = iterator.next().row ++ ++ override def close(): Unit = { ++ this.iterator.close() ++ RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) ++ clientCacheItem.close() ++ if (throughputControlClientCacheItemOpt.isDefined) { ++ throughputControlClientCacheItemOpt.get.close() ++ } ++ } ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala +new file mode 100644 +index 00000000000..dcfdf4f9353 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala +@@ -0,0 +1,259 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.{CosmosAsyncContainer, CosmosException} ++import com.azure.cosmos.implementation.OperationCancelledException ++import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple ++import com.azure.cosmos.models.{CosmosReadManyRequestOptions, FeedResponse, PartitionKey, SqlQuerySpec} ++import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait ++import com.azure.cosmos.util.CosmosPagedIterable ++ ++import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} ++import java.util.concurrent.atomic.AtomicLong ++import scala.concurrent.{Await, ExecutionContext, Future} ++import scala.util.Random ++import scala.util.control.Breaks ++ ++// scalastyle:off underscore.import ++import scala.collection.JavaConverters._ ++// scalastyle:on underscore.import ++ ++/** ++ * Retry-safe iterator for readManyByPartitionKey that batches partition keys and lazily ++ * iterates pages within each batch via CosmosPagedIterable ΓÇö consistent with how ++ * TransientIOErrorsRetryingIterator handles normal queries. On transient I/O errors the ++ * current batch's flux is recreated and pages already consumed are replayed, avoiding ++ * the memory overhead of collectList and matching the query iterator's structure. ++ */ ++private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] ++( ++ val container: CosmosAsyncContainer, ++ val partitionKeys: java.util.List[PartitionKey], ++ val customQuery: Option[SqlQuerySpec], ++ val queryOptions: CosmosReadManyRequestOptions, ++ val pageSize: Int, ++ val pagePrefetchBufferSize: Int, ++ val operationContextAndListener: Option[OperationContextAndListenerTuple], ++ val classType: Class[TSparkRow] ++) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { ++ ++ private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs ++ private[spark] var maxRetryCount = CosmosConstants.maxRetryCountForTransientFailures ++ ++ private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration( ++ 5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds, ++ scala.concurrent.duration.SECONDS) ++ ++ private val rnd = Random ++ private val retryCount = new AtomicLong(0) ++ private lazy val operationContextString = operationContextAndListener match { ++ case Some(o) => if (o.getOperationContext != null) { ++ o.getOperationContext.toString ++ } else { ++ "n/a" ++ } ++ case None => "n/a" ++ } ++ ++ private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None ++ private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None ++ ++ private val pkBatchIterator = partitionKeys.asScala.iterator.grouped(pageSize) ++ // Track the current batch so we can replay it on retry ++ private var currentBatch: Option[java.util.List[PartitionKey]] = None ++ ++ override def hasNext: Boolean = { ++ executeWithRetry("hasNextInternal", () => hasNextInternal) ++ } ++ ++ private def hasNextInternal: Boolean = { ++ var returnValue: Option[Boolean] = None ++ ++ while (returnValue.isEmpty) { ++ returnValue = hasNextInternalCore ++ } ++ ++ returnValue.get ++ } ++ ++ private def hasNextInternalCore: Option[Boolean] = { ++ if (hasBufferedNext) { ++ Some(true) ++ } else { ++ val feedResponseIterator = currentFeedResponseIterator match { ++ case Some(existing) => existing ++ case None => ++ // Need a new feed response iterator ΓÇö either for the current batch (on retry) ++ // or for the next batch ++ val batch = currentBatch match { ++ case Some(b) => b // retry of current batch ++ case None => ++ if (pkBatchIterator.hasNext) { ++ val nextBatch = new java.util.ArrayList[PartitionKey](pkBatchIterator.next().toList.asJava) ++ currentBatch = Some(nextBatch) ++ nextBatch ++ } else { ++ return Some(false) // no more batches ++ } ++ } ++ ++ val pagedFlux = customQuery match { ++ case Some(query) => ++ container.readManyByPartitionKey(batch, query, queryOptions, classType) ++ case None => ++ container.readManyByPartitionKey(batch, queryOptions, classType) ++ } ++ ++ currentFeedResponseIterator = Some( ++ new CosmosPagedIterable[TSparkRow]( ++ pagedFlux, ++ pageSize, ++ pagePrefetchBufferSize ++ ) ++ .iterableByPage() ++ .iterator ++ .asScala ++ .buffered ++ ) ++ ++ currentFeedResponseIterator.get ++ } ++ ++ val hasNext: Boolean = try { ++ Await.result( ++ Future { ++ feedResponseIterator.hasNext ++ }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), ++ maxPageRetrievalTimeout) ++ } catch { ++ case endToEndTimeoutException: OperationCancelledException => ++ val message = s"End-to-end timeout hit when trying to retrieve the next page. " + ++ s"Context: $operationContextString" ++ logError(message, throwable = endToEndTimeoutException) ++ throw endToEndTimeoutException ++ ++ case timeoutException: TimeoutException => ++ val message = s"Attempting to retrieve the next page timed out. " + ++ s"Context: $operationContextString" ++ logError(message, timeoutException) ++ val exception = new OperationCancelledException(message, null) ++ exception.setStackTrace(timeoutException.getStackTrace) ++ throw exception ++ ++ case other: Throwable => throw other ++ } ++ ++ if (hasNext) { ++ val feedResponse = feedResponseIterator.next() ++ if (operationContextAndListener.isDefined) { ++ operationContextAndListener.get.getOperationListener.feedResponseProcessedListener( ++ operationContextAndListener.get.getOperationContext, ++ feedResponse) ++ } ++ val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered ++ ++ if (iteratorCandidate.hasNext) { ++ currentItemIterator = Some(iteratorCandidate) ++ Some(true) ++ } else { ++ // empty page interleaved ΓÇö try again ++ None ++ } ++ } else { ++ // Current batch's flux is exhausted ΓÇö move to next batch ++ currentBatch = None ++ currentFeedResponseIterator = None ++ None ++ } ++ } ++ } ++ ++ private def hasBufferedNext: Boolean = { ++ currentItemIterator match { ++ case Some(iterator) => if (iterator.hasNext) { ++ true ++ } else { ++ currentItemIterator = None ++ false ++ } ++ case None => false ++ } ++ } ++ ++ override def next(): TSparkRow = { ++ currentItemIterator.get.next() ++ } ++ ++ override def head(): TSparkRow = { ++ currentItemIterator.get.head ++ } ++ ++ private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = { ++ val loop = new Breaks() ++ var returnValue: Option[T] = None ++ ++ loop.breakable { ++ while (true) { ++ val retryIntervalInMs = rnd.nextInt(maxRetryIntervalInMs) ++ ++ try { ++ returnValue = Some(func()) ++ retryCount.set(0) ++ loop.break ++ } ++ catch { ++ case cosmosException: CosmosException => ++ if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) { ++ val retryCountSnapshot = retryCount.incrementAndGet() ++ if (retryCountSnapshot > maxRetryCount) { ++ logError( ++ s"Too many transient failure retry attempts in " + ++ s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName", ++ cosmosException) ++ throw cosmosException ++ } else { ++ logWarning( ++ s"Transient failure handled in " + ++ s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName -" + ++ s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms", ++ cosmosException) ++ } ++ } else { ++ throw cosmosException ++ } ++ case other: Throwable => throw other ++ } ++ ++ // Reset iterators but keep currentBatch so the batch is replayed ++ currentItemIterator = None ++ currentFeedResponseIterator = None ++ Thread.sleep(retryIntervalInMs) ++ } ++ } ++ ++ returnValue.get ++ } ++ ++ override def close(): Unit = { ++ currentItemIterator = None ++ currentFeedResponseIterator = None ++ } ++} ++ ++private object TransientIOErrorsRetryingReadManyByPartitionKeyIterator extends BasicLoggingTrait { ++ private val maxConcurrency = SparkUtils.getNumberOfHostCPUCores ++ ++ val executorService: ExecutorService = new ThreadPoolExecutor( ++ maxConcurrency, ++ maxConcurrency, ++ 0L, ++ TimeUnit.MILLISECONDS, ++ new SynchronousQueue(), ++ SparkUtils.daemonThreadFactory(), ++ new ThreadPoolExecutor.CallerRunsPolicy() ++ ) ++ ++ val executionContext: ExecutionContext = ExecutionContext.fromExecutorService(executorService) ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala +new file mode 100644 +index 00000000000..a58d5b723b8 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala +@@ -0,0 +1,25 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark.udf ++ ++import com.azure.cosmos.spark.CosmosPartitionKeyHelper ++import com.azure.cosmos.spark.CosmosPredicates.requireNotNull ++import org.apache.spark.sql.api.java.UDF1 ++ ++@SerialVersionUID(1L) ++class GetCosmosPartitionKeyValue extends UDF1[Object, String] { ++ override def call ++ ( ++ partitionKeyValue: Object ++ ): String = { ++ requireNotNull(partitionKeyValue, "partitionKeyValue") ++ ++ partitionKeyValue match { ++ // for subpartitions case - Seq covers both WrappedArray (Scala 2.12) and ArraySeq (Scala 2.13) ++ case seq: Seq[Any] => ++ CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(seq.map(_.asInstanceOf[Object]).toList) ++ case _ => CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) ++ } ++ } ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala +index 17f75e45a74..17a298d6213 100644 +--- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala +@@ -457,6 +457,7 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { + config.runtimeFilteringEnabled shouldBe true + config.readManyFilteringConfig.readManyFilteringEnabled shouldBe false + config.readManyFilteringConfig.readManyFilterProperty shouldEqual "_itemIdentity" ++ config.readManyByPkTreatNullAsNone shouldBe false + + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", +@@ -630,6 +631,47 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { + config.customQuery.get.queryText shouldBe queryText + } + ++ it should "parse readManyByPk nullHandling configuration" in { ++ // Default (not specified) should treat null as JSON null (addNullValue) ++ var userConfig = Map( ++ "spark.cosmos.read.forceEventualConsistency" -> "false" ++ ) ++ var config = CosmosReadConfig.parseCosmosReadConfig(userConfig) ++ config.readManyByPkTreatNullAsNone shouldBe false ++ ++ // Explicit "Null" should treat null as JSON null (addNullValue) ++ userConfig = Map( ++ "spark.cosmos.read.forceEventualConsistency" -> "false", ++ "spark.cosmos.read.readManyByPk.nullHandling" -> "Null" ++ ) ++ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) ++ config.readManyByPkTreatNullAsNone shouldBe false ++ ++ // Case-insensitive "null" ++ userConfig = Map( ++ "spark.cosmos.read.forceEventualConsistency" -> "false", ++ "spark.cosmos.read.readManyByPk.nullHandling" -> "null" ++ ) ++ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) ++ config.readManyByPkTreatNullAsNone shouldBe false ++ ++ // "None" should treat null as PartitionKey.NONE (addNoneValue) ++ userConfig = Map( ++ "spark.cosmos.read.forceEventualConsistency" -> "false", ++ "spark.cosmos.read.readManyByPk.nullHandling" -> "None" ++ ) ++ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) ++ config.readManyByPkTreatNullAsNone shouldBe true ++ ++ // Case-insensitive "none" ++ userConfig = Map( ++ "spark.cosmos.read.forceEventualConsistency" -> "false", ++ "spark.cosmos.read.readManyByPk.nullHandling" -> "none" ++ ) ++ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) ++ config.readManyByPkTreatNullAsNone shouldBe true ++ } ++ + it should "throw on invalid read configuration" in { + val userConfig = Map( + "spark.cosmos.read.schemaConversionMode" -> "not a valid value" +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +new file mode 100644 +index 00000000000..1ac40e39584 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +@@ -0,0 +1,104 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} ++ ++class CosmosPartitionKeyHelperSpec extends UnitSpec { ++ //scalastyle:off multiple.string.literals ++ //scalastyle:off magic.number ++ ++ it should "return the correct partition key value string for single PK" in { ++ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("pk1")) ++ pkString shouldEqual "pk([\"pk1\"])" ++ } ++ ++ it should "return the correct partition key value string for HPK" in { ++ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city1", "zip1")) ++ pkString shouldEqual "pk([\"city1\",\"zip1\"])" ++ } ++ ++ it should "return the correct partition key value string for 3-level HPK" in { ++ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("a", "b", "c")) ++ pkString shouldEqual "pk([\"a\",\"b\",\"c\"])" ++ } ++ ++ it should "parse valid single PK string" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"myPkValue\"])") ++ pk.isDefined shouldBe true ++ pk.get shouldEqual new PartitionKey("myPkValue") ++ } ++ ++ it should "parse valid HPK string" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"city1\",\"zip1\"])") ++ pk.isDefined shouldBe true ++ val expected = new PartitionKeyBuilder().add("city1").add("zip1").build() ++ pk.get shouldEqual expected ++ } ++ ++ it should "parse valid 3-level HPK string" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"a\",\"b\",\"c\"])") ++ pk.isDefined shouldBe true ++ val expected = new PartitionKeyBuilder().add("a").add("b").add("c").build() ++ pk.get shouldEqual expected ++ } ++ ++ it should "roundtrip single PK" in { ++ val original = "pk([\"roundtrip\"])" ++ val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) ++ parsed.isDefined shouldBe true ++ val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("roundtrip")) ++ serialized shouldEqual original ++ } ++ ++ it should "roundtrip HPK" in { ++ val original = "pk([\"city\",\"zip\"])" ++ val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) ++ parsed.isDefined shouldBe true ++ val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city", "zip")) ++ serialized shouldEqual original ++ } ++ ++ it should "return None for malformed string" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("invalid_format") ++ pk.isDefined shouldBe false ++ } ++ ++ it should "return None for missing pk prefix" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("[\"value\"]") ++ pk.isDefined shouldBe false ++ } ++ ++ it should "be case-insensitive for parsing" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("PK([\"value\"])") ++ pk.isDefined shouldBe true ++ pk.get shouldEqual new PartitionKey("value") ++ } ++ ++ ++ it should "return None for malformed JSON inside pk() wrapper" in { ++ // Invalid JSON that would cause JsonProcessingException ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk({invalid json})") ++ pk.isDefined shouldBe false ++ } ++ ++ it should "return None for truncated JSON inside pk() wrapper" in { ++ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"unterminated)") ++ pk.isDefined shouldBe false ++ } ++ ++ it should "produce different partition keys for addNullValue vs addNoneValue in HPK" in { ++ // addNullValue represents an explicit JSON null for a field that exists with value null ++ val pkWithNull = new PartitionKeyBuilder().add("Redmond").addNullValue().build() ++ ++ // addNoneValue represents PartitionKey.NONE, meaning the field is absent/undefined ++ val pkWithNone = new PartitionKeyBuilder().add("Redmond").addNoneValue().build() ++ ++ // These MUST produce different partition key hashes and route to different physical partitions ++ pkWithNull should not equal pkWithNone ++ } ++ ++ //scalastyle:on multiple.string.literals ++ //scalastyle:on magic.number ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala +new file mode 100644 +index 00000000000..5c2d7b59836 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala +@@ -0,0 +1,158 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++ ++package com.azure.cosmos.spark ++ ++import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, TestConfigurations, Utils} ++import com.azure.cosmos.models.PartitionKey ++import com.azure.cosmos.spark.diagnostics.DiagnosticsContext ++import com.fasterxml.jackson.databind.node.ObjectNode ++import org.apache.spark.MockTaskContext ++import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.sql.types.{StringType, StructField, StructType} ++ ++import java.util.UUID ++import scala.collection.mutable.ListBuffer ++ ++class ItemsPartitionReaderWithReadManyByPartitionKeyITest ++ extends IntegrationSpec ++ with Spark ++ with AutoCleanableCosmosContainersWithPkAsPartitionKey { ++ private val idProperty = "id" ++ private val pkProperty = "pk" ++ ++ //scalastyle:off multiple.string.literals ++ //scalastyle:off magic.number ++ ++ it should "be able to retrieve all items for given partition keys" in { ++ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) ++ ++ // Create items with known PK values ++ val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition ++ val allItemsByPk = scala.collection.mutable.Map[String, ListBuffer[ObjectNode]]() ++ val pkValues = List("pkA", "pkB", "pkC") ++ ++ for (pk <- pkValues) { ++ allItemsByPk(pk) = ListBuffer[ObjectNode]() ++ for (_ <- 1 to 5) { ++ val objectNode = Utils.getSimpleObjectMapper.createObjectNode() ++ objectNode.put(idProperty, UUID.randomUUID().toString) ++ objectNode.put(pkProperty, pk) ++ container.createItem(objectNode).block() ++ allItemsByPk(pk) += objectNode ++ } ++ } ++ ++ val config = Map( ++ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, ++ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, ++ "spark.cosmos.database" -> cosmosDatabase, ++ "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, ++ "spark.cosmos.read.inferSchema.enabled" -> "true", ++ "spark.cosmos.applicationName" -> "ReadManyByPKTest" ++ ) ++ ++ val readSchema = StructType(Seq( ++ StructField(idProperty, StringType, false), ++ StructField(pkProperty, StringType, false) ++ )) ++ ++ val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") ++ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) ++ val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() ++ ++ // Read items for pkA and pkB (not pkC) ++ val targetPks = List("pkA", "pkB") ++ val pkIterator = targetPks.map(pk => new PartitionKey(pk)).iterator ++ ++ val reader = ItemsPartitionReaderWithReadManyByPartitionKey( ++ config, ++ NormalizedRange("", "FF"), ++ readSchema, ++ diagnosticsContext, ++ cosmosClientMetadataCachesSnapshots, ++ diagnosticsConfig, ++ "", ++ MockTaskContext.mockTaskContext(), ++ pkIterator ++ ) ++ ++ val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) ++ val itemsReadFromReader = ListBuffer[ObjectNode]() ++ while (reader.next()) { ++ itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) ++ } ++ ++ // Should have 10 items (5 for pkA + 5 for pkB) ++ itemsReadFromReader.size shouldEqual 10 ++ ++ // All items should be from pkA or pkB ++ itemsReadFromReader.foreach(item => { ++ val pk = item.get(pkProperty).asText() ++ targetPks should contain(pk) ++ }) ++ ++ // Validate all expected IDs are present ++ val expectedIds = (allItemsByPk("pkA") ++ allItemsByPk("pkB")).map(_.get(idProperty).asText()).toSet ++ val actualIds = itemsReadFromReader.map(_.get(idProperty).asText()).toSet ++ actualIds shouldEqual expectedIds ++ ++ reader.close() ++ } ++ ++ it should "return empty results for non-existent partition keys" in { ++ val config = Map( ++ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, ++ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, ++ "spark.cosmos.database" -> cosmosDatabase, ++ "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, ++ "spark.cosmos.read.inferSchema.enabled" -> "true", ++ "spark.cosmos.applicationName" -> "ReadManyByPKEmptyTest" ++ ) ++ ++ val readSchema = StructType(Seq( ++ StructField(idProperty, StringType, false), ++ StructField(pkProperty, StringType, false) ++ )) ++ ++ val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") ++ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) ++ val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() ++ ++ val pkIterator = List(new PartitionKey("nonExistentPk")).iterator ++ ++ val reader = ItemsPartitionReaderWithReadManyByPartitionKey( ++ config, ++ NormalizedRange("", "FF"), ++ readSchema, ++ diagnosticsContext, ++ cosmosClientMetadataCachesSnapshots, ++ diagnosticsConfig, ++ "", ++ MockTaskContext.mockTaskContext(), ++ pkIterator ++ ) ++ ++ val itemsReadFromReader = ListBuffer[ObjectNode]() ++ val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) ++ while (reader.next()) { ++ itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) ++ } ++ ++ itemsReadFromReader.size shouldEqual 0 ++ reader.close() ++ } ++ ++ private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = { ++ val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot() ++ cosmosClientMetadataCachesSnapshot.serialize(cosmosClient) ++ ++ spark.sparkContext.broadcast( ++ CosmosClientMetadataCachesSnapshots( ++ cosmosClientMetadataCachesSnapshot, ++ Option.empty[CosmosClientMetadataCachesSnapshot])) ++ } ++ ++ //scalastyle:on multiple.string.literals ++ //scalastyle:on magic.number ++} +diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +index 3972ae6aeb9..d8368be6a0d 100644 +--- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.47.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +new file mode 100644 +index 00000000000..2c26d564ed2 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +@@ -0,0 +1,462 @@ ++/* ++ * Copyright (c) Microsoft Corporation. All rights reserved. ++ * Licensed under the MIT License. ++ */ ++ ++package com.azure.cosmos; ++ ++import com.azure.cosmos.models.CosmosContainerProperties; ++import com.azure.cosmos.models.CosmosItemRequestOptions; ++import com.azure.cosmos.models.FeedResponse; ++import com.azure.cosmos.models.PartitionKey; ++import com.azure.cosmos.models.PartitionKeyBuilder; ++import com.azure.cosmos.models.PartitionKeyDefinition; ++import com.azure.cosmos.models.PartitionKeyDefinitionVersion; ++import com.azure.cosmos.models.PartitionKind; ++import com.azure.cosmos.models.SqlParameter; ++import com.azure.cosmos.models.SqlQuerySpec; ++import com.azure.cosmos.rx.TestSuiteBase; ++import com.azure.cosmos.util.CosmosPagedIterable; ++import com.fasterxml.jackson.databind.node.ObjectNode; ++import org.testng.annotations.AfterClass; ++import org.testng.annotations.BeforeClass; ++import org.testng.annotations.Factory; ++import org.testng.annotations.Test; ++ ++import java.util.ArrayList; ++import java.util.Arrays; ++import java.util.Collections; ++import java.util.List; ++import java.util.UUID; ++import java.util.stream.Collectors; ++ ++import static org.assertj.core.api.Assertions.assertThat; ++import static org.assertj.core.api.Assertions.fail; ++ ++public class ReadManyByPartitionKeyTest extends TestSuiteBase { ++ ++ private String preExistingDatabaseId = CosmosDatabaseForTest.generateId(); ++ private CosmosClient client; ++ private CosmosDatabase createdDatabase; ++ ++ // Single PK container (/mypk) ++ private CosmosContainer singlePkContainer; ++ ++ // HPK container (/city, /zipcode, /areaCode) ++ private CosmosContainer multiHashContainer; ++ ++ @Factory(dataProvider = "clientBuilders") ++ public ReadManyByPartitionKeyTest(CosmosClientBuilder clientBuilder) { ++ super(clientBuilder); ++ } ++ ++ @BeforeClass(groups = {"emulator"}, timeOut = SETUP_TIMEOUT) ++ public void before_ReadManyByPartitionKeyTest() { ++ client = getClientBuilder().buildClient(); ++ createdDatabase = createSyncDatabase(client, preExistingDatabaseId); ++ ++ // Single PK container ++ String singlePkContainerName = UUID.randomUUID().toString(); ++ CosmosContainerProperties singlePkProps = new CosmosContainerProperties(singlePkContainerName, "/mypk"); ++ createdDatabase.createContainer(singlePkProps); ++ singlePkContainer = createdDatabase.getContainer(singlePkContainerName); ++ ++ // HPK container ++ String multiHashContainerName = UUID.randomUUID().toString(); ++ PartitionKeyDefinition hpkDef = new PartitionKeyDefinition(); ++ hpkDef.setKind(PartitionKind.MULTI_HASH); ++ hpkDef.setVersion(PartitionKeyDefinitionVersion.V2); ++ ArrayList paths = new ArrayList<>(); ++ paths.add("/city"); ++ paths.add("/zipcode"); ++ paths.add("/areaCode"); ++ hpkDef.setPaths(paths); ++ ++ CosmosContainerProperties hpkProps = new CosmosContainerProperties(multiHashContainerName, hpkDef); ++ createdDatabase.createContainer(hpkProps); ++ multiHashContainer = createdDatabase.getContainer(multiHashContainerName); ++ } ++ ++ @AfterClass(groups = {"emulator"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) ++ public void afterClass() { ++ safeDeleteSyncDatabase(createdDatabase); ++ safeCloseSyncClient(client); ++ } ++ ++ //region Single PK tests ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_basic() { ++ // Create items with different PKs ++ List items = createSinglePkItems("pk1", 3); ++ items.addAll(createSinglePkItems("pk2", 2)); ++ items.addAll(createSinglePkItems("pk3", 4)); ++ ++ // Read by 2 partition keys ++ List pkValues = Arrays.asList( ++ new PartitionKey("pk1"), ++ new PartitionKey("pk2")); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(5); // 3 + 2 ++ resultList.forEach(item -> { ++ String pk = item.get("mypk").asText(); ++ assertThat(pk).isIn("pk1", "pk2"); ++ }); ++ ++ // Cleanup ++ cleanupContainer(singlePkContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_withProjection() { ++ List items = createSinglePkItems("pkProj", 2); ++ ++ List pkValues = Collections.singletonList(new PartitionKey("pkProj")); ++ SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.mypk FROM c"); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( ++ pkValues, customQuery, null, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(2); ++ // Should only have id and mypk fields (plus system properties) ++ resultList.forEach(item -> { ++ assertThat(item.has("id")).isTrue(); ++ assertThat(item.has("mypk")).isTrue(); ++ }); ++ ++ cleanupContainer(singlePkContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_withAdditionalFilter() { ++ // Create items with different "status" values ++ createSinglePkItemsWithStatus("pkFilter", "active", 3); ++ createSinglePkItemsWithStatus("pkFilter", "inactive", 2); ++ ++ List pkValues = Collections.singletonList(new PartitionKey("pkFilter")); ++ SqlQuerySpec customQuery = new SqlQuerySpec( ++ "SELECT * FROM c WHERE c.status = @status", ++ Arrays.asList(new SqlParameter("@status", "active"))); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( ++ pkValues, customQuery, null, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(3); ++ resultList.forEach(item -> { ++ assertThat(item.get("status").asText()).isEqualTo("active"); ++ }); ++ ++ cleanupContainer(singlePkContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_emptyResults() { ++ List pkValues = Collections.singletonList(new PartitionKey("nonExistent")); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).isEmpty(); ++ } ++ ++ //endregion ++ ++ //region HPK tests ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void hpk_readManyByPartitionKey_fullPk() { ++ createHpkItems(); ++ ++ // Read by full PKs ++ List pkValues = Arrays.asList( ++ new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build(), ++ new PartitionKeyBuilder().add("Pittsburgh").add("15232").add(2).build()); ++ ++ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ // Redmond/98053/1 has 2 items, Pittsburgh/15232/2 has 1 item ++ assertThat(resultList).hasSize(3); ++ ++ cleanupContainer(multiHashContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void hpk_readManyByPartitionKey_partialPk_singleLevel() { ++ createHpkItems(); ++ ++ // Read by partial PK (only city) ++ List pkValues = Collections.singletonList( ++ new PartitionKeyBuilder().add("Redmond").build()); ++ ++ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ // Redmond has 3 items total (2 with 98053/1 and 1 with 12345/1) ++ assertThat(resultList).hasSize(3); ++ resultList.forEach(item -> { ++ assertThat(item.get("city").asText()).isEqualTo("Redmond"); ++ }); ++ ++ cleanupContainer(multiHashContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void hpk_readManyByPartitionKey_partialPk_twoLevels() { ++ createHpkItems(); ++ ++ // Read by partial PK (city + zipcode) ++ List pkValues = Collections.singletonList( ++ new PartitionKeyBuilder().add("Redmond").add("98053").build()); ++ ++ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ // Redmond/98053 has 2 items ++ assertThat(resultList).hasSize(2); ++ resultList.forEach(item -> { ++ assertThat(item.get("city").asText()).isEqualTo("Redmond"); ++ assertThat(item.get("zipcode").asText()).isEqualTo("98053"); ++ }); ++ ++ cleanupContainer(multiHashContainer); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void hpk_readManyByPartitionKey_withProjection() { ++ createHpkItems(); ++ ++ List pkValues = Collections.singletonList( ++ new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build()); ++ ++ SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.city FROM c"); ++ ++ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey( ++ pkValues, customQuery, null, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(2); ++ ++ cleanupContainer(multiHashContainer); ++ } ++ ++ //endregion ++ ++ //region Negative/validation tests ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void rejectsAggregateQuery() { ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ SqlQuerySpec aggregateQuery = new SqlQuerySpec("SELECT COUNT(1) FROM c"); ++ ++ try { ++ singlePkContainer.readManyByPartitionKey(pkValues, aggregateQuery, null, ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ fail("Should have thrown IllegalArgumentException for aggregate query"); ++ } catch (IllegalArgumentException e) { ++ assertThat(e.getMessage()).contains("aggregates"); ++ } ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void rejectsOrderByQuery() { ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ SqlQuerySpec orderByQuery = new SqlQuerySpec("SELECT * FROM c ORDER BY c.id"); ++ ++ try { ++ singlePkContainer.readManyByPartitionKey(pkValues, orderByQuery, null, ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ fail("Should have thrown IllegalArgumentException for ORDER BY query"); ++ } catch (IllegalArgumentException e) { ++ assertThat(e.getMessage()).contains("ORDER BY"); ++ } ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void rejectsDistinctQuery() { ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ SqlQuerySpec distinctQuery = new SqlQuerySpec("SELECT DISTINCT c.mypk FROM c"); ++ ++ try { ++ singlePkContainer.readManyByPartitionKey(pkValues, distinctQuery, null, ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ fail("Should have thrown IllegalArgumentException for DISTINCT query"); ++ } catch (IllegalArgumentException e) { ++ assertThat(e.getMessage()).contains("DISTINCT"); ++ } ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void rejectsGroupByQuery() { ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk FROM c GROUP BY c.mypk"); ++ ++ try { ++ singlePkContainer.readManyByPartitionKey(pkValues, groupByQuery, null, ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ fail("Should have thrown IllegalArgumentException for GROUP BY query"); ++ } catch (IllegalArgumentException e) { ++ assertThat(e.getMessage()).contains("GROUP BY"); ++ } ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void rejectsGroupByWithAggregateQuery() { ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ SqlQuerySpec groupByWithAggregateQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); ++ ++ try { ++ singlePkContainer.readManyByPartitionKey(pkValues, groupByWithAggregateQuery, null, ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ fail("Should have thrown IllegalArgumentException for GROUP BY with aggregate query"); ++ } catch (IllegalArgumentException e) { ++ assertThat(e.getMessage()).contains("GROUP BY"); ++ } ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) ++ public void rejectsNullPartitionKeyList() { ++ singlePkContainer.readManyByPartitionKey((List) null, ObjectNode.class); ++ } ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) ++ public void rejectsEmptyPartitionKeyList() { ++ singlePkContainer.readManyByPartitionKey(new ArrayList<>(), ObjectNode.class) ++ .stream().collect(Collectors.toList()); ++ } ++ ++ //endregion ++ ++ ++ //region Batch size tests (#10) ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_withSmallBatchSize() { ++ // Temporarily set batch size to 2 to exercise the batching/interleaving logic ++ String originalValue = System.getProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); ++ try { ++ System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", "2"); ++ ++ // Create items across 4 PKs (more than the batch size of 2) ++ List items = createSinglePkItems("batchPk1", 2); ++ items.addAll(createSinglePkItems("batchPk2", 2)); ++ items.addAll(createSinglePkItems("batchPk3", 2)); ++ items.addAll(createSinglePkItems("batchPk4", 2)); ++ ++ // Read all 4 PKs ΓÇö should be split into batches of 2 ++ List pkValues = Arrays.asList( ++ new PartitionKey("batchPk1"), ++ new PartitionKey("batchPk2"), ++ new PartitionKey("batchPk3"), ++ new PartitionKey("batchPk4")); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(8); // 2 items per PK * 4 PKs ++ resultList.forEach(item -> { ++ String pk = item.get("mypk").asText(); ++ assertThat(pk).isIn("batchPk1", "batchPk2", "batchPk3", "batchPk4"); ++ }); ++ ++ cleanupContainer(singlePkContainer); ++ } finally { ++ if (originalValue != null) { ++ System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", originalValue); ++ } else { ++ System.clearProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); ++ } ++ } ++ } ++ ++ //endregion ++ ++ //region Custom serializer regression tests (#5) ++ ++ @Test(groups = {"emulator"}, timeOut = TIMEOUT) ++ public void singlePk_readManyByPartitionKey_withRequestOptions() { ++ // This test ensures that request options (like throughput control settings) ++ // are properly propagated through the readManyByPartitionKey path. ++ // It acts as a regression test for the redundant options construction bug. ++ List items = createSinglePkItems("pkOpts", 3); ++ ++ List pkValues = Collections.singletonList(new PartitionKey("pkOpts")); ++ com.azure.cosmos.models.CosmosReadManyRequestOptions options = new com.azure.cosmos.models.CosmosReadManyRequestOptions(); ++ ++ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( ++ pkValues, options, ObjectNode.class); ++ List resultList = results.stream().collect(Collectors.toList()); ++ ++ assertThat(resultList).hasSize(3); ++ ++ cleanupContainer(singlePkContainer); ++ } ++ ++ //endregion ++ ++ //region helper methods ++ ++ private List createSinglePkItems(String pkValue, int count) { ++ List items = new ArrayList<>(); ++ for (int i = 0; i < count; i++) { ++ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); ++ item.put("id", UUID.randomUUID().toString()); ++ item.put("mypk", pkValue); ++ singlePkContainer.createItem(item); ++ items.add(item); ++ } ++ return items; ++ } ++ ++ private List createSinglePkItemsWithStatus(String pkValue, String status, int count) { ++ List items = new ArrayList<>(); ++ for (int i = 0; i < count; i++) { ++ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); ++ item.put("id", UUID.randomUUID().toString()); ++ item.put("mypk", pkValue); ++ item.put("status", status); ++ singlePkContainer.createItem(item); ++ items.add(item); ++ } ++ return items; ++ } ++ ++ private void createHpkItems() { ++ // Same data as CosmosMultiHashTest.createItems() ++ createHpkItem("Redmond", "98053", 1); ++ createHpkItem("Redmond", "98053", 1); ++ createHpkItem("Pittsburgh", "15232", 2); ++ createHpkItem("Stonybrook", "11790", 3); ++ createHpkItem("Stonybrook", "11794", 3); ++ createHpkItem("Stonybrook", "11791", 3); ++ createHpkItem("Redmond", "12345", 1); ++ } ++ ++ private void createHpkItem(String city, String zipcode, int areaCode) { ++ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); ++ item.put("id", UUID.randomUUID().toString()); ++ item.put("city", city); ++ item.put("zipcode", zipcode); ++ item.put("areaCode", areaCode); ++ multiHashContainer.createItem(item); ++ } ++ ++ private void cleanupContainer(CosmosContainer container) { ++ CosmosPagedIterable allItems = container.queryItems( ++ "SELECT * FROM c", new com.azure.cosmos.models.CosmosQueryRequestOptions(), ObjectNode.class); ++ allItems.forEach(item -> { ++ try { ++ container.deleteItem(item, new CosmosItemRequestOptions()); ++ } catch (CosmosException e) { ++ // ignore cleanup failures ++ } ++ }); ++ } ++ ++ //endregion ++} +diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java +new file mode 100644 +index 00000000000..95c109ba025 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java +@@ -0,0 +1,426 @@ ++/* ++ * Copyright (c) Microsoft Corporation. All rights reserved. ++ * Licensed under the MIT License. ++ */ ++ ++package com.azure.cosmos.implementation; ++ ++import com.azure.cosmos.models.PartitionKey; ++import com.azure.cosmos.models.PartitionKeyBuilder; ++import com.azure.cosmos.models.PartitionKeyDefinition; ++import com.azure.cosmos.models.PartitionKeyDefinitionVersion; ++import com.azure.cosmos.models.PartitionKind; ++import com.azure.cosmos.models.SqlParameter; ++import com.azure.cosmos.models.SqlQuerySpec; ++import org.testng.annotations.Test; ++ ++import java.util.ArrayList; ++import java.util.Arrays; ++import java.util.Collections; ++import java.util.List; ++import java.util.stream.Collectors; ++ ++import static org.assertj.core.api.Assertions.assertThat; ++ ++public class ReadManyByPartitionKeyQueryHelperTest { ++ ++ //region Single PK (HASH) tests ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_defaultQuery_singleValue() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); ++ assertThat(result.getQueryText()).contains("IN ("); ++ assertThat(result.getQueryText()).contains("@__rmPk_0"); ++ assertThat(result.getParameters()).hasSize(1); ++ assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("pk1"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_defaultQuery_multipleValues() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Arrays.asList( ++ new PartitionKey("pk1"), ++ new PartitionKey("pk2"), ++ new PartitionKey("pk3")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("IN ("); ++ assertThat(result.getQueryText()).contains("@__rmPk_0"); ++ assertThat(result.getQueryText()).contains("@__rmPk_1"); ++ assertThat(result.getQueryText()).contains("@__rmPk_2"); ++ assertThat(result.getParameters()).hasSize(3); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_customQuery_noWhere() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT c.name, c.age FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).startsWith("SELECT c.name, c.age FROM c WHERE"); ++ assertThat(result.getQueryText()).contains("IN ("); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_customQuery_withExistingWhere() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ ++ List baseParams = new ArrayList<>(); ++ baseParams.add(new SqlParameter("@minAge", 18)); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c WHERE c.age > @minAge", baseParams, pkValues, selectors, pkDef); ++ ++ // Should AND the PK filter to the existing WHERE clause ++ assertThat(result.getQueryText()).contains("WHERE (c.age > @minAge) AND ("); ++ assertThat(result.getQueryText()).contains("IN ("); ++ assertThat(result.getParameters()).hasSize(2); // @minAge + @__rmPk_0 ++ assertThat(result.getParameters().get(0).getName()).isEqualTo("@minAge"); ++ } ++ ++ //endregion ++ ++ //region HPK (MULTI_HASH) tests ++ ++ @Test(groups = { "unit" }) ++ public void hpk_fullPk_defaultQuery() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); ++ List selectors = createSelectors(pkDef); ++ ++ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); ++ List pkValues = Collections.singletonList(pk); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); ++ // Should use OR/AND pattern, not IN ++ assertThat(result.getQueryText()).doesNotContain("IN ("); ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); ++ assertThat(result.getQueryText()).contains("AND"); ++ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); ++ assertThat(result.getParameters()).hasSize(2); ++ assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("Redmond"); ++ assertThat(result.getParameters().get(1).getValue(Object.class)).isEqualTo("98052"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void hpk_fullPk_multipleValues() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); ++ List selectors = createSelectors(pkDef); ++ ++ List pkValues = Arrays.asList( ++ new PartitionKeyBuilder().add("Redmond").add("98052").build(), ++ new PartitionKeyBuilder().add("Seattle").add("98101").build()); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("OR"); ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); ++ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_2"); ++ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_3"); ++ assertThat(result.getParameters()).hasSize(4); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void hpk_partialPk_singleLevel() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); ++ List selectors = createSelectors(pkDef); ++ ++ // Partial PK ΓÇö only first level ++ PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").build(); ++ List pkValues = Collections.singletonList(partialPk); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); ++ // Should NOT include zipcode or areaCode since it's partial ++ assertThat(result.getQueryText()).doesNotContain("zipcode"); ++ assertThat(result.getQueryText()).doesNotContain("areaCode"); ++ assertThat(result.getParameters()).hasSize(1); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void hpk_partialPk_twoLevels() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); ++ List selectors = createSelectors(pkDef); ++ ++ // Partial PK ΓÇö first two levels ++ PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); ++ List pkValues = Collections.singletonList(partialPk); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); ++ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); ++ assertThat(result.getQueryText()).doesNotContain("areaCode"); ++ assertThat(result.getParameters()).hasSize(2); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void hpk_customQuery_withWhere() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); ++ List selectors = createSelectors(pkDef); ++ ++ List baseParams = new ArrayList<>(); ++ baseParams.add(new SqlParameter("@status", "active")); ++ ++ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); ++ List pkValues = Collections.singletonList(pk); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT c.name FROM c WHERE c.status = @status", baseParams, pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("WHERE (c.status = @status) AND ("); ++ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); ++ assertThat(result.getParameters()).hasSize(3); // @status + 2 pk params ++ } ++ ++ //endregion ++ ++ //region findTopLevelWhereIndex tests ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_simpleQuery() { ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.ID = 1"); ++ assertThat(idx).isEqualTo(16); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_noWhere() { ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C"); ++ assertThat(idx).isEqualTo(-1); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_whereInSubquery() { ++ // WHERE inside parentheses (subquery) should be ignored ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( ++ "SELECT * FROM C WHERE EXISTS(SELECT VALUE T FROM T IN C.TAGS WHERE T = 'FOO')"); ++ // Should find the outer WHERE, not the inner one ++ assertThat(idx).isEqualTo(16); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_caseInsensitive() { ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.X = 1"); ++ assertThat(idx).isGreaterThan(0); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_whereNotKeyword() { ++ // "ELSEWHERE" should not match ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM ELSEWHERE"); ++ assertThat(idx).isEqualTo(-1); ++ } ++ ++ //endregion ++ ++ //region Custom alias tests ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_customAlias() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT x.id, x.mypk FROM x", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).startsWith("SELECT x.id, x.mypk FROM x WHERE"); ++ assertThat(result.getQueryText()).contains("x[\"mypk\"] IN ("); ++ assertThat(result.getQueryText()).doesNotContain("c[\"mypk\"]"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_customAlias_withWhere() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ ++ List baseParams = new ArrayList<>(); ++ baseParams.add(new SqlParameter("@cat", "HelloWorld")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT x.id, x.mypk FROM x WHERE x.category = @cat", baseParams, pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("WHERE (x.category = @cat) AND (x[\"mypk\"] IN ("); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void hpk_customAlias() { ++ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); ++ List selectors = createSelectors(pkDef); ++ ++ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); ++ List pkValues = Collections.singletonList(pk); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT r.name FROM root r", new ArrayList<>(), pkValues, selectors, pkDef); ++ ++ assertThat(result.getQueryText()).contains("r[\"city\"] = @__rmPk_0"); ++ assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @__rmPk_1"); ++ assertThat(result.getQueryText()).doesNotContain("c[\""); ++ } ++ ++ //endregion ++ ++ //region extractTableAlias tests ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_defaultC() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM c")).isEqualTo("c"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_customX() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT x.id FROM x WHERE x.age > 5")).isEqualTo("x"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_rootWithAlias() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT r.name FROM root r")).isEqualTo("r"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_rootNoAlias() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM root")).isEqualTo("root"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_containerWithWhere() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM items WHERE items.status = 'active'")).isEqualTo("items"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_caseInsensitive() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("select * from MyContainer where MyContainer.id = '1'")).isEqualTo("MyContainer"); ++ } ++ ++ //endregion ++ ++ ++ //region String literal handling tests (#1) ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_ignoresWhereInsideStringLiteral() { ++ // WHERE inside a string literal should be ignored ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( ++ "SELECT * FROM c WHERE c.msg = 'use WHERE clause here'"); ++ // Should find the outer WHERE at position 16, not the one inside the string ++ assertThat(idx).isEqualTo(16); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_ignoresParenthesesInsideStringLiteral() { ++ // Parentheses inside string literal should not affect depth tracking ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( ++ "SELECT * FROM c WHERE c.name = 'foo(bar)' AND c.x = 1"); ++ assertThat(idx).isEqualTo(16); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_handlesUnbalancedParenInStringLiteral() { ++ // Unbalanced paren inside string literal must not corrupt depth ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( ++ "SELECT * FROM c WHERE c.val = 'open(' AND c.active = true"); ++ assertThat(idx).isEqualTo(16); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void findWhere_handlesStringLiteralBeforeWhere() { ++ // String literal in SELECT before WHERE ++ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( ++ "SELECT 'WHERE' as label FROM c WHERE c.id = '1'"); ++ // The WHERE inside quotes should be ignored; the real WHERE is further along ++ assertThat(idx).isGreaterThan(30); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void singlePk_customQuery_withStringLiteralContainingParens() { ++ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); ++ List selectors = createSelectors(pkDef); ++ List pkValues = Collections.singletonList(new PartitionKey("pk1")); ++ ++ List baseParams = new ArrayList<>(); ++ baseParams.add(new SqlParameter("@msg", "hello")); ++ ++ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( ++ "SELECT * FROM c WHERE c.msg = 'test(value)WHERE'", baseParams, pkValues, selectors, pkDef); ++ ++ // Should correctly AND the PK filter to the real WHERE clause ++ assertThat(result.getQueryText()).contains("WHERE (c.msg = 'test(value)WHERE') AND ("); ++ } ++ ++ //endregion ++ ++ //region OFFSET/LIMIT/HAVING alias detection tests (#9) ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_containerWithOffset() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( ++ "SELECT * FROM c OFFSET 10 LIMIT 5")).isEqualTo("c"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_containerWithLimit() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( ++ "SELECT * FROM c LIMIT 10")).isEqualTo("c"); ++ } ++ ++ @Test(groups = { "unit" }) ++ public void extractAlias_containerWithHaving() { ++ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( ++ "SELECT c.status, COUNT(1) FROM c GROUP BY c.status HAVING COUNT(1) > 1")).isEqualTo("c"); ++ } ++ ++ //endregion ++ ++ //region helpers ++ ++ private PartitionKeyDefinition createSinglePkDefinition(String path) { ++ PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); ++ pkDef.setKind(PartitionKind.HASH); ++ pkDef.setVersion(PartitionKeyDefinitionVersion.V2); ++ pkDef.setPaths(Collections.singletonList(path)); ++ return pkDef; ++ } ++ ++ private PartitionKeyDefinition createMultiHashPkDefinition(String... paths) { ++ PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); ++ pkDef.setKind(PartitionKind.MULTI_HASH); ++ pkDef.setVersion(PartitionKeyDefinitionVersion.V2); ++ pkDef.setPaths(Arrays.asList(paths)); ++ return pkDef; ++ } ++ ++ private List createSelectors(PartitionKeyDefinition pkDef) { ++ return pkDef.getPaths() ++ .stream() ++ .map(pathPart -> pathPart.substring(1)) // skip starting / ++ .map(pathPart -> pathPart.replace("\"", "\\")) ++ .map(part -> "[\"" + part + "\"]") ++ .collect(Collectors.toList()); ++ } ++ ++ //endregion ++} +diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md +index e8ea564fab7..904c01c3238 100644 +--- a/sdk/cosmos/azure-cosmos/CHANGELOG.md ++++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md +@@ -3,6 +3,7 @@ + ### 4.80.0-beta.1 (Unreleased) + + #### Features Added ++* Added new `readManyByPartitionKey` to bulk query by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) + + #### Breaking Changes + +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +index ad871bb97c0..4e234667c1c 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +@@ -165,6 +165,7 @@ public class CosmosAsyncContainer { + private final String createItemSpanName; + private final String readAllItemsSpanName; + private final String readManyItemsSpanName; ++ private final String readManyByPartitionKeyItemsSpanName; + private final String readAllItemsOfLogicalPartitionSpanName; + private final String queryItemsSpanName; + private final String queryChangeFeedSpanName; +@@ -198,6 +199,7 @@ public class CosmosAsyncContainer { + this.createItemSpanName = "createItem." + this.id; + this.readAllItemsSpanName = "readAllItems." + this.id; + this.readManyItemsSpanName = "readManyItems." + this.id; ++ this.readManyByPartitionKeyItemsSpanName = "readManyByPartitionKeyItems." + this.id; + this.readAllItemsOfLogicalPartitionSpanName = "readAllItemsOfLogicalPartition." + this.id; + this.queryItemsSpanName = "queryItems." + this.id; + this.queryChangeFeedSpanName = "queryChangeFeed." + this.id; +@@ -1601,6 +1603,130 @@ public class CosmosAsyncContainer { + context); + } + ++ /** ++ * Reads many documents matching the provided partition key values. ++ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries ++ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} ++ * as the base query. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param classType class type ++ * @return a {@link CosmosPagedFlux} containing one or several feed response pages ++ */ ++ public CosmosPagedFlux readManyByPartitionKey( ++ List partitionKeys, ++ Class classType) { ++ ++ return this.readManyByPartitionKey(partitionKeys, null, null, classType); ++ } ++ ++ /** ++ * Reads many documents matching the provided partition key values. ++ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries ++ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} ++ * as the base query. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param requestOptions the optional request options ++ * @param classType class type ++ * @return a {@link CosmosPagedFlux} containing one or several feed response pages ++ */ ++ public CosmosPagedFlux readManyByPartitionKey( ++ List partitionKeys, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) { ++ ++ return this.readManyByPartitionKey(partitionKeys, null, requestOptions, classType); ++ } ++ ++ /** ++ * Reads many documents matching the provided partition key values with a custom query. ++ * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) ++ * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). ++ * The SDK will automatically append partition key filtering to the custom query. ++ *

++ * The custom query must be a simple streamable query ΓÇö aggregates, ORDER BY, DISTINCT, ++ * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be ++ * rejected. ++ *

++ * Partial hierarchical partition keys are supported and will fan out to multiple ++ * physical partitions. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) ++ * @param requestOptions the optional request options ++ * @param classType class type ++ * @return a {@link CosmosPagedFlux} containing one or several feed response pages ++ */ ++ public CosmosPagedFlux readManyByPartitionKey( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) { ++ ++ if (partitionKeys == null) { ++ throw new IllegalArgumentException("Argument 'partitionKeys' must not be null."); ++ } ++ if (partitionKeys.isEmpty()) { ++ throw new IllegalArgumentException("Argument 'partitionKeys' must not be empty."); ++ } ++ for (PartitionKey pk : partitionKeys) { ++ if (pk == null) { ++ throw new IllegalArgumentException( ++ "Argument 'partitionKeys' must not contain null elements."); ++ } ++ } ++ ++ return UtilBridgeInternal.createCosmosPagedFlux( ++ readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); ++ } ++ ++ private Function>> readManyByPartitionKeyInternalFunc( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) { ++ ++ CosmosAsyncClient client = this.getDatabase().getClient(); ++ ++ return (pagedFluxOptions -> { ++ CosmosQueryRequestOptions queryRequestOptions = requestOptions == null ++ ? new CosmosQueryRequestOptions() ++ : queryOptionsAccessor().clone(readManyOptionsAccessor().getImpl(requestOptions)); ++ queryRequestOptions.setMaxDegreeOfParallelism(-1); ++ queryRequestOptions.setQueryName("readManyByPartitionKey"); ++ CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); ++ applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeyItemsSpanName); ++ ++ QueryFeedOperationState state = new QueryFeedOperationState( ++ client, ++ this.readManyByPartitionKeyItemsSpanName, ++ database.getId(), ++ this.getId(), ++ ResourceType.Document, ++ OperationType.Query, ++ queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeyItemsSpanName), ++ queryRequestOptions, ++ pagedFluxOptions ++ ); ++ ++ pagedFluxOptions.setFeedOperationState(state); ++ ++ return CosmosBridgeInternal ++ .getAsyncDocumentClient(this.getDatabase()) ++ .readManyByPartitionKey( ++ partitionKeys, ++ customQuery, ++ BridgeInternal.getLink(this), ++ state, ++ classType) ++ .map(response -> prepareFeedResponse(response, false)); ++ }); ++ } ++ + /** + * Reads all the items of a logical partition + * +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java +index 04a6060c192..0bd8be5850c 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java +@@ -540,6 +540,73 @@ public class CosmosContainer { + classType)); + } + ++ /** ++ * Reads many documents matching the provided partition key values. ++ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries ++ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} ++ * as the base query. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param classType class type ++ * @return a {@link CosmosPagedIterable} containing the results ++ */ ++ public CosmosPagedIterable readManyByPartitionKey( ++ List partitionKeys, ++ Class classType) { ++ ++ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, classType)); ++ } ++ ++ /** ++ * Reads many documents matching the provided partition key values. ++ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries ++ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} ++ * as the base query. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param requestOptions the optional request options ++ * @param classType class type ++ * @return a {@link CosmosPagedIterable} containing the results ++ */ ++ public CosmosPagedIterable readManyByPartitionKey( ++ List partitionKeys, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) { ++ ++ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, requestOptions, classType)); ++ } ++ ++ /** ++ * Reads many documents matching the provided partition key values with a custom query. ++ * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) ++ * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). ++ * The SDK will automatically append partition key filtering to the custom query. ++ *

++ * The custom query must be a simple streamable query ΓÇö aggregates, ORDER BY, DISTINCT, ++ * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be ++ * rejected. ++ *

++ * Partial hierarchical partition keys are supported and will fan out to multiple ++ * physical partitions. ++ * ++ * @param the type parameter ++ * @param partitionKeys list of partition key values to read documents for ++ * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) ++ * @param requestOptions the optional request options ++ * @param classType class type ++ * @return a {@link CosmosPagedIterable} containing the results ++ */ ++ public CosmosPagedIterable readManyByPartitionKey( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) { ++ ++ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, requestOptions, classType)); ++ } ++ + /** + * Reads all the items of a logical partition returning the results as {@link CosmosPagedIterable}. + * +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +index 945e768a82f..8e2499c9039 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +@@ -1584,6 +1584,27 @@ public interface AsyncDocumentClient { + QueryFeedOperationState state, + Class klass); + ++ /** ++ * Reads many documents by partition key values. ++ * Unlike {@link #readMany(List, String, QueryFeedOperationState, Class)} this method does not require ++ * item ids - it queries all documents matching the provided partition key values. ++ * Partial hierarchical partition keys are supported and will fan out to multiple physical partitions. ++ * ++ * @param partitionKeys list of partition key values to read documents for ++ * @param customQuery optional custom query (for projections/additional filters) - null means SELECT * FROM c ++ * @param collectionLink link for the documentcollection/container to be queried ++ * @param state the query operation state ++ * @param klass class type ++ * @param the type parameter ++ * @return a Flux with feed response pages of documents ++ */ ++ Flux> readManyByPartitionKey( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ String collectionLink, ++ QueryFeedOperationState state, ++ Class klass); ++ + /** + * Read all documents of a certain logical partition. + *

+diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java +index 337055c6947..162b0740f40 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java +@@ -248,6 +248,11 @@ public class Configs { + public static final String MIN_TARGET_BULK_MICRO_BATCH_SIZE_VARIABLE = "COSMOS_MIN_TARGET_BULK_MICRO_BATCH_SIZE"; + public static final int DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE = 1; + ++ // readManyByPartitionKey: max number of PK values per query per physical partition ++ private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; ++ private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; ++ private static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 1000; ++ + public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY = "COSMOS.MAX_BULK_MICRO_BATCH_CONCURRENCY"; + public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY_VARIABLE = "COSMOS_MAX_BULK_MICRO_BATCH_CONCURRENCY"; + public static final int DEFAULT_MAX_BULK_MICRO_BATCH_CONCURRENCY = 1; +@@ -816,6 +821,20 @@ public class Configs { + return DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE; + } + ++ public static int getReadManyByPkMaxBatchSize() { ++ String valueFromSystemProperty = System.getProperty(READ_MANY_BY_PK_MAX_BATCH_SIZE); ++ if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { ++ return Math.max(1, Integer.parseInt(valueFromSystemProperty)); ++ } ++ ++ String valueFromEnvVariable = System.getenv(READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE); ++ if (valueFromEnvVariable != null && !valueFromEnvVariable.isEmpty()) { ++ return Math.max(1, Integer.parseInt(valueFromEnvVariable)); ++ } ++ ++ return DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE; ++ } ++ + public static int getMaxBulkMicroBatchConcurrency() { + String valueFromSystemProperty = System.getProperty(MAX_BULK_MICRO_BATCH_CONCURRENCY); + if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +new file mode 100644 +index 00000000000..6d6cd084e01 +--- /dev/null ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +@@ -0,0 +1,263 @@ ++// Copyright (c) Microsoft Corporation. All rights reserved. ++// Licensed under the MIT License. ++package com.azure.cosmos.implementation; ++ ++import com.azure.cosmos.BridgeInternal; ++import com.azure.cosmos.implementation.routing.PartitionKeyInternal; ++import com.azure.cosmos.models.PartitionKey; ++import com.azure.cosmos.models.PartitionKeyDefinition; ++import com.azure.cosmos.models.PartitionKind; ++import com.azure.cosmos.models.SqlParameter; ++import com.azure.cosmos.models.SqlQuerySpec; ++ ++import java.util.ArrayList; ++import java.util.List; ++ ++/** ++ * Helper for constructing SqlQuerySpec instances for readManyByPartitionKey operations. ++ * This class is not intended to be used directly by end-users. ++ */ ++public class ReadManyByPartitionKeyQueryHelper { ++ ++ private static final String DEFAULT_TABLE_ALIAS = "c"; ++ // Internal parameter prefix ΓÇö uses double-underscore to avoid collisions with user-provided parameters ++ private static final String PK_PARAM_PREFIX = "@__rmPk_"; ++ ++ public static SqlQuerySpec createReadManyByPkQuerySpec( ++ String baseQueryText, ++ List baseParameters, ++ List pkValues, ++ List partitionKeySelectors, ++ PartitionKeyDefinition pkDefinition) { ++ ++ // Extract the table alias from the FROM clause (e.g. "FROM x" ΓåÆ "x", "FROM c" ΓåÆ "c") ++ String tableAlias = extractTableAlias(baseQueryText); ++ ++ StringBuilder pkFilter = new StringBuilder(); ++ List parameters = new ArrayList<>(baseParameters); ++ int paramCount = 0; ++ ++ boolean isSinglePathPk = partitionKeySelectors.size() == 1; ++ ++ if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { ++ // Single PK path ΓÇö use IN clause for normal values, OR NOT IS_DEFINED for NONE ++ // First, separate NONE PKs from normal PKs ++ boolean hasNone = false; ++ List normalPkValues = new ArrayList<>(); ++ for (PartitionKey pk : pkValues) { ++ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); ++ if (pkInternal.getComponents() == null) { ++ hasNone = true; ++ } else { ++ normalPkValues.add(pk); ++ } ++ } ++ ++ pkFilter.append(" "); ++ boolean hasNormalValues = !normalPkValues.isEmpty(); ++ if (hasNormalValues && hasNone) { ++ pkFilter.append("("); ++ } ++ if (hasNormalValues) { ++ pkFilter.append(tableAlias); ++ pkFilter.append(partitionKeySelectors.get(0)); ++ pkFilter.append(" IN ( "); ++ for (int i = 0; i < normalPkValues.size(); i++) { ++ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(normalPkValues.get(i)); ++ Object[] pkComponents = pkInternal.toObjectArray(); ++ String pkParamName = PK_PARAM_PREFIX + paramCount; ++ parameters.add(new SqlParameter(pkParamName, pkComponents[0])); ++ paramCount++; ++ ++ pkFilter.append(pkParamName); ++ if (i < normalPkValues.size() - 1) { ++ pkFilter.append(", "); ++ } ++ } ++ pkFilter.append(" )"); ++ } ++ if (hasNone) { ++ if (hasNormalValues) { ++ pkFilter.append(" OR "); ++ } ++ pkFilter.append("NOT IS_DEFINED("); ++ pkFilter.append(tableAlias); ++ pkFilter.append(partitionKeySelectors.get(0)); ++ pkFilter.append(")"); ++ } ++ if (hasNormalValues && hasNone) { ++ pkFilter.append(")"); ++ } ++ } else { ++ // Multiple PK paths (HPK) or MULTI_HASH ΓÇö use OR of AND clauses ++ pkFilter.append(" "); ++ for (int i = 0; i < pkValues.size(); i++) { ++ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); ++ Object[] pkComponents = pkInternal.toObjectArray(); ++ ++ // PartitionKey.NONE ΓÇö generate NOT IS_DEFINED for all PK paths ++ if (pkComponents == null) { ++ pkFilter.append("("); ++ for (int j = 0; j < partitionKeySelectors.size(); j++) { ++ if (j > 0) { ++ pkFilter.append(" AND "); ++ } ++ pkFilter.append("NOT IS_DEFINED("); ++ pkFilter.append(tableAlias); ++ pkFilter.append(partitionKeySelectors.get(j)); ++ pkFilter.append(")"); ++ } ++ pkFilter.append(")"); ++ } else { ++ pkFilter.append("("); ++ for (int j = 0; j < pkComponents.length; j++) { ++ String pkParamName = PK_PARAM_PREFIX + paramCount; ++ parameters.add(new SqlParameter(pkParamName, pkComponents[j])); ++ paramCount++; ++ ++ if (j > 0) { ++ pkFilter.append(" AND "); ++ } ++ pkFilter.append(tableAlias); ++ pkFilter.append(partitionKeySelectors.get(j)); ++ pkFilter.append(" = "); ++ pkFilter.append(pkParamName); ++ } ++ pkFilter.append(")"); ++ } ++ ++ if (i < pkValues.size() - 1) { ++ pkFilter.append(" OR "); ++ } ++ } ++ } ++ ++ // Compose final query: handle existing WHERE clause in base query ++ String finalQuery; ++ int whereIndex = findTopLevelWhereIndex(baseQueryText); ++ if (whereIndex >= 0) { ++ // Base query has WHERE ΓÇö AND our PK filter ++ String beforeWhere = baseQueryText.substring(0, whereIndex); ++ String afterWhere = baseQueryText.substring(whereIndex + 5); // skip "WHERE" ++ finalQuery = beforeWhere + "WHERE (" + afterWhere.trim() + ") AND (" + pkFilter.toString().trim() + ")"; ++ } else { ++ // No WHERE ΓÇö add one ++ finalQuery = baseQueryText + " WHERE" + pkFilter.toString(); ++ } ++ ++ return new SqlQuerySpec(finalQuery, parameters); ++ } ++ ++ /** ++ * Extracts the table/collection alias from a SQL query's FROM clause. ++ * Handles: "SELECT * FROM c", "SELECT x.id FROM x WHERE ...", "SELECT * FROM root r", etc. ++ * Returns the alias used after FROM (last token before WHERE or end of FROM clause). ++ */ ++ static String extractTableAlias(String queryText) { ++ String upper = queryText.toUpperCase(); ++ int fromIndex = findTopLevelKeywordIndex(upper, "FROM"); ++ if (fromIndex < 0) { ++ return DEFAULT_TABLE_ALIAS; ++ } ++ ++ // Start scanning after "FROM" ++ int afterFrom = fromIndex + 4; ++ // Skip whitespace ++ while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { ++ afterFrom++; ++ } ++ ++ // Collect the container name token (could be "root", "c", etc.) ++ int tokenStart = afterFrom; ++ while (afterFrom < queryText.length() ++ && !Character.isWhitespace(queryText.charAt(afterFrom)) ++ && queryText.charAt(afterFrom) != '(' ++ && queryText.charAt(afterFrom) != ')') { ++ afterFrom++; ++ } ++ String containerName = queryText.substring(tokenStart, afterFrom); ++ ++ // Skip whitespace after container name ++ while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { ++ afterFrom++; ++ } ++ ++ // Check if there's an alias after the container name (before WHERE or end) ++ if (afterFrom < queryText.length()) { ++ char nextChar = Character.toUpperCase(queryText.charAt(afterFrom)); ++ // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING) or end, containerName IS the alias ++ if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J' ++ || nextChar == 'L' || nextChar == 'H') { ++ // Check if it's actually a keyword ++ String remaining = upper.substring(afterFrom); ++ if (remaining.startsWith("WHERE") || remaining.startsWith("ORDER") ++ || remaining.startsWith("GROUP") || remaining.startsWith("JOIN") ++ || remaining.startsWith("OFFSET") || remaining.startsWith("LIMIT") ++ || remaining.startsWith("HAVING")) { ++ return containerName; ++ } ++ } ++ // Otherwise the next token is the alias ("FROM root r" ΓåÆ alias is "r") ++ int aliasStart = afterFrom; ++ while (afterFrom < queryText.length() ++ && !Character.isWhitespace(queryText.charAt(afterFrom)) ++ && queryText.charAt(afterFrom) != '(' ++ && queryText.charAt(afterFrom) != ')') { ++ afterFrom++; ++ } ++ if (afterFrom > aliasStart) { ++ return queryText.substring(aliasStart, afterFrom); ++ } ++ } ++ ++ return containerName; ++ } ++ ++ /** ++ * Finds the index of a top-level SQL keyword in the query text (case-insensitive), ++ * ignoring occurrences inside parentheses or string literals. ++ */ ++ static int findTopLevelKeywordIndex(String queryText, String keyword) { ++ String queryTextUpper = queryText.toUpperCase(); ++ String keywordUpper = keyword.toUpperCase(); ++ int depth = 0; ++ int keyLen = keywordUpper.length(); ++ for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { ++ char ch = queryTextUpper.charAt(i); ++ // Skip string literals enclosed in single quotes (handle '' escape) ++ if (queryText.charAt(i) == '\'') { ++ i++; ++ while (i < queryText.length()) { ++ if (queryText.charAt(i) == '\'') { ++ if (i + 1 < queryText.length() && queryText.charAt(i + 1) == '\'') { ++ i += 2; // escaped quote ΓÇö skip both ++ continue; ++ } ++ break; // end of string literal ++ } ++ i++; ++ } ++ continue; ++ } ++ if (ch == '(') { ++ depth++; ++ } else if (ch == ')') { ++ depth--; ++ } else if (depth == 0 && ch == keywordUpper.charAt(0) ++ && queryTextUpper.startsWith(keywordUpper, i) ++ && (i == 0 || !Character.isLetterOrDigit(queryTextUpper.charAt(i - 1))) ++ && (i + keyLen >= queryTextUpper.length() || !Character.isLetterOrDigit(queryTextUpper.charAt(i + keyLen)))) { ++ return i; ++ } ++ } ++ return -1; ++ } ++ ++ /** ++ * Finds the index of the top-level WHERE keyword in the query text, ++ * ignoring WHERE that appears inside parentheses (subqueries). ++ */ ++ public static int findTopLevelWhereIndex(String queryTextUpper) { ++ return findTopLevelKeywordIndex(queryTextUpper, "WHERE"); ++ } ++} +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +index 11121bca033..c70dedaa1f2 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +@@ -4365,6 +4365,298 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization + ); + } + ++ @Override ++ public Flux> readManyByPartitionKey( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ String collectionLink, ++ QueryFeedOperationState state, ++ Class klass) { ++ ++ checkNotNull(partitionKeys, "Argument 'partitionKeys' must not be null."); ++ checkArgument(!partitionKeys.isEmpty(), "Argument 'partitionKeys' must not be empty."); ++ ++ final ScopedDiagnosticsFactory diagnosticsFactory = new ScopedDiagnosticsFactory(this, true); ++ state.registerDiagnosticsFactory( ++ () -> {}, // we never want to reset in readManyByPartitionKey ++ (ctx) -> diagnosticsFactory.merge(ctx) ++ ); ++ ++ StaleResourceRetryPolicy staleResourceRetryPolicy = new StaleResourceRetryPolicy( ++ this.collectionCache, ++ null, ++ collectionLink, ++ queryOptionsAccessor().getProperties(state.getQueryOptions()), ++ queryOptionsAccessor().getHeaders(state.getQueryOptions()), ++ this.sessionContainer, ++ diagnosticsFactory, ++ ResourceType.Document ++ ); ++ ++ return ObservableHelper ++ .fluxInlineIfPossibleAsObs( ++ () -> readManyByPartitionKey( ++ partitionKeys, customQuery, collectionLink, state, diagnosticsFactory, klass), ++ staleResourceRetryPolicy ++ ) ++ .onErrorMap(throwable -> { ++ if (throwable instanceof CosmosException) { ++ CosmosException cosmosException = (CosmosException) throwable; ++ CosmosDiagnostics diagnostics = cosmosException.getDiagnostics(); ++ if (diagnostics != null) { ++ state.mergeDiagnosticsContext(); ++ CosmosDiagnosticsContext ctx = state.getDiagnosticsContextSnapshot(); ++ if (ctx != null) { ++ ctxAccessor().recordOperation( ++ ctx, ++ cosmosException.getStatusCode(), ++ cosmosException.getSubStatusCode(), ++ 0, ++ cosmosException.getRequestCharge(), ++ diagnostics, ++ throwable ++ ); ++ diagAccessor() ++ .setDiagnosticsContext( ++ diagnostics, ++ state.getDiagnosticsContextSnapshot()); ++ } ++ } ++ ++ return cosmosException; ++ } ++ ++ return throwable; ++ }); ++ } ++ ++ private Flux> readManyByPartitionKey( ++ List partitionKeys, ++ SqlQuerySpec customQuery, ++ String collectionLink, ++ QueryFeedOperationState state, ++ ScopedDiagnosticsFactory diagnosticsFactory, ++ Class klass) { ++ ++ String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); ++ RxDocumentServiceRequest request = RxDocumentServiceRequest.create(diagnosticsFactory, ++ OperationType.Query, ++ ResourceType.Document, ++ collectionLink, null ++ ); ++ ++ Mono> collectionObs = ++ collectionCache.resolveCollectionAsync(null, request); ++ ++ return collectionObs ++ .flatMapMany(documentCollectionResourceResponse -> { ++ final DocumentCollection collection = documentCollectionResourceResponse.v; ++ if (collection == null) { ++ return Flux.error(new IllegalStateException("Collection cannot be null")); ++ } ++ ++ final PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); ++ ++ Mono> valueHolderMono = partitionKeyRangeCache ++ .tryLookupAsync( ++ BridgeInternal.getMetaDataDiagnosticContext(request.requestContext.cosmosDiagnostics), ++ collection.getResourceId(), ++ null, ++ null); ++ ++ // Validate custom query if provided ++ Mono queryValidationMono; ++ if (customQuery != null) { ++ queryValidationMono = validateCustomQueryForReadManyByPartitionKey( ++ customQuery, resourceLink, state.getQueryOptions()); ++ } else { ++ queryValidationMono = Mono.empty(); ++ } ++ ++ return valueHolderMono ++ .delayUntil(ignored -> queryValidationMono) ++ .flatMapMany(routingMapHolder -> { ++ CollectionRoutingMap routingMap = routingMapHolder.v; ++ if (routingMap == null) { ++ return Flux.error(new IllegalStateException("Failed to get routing map.")); ++ } ++ ++ Map> partitionRangePkMap = ++ groupPartitionKeysByPhysicalPartition(partitionKeys, pkDefinition, routingMap); ++ ++ List partitionKeySelectors = createPkSelectors(pkDefinition); ++ ++ String baseQueryText; ++ List baseParameters; ++ if (customQuery != null) { ++ baseQueryText = customQuery.getQueryText(); ++ baseParameters = customQuery.getParameters() != null ++ ? new ArrayList<>(customQuery.getParameters()) ++ : new ArrayList<>(); ++ } else { ++ baseQueryText = "SELECT * FROM c"; ++ baseParameters = new ArrayList<>(); ++ } ++ ++ // Build per-physical-partition batched queries. ++ // Each physical partition may have many PKs ΓÇö split into batches ++ // to avoid oversized SQL queries. Batch size is configurable via ++ // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 1000). ++ int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); ++ ++ // Build batches per partition as a list of lists (one inner list per partition). ++ // Then interleave in round-robin order so that concurrent execution ++ // prefers different physical partitions over multiple batches of the same partition. ++ List>> batchesPerPartition = new ArrayList<>(); ++ int maxBatchesPerPartition = 0; ++ ++ for (Map.Entry> entry : partitionRangePkMap.entrySet()) { ++ List allPks = entry.getValue(); ++ if (allPks.isEmpty()) { ++ continue; ++ } ++ List> partitionBatches = new ArrayList<>(); ++ for (int i = 0; i < allPks.size(); i += maxPksPerPartitionQuery) { ++ List batch = allPks.subList( ++ i, Math.min(i + maxPksPerPartitionQuery, allPks.size())); ++ SqlQuerySpec querySpec = ReadManyByPartitionKeyQueryHelper ++ .createReadManyByPkQuerySpec( ++ baseQueryText, baseParameters, batch, ++ partitionKeySelectors, pkDefinition); ++ partitionBatches.add( ++ Collections.singletonMap(entry.getKey(), querySpec)); ++ } ++ batchesPerPartition.add(partitionBatches); ++ maxBatchesPerPartition = Math.max(maxBatchesPerPartition, partitionBatches.size()); ++ } ++ ++ if (batchesPerPartition.isEmpty()) { ++ return Flux.empty(); ++ } ++ ++ // Round-robin interleave: [batch0-p1, batch0-p2, ..., batch0-pN, batch1-p1, batch1-p2, ...] ++ // This ensures that with bounded concurrency, different partitions are ++ // preferred over sequential batches of the same partition. ++ List> interleavedBatches = new ArrayList<>(); ++ for (int batchIdx = 0; batchIdx < maxBatchesPerPartition; batchIdx++) { ++ for (List> partitionBatches : batchesPerPartition) { ++ if (batchIdx < partitionBatches.size()) { ++ interleavedBatches.add(partitionBatches.get(batchIdx)); ++ } ++ } ++ } ++ ++ // Execute all batches with bounded concurrency. ++ List>> queryFluxes = interleavedBatches ++ .stream() ++ .map(batchMap -> queryForReadMany( ++ diagnosticsFactory, ++ resourceLink, ++ new SqlQuerySpec(DUMMY_SQL_QUERY), ++ state.getQueryOptions(), ++ klass, ++ ResourceType.Document, ++ collection, ++ Collections.unmodifiableMap(batchMap))) ++ .collect(Collectors.toList()); ++ ++ int fluxConcurrency = Math.min(queryFluxes.size(), ++ Math.max(Configs.getCPUCnt(), 1)); ++ ++ return Flux.merge(Flux.fromIterable(queryFluxes), fluxConcurrency, 1); ++ }); ++ }); ++ } ++ ++ private Mono validateCustomQueryForReadManyByPartitionKey( ++ SqlQuerySpec customQuery, ++ String resourceLink, ++ CosmosQueryRequestOptions queryRequestOptions) { ++ ++ IDocumentQueryClient queryClient = documentQueryClientImpl( ++ RxDocumentClientImpl.this, getOperationContextAndListenerTuple(queryRequestOptions)); ++ ++ return DocumentQueryExecutionContextFactory ++ .fetchQueryPlanForValidation(this, queryClient, customQuery, resourceLink, queryRequestOptions) ++ .flatMap(queryPlan -> { ++ QueryInfo queryInfo = queryPlan.getQueryInfo(); ++ ++ if (queryInfo.hasGroupBy()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain GROUP BY.")); ++ } ++ if (queryInfo.hasAggregates()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain aggregates.")); ++ } ++ if (queryInfo.hasOrderBy()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain ORDER BY.")); ++ } ++ if (queryInfo.hasDistinct()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain DISTINCT.")); ++ } ++ if (queryInfo.hasDCount()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain DCOUNT.")); ++ } ++ if (queryInfo.hasNonStreamingOrderBy()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain non-streaming ORDER BY.")); ++ } ++ if (queryPlan.hasHybridSearchQueryInfo()) { ++ return Mono.error(new IllegalArgumentException( ++ "Custom query for readMany by partition key must not contain hybrid/vector/full-text search.")); ++ } ++ ++ return Mono.empty(); ++ }); ++ } ++ ++ private Map> groupPartitionKeysByPhysicalPartition( ++ List partitionKeys, ++ PartitionKeyDefinition pkDefinition, ++ CollectionRoutingMap routingMap) { ++ ++ Map> partitionRangePkMap = new HashMap<>(); ++ ++ for (PartitionKey pk : partitionKeys) { ++ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); ++ ++ // PartitionKey.NONE wraps NonePartitionKey which has components = null. ++ // For routing purposes, treat NONE as UndefinedPartitionKey ΓÇö documents ingested ++ // without a partition key path are stored with the undefined EPK. ++ PartitionKeyInternal effectivePkInternal = pkInternal.getComponents() == null ++ ? PartitionKeyInternal.UndefinedPartitionKey ++ : pkInternal; ++ ++ int componentCount = effectivePkInternal.getComponents().size(); ++ int definedPathCount = pkDefinition.getPaths().size(); ++ ++ List targetRanges; ++ ++ if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { ++ // Partial HPK ΓÇö compute EPK prefix range and find all overlapping physical partitions ++ Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( ++ effectivePkInternal, pkDefinition); ++ targetRanges = routingMap.getOverlappingRanges(epkRange); ++ } else { ++ // Full PK ΓÇö maps to exactly one physical partition ++ String effectivePartitionKeyString = PartitionKeyInternalHelper ++ .getEffectivePartitionKeyString(effectivePkInternal, pkDefinition); ++ PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); ++ targetRanges = Collections.singletonList(range); ++ } ++ ++ for (PartitionKeyRange range : targetRanges) { ++ partitionRangePkMap.computeIfAbsent(range, k -> new ArrayList<>()).add(pk); ++ } ++ } ++ ++ return partitionRangePkMap; ++ } ++ + private Map getRangeQueryMap( + Map> partitionRangeItemKeyMap, + PartitionKeyDefinition partitionKeyDefinition) { +diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java +index e62d8ed3d75..d8f9614343c 100644 +--- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java ++++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java +@@ -318,6 +318,17 @@ public class DocumentQueryExecutionContextFactory { + return feedRanges; + } + ++ public static Mono fetchQueryPlanForValidation( ++ DiagnosticsClientContext diagnosticsClientContext, ++ IDocumentQueryClient queryClient, ++ SqlQuerySpec sqlQuerySpec, ++ String resourceLink, ++ CosmosQueryRequestOptions queryRequestOptions) { ++ ++ return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( ++ diagnosticsClientContext, queryClient, sqlQuerySpec, resourceLink, queryRequestOptions); ++ } ++ + public static Flux> createDocumentQueryExecutionContextAsync( + DiagnosticsClientContext diagnosticsClientContext, + IDocumentQueryClient client, +diff --git a/sdk/cosmos/cspell.yaml b/sdk/cosmos/cspell.yaml +new file mode 100644 +index 00000000000..94a4002c2c9 +--- /dev/null ++++ b/sdk/cosmos/cspell.yaml +@@ -0,0 +1,6 @@ ++import: ++ - ../../.vscode/cspell.json ++overrides: ++ - filename: "**/sdk/cosmos/*" ++ words: ++ - DCOUNT +diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md +new file mode 100644 +index 00000000000..95d7624f0c8 +--- /dev/null ++++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md +@@ -0,0 +1,169 @@ ++# readManyByPartitionKey ΓÇö Design & Implementation ++ ++## Overview ++ ++New `readManyByPartitionKey` methods on `CosmosAsyncContainer` / `CosmosContainer` that accept a ++`List` (without item-id). The SDK splits the PK values by physical ++partition, generates batched streaming queries per physical partition, and returns results as ++`CosmosPagedFlux` / `CosmosPagedIterable`. ++ ++An optional `SqlQuerySpec` parameter lets callers supply a custom query for projections ++and additional filters. The SDK appends the auto-generated PK WHERE clause to it. ++ ++## Decisions ++ ++| Topic | Decision | ++|---|---| ++| API name | `readManyByPartitionKey` ΓÇö distinct name to avoid ambiguity with existing `readMany(List)` | ++| Return type | `CosmosPagedFlux` (async) / `CosmosPagedIterable` (sync) | ++| Custom query format | `SqlQuerySpec` ΓÇö full query with parameters; SDK ANDs the PK filter | ++| Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | ++| PK deduplication | Done at Spark layer only, not in the SDK | ++| Spark UDF | New `GetCosmosPartitionKeyValue` UDF | ++| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/non-streaming ORDER BY/vector/fulltext | ++| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 1000 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | ++| Eager validation | Null and empty PK list rejected eagerly (not lazily in reactive chain) | ++| Telemetry | Separate span name `readManyByPartitionKeyItems.` (distinct from existing `readManyItems`) | ++| Query construction | Table alias auto-detected from FROM clause; string literals and subqueries handled correctly | ++ ++## Phase 1 ΓÇö SDK Core (`azure-cosmos`) ++ ++### Step 1: New public overloads in CosmosAsyncContainer ++ ++```java ++ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, Class classType) ++ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) ++ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, ++ SqlQuerySpec customQuery, ++ CosmosReadManyRequestOptions requestOptions, ++ Class classType) ++``` ++ ++All delegate to a private `readManyByPartitionKeyInternalFunc(...)`. ++ ++**Eager validation:** The 4-arg method validates `partitionKeys` is non-null and non-empty before constructing the reactive pipeline, throwing `IllegalArgumentException` synchronously. ++ ++### Step 2: Sync wrappers in CosmosContainer ++ ++Same signatures returning `CosmosPagedIterable`, delegating to the async container. ++ ++### Step 3: Internal orchestration (RxDocumentClientImpl) ++ ++1. Resolve collection metadata + PK definition from cache. ++2. Fetch routing map from `partitionKeyRangeCache` **in parallel with** custom query validation (Step 4). ++3. For each `PartitionKey`: ++ - Compute effective partition key (EPK). ++ - Full PK ΓåÆ `getRangeByEffectivePartitionKey()` (single range). ++ - Partial HPK ΓåÆ compute EPK prefix range ΓåÆ `getOverlappingRanges()` (multiple ranges). ++ **Note:** partial HPK intentionally fans out to multiple physical partitions. ++4. Group PK values by `PartitionKeyRange`. ++5. Per physical partition ΓåÆ split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 1000). ++6. Per batch ΓåÆ build `SqlQuerySpec` with PK WHERE clause (Step 5). ++7. Interleave batches across physical partitions in round-robin order so that bounded concurrency prefers different physical partitions over sequential batches of the same partition. ++8. Execute queries via `queryForReadMany()` with bounded concurrency (`Math.min(batchCount, cpuCount)`). ++9. Return results as `CosmosPagedFlux`. ++ ++### Step 4: Custom query validation ++ ++One-time call per invocation (existing query plan caching applies). Runs **in parallel** with routing map lookup to minimize latency: ++ ++- `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. ++- Reject (`IllegalArgumentException`) if: ++ - `queryInfo.hasGroupBy()` ΓÇö checked first (takes precedence over aggregates since `hasAggregates()` also returns true for GROUP BY queries) ++ - `queryInfo.hasAggregates()` ++ - `queryInfo.hasOrderBy()` ++ - `queryInfo.hasDistinct()` ++ - `queryInfo.hasDCount()` ++ - `queryInfo.hasNonStreamingOrderBy()` ++ - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` ++ ++### Step 5: Query construction ++ ++Query construction is implemented in `ReadManyByPartitionKeyQueryHelper`. The helper: ++- Extracts the table alias from the FROM clause (handles `FROM c`, `FROM root r`, `FROM x WHERE ...`) ++- Handles string literals in queries (parens/keywords inside `'...'` are correctly skipped) ++- Recognizes SQL keywords: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING ++- Uses parameterized queries (`@__rmPk_` prefix) to prevent SQL injection ++ ++**Single PK (HASH):** ++```sql ++{baseQuery} WHERE {alias}["{pkPath}"] IN (@__rmPk_0, @__rmPk_1, @__rmPk_2) ++``` ++ ++**Full HPK (MULTI_HASH):** ++```sql ++{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0 AND {alias}["{path2}"] = @__rmPk_1) ++ OR ({alias}["{path1}"] = @__rmPk_2 AND {alias}["{path2}"] = @__rmPk_3) ++``` ++ ++**Partial HPK (prefix-only):** ++```sql ++{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0) ++ OR ({alias}["{path1}"] = @__rmPk_1) ++``` ++ ++If the base query already has a WHERE clause: ++```sql ++{selectAndFrom} WHERE ({existingWhere}) AND ({pkFilter}) ++``` ++ ++### Step 6: Interface wiring ++ ++New method `readManyByPartitionKey` added directly to `AsyncDocumentClient` interface, implemented in `RxDocumentClientImpl`. New `fetchQueryPlanForValidation` static method added to `DocumentQueryExecutionContextFactory` for custom query validation. ++ ++### Step 7: Configuration ++ ++New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 1000, minimum: 1). Follows existing `Configs` patterns. ++ ++## Phase 2 ΓÇö Spark Connector (`azure-cosmos-spark_3`) ++ ++### Step 8: New UDF ΓÇö `GetCosmosPartitionKeyValue` ++ ++- Input: partition key value (single value or Seq for hierarchical PKs). ++- Output: serialized PK string in format `pk([...json...])`. ++- **Null handling:** Throws on null input (Scala convention; callers should filter nulls upstream). ++ ++### Step 9: PK-only serialization helper ++ ++`CosmosPartitionKeyHelper`: ++- `getCosmosPartitionKeyValueString(pkValues: List[Object]): String` ΓÇö serialize to `pk([...])` format. ++- `tryParsePartitionKey(serialized: String): Option[PartitionKey]` ΓÇö deserialize; returns `None` for malformed input including invalid JSON (wrapped in `scala.util.Try`). ++ ++### Step 10: `CosmosItemsDataSource.readManyByPartitionKey` ++ ++Static entry points that accept a DataFrame and Cosmos config. PK extraction supports two modes: ++1. **UDF-produced column**: DataFrame contains `_partitionKeyIdentity` column (from `GetCosmosPartitionKeyValue` UDF). ++2. **Schema-matched columns**: DataFrame columns match the container's PK paths. ++ ++Falls back with `IllegalArgumentException` if neither mode is possible. ++ ++### Step 11: `CosmosReadManyByPartitionKeyReader` ++ ++Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. ++ ++### Step 12: `ItemsPartitionReaderWithReadManyByPartitionKey` ++ ++Spark `PartitionReader[InternalRow]` that: ++- Deduplicates PKs via `LinkedHashMap` (by PK string representation). ++- Passes the pre-built `CosmosReadManyRequestOptions` (with throughput control, diagnostics, custom serializer) to the SDK. ++- Uses `TransientIOErrorsRetryingIterator` for retry handling. ++- Short-circuits empty PK lists to avoid SDK rejection. ++ ++## Phase 3 ΓÇö Testing ++ ++### Unit tests ++- Query construction: single PK, HPK full/partial, custom query composition, table alias detection. ++- Query plan rejection: aggregates, ORDER BY, DISTINCT, GROUP BY (with and without aggregates), DCOUNT. ++- String literal handling: WHERE/parentheses inside string constants. ++- Keyword detection: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING. ++- PK serialization/deserialization roundtrip (including malformed JSON handling). ++- `findTopLevelWhereIndex` edge cases: subqueries, string literals, case insensitivity. ++ ++### Integration tests ++- End-to-end SDK: single PK basic, projections, filters, empty results, HPK full/partial, request options propagation. ++- Batch size validation: temporarily lowered batch size to exercise batching/interleaving logic. ++- Null/empty PK list rejection (eager validation). ++- Spark connector: `ItemsPartitionReaderWithReadManyByPartitionKey` with known PK values and non-existent PKs. ++- `CosmosPartitionKeyHelper`: single/HPK roundtrip, case insensitivity, malformed input. + diff --git a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt new file mode 100644 index 000000000000..360eb15accb5 --- /dev/null +++ b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt @@ -0,0 +1,76 @@ +===== PR #LOCAL-users_fabianm_readManyByPK ===== +Title: Branch comparison users/fabianm/readManyByPK vs origin/main +Author: Fabian Meiswinkel +Status: DIVERGED (ahead 30, behind 4) +Branch: users/fabianm/readManyByPK -> origin/main +Head SHA: 93957f3a8442d730fe67fbc379ef5399f46f5665 +Merge Base: 20313f79ba8dd0dfa97862d0c31dd4b2e44ee671 +URL: N/A (local branch comparison) + +--- Description --- +Adds readManyByPartitionKey API (sync+async) and Spark connector support for PK-only reads, with query-plan-based validation +--- End Description --- + +===== Commits in PR ===== +9770833eb59 Adding readManyByPartitionKey API +ac287bcdf00 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +9a5b3e96e7e Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +a8720c3c9f2 Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +d499da76fb4 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +c3c542a33a7 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +4416354e03e ┬┤Fixing code review comments +3ab3f0d64f5 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +588a7550c54 Update CosmosAsyncContainer.java +8c5cdb47b31 Merge branch 'main' into users/fabianm/readManyByPK +f5485527a9c Update ReadManyByPartitionKeyTest.java +f68cf02ff71 Fixing test issues +8b6c4b168ea Update CosmosAsyncContainer.java +8ba7f4db2da Merge branch 'main' into users/fabianm/readManyByPK +56b067a9339 Reacted to code review feedback +fa430e918fa Merge branch 'main' into users/fabianm/readManyByPK +d9504c91f34 Fix build issues +73151f09e5f Merge branch 'main' into users/fabianm/readManyByPK +681830e2d4a Fixing changelog +7f745e60641 Merge branch 'main' into users/fabianm/readManyByPK +0b8905dbb01 Addressing code review comments +22abc780ed8 Addressing code review feedback +662b1a4b90e Update CosmosItemsDataSource.scala +c764de9de02 Update CosmosItemsDataSource.scala +e1e6f5a6f73 Merge branch 'main' into users/fabianm/readManyByPK +080ce4a2293 Update RxDocumentClientImpl.java +516bbf3a95a Merge branch 'users/fabianm/readManyByPK' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK +b01f8758eea Fix readManyByPartitionKey retries +7130d4aa35a Fix PK.None +93957f3a844 Update ReadManyByPartitionKeyQueryHelper.java + +===== Files Changed ===== + sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala (+20 -2) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala (+1 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala (+125 -1) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala (+45 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala (+150 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala (+249 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala (+259 -0) + sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala (+25 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala (+42 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala (+104 -0) + sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala (+158 -0) + sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java (+462 -0) + sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java (+426 -0) + sdk/cosmos/azure-cosmos/CHANGELOG.md (+1 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java (+126 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java (+67 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java (+21 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java (+19 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java (+263 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java (+292 -0) + sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java (+11 -0) + sdk/cosmos/cspell.yaml (+6 -0) + sdk/cosmos/docs/readManyByPartitionKey-design.md (+169 -0) + + diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index fe114462019c..2b9b9cda9361 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added -* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. Configure null handling via `spark.cosmos.read.readManyByPk.nullHandling` - default `Null` treats a null PK column as JSON null (`addNullValue`), `None` treats it as `PartitionKey.NONE` (`addNoneValue` / `NOT IS_DEFINED`). These route to different physical partitions - picking the wrong mode silently returns zero rows. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index 3b2c7ce36db1..771f5974be70 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added -* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. Configure null handling via `spark.cosmos.read.readManyByPk.nullHandling` - default `Null` treats a null PK column as JSON null (`addNullValue`), `None` treats it as `PartitionKey.NONE` (`addNoneValue` / `NOT IS_DEFINED`). These route to different physical partitions - picking the wrong mode silently returns zero rows. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 2240a48b1654..021d2e66929d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added -* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. Configure null handling via `spark.cosmos.read.readManyByPk.nullHandling` - default `Null` treats a null PK column as JSON null (`addNullValue`), `None` treats it as `PartitionKey.NONE` (`addNoneValue` / `NOT IS_DEFINED`). These route to different physical partitions - picking the wrong mode silently returns zero rows. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md index 20a3e3a61bd7..f512a65dc491 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added -* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. Configure null handling via `spark.cosmos.read.readManyByPk.nullHandling` - default `Null` treats a null PK column as JSON null (`addNullValue`), `None` treats it as `PartitionKey.NONE` (`addNoneValue` / `NOT IS_DEFINED`). These route to different physical partitions - picking the wrong mode silently returns zero rows. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index e1b8f0b51f8a..4a483ef38a6b 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -1144,11 +1144,12 @@ private object CosmosReadConfig { mandatory = false, defaultValue = Some("Null"), parseFromStringFunction = value => value, - helpMessage = "Determines how null values in hierarchical partition key components are treated " + - "for readManyByPartitionKey. 'Null' (default) maps null to a JSON null value via addNullValue(), " + - "which is appropriate when the document field exists with an explicit null value. " + - "'None' maps null to PartitionKey.NONE via addNoneValue(), which should only be used when the " + - "partition key path does not exist at all in the document." + helpMessage = "Determines how null values in partition key columns are treated for " + + "readManyByPartitionKey. 'Null' (default) maps null to a JSON null via addNullValue(), which " + + "is appropriate when the document field exists with an explicit null value. 'None' maps null " + + "to PartitionKey.NONE via addNoneValue(), which should only be used when the partition key " + + "path does not exist at all in the document. These two semantics hash to DIFFERENT physical " + + "partitions - picking the wrong mode for your data will silently return zero rows." ) def parseCosmosReadConfig(cfg: Map[String, String]): CosmosReadConfig = { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index 86ef865bcb83..82e3158d765d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -5,7 +5,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey, PartitionKeyBuilder} import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -import com.azure.cosmos.{SparkBridgeInternal} +import com.azure.cosmos.SparkBridgeInternal import org.apache.spark.sql.{DataFrame, Row, SparkSession} import java.util @@ -184,6 +184,15 @@ object CosmosItemsDataSource { pkDefinition.getPaths.asScala.map(_.stripPrefix("/")).toList }) + // Nested PK paths (containing /) cannot be resolved from top-level DataFrame columns. + // Surface an explicit error so users know to use the UDF-produced _partitionKeyIdentity column. + if (pkPaths.exists(_.contains("/"))) { + throw new IllegalArgumentException( + "Container has nested partition key path(s) " + pkPaths.mkString("[", ",", "]") + ". " + + "Nested paths cannot be resolved from DataFrame columns automatically - add a " + + "'_partitionKeyIdentity' column produced by the GetCosmosPartitionKeyValue UDF.") + } + // Check if ALL PK path columns exist in the DataFrame schema val dfFieldNames = df.schema.fieldNames.toSet val allPkColumnsPresent = pkPaths.forall(path => dfFieldNames.contains(path)) @@ -195,7 +204,7 @@ object CosmosItemsDataSource { // Single partition key buildPartitionKey(row.getAs[Any](pkPaths.head), treatNullAsNone) } else { - // Hierarchical partition key — build level by level + // Hierarchical partition key - build level by level val builder = new PartitionKeyBuilder() for (path <- pkPaths) { addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone) @@ -227,7 +236,15 @@ object CosmosItemsDataSource { case null => if (treatNullAsNone) builder.addNoneValue() else builder.addNullValue() - case other => builder.add(other.toString) + case other => + // Reject unknown types rather than silently .toString-ing them - the document field + // was stored with its original type and a stringified value will never match. + // Supported types: String, Number (Byte/Short/Int/Long/Float/Double/BigDecimal), Boolean, null. + throw new IllegalArgumentException( + s"Unsupported partition key column type '${other.getClass.getName}' with value '$other'. " + + "Supported types are String, Number (integral or floating-point), Boolean, and null. " + + "For other source types, convert the column before calling readManyByPartitionKey or use " + + "the GetCosmosPartitionKeyValue UDF to produce a '_partitionKeyIdentity' column.") } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala index 91f3a56bc664..207f77eb4766 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -71,7 +71,12 @@ private[spark] class CosmosReadManyByPartitionKeyReader( classOf[ObjectNode]) .block() } catch { - case _: CosmosException => None + // The warm-up readItem is only used to hydrate the collection/routing-map caches. + // A 404 (item not found) is expected, but we log other CosmosExceptions at debug to + // aid diagnosis (auth failures, throttling, etc.) while not failing reader setup. + case ex: CosmosException => + logDebug(s"Warm-up readItem for metadata caches completed with exception: ${ex.getMessage}", ex) + None } val state = new CosmosClientMetadataCachesSnapshot() diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index c67cc9c10be1..7fb6e0eb0aba 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -6,6 +6,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} +import com.azure.cosmos.BridgeInternal import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition, SqlQuerySpec} import com.azure.cosmos.spark.BulkWriter.getThreadInfo import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName @@ -154,13 +155,17 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey } ) - // Collect all PK values upfront — readManyByPartitionKey needs the full list to - // group by physical partition and issue parallel queries. - // Deduplicate by PK string representation — safe because the list size is bounded - // by the per-call limit of the readManyByPartitionKey API. + // Collect all PK values upfront - readManyByPartitionKey needs the full list to + // group by physical partition (the SDK batches internally per physical partition). + // Deduplicate using the canonical PartitionKeyInternal JSON representation so that + // equivalent PKs built from different runtime types (Int vs Long vs Double) are + // collapsed, and distinct PKs that happen to toString() identically are not. private lazy val pkList = { val seen = new java.util.LinkedHashMap[String, PartitionKey]() - readManyPartitionKeys.foreach(pk => seen.putIfAbsent(pk.toString, pk)) + readManyPartitionKeys.foreach(pk => { + val key = BridgeInternal.getPartitionKeyInternal(pk).toJson + seen.putIfAbsent(key, pk) + }) new java.util.ArrayList[PartitionKey](seen.values()) } @@ -188,9 +193,10 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey override def close(): Unit = {} } - // Batch partition keys and retry each batch independently on transient I/O errors. - // This avoids the continuation-token problem with TransientIOErrorsRetryingIterator - // where a retry would re-read all data from scratch, causing silent data duplication. + // Pass the full PK list to the SDK (which batches per physical partition internally). + // On transient I/O failures the retry iterator tracks pages already emitted upstream + // and skips them on replay; if a failure occurs mid-page (after items from that page + // have been emitted) the task fails rather than risking row duplication. private lazy val iterator: CloseableSparkRowItemIterator = if (pkList.isEmpty) { EmptySparkRowItemIterator diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala index dcfdf4f93536..f82855218cce 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala @@ -16,16 +16,12 @@ import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.Random import scala.util.control.Breaks -// scalastyle:off underscore.import -import scala.collection.JavaConverters._ -// scalastyle:on underscore.import - /** - * Retry-safe iterator for readManyByPartitionKey that batches partition keys and lazily - * iterates pages within each batch via CosmosPagedIterable — consistent with how - * TransientIOErrorsRetryingIterator handles normal queries. On transient I/O errors the - * current batch's flux is recreated and pages already consumed are replayed, avoiding - * the memory overhead of collectList and matching the query iterator's structure. + * Retry-safe iterator for readManyByPartitionKey. The full partition-key list is passed to the + * SDK in a single call - the SDK is responsible for fan-out and per-physical-partition batching + * (see Configs.getReadManyByPkMaxBatchSize()). This iterator therefore wraps a single + * CosmosPagedIterable and, on transient I/O failures, re-creates the underlying flux and + * skips the pages that were already emitted upstream so no row is delivered twice. */ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] ( @@ -57,13 +53,20 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp case None => "n/a" } + // Number of pages that have been fully emitted upstream. On retry, we recreate the flux + // and skip this many pages before emitting any item, so already-delivered rows are not + // re-emitted. A page is "committed" as soon as we surface its first item to the caller - + // subsequent failures while still inside that page cannot be recovered from without + // risking duplication, so we fail fast in that case. + private var pagesCommitted: Long = 0 + // Whether the currently-buffered page has emitted at least one item. If true, we have + // passed the point of no return for this page: any transient failure here must surface, + // because we cannot partially-skip within a page on retry. + private var currentPagePartiallyConsumed: Boolean = false + private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None - private val pkBatchIterator = partitionKeys.asScala.iterator.grouped(pageSize) - // Track the current batch so we can replay it on retry - private var currentBatch: Option[java.util.List[PartitionKey]] = None - override def hasNext: Boolean = { executeWithRetry("hasNextInternal", () => hasNextInternal) } @@ -85,38 +88,39 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp val feedResponseIterator = currentFeedResponseIterator match { case Some(existing) => existing case None => - // Need a new feed response iterator — either for the current batch (on retry) - // or for the next batch - val batch = currentBatch match { - case Some(b) => b // retry of current batch - case None => - if (pkBatchIterator.hasNext) { - val nextBatch = new java.util.ArrayList[PartitionKey](pkBatchIterator.next().toList.asJava) - currentBatch = Some(nextBatch) - nextBatch - } else { - return Some(false) // no more batches - } - } - val pagedFlux = customQuery match { case Some(query) => - container.readManyByPartitionKey(batch, query, queryOptions, classType) + container.readManyByPartitionKey(partitionKeys, query, queryOptions, classType) case None => - container.readManyByPartitionKey(batch, queryOptions, classType) + container.readManyByPartitionKey(partitionKeys, queryOptions, classType) } - currentFeedResponseIterator = Some( - new CosmosPagedIterable[TSparkRow]( - pagedFlux, - pageSize, - pagePrefetchBufferSize - ) - .iterableByPage() - .iterator - .asScala - .buffered + val rawIterator = new CosmosPagedIterable[TSparkRow]( + pagedFlux, + pageSize, + pagePrefetchBufferSize ) + .iterableByPage() + .iterator + + // Skip pages already emitted upstream (replay-safe retry). + var skipped: Long = 0 + while (skipped < pagesCommitted && rawIterator.hasNext) { + rawIterator.next() + skipped += 1 + } + if (skipped < pagesCommitted) { + // The server returned fewer pages than before - cannot safely replay. + // Surface a clean error rather than silently emitting a truncated result. + throw new IllegalStateException( + s"readManyByPartitionKey retry replay failed: expected to skip $pagesCommitted " + + s"already-emitted pages but only $skipped were available. Context: $operationContextString") + } + + // scalastyle:off underscore.import + import scala.collection.JavaConverters._ + // scalastyle:on underscore.import + currentFeedResponseIterator = Some(rawIterator.asScala.buffered) currentFeedResponseIterator.get } @@ -152,20 +156,24 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp operationContextAndListener.get.getOperationContext, feedResponse) } + // scalastyle:off underscore.import + import scala.collection.JavaConverters._ + // scalastyle:on underscore.import val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered if (iteratorCandidate.hasNext) { currentItemIterator = Some(iteratorCandidate) + currentPagePartiallyConsumed = false Some(true) } else { - // empty page interleaved — try again + // empty page - count it as committed (no items to replay) and try again + pagesCommitted += 1 None } } else { - // Current batch's flux is exhausted — move to next batch - currentBatch = None + // Flux exhausted currentFeedResponseIterator = None - None + Some(false) } } } @@ -175,6 +183,9 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp case Some(iterator) => if (iterator.hasNext) { true } else { + // Entire page drained -> it is now committed for replay-skipping purposes. + pagesCommitted += 1 + currentPagePartiallyConsumed = false currentItemIterator = None false } @@ -183,11 +194,15 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp } override def next(): TSparkRow = { - currentItemIterator.get.next() + executeWithRetry("next", () => { + val value = currentItemIterator.get.next() + currentPagePartiallyConsumed = true + value + }) } - override def head(): TSparkRow = { - currentItemIterator.get.head + override def head: TSparkRow = { + executeWithRetry("head", () => currentItemIterator.get.head) } private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = { @@ -206,6 +221,17 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp catch { case cosmosException: CosmosException => if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) { + if (currentPagePartiallyConsumed) { + // We have already emitted items from the current page upstream. Replaying + // the flux would re-skip only completed pages, not items within a page - + // which would cause silent duplication. Fail the task instead. + logError( + s"Transient failure in TransientIOErrorsRetryingReadManyByPartitionKeyIterator." + + s"$methodName after items from the current page were already emitted - " + + s"cannot safely retry without duplicating rows.", + cosmosException) + throw cosmosException + } val retryCountSnapshot = retryCount.incrementAndGet() if (retryCountSnapshot > maxRetryCount) { logError( @@ -217,7 +243,8 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp logWarning( s"Transient failure handled in " + s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName -" + - s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms", + s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms " + + s"(pagesCommitted=$pagesCommitted)", cosmosException) } } else { @@ -226,7 +253,7 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp case other: Throwable => throw other } - // Reset iterators but keep currentBatch so the batch is replayed + // Reset iterators; pagesCommitted is intentionally preserved so replay can skip them. currentItemIterator = None currentFeedResponseIterator = None Thread.sleep(retryIntervalInMs) @@ -236,6 +263,10 @@ private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSp returnValue.get } + // Clean up iterator references - the underlying Reactor subscription from + // CosmosPagedIterable.iterator will be cleaned up when the iterator is GC'd. + // This matches the behavior of TransientIOErrorsRetryingIterator; any still-prefetched + // pages are discarded with the iterator. override def close(): Unit = { currentItemIterator = None currentFeedResponseIterator = None diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala index a58d5b723b8b..fc038861a19a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala @@ -4,22 +4,25 @@ package com.azure.cosmos.spark.udf import com.azure.cosmos.spark.CosmosPartitionKeyHelper -import com.azure.cosmos.spark.CosmosPredicates.requireNotNull import org.apache.spark.sql.api.java.UDF1 @SerialVersionUID(1L) class GetCosmosPartitionKeyValue extends UDF1[Object, String] { - override def call - ( - partitionKeyValue: Object - ): String = { - requireNotNull(partitionKeyValue, "partitionKeyValue") - + // Null is a valid partition-key value (JSON null). A null input is serialized as a + // single-level partition key with a JSON null component; parsing that string back via + // CosmosPartitionKeyHelper.tryParsePartitionKey yields a PartitionKey built with + // addNullValue(). If the caller instead wants PartitionKey.NONE semantics (absent PK + // field) they should filter the null row before calling this UDF and use the + // schema-matched readManyByPartitionKey path with readManyByPk.nullHandling=None. + override def call(partitionKeyValue: Object): String = { partitionKeyValue match { + case null => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(null)) // for subpartitions case - Seq covers both WrappedArray (Scala 2.12) and ArraySeq (Scala 2.13) case seq: Seq[Any] => CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(seq.map(_.asInstanceOf[Object]).toList) - case _ => CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) + case _ => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) } } -} +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md index d8368be6a0da..243863e8de65 100644 --- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.47.0-beta.1 (Unreleased) #### Features Added -* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) +* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. Configure null handling via `spark.cosmos.read.readManyByPk.nullHandling` - default `Null` treats a null PK column as JSON null (`addNullValue`), `None` treats it as `PartitionKey.NONE` (`addNoneValue` / `NOT IS_DEFINED`). These route to different physical partitions - picking the wrong mode silently returns zero rows. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java index 162b0740f408..651019d91652 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java @@ -58,7 +58,7 @@ public class Configs { private static final String NETTY_HTTP_CLIENT_METRICS_ENABLED = "COSMOS.NETTY_HTTP_CLIENT_METRICS_ENABLED"; private static final String NETTY_HTTP_CLIENT_METRICS_ENABLED_VARIABLE = "COSMOS_NETTY_HTTP_CLIENT_METRICS_ENABLED"; - // Thin client connect/acquire timeout — controls CONNECT_TIMEOUT_MILLIS for Gateway V2 data plane endpoints. + // Thin client connect/acquire timeout - controls CONNECT_TIMEOUT_MILLIS for Gateway V2 data plane endpoints. // Data plane requests are routed to the thin client regional endpoint (from RegionalRoutingContext) // which uses a non-443 port. These get a shorter 5s connect/acquire timeout. // Metadata requests target Gateway V1 endpoint (port 443) and retain the full 45s/60s timeout (unchanged). @@ -249,9 +249,9 @@ public class Configs { public static final int DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE = 1; // readManyByPartitionKey: max number of PK values per query per physical partition - private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; - private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; - private static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 1000; + public static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; + public static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; + public static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 100; public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY = "COSMOS.MAX_BULK_MICRO_BATCH_CONCURRENCY"; public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY_VARIABLE = "COSMOS_MAX_BULK_MICRO_BATCH_CONCURRENCY"; @@ -683,7 +683,7 @@ public static int getThinClientConnectionTimeoutInMs() { } } - // Guard against invalid values — timeout must be at least 500ms + // Guard against invalid values - timeout must be at least 500ms if (value < 500) { logger.warn( "Invalid thin client connection timeout: {}ms. Must be >= 500. Falling back to default: {}ms.", diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index 6d6cd084e01a..24ac8ea19666 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -20,7 +20,7 @@ public class ReadManyByPartitionKeyQueryHelper { private static final String DEFAULT_TABLE_ALIAS = "c"; - // Internal parameter prefix — uses double-underscore to avoid collisions with user-provided parameters + // Internal parameter prefix - uses double-underscore to avoid collisions with user-provided parameters private static final String PK_PARAM_PREFIX = "@__rmPk_"; public static SqlQuerySpec createReadManyByPkQuerySpec( @@ -30,7 +30,19 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( List partitionKeySelectors, PartitionKeyDefinition pkDefinition) { - // Extract the table alias from the FROM clause (e.g. "FROM x" → "x", "FROM c" → "c") + // Guard against collisions with our internal parameter names - callers cannot realistically + // use the @__rmPk_ prefix for their own parameters, but if they do we surface a clear error + // rather than letting the server reject a SqlQuerySpec with duplicate parameter names. + for (SqlParameter baseParam : baseParameters) { + String name = baseParam.getName(); + if (name != null && name.startsWith(PK_PARAM_PREFIX)) { + throw new IllegalArgumentException( + "Custom query parameter name '" + name + "' collides with the reserved " + + "readManyByPartitionKey internal prefix '" + PK_PARAM_PREFIX + "'. Rename the parameter."); + } + } + + // Extract the table alias from the FROM clause (e.g. "FROM x" -> "x", "FROM c" -> "c") String tableAlias = extractTableAlias(baseQueryText); StringBuilder pkFilter = new StringBuilder(); @@ -40,7 +52,7 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( boolean isSinglePathPk = partitionKeySelectors.size() == 1; if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { - // Single PK path — use IN clause for normal values, OR NOT IS_DEFINED for NONE + // Single PK path - use IN clause for normal values, OR NOT IS_DEFINED for NONE // First, separate NONE PKs from normal PKs boolean hasNone = false; List normalPkValues = new ArrayList<>(); @@ -89,13 +101,13 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( pkFilter.append(")"); } } else { - // Multiple PK paths (HPK) or MULTI_HASH — use OR of AND clauses + // Multiple PK paths (HPK) or MULTI_HASH - use OR of AND clauses pkFilter.append(" "); for (int i = 0; i < pkValues.size(); i++) { PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); Object[] pkComponents = pkInternal.toObjectArray(); - // PartitionKey.NONE — generate NOT IS_DEFINED for all PK paths + // PartitionKey.NONE - generate NOT IS_DEFINED for all PK paths if (pkComponents == null) { pkFilter.append("("); for (int j = 0; j < partitionKeySelectors.size(); j++) { @@ -136,12 +148,12 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( String finalQuery; int whereIndex = findTopLevelWhereIndex(baseQueryText); if (whereIndex >= 0) { - // Base query has WHERE — AND our PK filter + // Base query has WHERE - AND our PK filter String beforeWhere = baseQueryText.substring(0, whereIndex); String afterWhere = baseQueryText.substring(whereIndex + 5); // skip "WHERE" finalQuery = beforeWhere + "WHERE (" + afterWhere.trim() + ") AND (" + pkFilter.toString().trim() + ")"; } else { - // No WHERE — add one + // No WHERE - add one finalQuery = baseQueryText + " WHERE" + pkFilter.toString(); } @@ -197,7 +209,7 @@ static String extractTableAlias(String queryText) { return containerName; } } - // Otherwise the next token is the alias ("FROM root r" → alias is "r") + // Otherwise the next token is the alias ("FROM root r" -> alias is "r") int aliasStart = afterFrom; while (afterFrom < queryText.length() && !Character.isWhitespace(queryText.charAt(afterFrom)) @@ -230,7 +242,7 @@ static int findTopLevelKeywordIndex(String queryText, String keyword) { while (i < queryText.length()) { if (queryText.charAt(i) == '\'') { if (i + 1 < queryText.length() && queryText.charAt(i + 1) == '\'') { - i += 2; // escaped quote — skip both + i += 2; // escaped quote - skip both continue; } break; // end of string literal diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index c70dedaa1f2c..3d9082a25bcc 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4499,7 +4499,7 @@ private Flux> readManyByPartitionKey( } // Build per-physical-partition batched queries. - // Each physical partition may have many PKs — split into batches + // Each physical partition may have many PKs - split into batches // to avoid oversized SQL queries. Batch size is configurable via // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 1000). int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); @@ -4534,9 +4534,11 @@ private Flux> readManyByPartitionKey( return Flux.empty(); } - // Round-robin interleave: [batch0-p1, batch0-p2, ..., batch0-pN, batch1-p1, batch1-p2, ...] - // This ensures that with bounded concurrency, different partitions are - // preferred over sequential batches of the same partition. + // Interleave batches across physical partitions so that the first batch for + // each partition is kicked off before the second batch of any partition. With bounded + // concurrency this spreads the initial wave of work across the cluster; note that + // skewed distributions (one partition with N batches, another with 1) will eventually + // fall back to sequential execution once the short partitions are drained. List> interleavedBatches = new ArrayList<>(); for (int batchIdx = 0; batchIdx < maxBatchesPerPartition; batchIdx++) { for (List> partitionBatches : batchesPerPartition) { @@ -4625,7 +4627,7 @@ private Map> groupPartitionKeysByPhysicalP PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); // PartitionKey.NONE wraps NonePartitionKey which has components = null. - // For routing purposes, treat NONE as UndefinedPartitionKey — documents ingested + // For routing purposes, treat NONE as UndefinedPartitionKey - documents ingested // without a partition key path are stored with the undefined EPK. PartitionKeyInternal effectivePkInternal = pkInternal.getComponents() == null ? PartitionKeyInternal.UndefinedPartitionKey @@ -4637,12 +4639,12 @@ private Map> groupPartitionKeysByPhysicalP List targetRanges; if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { - // Partial HPK — compute EPK prefix range and find all overlapping physical partitions + // Partial HPK - compute EPK prefix range and find all overlapping physical partitions Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( effectivePkInternal, pkDefinition); targetRanges = routingMap.getOverlappingRanges(epkRange); } else { - // Full PK — maps to exactly one physical partition + // Full PK - maps to exactly one physical partition String effectivePartitionKeyString = PartitionKeyInternalHelper .getEffectivePartitionKeyString(effectivePkInternal, pkDefinition); PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); From c96b6f6935913e3e3358fa74d591f5ebe1ff0663 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 13:38:57 +0000 Subject: [PATCH 23/25] Reacting to code review feedback --- sdk/cosmos/.gitignore | 2 + ...bianm_readManyByPK-vs-origin_main-full.txt | 3444 ----------------- ...bianm_readManyByPK-vs-origin_main-stat.txt | 76 - .../cosmos/spark/CosmosItemsDataSource.scala | 17 +- .../spark/CosmosPartitionKeyHelper.scala | 43 +- .../CosmosReadManyByPartitionKeyReader.scala | 8 + .../cosmos/ReadManyByPartitionKeyTest.java | 37 + .../azure/cosmos/CosmosAsyncContainer.java | 46 +- .../com/azure/cosmos/CosmosContainer.java | 29 +- .../ReadManyByPartitionKeyQueryHelper.java | 72 +- .../implementation/RxDocumentClientImpl.java | 8 +- 11 files changed, 218 insertions(+), 3564 deletions(-) delete mode 100644 sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt delete mode 100644 sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt diff --git a/sdk/cosmos/.gitignore b/sdk/cosmos/.gitignore index 1ea74182f6bb..81d278c1728f 100644 --- a/sdk/cosmos/.gitignore +++ b/sdk/cosmos/.gitignore @@ -2,3 +2,5 @@ metastore_db/* spark-warehouse/* + +.temp/ diff --git a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt deleted file mode 100644 index 5eee65a15a7b..000000000000 --- a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-full.txt +++ /dev/null @@ -1,3444 +0,0 @@ -===== PR #LOCAL-users_fabianm_readManyByPK ===== -Title: Branch comparison users/fabianm/readManyByPK vs origin/main -Author: Fabian Meiswinkel -Status: DIVERGED (ahead 30, behind 4) -Branch: users/fabianm/readManyByPK -> origin/main -Head SHA: 93957f3a8442d730fe67fbc379ef5399f46f5665 -Merge Base: 20313f79ba8dd0dfa97862d0c31dd4b2e44ee671 -URL: N/A (local branch comparison) - ---- Description --- -Adds readManyByPartitionKey API (sync+async) and Spark connector support for PK-only reads, with query-plan-based validation ---- End Description --- - -===== Commits in PR ===== -9770833eb59 Adding readManyByPartitionKey API -ac287bcdf00 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -9a5b3e96e7e Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -a8720c3c9f2 Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -d499da76fb4 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java -c3c542a33a7 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java -4416354e03e ┬┤Fixing code review comments -3ab3f0d64f5 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -588a7550c54 Update CosmosAsyncContainer.java -8c5cdb47b31 Merge branch 'main' into users/fabianm/readManyByPK -f5485527a9c Update ReadManyByPartitionKeyTest.java -f68cf02ff71 Fixing test issues -8b6c4b168ea Update CosmosAsyncContainer.java -8ba7f4db2da Merge branch 'main' into users/fabianm/readManyByPK -56b067a9339 Reacted to code review feedback -fa430e918fa Merge branch 'main' into users/fabianm/readManyByPK -d9504c91f34 Fix build issues -73151f09e5f Merge branch 'main' into users/fabianm/readManyByPK -681830e2d4a Fixing changelog -7f745e60641 Merge branch 'main' into users/fabianm/readManyByPK -0b8905dbb01 Addressing code review comments -22abc780ed8 Addressing code review feedback -662b1a4b90e Update CosmosItemsDataSource.scala -c764de9de02 Update CosmosItemsDataSource.scala -e1e6f5a6f73 Merge branch 'main' into users/fabianm/readManyByPK -080ce4a2293 Update RxDocumentClientImpl.java -516bbf3a95a Merge branch 'users/fabianm/readManyByPK' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -b01f8758eea Fix readManyByPartitionKey retries -7130d4aa35a Fix PK.None -93957f3a844 Update ReadManyByPartitionKeyQueryHelper.java - -===== Files Changed ===== - sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala (+20 -2) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala (+1 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala (+125 -1) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala (+45 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala (+150 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala (+249 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala (+259 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala (+25 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala (+42 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala (+104 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala (+158 -0) - sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java (+462 -0) - sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java (+426 -0) - sdk/cosmos/azure-cosmos/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java (+126 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java (+67 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java (+21 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java (+19 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java (+263 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java (+292 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java (+11 -0) - sdk/cosmos/cspell.yaml (+6 -0) - sdk/cosmos/docs/readManyByPartitionKey-design.md (+169 -0) - -===== Full Diff ===== -diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md -index cbf97c610f9..fe114462019 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.47.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md -index c9097e749f0..3b2c7ce36db 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.47.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md -index f5eac38bdb7..2240a48b165 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.47.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md -index 919d7fbfa32..20a3e3a61bd 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.47.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala -index 951f4735444..e1b8f0b51f8 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala -@@ -92,6 +92,7 @@ private[spark] object CosmosConfigNames { - val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter" - val ReadRuntimeFilteringEnabled = "spark.cosmos.read.runtimeFiltering.enabled" - val ReadManyFilteringEnabled = "spark.cosmos.read.readManyFiltering.enabled" -+ val ReadManyByPkNullHandling = "spark.cosmos.read.readManyByPk.nullHandling" - val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath" - val DiagnosticsMode = "spark.cosmos.diagnostics" - val DiagnosticsSamplingMaxCount = "spark.cosmos.diagnostics.sampling.maxCount" -@@ -226,6 +227,7 @@ private[spark] object CosmosConfigNames { - ReadPartitioningFeedRangeFilter, - ReadRuntimeFilteringEnabled, - ReadManyFilteringEnabled, -+ ReadManyByPkNullHandling, - ViewsRepositoryPath, - DiagnosticsMode, - DiagnosticsSamplingIntervalInSeconds, -@@ -1042,7 +1044,8 @@ private case class CosmosReadConfig(readConsistencyStrategy: ReadConsistencyStra - throughputControlConfig: Option[CosmosThroughputControlConfig] = None, - runtimeFilteringEnabled: Boolean, - readManyFilteringConfig: CosmosReadManyFilteringConfig, -- responseContinuationTokenLimitInKb: Option[Int] = None) -+ responseContinuationTokenLimitInKb: Option[Int] = None, -+ readManyByPkTreatNullAsNone: Boolean = false) - - private object SchemaConversionModes extends Enumeration { - type SchemaConversionMode = Value -@@ -1136,6 +1139,18 @@ private object CosmosReadConfig { - helpMessage = " Indicates whether dynamic partition pruning filters will be pushed down when applicable." - ) - -+ private val ReadManyByPkNullHandling = CosmosConfigEntry[String]( -+ key = CosmosConfigNames.ReadManyByPkNullHandling, -+ mandatory = false, -+ defaultValue = Some("Null"), -+ parseFromStringFunction = value => value, -+ helpMessage = "Determines how null values in hierarchical partition key components are treated " + -+ "for readManyByPartitionKey. 'Null' (default) maps null to a JSON null value via addNullValue(), " + -+ "which is appropriate when the document field exists with an explicit null value. " + -+ "'None' maps null to PartitionKey.NONE via addNoneValue(), which should only be used when the " + -+ "partition key path does not exist at all in the document." -+ ) -+ - def parseCosmosReadConfig(cfg: Map[String, String]): CosmosReadConfig = { - val forceEventualConsistency = CosmosConfigEntry.parse(cfg, ForceEventualConsistency) - val readConsistencyStrategyOverride = CosmosConfigEntry.parse(cfg, ReadConsistencyStrategyOverride) -@@ -1158,6 +1173,8 @@ private object CosmosReadConfig { - val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg) - val runtimeFilteringEnabled = CosmosConfigEntry.parse(cfg, ReadRuntimeFilteringEnabled) - val readManyFilteringConfig = CosmosReadManyFilteringConfig.parseCosmosReadManyFilterConfig(cfg) -+ val readManyByPkNullHandling = CosmosConfigEntry.parse(cfg, ReadManyByPkNullHandling) -+ val readManyByPkTreatNullAsNone = readManyByPkNullHandling.getOrElse("Null").equalsIgnoreCase("None") - - val effectiveReadConsistencyStrategy = if (readConsistencyStrategyOverride.getOrElse(ReadConsistencyStrategy.DEFAULT) != ReadConsistencyStrategy.DEFAULT) { - readConsistencyStrategyOverride.get -@@ -1189,7 +1206,8 @@ private object CosmosReadConfig { - throughputControlConfigOpt, - runtimeFilteringEnabled.get, - readManyFilteringConfig, -- responseContinuationTokenLimitInKb) -+ responseContinuationTokenLimitInKb, -+ readManyByPkTreatNullAsNone) - } - } - -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala -index 9ece4741652..00761f23d39 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala -@@ -45,6 +45,7 @@ private[cosmos] object CosmosConstants { - val Id = "id" - val ETag = "_etag" - val ItemIdentity = "_itemIdentity" -+ val PartitionKeyIdentity = "_partitionKeyIdentity" - } - - object StatusCodes { -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala -index a35cff27af6..86ef865bcb8 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala -@@ -2,9 +2,10 @@ - // Licensed under the MIT License. - package com.azure.cosmos.spark - --import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey} -+import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey, PartitionKeyBuilder} - import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver - import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -+import com.azure.cosmos.{SparkBridgeInternal} - import org.apache.spark.sql.{DataFrame, Row, SparkSession} - - import java.util -@@ -112,4 +113,127 @@ object CosmosItemsDataSource { - - readManyReader.readMany(df.rdd, readManyFilterExtraction) - } -+ -+ def readManyByPartitionKey(df: DataFrame, userConfig: java.util.Map[String, String]): DataFrame = { -+ readManyByPartitionKey(df, userConfig, null) -+ } -+ -+ def readManyByPartitionKey( -+ df: DataFrame, -+ userConfig: java.util.Map[String, String], -+ userProvidedSchema: StructType): DataFrame = { -+ -+ val readManyReader = new CosmosReadManyByPartitionKeyReader( -+ userProvidedSchema, -+ userConfig.asScala.toMap) -+ -+ // Option 1: Look for the _partitionKeyIdentity column (produced by GetCosmosPartitionKeyValue UDF) -+ val pkIdentityFieldExtraction = df -+ .schema -+ .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) -+ .map(field => (row: Row) => { -+ val rawValue = row.getString(row.fieldIndex(field.name)) -+ CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue) -+ .getOrElse(throw new IllegalArgumentException( -+ s"Invalid _partitionKeyIdentity value in row: '$rawValue'. " + -+ "Expected format: pk([...json...])")) -+ }) -+ -+ // Option 2: Detect PK columns by matching the container's partition key paths against the DataFrame schema -+ val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { -+ None // no need to resolve PK paths - _partitionKeyIdentity column takes precedence -+ } else { -+ val effectiveConfig = CosmosConfig.getEffectiveConfig( -+ databaseName = None, -+ containerName = None, -+ userConfig.asScala.toMap) -+ val readConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveConfig) -+ val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) -+ val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) -+ val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" -+ val treatNullAsNone = readConfig.readManyByPkTreatNullAsNone -+ -+ val pkPaths = Loan( -+ List[Option[CosmosClientCacheItem]]( -+ Some( -+ CosmosClientCache( -+ CosmosClientConfiguration( -+ effectiveConfig, -+ readConsistencyStrategy = readConfig.readConsistencyStrategy, -+ sparkEnvironmentInfo), -+ None, -+ calledFrom)), -+ ThroughputControlHelper.getThroughputControlClientCacheItem( -+ effectiveConfig, -+ calledFrom, -+ None, -+ sparkEnvironmentInfo) -+ )) -+ .to(clientCacheItems => { -+ val container = -+ ThroughputControlHelper.getContainer( -+ effectiveConfig, -+ containerConfig, -+ clientCacheItems(0).get, -+ clientCacheItems(1)) -+ -+ val pkDefinition = SparkBridgeInternal -+ .getContainerPropertiesFromCollectionCache(container) -+ .getPartitionKeyDefinition -+ -+ pkDefinition.getPaths.asScala.map(_.stripPrefix("/")).toList -+ }) -+ -+ // Check if ALL PK path columns exist in the DataFrame schema -+ val dfFieldNames = df.schema.fieldNames.toSet -+ val allPkColumnsPresent = pkPaths.forall(path => dfFieldNames.contains(path)) -+ -+ if (allPkColumnsPresent && pkPaths.nonEmpty) { -+ // pkPaths already defined above -+ Some((row: Row) => { -+ if (pkPaths.size == 1) { -+ // Single partition key -+ buildPartitionKey(row.getAs[Any](pkPaths.head), treatNullAsNone) -+ } else { -+ // Hierarchical partition key ΓÇö build level by level -+ val builder = new PartitionKeyBuilder() -+ for (path <- pkPaths) { -+ addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone) -+ } -+ builder.build() -+ } -+ }) -+ } else { -+ None -+ } -+ } -+ -+ val pkExtraction = pkIdentityFieldExtraction -+ .orElse(pkColumnExtraction) -+ .getOrElse( -+ throw new IllegalArgumentException( -+ "Cannot determine partition key extraction from the input DataFrame. " + -+ "Either add a '_partitionKeyIdentity' column (using the GetCosmosPartitionKeyValue UDF) " + -+ "or ensure the DataFrame contains columns matching the container's partition key paths.")) -+ -+ readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) -+ } -+ -+ private def addPartitionKeyComponent(builder: PartitionKeyBuilder, value: Any, treatNullAsNone: Boolean): Unit = { -+ value match { -+ case s: String => builder.add(s) -+ case n: Number => builder.add(n.doubleValue()) -+ case b: Boolean => builder.add(b) -+ case null => -+ if (treatNullAsNone) builder.addNoneValue() -+ else builder.addNullValue() -+ case other => builder.add(other.toString) -+ } -+ } -+ -+ private def buildPartitionKey(value: Any, treatNullAsNone: Boolean): PartitionKey = { -+ val builder = new PartitionKeyBuilder() -+ addPartitionKeyComponent(builder, value, treatNullAsNone) -+ builder.build() -+ } - } -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala -new file mode 100644 -index 00000000000..27776f5c3de ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala -@@ -0,0 +1,45 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.implementation.routing.PartitionKeyInternal -+import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, Utils} -+import com.azure.cosmos.models.PartitionKey -+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -+ -+import java.util -+ -+// scalastyle:off underscore.import -+import scala.collection.JavaConverters._ -+// scalastyle:on underscore.import -+ -+private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { -+ // pattern will be recognized -+ // pk(partitionKeyValue) -+ // -+ // (?i) : The whole matching is case-insensitive -+ // pk[(](.*)[)]: partitionKey Value -+ private val cosmosPartitionKeyStringRegx = """(?i)pk[(](.*)[)]""".r -+ private val objectMapper = Utils.getSimpleObjectMapper -+ -+ def getCosmosPartitionKeyValueString(partitionKeyValue: List[Object]): String = { -+ s"pk(${objectMapper.writeValueAsString(partitionKeyValue.asJava)})" -+ } -+ -+ def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = { -+ cosmosPartitionKeyString match { -+ case cosmosPartitionKeyStringRegx(pkValue) => -+ scala.util.Try(Utils.parse(pkValue, classOf[Object])).toOption.flatMap { -+ case arrayList: util.ArrayList[Object @unchecked] => -+ Some( -+ ImplementationBridgeHelpers -+ .PartitionKeyHelper -+ .getPartitionKeyAccessor -+ .toPartitionKey(PartitionKeyInternal.fromObjectArray(arrayList.toArray, false))) -+ case other => Some(new PartitionKey(other)) -+ } -+ case _ => None -+ } -+ } -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala -new file mode 100644 -index 00000000000..91f3a56bc66 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala -@@ -0,0 +1,150 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.{CosmosException, ReadConsistencyStrategy} -+import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, UUIDs} -+import com.azure.cosmos.models.PartitionKey -+import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver -+import com.azure.cosmos.spark.diagnostics.{BasicLoggingTrait, DiagnosticsContext} -+import com.fasterxml.jackson.databind.node.ObjectNode -+import org.apache.spark.TaskContext -+import org.apache.spark.broadcast.Broadcast -+import org.apache.spark.rdd.RDD -+import org.apache.spark.sql.{DataFrame, Row, SparkSession} -+import org.apache.spark.sql.types.StructType -+ -+import java.util.UUID -+ -+private[spark] class CosmosReadManyByPartitionKeyReader( -+ val userProvidedSchema: StructType, -+ val userConfig: Map[String, String] -+ ) extends BasicLoggingTrait with Serializable { -+ val effectiveUserConfig: Map[String, String] = CosmosConfig.getEffectiveConfig( -+ databaseName = None, -+ containerName = None, -+ userConfig) -+ -+ val clientConfig: CosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(effectiveUserConfig) -+ val readConfig: CosmosReadConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveUserConfig) -+ val cosmosContainerConfig: CosmosContainerConfig = -+ CosmosContainerConfig.parseCosmosContainerConfig(effectiveUserConfig) -+ //scalastyle:off multiple.string.literals -+ val tableName: String = s"com.azure.cosmos.spark.items.${clientConfig.accountName}." + -+ s"${cosmosContainerConfig.database}.${cosmosContainerConfig.container}" -+ private lazy val sparkSession = { -+ assertOnSparkDriver() -+ SparkSession.active -+ } -+ val sparkEnvironmentInfo: String = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(sparkSession)) -+ logTrace(s"Instantiated ${this.getClass.getSimpleName} for $tableName") -+ -+ private[spark] def initializeAndBroadcastCosmosClientStatesForContainer(): Broadcast[CosmosClientMetadataCachesSnapshots] = { -+ val calledFrom = s"CosmosReadManyByPartitionKeyReader($tableName).initializeAndBroadcastCosmosClientStateForContainer" -+ Loan( -+ List[Option[CosmosClientCacheItem]]( -+ Some( -+ CosmosClientCache( -+ CosmosClientConfiguration( -+ effectiveUserConfig, -+ readConsistencyStrategy = readConfig.readConsistencyStrategy, -+ sparkEnvironmentInfo), -+ None, -+ calledFrom)), -+ ThroughputControlHelper.getThroughputControlClientCacheItem( -+ effectiveUserConfig, -+ calledFrom, -+ None, -+ sparkEnvironmentInfo) -+ )) -+ .to(clientCacheItems => { -+ val container = -+ ThroughputControlHelper.getContainer( -+ effectiveUserConfig, -+ cosmosContainerConfig, -+ clientCacheItems(0).get, -+ clientCacheItems(1)) -+ try { -+ container.readItem( -+ UUIDs.nonBlockingRandomUUID().toString, -+ new PartitionKey(UUIDs.nonBlockingRandomUUID().toString), -+ classOf[ObjectNode]) -+ .block() -+ } catch { -+ case _: CosmosException => None -+ } -+ -+ val state = new CosmosClientMetadataCachesSnapshot() -+ state.serialize(clientCacheItems(0).get.cosmosClient) -+ -+ var throughputControlState: Option[CosmosClientMetadataCachesSnapshot] = None -+ if (clientCacheItems(1).isDefined) { -+ throughputControlState = Some(new CosmosClientMetadataCachesSnapshot()) -+ throughputControlState.get.serialize(clientCacheItems(1).get.cosmosClient) -+ } -+ -+ val metadataSnapshots = CosmosClientMetadataCachesSnapshots(state, throughputControlState) -+ sparkSession.sparkContext.broadcast(metadataSnapshots) -+ }) -+ } -+ -+ def readManyByPartitionKey(inputRdd: RDD[Row], pkExtraction: Row => PartitionKey): DataFrame = { -+ val correlationActivityId = UUIDs.nonBlockingRandomUUID() -+ val calledFrom = s"CosmosReadManyByPartitionKeyReader.readManyByPartitionKey($correlationActivityId)" -+ val schema = Loan( -+ List[Option[CosmosClientCacheItem]]( -+ Some(CosmosClientCache( -+ CosmosClientConfiguration( -+ effectiveUserConfig, -+ readConsistencyStrategy = readConfig.readConsistencyStrategy, -+ sparkEnvironmentInfo), -+ None, -+ calledFrom -+ )), -+ ThroughputControlHelper.getThroughputControlClientCacheItem( -+ effectiveUserConfig, -+ calledFrom, -+ None, -+ sparkEnvironmentInfo) -+ )) -+ .to(clientCacheItems => Option.apply(userProvidedSchema).getOrElse( -+ CosmosTableSchemaInferrer.inferSchema( -+ clientCacheItems(0).get, -+ clientCacheItems(1), -+ effectiveUserConfig, -+ ItemsTable.defaultSchemaForInferenceDisabled))) -+ -+ val clientStates = initializeAndBroadcastCosmosClientStatesForContainer -+ -+ sparkSession.sqlContext.createDataFrame( -+ inputRdd.mapPartitionsWithIndex( -+ (partitionIndex: Int, rowIterator: Iterator[Row]) => { -+ val pkIterator: Iterator[PartitionKey] = rowIterator -+ .map(row => pkExtraction.apply(row)) -+ -+ logInfo(s"Creating an ItemsPartitionReaderWithReadManyByPartitionKey for Activity $correlationActivityId to read for " -+ + s"input partition [$partitionIndex] ${tableName}") -+ -+ val reader = new ItemsPartitionReaderWithReadManyByPartitionKey( -+ effectiveUserConfig, -+ CosmosReadManyHelper.FullRangeFeedRange, -+ schema, -+ DiagnosticsContext(correlationActivityId, partitionIndex.toString), -+ clientStates, -+ DiagnosticsConfig.parseDiagnosticsConfig(effectiveUserConfig), -+ sparkEnvironmentInfo, -+ TaskContext.get, -+ pkIterator) -+ -+ new Iterator[Row] { -+ override def hasNext: Boolean = reader.next() -+ -+ override def next(): Row = reader.getCurrentRow() -+ } -+ }, -+ preservesPartitioning = true -+ ), -+ schema) -+ } -+} -+ -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -new file mode 100644 -index 00000000000..c67cc9c10be ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -@@ -0,0 +1,249 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} -+import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple -+import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} -+import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition, SqlQuerySpec} -+import com.azure.cosmos.spark.BulkWriter.getThreadInfo -+import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName -+import com.azure.cosmos.spark.diagnostics.{DetailedFeedDiagnosticsProvider, DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} -+import com.fasterxml.jackson.databind.node.ObjectNode -+import org.apache.spark.TaskContext -+import org.apache.spark.broadcast.Broadcast -+import org.apache.spark.sql.Row -+import org.apache.spark.sql.catalyst.InternalRow -+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -+import org.apache.spark.sql.connector.read.PartitionReader -+import org.apache.spark.sql.types.StructType -+ -+import java.util -+ -+// scalastyle:off underscore.import -+import scala.collection.JavaConverters._ -+// scalastyle:on underscore.import -+ -+private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey -+( -+ config: Map[String, String], -+ feedRange: NormalizedRange, -+ readSchema: StructType, -+ diagnosticsContext: DiagnosticsContext, -+ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], -+ diagnosticsConfig: DiagnosticsConfig, -+ sparkEnvironmentInfo: String, -+ taskContext: TaskContext, -+ readManyPartitionKeys: Iterator[PartitionKey] -+) -+ extends PartitionReader[InternalRow] { -+ -+ private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) -+ -+ private val readManyOptions = new CosmosReadManyRequestOptions() -+ private val readManyOptionsImpl = ImplementationBridgeHelpers -+ .CosmosReadManyRequestOptionsHelper -+ .getCosmosReadManyRequestOptionsAccessor -+ .getImpl(readManyOptions) -+ -+ private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) -+ ThroughputControlHelper.populateThroughputControlGroupName(readManyOptionsImpl, readConfig.throughputControlConfig) -+ -+ private val operationContext = { -+ assert(taskContext != null) -+ -+ SparkTaskContext(diagnosticsContext.correlationActivityId, -+ taskContext.stageId(), -+ taskContext.partitionId(), -+ taskContext.taskAttemptId(), -+ feedRange.toString) -+ } -+ -+ private val operationContextAndListenerTuple: Option[OperationContextAndListenerTuple] = { -+ if (diagnosticsConfig.mode.isDefined) { -+ val listener = -+ DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass) -+ -+ val ctxAndListener = new OperationContextAndListenerTuple(operationContext, listener) -+ -+ readManyOptionsImpl -+ .setOperationContextAndListenerTuple(ctxAndListener) -+ -+ Some(ctxAndListener) -+ } else { -+ None -+ } -+ } -+ -+ log.logTrace(s"Instantiated ${this.getClass.getSimpleName}, Context: ${operationContext.toString} $getThreadInfo") -+ -+ private val containerTargetConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) -+ -+ log.logInfo(s"Using ReadManyByPartitionKey from feed range $feedRange of " + -+ s"container ${containerTargetConfig.database}.${containerTargetConfig.container} - " + -+ s"correlationActivityId ${diagnosticsContext.correlationActivityId}, " + -+ s"Context: ${operationContext.toString} $getThreadInfo") -+ -+ private val clientCacheItem = CosmosClientCache( -+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), -+ Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), -+ s"ItemsPartitionReaderWithReadManyByPartitionKey($feedRange, ${containerTargetConfig.database}.${containerTargetConfig.container})" -+ ) -+ -+ private val throughputControlClientCacheItemOpt = -+ ThroughputControlHelper.getThroughputControlClientCacheItem( -+ config, -+ clientCacheItem.context, -+ Some(cosmosClientStateHandles), -+ sparkEnvironmentInfo) -+ -+ private val cosmosAsyncContainer = -+ ThroughputControlHelper.getContainer( -+ config, -+ containerTargetConfig, -+ clientCacheItem, -+ throughputControlClientCacheItemOpt) -+ -+ private val partitionKeyDefinition: PartitionKeyDefinition = { -+ TransientErrorsRetryPolicy.executeWithRetry(() => { -+ SparkBridgeInternal -+ .getContainerPropertiesFromCollectionCache(cosmosAsyncContainer).getPartitionKeyDefinition -+ }) -+ } -+ -+ private val cosmosSerializationConfig = CosmosSerializationConfig.parseSerializationConfig(config) -+ private val cosmosRowConverter = CosmosRowConverter.get(cosmosSerializationConfig) -+ -+ readManyOptionsImpl -+ .setCustomItemSerializer( -+ new CosmosItemSerializerNoExceptionWrapping { -+ override def serialize[T](item: T): util.Map[String, AnyRef] = { -+ throw new UnsupportedOperationException( -+ s"Serialization is not supported by the custom item serializer in " + -+ s"ItemsPartitionReaderWithReadManyByPartitionKey; this serializer is intended " + -+ s"for deserializing read-many responses into SparkRowItem only. " + -+ s"Unexpected item type: ${if (item == null) "null" else item.getClass.getName}" -+ ) -+ } -+ -+ override def deserialize[T](jsonNodeMap: util.Map[String, AnyRef], classType: Class[T]): T = { -+ if (jsonNodeMap == null) { -+ throw new IllegalStateException("The 'jsonNodeMap' should never be null here.") -+ } -+ -+ if (classType != classOf[SparkRowItem]) { -+ throw new IllegalStateException("The 'classType' must be 'classOf[SparkRowItem])' here.") -+ } -+ -+ val objectNode: ObjectNode = jsonNodeMap match { -+ case map: ObjectNodeMap => -+ map.getObjectNode -+ case _ => -+ Utils.getSimpleObjectMapper.convertValue(jsonNodeMap, classOf[ObjectNode]) -+ } -+ -+ val partitionKey = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition) -+ -+ val row = cosmosRowConverter.fromObjectNodeToRow(readSchema, -+ objectNode, -+ readConfig.schemaConversionMode) -+ -+ SparkRowItem(row, getPartitionKeyForFeedDiagnostics(partitionKey)).asInstanceOf[T] -+ } -+ } -+ ) -+ -+ // Collect all PK values upfront ΓÇö readManyByPartitionKey needs the full list to -+ // group by physical partition and issue parallel queries. -+ // Deduplicate by PK string representation ΓÇö safe because the list size is bounded -+ // by the per-call limit of the readManyByPartitionKey API. -+ private lazy val pkList = { -+ val seen = new java.util.LinkedHashMap[String, PartitionKey]() -+ readManyPartitionKeys.foreach(pk => seen.putIfAbsent(pk.toString, pk)) -+ new java.util.ArrayList[PartitionKey](seen.values()) -+ } -+ -+ private val endToEndTimeoutPolicy = -+ new CosmosEndToEndOperationLatencyPolicyConfigBuilder( -+ java.time.Duration.ofSeconds(CosmosConstants.readOperationEndToEndTimeoutInSeconds)) -+ .enable(true) -+ .build -+ -+ readManyOptionsImpl.setCosmosEndToEndOperationLatencyPolicyConfig(endToEndTimeoutPolicy) -+ -+ private trait CloseableSparkRowItemIterator { -+ def hasNext: Boolean -+ def next(): SparkRowItem -+ def close(): Unit -+ } -+ -+ private object EmptySparkRowItemIterator extends CloseableSparkRowItemIterator { -+ override def hasNext: Boolean = false -+ -+ override def next(): SparkRowItem = { -+ throw new java.util.NoSuchElementException("No items available for empty partition-key list.") -+ } -+ -+ override def close(): Unit = {} -+ } -+ -+ // Batch partition keys and retry each batch independently on transient I/O errors. -+ // This avoids the continuation-token problem with TransientIOErrorsRetryingIterator -+ // where a retry would re-read all data from scratch, causing silent data duplication. -+ private lazy val iterator: CloseableSparkRowItemIterator = -+ if (pkList.isEmpty) { -+ EmptySparkRowItemIterator -+ } else { -+ new CloseableSparkRowItemIterator { -+ private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( -+ cosmosAsyncContainer, -+ pkList, -+ readConfig.customQuery.map(_.toSqlQuerySpec), -+ readManyOptions, -+ readConfig.maxItemCount, -+ readConfig.prefetchBufferSize, -+ operationContextAndListenerTuple, -+ classOf[SparkRowItem] -+ ) -+ -+ override def hasNext: Boolean = delegate.hasNext -+ -+ override def next(): SparkRowItem = delegate.next() -+ -+ override def close(): Unit = delegate.close() -+ } -+ } -+ -+ private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) -+ -+ private def shouldLogDetailedFeedDiagnostics(): Boolean = { -+ diagnosticsConfig.mode.isDefined && -+ diagnosticsConfig.mode.get.equalsIgnoreCase(classOf[DetailedFeedDiagnosticsProvider].getName) -+ } -+ -+ private def getPartitionKeyForFeedDiagnostics(pkValue: PartitionKey): Option[PartitionKey] = { -+ if (shouldLogDetailedFeedDiagnostics()) { -+ Some(pkValue) -+ } else { -+ None -+ } -+ } -+ -+ override def next(): Boolean = iterator.hasNext -+ -+ override def get(): InternalRow = { -+ cosmosRowConverter.fromRowToInternalRow(iterator.next().row, rowSerializer) -+ } -+ -+ def getCurrentRow(): Row = iterator.next().row -+ -+ override def close(): Unit = { -+ this.iterator.close() -+ RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) -+ clientCacheItem.close() -+ if (throughputControlClientCacheItemOpt.isDefined) { -+ throughputControlClientCacheItemOpt.get.close() -+ } -+ } -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala -new file mode 100644 -index 00000000000..dcfdf4f9353 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala -@@ -0,0 +1,259 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.{CosmosAsyncContainer, CosmosException} -+import com.azure.cosmos.implementation.OperationCancelledException -+import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple -+import com.azure.cosmos.models.{CosmosReadManyRequestOptions, FeedResponse, PartitionKey, SqlQuerySpec} -+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait -+import com.azure.cosmos.util.CosmosPagedIterable -+ -+import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} -+import java.util.concurrent.atomic.AtomicLong -+import scala.concurrent.{Await, ExecutionContext, Future} -+import scala.util.Random -+import scala.util.control.Breaks -+ -+// scalastyle:off underscore.import -+import scala.collection.JavaConverters._ -+// scalastyle:on underscore.import -+ -+/** -+ * Retry-safe iterator for readManyByPartitionKey that batches partition keys and lazily -+ * iterates pages within each batch via CosmosPagedIterable ΓÇö consistent with how -+ * TransientIOErrorsRetryingIterator handles normal queries. On transient I/O errors the -+ * current batch's flux is recreated and pages already consumed are replayed, avoiding -+ * the memory overhead of collectList and matching the query iterator's structure. -+ */ -+private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] -+( -+ val container: CosmosAsyncContainer, -+ val partitionKeys: java.util.List[PartitionKey], -+ val customQuery: Option[SqlQuerySpec], -+ val queryOptions: CosmosReadManyRequestOptions, -+ val pageSize: Int, -+ val pagePrefetchBufferSize: Int, -+ val operationContextAndListener: Option[OperationContextAndListenerTuple], -+ val classType: Class[TSparkRow] -+) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { -+ -+ private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs -+ private[spark] var maxRetryCount = CosmosConstants.maxRetryCountForTransientFailures -+ -+ private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration( -+ 5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds, -+ scala.concurrent.duration.SECONDS) -+ -+ private val rnd = Random -+ private val retryCount = new AtomicLong(0) -+ private lazy val operationContextString = operationContextAndListener match { -+ case Some(o) => if (o.getOperationContext != null) { -+ o.getOperationContext.toString -+ } else { -+ "n/a" -+ } -+ case None => "n/a" -+ } -+ -+ private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None -+ private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None -+ -+ private val pkBatchIterator = partitionKeys.asScala.iterator.grouped(pageSize) -+ // Track the current batch so we can replay it on retry -+ private var currentBatch: Option[java.util.List[PartitionKey]] = None -+ -+ override def hasNext: Boolean = { -+ executeWithRetry("hasNextInternal", () => hasNextInternal) -+ } -+ -+ private def hasNextInternal: Boolean = { -+ var returnValue: Option[Boolean] = None -+ -+ while (returnValue.isEmpty) { -+ returnValue = hasNextInternalCore -+ } -+ -+ returnValue.get -+ } -+ -+ private def hasNextInternalCore: Option[Boolean] = { -+ if (hasBufferedNext) { -+ Some(true) -+ } else { -+ val feedResponseIterator = currentFeedResponseIterator match { -+ case Some(existing) => existing -+ case None => -+ // Need a new feed response iterator ΓÇö either for the current batch (on retry) -+ // or for the next batch -+ val batch = currentBatch match { -+ case Some(b) => b // retry of current batch -+ case None => -+ if (pkBatchIterator.hasNext) { -+ val nextBatch = new java.util.ArrayList[PartitionKey](pkBatchIterator.next().toList.asJava) -+ currentBatch = Some(nextBatch) -+ nextBatch -+ } else { -+ return Some(false) // no more batches -+ } -+ } -+ -+ val pagedFlux = customQuery match { -+ case Some(query) => -+ container.readManyByPartitionKey(batch, query, queryOptions, classType) -+ case None => -+ container.readManyByPartitionKey(batch, queryOptions, classType) -+ } -+ -+ currentFeedResponseIterator = Some( -+ new CosmosPagedIterable[TSparkRow]( -+ pagedFlux, -+ pageSize, -+ pagePrefetchBufferSize -+ ) -+ .iterableByPage() -+ .iterator -+ .asScala -+ .buffered -+ ) -+ -+ currentFeedResponseIterator.get -+ } -+ -+ val hasNext: Boolean = try { -+ Await.result( -+ Future { -+ feedResponseIterator.hasNext -+ }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), -+ maxPageRetrievalTimeout) -+ } catch { -+ case endToEndTimeoutException: OperationCancelledException => -+ val message = s"End-to-end timeout hit when trying to retrieve the next page. " + -+ s"Context: $operationContextString" -+ logError(message, throwable = endToEndTimeoutException) -+ throw endToEndTimeoutException -+ -+ case timeoutException: TimeoutException => -+ val message = s"Attempting to retrieve the next page timed out. " + -+ s"Context: $operationContextString" -+ logError(message, timeoutException) -+ val exception = new OperationCancelledException(message, null) -+ exception.setStackTrace(timeoutException.getStackTrace) -+ throw exception -+ -+ case other: Throwable => throw other -+ } -+ -+ if (hasNext) { -+ val feedResponse = feedResponseIterator.next() -+ if (operationContextAndListener.isDefined) { -+ operationContextAndListener.get.getOperationListener.feedResponseProcessedListener( -+ operationContextAndListener.get.getOperationContext, -+ feedResponse) -+ } -+ val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered -+ -+ if (iteratorCandidate.hasNext) { -+ currentItemIterator = Some(iteratorCandidate) -+ Some(true) -+ } else { -+ // empty page interleaved ΓÇö try again -+ None -+ } -+ } else { -+ // Current batch's flux is exhausted ΓÇö move to next batch -+ currentBatch = None -+ currentFeedResponseIterator = None -+ None -+ } -+ } -+ } -+ -+ private def hasBufferedNext: Boolean = { -+ currentItemIterator match { -+ case Some(iterator) => if (iterator.hasNext) { -+ true -+ } else { -+ currentItemIterator = None -+ false -+ } -+ case None => false -+ } -+ } -+ -+ override def next(): TSparkRow = { -+ currentItemIterator.get.next() -+ } -+ -+ override def head(): TSparkRow = { -+ currentItemIterator.get.head -+ } -+ -+ private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = { -+ val loop = new Breaks() -+ var returnValue: Option[T] = None -+ -+ loop.breakable { -+ while (true) { -+ val retryIntervalInMs = rnd.nextInt(maxRetryIntervalInMs) -+ -+ try { -+ returnValue = Some(func()) -+ retryCount.set(0) -+ loop.break -+ } -+ catch { -+ case cosmosException: CosmosException => -+ if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) { -+ val retryCountSnapshot = retryCount.incrementAndGet() -+ if (retryCountSnapshot > maxRetryCount) { -+ logError( -+ s"Too many transient failure retry attempts in " + -+ s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName", -+ cosmosException) -+ throw cosmosException -+ } else { -+ logWarning( -+ s"Transient failure handled in " + -+ s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName -" + -+ s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms", -+ cosmosException) -+ } -+ } else { -+ throw cosmosException -+ } -+ case other: Throwable => throw other -+ } -+ -+ // Reset iterators but keep currentBatch so the batch is replayed -+ currentItemIterator = None -+ currentFeedResponseIterator = None -+ Thread.sleep(retryIntervalInMs) -+ } -+ } -+ -+ returnValue.get -+ } -+ -+ override def close(): Unit = { -+ currentItemIterator = None -+ currentFeedResponseIterator = None -+ } -+} -+ -+private object TransientIOErrorsRetryingReadManyByPartitionKeyIterator extends BasicLoggingTrait { -+ private val maxConcurrency = SparkUtils.getNumberOfHostCPUCores -+ -+ val executorService: ExecutorService = new ThreadPoolExecutor( -+ maxConcurrency, -+ maxConcurrency, -+ 0L, -+ TimeUnit.MILLISECONDS, -+ new SynchronousQueue(), -+ SparkUtils.daemonThreadFactory(), -+ new ThreadPoolExecutor.CallerRunsPolicy() -+ ) -+ -+ val executionContext: ExecutionContext = ExecutionContext.fromExecutorService(executorService) -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala -new file mode 100644 -index 00000000000..a58d5b723b8 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala -@@ -0,0 +1,25 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark.udf -+ -+import com.azure.cosmos.spark.CosmosPartitionKeyHelper -+import com.azure.cosmos.spark.CosmosPredicates.requireNotNull -+import org.apache.spark.sql.api.java.UDF1 -+ -+@SerialVersionUID(1L) -+class GetCosmosPartitionKeyValue extends UDF1[Object, String] { -+ override def call -+ ( -+ partitionKeyValue: Object -+ ): String = { -+ requireNotNull(partitionKeyValue, "partitionKeyValue") -+ -+ partitionKeyValue match { -+ // for subpartitions case - Seq covers both WrappedArray (Scala 2.12) and ArraySeq (Scala 2.13) -+ case seq: Seq[Any] => -+ CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(seq.map(_.asInstanceOf[Object]).toList) -+ case _ => CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) -+ } -+ } -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala -index 17f75e45a74..17a298d6213 100644 ---- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala -@@ -457,6 +457,7 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { - config.runtimeFilteringEnabled shouldBe true - config.readManyFilteringConfig.readManyFilteringEnabled shouldBe false - config.readManyFilteringConfig.readManyFilterProperty shouldEqual "_itemIdentity" -+ config.readManyByPkTreatNullAsNone shouldBe false - - userConfig = Map( - "spark.cosmos.read.forceEventualConsistency" -> "false", -@@ -630,6 +631,47 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { - config.customQuery.get.queryText shouldBe queryText - } - -+ it should "parse readManyByPk nullHandling configuration" in { -+ // Default (not specified) should treat null as JSON null (addNullValue) -+ var userConfig = Map( -+ "spark.cosmos.read.forceEventualConsistency" -> "false" -+ ) -+ var config = CosmosReadConfig.parseCosmosReadConfig(userConfig) -+ config.readManyByPkTreatNullAsNone shouldBe false -+ -+ // Explicit "Null" should treat null as JSON null (addNullValue) -+ userConfig = Map( -+ "spark.cosmos.read.forceEventualConsistency" -> "false", -+ "spark.cosmos.read.readManyByPk.nullHandling" -> "Null" -+ ) -+ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) -+ config.readManyByPkTreatNullAsNone shouldBe false -+ -+ // Case-insensitive "null" -+ userConfig = Map( -+ "spark.cosmos.read.forceEventualConsistency" -> "false", -+ "spark.cosmos.read.readManyByPk.nullHandling" -> "null" -+ ) -+ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) -+ config.readManyByPkTreatNullAsNone shouldBe false -+ -+ // "None" should treat null as PartitionKey.NONE (addNoneValue) -+ userConfig = Map( -+ "spark.cosmos.read.forceEventualConsistency" -> "false", -+ "spark.cosmos.read.readManyByPk.nullHandling" -> "None" -+ ) -+ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) -+ config.readManyByPkTreatNullAsNone shouldBe true -+ -+ // Case-insensitive "none" -+ userConfig = Map( -+ "spark.cosmos.read.forceEventualConsistency" -> "false", -+ "spark.cosmos.read.readManyByPk.nullHandling" -> "none" -+ ) -+ config = CosmosReadConfig.parseCosmosReadConfig(userConfig) -+ config.readManyByPkTreatNullAsNone shouldBe true -+ } -+ - it should "throw on invalid read configuration" in { - val userConfig = Map( - "spark.cosmos.read.schemaConversionMode" -> "not a valid value" -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala -new file mode 100644 -index 00000000000..1ac40e39584 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala -@@ -0,0 +1,104 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} -+ -+class CosmosPartitionKeyHelperSpec extends UnitSpec { -+ //scalastyle:off multiple.string.literals -+ //scalastyle:off magic.number -+ -+ it should "return the correct partition key value string for single PK" in { -+ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("pk1")) -+ pkString shouldEqual "pk([\"pk1\"])" -+ } -+ -+ it should "return the correct partition key value string for HPK" in { -+ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city1", "zip1")) -+ pkString shouldEqual "pk([\"city1\",\"zip1\"])" -+ } -+ -+ it should "return the correct partition key value string for 3-level HPK" in { -+ val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("a", "b", "c")) -+ pkString shouldEqual "pk([\"a\",\"b\",\"c\"])" -+ } -+ -+ it should "parse valid single PK string" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"myPkValue\"])") -+ pk.isDefined shouldBe true -+ pk.get shouldEqual new PartitionKey("myPkValue") -+ } -+ -+ it should "parse valid HPK string" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"city1\",\"zip1\"])") -+ pk.isDefined shouldBe true -+ val expected = new PartitionKeyBuilder().add("city1").add("zip1").build() -+ pk.get shouldEqual expected -+ } -+ -+ it should "parse valid 3-level HPK string" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"a\",\"b\",\"c\"])") -+ pk.isDefined shouldBe true -+ val expected = new PartitionKeyBuilder().add("a").add("b").add("c").build() -+ pk.get shouldEqual expected -+ } -+ -+ it should "roundtrip single PK" in { -+ val original = "pk([\"roundtrip\"])" -+ val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) -+ parsed.isDefined shouldBe true -+ val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("roundtrip")) -+ serialized shouldEqual original -+ } -+ -+ it should "roundtrip HPK" in { -+ val original = "pk([\"city\",\"zip\"])" -+ val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) -+ parsed.isDefined shouldBe true -+ val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city", "zip")) -+ serialized shouldEqual original -+ } -+ -+ it should "return None for malformed string" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("invalid_format") -+ pk.isDefined shouldBe false -+ } -+ -+ it should "return None for missing pk prefix" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("[\"value\"]") -+ pk.isDefined shouldBe false -+ } -+ -+ it should "be case-insensitive for parsing" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("PK([\"value\"])") -+ pk.isDefined shouldBe true -+ pk.get shouldEqual new PartitionKey("value") -+ } -+ -+ -+ it should "return None for malformed JSON inside pk() wrapper" in { -+ // Invalid JSON that would cause JsonProcessingException -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk({invalid json})") -+ pk.isDefined shouldBe false -+ } -+ -+ it should "return None for truncated JSON inside pk() wrapper" in { -+ val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"unterminated)") -+ pk.isDefined shouldBe false -+ } -+ -+ it should "produce different partition keys for addNullValue vs addNoneValue in HPK" in { -+ // addNullValue represents an explicit JSON null for a field that exists with value null -+ val pkWithNull = new PartitionKeyBuilder().add("Redmond").addNullValue().build() -+ -+ // addNoneValue represents PartitionKey.NONE, meaning the field is absent/undefined -+ val pkWithNone = new PartitionKeyBuilder().add("Redmond").addNoneValue().build() -+ -+ // These MUST produce different partition key hashes and route to different physical partitions -+ pkWithNull should not equal pkWithNone -+ } -+ -+ //scalastyle:on multiple.string.literals -+ //scalastyle:on magic.number -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala -new file mode 100644 -index 00000000000..5c2d7b59836 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala -@@ -0,0 +1,158 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+ -+package com.azure.cosmos.spark -+ -+import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, TestConfigurations, Utils} -+import com.azure.cosmos.models.PartitionKey -+import com.azure.cosmos.spark.diagnostics.DiagnosticsContext -+import com.fasterxml.jackson.databind.node.ObjectNode -+import org.apache.spark.MockTaskContext -+import org.apache.spark.broadcast.Broadcast -+import org.apache.spark.sql.types.{StringType, StructField, StructType} -+ -+import java.util.UUID -+import scala.collection.mutable.ListBuffer -+ -+class ItemsPartitionReaderWithReadManyByPartitionKeyITest -+ extends IntegrationSpec -+ with Spark -+ with AutoCleanableCosmosContainersWithPkAsPartitionKey { -+ private val idProperty = "id" -+ private val pkProperty = "pk" -+ -+ //scalastyle:off multiple.string.literals -+ //scalastyle:off magic.number -+ -+ it should "be able to retrieve all items for given partition keys" in { -+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) -+ -+ // Create items with known PK values -+ val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition -+ val allItemsByPk = scala.collection.mutable.Map[String, ListBuffer[ObjectNode]]() -+ val pkValues = List("pkA", "pkB", "pkC") -+ -+ for (pk <- pkValues) { -+ allItemsByPk(pk) = ListBuffer[ObjectNode]() -+ for (_ <- 1 to 5) { -+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode() -+ objectNode.put(idProperty, UUID.randomUUID().toString) -+ objectNode.put(pkProperty, pk) -+ container.createItem(objectNode).block() -+ allItemsByPk(pk) += objectNode -+ } -+ } -+ -+ val config = Map( -+ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, -+ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, -+ "spark.cosmos.database" -> cosmosDatabase, -+ "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, -+ "spark.cosmos.read.inferSchema.enabled" -> "true", -+ "spark.cosmos.applicationName" -> "ReadManyByPKTest" -+ ) -+ -+ val readSchema = StructType(Seq( -+ StructField(idProperty, StringType, false), -+ StructField(pkProperty, StringType, false) -+ )) -+ -+ val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") -+ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) -+ val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() -+ -+ // Read items for pkA and pkB (not pkC) -+ val targetPks = List("pkA", "pkB") -+ val pkIterator = targetPks.map(pk => new PartitionKey(pk)).iterator -+ -+ val reader = ItemsPartitionReaderWithReadManyByPartitionKey( -+ config, -+ NormalizedRange("", "FF"), -+ readSchema, -+ diagnosticsContext, -+ cosmosClientMetadataCachesSnapshots, -+ diagnosticsConfig, -+ "", -+ MockTaskContext.mockTaskContext(), -+ pkIterator -+ ) -+ -+ val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) -+ val itemsReadFromReader = ListBuffer[ObjectNode]() -+ while (reader.next()) { -+ itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) -+ } -+ -+ // Should have 10 items (5 for pkA + 5 for pkB) -+ itemsReadFromReader.size shouldEqual 10 -+ -+ // All items should be from pkA or pkB -+ itemsReadFromReader.foreach(item => { -+ val pk = item.get(pkProperty).asText() -+ targetPks should contain(pk) -+ }) -+ -+ // Validate all expected IDs are present -+ val expectedIds = (allItemsByPk("pkA") ++ allItemsByPk("pkB")).map(_.get(idProperty).asText()).toSet -+ val actualIds = itemsReadFromReader.map(_.get(idProperty).asText()).toSet -+ actualIds shouldEqual expectedIds -+ -+ reader.close() -+ } -+ -+ it should "return empty results for non-existent partition keys" in { -+ val config = Map( -+ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, -+ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, -+ "spark.cosmos.database" -> cosmosDatabase, -+ "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, -+ "spark.cosmos.read.inferSchema.enabled" -> "true", -+ "spark.cosmos.applicationName" -> "ReadManyByPKEmptyTest" -+ ) -+ -+ val readSchema = StructType(Seq( -+ StructField(idProperty, StringType, false), -+ StructField(pkProperty, StringType, false) -+ )) -+ -+ val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") -+ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) -+ val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() -+ -+ val pkIterator = List(new PartitionKey("nonExistentPk")).iterator -+ -+ val reader = ItemsPartitionReaderWithReadManyByPartitionKey( -+ config, -+ NormalizedRange("", "FF"), -+ readSchema, -+ diagnosticsContext, -+ cosmosClientMetadataCachesSnapshots, -+ diagnosticsConfig, -+ "", -+ MockTaskContext.mockTaskContext(), -+ pkIterator -+ ) -+ -+ val itemsReadFromReader = ListBuffer[ObjectNode]() -+ val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) -+ while (reader.next()) { -+ itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) -+ } -+ -+ itemsReadFromReader.size shouldEqual 0 -+ reader.close() -+ } -+ -+ private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = { -+ val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot() -+ cosmosClientMetadataCachesSnapshot.serialize(cosmosClient) -+ -+ spark.sparkContext.broadcast( -+ CosmosClientMetadataCachesSnapshots( -+ cosmosClientMetadataCachesSnapshot, -+ Option.empty[CosmosClientMetadataCachesSnapshot])) -+ } -+ -+ //scalastyle:on multiple.string.literals -+ //scalastyle:on magic.number -+} -diff --git a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md -index 3972ae6aeb9..d8368be6a0d 100644 ---- a/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.47.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `CosmosItemsDataSource.readManyByPartitionKey` Spark function to execute bulk queries by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java -new file mode 100644 -index 00000000000..2c26d564ed2 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java -@@ -0,0 +1,462 @@ -+/* -+ * Copyright (c) Microsoft Corporation. All rights reserved. -+ * Licensed under the MIT License. -+ */ -+ -+package com.azure.cosmos; -+ -+import com.azure.cosmos.models.CosmosContainerProperties; -+import com.azure.cosmos.models.CosmosItemRequestOptions; -+import com.azure.cosmos.models.FeedResponse; -+import com.azure.cosmos.models.PartitionKey; -+import com.azure.cosmos.models.PartitionKeyBuilder; -+import com.azure.cosmos.models.PartitionKeyDefinition; -+import com.azure.cosmos.models.PartitionKeyDefinitionVersion; -+import com.azure.cosmos.models.PartitionKind; -+import com.azure.cosmos.models.SqlParameter; -+import com.azure.cosmos.models.SqlQuerySpec; -+import com.azure.cosmos.rx.TestSuiteBase; -+import com.azure.cosmos.util.CosmosPagedIterable; -+import com.fasterxml.jackson.databind.node.ObjectNode; -+import org.testng.annotations.AfterClass; -+import org.testng.annotations.BeforeClass; -+import org.testng.annotations.Factory; -+import org.testng.annotations.Test; -+ -+import java.util.ArrayList; -+import java.util.Arrays; -+import java.util.Collections; -+import java.util.List; -+import java.util.UUID; -+import java.util.stream.Collectors; -+ -+import static org.assertj.core.api.Assertions.assertThat; -+import static org.assertj.core.api.Assertions.fail; -+ -+public class ReadManyByPartitionKeyTest extends TestSuiteBase { -+ -+ private String preExistingDatabaseId = CosmosDatabaseForTest.generateId(); -+ private CosmosClient client; -+ private CosmosDatabase createdDatabase; -+ -+ // Single PK container (/mypk) -+ private CosmosContainer singlePkContainer; -+ -+ // HPK container (/city, /zipcode, /areaCode) -+ private CosmosContainer multiHashContainer; -+ -+ @Factory(dataProvider = "clientBuilders") -+ public ReadManyByPartitionKeyTest(CosmosClientBuilder clientBuilder) { -+ super(clientBuilder); -+ } -+ -+ @BeforeClass(groups = {"emulator"}, timeOut = SETUP_TIMEOUT) -+ public void before_ReadManyByPartitionKeyTest() { -+ client = getClientBuilder().buildClient(); -+ createdDatabase = createSyncDatabase(client, preExistingDatabaseId); -+ -+ // Single PK container -+ String singlePkContainerName = UUID.randomUUID().toString(); -+ CosmosContainerProperties singlePkProps = new CosmosContainerProperties(singlePkContainerName, "/mypk"); -+ createdDatabase.createContainer(singlePkProps); -+ singlePkContainer = createdDatabase.getContainer(singlePkContainerName); -+ -+ // HPK container -+ String multiHashContainerName = UUID.randomUUID().toString(); -+ PartitionKeyDefinition hpkDef = new PartitionKeyDefinition(); -+ hpkDef.setKind(PartitionKind.MULTI_HASH); -+ hpkDef.setVersion(PartitionKeyDefinitionVersion.V2); -+ ArrayList paths = new ArrayList<>(); -+ paths.add("/city"); -+ paths.add("/zipcode"); -+ paths.add("/areaCode"); -+ hpkDef.setPaths(paths); -+ -+ CosmosContainerProperties hpkProps = new CosmosContainerProperties(multiHashContainerName, hpkDef); -+ createdDatabase.createContainer(hpkProps); -+ multiHashContainer = createdDatabase.getContainer(multiHashContainerName); -+ } -+ -+ @AfterClass(groups = {"emulator"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) -+ public void afterClass() { -+ safeDeleteSyncDatabase(createdDatabase); -+ safeCloseSyncClient(client); -+ } -+ -+ //region Single PK tests -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_basic() { -+ // Create items with different PKs -+ List items = createSinglePkItems("pk1", 3); -+ items.addAll(createSinglePkItems("pk2", 2)); -+ items.addAll(createSinglePkItems("pk3", 4)); -+ -+ // Read by 2 partition keys -+ List pkValues = Arrays.asList( -+ new PartitionKey("pk1"), -+ new PartitionKey("pk2")); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(5); // 3 + 2 -+ resultList.forEach(item -> { -+ String pk = item.get("mypk").asText(); -+ assertThat(pk).isIn("pk1", "pk2"); -+ }); -+ -+ // Cleanup -+ cleanupContainer(singlePkContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_withProjection() { -+ List items = createSinglePkItems("pkProj", 2); -+ -+ List pkValues = Collections.singletonList(new PartitionKey("pkProj")); -+ SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.mypk FROM c"); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( -+ pkValues, customQuery, null, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(2); -+ // Should only have id and mypk fields (plus system properties) -+ resultList.forEach(item -> { -+ assertThat(item.has("id")).isTrue(); -+ assertThat(item.has("mypk")).isTrue(); -+ }); -+ -+ cleanupContainer(singlePkContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_withAdditionalFilter() { -+ // Create items with different "status" values -+ createSinglePkItemsWithStatus("pkFilter", "active", 3); -+ createSinglePkItemsWithStatus("pkFilter", "inactive", 2); -+ -+ List pkValues = Collections.singletonList(new PartitionKey("pkFilter")); -+ SqlQuerySpec customQuery = new SqlQuerySpec( -+ "SELECT * FROM c WHERE c.status = @status", -+ Arrays.asList(new SqlParameter("@status", "active"))); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( -+ pkValues, customQuery, null, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(3); -+ resultList.forEach(item -> { -+ assertThat(item.get("status").asText()).isEqualTo("active"); -+ }); -+ -+ cleanupContainer(singlePkContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_emptyResults() { -+ List pkValues = Collections.singletonList(new PartitionKey("nonExistent")); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).isEmpty(); -+ } -+ -+ //endregion -+ -+ //region HPK tests -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void hpk_readManyByPartitionKey_fullPk() { -+ createHpkItems(); -+ -+ // Read by full PKs -+ List pkValues = Arrays.asList( -+ new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build(), -+ new PartitionKeyBuilder().add("Pittsburgh").add("15232").add(2).build()); -+ -+ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ // Redmond/98053/1 has 2 items, Pittsburgh/15232/2 has 1 item -+ assertThat(resultList).hasSize(3); -+ -+ cleanupContainer(multiHashContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void hpk_readManyByPartitionKey_partialPk_singleLevel() { -+ createHpkItems(); -+ -+ // Read by partial PK (only city) -+ List pkValues = Collections.singletonList( -+ new PartitionKeyBuilder().add("Redmond").build()); -+ -+ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ // Redmond has 3 items total (2 with 98053/1 and 1 with 12345/1) -+ assertThat(resultList).hasSize(3); -+ resultList.forEach(item -> { -+ assertThat(item.get("city").asText()).isEqualTo("Redmond"); -+ }); -+ -+ cleanupContainer(multiHashContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void hpk_readManyByPartitionKey_partialPk_twoLevels() { -+ createHpkItems(); -+ -+ // Read by partial PK (city + zipcode) -+ List pkValues = Collections.singletonList( -+ new PartitionKeyBuilder().add("Redmond").add("98053").build()); -+ -+ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ // Redmond/98053 has 2 items -+ assertThat(resultList).hasSize(2); -+ resultList.forEach(item -> { -+ assertThat(item.get("city").asText()).isEqualTo("Redmond"); -+ assertThat(item.get("zipcode").asText()).isEqualTo("98053"); -+ }); -+ -+ cleanupContainer(multiHashContainer); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void hpk_readManyByPartitionKey_withProjection() { -+ createHpkItems(); -+ -+ List pkValues = Collections.singletonList( -+ new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build()); -+ -+ SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.city FROM c"); -+ -+ CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey( -+ pkValues, customQuery, null, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(2); -+ -+ cleanupContainer(multiHashContainer); -+ } -+ -+ //endregion -+ -+ //region Negative/validation tests -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void rejectsAggregateQuery() { -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ SqlQuerySpec aggregateQuery = new SqlQuerySpec("SELECT COUNT(1) FROM c"); -+ -+ try { -+ singlePkContainer.readManyByPartitionKey(pkValues, aggregateQuery, null, ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ fail("Should have thrown IllegalArgumentException for aggregate query"); -+ } catch (IllegalArgumentException e) { -+ assertThat(e.getMessage()).contains("aggregates"); -+ } -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void rejectsOrderByQuery() { -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ SqlQuerySpec orderByQuery = new SqlQuerySpec("SELECT * FROM c ORDER BY c.id"); -+ -+ try { -+ singlePkContainer.readManyByPartitionKey(pkValues, orderByQuery, null, ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ fail("Should have thrown IllegalArgumentException for ORDER BY query"); -+ } catch (IllegalArgumentException e) { -+ assertThat(e.getMessage()).contains("ORDER BY"); -+ } -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void rejectsDistinctQuery() { -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ SqlQuerySpec distinctQuery = new SqlQuerySpec("SELECT DISTINCT c.mypk FROM c"); -+ -+ try { -+ singlePkContainer.readManyByPartitionKey(pkValues, distinctQuery, null, ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ fail("Should have thrown IllegalArgumentException for DISTINCT query"); -+ } catch (IllegalArgumentException e) { -+ assertThat(e.getMessage()).contains("DISTINCT"); -+ } -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void rejectsGroupByQuery() { -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk FROM c GROUP BY c.mypk"); -+ -+ try { -+ singlePkContainer.readManyByPartitionKey(pkValues, groupByQuery, null, ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ fail("Should have thrown IllegalArgumentException for GROUP BY query"); -+ } catch (IllegalArgumentException e) { -+ assertThat(e.getMessage()).contains("GROUP BY"); -+ } -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void rejectsGroupByWithAggregateQuery() { -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ SqlQuerySpec groupByWithAggregateQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); -+ -+ try { -+ singlePkContainer.readManyByPartitionKey(pkValues, groupByWithAggregateQuery, null, ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ fail("Should have thrown IllegalArgumentException for GROUP BY with aggregate query"); -+ } catch (IllegalArgumentException e) { -+ assertThat(e.getMessage()).contains("GROUP BY"); -+ } -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) -+ public void rejectsNullPartitionKeyList() { -+ singlePkContainer.readManyByPartitionKey((List) null, ObjectNode.class); -+ } -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) -+ public void rejectsEmptyPartitionKeyList() { -+ singlePkContainer.readManyByPartitionKey(new ArrayList<>(), ObjectNode.class) -+ .stream().collect(Collectors.toList()); -+ } -+ -+ //endregion -+ -+ -+ //region Batch size tests (#10) -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_withSmallBatchSize() { -+ // Temporarily set batch size to 2 to exercise the batching/interleaving logic -+ String originalValue = System.getProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); -+ try { -+ System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", "2"); -+ -+ // Create items across 4 PKs (more than the batch size of 2) -+ List items = createSinglePkItems("batchPk1", 2); -+ items.addAll(createSinglePkItems("batchPk2", 2)); -+ items.addAll(createSinglePkItems("batchPk3", 2)); -+ items.addAll(createSinglePkItems("batchPk4", 2)); -+ -+ // Read all 4 PKs ΓÇö should be split into batches of 2 -+ List pkValues = Arrays.asList( -+ new PartitionKey("batchPk1"), -+ new PartitionKey("batchPk2"), -+ new PartitionKey("batchPk3"), -+ new PartitionKey("batchPk4")); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(8); // 2 items per PK * 4 PKs -+ resultList.forEach(item -> { -+ String pk = item.get("mypk").asText(); -+ assertThat(pk).isIn("batchPk1", "batchPk2", "batchPk3", "batchPk4"); -+ }); -+ -+ cleanupContainer(singlePkContainer); -+ } finally { -+ if (originalValue != null) { -+ System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", originalValue); -+ } else { -+ System.clearProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); -+ } -+ } -+ } -+ -+ //endregion -+ -+ //region Custom serializer regression tests (#5) -+ -+ @Test(groups = {"emulator"}, timeOut = TIMEOUT) -+ public void singlePk_readManyByPartitionKey_withRequestOptions() { -+ // This test ensures that request options (like throughput control settings) -+ // are properly propagated through the readManyByPartitionKey path. -+ // It acts as a regression test for the redundant options construction bug. -+ List items = createSinglePkItems("pkOpts", 3); -+ -+ List pkValues = Collections.singletonList(new PartitionKey("pkOpts")); -+ com.azure.cosmos.models.CosmosReadManyRequestOptions options = new com.azure.cosmos.models.CosmosReadManyRequestOptions(); -+ -+ CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( -+ pkValues, options, ObjectNode.class); -+ List resultList = results.stream().collect(Collectors.toList()); -+ -+ assertThat(resultList).hasSize(3); -+ -+ cleanupContainer(singlePkContainer); -+ } -+ -+ //endregion -+ -+ //region helper methods -+ -+ private List createSinglePkItems(String pkValue, int count) { -+ List items = new ArrayList<>(); -+ for (int i = 0; i < count; i++) { -+ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); -+ item.put("id", UUID.randomUUID().toString()); -+ item.put("mypk", pkValue); -+ singlePkContainer.createItem(item); -+ items.add(item); -+ } -+ return items; -+ } -+ -+ private List createSinglePkItemsWithStatus(String pkValue, String status, int count) { -+ List items = new ArrayList<>(); -+ for (int i = 0; i < count; i++) { -+ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); -+ item.put("id", UUID.randomUUID().toString()); -+ item.put("mypk", pkValue); -+ item.put("status", status); -+ singlePkContainer.createItem(item); -+ items.add(item); -+ } -+ return items; -+ } -+ -+ private void createHpkItems() { -+ // Same data as CosmosMultiHashTest.createItems() -+ createHpkItem("Redmond", "98053", 1); -+ createHpkItem("Redmond", "98053", 1); -+ createHpkItem("Pittsburgh", "15232", 2); -+ createHpkItem("Stonybrook", "11790", 3); -+ createHpkItem("Stonybrook", "11794", 3); -+ createHpkItem("Stonybrook", "11791", 3); -+ createHpkItem("Redmond", "12345", 1); -+ } -+ -+ private void createHpkItem(String city, String zipcode, int areaCode) { -+ ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); -+ item.put("id", UUID.randomUUID().toString()); -+ item.put("city", city); -+ item.put("zipcode", zipcode); -+ item.put("areaCode", areaCode); -+ multiHashContainer.createItem(item); -+ } -+ -+ private void cleanupContainer(CosmosContainer container) { -+ CosmosPagedIterable allItems = container.queryItems( -+ "SELECT * FROM c", new com.azure.cosmos.models.CosmosQueryRequestOptions(), ObjectNode.class); -+ allItems.forEach(item -> { -+ try { -+ container.deleteItem(item, new CosmosItemRequestOptions()); -+ } catch (CosmosException e) { -+ // ignore cleanup failures -+ } -+ }); -+ } -+ -+ //endregion -+} -diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java -new file mode 100644 -index 00000000000..95c109ba025 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java -@@ -0,0 +1,426 @@ -+/* -+ * Copyright (c) Microsoft Corporation. All rights reserved. -+ * Licensed under the MIT License. -+ */ -+ -+package com.azure.cosmos.implementation; -+ -+import com.azure.cosmos.models.PartitionKey; -+import com.azure.cosmos.models.PartitionKeyBuilder; -+import com.azure.cosmos.models.PartitionKeyDefinition; -+import com.azure.cosmos.models.PartitionKeyDefinitionVersion; -+import com.azure.cosmos.models.PartitionKind; -+import com.azure.cosmos.models.SqlParameter; -+import com.azure.cosmos.models.SqlQuerySpec; -+import org.testng.annotations.Test; -+ -+import java.util.ArrayList; -+import java.util.Arrays; -+import java.util.Collections; -+import java.util.List; -+import java.util.stream.Collectors; -+ -+import static org.assertj.core.api.Assertions.assertThat; -+ -+public class ReadManyByPartitionKeyQueryHelperTest { -+ -+ //region Single PK (HASH) tests -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_defaultQuery_singleValue() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); -+ assertThat(result.getQueryText()).contains("IN ("); -+ assertThat(result.getQueryText()).contains("@__rmPk_0"); -+ assertThat(result.getParameters()).hasSize(1); -+ assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("pk1"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_defaultQuery_multipleValues() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Arrays.asList( -+ new PartitionKey("pk1"), -+ new PartitionKey("pk2"), -+ new PartitionKey("pk3")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("IN ("); -+ assertThat(result.getQueryText()).contains("@__rmPk_0"); -+ assertThat(result.getQueryText()).contains("@__rmPk_1"); -+ assertThat(result.getQueryText()).contains("@__rmPk_2"); -+ assertThat(result.getParameters()).hasSize(3); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_customQuery_noWhere() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT c.name, c.age FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).startsWith("SELECT c.name, c.age FROM c WHERE"); -+ assertThat(result.getQueryText()).contains("IN ("); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_customQuery_withExistingWhere() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ -+ List baseParams = new ArrayList<>(); -+ baseParams.add(new SqlParameter("@minAge", 18)); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c WHERE c.age > @minAge", baseParams, pkValues, selectors, pkDef); -+ -+ // Should AND the PK filter to the existing WHERE clause -+ assertThat(result.getQueryText()).contains("WHERE (c.age > @minAge) AND ("); -+ assertThat(result.getQueryText()).contains("IN ("); -+ assertThat(result.getParameters()).hasSize(2); // @minAge + @__rmPk_0 -+ assertThat(result.getParameters().get(0).getName()).isEqualTo("@minAge"); -+ } -+ -+ //endregion -+ -+ //region HPK (MULTI_HASH) tests -+ -+ @Test(groups = { "unit" }) -+ public void hpk_fullPk_defaultQuery() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); -+ List selectors = createSelectors(pkDef); -+ -+ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); -+ List pkValues = Collections.singletonList(pk); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); -+ // Should use OR/AND pattern, not IN -+ assertThat(result.getQueryText()).doesNotContain("IN ("); -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); -+ assertThat(result.getQueryText()).contains("AND"); -+ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); -+ assertThat(result.getParameters()).hasSize(2); -+ assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("Redmond"); -+ assertThat(result.getParameters().get(1).getValue(Object.class)).isEqualTo("98052"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void hpk_fullPk_multipleValues() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); -+ List selectors = createSelectors(pkDef); -+ -+ List pkValues = Arrays.asList( -+ new PartitionKeyBuilder().add("Redmond").add("98052").build(), -+ new PartitionKeyBuilder().add("Seattle").add("98101").build()); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("OR"); -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); -+ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_2"); -+ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_3"); -+ assertThat(result.getParameters()).hasSize(4); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void hpk_partialPk_singleLevel() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); -+ List selectors = createSelectors(pkDef); -+ -+ // Partial PK ΓÇö only first level -+ PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").build(); -+ List pkValues = Collections.singletonList(partialPk); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); -+ // Should NOT include zipcode or areaCode since it's partial -+ assertThat(result.getQueryText()).doesNotContain("zipcode"); -+ assertThat(result.getQueryText()).doesNotContain("areaCode"); -+ assertThat(result.getParameters()).hasSize(1); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void hpk_partialPk_twoLevels() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); -+ List selectors = createSelectors(pkDef); -+ -+ // Partial PK ΓÇö first two levels -+ PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); -+ List pkValues = Collections.singletonList(partialPk); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); -+ assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); -+ assertThat(result.getQueryText()).doesNotContain("areaCode"); -+ assertThat(result.getParameters()).hasSize(2); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void hpk_customQuery_withWhere() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); -+ List selectors = createSelectors(pkDef); -+ -+ List baseParams = new ArrayList<>(); -+ baseParams.add(new SqlParameter("@status", "active")); -+ -+ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); -+ List pkValues = Collections.singletonList(pk); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT c.name FROM c WHERE c.status = @status", baseParams, pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("WHERE (c.status = @status) AND ("); -+ assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); -+ assertThat(result.getParameters()).hasSize(3); // @status + 2 pk params -+ } -+ -+ //endregion -+ -+ //region findTopLevelWhereIndex tests -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_simpleQuery() { -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.ID = 1"); -+ assertThat(idx).isEqualTo(16); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_noWhere() { -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C"); -+ assertThat(idx).isEqualTo(-1); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_whereInSubquery() { -+ // WHERE inside parentheses (subquery) should be ignored -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( -+ "SELECT * FROM C WHERE EXISTS(SELECT VALUE T FROM T IN C.TAGS WHERE T = 'FOO')"); -+ // Should find the outer WHERE, not the inner one -+ assertThat(idx).isEqualTo(16); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_caseInsensitive() { -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.X = 1"); -+ assertThat(idx).isGreaterThan(0); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_whereNotKeyword() { -+ // "ELSEWHERE" should not match -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM ELSEWHERE"); -+ assertThat(idx).isEqualTo(-1); -+ } -+ -+ //endregion -+ -+ //region Custom alias tests -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_customAlias() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT x.id, x.mypk FROM x", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).startsWith("SELECT x.id, x.mypk FROM x WHERE"); -+ assertThat(result.getQueryText()).contains("x[\"mypk\"] IN ("); -+ assertThat(result.getQueryText()).doesNotContain("c[\"mypk\"]"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_customAlias_withWhere() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ -+ List baseParams = new ArrayList<>(); -+ baseParams.add(new SqlParameter("@cat", "HelloWorld")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT x.id, x.mypk FROM x WHERE x.category = @cat", baseParams, pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("WHERE (x.category = @cat) AND (x[\"mypk\"] IN ("); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void hpk_customAlias() { -+ PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); -+ List selectors = createSelectors(pkDef); -+ -+ PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); -+ List pkValues = Collections.singletonList(pk); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT r.name FROM root r", new ArrayList<>(), pkValues, selectors, pkDef); -+ -+ assertThat(result.getQueryText()).contains("r[\"city\"] = @__rmPk_0"); -+ assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @__rmPk_1"); -+ assertThat(result.getQueryText()).doesNotContain("c[\""); -+ } -+ -+ //endregion -+ -+ //region extractTableAlias tests -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_defaultC() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM c")).isEqualTo("c"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_customX() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT x.id FROM x WHERE x.age > 5")).isEqualTo("x"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_rootWithAlias() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT r.name FROM root r")).isEqualTo("r"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_rootNoAlias() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM root")).isEqualTo("root"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_containerWithWhere() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM items WHERE items.status = 'active'")).isEqualTo("items"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_caseInsensitive() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("select * from MyContainer where MyContainer.id = '1'")).isEqualTo("MyContainer"); -+ } -+ -+ //endregion -+ -+ -+ //region String literal handling tests (#1) -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_ignoresWhereInsideStringLiteral() { -+ // WHERE inside a string literal should be ignored -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( -+ "SELECT * FROM c WHERE c.msg = 'use WHERE clause here'"); -+ // Should find the outer WHERE at position 16, not the one inside the string -+ assertThat(idx).isEqualTo(16); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_ignoresParenthesesInsideStringLiteral() { -+ // Parentheses inside string literal should not affect depth tracking -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( -+ "SELECT * FROM c WHERE c.name = 'foo(bar)' AND c.x = 1"); -+ assertThat(idx).isEqualTo(16); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_handlesUnbalancedParenInStringLiteral() { -+ // Unbalanced paren inside string literal must not corrupt depth -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( -+ "SELECT * FROM c WHERE c.val = 'open(' AND c.active = true"); -+ assertThat(idx).isEqualTo(16); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void findWhere_handlesStringLiteralBeforeWhere() { -+ // String literal in SELECT before WHERE -+ int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( -+ "SELECT 'WHERE' as label FROM c WHERE c.id = '1'"); -+ // The WHERE inside quotes should be ignored; the real WHERE is further along -+ assertThat(idx).isGreaterThan(30); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void singlePk_customQuery_withStringLiteralContainingParens() { -+ PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); -+ List selectors = createSelectors(pkDef); -+ List pkValues = Collections.singletonList(new PartitionKey("pk1")); -+ -+ List baseParams = new ArrayList<>(); -+ baseParams.add(new SqlParameter("@msg", "hello")); -+ -+ SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( -+ "SELECT * FROM c WHERE c.msg = 'test(value)WHERE'", baseParams, pkValues, selectors, pkDef); -+ -+ // Should correctly AND the PK filter to the real WHERE clause -+ assertThat(result.getQueryText()).contains("WHERE (c.msg = 'test(value)WHERE') AND ("); -+ } -+ -+ //endregion -+ -+ //region OFFSET/LIMIT/HAVING alias detection tests (#9) -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_containerWithOffset() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( -+ "SELECT * FROM c OFFSET 10 LIMIT 5")).isEqualTo("c"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_containerWithLimit() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( -+ "SELECT * FROM c LIMIT 10")).isEqualTo("c"); -+ } -+ -+ @Test(groups = { "unit" }) -+ public void extractAlias_containerWithHaving() { -+ assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( -+ "SELECT c.status, COUNT(1) FROM c GROUP BY c.status HAVING COUNT(1) > 1")).isEqualTo("c"); -+ } -+ -+ //endregion -+ -+ //region helpers -+ -+ private PartitionKeyDefinition createSinglePkDefinition(String path) { -+ PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); -+ pkDef.setKind(PartitionKind.HASH); -+ pkDef.setVersion(PartitionKeyDefinitionVersion.V2); -+ pkDef.setPaths(Collections.singletonList(path)); -+ return pkDef; -+ } -+ -+ private PartitionKeyDefinition createMultiHashPkDefinition(String... paths) { -+ PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); -+ pkDef.setKind(PartitionKind.MULTI_HASH); -+ pkDef.setVersion(PartitionKeyDefinitionVersion.V2); -+ pkDef.setPaths(Arrays.asList(paths)); -+ return pkDef; -+ } -+ -+ private List createSelectors(PartitionKeyDefinition pkDef) { -+ return pkDef.getPaths() -+ .stream() -+ .map(pathPart -> pathPart.substring(1)) // skip starting / -+ .map(pathPart -> pathPart.replace("\"", "\\")) -+ .map(part -> "[\"" + part + "\"]") -+ .collect(Collectors.toList()); -+ } -+ -+ //endregion -+} -diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md -index e8ea564fab7..904c01c3238 100644 ---- a/sdk/cosmos/azure-cosmos/CHANGELOG.md -+++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md -@@ -3,6 +3,7 @@ - ### 4.80.0-beta.1 (Unreleased) - - #### Features Added -+* Added new `readManyByPartitionKey` to bulk query by a list of pk-values with better efficiency. See [PR 48801](https://github.com/Azure/azure-sdk-for-java/pull/48801) - - #### Breaking Changes - -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java -index ad871bb97c0..4e234667c1c 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java -@@ -165,6 +165,7 @@ public class CosmosAsyncContainer { - private final String createItemSpanName; - private final String readAllItemsSpanName; - private final String readManyItemsSpanName; -+ private final String readManyByPartitionKeyItemsSpanName; - private final String readAllItemsOfLogicalPartitionSpanName; - private final String queryItemsSpanName; - private final String queryChangeFeedSpanName; -@@ -198,6 +199,7 @@ public class CosmosAsyncContainer { - this.createItemSpanName = "createItem." + this.id; - this.readAllItemsSpanName = "readAllItems." + this.id; - this.readManyItemsSpanName = "readManyItems." + this.id; -+ this.readManyByPartitionKeyItemsSpanName = "readManyByPartitionKeyItems." + this.id; - this.readAllItemsOfLogicalPartitionSpanName = "readAllItemsOfLogicalPartition." + this.id; - this.queryItemsSpanName = "queryItems." + this.id; - this.queryChangeFeedSpanName = "queryChangeFeed." + this.id; -@@ -1601,6 +1603,130 @@ public class CosmosAsyncContainer { - context); - } - -+ /** -+ * Reads many documents matching the provided partition key values. -+ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries -+ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} -+ * as the base query. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param classType class type -+ * @return a {@link CosmosPagedFlux} containing one or several feed response pages -+ */ -+ public CosmosPagedFlux readManyByPartitionKey( -+ List partitionKeys, -+ Class classType) { -+ -+ return this.readManyByPartitionKey(partitionKeys, null, null, classType); -+ } -+ -+ /** -+ * Reads many documents matching the provided partition key values. -+ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries -+ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} -+ * as the base query. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param requestOptions the optional request options -+ * @param classType class type -+ * @return a {@link CosmosPagedFlux} containing one or several feed response pages -+ */ -+ public CosmosPagedFlux readManyByPartitionKey( -+ List partitionKeys, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) { -+ -+ return this.readManyByPartitionKey(partitionKeys, null, requestOptions, classType); -+ } -+ -+ /** -+ * Reads many documents matching the provided partition key values with a custom query. -+ * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) -+ * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). -+ * The SDK will automatically append partition key filtering to the custom query. -+ *

-+ * The custom query must be a simple streamable query ΓÇö aggregates, ORDER BY, DISTINCT, -+ * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be -+ * rejected. -+ *

-+ * Partial hierarchical partition keys are supported and will fan out to multiple -+ * physical partitions. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) -+ * @param requestOptions the optional request options -+ * @param classType class type -+ * @return a {@link CosmosPagedFlux} containing one or several feed response pages -+ */ -+ public CosmosPagedFlux readManyByPartitionKey( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) { -+ -+ if (partitionKeys == null) { -+ throw new IllegalArgumentException("Argument 'partitionKeys' must not be null."); -+ } -+ if (partitionKeys.isEmpty()) { -+ throw new IllegalArgumentException("Argument 'partitionKeys' must not be empty."); -+ } -+ for (PartitionKey pk : partitionKeys) { -+ if (pk == null) { -+ throw new IllegalArgumentException( -+ "Argument 'partitionKeys' must not contain null elements."); -+ } -+ } -+ -+ return UtilBridgeInternal.createCosmosPagedFlux( -+ readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); -+ } -+ -+ private Function>> readManyByPartitionKeyInternalFunc( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) { -+ -+ CosmosAsyncClient client = this.getDatabase().getClient(); -+ -+ return (pagedFluxOptions -> { -+ CosmosQueryRequestOptions queryRequestOptions = requestOptions == null -+ ? new CosmosQueryRequestOptions() -+ : queryOptionsAccessor().clone(readManyOptionsAccessor().getImpl(requestOptions)); -+ queryRequestOptions.setMaxDegreeOfParallelism(-1); -+ queryRequestOptions.setQueryName("readManyByPartitionKey"); -+ CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); -+ applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeyItemsSpanName); -+ -+ QueryFeedOperationState state = new QueryFeedOperationState( -+ client, -+ this.readManyByPartitionKeyItemsSpanName, -+ database.getId(), -+ this.getId(), -+ ResourceType.Document, -+ OperationType.Query, -+ queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeyItemsSpanName), -+ queryRequestOptions, -+ pagedFluxOptions -+ ); -+ -+ pagedFluxOptions.setFeedOperationState(state); -+ -+ return CosmosBridgeInternal -+ .getAsyncDocumentClient(this.getDatabase()) -+ .readManyByPartitionKey( -+ partitionKeys, -+ customQuery, -+ BridgeInternal.getLink(this), -+ state, -+ classType) -+ .map(response -> prepareFeedResponse(response, false)); -+ }); -+ } -+ - /** - * Reads all the items of a logical partition - * -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java -index 04a6060c192..0bd8be5850c 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java -@@ -540,6 +540,73 @@ public class CosmosContainer { - classType)); - } - -+ /** -+ * Reads many documents matching the provided partition key values. -+ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries -+ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} -+ * as the base query. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param classType class type -+ * @return a {@link CosmosPagedIterable} containing the results -+ */ -+ public CosmosPagedIterable readManyByPartitionKey( -+ List partitionKeys, -+ Class classType) { -+ -+ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, classType)); -+ } -+ -+ /** -+ * Reads many documents matching the provided partition key values. -+ * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries -+ * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} -+ * as the base query. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param requestOptions the optional request options -+ * @param classType class type -+ * @return a {@link CosmosPagedIterable} containing the results -+ */ -+ public CosmosPagedIterable readManyByPartitionKey( -+ List partitionKeys, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) { -+ -+ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, requestOptions, classType)); -+ } -+ -+ /** -+ * Reads many documents matching the provided partition key values with a custom query. -+ * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) -+ * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). -+ * The SDK will automatically append partition key filtering to the custom query. -+ *

-+ * The custom query must be a simple streamable query ΓÇö aggregates, ORDER BY, DISTINCT, -+ * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be -+ * rejected. -+ *

-+ * Partial hierarchical partition keys are supported and will fan out to multiple -+ * physical partitions. -+ * -+ * @param the type parameter -+ * @param partitionKeys list of partition key values to read documents for -+ * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) -+ * @param requestOptions the optional request options -+ * @param classType class type -+ * @return a {@link CosmosPagedIterable} containing the results -+ */ -+ public CosmosPagedIterable readManyByPartitionKey( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) { -+ -+ return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, requestOptions, classType)); -+ } -+ - /** - * Reads all the items of a logical partition returning the results as {@link CosmosPagedIterable}. - * -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java -index 945e768a82f..8e2499c9039 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java -@@ -1584,6 +1584,27 @@ public interface AsyncDocumentClient { - QueryFeedOperationState state, - Class klass); - -+ /** -+ * Reads many documents by partition key values. -+ * Unlike {@link #readMany(List, String, QueryFeedOperationState, Class)} this method does not require -+ * item ids - it queries all documents matching the provided partition key values. -+ * Partial hierarchical partition keys are supported and will fan out to multiple physical partitions. -+ * -+ * @param partitionKeys list of partition key values to read documents for -+ * @param customQuery optional custom query (for projections/additional filters) - null means SELECT * FROM c -+ * @param collectionLink link for the documentcollection/container to be queried -+ * @param state the query operation state -+ * @param klass class type -+ * @param the type parameter -+ * @return a Flux with feed response pages of documents -+ */ -+ Flux> readManyByPartitionKey( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ String collectionLink, -+ QueryFeedOperationState state, -+ Class klass); -+ - /** - * Read all documents of a certain logical partition. - *

-diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java -index 337055c6947..162b0740f40 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java -@@ -248,6 +248,11 @@ public class Configs { - public static final String MIN_TARGET_BULK_MICRO_BATCH_SIZE_VARIABLE = "COSMOS_MIN_TARGET_BULK_MICRO_BATCH_SIZE"; - public static final int DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE = 1; - -+ // readManyByPartitionKey: max number of PK values per query per physical partition -+ private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; -+ private static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; -+ private static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 1000; -+ - public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY = "COSMOS.MAX_BULK_MICRO_BATCH_CONCURRENCY"; - public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY_VARIABLE = "COSMOS_MAX_BULK_MICRO_BATCH_CONCURRENCY"; - public static final int DEFAULT_MAX_BULK_MICRO_BATCH_CONCURRENCY = 1; -@@ -816,6 +821,20 @@ public class Configs { - return DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE; - } - -+ public static int getReadManyByPkMaxBatchSize() { -+ String valueFromSystemProperty = System.getProperty(READ_MANY_BY_PK_MAX_BATCH_SIZE); -+ if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { -+ return Math.max(1, Integer.parseInt(valueFromSystemProperty)); -+ } -+ -+ String valueFromEnvVariable = System.getenv(READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE); -+ if (valueFromEnvVariable != null && !valueFromEnvVariable.isEmpty()) { -+ return Math.max(1, Integer.parseInt(valueFromEnvVariable)); -+ } -+ -+ return DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE; -+ } -+ - public static int getMaxBulkMicroBatchConcurrency() { - String valueFromSystemProperty = System.getProperty(MAX_BULK_MICRO_BATCH_CONCURRENCY); - if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java -new file mode 100644 -index 00000000000..6d6cd084e01 ---- /dev/null -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java -@@ -0,0 +1,263 @@ -+// Copyright (c) Microsoft Corporation. All rights reserved. -+// Licensed under the MIT License. -+package com.azure.cosmos.implementation; -+ -+import com.azure.cosmos.BridgeInternal; -+import com.azure.cosmos.implementation.routing.PartitionKeyInternal; -+import com.azure.cosmos.models.PartitionKey; -+import com.azure.cosmos.models.PartitionKeyDefinition; -+import com.azure.cosmos.models.PartitionKind; -+import com.azure.cosmos.models.SqlParameter; -+import com.azure.cosmos.models.SqlQuerySpec; -+ -+import java.util.ArrayList; -+import java.util.List; -+ -+/** -+ * Helper for constructing SqlQuerySpec instances for readManyByPartitionKey operations. -+ * This class is not intended to be used directly by end-users. -+ */ -+public class ReadManyByPartitionKeyQueryHelper { -+ -+ private static final String DEFAULT_TABLE_ALIAS = "c"; -+ // Internal parameter prefix ΓÇö uses double-underscore to avoid collisions with user-provided parameters -+ private static final String PK_PARAM_PREFIX = "@__rmPk_"; -+ -+ public static SqlQuerySpec createReadManyByPkQuerySpec( -+ String baseQueryText, -+ List baseParameters, -+ List pkValues, -+ List partitionKeySelectors, -+ PartitionKeyDefinition pkDefinition) { -+ -+ // Extract the table alias from the FROM clause (e.g. "FROM x" ΓåÆ "x", "FROM c" ΓåÆ "c") -+ String tableAlias = extractTableAlias(baseQueryText); -+ -+ StringBuilder pkFilter = new StringBuilder(); -+ List parameters = new ArrayList<>(baseParameters); -+ int paramCount = 0; -+ -+ boolean isSinglePathPk = partitionKeySelectors.size() == 1; -+ -+ if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { -+ // Single PK path ΓÇö use IN clause for normal values, OR NOT IS_DEFINED for NONE -+ // First, separate NONE PKs from normal PKs -+ boolean hasNone = false; -+ List normalPkValues = new ArrayList<>(); -+ for (PartitionKey pk : pkValues) { -+ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); -+ if (pkInternal.getComponents() == null) { -+ hasNone = true; -+ } else { -+ normalPkValues.add(pk); -+ } -+ } -+ -+ pkFilter.append(" "); -+ boolean hasNormalValues = !normalPkValues.isEmpty(); -+ if (hasNormalValues && hasNone) { -+ pkFilter.append("("); -+ } -+ if (hasNormalValues) { -+ pkFilter.append(tableAlias); -+ pkFilter.append(partitionKeySelectors.get(0)); -+ pkFilter.append(" IN ( "); -+ for (int i = 0; i < normalPkValues.size(); i++) { -+ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(normalPkValues.get(i)); -+ Object[] pkComponents = pkInternal.toObjectArray(); -+ String pkParamName = PK_PARAM_PREFIX + paramCount; -+ parameters.add(new SqlParameter(pkParamName, pkComponents[0])); -+ paramCount++; -+ -+ pkFilter.append(pkParamName); -+ if (i < normalPkValues.size() - 1) { -+ pkFilter.append(", "); -+ } -+ } -+ pkFilter.append(" )"); -+ } -+ if (hasNone) { -+ if (hasNormalValues) { -+ pkFilter.append(" OR "); -+ } -+ pkFilter.append("NOT IS_DEFINED("); -+ pkFilter.append(tableAlias); -+ pkFilter.append(partitionKeySelectors.get(0)); -+ pkFilter.append(")"); -+ } -+ if (hasNormalValues && hasNone) { -+ pkFilter.append(")"); -+ } -+ } else { -+ // Multiple PK paths (HPK) or MULTI_HASH ΓÇö use OR of AND clauses -+ pkFilter.append(" "); -+ for (int i = 0; i < pkValues.size(); i++) { -+ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); -+ Object[] pkComponents = pkInternal.toObjectArray(); -+ -+ // PartitionKey.NONE ΓÇö generate NOT IS_DEFINED for all PK paths -+ if (pkComponents == null) { -+ pkFilter.append("("); -+ for (int j = 0; j < partitionKeySelectors.size(); j++) { -+ if (j > 0) { -+ pkFilter.append(" AND "); -+ } -+ pkFilter.append("NOT IS_DEFINED("); -+ pkFilter.append(tableAlias); -+ pkFilter.append(partitionKeySelectors.get(j)); -+ pkFilter.append(")"); -+ } -+ pkFilter.append(")"); -+ } else { -+ pkFilter.append("("); -+ for (int j = 0; j < pkComponents.length; j++) { -+ String pkParamName = PK_PARAM_PREFIX + paramCount; -+ parameters.add(new SqlParameter(pkParamName, pkComponents[j])); -+ paramCount++; -+ -+ if (j > 0) { -+ pkFilter.append(" AND "); -+ } -+ pkFilter.append(tableAlias); -+ pkFilter.append(partitionKeySelectors.get(j)); -+ pkFilter.append(" = "); -+ pkFilter.append(pkParamName); -+ } -+ pkFilter.append(")"); -+ } -+ -+ if (i < pkValues.size() - 1) { -+ pkFilter.append(" OR "); -+ } -+ } -+ } -+ -+ // Compose final query: handle existing WHERE clause in base query -+ String finalQuery; -+ int whereIndex = findTopLevelWhereIndex(baseQueryText); -+ if (whereIndex >= 0) { -+ // Base query has WHERE ΓÇö AND our PK filter -+ String beforeWhere = baseQueryText.substring(0, whereIndex); -+ String afterWhere = baseQueryText.substring(whereIndex + 5); // skip "WHERE" -+ finalQuery = beforeWhere + "WHERE (" + afterWhere.trim() + ") AND (" + pkFilter.toString().trim() + ")"; -+ } else { -+ // No WHERE ΓÇö add one -+ finalQuery = baseQueryText + " WHERE" + pkFilter.toString(); -+ } -+ -+ return new SqlQuerySpec(finalQuery, parameters); -+ } -+ -+ /** -+ * Extracts the table/collection alias from a SQL query's FROM clause. -+ * Handles: "SELECT * FROM c", "SELECT x.id FROM x WHERE ...", "SELECT * FROM root r", etc. -+ * Returns the alias used after FROM (last token before WHERE or end of FROM clause). -+ */ -+ static String extractTableAlias(String queryText) { -+ String upper = queryText.toUpperCase(); -+ int fromIndex = findTopLevelKeywordIndex(upper, "FROM"); -+ if (fromIndex < 0) { -+ return DEFAULT_TABLE_ALIAS; -+ } -+ -+ // Start scanning after "FROM" -+ int afterFrom = fromIndex + 4; -+ // Skip whitespace -+ while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { -+ afterFrom++; -+ } -+ -+ // Collect the container name token (could be "root", "c", etc.) -+ int tokenStart = afterFrom; -+ while (afterFrom < queryText.length() -+ && !Character.isWhitespace(queryText.charAt(afterFrom)) -+ && queryText.charAt(afterFrom) != '(' -+ && queryText.charAt(afterFrom) != ')') { -+ afterFrom++; -+ } -+ String containerName = queryText.substring(tokenStart, afterFrom); -+ -+ // Skip whitespace after container name -+ while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { -+ afterFrom++; -+ } -+ -+ // Check if there's an alias after the container name (before WHERE or end) -+ if (afterFrom < queryText.length()) { -+ char nextChar = Character.toUpperCase(queryText.charAt(afterFrom)); -+ // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING) or end, containerName IS the alias -+ if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J' -+ || nextChar == 'L' || nextChar == 'H') { -+ // Check if it's actually a keyword -+ String remaining = upper.substring(afterFrom); -+ if (remaining.startsWith("WHERE") || remaining.startsWith("ORDER") -+ || remaining.startsWith("GROUP") || remaining.startsWith("JOIN") -+ || remaining.startsWith("OFFSET") || remaining.startsWith("LIMIT") -+ || remaining.startsWith("HAVING")) { -+ return containerName; -+ } -+ } -+ // Otherwise the next token is the alias ("FROM root r" ΓåÆ alias is "r") -+ int aliasStart = afterFrom; -+ while (afterFrom < queryText.length() -+ && !Character.isWhitespace(queryText.charAt(afterFrom)) -+ && queryText.charAt(afterFrom) != '(' -+ && queryText.charAt(afterFrom) != ')') { -+ afterFrom++; -+ } -+ if (afterFrom > aliasStart) { -+ return queryText.substring(aliasStart, afterFrom); -+ } -+ } -+ -+ return containerName; -+ } -+ -+ /** -+ * Finds the index of a top-level SQL keyword in the query text (case-insensitive), -+ * ignoring occurrences inside parentheses or string literals. -+ */ -+ static int findTopLevelKeywordIndex(String queryText, String keyword) { -+ String queryTextUpper = queryText.toUpperCase(); -+ String keywordUpper = keyword.toUpperCase(); -+ int depth = 0; -+ int keyLen = keywordUpper.length(); -+ for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { -+ char ch = queryTextUpper.charAt(i); -+ // Skip string literals enclosed in single quotes (handle '' escape) -+ if (queryText.charAt(i) == '\'') { -+ i++; -+ while (i < queryText.length()) { -+ if (queryText.charAt(i) == '\'') { -+ if (i + 1 < queryText.length() && queryText.charAt(i + 1) == '\'') { -+ i += 2; // escaped quote ΓÇö skip both -+ continue; -+ } -+ break; // end of string literal -+ } -+ i++; -+ } -+ continue; -+ } -+ if (ch == '(') { -+ depth++; -+ } else if (ch == ')') { -+ depth--; -+ } else if (depth == 0 && ch == keywordUpper.charAt(0) -+ && queryTextUpper.startsWith(keywordUpper, i) -+ && (i == 0 || !Character.isLetterOrDigit(queryTextUpper.charAt(i - 1))) -+ && (i + keyLen >= queryTextUpper.length() || !Character.isLetterOrDigit(queryTextUpper.charAt(i + keyLen)))) { -+ return i; -+ } -+ } -+ return -1; -+ } -+ -+ /** -+ * Finds the index of the top-level WHERE keyword in the query text, -+ * ignoring WHERE that appears inside parentheses (subqueries). -+ */ -+ public static int findTopLevelWhereIndex(String queryTextUpper) { -+ return findTopLevelKeywordIndex(queryTextUpper, "WHERE"); -+ } -+} -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java -index 11121bca033..c70dedaa1f2 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java -@@ -4365,6 +4365,298 @@ public class RxDocumentClientImpl implements AsyncDocumentClient, IAuthorization - ); - } - -+ @Override -+ public Flux> readManyByPartitionKey( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ String collectionLink, -+ QueryFeedOperationState state, -+ Class klass) { -+ -+ checkNotNull(partitionKeys, "Argument 'partitionKeys' must not be null."); -+ checkArgument(!partitionKeys.isEmpty(), "Argument 'partitionKeys' must not be empty."); -+ -+ final ScopedDiagnosticsFactory diagnosticsFactory = new ScopedDiagnosticsFactory(this, true); -+ state.registerDiagnosticsFactory( -+ () -> {}, // we never want to reset in readManyByPartitionKey -+ (ctx) -> diagnosticsFactory.merge(ctx) -+ ); -+ -+ StaleResourceRetryPolicy staleResourceRetryPolicy = new StaleResourceRetryPolicy( -+ this.collectionCache, -+ null, -+ collectionLink, -+ queryOptionsAccessor().getProperties(state.getQueryOptions()), -+ queryOptionsAccessor().getHeaders(state.getQueryOptions()), -+ this.sessionContainer, -+ diagnosticsFactory, -+ ResourceType.Document -+ ); -+ -+ return ObservableHelper -+ .fluxInlineIfPossibleAsObs( -+ () -> readManyByPartitionKey( -+ partitionKeys, customQuery, collectionLink, state, diagnosticsFactory, klass), -+ staleResourceRetryPolicy -+ ) -+ .onErrorMap(throwable -> { -+ if (throwable instanceof CosmosException) { -+ CosmosException cosmosException = (CosmosException) throwable; -+ CosmosDiagnostics diagnostics = cosmosException.getDiagnostics(); -+ if (diagnostics != null) { -+ state.mergeDiagnosticsContext(); -+ CosmosDiagnosticsContext ctx = state.getDiagnosticsContextSnapshot(); -+ if (ctx != null) { -+ ctxAccessor().recordOperation( -+ ctx, -+ cosmosException.getStatusCode(), -+ cosmosException.getSubStatusCode(), -+ 0, -+ cosmosException.getRequestCharge(), -+ diagnostics, -+ throwable -+ ); -+ diagAccessor() -+ .setDiagnosticsContext( -+ diagnostics, -+ state.getDiagnosticsContextSnapshot()); -+ } -+ } -+ -+ return cosmosException; -+ } -+ -+ return throwable; -+ }); -+ } -+ -+ private Flux> readManyByPartitionKey( -+ List partitionKeys, -+ SqlQuerySpec customQuery, -+ String collectionLink, -+ QueryFeedOperationState state, -+ ScopedDiagnosticsFactory diagnosticsFactory, -+ Class klass) { -+ -+ String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); -+ RxDocumentServiceRequest request = RxDocumentServiceRequest.create(diagnosticsFactory, -+ OperationType.Query, -+ ResourceType.Document, -+ collectionLink, null -+ ); -+ -+ Mono> collectionObs = -+ collectionCache.resolveCollectionAsync(null, request); -+ -+ return collectionObs -+ .flatMapMany(documentCollectionResourceResponse -> { -+ final DocumentCollection collection = documentCollectionResourceResponse.v; -+ if (collection == null) { -+ return Flux.error(new IllegalStateException("Collection cannot be null")); -+ } -+ -+ final PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); -+ -+ Mono> valueHolderMono = partitionKeyRangeCache -+ .tryLookupAsync( -+ BridgeInternal.getMetaDataDiagnosticContext(request.requestContext.cosmosDiagnostics), -+ collection.getResourceId(), -+ null, -+ null); -+ -+ // Validate custom query if provided -+ Mono queryValidationMono; -+ if (customQuery != null) { -+ queryValidationMono = validateCustomQueryForReadManyByPartitionKey( -+ customQuery, resourceLink, state.getQueryOptions()); -+ } else { -+ queryValidationMono = Mono.empty(); -+ } -+ -+ return valueHolderMono -+ .delayUntil(ignored -> queryValidationMono) -+ .flatMapMany(routingMapHolder -> { -+ CollectionRoutingMap routingMap = routingMapHolder.v; -+ if (routingMap == null) { -+ return Flux.error(new IllegalStateException("Failed to get routing map.")); -+ } -+ -+ Map> partitionRangePkMap = -+ groupPartitionKeysByPhysicalPartition(partitionKeys, pkDefinition, routingMap); -+ -+ List partitionKeySelectors = createPkSelectors(pkDefinition); -+ -+ String baseQueryText; -+ List baseParameters; -+ if (customQuery != null) { -+ baseQueryText = customQuery.getQueryText(); -+ baseParameters = customQuery.getParameters() != null -+ ? new ArrayList<>(customQuery.getParameters()) -+ : new ArrayList<>(); -+ } else { -+ baseQueryText = "SELECT * FROM c"; -+ baseParameters = new ArrayList<>(); -+ } -+ -+ // Build per-physical-partition batched queries. -+ // Each physical partition may have many PKs ΓÇö split into batches -+ // to avoid oversized SQL queries. Batch size is configurable via -+ // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 1000). -+ int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); -+ -+ // Build batches per partition as a list of lists (one inner list per partition). -+ // Then interleave in round-robin order so that concurrent execution -+ // prefers different physical partitions over multiple batches of the same partition. -+ List>> batchesPerPartition = new ArrayList<>(); -+ int maxBatchesPerPartition = 0; -+ -+ for (Map.Entry> entry : partitionRangePkMap.entrySet()) { -+ List allPks = entry.getValue(); -+ if (allPks.isEmpty()) { -+ continue; -+ } -+ List> partitionBatches = new ArrayList<>(); -+ for (int i = 0; i < allPks.size(); i += maxPksPerPartitionQuery) { -+ List batch = allPks.subList( -+ i, Math.min(i + maxPksPerPartitionQuery, allPks.size())); -+ SqlQuerySpec querySpec = ReadManyByPartitionKeyQueryHelper -+ .createReadManyByPkQuerySpec( -+ baseQueryText, baseParameters, batch, -+ partitionKeySelectors, pkDefinition); -+ partitionBatches.add( -+ Collections.singletonMap(entry.getKey(), querySpec)); -+ } -+ batchesPerPartition.add(partitionBatches); -+ maxBatchesPerPartition = Math.max(maxBatchesPerPartition, partitionBatches.size()); -+ } -+ -+ if (batchesPerPartition.isEmpty()) { -+ return Flux.empty(); -+ } -+ -+ // Round-robin interleave: [batch0-p1, batch0-p2, ..., batch0-pN, batch1-p1, batch1-p2, ...] -+ // This ensures that with bounded concurrency, different partitions are -+ // preferred over sequential batches of the same partition. -+ List> interleavedBatches = new ArrayList<>(); -+ for (int batchIdx = 0; batchIdx < maxBatchesPerPartition; batchIdx++) { -+ for (List> partitionBatches : batchesPerPartition) { -+ if (batchIdx < partitionBatches.size()) { -+ interleavedBatches.add(partitionBatches.get(batchIdx)); -+ } -+ } -+ } -+ -+ // Execute all batches with bounded concurrency. -+ List>> queryFluxes = interleavedBatches -+ .stream() -+ .map(batchMap -> queryForReadMany( -+ diagnosticsFactory, -+ resourceLink, -+ new SqlQuerySpec(DUMMY_SQL_QUERY), -+ state.getQueryOptions(), -+ klass, -+ ResourceType.Document, -+ collection, -+ Collections.unmodifiableMap(batchMap))) -+ .collect(Collectors.toList()); -+ -+ int fluxConcurrency = Math.min(queryFluxes.size(), -+ Math.max(Configs.getCPUCnt(), 1)); -+ -+ return Flux.merge(Flux.fromIterable(queryFluxes), fluxConcurrency, 1); -+ }); -+ }); -+ } -+ -+ private Mono validateCustomQueryForReadManyByPartitionKey( -+ SqlQuerySpec customQuery, -+ String resourceLink, -+ CosmosQueryRequestOptions queryRequestOptions) { -+ -+ IDocumentQueryClient queryClient = documentQueryClientImpl( -+ RxDocumentClientImpl.this, getOperationContextAndListenerTuple(queryRequestOptions)); -+ -+ return DocumentQueryExecutionContextFactory -+ .fetchQueryPlanForValidation(this, queryClient, customQuery, resourceLink, queryRequestOptions) -+ .flatMap(queryPlan -> { -+ QueryInfo queryInfo = queryPlan.getQueryInfo(); -+ -+ if (queryInfo.hasGroupBy()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain GROUP BY.")); -+ } -+ if (queryInfo.hasAggregates()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain aggregates.")); -+ } -+ if (queryInfo.hasOrderBy()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain ORDER BY.")); -+ } -+ if (queryInfo.hasDistinct()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain DISTINCT.")); -+ } -+ if (queryInfo.hasDCount()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain DCOUNT.")); -+ } -+ if (queryInfo.hasNonStreamingOrderBy()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain non-streaming ORDER BY.")); -+ } -+ if (queryPlan.hasHybridSearchQueryInfo()) { -+ return Mono.error(new IllegalArgumentException( -+ "Custom query for readMany by partition key must not contain hybrid/vector/full-text search.")); -+ } -+ -+ return Mono.empty(); -+ }); -+ } -+ -+ private Map> groupPartitionKeysByPhysicalPartition( -+ List partitionKeys, -+ PartitionKeyDefinition pkDefinition, -+ CollectionRoutingMap routingMap) { -+ -+ Map> partitionRangePkMap = new HashMap<>(); -+ -+ for (PartitionKey pk : partitionKeys) { -+ PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); -+ -+ // PartitionKey.NONE wraps NonePartitionKey which has components = null. -+ // For routing purposes, treat NONE as UndefinedPartitionKey ΓÇö documents ingested -+ // without a partition key path are stored with the undefined EPK. -+ PartitionKeyInternal effectivePkInternal = pkInternal.getComponents() == null -+ ? PartitionKeyInternal.UndefinedPartitionKey -+ : pkInternal; -+ -+ int componentCount = effectivePkInternal.getComponents().size(); -+ int definedPathCount = pkDefinition.getPaths().size(); -+ -+ List targetRanges; -+ -+ if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { -+ // Partial HPK ΓÇö compute EPK prefix range and find all overlapping physical partitions -+ Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( -+ effectivePkInternal, pkDefinition); -+ targetRanges = routingMap.getOverlappingRanges(epkRange); -+ } else { -+ // Full PK ΓÇö maps to exactly one physical partition -+ String effectivePartitionKeyString = PartitionKeyInternalHelper -+ .getEffectivePartitionKeyString(effectivePkInternal, pkDefinition); -+ PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); -+ targetRanges = Collections.singletonList(range); -+ } -+ -+ for (PartitionKeyRange range : targetRanges) { -+ partitionRangePkMap.computeIfAbsent(range, k -> new ArrayList<>()).add(pk); -+ } -+ } -+ -+ return partitionRangePkMap; -+ } -+ - private Map getRangeQueryMap( - Map> partitionRangeItemKeyMap, - PartitionKeyDefinition partitionKeyDefinition) { -diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java -index e62d8ed3d75..d8f9614343c 100644 ---- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java -+++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java -@@ -318,6 +318,17 @@ public class DocumentQueryExecutionContextFactory { - return feedRanges; - } - -+ public static Mono fetchQueryPlanForValidation( -+ DiagnosticsClientContext diagnosticsClientContext, -+ IDocumentQueryClient queryClient, -+ SqlQuerySpec sqlQuerySpec, -+ String resourceLink, -+ CosmosQueryRequestOptions queryRequestOptions) { -+ -+ return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( -+ diagnosticsClientContext, queryClient, sqlQuerySpec, resourceLink, queryRequestOptions); -+ } -+ - public static Flux> createDocumentQueryExecutionContextAsync( - DiagnosticsClientContext diagnosticsClientContext, - IDocumentQueryClient client, -diff --git a/sdk/cosmos/cspell.yaml b/sdk/cosmos/cspell.yaml -new file mode 100644 -index 00000000000..94a4002c2c9 ---- /dev/null -+++ b/sdk/cosmos/cspell.yaml -@@ -0,0 +1,6 @@ -+import: -+ - ../../.vscode/cspell.json -+overrides: -+ - filename: "**/sdk/cosmos/*" -+ words: -+ - DCOUNT -diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md -new file mode 100644 -index 00000000000..95d7624f0c8 ---- /dev/null -+++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md -@@ -0,0 +1,169 @@ -+# readManyByPartitionKey ΓÇö Design & Implementation -+ -+## Overview -+ -+New `readManyByPartitionKey` methods on `CosmosAsyncContainer` / `CosmosContainer` that accept a -+`List` (without item-id). The SDK splits the PK values by physical -+partition, generates batched streaming queries per physical partition, and returns results as -+`CosmosPagedFlux` / `CosmosPagedIterable`. -+ -+An optional `SqlQuerySpec` parameter lets callers supply a custom query for projections -+and additional filters. The SDK appends the auto-generated PK WHERE clause to it. -+ -+## Decisions -+ -+| Topic | Decision | -+|---|---| -+| API name | `readManyByPartitionKey` ΓÇö distinct name to avoid ambiguity with existing `readMany(List)` | -+| Return type | `CosmosPagedFlux` (async) / `CosmosPagedIterable` (sync) | -+| Custom query format | `SqlQuerySpec` ΓÇö full query with parameters; SDK ANDs the PK filter | -+| Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | -+| PK deduplication | Done at Spark layer only, not in the SDK | -+| Spark UDF | New `GetCosmosPartitionKeyValue` UDF | -+| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/non-streaming ORDER BY/vector/fulltext | -+| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 1000 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | -+| Eager validation | Null and empty PK list rejected eagerly (not lazily in reactive chain) | -+| Telemetry | Separate span name `readManyByPartitionKeyItems.` (distinct from existing `readManyItems`) | -+| Query construction | Table alias auto-detected from FROM clause; string literals and subqueries handled correctly | -+ -+## Phase 1 ΓÇö SDK Core (`azure-cosmos`) -+ -+### Step 1: New public overloads in CosmosAsyncContainer -+ -+```java -+ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, Class classType) -+ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) -+ CosmosPagedFlux readManyByPartitionKey(List partitionKeys, -+ SqlQuerySpec customQuery, -+ CosmosReadManyRequestOptions requestOptions, -+ Class classType) -+``` -+ -+All delegate to a private `readManyByPartitionKeyInternalFunc(...)`. -+ -+**Eager validation:** The 4-arg method validates `partitionKeys` is non-null and non-empty before constructing the reactive pipeline, throwing `IllegalArgumentException` synchronously. -+ -+### Step 2: Sync wrappers in CosmosContainer -+ -+Same signatures returning `CosmosPagedIterable`, delegating to the async container. -+ -+### Step 3: Internal orchestration (RxDocumentClientImpl) -+ -+1. Resolve collection metadata + PK definition from cache. -+2. Fetch routing map from `partitionKeyRangeCache` **in parallel with** custom query validation (Step 4). -+3. For each `PartitionKey`: -+ - Compute effective partition key (EPK). -+ - Full PK ΓåÆ `getRangeByEffectivePartitionKey()` (single range). -+ - Partial HPK ΓåÆ compute EPK prefix range ΓåÆ `getOverlappingRanges()` (multiple ranges). -+ **Note:** partial HPK intentionally fans out to multiple physical partitions. -+4. Group PK values by `PartitionKeyRange`. -+5. Per physical partition ΓåÆ split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 1000). -+6. Per batch ΓåÆ build `SqlQuerySpec` with PK WHERE clause (Step 5). -+7. Interleave batches across physical partitions in round-robin order so that bounded concurrency prefers different physical partitions over sequential batches of the same partition. -+8. Execute queries via `queryForReadMany()` with bounded concurrency (`Math.min(batchCount, cpuCount)`). -+9. Return results as `CosmosPagedFlux`. -+ -+### Step 4: Custom query validation -+ -+One-time call per invocation (existing query plan caching applies). Runs **in parallel** with routing map lookup to minimize latency: -+ -+- `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. -+- Reject (`IllegalArgumentException`) if: -+ - `queryInfo.hasGroupBy()` ΓÇö checked first (takes precedence over aggregates since `hasAggregates()` also returns true for GROUP BY queries) -+ - `queryInfo.hasAggregates()` -+ - `queryInfo.hasOrderBy()` -+ - `queryInfo.hasDistinct()` -+ - `queryInfo.hasDCount()` -+ - `queryInfo.hasNonStreamingOrderBy()` -+ - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` -+ -+### Step 5: Query construction -+ -+Query construction is implemented in `ReadManyByPartitionKeyQueryHelper`. The helper: -+- Extracts the table alias from the FROM clause (handles `FROM c`, `FROM root r`, `FROM x WHERE ...`) -+- Handles string literals in queries (parens/keywords inside `'...'` are correctly skipped) -+- Recognizes SQL keywords: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING -+- Uses parameterized queries (`@__rmPk_` prefix) to prevent SQL injection -+ -+**Single PK (HASH):** -+```sql -+{baseQuery} WHERE {alias}["{pkPath}"] IN (@__rmPk_0, @__rmPk_1, @__rmPk_2) -+``` -+ -+**Full HPK (MULTI_HASH):** -+```sql -+{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0 AND {alias}["{path2}"] = @__rmPk_1) -+ OR ({alias}["{path1}"] = @__rmPk_2 AND {alias}["{path2}"] = @__rmPk_3) -+``` -+ -+**Partial HPK (prefix-only):** -+```sql -+{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0) -+ OR ({alias}["{path1}"] = @__rmPk_1) -+``` -+ -+If the base query already has a WHERE clause: -+```sql -+{selectAndFrom} WHERE ({existingWhere}) AND ({pkFilter}) -+``` -+ -+### Step 6: Interface wiring -+ -+New method `readManyByPartitionKey` added directly to `AsyncDocumentClient` interface, implemented in `RxDocumentClientImpl`. New `fetchQueryPlanForValidation` static method added to `DocumentQueryExecutionContextFactory` for custom query validation. -+ -+### Step 7: Configuration -+ -+New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 1000, minimum: 1). Follows existing `Configs` patterns. -+ -+## Phase 2 ΓÇö Spark Connector (`azure-cosmos-spark_3`) -+ -+### Step 8: New UDF ΓÇö `GetCosmosPartitionKeyValue` -+ -+- Input: partition key value (single value or Seq for hierarchical PKs). -+- Output: serialized PK string in format `pk([...json...])`. -+- **Null handling:** Throws on null input (Scala convention; callers should filter nulls upstream). -+ -+### Step 9: PK-only serialization helper -+ -+`CosmosPartitionKeyHelper`: -+- `getCosmosPartitionKeyValueString(pkValues: List[Object]): String` ΓÇö serialize to `pk([...])` format. -+- `tryParsePartitionKey(serialized: String): Option[PartitionKey]` ΓÇö deserialize; returns `None` for malformed input including invalid JSON (wrapped in `scala.util.Try`). -+ -+### Step 10: `CosmosItemsDataSource.readManyByPartitionKey` -+ -+Static entry points that accept a DataFrame and Cosmos config. PK extraction supports two modes: -+1. **UDF-produced column**: DataFrame contains `_partitionKeyIdentity` column (from `GetCosmosPartitionKeyValue` UDF). -+2. **Schema-matched columns**: DataFrame columns match the container's PK paths. -+ -+Falls back with `IllegalArgumentException` if neither mode is possible. -+ -+### Step 11: `CosmosReadManyByPartitionKeyReader` -+ -+Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. -+ -+### Step 12: `ItemsPartitionReaderWithReadManyByPartitionKey` -+ -+Spark `PartitionReader[InternalRow]` that: -+- Deduplicates PKs via `LinkedHashMap` (by PK string representation). -+- Passes the pre-built `CosmosReadManyRequestOptions` (with throughput control, diagnostics, custom serializer) to the SDK. -+- Uses `TransientIOErrorsRetryingIterator` for retry handling. -+- Short-circuits empty PK lists to avoid SDK rejection. -+ -+## Phase 3 ΓÇö Testing -+ -+### Unit tests -+- Query construction: single PK, HPK full/partial, custom query composition, table alias detection. -+- Query plan rejection: aggregates, ORDER BY, DISTINCT, GROUP BY (with and without aggregates), DCOUNT. -+- String literal handling: WHERE/parentheses inside string constants. -+- Keyword detection: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING. -+- PK serialization/deserialization roundtrip (including malformed JSON handling). -+- `findTopLevelWhereIndex` edge cases: subqueries, string literals, case insensitivity. -+ -+### Integration tests -+- End-to-end SDK: single PK basic, projections, filters, empty results, HPK full/partial, request options propagation. -+- Batch size validation: temporarily lowered batch size to exercise batching/interleaving logic. -+- Null/empty PK list rejection (eager validation). -+- Spark connector: `ItemsPartitionReaderWithReadManyByPartitionKey` with known PK values and non-existent PKs. -+- `CosmosPartitionKeyHelper`: single/HPK roundtrip, case insensitivity, malformed input. - diff --git a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt b/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt deleted file mode 100644 index 360eb15accb5..000000000000 --- a/sdk/cosmos/.temp/branch-users_fabianm_readManyByPK-vs-origin_main-stat.txt +++ /dev/null @@ -1,76 +0,0 @@ -===== PR #LOCAL-users_fabianm_readManyByPK ===== -Title: Branch comparison users/fabianm/readManyByPK vs origin/main -Author: Fabian Meiswinkel -Status: DIVERGED (ahead 30, behind 4) -Branch: users/fabianm/readManyByPK -> origin/main -Head SHA: 93957f3a8442d730fe67fbc379ef5399f46f5665 -Merge Base: 20313f79ba8dd0dfa97862d0c31dd4b2e44ee671 -URL: N/A (local branch comparison) - ---- Description --- -Adds readManyByPartitionKey API (sync+async) and Spark connector support for PK-only reads, with query-plan-based validation ---- End Description --- - -===== Commits in PR ===== -9770833eb59 Adding readManyByPartitionKey API -ac287bcdf00 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -9a5b3e96e7e Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -a8720c3c9f2 Update sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala -d499da76fb4 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java -c3c542a33a7 Update sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java -4416354e03e ┬┤Fixing code review comments -3ab3f0d64f5 Merge branch 'main' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -588a7550c54 Update CosmosAsyncContainer.java -8c5cdb47b31 Merge branch 'main' into users/fabianm/readManyByPK -f5485527a9c Update ReadManyByPartitionKeyTest.java -f68cf02ff71 Fixing test issues -8b6c4b168ea Update CosmosAsyncContainer.java -8ba7f4db2da Merge branch 'main' into users/fabianm/readManyByPK -56b067a9339 Reacted to code review feedback -fa430e918fa Merge branch 'main' into users/fabianm/readManyByPK -d9504c91f34 Fix build issues -73151f09e5f Merge branch 'main' into users/fabianm/readManyByPK -681830e2d4a Fixing changelog -7f745e60641 Merge branch 'main' into users/fabianm/readManyByPK -0b8905dbb01 Addressing code review comments -22abc780ed8 Addressing code review feedback -662b1a4b90e Update CosmosItemsDataSource.scala -c764de9de02 Update CosmosItemsDataSource.scala -e1e6f5a6f73 Merge branch 'main' into users/fabianm/readManyByPK -080ce4a2293 Update RxDocumentClientImpl.java -516bbf3a95a Merge branch 'users/fabianm/readManyByPK' of https://github.com/Azure/azure-sdk-for-java into users/fabianm/readManyByPK -b01f8758eea Fix readManyByPartitionKey retries -7130d4aa35a Fix PK.None -93957f3a844 Update ReadManyByPartitionKeyQueryHelper.java - -===== Files Changed ===== - sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3-5_2-13/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala (+20 -2) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala (+1 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala (+125 -1) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala (+45 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala (+150 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala (+249 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala (+259 -0) - sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala (+25 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala (+42 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala (+104 -0) - sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala (+158 -0) - sdk/cosmos/azure-cosmos-spark_4-0_2-13/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java (+462 -0) - sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java (+426 -0) - sdk/cosmos/azure-cosmos/CHANGELOG.md (+1 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java (+126 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java (+67 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java (+21 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java (+19 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java (+263 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java (+292 -0) - sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java (+11 -0) - sdk/cosmos/cspell.yaml (+6 -0) - sdk/cosmos/docs/readManyByPartitionKey-design.md (+169 -0) - - diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index 82e3158d765d..202a31c97bd2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -127,13 +127,21 @@ object CosmosItemsDataSource { userProvidedSchema, userConfig.asScala.toMap) + // Resolve the null-handling config up front so both the UDF path and the PK-column path honor it. + val sharedEffectiveConfig = CosmosConfig.getEffectiveConfig( + databaseName = None, + containerName = None, + userConfig.asScala.toMap) + val sharedReadConfig = CosmosReadConfig.parseCosmosReadConfig(sharedEffectiveConfig) + val sharedTreatNullAsNone = sharedReadConfig.readManyByPkTreatNullAsNone + // Option 1: Look for the _partitionKeyIdentity column (produced by GetCosmosPartitionKeyValue UDF) val pkIdentityFieldExtraction = df .schema .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) .map(field => (row: Row) => { val rawValue = row.getString(row.fieldIndex(field.name)) - CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue) + CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue, sharedTreatNullAsNone) .getOrElse(throw new IllegalArgumentException( s"Invalid _partitionKeyIdentity value in row: '$rawValue'. " + "Expected format: pk([...json...])")) @@ -143,11 +151,8 @@ object CosmosItemsDataSource { val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { None // no need to resolve PK paths - _partitionKeyIdentity column takes precedence } else { - val effectiveConfig = CosmosConfig.getEffectiveConfig( - databaseName = None, - containerName = None, - userConfig.asScala.toMap) - val readConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveConfig) + val effectiveConfig = sharedEffectiveConfig + val readConfig = sharedReadConfig val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala index 27776f5c3de6..afea6c26c89d 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -5,7 +5,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.implementation.routing.PartitionKeyInternal import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, Utils} -import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait import java.util @@ -27,16 +27,45 @@ private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { s"pk(${objectMapper.writeValueAsString(partitionKeyValue.asJava)})" } - def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = { + def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = + tryParsePartitionKey(cosmosPartitionKeyString, treatNullAsNone = false) + + /** + * Parses a pk(...) string into a [[PartitionKey]]. + * + * When treatNullAsNone is true, any JSON null components in the serialized array are mapped to + * [[PartitionKeyBuilder.addNoneValue()]] (meaning the document field is absent/undefined). + * When false, they are mapped to [[PartitionKeyBuilder.addNullValue()]] (JSON null value). + * This matches the spark.cosmos.read.readManyByPk.nullHandling config for the non-UDF column path. + */ + def tryParsePartitionKey( + cosmosPartitionKeyString: String, + treatNullAsNone: Boolean): Option[PartitionKey] = { cosmosPartitionKeyString match { case cosmosPartitionKeyStringRegx(pkValue) => scala.util.Try(Utils.parse(pkValue, classOf[Object])).toOption.flatMap { case arrayList: util.ArrayList[Object @unchecked] => - Some( - ImplementationBridgeHelpers - .PartitionKeyHelper - .getPartitionKeyAccessor - .toPartitionKey(PartitionKeyInternal.fromObjectArray(arrayList.toArray, false))) + val components = arrayList.toArray + if (components.exists(_ == null)) { + // Build via PartitionKeyBuilder so nulls can be disambiguated between + // JSON-null (addNullValue) and undefined (addNoneValue) based on config. + val builder = new PartitionKeyBuilder() + components.foreach { + case null => + if (treatNullAsNone) builder.addNoneValue() else builder.addNullValue() + case s: String => builder.add(s) + case n: java.lang.Number => builder.add(n.doubleValue()) + case b: java.lang.Boolean => builder.add(b.booleanValue()) + case other => builder.add(other.toString) + } + Some(builder.build()) + } else { + Some( + ImplementationBridgeHelpers + .PartitionKeyHelper + .getPartitionKeyAccessor + .toPartitionKey(PartitionKeyInternal.fromObjectArray(components, false))) + } case other => Some(new PartitionKey(other)) } case _ => None diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala index 207f77eb4766..82225a9039ee 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -64,6 +64,14 @@ private[spark] class CosmosReadManyByPartitionKeyReader( cosmosContainerConfig, clientCacheItems(0).get, clientCacheItems(1)) + // Warm-up readItem: intentionally issues a lookup for a random id/partition-key pair + // on the driver so that the collection/routing-map caches are populated before we serialize + // the client state and broadcast it to executors. This costs ~1 RU + 1 RTT per broadcast build + // (expected 404) but avoids every executor doing the same lookup in parallel on first use. + // Warm-up readItem: intentionally issues a lookup for a random id/partition-key pair + // on the driver so that the collection/routing-map caches are populated before we serialize + // the client state and broadcast it to executors. This costs ~1 RU + 1 RTT per broadcast build + // (expected 404) but avoids every executor doing the same lookup in parallel on first use. try { container.readItem( UUIDs.nonBlockingRandomUUID().toString, diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java index 2c26d564ed24..1eff5f7f18e7 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -227,6 +227,43 @@ public void hpk_readManyByPartitionKey_partialPk_twoLevels() { cleanupContainer(multiHashContainer); } + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + @SuppressWarnings("deprecation") + public void hpk_readManyByPartitionKey_withNoneComponent() { + // Regression test for hierarchical partition key routing with PartitionKey.NONE / addNoneValue() + // at a trailing position. Some documents omit the last PK path (areaCode); they must be + // routed via the NOT IS_DEFINED(c["areaCode"]) predicate and returned only when the caller + // requests that slice via addNoneValue(). + createHpkItems(); + // Insert 3 documents where areaCode is undefined (NONE) under Redmond/98053 + for (int i = 0; i < 3; i++) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("city", "Redmond"); + item.put("zipcode", "98053"); + // deliberately omit areaCode + multiHashContainer.createItem(item); + } + + // Request the NONE slice: Redmond/98053/ + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").add("98053").addNoneValue().build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Only the 3 documents without areaCode should come back — the pre-existing items in + // createHpkItems() all have areaCode defined and live in a different physical partition slice. + assertThat(resultList).hasSize(3); + resultList.forEach(item -> { + assertThat(item.get("city").asText()).isEqualTo("Redmond"); + assertThat(item.get("zipcode").asText()).isEqualTo("98053"); + assertThat(item.has("areaCode")).isFalse(); + }); + + cleanupContainer(multiHashContainer); + } + @Test(groups = {"emulator"}, timeOut = TIMEOUT) public void hpk_readManyByPartitionKey_withProjection() { createHpkItems(); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index 4e234667c1c0..8260ed95b19d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -165,7 +165,7 @@ private static ImplementationBridgeHelpers.CosmosBatchRequestOptionsHelper.Cosmo private final String createItemSpanName; private final String readAllItemsSpanName; private final String readManyItemsSpanName; - private final String readManyByPartitionKeyItemsSpanName; + private final String readManyByPartitionKeySpanName; private final String readAllItemsOfLogicalPartitionSpanName; private final String queryItemsSpanName; private final String queryChangeFeedSpanName; @@ -199,7 +199,7 @@ protected CosmosAsyncContainer(CosmosAsyncContainer toBeWrappedContainer) { this.createItemSpanName = "createItem." + this.id; this.readAllItemsSpanName = "readAllItems." + this.id; this.readManyItemsSpanName = "readManyItems." + this.id; - this.readManyByPartitionKeyItemsSpanName = "readManyByPartitionKeyItems." + this.id; + this.readManyByPartitionKeySpanName = "readManyByPartitionKey." + this.id; this.readAllItemsOfLogicalPartitionSpanName = "readAllItemsOfLogicalPartition." + this.id; this.queryItemsSpanName = "queryItems." + this.id; this.queryChangeFeedSpanName = "queryChangeFeed." + this.id; @@ -1647,7 +1647,34 @@ public CosmosPagedFlux readManyByPartitionKey( * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). * The SDK will automatically append partition key filtering to the custom query. *

- * The custom query must be a simple streamable query — aggregates, ORDER BY, DISTINCT, + * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, customQuery, null, classType); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be * rejected. *

@@ -1696,19 +1723,24 @@ private Function>> readManyByPa CosmosQueryRequestOptions queryRequestOptions = requestOptions == null ? new CosmosQueryRequestOptions() : queryOptionsAccessor().clone(readManyOptionsAccessor().getImpl(requestOptions)); - queryRequestOptions.setMaxDegreeOfParallelism(-1); + // Honor any caller-provided MaxDegreeOfParallelism; only default to the "unbounded" sentinel + // (-1) when the value is still at the default (0). CosmosReadManyRequestOptions currently does not + // expose MDOP, so this branch is defensive in case it is plumbed through in the future. + if (queryRequestOptions.getMaxDegreeOfParallelism() == 0) { + queryRequestOptions.setMaxDegreeOfParallelism(-1); + } queryRequestOptions.setQueryName("readManyByPartitionKey"); CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); - applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeyItemsSpanName); + applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeySpanName); QueryFeedOperationState state = new QueryFeedOperationState( client, - this.readManyByPartitionKeyItemsSpanName, + this.readManyByPartitionKeySpanName, database.getId(), this.getId(), ResourceType.Document, OperationType.Query, - queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeyItemsSpanName), + queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeySpanName), queryRequestOptions, pagedFluxOptions ); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java index 0bd8be5850c0..33149a498390 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java @@ -584,7 +584,34 @@ public CosmosPagedIterable readManyByPartitionKey( * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). * The SDK will automatically append partition key filtering to the custom query. *

- * The custom query must be a simple streamable query — aggregates, ORDER BY, DISTINCT, + * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, classType)); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be * rejected. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index 24ac8ea19666..965df746768b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -196,17 +196,19 @@ static String extractTableAlias(String queryText) { // Check if there's an alias after the container name (before WHERE or end) if (afterFrom < queryText.length()) { - char nextChar = Character.toUpperCase(queryText.charAt(afterFrom)); - // If the next token is a keyword (WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING) or end, containerName IS the alias - if (nextChar == 'W' || nextChar == 'O' || nextChar == 'G' || nextChar == 'J' - || nextChar == 'L' || nextChar == 'H') { - // Check if it's actually a keyword - String remaining = upper.substring(afterFrom); - if (remaining.startsWith("WHERE") || remaining.startsWith("ORDER") - || remaining.startsWith("GROUP") || remaining.startsWith("JOIN") - || remaining.startsWith("OFFSET") || remaining.startsWith("LIMIT") - || remaining.startsWith("HAVING")) { - return containerName; + String remaining = upper.substring(afterFrom); + // Reserved keywords that terminate the FROM clause - when the next token is one of these, + // containerName itself IS the alias used throughout the rest of the query. + if (isFollowedByReservedKeyword(remaining)) { + return containerName; + } + // Handle optional AS: "FROM root AS r" -> alias is "r" + if (remaining.startsWith("AS") + && (remaining.length() == 2 || !Character.isLetterOrDigit(remaining.charAt(2)))) { + afterFrom += 2; // skip AS + while (afterFrom < queryText.length() + && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; } } // Otherwise the next token is the alias ("FROM root r" -> alias is "r") @@ -225,6 +227,18 @@ static String extractTableAlias(String queryText) { return containerName; } + private static boolean isFollowedByReservedKeyword(String remainingUpper) { + String[] keywords = { "WHERE", "ORDER", "GROUP", "JOIN", "OFFSET", "LIMIT", "HAVING" }; + for (String kw : keywords) { + if (remainingUpper.startsWith(kw) + && (remainingUpper.length() == kw.length() + || !Character.isLetterOrDigit(remainingUpper.charAt(kw.length())))) { + return true; + } + } + return false; + } + /** * Finds the index of a top-level SQL keyword in the query text (case-insensitive), * ignoring occurrences inside parentheses or string literals. @@ -234,14 +248,33 @@ static int findTopLevelKeywordIndex(String queryText, String keyword) { String keywordUpper = keyword.toUpperCase(); int depth = 0; int keyLen = keywordUpper.length(); - for (int i = 0; i <= queryTextUpper.length() - keyLen; i++) { - char ch = queryTextUpper.charAt(i); + int len = queryTextUpper.length(); + for (int i = 0; i <= len - keyLen; i++) { + char ch = queryText.charAt(i); + // Skip single-line comments: -- ... end-of-line + if (ch == '-' && i + 1 < len && queryText.charAt(i + 1) == '-') { + i += 2; + while (i < len && queryText.charAt(i) != '\n' && queryText.charAt(i) != '\r') { + i++; + } + continue; + } + // Skip block comments: /* ... */ + if (ch == '/' && i + 1 < len && queryText.charAt(i + 1) == '*') { + i += 2; + while (i + 1 < len + && !(queryText.charAt(i) == '*' && queryText.charAt(i + 1) == '/')) { + i++; + } + i++; // position on the '/'; loop post-increment moves past it + continue; + } // Skip string literals enclosed in single quotes (handle '' escape) - if (queryText.charAt(i) == '\'') { + if (ch == '\'') { i++; - while (i < queryText.length()) { + while (i < len) { if (queryText.charAt(i) == '\'') { - if (i + 1 < queryText.length() && queryText.charAt(i + 1) == '\'') { + if (i + 1 < len && queryText.charAt(i + 1) == '\'') { i += 2; // escaped quote - skip both continue; } @@ -251,11 +284,12 @@ static int findTopLevelKeywordIndex(String queryText, String keyword) { } continue; } - if (ch == '(') { + char upperCh = queryTextUpper.charAt(i); + if (upperCh == '(') { depth++; - } else if (ch == ')') { + } else if (upperCh == ')') { depth--; - } else if (depth == 0 && ch == keywordUpper.charAt(0) + } else if (depth == 0 && upperCh == keywordUpper.charAt(0) && queryTextUpper.startsWith(keywordUpper, i) && (i == 0 || !Character.isLetterOrDigit(queryTextUpper.charAt(i - 1))) && (i + keyLen >= queryTextUpper.length() || !Character.isLetterOrDigit(queryTextUpper.charAt(i + keyLen)))) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 3d9082a25bcc..46767c9e9dd5 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -121,6 +121,7 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -4512,9 +4513,6 @@ private Flux> readManyByPartitionKey( for (Map.Entry> entry : partitionRangePkMap.entrySet()) { List allPks = entry.getValue(); - if (allPks.isEmpty()) { - continue; - } List> partitionBatches = new ArrayList<>(); for (int i = 0; i < allPks.size(); i += maxPksPerPartitionQuery) { List batch = allPks.subList( @@ -4621,7 +4619,9 @@ private Map> groupPartitionKeysByPhysicalP PartitionKeyDefinition pkDefinition, CollectionRoutingMap routingMap) { - Map> partitionRangePkMap = new HashMap<>(); + // Use LinkedHashMap so the downstream round-robin interleave is deterministic and the iteration + // order follows insertion order of partition keys (i.e. the order the caller provided). + Map> partitionRangePkMap = new LinkedHashMap<>(); for (PartitionKey pk : partitionKeys) { PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); From e306faef6c378cba267123b87b51f12c45395089 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 17:27:07 +0000 Subject: [PATCH 24/25] React to code review feedback --- .../cosmos/spark/SparkE2EQueryITest.scala | 122 +++++++++++++++++- .../com/azure/cosmos/spark/CosmosConfig.scala | 5 +- .../cosmos/spark/CosmosItemsDataSource.scala | 13 +- .../spark/CosmosPartitionKeyHelper.scala | 17 ++- .../CosmosReadManyByPartitionKeyReader.scala | 42 +++++- ...tionReaderWithReadManyByPartitionKey.scala | 82 +++++++----- .../udf/GetCosmosPartitionKeyValue.scala | 5 +- .../spark/CosmosPartitionKeyHelperSpec.scala | 27 +++- .../cosmos/ReadManyByPartitionKeyTest.java | 98 +++++++++----- ...ByPartitionKeyQueryPlanValidationTest.java | 81 ++++++++++++ .../ReadManyByPartitionKeyQueryHelper.java | 11 ++ .../implementation/RxDocumentClientImpl.java | 104 ++++++++------- .../DocumentQueryExecutionContextFactory.java | 73 +++++++---- .../docs/readManyByPartitionKey-design.md | 23 ++-- 14 files changed, 536 insertions(+), 167 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java diff --git a/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala b/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala index 5f9cb1dbdbc8..8476e4aa2026 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala @@ -4,13 +4,20 @@ package com.azure.cosmos.spark import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.models.{CosmosContainerProperties, CosmosItemRequestOptions, PartitionKey, PartitionKeyDefinition, PartitionKeyDefinitionVersion, PartitionKind, ThroughputProperties} +import com.azure.cosmos.spark.udf.GetCosmosPartitionKeyValue import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.types.StringType -import java.util.UUID +import java.util.{ArrayList, UUID} + +import scala.collection.JavaConverters._ class SparkE2EQueryITest extends SparkE2EQueryITestBase { + // scalastyle:off multiple.string.literals "spark query" can "return proper Cosmos specific query plan on explain with nullable properties" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY @@ -67,4 +74,115 @@ class SparkE2EQueryITest val item = rowsArray(0) item.getAs[String]("id") shouldEqual id } -} + + "spark readManyByPartitionKey" can "use a matching top-level partition key column without the UDF" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val requestOptions = new CosmosItemRequestOptions() + + Seq("pkA", "pkB").foreach { pkValue => + val item = objectMapper.createObjectNode() + item.put("id", s"item-$pkValue") + item.put("pk", pkValue) + item.put("payload", s"value-$pkValue") + + container.createItem(item, new PartitionKey(pkValue), requestOptions).block() + } + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true" + ) + + val sparkSession = spark + import sparkSession.implicits._ + + val rows = CosmosItemsDataSource + .readManyByPartitionKey(Seq("pkA", "pkB").toDF("pk"), cfg.asJava) + .selectExpr("id", "pk", "payload") + .collect() + + rows should have size 2 + rows.map(_.getAs[String]("id")).toSet shouldEqual Set("item-pkA", "item-pkB") + rows.map(_.getAs[String]("pk")).toSet shouldEqual Set("pkA", "pkB") + rows.map(_.getAs[String]("payload")).toSet shouldEqual Set("value-pkA", "value-pkB") + } + "spark readManyByPartitionKey" can "require the UDF for nested partition key paths and succeed with it" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val containerName = s"nested-pk-${UUID.randomUUID()}" + + val pkPaths = new ArrayList[String]() + pkPaths.add("/tenant/id") + + val pkDefinition = new PartitionKeyDefinition() + pkDefinition.setPaths(pkPaths) + pkDefinition.setKind(PartitionKind.HASH) + pkDefinition.setVersion(PartitionKeyDefinitionVersion.V2) + + val containerProperties = new CosmosContainerProperties(containerName, pkDefinition) + cosmosClient + .getDatabase(cosmosDatabase) + .createContainerIfNotExists(containerProperties, ThroughputProperties.createManualThroughput(400)) + .block() + + try { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(containerName) + val requestOptions = new CosmosItemRequestOptions() + + Seq("tenantA", "tenantB").foreach { tenantId => + val item = objectMapper.createObjectNode() + item.put("id", s"item-$tenantId") + item.put("payload", s"value-$tenantId") + item.putObject("tenant").put("id", tenantId) + + container.createItem(item, new PartitionKey(tenantId), requestOptions).block() + } + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> containerName, + "spark.cosmos.read.inferSchema.enabled" -> "true" + ) + + val sparkSession = spark + import sparkSession.implicits._ + + val missingUdfError = the[IllegalArgumentException] thrownBy { + CosmosItemsDataSource.readManyByPartitionKey(Seq("tenantA").toDF("tenantId"), cfg.asJava) + } + + missingUdfError.getMessage should include("Nested paths cannot be resolved from DataFrame columns automatically") + missingUdfError.getMessage should include("_partitionKeyIdentity") + + spark.udf.register("GetCosmosPartitionKeyValue", new GetCosmosPartitionKeyValue(), StringType) + + val inputDf = Seq("tenantA", "tenantB") + .toDF("tenantId") + .withColumn("_partitionKeyIdentity", expr("GetCosmosPartitionKeyValue(tenantId)")) + + val rows = CosmosItemsDataSource + .readManyByPartitionKey(inputDf, cfg.asJava) + .selectExpr("id", "tenant.id as tenantId") + .collect() + + rows should have size 2 + rows.map(_.getAs[String]("id")).toSet shouldEqual Set("item-tenantA", "item-tenantB") + rows.map(_.getAs[String]("tenantId")).toSet shouldEqual Set("tenantA", "tenantB") + } finally { + cosmosClient + .getDatabase(cosmosDatabase) + .getContainer(containerName) + .delete() + .block() + } + } + + // scalastyle:on multiple.string.literals +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 4a483ef38a6b..8312e34bf0c1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -1147,8 +1147,9 @@ private object CosmosReadConfig { helpMessage = "Determines how null values in partition key columns are treated for " + "readManyByPartitionKey. 'Null' (default) maps null to a JSON null via addNullValue(), which " + "is appropriate when the document field exists with an explicit null value. 'None' maps null " + - "to PartitionKey.NONE via addNoneValue(), which should only be used when the partition key " + - "path does not exist at all in the document. These two semantics hash to DIFFERENT physical " + + "to PartitionKey.NONE via addNoneValue(), which is only supported for single-path partition keys " + + "and should only be used when the partition key path does not exist at all in the document. " + + "Hierarchical partition keys reject this mode. These two semantics hash to DIFFERENT physical " + "partitions - picking the wrong mode for your data will silently return zero rows." ) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index 202a31c97bd2..3433e352870b 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -212,7 +212,7 @@ object CosmosItemsDataSource { // Hierarchical partition key - build level by level val builder = new PartitionKeyBuilder() for (path <- pkPaths) { - addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone) + addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone, pkPaths.size) } builder.build() } @@ -233,12 +233,19 @@ object CosmosItemsDataSource { readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) } - private def addPartitionKeyComponent(builder: PartitionKeyBuilder, value: Any, treatNullAsNone: Boolean): Unit = { + private def addPartitionKeyComponent( + builder: PartitionKeyBuilder, + value: Any, + treatNullAsNone: Boolean, + partitionKeyComponentCount: Int): Unit = { value match { case s: String => builder.add(s) case n: Number => builder.add(n.doubleValue()) case b: Boolean => builder.add(b) case null => + CosmosPartitionKeyHelper.validateNoneHandlingForPartitionKeyComponentCount( + partitionKeyComponentCount, + treatNullAsNone) if (treatNullAsNone) builder.addNoneValue() else builder.addNullValue() case other => @@ -255,7 +262,7 @@ object CosmosItemsDataSource { private def buildPartitionKey(value: Any, treatNullAsNone: Boolean): PartitionKey = { val builder = new PartitionKeyBuilder() - addPartitionKeyComponent(builder, value, treatNullAsNone) + addPartitionKeyComponent(builder, value, treatNullAsNone, 1) builder.build() } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala index afea6c26c89d..29831e9a6fed 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -15,6 +15,20 @@ import scala.collection.JavaConverters._ // scalastyle:on underscore.import private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { + private[spark] val HierarchicalPartitionKeyNoneHandlingErrorMessage = + s"The configuration '${CosmosConfigNames.ReadManyByPkNullHandling}=None' is not supported for " + + "hierarchical partition keys because PartitionKey.NONE can't be used with multiple paths. " + + "Use 'Null' for explicit JSON null values, filter out rows with missing partition key " + + "components, or provide fully-defined hierarchical partition keys." + + private[spark] def validateNoneHandlingForPartitionKeyComponentCount( + componentCount: Int, + treatNullAsNone: Boolean): Unit = { + if (treatNullAsNone && componentCount > 1) { + throw new IllegalArgumentException(HierarchicalPartitionKeyNoneHandlingErrorMessage) + } + } + // pattern will be recognized // pk(partitionKeyValue) // @@ -47,6 +61,7 @@ private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { case arrayList: util.ArrayList[Object @unchecked] => val components = arrayList.toArray if (components.exists(_ == null)) { + validateNoneHandlingForPartitionKeyComponentCount(components.length, treatNullAsNone) // Build via PartitionKeyBuilder so nulls can be disambiguated between // JSON-null (addNullValue) and undefined (addNoneValue) based on config. val builder = new PartitionKeyBuilder() @@ -71,4 +86,4 @@ private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { case _ => None } } -} +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala index 82225a9039ee..bd6fa5be8bdf 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -138,6 +138,7 @@ private[spark] class CosmosReadManyByPartitionKeyReader( logInfo(s"Creating an ItemsPartitionReaderWithReadManyByPartitionKey for Activity $correlationActivityId to read for " + s"input partition [$partitionIndex] ${tableName}") + val taskContext = TaskContext.get val reader = new ItemsPartitionReaderWithReadManyByPartitionKey( effectiveUserConfig, CosmosReadManyHelper.FullRangeFeedRange, @@ -146,13 +147,46 @@ private[spark] class CosmosReadManyByPartitionKeyReader( clientStates, DiagnosticsConfig.parseDiagnosticsConfig(effectiveUserConfig), sparkEnvironmentInfo, - TaskContext.get, + taskContext, pkIterator) new Iterator[Row] { - override def hasNext: Boolean = reader.next() - - override def next(): Row = reader.getCurrentRow() + private var isClosed = false + + private def closeReader(): Unit = { + if (!isClosed) { + isClosed = true + reader.close() + } + } + + if (taskContext != null) { + taskContext.addTaskCompletionListener[Unit](_ => closeReader()) + } + + override def hasNext: Boolean = { + try { + val hasMore = reader.next() + if (!hasMore) { + closeReader() + } + hasMore + } catch { + case error: Throwable => + closeReader() + throw error + } + } + + override def next(): Row = { + try { + reader.getCurrentRow() + } catch { + case error: Throwable => + closeReader() + throw error + } + } } }, preservesPartitioning = true diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala index 7fb6e0eb0aba..5c994cc2f7f8 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType import java.util +import java.util.concurrent.atomic.AtomicBoolean // scalastyle:off underscore.import import scala.collection.JavaConverters._ @@ -195,31 +196,41 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey // Pass the full PK list to the SDK (which batches per physical partition internally). // On transient I/O failures the retry iterator tracks pages already emitted upstream - // and skips them on replay; if a failure occurs mid-page (after items from that page - // have been emitted) the task fails rather than risking row duplication. - private lazy val iterator: CloseableSparkRowItemIterator = - if (pkList.isEmpty) { - EmptySparkRowItemIterator - } else { - new CloseableSparkRowItemIterator { - private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( - cosmosAsyncContainer, - pkList, - readConfig.customQuery.map(_.toSqlQuerySpec), - readManyOptions, - readConfig.maxItemCount, - readConfig.prefetchBufferSize, - operationContextAndListenerTuple, - classOf[SparkRowItem] - ) - - override def hasNext: Boolean = delegate.hasNext - - override def next(): SparkRowItem = delegate.next() - - override def close(): Unit = delegate.close() - } - } + // and skips them on replay; if a failure occurs mid-page (after items from that page have been + // emitted) the task fails rather than risking row duplication. + private val isClosed = new AtomicBoolean(false) + private var iteratorOpt: Option[CloseableSparkRowItemIterator] = None + + private def getOrCreateIterator: CloseableSparkRowItemIterator = iteratorOpt match { + case Some(existing) => existing + case None => + val created = + if (pkList.isEmpty) { + EmptySparkRowItemIterator + } else { + new CloseableSparkRowItemIterator { + private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( + cosmosAsyncContainer, + pkList, + readConfig.customQuery.map(_.toSqlQuerySpec), + readManyOptions, + readConfig.maxItemCount, + readConfig.prefetchBufferSize, + operationContextAndListenerTuple, + classOf[SparkRowItem] + ) + + override def hasNext: Boolean = delegate.hasNext + + override def next(): SparkRowItem = delegate.next() + + override def close(): Unit = delegate.close() + } + } + + iteratorOpt = Some(created) + created + } private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) @@ -236,20 +247,23 @@ private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey } } - override def next(): Boolean = iterator.hasNext + override def next(): Boolean = getOrCreateIterator.hasNext override def get(): InternalRow = { - cosmosRowConverter.fromRowToInternalRow(iterator.next().row, rowSerializer) + cosmosRowConverter.fromRowToInternalRow(getOrCreateIterator.next().row, rowSerializer) } - def getCurrentRow(): Row = iterator.next().row + def getCurrentRow(): Row = getOrCreateIterator.next().row override def close(): Unit = { - this.iterator.close() - RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) - clientCacheItem.close() - if (throughputControlClientCacheItemOpt.isDefined) { - throughputControlClientCacheItemOpt.get.close() + if (isClosed.compareAndSet(false, true)) { + iteratorOpt.foreach(_.close()) + iteratorOpt = None + RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) + clientCacheItem.close() + if (throughputControlClientCacheItemOpt.isDefined) { + throughputControlClientCacheItemOpt.get.close() + } } } -} +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala index fc038861a19a..4dcc812f3122 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala @@ -12,8 +12,9 @@ class GetCosmosPartitionKeyValue extends UDF1[Object, String] { // single-level partition key with a JSON null component; parsing that string back via // CosmosPartitionKeyHelper.tryParsePartitionKey yields a PartitionKey built with // addNullValue(). If the caller instead wants PartitionKey.NONE semantics (absent PK - // field) they should filter the null row before calling this UDF and use the - // schema-matched readManyByPartitionKey path with readManyByPk.nullHandling=None. + // field) they should filter the null row before calling this UDF and use the schema-matched + // readManyByPartitionKey path with readManyByPk.nullHandling=None. That None mode is only + // supported for single-path partition keys; hierarchical partition keys reject it. override def call(partitionKeyValue: Object): String = { partitionKeyValue match { case null => diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala index 1ac40e395847..c81113d3a190 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -88,15 +88,28 @@ class CosmosPartitionKeyHelperSpec extends UnitSpec { pk.isDefined shouldBe false } - it should "produce different partition keys for addNullValue vs addNoneValue in HPK" in { - // addNullValue represents an explicit JSON null for a field that exists with value null - val pkWithNull = new PartitionKeyBuilder().add("Redmond").addNullValue().build() + it should "parse single-path null as PartitionKey.NONE when treatNullAsNone is true" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([null])", treatNullAsNone = true) - // addNoneValue represents PartitionKey.NONE, meaning the field is absent/undefined - val pkWithNone = new PartitionKeyBuilder().add("Redmond").addNoneValue().build() + pk.isDefined shouldBe true + pk.get shouldEqual PartitionKey.NONE + } + + it should "throw a clear error when None nullHandling is used for hierarchical partition keys" in { + val error = the[IllegalArgumentException] thrownBy { + CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"Redmond\",null])", treatNullAsNone = true) + } + + error.getMessage should include(CosmosConfigNames.ReadManyByPkNullHandling) + error.getMessage should include("hierarchical partition keys") + } + + it should "reject addNoneValue in hierarchical partition keys" in { + val error = the[IllegalStateException] thrownBy { + new PartitionKeyBuilder().add("Redmond").addNoneValue().build() + } - // These MUST produce different partition key hashes and route to different physical partitions - pkWithNull should not equal pkWithNone + error.getMessage should include("PartitionKey.None can't be used with multiple paths") } //scalastyle:on multiple.string.literals diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java index 1eff5f7f18e7..ca684f1abf9e 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -27,7 +27,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -230,38 +232,30 @@ public void hpk_readManyByPartitionKey_partialPk_twoLevels() { @Test(groups = {"emulator"}, timeOut = TIMEOUT) @SuppressWarnings("deprecation") public void hpk_readManyByPartitionKey_withNoneComponent() { - // Regression test for hierarchical partition key routing with PartitionKey.NONE / addNoneValue() - // at a trailing position. Some documents omit the last PK path (areaCode); they must be - // routed via the NOT IS_DEFINED(c["areaCode"]) predicate and returned only when the caller - // requests that slice via addNoneValue(). - createHpkItems(); - // Insert 3 documents where areaCode is undefined (NONE) under Redmond/98053 - for (int i = 0; i < 3; i++) { + try { + createHpkItems(); + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); item.put("id", UUID.randomUUID().toString()); item.put("city", "Redmond"); item.put("zipcode", "98053"); - // deliberately omit areaCode - multiHashContainer.createItem(item); - } - - // Request the NONE slice: Redmond/98053/ - List pkValues = Collections.singletonList( - new PartitionKeyBuilder().add("Redmond").add("98053").addNoneValue().build()); - CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); - List resultList = results.stream().collect(Collectors.toList()); - - // Only the 3 documents without areaCode should come back — the pre-existing items in - // createHpkItems() all have areaCode defined and live in a different physical partition slice. - assertThat(resultList).hasSize(3); - resultList.forEach(item -> { - assertThat(item.get("city").asText()).isEqualTo("Redmond"); - assertThat(item.get("zipcode").asText()).isEqualTo("98053"); - assertThat(item.has("areaCode")).isFalse(); - }); + try { + multiHashContainer.createItem(item); + fail("Should have thrown CosmosException for HPK item with missing trailing partition key component"); + } catch (CosmosException e) { + assertThat(e.getMessage()).contains("wrong-pk-value"); + } - cleanupContainer(multiHashContainer); + try { + new PartitionKeyBuilder().add("Redmond").add("98053").addNoneValue().build(); + fail("Should have thrown IllegalStateException for HPK addNoneValue"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).contains("PartitionKey.None can't be used with multiple paths"); + } + } finally { + cleanupContainer(multiHashContainer); + } } @Test(groups = {"emulator"}, timeOut = TIMEOUT) @@ -367,6 +361,22 @@ public void rejectsEmptyPartitionKeyList() { .stream().collect(Collectors.toList()); } + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsOffsetQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec offsetQuery = new SqlQuerySpec("SELECT * FROM c OFFSET 0 LIMIT 10"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, offsetQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for OFFSET query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("OFFSET"); + } + } + + + //endregion @@ -393,9 +403,14 @@ public void singlePk_readManyByPartitionKey_withSmallBatchSize() { new PartitionKey("batchPk4")); CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); - List resultList = results.stream().collect(Collectors.toList()); + List> pages = new ArrayList<>(); + results.iterableByPage().forEach(pages::add); + List resultList = pages.stream() + .flatMap(page -> page.getResults().stream()) + .collect(Collectors.toList()); assertThat(resultList).hasSize(8); // 2 items per PK * 4 PKs + assertThat(pages.size()).isGreaterThan(1); resultList.forEach(item -> { String pk = item.get("mypk").asText(); assertThat(pk).isIn("batchPk1", "batchPk2", "batchPk3", "batchPk4"); @@ -417,19 +432,31 @@ public void singlePk_readManyByPartitionKey_withSmallBatchSize() { @Test(groups = {"emulator"}, timeOut = TIMEOUT) public void singlePk_readManyByPartitionKey_withRequestOptions() { - // This test ensures that request options (like throughput control settings) - // are properly propagated through the readManyByPartitionKey path. - // It acts as a regression test for the redundant options construction bug. List items = createSinglePkItems("pkOpts", 3); List pkValues = Collections.singletonList(new PartitionKey("pkOpts")); com.azure.cosmos.models.CosmosReadManyRequestOptions options = new com.azure.cosmos.models.CosmosReadManyRequestOptions(); + AtomicInteger deserializeCount = new AtomicInteger(); + options.setCustomItemSerializer(new CosmosItemSerializerNoExceptionWrapping() { + @Override + public Map serialize(T item) { + return CosmosItemSerializer.DEFAULT_SERIALIZER.serialize(item); + } - CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( - pkValues, options, ObjectNode.class); - List resultList = results.stream().collect(Collectors.toList()); + @Override + public T deserialize(Map jsonNodeMap, Class classType) { + deserializeCount.incrementAndGet(); + return CosmosItemSerializer.DEFAULT_SERIALIZER.deserialize(jsonNodeMap, classType); + } + }); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, options, ReadManyByPartitionKeyPojo.class); + List resultList = results.stream().collect(Collectors.toList()); assertThat(resultList).hasSize(3); + assertThat(deserializeCount.get()).isEqualTo(3); + assertThat(resultList.stream().map(item -> item.mypk)).containsOnly("pkOpts"); cleanupContainer(singlePkContainer); } @@ -495,5 +522,10 @@ private void cleanupContainer(CosmosContainer container) { }); } + private static class ReadManyByPartitionKeyPojo { + public String id; + public String mypk; + } + //endregion } diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java new file mode 100644 index 000000000000..429d1fd8f89c --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos.implementation; + +import com.azure.cosmos.implementation.query.DCountInfo; +import com.azure.cosmos.implementation.query.PartitionedQueryExecutionInfo; +import com.azure.cosmos.implementation.query.QueryInfo; +import com.azure.cosmos.implementation.query.hybridsearch.HybridSearchQueryInfo; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ReadManyByPartitionKeyQueryPlanValidationTest { + + @Test(groups = { "unit" }) + public void rejectsDCountQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + DCountInfo dCountInfo = new DCountInfo(); + dCountInfo.setDCountAlias("countAlias"); + queryInfo.set("dCountInfo", dCountInfo); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("DCOUNT"); + } + + @Test(groups = { "unit" }) + public void rejectsOffsetQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + queryInfo.set("offset", 10); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("OFFSET"); + } + + @Test(groups = { "unit" }) + public void rejectsLimitQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + queryInfo.set("limit", 10); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("LIMIT"); + } + + @Test(groups = { "unit" }) + public void rejectsHybridSearchQueryPlanWithoutDereferencingNullQueryInfo() { + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey( + createQueryPlan(null, new HybridSearchQueryInfo()))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("hybrid/vector/full-text"); + } + + @Test(groups = { "unit" }) + public void acceptsSimpleQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + + assertThatCode(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .doesNotThrowAnyException(); + } + + private PartitionedQueryExecutionInfo createQueryPlan(QueryInfo queryInfo, HybridSearchQueryInfo hybridSearchQueryInfo) { + ObjectNode content = Utils.getSimpleObjectMapper().createObjectNode(); + content.put("partitionedQueryExecutionInfoVersion", Constants.PartitionedQueryExecutionInfo.VERSION_1); + + if (queryInfo != null) { + content.set("queryInfo", Utils.getSimpleObjectMapper().valueToTree(queryInfo.getMap())); + } + if (hybridSearchQueryInfo != null) { + content.set("hybridSearchQueryInfo", Utils.getSimpleObjectMapper().createObjectNode()); + } + + return new PartitionedQueryExecutionInfo(content, null); + } +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java index 965df746768b..0bdb867dc3ee 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; /** * Helper for constructing SqlQuerySpec instances for readManyByPartitionKey operations. @@ -160,6 +161,16 @@ public static SqlQuerySpec createReadManyByPkQuerySpec( return new SqlQuerySpec(finalQuery, parameters); } + static List createPkSelectors(PartitionKeyDefinition partitionKeyDefinition) { + return partitionKeyDefinition.getPaths() + .stream() + .map(PathParser::getPathParts) + .map(pathParts -> pathParts.stream() + .map(pathPart -> "[\"" + pathPart.replace("\"", "\\") + "\"]") + .collect(Collectors.joining())) + .collect(Collectors.toList()); + } + /** * Extracts the table/collection alias from a SQL query's FROM clause. * Handles: "SELECT * FROM c", "SELECT x.id FROM x WHERE ...", "SELECT * FROM root r", etc. diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 46767c9e9dd5..1126008946a3 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4485,7 +4485,7 @@ private Flux> readManyByPartitionKey( Map> partitionRangePkMap = groupPartitionKeysByPhysicalPartition(partitionKeys, pkDefinition, routingMap); - List partitionKeySelectors = createPkSelectors(pkDefinition); + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(pkDefinition); String baseQueryText; List baseParameters; @@ -4502,7 +4502,7 @@ private Flux> readManyByPartitionKey( // Build per-physical-partition batched queries. // Each physical partition may have many PKs - split into batches // to avoid oversized SQL queries. Batch size is configurable via - // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 1000). + // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 100). int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); // Build batches per partition as a list of lists (one inner list per partition). @@ -4577,41 +4577,62 @@ private Mono validateCustomQueryForReadManyByPartitionKey( RxDocumentClientImpl.this, getOperationContextAndListenerTuple(queryRequestOptions)); return DocumentQueryExecutionContextFactory - .fetchQueryPlanForValidation(this, queryClient, customQuery, resourceLink, queryRequestOptions) - .flatMap(queryPlan -> { - QueryInfo queryInfo = queryPlan.getQueryInfo(); + .fetchQueryPlanForValidation( + this, + queryClient, + customQuery, + resourceLink, + queryRequestOptions, + Configs.isQueryPlanCachingEnabled(), + this.getQueryPlanCache()) + .doOnNext(RxDocumentClientImpl::validateQueryPlanForReadManyByPartitionKey) + .then(); + } - if (queryInfo.hasGroupBy()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain GROUP BY.")); - } - if (queryInfo.hasAggregates()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain aggregates.")); - } - if (queryInfo.hasOrderBy()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain ORDER BY.")); - } - if (queryInfo.hasDistinct()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain DISTINCT.")); - } - if (queryInfo.hasDCount()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain DCOUNT.")); - } - if (queryInfo.hasNonStreamingOrderBy()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain non-streaming ORDER BY.")); - } - if (queryPlan.hasHybridSearchQueryInfo()) { - return Mono.error(new IllegalArgumentException( - "Custom query for readMany by partition key must not contain hybrid/vector/full-text search.")); - } + static void validateQueryPlanForReadManyByPartitionKey(PartitionedQueryExecutionInfo queryPlan) { + if (queryPlan.hasHybridSearchQueryInfo()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain hybrid/vector/full-text search."); + } - return Mono.empty(); - }); + QueryInfo queryInfo = queryPlan.getQueryInfo(); + if (queryInfo == null) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key is not supported because query plan details are unavailable."); + } + + if (queryInfo.hasGroupBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain GROUP BY."); + } + if (queryInfo.hasAggregates()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain aggregates."); + } + if (queryInfo.hasOrderBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain ORDER BY."); + } + if (queryInfo.hasDistinct()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DISTINCT."); + } + if (queryInfo.hasDCount()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DCOUNT."); + } + if (queryInfo.hasOffset()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain OFFSET."); + } + if (queryInfo.hasLimit()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain LIMIT."); + } + if (queryInfo.hasNonStreamingOrderBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain non-streaming ORDER BY."); + } } private Map> groupPartitionKeysByPhysicalPartition( @@ -4665,7 +4686,7 @@ private Map getRangeQueryMap( //TODO: Optimise this to include all types of partitionkeydefinitions. ex: c["prop1./ab"]["key1"] Map rangeQueryMap = new HashMap<>(); - List partitionKeySelectors = createPkSelectors(partitionKeyDefinition); + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(partitionKeyDefinition); for(Map.Entry> entry: partitionRangeItemKeyMap.entrySet()) { SqlQuerySpec sqlQuerySpec; @@ -4759,15 +4780,6 @@ private SqlQuerySpec createReadManyQuerySpec( return new SqlQuerySpec(queryStringBuilder.toString(), parameters); } - private List createPkSelectors(PartitionKeyDefinition partitionKeyDefinition) { - return partitionKeyDefinition.getPaths() - .stream() - .map(pathPart -> StringUtils.substring(pathPart, 1)) // skip starting / - .map(pathPart -> StringUtils.replace(pathPart, "\"", "\\")) // escape quote - .map(part -> "[\"" + part + "\"]") - .collect(Collectors.toList()); - } - private Flux> queryForReadMany( ScopedDiagnosticsFactory diagnosticsFactory, String parentResourceLink, @@ -5289,7 +5301,7 @@ public Flux> readAllDocuments( } PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); - List partitionKeySelectors = createPkSelectors(pkDefinition); + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(pkDefinition); SqlQuerySpec querySpec = createLogicalPartitionScanQuerySpec(partitionKey, partitionKeySelectors); String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java index d8f9614343cb..82506ae361e1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java @@ -107,7 +107,6 @@ private static Mono getPartitionKeyRangesAn } Instant startTime = Instant.now(); - Mono queryExecutionInfoMono; if (queryRequestOptionsAccessor() .isQueryPlanRetrievalDisallowed(cosmosQueryRequestOptions)) { @@ -122,37 +121,53 @@ private static Mono getPartitionKeyRangesAn endTime); } + return fetchQueryPlan( + diagnosticsClientContext, + client, + query, + resourceLink, + cosmosQueryRequestOptions, + queryPlanCachingEnabled, + queryPlanCache) + .flatMap( + partitionedQueryExecutionInfo -> { + + Instant endTime = Instant.now(); + + return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, + partitionedQueryExecutionInfo, startTime, endTime); + }); + } + + private static Mono fetchQueryPlan( + DiagnosticsClientContext diagnosticsClientContext, + IDocumentQueryClient client, + SqlQuerySpec query, + String resourceLink, + CosmosQueryRequestOptions cosmosQueryRequestOptions, + boolean queryPlanCachingEnabled, + Map queryPlanCache) { + if (queryPlanCachingEnabled && - isScopedToSinglePartition(cosmosQueryRequestOptions) && - queryPlanCache.containsKey(query.getQueryText())) { - Instant endTime = Instant.now(); // endTime for query plan diagnostics + isScopedToSinglePartition(cosmosQueryRequestOptions) && + queryPlanCache.containsKey(query.getQueryText())) { PartitionedQueryExecutionInfo partitionedQueryExecutionInfo = queryPlanCache.get(query.getQueryText()); if (partitionedQueryExecutionInfo != null) { logger.debug("Skipping query plan round trip by using the cached plan"); - return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, - partitionedQueryExecutionInfo, startTime, endTime); + return Mono.just(partitionedQueryExecutionInfo); } } - queryExecutionInfoMono = - QueryPlanRetriever.getQueryPlanThroughGatewayAsync( - diagnosticsClientContext, - client, - query, - resourceLink, - cosmosQueryRequestOptions); - - return queryExecutionInfoMono.flatMap( - partitionedQueryExecutionInfo -> { - - Instant endTime = Instant.now(); - + return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( + diagnosticsClientContext, + client, + query, + resourceLink, + cosmosQueryRequestOptions) + .doOnNext(partitionedQueryExecutionInfo -> { if (queryPlanCachingEnabled && isScopedToSinglePartition(cosmosQueryRequestOptions)) { tryCacheQueryPlan(query, partitionedQueryExecutionInfo, queryPlanCache); } - - return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, - partitionedQueryExecutionInfo, startTime, endTime); }); } @@ -323,10 +338,18 @@ public static Mono fetchQueryPlanForValidation( IDocumentQueryClient queryClient, SqlQuerySpec sqlQuerySpec, String resourceLink, - CosmosQueryRequestOptions queryRequestOptions) { + CosmosQueryRequestOptions queryRequestOptions, + boolean queryPlanCachingEnabled, + Map queryPlanCache) { - return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( - diagnosticsClientContext, queryClient, sqlQuerySpec, resourceLink, queryRequestOptions); + return fetchQueryPlan( + diagnosticsClientContext, + queryClient, + sqlQuerySpec, + resourceLink, + queryRequestOptions, + queryPlanCachingEnabled, + queryPlanCache); } public static Flux> createDocumentQueryExecutionContextAsync( diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md index 95d7624f0c8b..c193630d4064 100644 --- a/sdk/cosmos/docs/readManyByPartitionKey-design.md +++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md @@ -20,8 +20,8 @@ and additional filters. The SDK appends the auto-generated PK WHERE clause to it | Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | | PK deduplication | Done at Spark layer only, not in the SDK | | Spark UDF | New `GetCosmosPartitionKeyValue` UDF | -| Custom query validation | Gateway query plan; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/non-streaming ORDER BY/vector/fulltext | -| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 1000 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | +| Custom query validation | Gateway query plan via the standard SDK query-plan retrieval path; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/OFFSET/LIMIT/non-streaming ORDER BY/vector/fulltext | +| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 100 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | | Eager validation | Null and empty PK list rejected eagerly (not lazily in reactive chain) | | Telemetry | Separate span name `readManyByPartitionKeyItems.` (distinct from existing `readManyItems`) | | Query construction | Table alias auto-detected from FROM clause; string literals and subqueries handled correctly | @@ -52,14 +52,14 @@ Same signatures returning `CosmosPagedIterable`, delegating to the async cont ### Step 3: Internal orchestration (RxDocumentClientImpl) 1. Resolve collection metadata + PK definition from cache. -2. Fetch routing map from `partitionKeyRangeCache` **in parallel with** custom query validation (Step 4). +2. Fetch routing map from `partitionKeyRangeCache`. 3. For each `PartitionKey`: - Compute effective partition key (EPK). - Full PK → `getRangeByEffectivePartitionKey()` (single range). - Partial HPK → compute EPK prefix range → `getOverlappingRanges()` (multiple ranges). **Note:** partial HPK intentionally fans out to multiple physical partitions. 4. Group PK values by `PartitionKeyRange`. -5. Per physical partition → split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 1000). +5. Per physical partition → split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 100). 6. Per batch → build `SqlQuerySpec` with PK WHERE clause (Step 5). 7. Interleave batches across physical partitions in round-robin order so that bounded concurrency prefers different physical partitions over sequential batches of the same partition. 8. Execute queries via `queryForReadMany()` with bounded concurrency (`Math.min(batchCount, cpuCount)`). @@ -67,7 +67,7 @@ Same signatures returning `CosmosPagedIterable`, delegating to the async cont ### Step 4: Custom query validation -One-time call per invocation (existing query plan caching applies). Runs **in parallel** with routing map lookup to minimize latency: +One-time call per invocation using the same query-plan retrieval path and cacheability rules as regular SDK queries. - `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. - Reject (`IllegalArgumentException`) if: @@ -76,8 +76,11 @@ One-time call per invocation (existing query plan caching applies). Runs **in pa - `queryInfo.hasOrderBy()` - `queryInfo.hasDistinct()` - `queryInfo.hasDCount()` + - `queryInfo.hasOffset()` + - `queryInfo.hasLimit()` - `queryInfo.hasNonStreamingOrderBy()` - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` + - query plan details are unavailable (`queryInfo == null`) ### Step 5: Query construction @@ -115,7 +118,7 @@ New method `readManyByPartitionKey` added directly to `AsyncDocumentClient` inte ### Step 7: Configuration -New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 1000, minimum: 1). Follows existing `Configs` patterns. +New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 100, minimum: 1). Follows existing `Configs` patterns. ## Phase 2 — Spark Connector (`azure-cosmos-spark_3`) @@ -123,13 +126,14 @@ New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATC - Input: partition key value (single value or Seq for hierarchical PKs). - Output: serialized PK string in format `pk([...json...])`. -- **Null handling:** Throws on null input (Scala convention; callers should filter nulls upstream). +- **Null handling:** Null input is serialized as a JSON-null partition key component. If callers need `PartitionKey.NONE` semantics they must use the schema-matched path with `spark.cosmos.read.readManyByPk.nullHandling=None`, which is only supported for single-path partition keys. ### Step 9: PK-only serialization helper `CosmosPartitionKeyHelper`: - `getCosmosPartitionKeyValueString(pkValues: List[Object]): String` — serialize to `pk([...])` format. - `tryParsePartitionKey(serialized: String): Option[PartitionKey]` — deserialize; returns `None` for malformed input including invalid JSON (wrapped in `scala.util.Try`). +- When `spark.cosmos.read.readManyByPk.nullHandling=None` is used, hierarchical partition keys with null components are rejected with a clear error because `PartitionKey.NONE` cannot be used with multiple paths. ### Step 10: `CosmosItemsDataSource.readManyByPartitionKey` @@ -137,11 +141,13 @@ Static entry points that accept a DataFrame and Cosmos config. PK extraction sup 1. **UDF-produced column**: DataFrame contains `_partitionKeyIdentity` column (from `GetCosmosPartitionKeyValue` UDF). 2. **Schema-matched columns**: DataFrame columns match the container's PK paths. +Nested partition key paths are not resolved automatically from DataFrame columns and must use the UDF-produced `_partitionKeyIdentity` column. + Falls back with `IllegalArgumentException` if neither mode is possible. ### Step 11: `CosmosReadManyByPartitionKeyReader` -Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. +Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. The wrapper iterator closes the reader deterministically on exhaustion, on failures, and via Spark task-completion callbacks. ### Step 12: `ItemsPartitionReaderWithReadManyByPartitionKey` @@ -166,4 +172,5 @@ Spark `PartitionReader[InternalRow]` that: - Batch size validation: temporarily lowered batch size to exercise batching/interleaving logic. - Null/empty PK list rejection (eager validation). - Spark connector: `ItemsPartitionReaderWithReadManyByPartitionKey` with known PK values and non-existent PKs. +- Spark public API: nested partition key containers require `_partitionKeyIdentity` and succeed when populated via `GetCosmosPartitionKeyValue`. - `CosmosPartitionKeyHelper`: single/HPK roundtrip, case insensitivity, malformed input. From f34270dc97237a8b423806e2d43fc4d6e4319de9 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 17 Apr 2026 18:24:40 +0000 Subject: [PATCH 25/25] Addressing code review comments --- .../spark/CosmosPartitionKeyHelper.scala | 5 +++- .../CosmosReadManyByPartitionKeyReader.scala | 5 +--- .../spark/CosmosPartitionKeyHelperSpec.scala | 8 +++++++ .../com/azure/cosmos/CosmosTracerTest.java | 23 +++++++++++++++++++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala index 29831e9a6fed..84c3f2fadaf2 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -71,7 +71,10 @@ private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { case s: String => builder.add(s) case n: java.lang.Number => builder.add(n.doubleValue()) case b: java.lang.Boolean => builder.add(b.booleanValue()) - case other => builder.add(other.toString) + case other => + throw new IllegalArgumentException( + s"Unsupported partition key component type '${other.getClass.getName}' with value '$other'. " + + "Supported types are String, Number (integral or floating-point), Boolean, and null.") } Some(builder.build()) } else { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala index bd6fa5be8bdf..796aeea52b91 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -64,10 +64,7 @@ private[spark] class CosmosReadManyByPartitionKeyReader( cosmosContainerConfig, clientCacheItems(0).get, clientCacheItems(1)) - // Warm-up readItem: intentionally issues a lookup for a random id/partition-key pair - // on the driver so that the collection/routing-map caches are populated before we serialize - // the client state and broadcast it to executors. This costs ~1 RU + 1 RTT per broadcast build - // (expected 404) but avoids every executor doing the same lookup in parallel on first use. + // Warm-up readItem: intentionally issues a lookup for a random id/partition-key pair // on the driver so that the collection/routing-map caches are populated before we serialize // the client state and broadcast it to executors. This costs ~1 RU + 1 RTT per broadcast build diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala index c81113d3a190..ba7b37a27065 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -95,6 +95,14 @@ class CosmosPartitionKeyHelperSpec extends UnitSpec { pk.get shouldEqual PartitionKey.NONE } + it should "throw for unsupported component types in the null-handling builder path" in { + val error = the[IllegalArgumentException] thrownBy { + CosmosPartitionKeyHelper.tryParsePartitionKey("pk([null,{\"nested\":\"value\"}])", treatNullAsNone = false) + } + + error.getMessage should include("Unsupported partition key component type") + error.getMessage should include("java.util.LinkedHashMap") + } it should "throw a clear error when None nullHandling is used for hierarchical partition keys" in { val error = the[IllegalArgumentException] thrownBy { CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"Redmond\",null])", treatNullAsNone = true) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java index 06de8524bcbd..9a59e7af8d7c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java @@ -920,6 +920,29 @@ public void cosmosAsyncContainer( "readMany", samplingRate); mockTracer.reset(); + List partitionKeys = createdDocs + .stream() + .map(CosmosItemIdentity::getPartitionKey) + .collect(Collectors.toList()); + feedItemResponse = cosmosAsyncContainer + .readManyByPartitionKey(partitionKeys, ObjectNode.class) + .byPage(1) + .blockFirst(); + assertThat(feedItemResponse).isNotNull(); + assertThat(feedItemResponse.getResults()).isNotEmpty(); + verifyTracerAttributes( + mockTracer, + "readManyByPartitionKey." + cosmosAsyncContainer.getId(), + cosmosAsyncDatabase.getId(), + cosmosAsyncContainer.getId(), + feedItemResponse.getCosmosDiagnostics(), + null, + useLegacyTracing, + enableRequestLevelTracing, + forceThresholdViolations, + "readManyByPartitionKey", + samplingRate); + mockTracer.reset(); } @Test(groups = { "fast", "simple" }, timeOut = 10 * TIMEOUT)