diff --git a/sdk/cosmos/.gitignore b/sdk/cosmos/.gitignore index 1ea74182f6bb..81d278c1728f 100644 --- a/sdk/cosmos/.gitignore +++ b/sdk/cosmos/.gitignore @@ -2,3 +2,5 @@ metastore_db/* spark-warehouse/* + +.temp/ diff --git a/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala b/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala index 5f9cb1dbdbc8..8476e4aa2026 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3-5/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala @@ -4,13 +4,20 @@ package com.azure.cosmos.spark import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.models.{CosmosContainerProperties, CosmosItemRequestOptions, PartitionKey, PartitionKeyDefinition, PartitionKeyDefinitionVersion, PartitionKind, ThroughputProperties} +import com.azure.cosmos.spark.udf.GetCosmosPartitionKeyValue import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.types.StringType -import java.util.UUID +import java.util.{ArrayList, UUID} + +import scala.collection.JavaConverters._ class SparkE2EQueryITest extends SparkE2EQueryITestBase { + // scalastyle:off multiple.string.literals "spark query" can "return proper Cosmos specific query plan on explain with nullable properties" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY @@ -67,4 +74,115 @@ class SparkE2EQueryITest val item = rowsArray(0) item.getAs[String]("id") shouldEqual id } -} + + "spark readManyByPartitionKey" can "use a matching top-level partition key column without the UDF" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val requestOptions = new CosmosItemRequestOptions() + + Seq("pkA", "pkB").foreach { pkValue => + val item = objectMapper.createObjectNode() + item.put("id", s"item-$pkValue") + item.put("pk", pkValue) + item.put("payload", s"value-$pkValue") + + container.createItem(item, new PartitionKey(pkValue), requestOptions).block() + } + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true" + ) + + val sparkSession = spark + import sparkSession.implicits._ + + val rows = CosmosItemsDataSource + .readManyByPartitionKey(Seq("pkA", "pkB").toDF("pk"), cfg.asJava) + .selectExpr("id", "pk", "payload") + .collect() + + rows should have size 2 + rows.map(_.getAs[String]("id")).toSet shouldEqual Set("item-pkA", "item-pkB") + rows.map(_.getAs[String]("pk")).toSet shouldEqual Set("pkA", "pkB") + rows.map(_.getAs[String]("payload")).toSet shouldEqual Set("value-pkA", "value-pkB") + } + "spark readManyByPartitionKey" can "require the UDF for nested partition key paths and succeed with it" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val containerName = s"nested-pk-${UUID.randomUUID()}" + + val pkPaths = new ArrayList[String]() + pkPaths.add("/tenant/id") + + val pkDefinition = new PartitionKeyDefinition() + pkDefinition.setPaths(pkPaths) + pkDefinition.setKind(PartitionKind.HASH) + pkDefinition.setVersion(PartitionKeyDefinitionVersion.V2) + + val containerProperties = new CosmosContainerProperties(containerName, pkDefinition) + cosmosClient + .getDatabase(cosmosDatabase) + .createContainerIfNotExists(containerProperties, ThroughputProperties.createManualThroughput(400)) + .block() + + try { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(containerName) + val requestOptions = new CosmosItemRequestOptions() + + Seq("tenantA", "tenantB").foreach { tenantId => + val item = objectMapper.createObjectNode() + item.put("id", s"item-$tenantId") + item.put("payload", s"value-$tenantId") + item.putObject("tenant").put("id", tenantId) + + container.createItem(item, new PartitionKey(tenantId), requestOptions).block() + } + + val cfg = Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> containerName, + "spark.cosmos.read.inferSchema.enabled" -> "true" + ) + + val sparkSession = spark + import sparkSession.implicits._ + + val missingUdfError = the[IllegalArgumentException] thrownBy { + CosmosItemsDataSource.readManyByPartitionKey(Seq("tenantA").toDF("tenantId"), cfg.asJava) + } + + missingUdfError.getMessage should include("Nested paths cannot be resolved from DataFrame columns automatically") + missingUdfError.getMessage should include("_partitionKeyIdentity") + + spark.udf.register("GetCosmosPartitionKeyValue", new GetCosmosPartitionKeyValue(), StringType) + + val inputDf = Seq("tenantA", "tenantB") + .toDF("tenantId") + .withColumn("_partitionKeyIdentity", expr("GetCosmosPartitionKeyValue(tenantId)")) + + val rows = CosmosItemsDataSource + .readManyByPartitionKey(inputDf, cfg.asJava) + .selectExpr("id", "tenant.id as tenantId") + .collect() + + rows should have size 2 + rows.map(_.getAs[String]("id")).toSet shouldEqual Set("item-tenantA", "item-tenantB") + rows.map(_.getAs[String]("tenantId")).toSet shouldEqual Set("tenantA", "tenantB") + } finally { + cosmosClient + .getDatabase(cosmosDatabase) + .getContainer(containerName) + .delete() + .block() + } + } + + // scalastyle:on multiple.string.literals +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..8312e34bf0c1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -92,6 +92,7 @@ private[spark] object CosmosConfigNames { val ReadPartitioningFeedRangeFilter = "spark.cosmos.partitioning.feedRangeFilter" val ReadRuntimeFilteringEnabled = "spark.cosmos.read.runtimeFiltering.enabled" val ReadManyFilteringEnabled = "spark.cosmos.read.readManyFiltering.enabled" + val ReadManyByPkNullHandling = "spark.cosmos.read.readManyByPk.nullHandling" val ViewsRepositoryPath = "spark.cosmos.views.repositoryPath" val DiagnosticsMode = "spark.cosmos.diagnostics" val DiagnosticsSamplingMaxCount = "spark.cosmos.diagnostics.sampling.maxCount" @@ -226,6 +227,7 @@ private[spark] object CosmosConfigNames { ReadPartitioningFeedRangeFilter, ReadRuntimeFilteringEnabled, ReadManyFilteringEnabled, + ReadManyByPkNullHandling, ViewsRepositoryPath, DiagnosticsMode, DiagnosticsSamplingIntervalInSeconds, @@ -1042,7 +1044,8 @@ private case class CosmosReadConfig(readConsistencyStrategy: ReadConsistencyStra throughputControlConfig: Option[CosmosThroughputControlConfig] = None, runtimeFilteringEnabled: Boolean, readManyFilteringConfig: CosmosReadManyFilteringConfig, - responseContinuationTokenLimitInKb: Option[Int] = None) + responseContinuationTokenLimitInKb: Option[Int] = None, + readManyByPkTreatNullAsNone: Boolean = false) private object SchemaConversionModes extends Enumeration { type SchemaConversionMode = Value @@ -1136,6 +1139,20 @@ private object CosmosReadConfig { helpMessage = " Indicates whether dynamic partition pruning filters will be pushed down when applicable." ) + private val ReadManyByPkNullHandling = CosmosConfigEntry[String]( + key = CosmosConfigNames.ReadManyByPkNullHandling, + mandatory = false, + defaultValue = Some("Null"), + parseFromStringFunction = value => value, + helpMessage = "Determines how null values in partition key columns are treated for " + + "readManyByPartitionKey. 'Null' (default) maps null to a JSON null via addNullValue(), which " + + "is appropriate when the document field exists with an explicit null value. 'None' maps null " + + "to PartitionKey.NONE via addNoneValue(), which is only supported for single-path partition keys " + + "and should only be used when the partition key path does not exist at all in the document. " + + "Hierarchical partition keys reject this mode. These two semantics hash to DIFFERENT physical " + + "partitions - picking the wrong mode for your data will silently return zero rows." + ) + def parseCosmosReadConfig(cfg: Map[String, String]): CosmosReadConfig = { val forceEventualConsistency = CosmosConfigEntry.parse(cfg, ForceEventualConsistency) val readConsistencyStrategyOverride = CosmosConfigEntry.parse(cfg, ReadConsistencyStrategyOverride) @@ -1158,6 +1175,8 @@ private object CosmosReadConfig { val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg) val runtimeFilteringEnabled = CosmosConfigEntry.parse(cfg, ReadRuntimeFilteringEnabled) val readManyFilteringConfig = CosmosReadManyFilteringConfig.parseCosmosReadManyFilterConfig(cfg) + val readManyByPkNullHandling = CosmosConfigEntry.parse(cfg, ReadManyByPkNullHandling) + val readManyByPkTreatNullAsNone = readManyByPkNullHandling.getOrElse("Null").equalsIgnoreCase("None") val effectiveReadConsistencyStrategy = if (readConsistencyStrategyOverride.getOrElse(ReadConsistencyStrategy.DEFAULT) != ReadConsistencyStrategy.DEFAULT) { readConsistencyStrategyOverride.get @@ -1189,7 +1208,8 @@ private object CosmosReadConfig { throughputControlConfigOpt, runtimeFilteringEnabled.get, readManyFilteringConfig, - responseContinuationTokenLimitInKb) + responseContinuationTokenLimitInKb, + readManyByPkTreatNullAsNone) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala index 9ece47416526..00761f23d399 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala @@ -45,6 +45,7 @@ private[cosmos] object CosmosConstants { val Id = "id" val ETag = "_etag" val ItemIdentity = "_itemIdentity" + val PartitionKeyIdentity = "_partitionKeyIdentity" } object StatusCodes { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala index a35cff27af68..3433e352870b 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosItemsDataSource.scala @@ -2,9 +2,10 @@ // Licensed under the MIT License. package com.azure.cosmos.spark -import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey} +import com.azure.cosmos.models.{CosmosItemIdentity, PartitionKey, PartitionKeyBuilder} import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.azure.cosmos.SparkBridgeInternal import org.apache.spark.sql.{DataFrame, Row, SparkSession} import java.util @@ -112,4 +113,156 @@ object CosmosItemsDataSource { readManyReader.readMany(df.rdd, readManyFilterExtraction) } + + def readManyByPartitionKey(df: DataFrame, userConfig: java.util.Map[String, String]): DataFrame = { + readManyByPartitionKey(df, userConfig, null) + } + + def readManyByPartitionKey( + df: DataFrame, + userConfig: java.util.Map[String, String], + userProvidedSchema: StructType): DataFrame = { + + val readManyReader = new CosmosReadManyByPartitionKeyReader( + userProvidedSchema, + userConfig.asScala.toMap) + + // Resolve the null-handling config up front so both the UDF path and the PK-column path honor it. + val sharedEffectiveConfig = CosmosConfig.getEffectiveConfig( + databaseName = None, + containerName = None, + userConfig.asScala.toMap) + val sharedReadConfig = CosmosReadConfig.parseCosmosReadConfig(sharedEffectiveConfig) + val sharedTreatNullAsNone = sharedReadConfig.readManyByPkTreatNullAsNone + + // Option 1: Look for the _partitionKeyIdentity column (produced by GetCosmosPartitionKeyValue UDF) + val pkIdentityFieldExtraction = df + .schema + .find(field => field.name.equals(CosmosConstants.Properties.PartitionKeyIdentity) && field.dataType.equals(StringType)) + .map(field => (row: Row) => { + val rawValue = row.getString(row.fieldIndex(field.name)) + CosmosPartitionKeyHelper.tryParsePartitionKey(rawValue, sharedTreatNullAsNone) + .getOrElse(throw new IllegalArgumentException( + s"Invalid _partitionKeyIdentity value in row: '$rawValue'. " + + "Expected format: pk([...json...])")) + }) + + // Option 2: Detect PK columns by matching the container's partition key paths against the DataFrame schema + val pkColumnExtraction: Option[Row => PartitionKey] = if (pkIdentityFieldExtraction.isDefined) { + None // no need to resolve PK paths - _partitionKeyIdentity column takes precedence + } else { + val effectiveConfig = sharedEffectiveConfig + val readConfig = sharedReadConfig + val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(effectiveConfig) + val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(None) + val calledFrom = s"CosmosItemsDataSource.readManyByPartitionKey" + val treatNullAsNone = readConfig.readManyByPkTreatNullAsNone + + val pkPaths = Loan( + List[Option[CosmosClientCacheItem]]( + Some( + CosmosClientCache( + CosmosClientConfiguration( + effectiveConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom)), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => { + val container = + ThroughputControlHelper.getContainer( + effectiveConfig, + containerConfig, + clientCacheItems(0).get, + clientCacheItems(1)) + + val pkDefinition = SparkBridgeInternal + .getContainerPropertiesFromCollectionCache(container) + .getPartitionKeyDefinition + + pkDefinition.getPaths.asScala.map(_.stripPrefix("/")).toList + }) + + // Nested PK paths (containing /) cannot be resolved from top-level DataFrame columns. + // Surface an explicit error so users know to use the UDF-produced _partitionKeyIdentity column. + if (pkPaths.exists(_.contains("/"))) { + throw new IllegalArgumentException( + "Container has nested partition key path(s) " + pkPaths.mkString("[", ",", "]") + ". " + + "Nested paths cannot be resolved from DataFrame columns automatically - add a " + + "'_partitionKeyIdentity' column produced by the GetCosmosPartitionKeyValue UDF.") + } + + // Check if ALL PK path columns exist in the DataFrame schema + val dfFieldNames = df.schema.fieldNames.toSet + val allPkColumnsPresent = pkPaths.forall(path => dfFieldNames.contains(path)) + + if (allPkColumnsPresent && pkPaths.nonEmpty) { + // pkPaths already defined above + Some((row: Row) => { + if (pkPaths.size == 1) { + // Single partition key + buildPartitionKey(row.getAs[Any](pkPaths.head), treatNullAsNone) + } else { + // Hierarchical partition key - build level by level + val builder = new PartitionKeyBuilder() + for (path <- pkPaths) { + addPartitionKeyComponent(builder, row.getAs[Any](path), treatNullAsNone, pkPaths.size) + } + builder.build() + } + }) + } else { + None + } + } + + val pkExtraction = pkIdentityFieldExtraction + .orElse(pkColumnExtraction) + .getOrElse( + throw new IllegalArgumentException( + "Cannot determine partition key extraction from the input DataFrame. " + + "Either add a '_partitionKeyIdentity' column (using the GetCosmosPartitionKeyValue UDF) " + + "or ensure the DataFrame contains columns matching the container's partition key paths.")) + + readManyReader.readManyByPartitionKey(df.rdd, pkExtraction) + } + + private def addPartitionKeyComponent( + builder: PartitionKeyBuilder, + value: Any, + treatNullAsNone: Boolean, + partitionKeyComponentCount: Int): Unit = { + value match { + case s: String => builder.add(s) + case n: Number => builder.add(n.doubleValue()) + case b: Boolean => builder.add(b) + case null => + CosmosPartitionKeyHelper.validateNoneHandlingForPartitionKeyComponentCount( + partitionKeyComponentCount, + treatNullAsNone) + if (treatNullAsNone) builder.addNoneValue() + else builder.addNullValue() + case other => + // Reject unknown types rather than silently .toString-ing them - the document field + // was stored with its original type and a stringified value will never match. + // Supported types: String, Number (Byte/Short/Int/Long/Float/Double/BigDecimal), Boolean, null. + throw new IllegalArgumentException( + s"Unsupported partition key column type '${other.getClass.getName}' with value '$other'. " + + "Supported types are String, Number (integral or floating-point), Boolean, and null. " + + "For other source types, convert the column before calling readManyByPartitionKey or use " + + "the GetCosmosPartitionKeyValue UDF to produce a '_partitionKeyIdentity' column.") + } + } + + private def buildPartitionKey(value: Any, treatNullAsNone: Boolean): PartitionKey = { + val builder = new PartitionKeyBuilder() + addPartitionKeyComponent(builder, value, treatNullAsNone, 1) + builder.build() + } } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala new file mode 100644 index 000000000000..84c3f2fadaf2 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelper.scala @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.routing.PartitionKeyInternal +import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, Utils} +import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait + +import java.util + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private[spark] object CosmosPartitionKeyHelper extends BasicLoggingTrait { + private[spark] val HierarchicalPartitionKeyNoneHandlingErrorMessage = + s"The configuration '${CosmosConfigNames.ReadManyByPkNullHandling}=None' is not supported for " + + "hierarchical partition keys because PartitionKey.NONE can't be used with multiple paths. " + + "Use 'Null' for explicit JSON null values, filter out rows with missing partition key " + + "components, or provide fully-defined hierarchical partition keys." + + private[spark] def validateNoneHandlingForPartitionKeyComponentCount( + componentCount: Int, + treatNullAsNone: Boolean): Unit = { + if (treatNullAsNone && componentCount > 1) { + throw new IllegalArgumentException(HierarchicalPartitionKeyNoneHandlingErrorMessage) + } + } + + // pattern will be recognized + // pk(partitionKeyValue) + // + // (?i) : The whole matching is case-insensitive + // pk[(](.*)[)]: partitionKey Value + private val cosmosPartitionKeyStringRegx = """(?i)pk[(](.*)[)]""".r + private val objectMapper = Utils.getSimpleObjectMapper + + def getCosmosPartitionKeyValueString(partitionKeyValue: List[Object]): String = { + s"pk(${objectMapper.writeValueAsString(partitionKeyValue.asJava)})" + } + + def tryParsePartitionKey(cosmosPartitionKeyString: String): Option[PartitionKey] = + tryParsePartitionKey(cosmosPartitionKeyString, treatNullAsNone = false) + + /** + * Parses a pk(...) string into a [[PartitionKey]]. + * + * When treatNullAsNone is true, any JSON null components in the serialized array are mapped to + * [[PartitionKeyBuilder.addNoneValue()]] (meaning the document field is absent/undefined). + * When false, they are mapped to [[PartitionKeyBuilder.addNullValue()]] (JSON null value). + * This matches the spark.cosmos.read.readManyByPk.nullHandling config for the non-UDF column path. + */ + def tryParsePartitionKey( + cosmosPartitionKeyString: String, + treatNullAsNone: Boolean): Option[PartitionKey] = { + cosmosPartitionKeyString match { + case cosmosPartitionKeyStringRegx(pkValue) => + scala.util.Try(Utils.parse(pkValue, classOf[Object])).toOption.flatMap { + case arrayList: util.ArrayList[Object @unchecked] => + val components = arrayList.toArray + if (components.exists(_ == null)) { + validateNoneHandlingForPartitionKeyComponentCount(components.length, treatNullAsNone) + // Build via PartitionKeyBuilder so nulls can be disambiguated between + // JSON-null (addNullValue) and undefined (addNoneValue) based on config. + val builder = new PartitionKeyBuilder() + components.foreach { + case null => + if (treatNullAsNone) builder.addNoneValue() else builder.addNullValue() + case s: String => builder.add(s) + case n: java.lang.Number => builder.add(n.doubleValue()) + case b: java.lang.Boolean => builder.add(b.booleanValue()) + case other => + throw new IllegalArgumentException( + s"Unsupported partition key component type '${other.getClass.getName}' with value '$other'. " + + "Supported types are String, Number (integral or floating-point), Boolean, and null.") + } + Some(builder.build()) + } else { + Some( + ImplementationBridgeHelpers + .PartitionKeyHelper + .getPartitionKeyAccessor + .toPartitionKey(PartitionKeyInternal.fromObjectArray(components, false))) + } + case other => Some(new PartitionKey(other)) + } + case _ => None + } + } +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala new file mode 100644 index 000000000000..796aeea52b91 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosReadManyByPartitionKeyReader.scala @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosException, ReadConsistencyStrategy} +import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, UUIDs} +import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.spark.CosmosPredicates.assertOnSparkDriver +import com.azure.cosmos.spark.diagnostics.{BasicLoggingTrait, DiagnosticsContext} +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.StructType + +import java.util.UUID + +private[spark] class CosmosReadManyByPartitionKeyReader( + val userProvidedSchema: StructType, + val userConfig: Map[String, String] + ) extends BasicLoggingTrait with Serializable { + val effectiveUserConfig: Map[String, String] = CosmosConfig.getEffectiveConfig( + databaseName = None, + containerName = None, + userConfig) + + val clientConfig: CosmosAccountConfig = CosmosAccountConfig.parseCosmosAccountConfig(effectiveUserConfig) + val readConfig: CosmosReadConfig = CosmosReadConfig.parseCosmosReadConfig(effectiveUserConfig) + val cosmosContainerConfig: CosmosContainerConfig = + CosmosContainerConfig.parseCosmosContainerConfig(effectiveUserConfig) + //scalastyle:off multiple.string.literals + val tableName: String = s"com.azure.cosmos.spark.items.${clientConfig.accountName}." + + s"${cosmosContainerConfig.database}.${cosmosContainerConfig.container}" + private lazy val sparkSession = { + assertOnSparkDriver() + SparkSession.active + } + val sparkEnvironmentInfo: String = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(sparkSession)) + logTrace(s"Instantiated ${this.getClass.getSimpleName} for $tableName") + + private[spark] def initializeAndBroadcastCosmosClientStatesForContainer(): Broadcast[CosmosClientMetadataCachesSnapshots] = { + val calledFrom = s"CosmosReadManyByPartitionKeyReader($tableName).initializeAndBroadcastCosmosClientStateForContainer" + Loan( + List[Option[CosmosClientCacheItem]]( + Some( + CosmosClientCache( + CosmosClientConfiguration( + effectiveUserConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom)), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveUserConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => { + val container = + ThroughputControlHelper.getContainer( + effectiveUserConfig, + cosmosContainerConfig, + clientCacheItems(0).get, + clientCacheItems(1)) + + // Warm-up readItem: intentionally issues a lookup for a random id/partition-key pair + // on the driver so that the collection/routing-map caches are populated before we serialize + // the client state and broadcast it to executors. This costs ~1 RU + 1 RTT per broadcast build + // (expected 404) but avoids every executor doing the same lookup in parallel on first use. + try { + container.readItem( + UUIDs.nonBlockingRandomUUID().toString, + new PartitionKey(UUIDs.nonBlockingRandomUUID().toString), + classOf[ObjectNode]) + .block() + } catch { + // The warm-up readItem is only used to hydrate the collection/routing-map caches. + // A 404 (item not found) is expected, but we log other CosmosExceptions at debug to + // aid diagnosis (auth failures, throttling, etc.) while not failing reader setup. + case ex: CosmosException => + logDebug(s"Warm-up readItem for metadata caches completed with exception: ${ex.getMessage}", ex) + None + } + + val state = new CosmosClientMetadataCachesSnapshot() + state.serialize(clientCacheItems(0).get.cosmosClient) + + var throughputControlState: Option[CosmosClientMetadataCachesSnapshot] = None + if (clientCacheItems(1).isDefined) { + throughputControlState = Some(new CosmosClientMetadataCachesSnapshot()) + throughputControlState.get.serialize(clientCacheItems(1).get.cosmosClient) + } + + val metadataSnapshots = CosmosClientMetadataCachesSnapshots(state, throughputControlState) + sparkSession.sparkContext.broadcast(metadataSnapshots) + }) + } + + def readManyByPartitionKey(inputRdd: RDD[Row], pkExtraction: Row => PartitionKey): DataFrame = { + val correlationActivityId = UUIDs.nonBlockingRandomUUID() + val calledFrom = s"CosmosReadManyByPartitionKeyReader.readManyByPartitionKey($correlationActivityId)" + val schema = Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration( + effectiveUserConfig, + readConsistencyStrategy = readConfig.readConsistencyStrategy, + sparkEnvironmentInfo), + None, + calledFrom + )), + ThroughputControlHelper.getThroughputControlClientCacheItem( + effectiveUserConfig, + calledFrom, + None, + sparkEnvironmentInfo) + )) + .to(clientCacheItems => Option.apply(userProvidedSchema).getOrElse( + CosmosTableSchemaInferrer.inferSchema( + clientCacheItems(0).get, + clientCacheItems(1), + effectiveUserConfig, + ItemsTable.defaultSchemaForInferenceDisabled))) + + val clientStates = initializeAndBroadcastCosmosClientStatesForContainer + + sparkSession.sqlContext.createDataFrame( + inputRdd.mapPartitionsWithIndex( + (partitionIndex: Int, rowIterator: Iterator[Row]) => { + val pkIterator: Iterator[PartitionKey] = rowIterator + .map(row => pkExtraction.apply(row)) + + logInfo(s"Creating an ItemsPartitionReaderWithReadManyByPartitionKey for Activity $correlationActivityId to read for " + + s"input partition [$partitionIndex] ${tableName}") + + val taskContext = TaskContext.get + val reader = new ItemsPartitionReaderWithReadManyByPartitionKey( + effectiveUserConfig, + CosmosReadManyHelper.FullRangeFeedRange, + schema, + DiagnosticsContext(correlationActivityId, partitionIndex.toString), + clientStates, + DiagnosticsConfig.parseDiagnosticsConfig(effectiveUserConfig), + sparkEnvironmentInfo, + taskContext, + pkIterator) + + new Iterator[Row] { + private var isClosed = false + + private def closeReader(): Unit = { + if (!isClosed) { + isClosed = true + reader.close() + } + } + + if (taskContext != null) { + taskContext.addTaskCompletionListener[Unit](_ => closeReader()) + } + + override def hasNext: Boolean = { + try { + val hasMore = reader.next() + if (!hasMore) { + closeReader() + } + hasMore + } catch { + case error: Throwable => + closeReader() + throw error + } + } + + override def next(): Row = { + try { + reader.getCurrentRow() + } catch { + case error: Throwable => + closeReader() + throw error + } + } + } + }, + preservesPartitioning = true + ), + schema) + } +} + diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala new file mode 100644 index 000000000000..5c994cc2f7f8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKey.scala @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosAsyncContainer, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosItemSerializerNoExceptionWrapping, SparkBridgeInternal} +import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple +import com.azure.cosmos.implementation.{ImplementationBridgeHelpers, ObjectNodeMap, SparkRowItem, Utils} +import com.azure.cosmos.BridgeInternal +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, ModelBridgeInternal, PartitionKey, PartitionKeyDefinition, SqlQuerySpec} +import com.azure.cosmos.spark.BulkWriter.getThreadInfo +import com.azure.cosmos.spark.CosmosTableSchemaInferrer.IdAttributeName +import com.azure.cosmos.spark.diagnostics.{DetailedFeedDiagnosticsProvider, DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType + +import java.util +import java.util.concurrent.atomic.AtomicBoolean + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private[spark] case class ItemsPartitionReaderWithReadManyByPartitionKey +( + config: Map[String, String], + feedRange: NormalizedRange, + readSchema: StructType, + diagnosticsContext: DiagnosticsContext, + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + sparkEnvironmentInfo: String, + taskContext: TaskContext, + readManyPartitionKeys: Iterator[PartitionKey] +) + extends PartitionReader[InternalRow] { + + private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) + + private val readManyOptions = new CosmosReadManyRequestOptions() + private val readManyOptionsImpl = ImplementationBridgeHelpers + .CosmosReadManyRequestOptionsHelper + .getCosmosReadManyRequestOptionsAccessor + .getImpl(readManyOptions) + + private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + ThroughputControlHelper.populateThroughputControlGroupName(readManyOptionsImpl, readConfig.throughputControlConfig) + + private val operationContext = { + assert(taskContext != null) + + SparkTaskContext(diagnosticsContext.correlationActivityId, + taskContext.stageId(), + taskContext.partitionId(), + taskContext.taskAttemptId(), + feedRange.toString) + } + + private val operationContextAndListenerTuple: Option[OperationContextAndListenerTuple] = { + if (diagnosticsConfig.mode.isDefined) { + val listener = + DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass) + + val ctxAndListener = new OperationContextAndListenerTuple(operationContext, listener) + + readManyOptionsImpl + .setOperationContextAndListenerTuple(ctxAndListener) + + Some(ctxAndListener) + } else { + None + } + } + + log.logTrace(s"Instantiated ${this.getClass.getSimpleName}, Context: ${operationContext.toString} $getThreadInfo") + + private val containerTargetConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) + + log.logInfo(s"Using ReadManyByPartitionKey from feed range $feedRange of " + + s"container ${containerTargetConfig.database}.${containerTargetConfig.container} - " + + s"correlationActivityId ${diagnosticsContext.correlationActivityId}, " + + s"Context: ${operationContext.toString} $getThreadInfo") + + private val clientCacheItem = CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), + s"ItemsPartitionReaderWithReadManyByPartitionKey($feedRange, ${containerTargetConfig.database}.${containerTargetConfig.container})" + ) + + private val throughputControlClientCacheItemOpt = + ThroughputControlHelper.getThroughputControlClientCacheItem( + config, + clientCacheItem.context, + Some(cosmosClientStateHandles), + sparkEnvironmentInfo) + + private val cosmosAsyncContainer = + ThroughputControlHelper.getContainer( + config, + containerTargetConfig, + clientCacheItem, + throughputControlClientCacheItemOpt) + + private val partitionKeyDefinition: PartitionKeyDefinition = { + TransientErrorsRetryPolicy.executeWithRetry(() => { + SparkBridgeInternal + .getContainerPropertiesFromCollectionCache(cosmosAsyncContainer).getPartitionKeyDefinition + }) + } + + private val cosmosSerializationConfig = CosmosSerializationConfig.parseSerializationConfig(config) + private val cosmosRowConverter = CosmosRowConverter.get(cosmosSerializationConfig) + + readManyOptionsImpl + .setCustomItemSerializer( + new CosmosItemSerializerNoExceptionWrapping { + override def serialize[T](item: T): util.Map[String, AnyRef] = { + throw new UnsupportedOperationException( + s"Serialization is not supported by the custom item serializer in " + + s"ItemsPartitionReaderWithReadManyByPartitionKey; this serializer is intended " + + s"for deserializing read-many responses into SparkRowItem only. " + + s"Unexpected item type: ${if (item == null) "null" else item.getClass.getName}" + ) + } + + override def deserialize[T](jsonNodeMap: util.Map[String, AnyRef], classType: Class[T]): T = { + if (jsonNodeMap == null) { + throw new IllegalStateException("The 'jsonNodeMap' should never be null here.") + } + + if (classType != classOf[SparkRowItem]) { + throw new IllegalStateException("The 'classType' must be 'classOf[SparkRowItem])' here.") + } + + val objectNode: ObjectNode = jsonNodeMap match { + case map: ObjectNodeMap => + map.getObjectNode + case _ => + Utils.getSimpleObjectMapper.convertValue(jsonNodeMap, classOf[ObjectNode]) + } + + val partitionKey = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition) + + val row = cosmosRowConverter.fromObjectNodeToRow(readSchema, + objectNode, + readConfig.schemaConversionMode) + + SparkRowItem(row, getPartitionKeyForFeedDiagnostics(partitionKey)).asInstanceOf[T] + } + } + ) + + // Collect all PK values upfront - readManyByPartitionKey needs the full list to + // group by physical partition (the SDK batches internally per physical partition). + // Deduplicate using the canonical PartitionKeyInternal JSON representation so that + // equivalent PKs built from different runtime types (Int vs Long vs Double) are + // collapsed, and distinct PKs that happen to toString() identically are not. + private lazy val pkList = { + val seen = new java.util.LinkedHashMap[String, PartitionKey]() + readManyPartitionKeys.foreach(pk => { + val key = BridgeInternal.getPartitionKeyInternal(pk).toJson + seen.putIfAbsent(key, pk) + }) + new java.util.ArrayList[PartitionKey](seen.values()) + } + + private val endToEndTimeoutPolicy = + new CosmosEndToEndOperationLatencyPolicyConfigBuilder( + java.time.Duration.ofSeconds(CosmosConstants.readOperationEndToEndTimeoutInSeconds)) + .enable(true) + .build + + readManyOptionsImpl.setCosmosEndToEndOperationLatencyPolicyConfig(endToEndTimeoutPolicy) + + private trait CloseableSparkRowItemIterator { + def hasNext: Boolean + def next(): SparkRowItem + def close(): Unit + } + + private object EmptySparkRowItemIterator extends CloseableSparkRowItemIterator { + override def hasNext: Boolean = false + + override def next(): SparkRowItem = { + throw new java.util.NoSuchElementException("No items available for empty partition-key list.") + } + + override def close(): Unit = {} + } + + // Pass the full PK list to the SDK (which batches per physical partition internally). + // On transient I/O failures the retry iterator tracks pages already emitted upstream + // and skips them on replay; if a failure occurs mid-page (after items from that page have been + // emitted) the task fails rather than risking row duplication. + private val isClosed = new AtomicBoolean(false) + private var iteratorOpt: Option[CloseableSparkRowItemIterator] = None + + private def getOrCreateIterator: CloseableSparkRowItemIterator = iteratorOpt match { + case Some(existing) => existing + case None => + val created = + if (pkList.isEmpty) { + EmptySparkRowItemIterator + } else { + new CloseableSparkRowItemIterator { + private val delegate = new TransientIOErrorsRetryingReadManyByPartitionKeyIterator[SparkRowItem]( + cosmosAsyncContainer, + pkList, + readConfig.customQuery.map(_.toSqlQuerySpec), + readManyOptions, + readConfig.maxItemCount, + readConfig.prefetchBufferSize, + operationContextAndListenerTuple, + classOf[SparkRowItem] + ) + + override def hasNext: Boolean = delegate.hasNext + + override def next(): SparkRowItem = delegate.next() + + override def close(): Unit = delegate.close() + } + } + + iteratorOpt = Some(created) + created + } + + private val rowSerializer: ExpressionEncoder.Serializer[Row] = RowSerializerPool.getOrCreateSerializer(readSchema) + + private def shouldLogDetailedFeedDiagnostics(): Boolean = { + diagnosticsConfig.mode.isDefined && + diagnosticsConfig.mode.get.equalsIgnoreCase(classOf[DetailedFeedDiagnosticsProvider].getName) + } + + private def getPartitionKeyForFeedDiagnostics(pkValue: PartitionKey): Option[PartitionKey] = { + if (shouldLogDetailedFeedDiagnostics()) { + Some(pkValue) + } else { + None + } + } + + override def next(): Boolean = getOrCreateIterator.hasNext + + override def get(): InternalRow = { + cosmosRowConverter.fromRowToInternalRow(getOrCreateIterator.next().row, rowSerializer) + } + + def getCurrentRow(): Row = getOrCreateIterator.next().row + + override def close(): Unit = { + if (isClosed.compareAndSet(false, true)) { + iteratorOpt.foreach(_.close()) + iteratorOpt = None + RowSerializerPool.returnSerializerToPool(readSchema, rowSerializer) + clientCacheItem.close() + if (throughputControlClientCacheItemOpt.isDefined) { + throughputControlClientCacheItemOpt.get.close() + } + } + } +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala new file mode 100644 index 000000000000..f82855218cce --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransientIOErrorsRetryingReadManyByPartitionKeyIterator.scala @@ -0,0 +1,290 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosAsyncContainer, CosmosException} +import com.azure.cosmos.implementation.OperationCancelledException +import com.azure.cosmos.implementation.spark.OperationContextAndListenerTuple +import com.azure.cosmos.models.{CosmosReadManyRequestOptions, FeedResponse, PartitionKey, SqlQuerySpec} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.azure.cosmos.util.CosmosPagedIterable + +import java.util.concurrent.{ExecutorService, SynchronousQueue, ThreadPoolExecutor, TimeUnit, TimeoutException} +import java.util.concurrent.atomic.AtomicLong +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.util.Random +import scala.util.control.Breaks + +/** + * Retry-safe iterator for readManyByPartitionKey. The full partition-key list is passed to the + * SDK in a single call - the SDK is responsible for fan-out and per-physical-partition batching + * (see Configs.getReadManyByPkMaxBatchSize()). This iterator therefore wraps a single + * CosmosPagedIterable and, on transient I/O failures, re-creates the underlying flux and + * skips the pages that were already emitted upstream so no row is delivered twice. + */ +private[spark] class TransientIOErrorsRetryingReadManyByPartitionKeyIterator[TSparkRow] +( + val container: CosmosAsyncContainer, + val partitionKeys: java.util.List[PartitionKey], + val customQuery: Option[SqlQuerySpec], + val queryOptions: CosmosReadManyRequestOptions, + val pageSize: Int, + val pagePrefetchBufferSize: Int, + val operationContextAndListener: Option[OperationContextAndListenerTuple], + val classType: Class[TSparkRow] +) extends BufferedIterator[TSparkRow] with BasicLoggingTrait with AutoCloseable { + + private[spark] var maxRetryIntervalInMs = CosmosConstants.maxRetryIntervalForTransientFailuresInMs + private[spark] var maxRetryCount = CosmosConstants.maxRetryCountForTransientFailures + + private val maxPageRetrievalTimeout = scala.concurrent.duration.FiniteDuration( + 5 + CosmosConstants.readOperationEndToEndTimeoutInSeconds, + scala.concurrent.duration.SECONDS) + + private val rnd = Random + private val retryCount = new AtomicLong(0) + private lazy val operationContextString = operationContextAndListener match { + case Some(o) => if (o.getOperationContext != null) { + o.getOperationContext.toString + } else { + "n/a" + } + case None => "n/a" + } + + // Number of pages that have been fully emitted upstream. On retry, we recreate the flux + // and skip this many pages before emitting any item, so already-delivered rows are not + // re-emitted. A page is "committed" as soon as we surface its first item to the caller - + // subsequent failures while still inside that page cannot be recovered from without + // risking duplication, so we fail fast in that case. + private var pagesCommitted: Long = 0 + // Whether the currently-buffered page has emitted at least one item. If true, we have + // passed the point of no return for this page: any transient failure here must surface, + // because we cannot partially-skip within a page on retry. + private var currentPagePartiallyConsumed: Boolean = false + + private[spark] var currentFeedResponseIterator: Option[BufferedIterator[FeedResponse[TSparkRow]]] = None + private[spark] var currentItemIterator: Option[BufferedIterator[TSparkRow]] = None + + override def hasNext: Boolean = { + executeWithRetry("hasNextInternal", () => hasNextInternal) + } + + private def hasNextInternal: Boolean = { + var returnValue: Option[Boolean] = None + + while (returnValue.isEmpty) { + returnValue = hasNextInternalCore + } + + returnValue.get + } + + private def hasNextInternalCore: Option[Boolean] = { + if (hasBufferedNext) { + Some(true) + } else { + val feedResponseIterator = currentFeedResponseIterator match { + case Some(existing) => existing + case None => + val pagedFlux = customQuery match { + case Some(query) => + container.readManyByPartitionKey(partitionKeys, query, queryOptions, classType) + case None => + container.readManyByPartitionKey(partitionKeys, queryOptions, classType) + } + + val rawIterator = new CosmosPagedIterable[TSparkRow]( + pagedFlux, + pageSize, + pagePrefetchBufferSize + ) + .iterableByPage() + .iterator + + // Skip pages already emitted upstream (replay-safe retry). + var skipped: Long = 0 + while (skipped < pagesCommitted && rawIterator.hasNext) { + rawIterator.next() + skipped += 1 + } + if (skipped < pagesCommitted) { + // The server returned fewer pages than before - cannot safely replay. + // Surface a clean error rather than silently emitting a truncated result. + throw new IllegalStateException( + s"readManyByPartitionKey retry replay failed: expected to skip $pagesCommitted " + + s"already-emitted pages but only $skipped were available. Context: $operationContextString") + } + + // scalastyle:off underscore.import + import scala.collection.JavaConverters._ + // scalastyle:on underscore.import + currentFeedResponseIterator = Some(rawIterator.asScala.buffered) + + currentFeedResponseIterator.get + } + + val hasNext: Boolean = try { + Await.result( + Future { + feedResponseIterator.hasNext + }(TransientIOErrorsRetryingReadManyByPartitionKeyIterator.executionContext), + maxPageRetrievalTimeout) + } catch { + case endToEndTimeoutException: OperationCancelledException => + val message = s"End-to-end timeout hit when trying to retrieve the next page. " + + s"Context: $operationContextString" + logError(message, throwable = endToEndTimeoutException) + throw endToEndTimeoutException + + case timeoutException: TimeoutException => + val message = s"Attempting to retrieve the next page timed out. " + + s"Context: $operationContextString" + logError(message, timeoutException) + val exception = new OperationCancelledException(message, null) + exception.setStackTrace(timeoutException.getStackTrace) + throw exception + + case other: Throwable => throw other + } + + if (hasNext) { + val feedResponse = feedResponseIterator.next() + if (operationContextAndListener.isDefined) { + operationContextAndListener.get.getOperationListener.feedResponseProcessedListener( + operationContextAndListener.get.getOperationContext, + feedResponse) + } + // scalastyle:off underscore.import + import scala.collection.JavaConverters._ + // scalastyle:on underscore.import + val iteratorCandidate = feedResponse.getResults.iterator().asScala.buffered + + if (iteratorCandidate.hasNext) { + currentItemIterator = Some(iteratorCandidate) + currentPagePartiallyConsumed = false + Some(true) + } else { + // empty page - count it as committed (no items to replay) and try again + pagesCommitted += 1 + None + } + } else { + // Flux exhausted + currentFeedResponseIterator = None + Some(false) + } + } + } + + private def hasBufferedNext: Boolean = { + currentItemIterator match { + case Some(iterator) => if (iterator.hasNext) { + true + } else { + // Entire page drained -> it is now committed for replay-skipping purposes. + pagesCommitted += 1 + currentPagePartiallyConsumed = false + currentItemIterator = None + false + } + case None => false + } + } + + override def next(): TSparkRow = { + executeWithRetry("next", () => { + val value = currentItemIterator.get.next() + currentPagePartiallyConsumed = true + value + }) + } + + override def head: TSparkRow = { + executeWithRetry("head", () => currentItemIterator.get.head) + } + + private[spark] def executeWithRetry[T](methodName: String, func: () => T): T = { + val loop = new Breaks() + var returnValue: Option[T] = None + + loop.breakable { + while (true) { + val retryIntervalInMs = rnd.nextInt(maxRetryIntervalInMs) + + try { + returnValue = Some(func()) + retryCount.set(0) + loop.break + } + catch { + case cosmosException: CosmosException => + if (Exceptions.canBeTransientFailure(cosmosException.getStatusCode, cosmosException.getSubStatusCode)) { + if (currentPagePartiallyConsumed) { + // We have already emitted items from the current page upstream. Replaying + // the flux would re-skip only completed pages, not items within a page - + // which would cause silent duplication. Fail the task instead. + logError( + s"Transient failure in TransientIOErrorsRetryingReadManyByPartitionKeyIterator." + + s"$methodName after items from the current page were already emitted - " + + s"cannot safely retry without duplicating rows.", + cosmosException) + throw cosmosException + } + val retryCountSnapshot = retryCount.incrementAndGet() + if (retryCountSnapshot > maxRetryCount) { + logError( + s"Too many transient failure retry attempts in " + + s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName", + cosmosException) + throw cosmosException + } else { + logWarning( + s"Transient failure handled in " + + s"TransientIOErrorsRetryingReadManyByPartitionKeyIterator.$methodName -" + + s" will be retried (attempt#$retryCountSnapshot) in ${retryIntervalInMs}ms " + + s"(pagesCommitted=$pagesCommitted)", + cosmosException) + } + } else { + throw cosmosException + } + case other: Throwable => throw other + } + + // Reset iterators; pagesCommitted is intentionally preserved so replay can skip them. + currentItemIterator = None + currentFeedResponseIterator = None + Thread.sleep(retryIntervalInMs) + } + } + + returnValue.get + } + + // Clean up iterator references - the underlying Reactor subscription from + // CosmosPagedIterable.iterator will be cleaned up when the iterator is GC'd. + // This matches the behavior of TransientIOErrorsRetryingIterator; any still-prefetched + // pages are discarded with the iterator. + override def close(): Unit = { + currentItemIterator = None + currentFeedResponseIterator = None + } +} + +private object TransientIOErrorsRetryingReadManyByPartitionKeyIterator extends BasicLoggingTrait { + private val maxConcurrency = SparkUtils.getNumberOfHostCPUCores + + val executorService: ExecutorService = new ThreadPoolExecutor( + maxConcurrency, + maxConcurrency, + 0L, + TimeUnit.MILLISECONDS, + new SynchronousQueue(), + SparkUtils.daemonThreadFactory(), + new ThreadPoolExecutor.CallerRunsPolicy() + ) + + val executionContext: ExecutionContext = ExecutionContext.fromExecutorService(executorService) +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala new file mode 100644 index 000000000000..4dcc812f3122 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/udf/GetCosmosPartitionKeyValue.scala @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark.udf + +import com.azure.cosmos.spark.CosmosPartitionKeyHelper +import org.apache.spark.sql.api.java.UDF1 + +@SerialVersionUID(1L) +class GetCosmosPartitionKeyValue extends UDF1[Object, String] { + // Null is a valid partition-key value (JSON null). A null input is serialized as a + // single-level partition key with a JSON null component; parsing that string back via + // CosmosPartitionKeyHelper.tryParsePartitionKey yields a PartitionKey built with + // addNullValue(). If the caller instead wants PartitionKey.NONE semantics (absent PK + // field) they should filter the null row before calling this UDF and use the schema-matched + // readManyByPartitionKey path with readManyByPk.nullHandling=None. That None mode is only + // supported for single-path partition keys; hierarchical partition keys reject it. + override def call(partitionKeyValue: Object): String = { + partitionKeyValue match { + case null => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(null)) + // for subpartitions case - Seq covers both WrappedArray (Scala 2.12) and ArraySeq (Scala 2.13) + case seq: Seq[Any] => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(seq.map(_.asInstanceOf[Object]).toList) + case _ => + CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List(partitionKeyValue)) + } + } +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala index 17f75e45a746..17a298d62131 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosConfigSpec.scala @@ -457,6 +457,7 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { config.runtimeFilteringEnabled shouldBe true config.readManyFilteringConfig.readManyFilteringEnabled shouldBe false config.readManyFilteringConfig.readManyFilterProperty shouldEqual "_itemIdentity" + config.readManyByPkTreatNullAsNone shouldBe false userConfig = Map( "spark.cosmos.read.forceEventualConsistency" -> "false", @@ -630,6 +631,47 @@ class CosmosConfigSpec extends UnitSpec with BasicLoggingTrait { config.customQuery.get.queryText shouldBe queryText } + it should "parse readManyByPk nullHandling configuration" in { + // Default (not specified) should treat null as JSON null (addNullValue) + var userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false" + ) + var config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // Explicit "Null" should treat null as JSON null (addNullValue) + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "Null" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // Case-insensitive "null" + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "null" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe false + + // "None" should treat null as PartitionKey.NONE (addNoneValue) + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "None" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe true + + // Case-insensitive "none" + userConfig = Map( + "spark.cosmos.read.forceEventualConsistency" -> "false", + "spark.cosmos.read.readManyByPk.nullHandling" -> "none" + ) + config = CosmosReadConfig.parseCosmosReadConfig(userConfig) + config.readManyByPkTreatNullAsNone shouldBe true + } + it should "throw on invalid read configuration" in { val userConfig = Map( "spark.cosmos.read.schemaConversionMode" -> "not a valid value" diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala new file mode 100644 index 000000000000..ba7b37a27065 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/CosmosPartitionKeyHelperSpec.scala @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} + +class CosmosPartitionKeyHelperSpec extends UnitSpec { + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + it should "return the correct partition key value string for single PK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("pk1")) + pkString shouldEqual "pk([\"pk1\"])" + } + + it should "return the correct partition key value string for HPK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city1", "zip1")) + pkString shouldEqual "pk([\"city1\",\"zip1\"])" + } + + it should "return the correct partition key value string for 3-level HPK" in { + val pkString = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("a", "b", "c")) + pkString shouldEqual "pk([\"a\",\"b\",\"c\"])" + } + + it should "parse valid single PK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"myPkValue\"])") + pk.isDefined shouldBe true + pk.get shouldEqual new PartitionKey("myPkValue") + } + + it should "parse valid HPK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"city1\",\"zip1\"])") + pk.isDefined shouldBe true + val expected = new PartitionKeyBuilder().add("city1").add("zip1").build() + pk.get shouldEqual expected + } + + it should "parse valid 3-level HPK string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"a\",\"b\",\"c\"])") + pk.isDefined shouldBe true + val expected = new PartitionKeyBuilder().add("a").add("b").add("c").build() + pk.get shouldEqual expected + } + + it should "roundtrip single PK" in { + val original = "pk([\"roundtrip\"])" + val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) + parsed.isDefined shouldBe true + val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("roundtrip")) + serialized shouldEqual original + } + + it should "roundtrip HPK" in { + val original = "pk([\"city\",\"zip\"])" + val parsed = CosmosPartitionKeyHelper.tryParsePartitionKey(original) + parsed.isDefined shouldBe true + val serialized = CosmosPartitionKeyHelper.getCosmosPartitionKeyValueString(List("city", "zip")) + serialized shouldEqual original + } + + it should "return None for malformed string" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("invalid_format") + pk.isDefined shouldBe false + } + + it should "return None for missing pk prefix" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("[\"value\"]") + pk.isDefined shouldBe false + } + + it should "be case-insensitive for parsing" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("PK([\"value\"])") + pk.isDefined shouldBe true + pk.get shouldEqual new PartitionKey("value") + } + + + it should "return None for malformed JSON inside pk() wrapper" in { + // Invalid JSON that would cause JsonProcessingException + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk({invalid json})") + pk.isDefined shouldBe false + } + + it should "return None for truncated JSON inside pk() wrapper" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"unterminated)") + pk.isDefined shouldBe false + } + + it should "parse single-path null as PartitionKey.NONE when treatNullAsNone is true" in { + val pk = CosmosPartitionKeyHelper.tryParsePartitionKey("pk([null])", treatNullAsNone = true) + + pk.isDefined shouldBe true + pk.get shouldEqual PartitionKey.NONE + } + + it should "throw for unsupported component types in the null-handling builder path" in { + val error = the[IllegalArgumentException] thrownBy { + CosmosPartitionKeyHelper.tryParsePartitionKey("pk([null,{\"nested\":\"value\"}])", treatNullAsNone = false) + } + + error.getMessage should include("Unsupported partition key component type") + error.getMessage should include("java.util.LinkedHashMap") + } + it should "throw a clear error when None nullHandling is used for hierarchical partition keys" in { + val error = the[IllegalArgumentException] thrownBy { + CosmosPartitionKeyHelper.tryParsePartitionKey("pk([\"Redmond\",null])", treatNullAsNone = true) + } + + error.getMessage should include(CosmosConfigNames.ReadManyByPkNullHandling) + error.getMessage should include("hierarchical partition keys") + } + + it should "reject addNoneValue in hierarchical partition keys" in { + val error = the[IllegalStateException] thrownBy { + new PartitionKeyBuilder().add("Redmond").addNoneValue().build() + } + + error.getMessage should include("PartitionKey.None can't be used with multiple paths") + } + + //scalastyle:on multiple.string.literals + //scalastyle:on magic.number +} diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala new file mode 100644 index 000000000000..5c2d7b59836d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/ItemsPartitionReaderWithReadManyByPartitionKeyITest.scala @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, TestConfigurations, Utils} +import com.azure.cosmos.models.PartitionKey +import com.azure.cosmos.spark.diagnostics.DiagnosticsContext +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.MockTaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +import java.util.UUID +import scala.collection.mutable.ListBuffer + +class ItemsPartitionReaderWithReadManyByPartitionKeyITest + extends IntegrationSpec + with Spark + with AutoCleanableCosmosContainersWithPkAsPartitionKey { + private val idProperty = "id" + private val pkProperty = "pk" + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + it should "be able to retrieve all items for given partition keys" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + + // Create items with known PK values + val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition + val allItemsByPk = scala.collection.mutable.Map[String, ListBuffer[ObjectNode]]() + val pkValues = List("pkA", "pkB", "pkC") + + for (pk <- pkValues) { + allItemsByPk(pk) = ListBuffer[ObjectNode]() + for (_ <- 1 to 5) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put(idProperty, UUID.randomUUID().toString) + objectNode.put(pkProperty, pk) + container.createItem(objectNode).block() + allItemsByPk(pk) += objectNode + } + } + + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ReadManyByPKTest" + ) + + val readSchema = StructType(Seq( + StructField(idProperty, StringType, false), + StructField(pkProperty, StringType, false) + )) + + val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + // Read items for pkA and pkB (not pkC) + val targetPks = List("pkA", "pkB") + val pkIterator = targetPks.map(pk => new PartitionKey(pk)).iterator + + val reader = ItemsPartitionReaderWithReadManyByPartitionKey( + config, + NormalizedRange("", "FF"), + readSchema, + diagnosticsContext, + cosmosClientMetadataCachesSnapshots, + diagnosticsConfig, + "", + MockTaskContext.mockTaskContext(), + pkIterator + ) + + val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) + val itemsReadFromReader = ListBuffer[ObjectNode]() + while (reader.next()) { + itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) + } + + // Should have 10 items (5 for pkA + 5 for pkB) + itemsReadFromReader.size shouldEqual 10 + + // All items should be from pkA or pkB + itemsReadFromReader.foreach(item => { + val pk = item.get(pkProperty).asText() + targetPks should contain(pk) + }) + + // Validate all expected IDs are present + val expectedIds = (allItemsByPk("pkA") ++ allItemsByPk("pkB")).map(_.get(idProperty).asText()).toSet + val actualIds = itemsReadFromReader.map(_.get(idProperty).asText()).toSet + actualIds shouldEqual expectedIds + + reader.close() + } + + it should "return empty results for non-existent partition keys" in { + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ReadManyByPKEmptyTest" + ) + + val readSchema = StructType(Seq( + StructField(idProperty, StringType, false), + StructField(pkProperty, StringType, false) + )) + + val diagnosticsContext = DiagnosticsContext(UUID.randomUUID(), "") + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + val cosmosClientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + val pkIterator = List(new PartitionKey("nonExistentPk")).iterator + + val reader = ItemsPartitionReaderWithReadManyByPartitionKey( + config, + NormalizedRange("", "FF"), + readSchema, + diagnosticsContext, + cosmosClientMetadataCachesSnapshots, + diagnosticsConfig, + "", + MockTaskContext.mockTaskContext(), + pkIterator + ) + + val itemsReadFromReader = ListBuffer[ObjectNode]() + val cosmosRowConverter = CosmosRowConverter.get(CosmosSerializationConfig.parseSerializationConfig(config)) + while (reader.next()) { + itemsReadFromReader += cosmosRowConverter.fromInternalRowToObjectNode(reader.get(), readSchema) + } + + itemsReadFromReader.size shouldEqual 0 + reader.close() + } + + private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = { + val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot() + cosmosClientMetadataCachesSnapshot.serialize(cosmosClient) + + spark.sparkContext.broadcast( + CosmosClientMetadataCachesSnapshots( + cosmosClientMetadataCachesSnapshot, + Option.empty[CosmosClientMetadataCachesSnapshot])) + } + + //scalastyle:on multiple.string.literals + //scalastyle:on magic.number +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java index 06de8524bcbd..9a59e7af8d7c 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosTracerTest.java @@ -920,6 +920,29 @@ public void cosmosAsyncContainer( "readMany", samplingRate); mockTracer.reset(); + List partitionKeys = createdDocs + .stream() + .map(CosmosItemIdentity::getPartitionKey) + .collect(Collectors.toList()); + feedItemResponse = cosmosAsyncContainer + .readManyByPartitionKey(partitionKeys, ObjectNode.class) + .byPage(1) + .blockFirst(); + assertThat(feedItemResponse).isNotNull(); + assertThat(feedItemResponse.getResults()).isNotEmpty(); + verifyTracerAttributes( + mockTracer, + "readManyByPartitionKey." + cosmosAsyncContainer.getId(), + cosmosAsyncDatabase.getId(), + cosmosAsyncContainer.getId(), + feedItemResponse.getCosmosDiagnostics(), + null, + useLegacyTracing, + enableRequestLevelTracing, + forceThresholdViolations, + "readManyByPartitionKey", + samplingRate); + mockTracer.reset(); } @Test(groups = { "fast", "simple" }, timeOut = 10 * TIMEOUT) diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java new file mode 100644 index 000000000000..ca684f1abf9e --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/ReadManyByPartitionKeyTest.java @@ -0,0 +1,531 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos; + +import com.azure.cosmos.models.CosmosContainerProperties; +import com.azure.cosmos.models.CosmosItemRequestOptions; +import com.azure.cosmos.models.FeedResponse; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyBuilder; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKeyDefinitionVersion; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; +import com.azure.cosmos.rx.TestSuiteBase; +import com.azure.cosmos.util.CosmosPagedIterable; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Factory; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; + +public class ReadManyByPartitionKeyTest extends TestSuiteBase { + + private String preExistingDatabaseId = CosmosDatabaseForTest.generateId(); + private CosmosClient client; + private CosmosDatabase createdDatabase; + + // Single PK container (/mypk) + private CosmosContainer singlePkContainer; + + // HPK container (/city, /zipcode, /areaCode) + private CosmosContainer multiHashContainer; + + @Factory(dataProvider = "clientBuilders") + public ReadManyByPartitionKeyTest(CosmosClientBuilder clientBuilder) { + super(clientBuilder); + } + + @BeforeClass(groups = {"emulator"}, timeOut = SETUP_TIMEOUT) + public void before_ReadManyByPartitionKeyTest() { + client = getClientBuilder().buildClient(); + createdDatabase = createSyncDatabase(client, preExistingDatabaseId); + + // Single PK container + String singlePkContainerName = UUID.randomUUID().toString(); + CosmosContainerProperties singlePkProps = new CosmosContainerProperties(singlePkContainerName, "/mypk"); + createdDatabase.createContainer(singlePkProps); + singlePkContainer = createdDatabase.getContainer(singlePkContainerName); + + // HPK container + String multiHashContainerName = UUID.randomUUID().toString(); + PartitionKeyDefinition hpkDef = new PartitionKeyDefinition(); + hpkDef.setKind(PartitionKind.MULTI_HASH); + hpkDef.setVersion(PartitionKeyDefinitionVersion.V2); + ArrayList paths = new ArrayList<>(); + paths.add("/city"); + paths.add("/zipcode"); + paths.add("/areaCode"); + hpkDef.setPaths(paths); + + CosmosContainerProperties hpkProps = new CosmosContainerProperties(multiHashContainerName, hpkDef); + createdDatabase.createContainer(hpkProps); + multiHashContainer = createdDatabase.getContainer(multiHashContainerName); + } + + @AfterClass(groups = {"emulator"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true) + public void afterClass() { + safeDeleteSyncDatabase(createdDatabase); + safeCloseSyncClient(client); + } + + //region Single PK tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_basic() { + // Create items with different PKs + List items = createSinglePkItems("pk1", 3); + items.addAll(createSinglePkItems("pk2", 2)); + items.addAll(createSinglePkItems("pk3", 4)); + + // Read by 2 partition keys + List pkValues = Arrays.asList( + new PartitionKey("pk1"), + new PartitionKey("pk2")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(5); // 3 + 2 + resultList.forEach(item -> { + String pk = item.get("mypk").asText(); + assertThat(pk).isIn("pk1", "pk2"); + }); + + // Cleanup + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withProjection() { + List items = createSinglePkItems("pkProj", 2); + + List pkValues = Collections.singletonList(new PartitionKey("pkProj")); + SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.mypk FROM c"); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(2); + // Should only have id and mypk fields (plus system properties) + resultList.forEach(item -> { + assertThat(item.has("id")).isTrue(); + assertThat(item.has("mypk")).isTrue(); + }); + + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withAdditionalFilter() { + // Create items with different "status" values + createSinglePkItemsWithStatus("pkFilter", "active", 3); + createSinglePkItemsWithStatus("pkFilter", "inactive", 2); + + List pkValues = Collections.singletonList(new PartitionKey("pkFilter")); + SqlQuerySpec customQuery = new SqlQuerySpec( + "SELECT * FROM c WHERE c.status = @status", + Arrays.asList(new SqlParameter("@status", "active"))); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(3); + resultList.forEach(item -> { + assertThat(item.get("status").asText()).isEqualTo("active"); + }); + + cleanupContainer(singlePkContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_emptyResults() { + List pkValues = Collections.singletonList(new PartitionKey("nonExistent")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).isEmpty(); + } + + //endregion + + //region HPK tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_fullPk() { + createHpkItems(); + + // Read by full PKs + List pkValues = Arrays.asList( + new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build(), + new PartitionKeyBuilder().add("Pittsburgh").add("15232").add(2).build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond/98053/1 has 2 items, Pittsburgh/15232/2 has 1 item + assertThat(resultList).hasSize(3); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_partialPk_singleLevel() { + createHpkItems(); + + // Read by partial PK (only city) + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond has 3 items total (2 with 98053/1 and 1 with 12345/1) + assertThat(resultList).hasSize(3); + resultList.forEach(item -> { + assertThat(item.get("city").asText()).isEqualTo("Redmond"); + }); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_partialPk_twoLevels() { + createHpkItems(); + + // Read by partial PK (city + zipcode) + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").add("98053").build()); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + // Redmond/98053 has 2 items + assertThat(resultList).hasSize(2); + resultList.forEach(item -> { + assertThat(item.get("city").asText()).isEqualTo("Redmond"); + assertThat(item.get("zipcode").asText()).isEqualTo("98053"); + }); + + cleanupContainer(multiHashContainer); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + @SuppressWarnings("deprecation") + public void hpk_readManyByPartitionKey_withNoneComponent() { + try { + createHpkItems(); + + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("city", "Redmond"); + item.put("zipcode", "98053"); + + try { + multiHashContainer.createItem(item); + fail("Should have thrown CosmosException for HPK item with missing trailing partition key component"); + } catch (CosmosException e) { + assertThat(e.getMessage()).contains("wrong-pk-value"); + } + + try { + new PartitionKeyBuilder().add("Redmond").add("98053").addNoneValue().build(); + fail("Should have thrown IllegalStateException for HPK addNoneValue"); + } catch (IllegalStateException e) { + assertThat(e.getMessage()).contains("PartitionKey.None can't be used with multiple paths"); + } + } finally { + cleanupContainer(multiHashContainer); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void hpk_readManyByPartitionKey_withProjection() { + createHpkItems(); + + List pkValues = Collections.singletonList( + new PartitionKeyBuilder().add("Redmond").add("98053").add(1).build()); + + SqlQuerySpec customQuery = new SqlQuerySpec("SELECT c.id, c.city FROM c"); + + CosmosPagedIterable results = multiHashContainer.readManyByPartitionKey( + pkValues, customQuery, null, ObjectNode.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(2); + + cleanupContainer(multiHashContainer); + } + + //endregion + + //region Negative/validation tests + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsAggregateQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec aggregateQuery = new SqlQuerySpec("SELECT COUNT(1) FROM c"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, aggregateQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for aggregate query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("aggregates"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsOrderByQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec orderByQuery = new SqlQuerySpec("SELECT * FROM c ORDER BY c.id"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, orderByQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for ORDER BY query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("ORDER BY"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsDistinctQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec distinctQuery = new SqlQuerySpec("SELECT DISTINCT c.mypk FROM c"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, distinctQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for DISTINCT query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("DISTINCT"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsGroupByQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec groupByQuery = new SqlQuerySpec("SELECT c.mypk FROM c GROUP BY c.mypk"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, groupByQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for GROUP BY query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("GROUP BY"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsGroupByWithAggregateQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec groupByWithAggregateQuery = new SqlQuerySpec("SELECT c.mypk, COUNT(1) as cnt FROM c GROUP BY c.mypk"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, groupByWithAggregateQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for GROUP BY with aggregate query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("GROUP BY"); + } + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void rejectsNullPartitionKeyList() { + singlePkContainer.readManyByPartitionKey((List) null, ObjectNode.class); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT, expectedExceptions = IllegalArgumentException.class) + public void rejectsEmptyPartitionKeyList() { + singlePkContainer.readManyByPartitionKey(new ArrayList<>(), ObjectNode.class) + .stream().collect(Collectors.toList()); + } + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void rejectsOffsetQuery() { + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + SqlQuerySpec offsetQuery = new SqlQuerySpec("SELECT * FROM c OFFSET 0 LIMIT 10"); + + try { + singlePkContainer.readManyByPartitionKey(pkValues, offsetQuery, null, ObjectNode.class) + .stream().collect(Collectors.toList()); + fail("Should have thrown IllegalArgumentException for OFFSET query"); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("OFFSET"); + } + } + + + + //endregion + + + //region Batch size tests (#10) + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withSmallBatchSize() { + // Temporarily set batch size to 2 to exercise the batching/interleaving logic + String originalValue = System.getProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); + try { + System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", "2"); + + // Create items across 4 PKs (more than the batch size of 2) + List items = createSinglePkItems("batchPk1", 2); + items.addAll(createSinglePkItems("batchPk2", 2)); + items.addAll(createSinglePkItems("batchPk3", 2)); + items.addAll(createSinglePkItems("batchPk4", 2)); + + // Read all 4 PKs — should be split into batches of 2 + List pkValues = Arrays.asList( + new PartitionKey("batchPk1"), + new PartitionKey("batchPk2"), + new PartitionKey("batchPk3"), + new PartitionKey("batchPk4")); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey(pkValues, ObjectNode.class); + List> pages = new ArrayList<>(); + results.iterableByPage().forEach(pages::add); + List resultList = pages.stream() + .flatMap(page -> page.getResults().stream()) + .collect(Collectors.toList()); + + assertThat(resultList).hasSize(8); // 2 items per PK * 4 PKs + assertThat(pages.size()).isGreaterThan(1); + resultList.forEach(item -> { + String pk = item.get("mypk").asText(); + assertThat(pk).isIn("batchPk1", "batchPk2", "batchPk3", "batchPk4"); + }); + + cleanupContainer(singlePkContainer); + } finally { + if (originalValue != null) { + System.setProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE", originalValue); + } else { + System.clearProperty("COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"); + } + } + } + + //endregion + + //region Custom serializer regression tests (#5) + + @Test(groups = {"emulator"}, timeOut = TIMEOUT) + public void singlePk_readManyByPartitionKey_withRequestOptions() { + List items = createSinglePkItems("pkOpts", 3); + + List pkValues = Collections.singletonList(new PartitionKey("pkOpts")); + com.azure.cosmos.models.CosmosReadManyRequestOptions options = new com.azure.cosmos.models.CosmosReadManyRequestOptions(); + AtomicInteger deserializeCount = new AtomicInteger(); + options.setCustomItemSerializer(new CosmosItemSerializerNoExceptionWrapping() { + @Override + public Map serialize(T item) { + return CosmosItemSerializer.DEFAULT_SERIALIZER.serialize(item); + } + + @Override + public T deserialize(Map jsonNodeMap, Class classType) { + deserializeCount.incrementAndGet(); + return CosmosItemSerializer.DEFAULT_SERIALIZER.deserialize(jsonNodeMap, classType); + } + }); + + CosmosPagedIterable results = singlePkContainer.readManyByPartitionKey( + pkValues, options, ReadManyByPartitionKeyPojo.class); + List resultList = results.stream().collect(Collectors.toList()); + + assertThat(resultList).hasSize(3); + assertThat(deserializeCount.get()).isEqualTo(3); + assertThat(resultList.stream().map(item -> item.mypk)).containsOnly("pkOpts"); + + cleanupContainer(singlePkContainer); + } + + //endregion + + //region helper methods + + private List createSinglePkItems(String pkValue, int count) { + List items = new ArrayList<>(); + for (int i = 0; i < count; i++) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("mypk", pkValue); + singlePkContainer.createItem(item); + items.add(item); + } + return items; + } + + private List createSinglePkItemsWithStatus(String pkValue, String status, int count) { + List items = new ArrayList<>(); + for (int i = 0; i < count; i++) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("mypk", pkValue); + item.put("status", status); + singlePkContainer.createItem(item); + items.add(item); + } + return items; + } + + private void createHpkItems() { + // Same data as CosmosMultiHashTest.createItems() + createHpkItem("Redmond", "98053", 1); + createHpkItem("Redmond", "98053", 1); + createHpkItem("Pittsburgh", "15232", 2); + createHpkItem("Stonybrook", "11790", 3); + createHpkItem("Stonybrook", "11794", 3); + createHpkItem("Stonybrook", "11791", 3); + createHpkItem("Redmond", "12345", 1); + } + + private void createHpkItem(String city, String zipcode, int areaCode) { + ObjectNode item = com.azure.cosmos.implementation.Utils.getSimpleObjectMapper().createObjectNode(); + item.put("id", UUID.randomUUID().toString()); + item.put("city", city); + item.put("zipcode", zipcode); + item.put("areaCode", areaCode); + multiHashContainer.createItem(item); + } + + private void cleanupContainer(CosmosContainer container) { + CosmosPagedIterable allItems = container.queryItems( + "SELECT * FROM c", new com.azure.cosmos.models.CosmosQueryRequestOptions(), ObjectNode.class); + allItems.forEach(item -> { + try { + container.deleteItem(item, new CosmosItemRequestOptions()); + } catch (CosmosException e) { + // ignore cleanup failures + } + }); + } + + private static class ReadManyByPartitionKeyPojo { + public String id; + public String mypk; + } + + //endregion +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java new file mode 100644 index 000000000000..95c109ba025f --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelperTest.java @@ -0,0 +1,426 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos.implementation; + +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyBuilder; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKeyDefinitionVersion; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ReadManyByPartitionKeyQueryHelperTest { + + //region Single PK (HASH) tests + + @Test(groups = { "unit" }) + public void singlePk_defaultQuery_singleValue() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getQueryText()).contains("@__rmPk_0"); + assertThat(result.getParameters()).hasSize(1); + assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("pk1"); + } + + @Test(groups = { "unit" }) + public void singlePk_defaultQuery_multipleValues() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList( + new PartitionKey("pk1"), + new PartitionKey("pk2"), + new PartitionKey("pk3")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getQueryText()).contains("@__rmPk_0"); + assertThat(result.getQueryText()).contains("@__rmPk_1"); + assertThat(result.getQueryText()).contains("@__rmPk_2"); + assertThat(result.getParameters()).hasSize(3); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_noWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT c.name, c.age FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).startsWith("SELECT c.name, c.age FROM c WHERE"); + assertThat(result.getQueryText()).contains("IN ("); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_withExistingWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@minAge", 18)); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c WHERE c.age > @minAge", baseParams, pkValues, selectors, pkDef); + + // Should AND the PK filter to the existing WHERE clause + assertThat(result.getQueryText()).contains("WHERE (c.age > @minAge) AND ("); + assertThat(result.getQueryText()).contains("IN ("); + assertThat(result.getParameters()).hasSize(2); // @minAge + @__rmPk_0 + assertThat(result.getParameters().get(0).getName()).isEqualTo("@minAge"); + } + + //endregion + + //region HPK (MULTI_HASH) tests + + @Test(groups = { "unit" }) + public void hpk_fullPk_defaultQuery() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("SELECT * FROM c WHERE"); + // Should use OR/AND pattern, not IN + assertThat(result.getQueryText()).doesNotContain("IN ("); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("AND"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); + assertThat(result.getParameters()).hasSize(2); + assertThat(result.getParameters().get(0).getValue(Object.class)).isEqualTo("Redmond"); + assertThat(result.getParameters().get(1).getValue(Object.class)).isEqualTo("98052"); + } + + @Test(groups = { "unit" }) + public void hpk_fullPk_multipleValues() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + List pkValues = Arrays.asList( + new PartitionKeyBuilder().add("Redmond").add("98052").build(), + new PartitionKeyBuilder().add("Seattle").add("98101").build()); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("OR"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_2"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_3"); + assertThat(result.getParameters()).hasSize(4); + } + + @Test(groups = { "unit" }) + public void hpk_partialPk_singleLevel() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); + List selectors = createSelectors(pkDef); + + // Partial PK — only first level + PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").build(); + List pkValues = Collections.singletonList(partialPk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + // Should NOT include zipcode or areaCode since it's partial + assertThat(result.getQueryText()).doesNotContain("zipcode"); + assertThat(result.getQueryText()).doesNotContain("areaCode"); + assertThat(result.getParameters()).hasSize(1); + } + + @Test(groups = { "unit" }) + public void hpk_partialPk_twoLevels() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode", "/areaCode"); + List selectors = createSelectors(pkDef); + + // Partial PK — first two levels + PartitionKey partialPk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(partialPk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("c[\"zipcode\"] = @__rmPk_1"); + assertThat(result.getQueryText()).doesNotContain("areaCode"); + assertThat(result.getParameters()).hasSize(2); + } + + @Test(groups = { "unit" }) + public void hpk_customQuery_withWhere() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@status", "active")); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT c.name FROM c WHERE c.status = @status", baseParams, pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("WHERE (c.status = @status) AND ("); + assertThat(result.getQueryText()).contains("c[\"city\"] = @__rmPk_0"); + assertThat(result.getParameters()).hasSize(3); // @status + 2 pk params + } + + //endregion + + //region findTopLevelWhereIndex tests + + @Test(groups = { "unit" }) + public void findWhere_simpleQuery() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.ID = 1"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_noWhere() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C"); + assertThat(idx).isEqualTo(-1); + } + + @Test(groups = { "unit" }) + public void findWhere_whereInSubquery() { + // WHERE inside parentheses (subquery) should be ignored + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM C WHERE EXISTS(SELECT VALUE T FROM T IN C.TAGS WHERE T = 'FOO')"); + // Should find the outer WHERE, not the inner one + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_caseInsensitive() { + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM C WHERE C.X = 1"); + assertThat(idx).isGreaterThan(0); + } + + @Test(groups = { "unit" }) + public void findWhere_whereNotKeyword() { + // "ELSEWHERE" should not match + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex("SELECT * FROM ELSEWHERE"); + assertThat(idx).isEqualTo(-1); + } + + //endregion + + //region Custom alias tests + + @Test(groups = { "unit" }) + public void singlePk_customAlias() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Arrays.asList(new PartitionKey("pk1"), new PartitionKey("pk2")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT x.id, x.mypk FROM x", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).startsWith("SELECT x.id, x.mypk FROM x WHERE"); + assertThat(result.getQueryText()).contains("x[\"mypk\"] IN ("); + assertThat(result.getQueryText()).doesNotContain("c[\"mypk\"]"); + } + + @Test(groups = { "unit" }) + public void singlePk_customAlias_withWhere() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@cat", "HelloWorld")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT x.id, x.mypk FROM x WHERE x.category = @cat", baseParams, pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("WHERE (x.category = @cat) AND (x[\"mypk\"] IN ("); + } + + @Test(groups = { "unit" }) + public void hpk_customAlias() { + PartitionKeyDefinition pkDef = createMultiHashPkDefinition("/city", "/zipcode"); + List selectors = createSelectors(pkDef); + + PartitionKey pk = new PartitionKeyBuilder().add("Redmond").add("98052").build(); + List pkValues = Collections.singletonList(pk); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT r.name FROM root r", new ArrayList<>(), pkValues, selectors, pkDef); + + assertThat(result.getQueryText()).contains("r[\"city\"] = @__rmPk_0"); + assertThat(result.getQueryText()).contains("r[\"zipcode\"] = @__rmPk_1"); + assertThat(result.getQueryText()).doesNotContain("c[\""); + } + + //endregion + + //region extractTableAlias tests + + @Test(groups = { "unit" }) + public void extractAlias_defaultC() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM c")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_customX() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT x.id FROM x WHERE x.age > 5")).isEqualTo("x"); + } + + @Test(groups = { "unit" }) + public void extractAlias_rootWithAlias() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT r.name FROM root r")).isEqualTo("r"); + } + + @Test(groups = { "unit" }) + public void extractAlias_rootNoAlias() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM root")).isEqualTo("root"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithWhere() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("SELECT * FROM items WHERE items.status = 'active'")).isEqualTo("items"); + } + + @Test(groups = { "unit" }) + public void extractAlias_caseInsensitive() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias("select * from MyContainer where MyContainer.id = '1'")).isEqualTo("MyContainer"); + } + + //endregion + + + //region String literal handling tests (#1) + + @Test(groups = { "unit" }) + public void findWhere_ignoresWhereInsideStringLiteral() { + // WHERE inside a string literal should be ignored + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.msg = 'use WHERE clause here'"); + // Should find the outer WHERE at position 16, not the one inside the string + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_ignoresParenthesesInsideStringLiteral() { + // Parentheses inside string literal should not affect depth tracking + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.name = 'foo(bar)' AND c.x = 1"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_handlesUnbalancedParenInStringLiteral() { + // Unbalanced paren inside string literal must not corrupt depth + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT * FROM c WHERE c.val = 'open(' AND c.active = true"); + assertThat(idx).isEqualTo(16); + } + + @Test(groups = { "unit" }) + public void findWhere_handlesStringLiteralBeforeWhere() { + // String literal in SELECT before WHERE + int idx = ReadManyByPartitionKeyQueryHelper.findTopLevelWhereIndex( + "SELECT 'WHERE' as label FROM c WHERE c.id = '1'"); + // The WHERE inside quotes should be ignored; the real WHERE is further along + assertThat(idx).isGreaterThan(30); + } + + @Test(groups = { "unit" }) + public void singlePk_customQuery_withStringLiteralContainingParens() { + PartitionKeyDefinition pkDef = createSinglePkDefinition("/mypk"); + List selectors = createSelectors(pkDef); + List pkValues = Collections.singletonList(new PartitionKey("pk1")); + + List baseParams = new ArrayList<>(); + baseParams.add(new SqlParameter("@msg", "hello")); + + SqlQuerySpec result = ReadManyByPartitionKeyQueryHelper.createReadManyByPkQuerySpec( + "SELECT * FROM c WHERE c.msg = 'test(value)WHERE'", baseParams, pkValues, selectors, pkDef); + + // Should correctly AND the PK filter to the real WHERE clause + assertThat(result.getQueryText()).contains("WHERE (c.msg = 'test(value)WHERE') AND ("); + } + + //endregion + + //region OFFSET/LIMIT/HAVING alias detection tests (#9) + + @Test(groups = { "unit" }) + public void extractAlias_containerWithOffset() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT * FROM c OFFSET 10 LIMIT 5")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithLimit() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT * FROM c LIMIT 10")).isEqualTo("c"); + } + + @Test(groups = { "unit" }) + public void extractAlias_containerWithHaving() { + assertThat(ReadManyByPartitionKeyQueryHelper.extractTableAlias( + "SELECT c.status, COUNT(1) FROM c GROUP BY c.status HAVING COUNT(1) > 1")).isEqualTo("c"); + } + + //endregion + + //region helpers + + private PartitionKeyDefinition createSinglePkDefinition(String path) { + PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); + pkDef.setKind(PartitionKind.HASH); + pkDef.setVersion(PartitionKeyDefinitionVersion.V2); + pkDef.setPaths(Collections.singletonList(path)); + return pkDef; + } + + private PartitionKeyDefinition createMultiHashPkDefinition(String... paths) { + PartitionKeyDefinition pkDef = new PartitionKeyDefinition(); + pkDef.setKind(PartitionKind.MULTI_HASH); + pkDef.setVersion(PartitionKeyDefinitionVersion.V2); + pkDef.setPaths(Arrays.asList(paths)); + return pkDef; + } + + private List createSelectors(PartitionKeyDefinition pkDef) { + return pkDef.getPaths() + .stream() + .map(pathPart -> pathPart.substring(1)) // skip starting / + .map(pathPart -> pathPart.replace("\"", "\\")) + .map(part -> "[\"" + part + "\"]") + .collect(Collectors.toList()); + } + + //endregion +} diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java new file mode 100644 index 000000000000..429d1fd8f89c --- /dev/null +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryPlanValidationTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ + +package com.azure.cosmos.implementation; + +import com.azure.cosmos.implementation.query.DCountInfo; +import com.azure.cosmos.implementation.query.PartitionedQueryExecutionInfo; +import com.azure.cosmos.implementation.query.QueryInfo; +import com.azure.cosmos.implementation.query.hybridsearch.HybridSearchQueryInfo; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.testng.annotations.Test; + +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ReadManyByPartitionKeyQueryPlanValidationTest { + + @Test(groups = { "unit" }) + public void rejectsDCountQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + DCountInfo dCountInfo = new DCountInfo(); + dCountInfo.setDCountAlias("countAlias"); + queryInfo.set("dCountInfo", dCountInfo); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("DCOUNT"); + } + + @Test(groups = { "unit" }) + public void rejectsOffsetQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + queryInfo.set("offset", 10); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("OFFSET"); + } + + @Test(groups = { "unit" }) + public void rejectsLimitQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + queryInfo.set("limit", 10); + + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("LIMIT"); + } + + @Test(groups = { "unit" }) + public void rejectsHybridSearchQueryPlanWithoutDereferencingNullQueryInfo() { + assertThatThrownBy(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey( + createQueryPlan(null, new HybridSearchQueryInfo()))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("hybrid/vector/full-text"); + } + + @Test(groups = { "unit" }) + public void acceptsSimpleQueryPlan() { + QueryInfo queryInfo = new QueryInfo(); + + assertThatCode(() -> RxDocumentClientImpl.validateQueryPlanForReadManyByPartitionKey(createQueryPlan(queryInfo, null))) + .doesNotThrowAnyException(); + } + + private PartitionedQueryExecutionInfo createQueryPlan(QueryInfo queryInfo, HybridSearchQueryInfo hybridSearchQueryInfo) { + ObjectNode content = Utils.getSimpleObjectMapper().createObjectNode(); + content.put("partitionedQueryExecutionInfoVersion", Constants.PartitionedQueryExecutionInfo.VERSION_1); + + if (queryInfo != null) { + content.set("queryInfo", Utils.getSimpleObjectMapper().valueToTree(queryInfo.getMap())); + } + if (hybridSearchQueryInfo != null) { + content.set("hybridSearchQueryInfo", Utils.getSimpleObjectMapper().createObjectNode()); + } + + return new PartitionedQueryExecutionInfo(content, null); + } +} \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java index ad871bb97c01..8260ed95b19d 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosAsyncContainer.java @@ -165,6 +165,7 @@ private static ImplementationBridgeHelpers.CosmosBatchRequestOptionsHelper.Cosmo private final String createItemSpanName; private final String readAllItemsSpanName; private final String readManyItemsSpanName; + private final String readManyByPartitionKeySpanName; private final String readAllItemsOfLogicalPartitionSpanName; private final String queryItemsSpanName; private final String queryChangeFeedSpanName; @@ -198,6 +199,7 @@ protected CosmosAsyncContainer(CosmosAsyncContainer toBeWrappedContainer) { this.createItemSpanName = "createItem." + this.id; this.readAllItemsSpanName = "readAllItems." + this.id; this.readManyItemsSpanName = "readManyItems." + this.id; + this.readManyByPartitionKeySpanName = "readManyByPartitionKey." + this.id; this.readAllItemsOfLogicalPartitionSpanName = "readAllItemsOfLogicalPartition." + this.id; this.queryItemsSpanName = "queryItems." + this.id; this.queryChangeFeedSpanName = "queryChangeFeed." + this.id; @@ -1601,6 +1603,162 @@ private Mono> readManyInternal( context); } + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, null, null, classType); + } + + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, null, requestOptions, classType); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + Class classType) { + + return this.readManyByPartitionKey(partitionKeys, customQuery, null, classType); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedFlux} containing one or several feed response pages + */ + public CosmosPagedFlux readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + if (partitionKeys == null) { + throw new IllegalArgumentException("Argument 'partitionKeys' must not be null."); + } + if (partitionKeys.isEmpty()) { + throw new IllegalArgumentException("Argument 'partitionKeys' must not be empty."); + } + for (PartitionKey pk : partitionKeys) { + if (pk == null) { + throw new IllegalArgumentException( + "Argument 'partitionKeys' must not contain null elements."); + } + } + + return UtilBridgeInternal.createCosmosPagedFlux( + readManyByPartitionKeyInternalFunc(partitionKeys, customQuery, requestOptions, classType)); + } + + private Function>> readManyByPartitionKeyInternalFunc( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + CosmosAsyncClient client = this.getDatabase().getClient(); + + return (pagedFluxOptions -> { + CosmosQueryRequestOptions queryRequestOptions = requestOptions == null + ? new CosmosQueryRequestOptions() + : queryOptionsAccessor().clone(readManyOptionsAccessor().getImpl(requestOptions)); + // Honor any caller-provided MaxDegreeOfParallelism; only default to the "unbounded" sentinel + // (-1) when the value is still at the default (0). CosmosReadManyRequestOptions currently does not + // expose MDOP, so this branch is defensive in case it is plumbed through in the future. + if (queryRequestOptions.getMaxDegreeOfParallelism() == 0) { + queryRequestOptions.setMaxDegreeOfParallelism(-1); + } + queryRequestOptions.setQueryName("readManyByPartitionKey"); + CosmosQueryRequestOptionsBase cosmosQueryRequestOptionsImpl = queryOptionsAccessor().getImpl(queryRequestOptions); + applyPolicies(OperationType.Query, ResourceType.Document, cosmosQueryRequestOptionsImpl, this.readManyByPartitionKeySpanName); + + QueryFeedOperationState state = new QueryFeedOperationState( + client, + this.readManyByPartitionKeySpanName, + database.getId(), + this.getId(), + ResourceType.Document, + OperationType.Query, + queryOptionsAccessor().getQueryNameOrDefault(queryRequestOptions, this.readManyByPartitionKeySpanName), + queryRequestOptions, + pagedFluxOptions + ); + + pagedFluxOptions.setFeedOperationState(state); + + return CosmosBridgeInternal + .getAsyncDocumentClient(this.getDatabase()) + .readManyByPartitionKey( + partitionKeys, + customQuery, + BridgeInternal.getLink(this), + state, + classType) + .map(response -> prepareFeedResponse(response, false)); + }); + } + /** * Reads all the items of a logical partition * diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java index 04a6060c1927..33149a498390 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/CosmosContainer.java @@ -540,6 +540,100 @@ public FeedResponse readMany( classType)); } + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, classType)); + } + + /** + * Reads many documents matching the provided partition key values. + * Unlike {@link #readMany(List, Class)} this method does not require item ids - it queries + * all documents matching the provided partition key values. Uses {@code SELECT * FROM c} + * as the base query. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, requestOptions, classType)); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, classType)); + } + + /** + * Reads many documents matching the provided partition key values with a custom query. + * The custom query can be used to apply projections (e.g. {@code SELECT c.name, c.age FROM c}) + * and/or additional filters (e.g. {@code SELECT * FROM c WHERE c.status = 'active'}). + * The SDK will automatically append partition key filtering to the custom query. + *

+ * The custom query must be a simple streamable query - aggregates, ORDER BY, DISTINCT, + * GROUP BY, DCOUNT, vector search, and full-text search are not supported and will be + * rejected. + *

+ * Partial hierarchical partition keys are supported and will fan out to multiple + * physical partitions. + * + * @param the type parameter + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query for projections/additional filters (null means SELECT * FROM c) + * @param requestOptions the optional request options + * @param classType class type + * @return a {@link CosmosPagedIterable} containing the results + */ + public CosmosPagedIterable readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) { + + return getCosmosPagedIterable(this.asyncContainer.readManyByPartitionKey(partitionKeys, customQuery, requestOptions, classType)); + } + /** * Reads all the items of a logical partition returning the results as {@link CosmosPagedIterable}. * diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java index 945e768a82ff..8e2499c9039f 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/AsyncDocumentClient.java @@ -1584,6 +1584,27 @@ Mono> readMany( QueryFeedOperationState state, Class klass); + /** + * Reads many documents by partition key values. + * Unlike {@link #readMany(List, String, QueryFeedOperationState, Class)} this method does not require + * item ids - it queries all documents matching the provided partition key values. + * Partial hierarchical partition keys are supported and will fan out to multiple physical partitions. + * + * @param partitionKeys list of partition key values to read documents for + * @param customQuery optional custom query (for projections/additional filters) - null means SELECT * FROM c + * @param collectionLink link for the documentcollection/container to be queried + * @param state the query operation state + * @param klass class type + * @param the type parameter + * @return a Flux with feed response pages of documents + */ + Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + Class klass); + /** * Read all documents of a certain logical partition. *

diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java index 337055c6947f..651019d91652 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/Configs.java @@ -58,7 +58,7 @@ public class Configs { private static final String NETTY_HTTP_CLIENT_METRICS_ENABLED = "COSMOS.NETTY_HTTP_CLIENT_METRICS_ENABLED"; private static final String NETTY_HTTP_CLIENT_METRICS_ENABLED_VARIABLE = "COSMOS_NETTY_HTTP_CLIENT_METRICS_ENABLED"; - // Thin client connect/acquire timeout — controls CONNECT_TIMEOUT_MILLIS for Gateway V2 data plane endpoints. + // Thin client connect/acquire timeout - controls CONNECT_TIMEOUT_MILLIS for Gateway V2 data plane endpoints. // Data plane requests are routed to the thin client regional endpoint (from RegionalRoutingContext) // which uses a non-443 port. These get a shorter 5s connect/acquire timeout. // Metadata requests target Gateway V1 endpoint (port 443) and retain the full 45s/60s timeout (unchanged). @@ -248,6 +248,11 @@ public class Configs { public static final String MIN_TARGET_BULK_MICRO_BATCH_SIZE_VARIABLE = "COSMOS_MIN_TARGET_BULK_MICRO_BATCH_SIZE"; public static final int DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE = 1; + // readManyByPartitionKey: max number of PK values per query per physical partition + public static final String READ_MANY_BY_PK_MAX_BATCH_SIZE = "COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE"; + public static final String READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE = "COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE"; + public static final int DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE = 100; + public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY = "COSMOS.MAX_BULK_MICRO_BATCH_CONCURRENCY"; public static final String MAX_BULK_MICRO_BATCH_CONCURRENCY_VARIABLE = "COSMOS_MAX_BULK_MICRO_BATCH_CONCURRENCY"; public static final int DEFAULT_MAX_BULK_MICRO_BATCH_CONCURRENCY = 1; @@ -678,7 +683,7 @@ public static int getThinClientConnectionTimeoutInMs() { } } - // Guard against invalid values — timeout must be at least 500ms + // Guard against invalid values - timeout must be at least 500ms if (value < 500) { logger.warn( "Invalid thin client connection timeout: {}ms. Must be >= 500. Falling back to default: {}ms.", @@ -816,6 +821,20 @@ public static int getMinTargetBulkMicroBatchSize() { return DEFAULT_MIN_TARGET_BULK_MICRO_BATCH_SIZE; } + public static int getReadManyByPkMaxBatchSize() { + String valueFromSystemProperty = System.getProperty(READ_MANY_BY_PK_MAX_BATCH_SIZE); + if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { + return Math.max(1, Integer.parseInt(valueFromSystemProperty)); + } + + String valueFromEnvVariable = System.getenv(READ_MANY_BY_PK_MAX_BATCH_SIZE_VARIABLE); + if (valueFromEnvVariable != null && !valueFromEnvVariable.isEmpty()) { + return Math.max(1, Integer.parseInt(valueFromEnvVariable)); + } + + return DEFAULT_READ_MANY_BY_PK_MAX_BATCH_SIZE; + } + public static int getMaxBulkMicroBatchConcurrency() { String valueFromSystemProperty = System.getProperty(MAX_BULK_MICRO_BATCH_CONCURRENCY); if (valueFromSystemProperty != null && !valueFromSystemProperty.isEmpty()) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java new file mode 100644 index 000000000000..0bdb867dc3ee --- /dev/null +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ReadManyByPartitionKeyQueryHelper.java @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.implementation; + +import com.azure.cosmos.BridgeInternal; +import com.azure.cosmos.implementation.routing.PartitionKeyInternal; +import com.azure.cosmos.models.PartitionKey; +import com.azure.cosmos.models.PartitionKeyDefinition; +import com.azure.cosmos.models.PartitionKind; +import com.azure.cosmos.models.SqlParameter; +import com.azure.cosmos.models.SqlQuerySpec; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Helper for constructing SqlQuerySpec instances for readManyByPartitionKey operations. + * This class is not intended to be used directly by end-users. + */ +public class ReadManyByPartitionKeyQueryHelper { + + private static final String DEFAULT_TABLE_ALIAS = "c"; + // Internal parameter prefix - uses double-underscore to avoid collisions with user-provided parameters + private static final String PK_PARAM_PREFIX = "@__rmPk_"; + + public static SqlQuerySpec createReadManyByPkQuerySpec( + String baseQueryText, + List baseParameters, + List pkValues, + List partitionKeySelectors, + PartitionKeyDefinition pkDefinition) { + + // Guard against collisions with our internal parameter names - callers cannot realistically + // use the @__rmPk_ prefix for their own parameters, but if they do we surface a clear error + // rather than letting the server reject a SqlQuerySpec with duplicate parameter names. + for (SqlParameter baseParam : baseParameters) { + String name = baseParam.getName(); + if (name != null && name.startsWith(PK_PARAM_PREFIX)) { + throw new IllegalArgumentException( + "Custom query parameter name '" + name + "' collides with the reserved " + + "readManyByPartitionKey internal prefix '" + PK_PARAM_PREFIX + "'. Rename the parameter."); + } + } + + // Extract the table alias from the FROM clause (e.g. "FROM x" -> "x", "FROM c" -> "c") + String tableAlias = extractTableAlias(baseQueryText); + + StringBuilder pkFilter = new StringBuilder(); + List parameters = new ArrayList<>(baseParameters); + int paramCount = 0; + + boolean isSinglePathPk = partitionKeySelectors.size() == 1; + + if (isSinglePathPk && pkDefinition.getKind() != PartitionKind.MULTI_HASH) { + // Single PK path - use IN clause for normal values, OR NOT IS_DEFINED for NONE + // First, separate NONE PKs from normal PKs + boolean hasNone = false; + List normalPkValues = new ArrayList<>(); + for (PartitionKey pk : pkValues) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); + if (pkInternal.getComponents() == null) { + hasNone = true; + } else { + normalPkValues.add(pk); + } + } + + pkFilter.append(" "); + boolean hasNormalValues = !normalPkValues.isEmpty(); + if (hasNormalValues && hasNone) { + pkFilter.append("("); + } + if (hasNormalValues) { + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(0)); + pkFilter.append(" IN ( "); + for (int i = 0; i < normalPkValues.size(); i++) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(normalPkValues.get(i)); + Object[] pkComponents = pkInternal.toObjectArray(); + String pkParamName = PK_PARAM_PREFIX + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[0])); + paramCount++; + + pkFilter.append(pkParamName); + if (i < normalPkValues.size() - 1) { + pkFilter.append(", "); + } + } + pkFilter.append(" )"); + } + if (hasNone) { + if (hasNormalValues) { + pkFilter.append(" OR "); + } + pkFilter.append("NOT IS_DEFINED("); + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(0)); + pkFilter.append(")"); + } + if (hasNormalValues && hasNone) { + pkFilter.append(")"); + } + } else { + // Multiple PK paths (HPK) or MULTI_HASH - use OR of AND clauses + pkFilter.append(" "); + for (int i = 0; i < pkValues.size(); i++) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pkValues.get(i)); + Object[] pkComponents = pkInternal.toObjectArray(); + + // PartitionKey.NONE - generate NOT IS_DEFINED for all PK paths + if (pkComponents == null) { + pkFilter.append("("); + for (int j = 0; j < partitionKeySelectors.size(); j++) { + if (j > 0) { + pkFilter.append(" AND "); + } + pkFilter.append("NOT IS_DEFINED("); + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(j)); + pkFilter.append(")"); + } + pkFilter.append(")"); + } else { + pkFilter.append("("); + for (int j = 0; j < pkComponents.length; j++) { + String pkParamName = PK_PARAM_PREFIX + paramCount; + parameters.add(new SqlParameter(pkParamName, pkComponents[j])); + paramCount++; + + if (j > 0) { + pkFilter.append(" AND "); + } + pkFilter.append(tableAlias); + pkFilter.append(partitionKeySelectors.get(j)); + pkFilter.append(" = "); + pkFilter.append(pkParamName); + } + pkFilter.append(")"); + } + + if (i < pkValues.size() - 1) { + pkFilter.append(" OR "); + } + } + } + + // Compose final query: handle existing WHERE clause in base query + String finalQuery; + int whereIndex = findTopLevelWhereIndex(baseQueryText); + if (whereIndex >= 0) { + // Base query has WHERE - AND our PK filter + String beforeWhere = baseQueryText.substring(0, whereIndex); + String afterWhere = baseQueryText.substring(whereIndex + 5); // skip "WHERE" + finalQuery = beforeWhere + "WHERE (" + afterWhere.trim() + ") AND (" + pkFilter.toString().trim() + ")"; + } else { + // No WHERE - add one + finalQuery = baseQueryText + " WHERE" + pkFilter.toString(); + } + + return new SqlQuerySpec(finalQuery, parameters); + } + + static List createPkSelectors(PartitionKeyDefinition partitionKeyDefinition) { + return partitionKeyDefinition.getPaths() + .stream() + .map(PathParser::getPathParts) + .map(pathParts -> pathParts.stream() + .map(pathPart -> "[\"" + pathPart.replace("\"", "\\") + "\"]") + .collect(Collectors.joining())) + .collect(Collectors.toList()); + } + + /** + * Extracts the table/collection alias from a SQL query's FROM clause. + * Handles: "SELECT * FROM c", "SELECT x.id FROM x WHERE ...", "SELECT * FROM root r", etc. + * Returns the alias used after FROM (last token before WHERE or end of FROM clause). + */ + static String extractTableAlias(String queryText) { + String upper = queryText.toUpperCase(); + int fromIndex = findTopLevelKeywordIndex(upper, "FROM"); + if (fromIndex < 0) { + return DEFAULT_TABLE_ALIAS; + } + + // Start scanning after "FROM" + int afterFrom = fromIndex + 4; + // Skip whitespace + while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; + } + + // Collect the container name token (could be "root", "c", etc.) + int tokenStart = afterFrom; + while (afterFrom < queryText.length() + && !Character.isWhitespace(queryText.charAt(afterFrom)) + && queryText.charAt(afterFrom) != '(' + && queryText.charAt(afterFrom) != ')') { + afterFrom++; + } + String containerName = queryText.substring(tokenStart, afterFrom); + + // Skip whitespace after container name + while (afterFrom < queryText.length() && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; + } + + // Check if there's an alias after the container name (before WHERE or end) + if (afterFrom < queryText.length()) { + String remaining = upper.substring(afterFrom); + // Reserved keywords that terminate the FROM clause - when the next token is one of these, + // containerName itself IS the alias used throughout the rest of the query. + if (isFollowedByReservedKeyword(remaining)) { + return containerName; + } + // Handle optional AS: "FROM root AS r" -> alias is "r" + if (remaining.startsWith("AS") + && (remaining.length() == 2 || !Character.isLetterOrDigit(remaining.charAt(2)))) { + afterFrom += 2; // skip AS + while (afterFrom < queryText.length() + && Character.isWhitespace(queryText.charAt(afterFrom))) { + afterFrom++; + } + } + // Otherwise the next token is the alias ("FROM root r" -> alias is "r") + int aliasStart = afterFrom; + while (afterFrom < queryText.length() + && !Character.isWhitespace(queryText.charAt(afterFrom)) + && queryText.charAt(afterFrom) != '(' + && queryText.charAt(afterFrom) != ')') { + afterFrom++; + } + if (afterFrom > aliasStart) { + return queryText.substring(aliasStart, afterFrom); + } + } + + return containerName; + } + + private static boolean isFollowedByReservedKeyword(String remainingUpper) { + String[] keywords = { "WHERE", "ORDER", "GROUP", "JOIN", "OFFSET", "LIMIT", "HAVING" }; + for (String kw : keywords) { + if (remainingUpper.startsWith(kw) + && (remainingUpper.length() == kw.length() + || !Character.isLetterOrDigit(remainingUpper.charAt(kw.length())))) { + return true; + } + } + return false; + } + + /** + * Finds the index of a top-level SQL keyword in the query text (case-insensitive), + * ignoring occurrences inside parentheses or string literals. + */ + static int findTopLevelKeywordIndex(String queryText, String keyword) { + String queryTextUpper = queryText.toUpperCase(); + String keywordUpper = keyword.toUpperCase(); + int depth = 0; + int keyLen = keywordUpper.length(); + int len = queryTextUpper.length(); + for (int i = 0; i <= len - keyLen; i++) { + char ch = queryText.charAt(i); + // Skip single-line comments: -- ... end-of-line + if (ch == '-' && i + 1 < len && queryText.charAt(i + 1) == '-') { + i += 2; + while (i < len && queryText.charAt(i) != '\n' && queryText.charAt(i) != '\r') { + i++; + } + continue; + } + // Skip block comments: /* ... */ + if (ch == '/' && i + 1 < len && queryText.charAt(i + 1) == '*') { + i += 2; + while (i + 1 < len + && !(queryText.charAt(i) == '*' && queryText.charAt(i + 1) == '/')) { + i++; + } + i++; // position on the '/'; loop post-increment moves past it + continue; + } + // Skip string literals enclosed in single quotes (handle '' escape) + if (ch == '\'') { + i++; + while (i < len) { + if (queryText.charAt(i) == '\'') { + if (i + 1 < len && queryText.charAt(i + 1) == '\'') { + i += 2; // escaped quote - skip both + continue; + } + break; // end of string literal + } + i++; + } + continue; + } + char upperCh = queryTextUpper.charAt(i); + if (upperCh == '(') { + depth++; + } else if (upperCh == ')') { + depth--; + } else if (depth == 0 && upperCh == keywordUpper.charAt(0) + && queryTextUpper.startsWith(keywordUpper, i) + && (i == 0 || !Character.isLetterOrDigit(queryTextUpper.charAt(i - 1))) + && (i + keyLen >= queryTextUpper.length() || !Character.isLetterOrDigit(queryTextUpper.charAt(i + keyLen)))) { + return i; + } + } + return -1; + } + + /** + * Finds the index of the top-level WHERE keyword in the query text, + * ignoring WHERE that appears inside parentheses (subqueries). + */ + public static int findTopLevelWhereIndex(String queryTextUpper) { + return findTopLevelKeywordIndex(queryTextUpper, "WHERE"); + } +} diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index 11121bca033e..1126008946a3 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -121,6 +121,7 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -4365,13 +4366,327 @@ private Mono>> readMany( ); } + @Override + public Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + Class klass) { + + checkNotNull(partitionKeys, "Argument 'partitionKeys' must not be null."); + checkArgument(!partitionKeys.isEmpty(), "Argument 'partitionKeys' must not be empty."); + + final ScopedDiagnosticsFactory diagnosticsFactory = new ScopedDiagnosticsFactory(this, true); + state.registerDiagnosticsFactory( + () -> {}, // we never want to reset in readManyByPartitionKey + (ctx) -> diagnosticsFactory.merge(ctx) + ); + + StaleResourceRetryPolicy staleResourceRetryPolicy = new StaleResourceRetryPolicy( + this.collectionCache, + null, + collectionLink, + queryOptionsAccessor().getProperties(state.getQueryOptions()), + queryOptionsAccessor().getHeaders(state.getQueryOptions()), + this.sessionContainer, + diagnosticsFactory, + ResourceType.Document + ); + + return ObservableHelper + .fluxInlineIfPossibleAsObs( + () -> readManyByPartitionKey( + partitionKeys, customQuery, collectionLink, state, diagnosticsFactory, klass), + staleResourceRetryPolicy + ) + .onErrorMap(throwable -> { + if (throwable instanceof CosmosException) { + CosmosException cosmosException = (CosmosException) throwable; + CosmosDiagnostics diagnostics = cosmosException.getDiagnostics(); + if (diagnostics != null) { + state.mergeDiagnosticsContext(); + CosmosDiagnosticsContext ctx = state.getDiagnosticsContextSnapshot(); + if (ctx != null) { + ctxAccessor().recordOperation( + ctx, + cosmosException.getStatusCode(), + cosmosException.getSubStatusCode(), + 0, + cosmosException.getRequestCharge(), + diagnostics, + throwable + ); + diagAccessor() + .setDiagnosticsContext( + diagnostics, + state.getDiagnosticsContextSnapshot()); + } + } + + return cosmosException; + } + + return throwable; + }); + } + + private Flux> readManyByPartitionKey( + List partitionKeys, + SqlQuerySpec customQuery, + String collectionLink, + QueryFeedOperationState state, + ScopedDiagnosticsFactory diagnosticsFactory, + Class klass) { + + String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); + RxDocumentServiceRequest request = RxDocumentServiceRequest.create(diagnosticsFactory, + OperationType.Query, + ResourceType.Document, + collectionLink, null + ); + + Mono> collectionObs = + collectionCache.resolveCollectionAsync(null, request); + + return collectionObs + .flatMapMany(documentCollectionResourceResponse -> { + final DocumentCollection collection = documentCollectionResourceResponse.v; + if (collection == null) { + return Flux.error(new IllegalStateException("Collection cannot be null")); + } + + final PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); + + Mono> valueHolderMono = partitionKeyRangeCache + .tryLookupAsync( + BridgeInternal.getMetaDataDiagnosticContext(request.requestContext.cosmosDiagnostics), + collection.getResourceId(), + null, + null); + + // Validate custom query if provided + Mono queryValidationMono; + if (customQuery != null) { + queryValidationMono = validateCustomQueryForReadManyByPartitionKey( + customQuery, resourceLink, state.getQueryOptions()); + } else { + queryValidationMono = Mono.empty(); + } + + return valueHolderMono + .delayUntil(ignored -> queryValidationMono) + .flatMapMany(routingMapHolder -> { + CollectionRoutingMap routingMap = routingMapHolder.v; + if (routingMap == null) { + return Flux.error(new IllegalStateException("Failed to get routing map.")); + } + + Map> partitionRangePkMap = + groupPartitionKeysByPhysicalPartition(partitionKeys, pkDefinition, routingMap); + + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(pkDefinition); + + String baseQueryText; + List baseParameters; + if (customQuery != null) { + baseQueryText = customQuery.getQueryText(); + baseParameters = customQuery.getParameters() != null + ? new ArrayList<>(customQuery.getParameters()) + : new ArrayList<>(); + } else { + baseQueryText = "SELECT * FROM c"; + baseParameters = new ArrayList<>(); + } + + // Build per-physical-partition batched queries. + // Each physical partition may have many PKs - split into batches + // to avoid oversized SQL queries. Batch size is configurable via + // system property COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE (default 100). + int maxPksPerPartitionQuery = Configs.getReadManyByPkMaxBatchSize(); + + // Build batches per partition as a list of lists (one inner list per partition). + // Then interleave in round-robin order so that concurrent execution + // prefers different physical partitions over multiple batches of the same partition. + List>> batchesPerPartition = new ArrayList<>(); + int maxBatchesPerPartition = 0; + + for (Map.Entry> entry : partitionRangePkMap.entrySet()) { + List allPks = entry.getValue(); + List> partitionBatches = new ArrayList<>(); + for (int i = 0; i < allPks.size(); i += maxPksPerPartitionQuery) { + List batch = allPks.subList( + i, Math.min(i + maxPksPerPartitionQuery, allPks.size())); + SqlQuerySpec querySpec = ReadManyByPartitionKeyQueryHelper + .createReadManyByPkQuerySpec( + baseQueryText, baseParameters, batch, + partitionKeySelectors, pkDefinition); + partitionBatches.add( + Collections.singletonMap(entry.getKey(), querySpec)); + } + batchesPerPartition.add(partitionBatches); + maxBatchesPerPartition = Math.max(maxBatchesPerPartition, partitionBatches.size()); + } + + if (batchesPerPartition.isEmpty()) { + return Flux.empty(); + } + + // Interleave batches across physical partitions so that the first batch for + // each partition is kicked off before the second batch of any partition. With bounded + // concurrency this spreads the initial wave of work across the cluster; note that + // skewed distributions (one partition with N batches, another with 1) will eventually + // fall back to sequential execution once the short partitions are drained. + List> interleavedBatches = new ArrayList<>(); + for (int batchIdx = 0; batchIdx < maxBatchesPerPartition; batchIdx++) { + for (List> partitionBatches : batchesPerPartition) { + if (batchIdx < partitionBatches.size()) { + interleavedBatches.add(partitionBatches.get(batchIdx)); + } + } + } + + // Execute all batches with bounded concurrency. + List>> queryFluxes = interleavedBatches + .stream() + .map(batchMap -> queryForReadMany( + diagnosticsFactory, + resourceLink, + new SqlQuerySpec(DUMMY_SQL_QUERY), + state.getQueryOptions(), + klass, + ResourceType.Document, + collection, + Collections.unmodifiableMap(batchMap))) + .collect(Collectors.toList()); + + int fluxConcurrency = Math.min(queryFluxes.size(), + Math.max(Configs.getCPUCnt(), 1)); + + return Flux.merge(Flux.fromIterable(queryFluxes), fluxConcurrency, 1); + }); + }); + } + + private Mono validateCustomQueryForReadManyByPartitionKey( + SqlQuerySpec customQuery, + String resourceLink, + CosmosQueryRequestOptions queryRequestOptions) { + + IDocumentQueryClient queryClient = documentQueryClientImpl( + RxDocumentClientImpl.this, getOperationContextAndListenerTuple(queryRequestOptions)); + + return DocumentQueryExecutionContextFactory + .fetchQueryPlanForValidation( + this, + queryClient, + customQuery, + resourceLink, + queryRequestOptions, + Configs.isQueryPlanCachingEnabled(), + this.getQueryPlanCache()) + .doOnNext(RxDocumentClientImpl::validateQueryPlanForReadManyByPartitionKey) + .then(); + } + + static void validateQueryPlanForReadManyByPartitionKey(PartitionedQueryExecutionInfo queryPlan) { + if (queryPlan.hasHybridSearchQueryInfo()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain hybrid/vector/full-text search."); + } + + QueryInfo queryInfo = queryPlan.getQueryInfo(); + if (queryInfo == null) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key is not supported because query plan details are unavailable."); + } + + if (queryInfo.hasGroupBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain GROUP BY."); + } + if (queryInfo.hasAggregates()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain aggregates."); + } + if (queryInfo.hasOrderBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain ORDER BY."); + } + if (queryInfo.hasDistinct()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DISTINCT."); + } + if (queryInfo.hasDCount()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain DCOUNT."); + } + if (queryInfo.hasOffset()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain OFFSET."); + } + if (queryInfo.hasLimit()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain LIMIT."); + } + if (queryInfo.hasNonStreamingOrderBy()) { + throw new IllegalArgumentException( + "Custom query for readMany by partition key must not contain non-streaming ORDER BY."); + } + } + + private Map> groupPartitionKeysByPhysicalPartition( + List partitionKeys, + PartitionKeyDefinition pkDefinition, + CollectionRoutingMap routingMap) { + + // Use LinkedHashMap so the downstream round-robin interleave is deterministic and the iteration + // order follows insertion order of partition keys (i.e. the order the caller provided). + Map> partitionRangePkMap = new LinkedHashMap<>(); + + for (PartitionKey pk : partitionKeys) { + PartitionKeyInternal pkInternal = BridgeInternal.getPartitionKeyInternal(pk); + + // PartitionKey.NONE wraps NonePartitionKey which has components = null. + // For routing purposes, treat NONE as UndefinedPartitionKey - documents ingested + // without a partition key path are stored with the undefined EPK. + PartitionKeyInternal effectivePkInternal = pkInternal.getComponents() == null + ? PartitionKeyInternal.UndefinedPartitionKey + : pkInternal; + + int componentCount = effectivePkInternal.getComponents().size(); + int definedPathCount = pkDefinition.getPaths().size(); + + List targetRanges; + + if (pkDefinition.getKind() == PartitionKind.MULTI_HASH && componentCount < definedPathCount) { + // Partial HPK - compute EPK prefix range and find all overlapping physical partitions + Range epkRange = PartitionKeyInternalHelper.getEPKRangeForPrefixPartitionKey( + effectivePkInternal, pkDefinition); + targetRanges = routingMap.getOverlappingRanges(epkRange); + } else { + // Full PK - maps to exactly one physical partition + String effectivePartitionKeyString = PartitionKeyInternalHelper + .getEffectivePartitionKeyString(effectivePkInternal, pkDefinition); + PartitionKeyRange range = routingMap.getRangeByEffectivePartitionKey(effectivePartitionKeyString); + targetRanges = Collections.singletonList(range); + } + + for (PartitionKeyRange range : targetRanges) { + partitionRangePkMap.computeIfAbsent(range, k -> new ArrayList<>()).add(pk); + } + } + + return partitionRangePkMap; + } + private Map getRangeQueryMap( Map> partitionRangeItemKeyMap, PartitionKeyDefinition partitionKeyDefinition) { //TODO: Optimise this to include all types of partitionkeydefinitions. ex: c["prop1./ab"]["key1"] Map rangeQueryMap = new HashMap<>(); - List partitionKeySelectors = createPkSelectors(partitionKeyDefinition); + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(partitionKeyDefinition); for(Map.Entry> entry: partitionRangeItemKeyMap.entrySet()) { SqlQuerySpec sqlQuerySpec; @@ -4465,15 +4780,6 @@ private SqlQuerySpec createReadManyQuerySpec( return new SqlQuerySpec(queryStringBuilder.toString(), parameters); } - private List createPkSelectors(PartitionKeyDefinition partitionKeyDefinition) { - return partitionKeyDefinition.getPaths() - .stream() - .map(pathPart -> StringUtils.substring(pathPart, 1)) // skip starting / - .map(pathPart -> StringUtils.replace(pathPart, "\"", "\\")) // escape quote - .map(part -> "[\"" + part + "\"]") - .collect(Collectors.toList()); - } - private Flux> queryForReadMany( ScopedDiagnosticsFactory diagnosticsFactory, String parentResourceLink, @@ -4995,7 +5301,7 @@ public Flux> readAllDocuments( } PartitionKeyDefinition pkDefinition = collection.getPartitionKey(); - List partitionKeySelectors = createPkSelectors(pkDefinition); + List partitionKeySelectors = ReadManyByPartitionKeyQueryHelper.createPkSelectors(pkDefinition); SqlQuerySpec querySpec = createLogicalPartitionScanQuerySpec(partitionKey, partitionKeySelectors); String resourceLink = parentResourceLinkToQueryLink(collectionLink, ResourceType.Document); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java index e62d8ed3d754..82506ae361e1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/query/DocumentQueryExecutionContextFactory.java @@ -107,7 +107,6 @@ private static Mono getPartitionKeyRangesAn } Instant startTime = Instant.now(); - Mono queryExecutionInfoMono; if (queryRequestOptionsAccessor() .isQueryPlanRetrievalDisallowed(cosmosQueryRequestOptions)) { @@ -122,37 +121,53 @@ private static Mono getPartitionKeyRangesAn endTime); } + return fetchQueryPlan( + diagnosticsClientContext, + client, + query, + resourceLink, + cosmosQueryRequestOptions, + queryPlanCachingEnabled, + queryPlanCache) + .flatMap( + partitionedQueryExecutionInfo -> { + + Instant endTime = Instant.now(); + + return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, + partitionedQueryExecutionInfo, startTime, endTime); + }); + } + + private static Mono fetchQueryPlan( + DiagnosticsClientContext diagnosticsClientContext, + IDocumentQueryClient client, + SqlQuerySpec query, + String resourceLink, + CosmosQueryRequestOptions cosmosQueryRequestOptions, + boolean queryPlanCachingEnabled, + Map queryPlanCache) { + if (queryPlanCachingEnabled && - isScopedToSinglePartition(cosmosQueryRequestOptions) && - queryPlanCache.containsKey(query.getQueryText())) { - Instant endTime = Instant.now(); // endTime for query plan diagnostics + isScopedToSinglePartition(cosmosQueryRequestOptions) && + queryPlanCache.containsKey(query.getQueryText())) { PartitionedQueryExecutionInfo partitionedQueryExecutionInfo = queryPlanCache.get(query.getQueryText()); if (partitionedQueryExecutionInfo != null) { logger.debug("Skipping query plan round trip by using the cached plan"); - return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, - partitionedQueryExecutionInfo, startTime, endTime); + return Mono.just(partitionedQueryExecutionInfo); } } - queryExecutionInfoMono = - QueryPlanRetriever.getQueryPlanThroughGatewayAsync( - diagnosticsClientContext, - client, - query, - resourceLink, - cosmosQueryRequestOptions); - - return queryExecutionInfoMono.flatMap( - partitionedQueryExecutionInfo -> { - - Instant endTime = Instant.now(); - + return QueryPlanRetriever.getQueryPlanThroughGatewayAsync( + diagnosticsClientContext, + client, + query, + resourceLink, + cosmosQueryRequestOptions) + .doOnNext(partitionedQueryExecutionInfo -> { if (queryPlanCachingEnabled && isScopedToSinglePartition(cosmosQueryRequestOptions)) { tryCacheQueryPlan(query, partitionedQueryExecutionInfo, queryPlanCache); } - - return getTargetRangesFromQueryPlan(cosmosQueryRequestOptions, collection, queryExecutionContext, - partitionedQueryExecutionInfo, startTime, endTime); }); } @@ -318,6 +333,25 @@ private static List getFeedRangeEpks(List> range return feedRanges; } + public static Mono fetchQueryPlanForValidation( + DiagnosticsClientContext diagnosticsClientContext, + IDocumentQueryClient queryClient, + SqlQuerySpec sqlQuerySpec, + String resourceLink, + CosmosQueryRequestOptions queryRequestOptions, + boolean queryPlanCachingEnabled, + Map queryPlanCache) { + + return fetchQueryPlan( + diagnosticsClientContext, + queryClient, + sqlQuerySpec, + resourceLink, + queryRequestOptions, + queryPlanCachingEnabled, + queryPlanCache); + } + public static Flux> createDocumentQueryExecutionContextAsync( DiagnosticsClientContext diagnosticsClientContext, IDocumentQueryClient client, diff --git a/sdk/cosmos/cspell.yaml b/sdk/cosmos/cspell.yaml new file mode 100644 index 000000000000..94a4002c2c9c --- /dev/null +++ b/sdk/cosmos/cspell.yaml @@ -0,0 +1,6 @@ +import: + - ../../.vscode/cspell.json +overrides: + - filename: "**/sdk/cosmos/*" + words: + - DCOUNT diff --git a/sdk/cosmos/docs/readManyByPartitionKey-design.md b/sdk/cosmos/docs/readManyByPartitionKey-design.md new file mode 100644 index 000000000000..c193630d4064 --- /dev/null +++ b/sdk/cosmos/docs/readManyByPartitionKey-design.md @@ -0,0 +1,176 @@ +# readManyByPartitionKey — Design & Implementation + +## Overview + +New `readManyByPartitionKey` methods on `CosmosAsyncContainer` / `CosmosContainer` that accept a +`List` (without item-id). The SDK splits the PK values by physical +partition, generates batched streaming queries per physical partition, and returns results as +`CosmosPagedFlux` / `CosmosPagedIterable`. + +An optional `SqlQuerySpec` parameter lets callers supply a custom query for projections +and additional filters. The SDK appends the auto-generated PK WHERE clause to it. + +## Decisions + +| Topic | Decision | +|---|---| +| API name | `readManyByPartitionKey` — distinct name to avoid ambiguity with existing `readMany(List)` | +| Return type | `CosmosPagedFlux` (async) / `CosmosPagedIterable` (sync) | +| Custom query format | `SqlQuerySpec` — full query with parameters; SDK ANDs the PK filter | +| Partial HPK | Supported from the start; prefix PKs fan out via `getOverlappingRanges` | +| PK deduplication | Done at Spark layer only, not in the SDK | +| Spark UDF | New `GetCosmosPartitionKeyValue` UDF | +| Custom query validation | Gateway query plan via the standard SDK query-plan retrieval path; reject aggregates/ORDER BY/DISTINCT/GROUP BY/DCount/OFFSET/LIMIT/non-streaming ORDER BY/vector/fulltext | +| PK list size | No hard upper-bound enforced; SDK batches internally per physical partition (default 100 PKs per batch, configurable via `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE`) | +| Eager validation | Null and empty PK list rejected eagerly (not lazily in reactive chain) | +| Telemetry | Separate span name `readManyByPartitionKeyItems.` (distinct from existing `readManyItems`) | +| Query construction | Table alias auto-detected from FROM clause; string literals and subqueries handled correctly | + +## Phase 1 — SDK Core (`azure-cosmos`) + +### Step 1: New public overloads in CosmosAsyncContainer + +```java + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, Class classType) + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, + CosmosReadManyRequestOptions requestOptions, + Class classType) + CosmosPagedFlux readManyByPartitionKey(List partitionKeys, + SqlQuerySpec customQuery, + CosmosReadManyRequestOptions requestOptions, + Class classType) +``` + +All delegate to a private `readManyByPartitionKeyInternalFunc(...)`. + +**Eager validation:** The 4-arg method validates `partitionKeys` is non-null and non-empty before constructing the reactive pipeline, throwing `IllegalArgumentException` synchronously. + +### Step 2: Sync wrappers in CosmosContainer + +Same signatures returning `CosmosPagedIterable`, delegating to the async container. + +### Step 3: Internal orchestration (RxDocumentClientImpl) + +1. Resolve collection metadata + PK definition from cache. +2. Fetch routing map from `partitionKeyRangeCache`. +3. For each `PartitionKey`: + - Compute effective partition key (EPK). + - Full PK → `getRangeByEffectivePartitionKey()` (single range). + - Partial HPK → compute EPK prefix range → `getOverlappingRanges()` (multiple ranges). + **Note:** partial HPK intentionally fans out to multiple physical partitions. +4. Group PK values by `PartitionKeyRange`. +5. Per physical partition → split PKs into batches of `maxPksPerPartitionQuery` (configurable, default 100). +6. Per batch → build `SqlQuerySpec` with PK WHERE clause (Step 5). +7. Interleave batches across physical partitions in round-robin order so that bounded concurrency prefers different physical partitions over sequential batches of the same partition. +8. Execute queries via `queryForReadMany()` with bounded concurrency (`Math.min(batchCount, cpuCount)`). +9. Return results as `CosmosPagedFlux`. + +### Step 4: Custom query validation + +One-time call per invocation using the same query-plan retrieval path and cacheability rules as regular SDK queries. + +- `QueryPlanRetriever.getQueryPlanThroughGatewayAsync()` for the user query. +- Reject (`IllegalArgumentException`) if: + - `queryInfo.hasGroupBy()` — checked first (takes precedence over aggregates since `hasAggregates()` also returns true for GROUP BY queries) + - `queryInfo.hasAggregates()` + - `queryInfo.hasOrderBy()` + - `queryInfo.hasDistinct()` + - `queryInfo.hasDCount()` + - `queryInfo.hasOffset()` + - `queryInfo.hasLimit()` + - `queryInfo.hasNonStreamingOrderBy()` + - `partitionedQueryExecutionInfo.hasHybridSearchQueryInfo()` + - query plan details are unavailable (`queryInfo == null`) + +### Step 5: Query construction + +Query construction is implemented in `ReadManyByPartitionKeyQueryHelper`. The helper: +- Extracts the table alias from the FROM clause (handles `FROM c`, `FROM root r`, `FROM x WHERE ...`) +- Handles string literals in queries (parens/keywords inside `'...'` are correctly skipped) +- Recognizes SQL keywords: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING +- Uses parameterized queries (`@__rmPk_` prefix) to prevent SQL injection + +**Single PK (HASH):** +```sql +{baseQuery} WHERE {alias}["{pkPath}"] IN (@__rmPk_0, @__rmPk_1, @__rmPk_2) +``` + +**Full HPK (MULTI_HASH):** +```sql +{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0 AND {alias}["{path2}"] = @__rmPk_1) + OR ({alias}["{path1}"] = @__rmPk_2 AND {alias}["{path2}"] = @__rmPk_3) +``` + +**Partial HPK (prefix-only):** +```sql +{baseQuery} WHERE ({alias}["{path1}"] = @__rmPk_0) + OR ({alias}["{path1}"] = @__rmPk_1) +``` + +If the base query already has a WHERE clause: +```sql +{selectAndFrom} WHERE ({existingWhere}) AND ({pkFilter}) +``` + +### Step 6: Interface wiring + +New method `readManyByPartitionKey` added directly to `AsyncDocumentClient` interface, implemented in `RxDocumentClientImpl`. New `fetchQueryPlanForValidation` static method added to `DocumentQueryExecutionContextFactory` for custom query validation. + +### Step 7: Configuration + +New configurable batch size via system property `COSMOS.READ_MANY_BY_PK_MAX_BATCH_SIZE` or environment variable `COSMOS_READ_MANY_BY_PK_MAX_BATCH_SIZE` (default: 100, minimum: 1). Follows existing `Configs` patterns. + +## Phase 2 — Spark Connector (`azure-cosmos-spark_3`) + +### Step 8: New UDF — `GetCosmosPartitionKeyValue` + +- Input: partition key value (single value or Seq for hierarchical PKs). +- Output: serialized PK string in format `pk([...json...])`. +- **Null handling:** Null input is serialized as a JSON-null partition key component. If callers need `PartitionKey.NONE` semantics they must use the schema-matched path with `spark.cosmos.read.readManyByPk.nullHandling=None`, which is only supported for single-path partition keys. + +### Step 9: PK-only serialization helper + +`CosmosPartitionKeyHelper`: +- `getCosmosPartitionKeyValueString(pkValues: List[Object]): String` — serialize to `pk([...])` format. +- `tryParsePartitionKey(serialized: String): Option[PartitionKey]` — deserialize; returns `None` for malformed input including invalid JSON (wrapped in `scala.util.Try`). +- When `spark.cosmos.read.readManyByPk.nullHandling=None` is used, hierarchical partition keys with null components are rejected with a clear error because `PartitionKey.NONE` cannot be used with multiple paths. + +### Step 10: `CosmosItemsDataSource.readManyByPartitionKey` + +Static entry points that accept a DataFrame and Cosmos config. PK extraction supports two modes: +1. **UDF-produced column**: DataFrame contains `_partitionKeyIdentity` column (from `GetCosmosPartitionKeyValue` UDF). +2. **Schema-matched columns**: DataFrame columns match the container's PK paths. + +Nested partition key paths are not resolved automatically from DataFrame columns and must use the UDF-produced `_partitionKeyIdentity` column. + +Falls back with `IllegalArgumentException` if neither mode is possible. + +### Step 11: `CosmosReadManyByPartitionKeyReader` + +Orchestrator that resolves schema, initializes and broadcasts client state to executors, then maps each Spark partition to an `ItemsPartitionReaderWithReadManyByPartitionKey`. The wrapper iterator closes the reader deterministically on exhaustion, on failures, and via Spark task-completion callbacks. + +### Step 12: `ItemsPartitionReaderWithReadManyByPartitionKey` + +Spark `PartitionReader[InternalRow]` that: +- Deduplicates PKs via `LinkedHashMap` (by PK string representation). +- Passes the pre-built `CosmosReadManyRequestOptions` (with throughput control, diagnostics, custom serializer) to the SDK. +- Uses `TransientIOErrorsRetryingIterator` for retry handling. +- Short-circuits empty PK lists to avoid SDK rejection. + +## Phase 3 — Testing + +### Unit tests +- Query construction: single PK, HPK full/partial, custom query composition, table alias detection. +- Query plan rejection: aggregates, ORDER BY, DISTINCT, GROUP BY (with and without aggregates), DCOUNT. +- String literal handling: WHERE/parentheses inside string constants. +- Keyword detection: WHERE, ORDER, GROUP, JOIN, OFFSET, LIMIT, HAVING. +- PK serialization/deserialization roundtrip (including malformed JSON handling). +- `findTopLevelWhereIndex` edge cases: subqueries, string literals, case insensitivity. + +### Integration tests +- End-to-end SDK: single PK basic, projections, filters, empty results, HPK full/partial, request options propagation. +- Batch size validation: temporarily lowered batch size to exercise batching/interleaving logic. +- Null/empty PK list rejection (eager validation). +- Spark connector: `ItemsPartitionReaderWithReadManyByPartitionKey` with known PK values and non-existent PKs. +- Spark public API: nested partition key containers require `_partitionKeyIdentity` and succeed when populated via `GetCosmosPartitionKeyValue`. +- `CosmosPartitionKeyHelper`: single/HPK roundtrip, case insensitivity, malformed input.