Skip to content

Commit 5645a6d

Browse files
committed
[SPARK-52812][SQL] Make Spark Connect Catalog.createTable eager
1 parent 846376a commit 5645a6d

2 files changed

Lines changed: 22 additions & 15 deletions

File tree

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
9999
import session.implicits._
100100
val df1 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name")
101101
df1.write.parquet(table1Dir.getPath)
102-
spark.catalog.createTable(parquetTableName, table1Dir.getPath).collect()
102+
spark.catalog.createTable(parquetTableName, table1Dir.getPath)
103103
withTable(orcTableName, jsonTableName) {
104104
withTempPath { table2Dir =>
105105
val df2 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").zipWithIndex.toDF("name", "id")
106106
df2.write.orc(table2Dir.getPath)
107-
spark.catalog.createTable(orcTableName, table2Dir.getPath, "orc").collect()
107+
spark.catalog.createTable(orcTableName, table2Dir.getPath, "orc")
108108
val orcTable = spark.catalog.getTable(orcTableName)
109109
assert(!orcTable.isTemporary)
110110
assert(orcTable.name == orcTableName)
@@ -117,7 +117,6 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
117117
val schema = new StructType().add("id", LongType).add("a", DoubleType)
118118
spark.catalog
119119
.createTable(jsonTableName, "json", schema, Map.empty[String, String])
120-
.collect()
121120
val jsonTable = spark.catalog.getTable("default", jsonTableName)
122121
assert(!jsonTable.isTemporary)
123122
assert(jsonTable.name == jsonTableName)
@@ -151,6 +150,19 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
151150
assert(spark.catalog.listTables().collect().isEmpty)
152151
}
153152

153+
test("createTable should be eager") {
154+
val tableName = "eager_table"
155+
withTable(tableName) {
156+
withTempPath { dir =>
157+
val session = spark
158+
import session.implicits._
159+
Seq((1, "a")).toDF("id", "value").write.parquet(dir.getPath)
160+
spark.catalog.createTable(tableName, dir.getPath)
161+
assert(spark.catalog.tableExists(tableName))
162+
}
163+
}
164+
}
165+
154166
test("Cache Table APIs") {
155167
val parquetTableName = "parquet_table"
156168
withTable(parquetTableName) {
@@ -159,7 +171,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
159171
import session.implicits._
160172
val df1 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name")
161173
df1.write.parquet(table1Dir.getPath)
162-
spark.catalog.createTable(parquetTableName, table1Dir.getPath).collect()
174+
spark.catalog.createTable(parquetTableName, table1Dir.getPath)
163175

164176
// Test cache and uncacheTable
165177
spark.catalog.cacheTable(parquetTableName)
@@ -375,7 +387,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
375387
val session = spark
376388
import session.implicits._
377389
Seq(1).toDF("id").write.parquet(dir.getPath)
378-
spark.catalog.createTable(tbl, dir.getPath).collect()
390+
spark.catalog.createTable(tbl, dir.getPath)
379391
assert(spark.catalog.tableExists(tbl))
380392
spark.catalog.dropTable(tbl)
381393
assert(!spark.catalog.tableExists(tbl))
@@ -445,7 +457,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
445457
val session = spark
446458
import session.implicits._
447459
Seq(1).toDF("id").write.parquet(dir.getPath)
448-
spark.catalog.createTable(t, dir.getPath).collect()
460+
spark.catalog.createTable(t, dir.getPath)
449461
val ddl = spark.catalog.getCreateTableString(t)
450462
assert(ddl.nonEmpty && ddl.toLowerCase(java.util.Locale.ROOT).contains("create"))
451463
}
@@ -470,7 +482,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe
470482
val session = spark
471483
import session.implicits._
472484
Seq(1).toDF("id").write.parquet(dir.getPath)
473-
spark.catalog.createTable(t, dir.getPath).collect()
485+
spark.catalog.createTable(t, dir.getPath)
474486
spark.catalog.analyzeTable(t, noScan = true)
475487
}
476488
}

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
392392
* @since 3.5.0
393393
*/
394394
override def createTable(tableName: String, path: String): DataFrame = {
395-
sparkSession.newDataFrame { builder =>
396-
builder.getCatalogBuilder.getCreateTableBuilder
397-
.setTableName(tableName)
398-
.setSchema(DataTypeProtoConverter.toConnectProtoType(new StructType))
399-
.setDescription("")
400-
.putOptions("path", path)
401-
}
395+
createTable(tableName, path, "parquet")
402396
}
403397

404398
/**
@@ -484,7 +478,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
484478
schema: StructType,
485479
description: String,
486480
options: Map[String, String]): DataFrame = {
487-
sparkSession.newDataFrame { builder =>
481+
sparkSession.execute { builder =>
488482
val createTableBuilder = builder.getCatalogBuilder.getCreateTableBuilder
489483
.setTableName(tableName)
490484
.setSource(source)
@@ -494,6 +488,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
494488
createTableBuilder.putOptions(k, v)
495489
}
496490
}
491+
sparkSession.table(tableName)
497492
}
498493

499494
/**

0 commit comments

Comments
 (0)