Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ dmlStatementNoWith
: insertInto (query | LEFT_PAREN query RIGHT_PAREN queryAlias=tableAlias) #singleInsertQuery
| fromClause multiInsertQueryBody+ #multiInsertQuery
| DELETE FROM identifierReference tableAlias whereClause? #deleteFromTable
| UPDATE identifierReference tableAlias setClause whereClause? #updateTable
| UPDATE identifierReference optionsClause? tableAlias setClause whereClause? #updateTable

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this seems to follow read path, this is inconsistent with our existing INSERT syntax.

| INSERT (WITH SCHEMA EVOLUTION)? INTO TABLE? identifierReference tableAlias optionsClause? (BY NAME)?
REPLACE (WHERE | ON) replaceCondition=booleanExpression #insertIntoReplaceBooleanCond
| INSERT (WITH SCHEMA EVOLUTION)? INTO TABLE? identifierReference tableAlias optionsClause? (BY NAME)?
REPLACE USING identifierList #insertIntoReplaceUsing

UPDATE should follow INSERT style.

| MERGE (WITH SCHEMA EVOLUTION)? INTO target=identifierReference targetAlias=tableAlias

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this PR makes INSERT/DELETE/UPDATE accept WITH(...), but how about MERGE?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally did not add MERGE as the semantics are a bit different since MERGE statements have source and target relations, I'm still thinking about how options would look like there. I can file a separate JIRA and PR for MERGE later.

USING (source=identifierReference |
LEFT_PAREN sourceQuery=query RIGHT_PAREN) sourceAlias=tableAlias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDel
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A rule that rewrites UPDATE operations using plans that operate on individual or groups of rows.
Expand All @@ -41,7 +40,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {

EliminateSubqueryAliases(aliasedTable) match {
case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) =>
val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty())
val table = buildOperationTable(tbl, UPDATE, r.options)
val updateCond = cond.getOrElse(TrueLiteral)
table.operation match {
case _: SupportsDelta =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,8 @@ class AstBuilder extends DataTypeAstBuilder

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Set(TableWritePrivilege.UPDATE))
ctx.identifierReference, Option(ctx.optionsClause()),
writePrivileges = Set(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransfo
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, Decimal, IntegerType, LongType, StringType, StructType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampType, TimeType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.StorageLevelMapper
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -2253,6 +2254,50 @@ class DDLParserSuite extends AnalysisTest {
stop = 70))
}

test("update table: with options") {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use test prefix when it's possible.

- test("update table: with options") {
+ test("SPARK-57681: update table: with options") {

parseCompare(
"""
|UPDATE testcat.ns1.ns2.tbl WITH (`write.split-size` = 10)
|SET a='Robert', b=32
""".stripMargin,
UpdateTable(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl"),
new CaseInsensitiveStringMap(
java.util.Map.of("write.split-size", "10"))),
Seq(Assignment(UnresolvedAttribute("a"), Literal("Robert")),
Assignment(UnresolvedAttribute("b"), Literal(32))),
None))
}

test("update table: with options and alias") {
parseCompare(
"""
|UPDATE testcat.ns1.ns2.tbl WITH (`k` = 'v') AS t
|SET t.a='Robert', t.b=32
|WHERE t.c=2
""".stripMargin,
UpdateTable(
SubqueryAlias("t",
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl"),
new CaseInsensitiveStringMap(
java.util.Map.of("k", "v")))),
Seq(Assignment(UnresolvedAttribute("t.a"), Literal("Robert")),
Assignment(UnresolvedAttribute("t.b"), Literal(32))),
Some(EqualTo(UnresolvedAttribute("t.c"), Literal(2)))))
}

test("update table: options without values are not allowed") {
val sql = "UPDATE testcat.ns1.ns2.tbl WITH (`split-size`) SET a = 1"
checkError(
exception = parseException(sql),
condition = "_LEGACY_ERROR_TEMP_0035",
parameters = Map("message" -> "Values must be specified for key(s): [split-size]"),
context = ExpectedContext(
fragment = "testcat.ns1.ns2.tbl",
start = 7,
stop = 25))
}

test("merge into table: basic") {
parseCompare(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._

/**
* Test helper trait mixed into the in-memory row-level operations so tests can verify that
* per-statement SQL options reach the operation via [[RowLevelOperationInfo#options]].
*/
trait RowLevelOperationWithOptions {
def options: CaseInsensitiveStringMap
}

class InMemoryRowLevelOperationTable private (
name: String,
columns: Array[Column],
Expand Down Expand Up @@ -108,13 +116,14 @@ class InMemoryRowLevelOperationTable private (
override def newRowLevelOperationBuilder(
info: RowLevelOperationInfo): RowLevelOperationBuilder = {
if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") {
() => DeltaBasedOperation(info.command)
() => DeltaBasedOperation(info.command, info.options)
} else {
() => PartitionBasedOperation(info.command)
() => PartitionBasedOperation(info.command, info.options)
}
}

case class PartitionBasedOperation(command: Command) extends RowLevelOperation {
case class PartitionBasedOperation(command: Command, options: CaseInsensitiveStringMap)
extends RowLevelOperation with RowLevelOperationWithOptions {
var configuredScan: InMemoryBatchScan = _

override def requiredMetadataAttributes(): Array[NamedReference] = {
Expand Down Expand Up @@ -183,7 +192,8 @@ class InMemoryRowLevelOperationTable private (
}
}

case class DeltaBasedOperation(command: Command) extends RowLevelOperation with SupportsDelta {
case class DeltaBasedOperation(command: Command, options: CaseInsensitiveStringMap)
extends RowLevelOperation with SupportsDelta with RowLevelOperationWithOptions {
private final val PK_COLUMN_REF = FieldReference("pk")

override def requiredMetadataAttributes(): Array[NamedReference] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package org.apache.spark.sql.connector
import org.apache.spark.SparkRuntimeException
import org.apache.spark.internal.config
import org.apache.spark.sql.{sources, AnalysisException, Row}
import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableChange, TableInfo}
import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured
import org.apache.spark.sql.catalyst.plans.logical.{ReplaceData, WriteDelta}
import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, RowLevelOperationWithOptions, TableChange, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.write.UpdateSummary
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, UpdateSummary}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}

Expand Down Expand Up @@ -1232,4 +1235,40 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
Row(1, 100, "hr"),
Row(2, 200, "software")))
}

// asserts the given SQL options reached every layer that should carry them: the rewritten
// DataSourceV2Relation, the RowLevelOperationInfo passed to the operation builder, and the
// write builder's LogicalWriteInfo
protected def checkRowLevelOperationOptions(
func: => Unit,
expectedOptions: (String, String)*): Unit = {
val Seq(qe) = withQueryExecutionsCaptured(spark)(func)
val writeRelation = qe.optimizedPlan.collectFirst {
case rd: ReplaceData => rd.table
case wd: WriteDelta => wd.table
}.getOrElse(fail("couldn't find row-level operation in optimized plan"))
.asInstanceOf[DataSourceV2Relation]
val operation = writeRelation.table.asInstanceOf[RowLevelOperationTable].operation
.asInstanceOf[RowLevelOperationWithOptions]
expectedOptions.foreach { case (key, value) =>
assert(writeRelation.options.get(key) === value, s"relation option '$key'")
assert(operation.options.get(key) === value, s"row-level operation option '$key'")
assert(table.lastWriteInfo.options().get(key) === value, s"write option '$key'")
}
}

test("update with dynamic options") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
"""{ "pk": 1, "salary": 100, "dep": "hr" }
|{ "pk": 2, "salary": 200, "dep": "software" }
|""".stripMargin)

checkRowLevelOperationOptions(
sql(s"UPDATE $tableNameAsString WITH (`write.split-size` = 10) SET salary = -1 WHERE pk = 1"),
"write.split-size" -> "10")

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Row(1, -1, "hr") :: Row(2, 200, "software") :: Nil)
}
}