diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala index 61c4502b256d8..e8ccc9f083c63 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/CatalogSuite.scala @@ -99,12 +99,12 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe import session.implicits._ val df1 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name") df1.write.parquet(table1Dir.getPath) - spark.catalog.createTable(parquetTableName, table1Dir.getPath).collect() + spark.catalog.createTable(parquetTableName, table1Dir.getPath) withTable(orcTableName, jsonTableName) { withTempPath { table2Dir => val df2 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").zipWithIndex.toDF("name", "id") df2.write.orc(table2Dir.getPath) - spark.catalog.createTable(orcTableName, table2Dir.getPath, "orc").collect() + spark.catalog.createTable(orcTableName, table2Dir.getPath, "orc") val orcTable = spark.catalog.getTable(orcTableName) assert(!orcTable.isTemporary) assert(orcTable.name == orcTableName) @@ -117,7 +117,6 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val schema = new StructType().add("id", LongType).add("a", DoubleType) spark.catalog .createTable(jsonTableName, "json", schema, Map.empty[String, String]) - .collect() val jsonTable = spark.catalog.getTable("default", jsonTableName) assert(!jsonTable.isTemporary) assert(jsonTable.name == jsonTableName) @@ -151,6 +150,19 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe assert(spark.catalog.listTables().collect().isEmpty) } + test("createTable should be eager") { + val tableName = "eager_table" + withTable(tableName) { + withTempPath { dir => + val session = spark + import session.implicits._ + Seq((1, "a")).toDF("id", "value").write.parquet(dir.getPath) + spark.catalog.createTable(tableName, dir.getPath) + assert(spark.catalog.tableExists(tableName)) + } + } + } + test("Cache Table APIs") { val parquetTableName = "parquet_table" withTable(parquetTableName) { @@ -159,7 +171,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe import session.implicits._ val df1 = Seq("Bob", "Alice", "Nico", "Bob", "Alice").toDF("name") df1.write.parquet(table1Dir.getPath) - spark.catalog.createTable(parquetTableName, table1Dir.getPath).collect() + spark.catalog.createTable(parquetTableName, table1Dir.getPath) // Test cache and uncacheTable spark.catalog.cacheTable(parquetTableName) @@ -375,7 +387,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val session = spark import session.implicits._ Seq(1).toDF("id").write.parquet(dir.getPath) - spark.catalog.createTable(tbl, dir.getPath).collect() + spark.catalog.createTable(tbl, dir.getPath) assert(spark.catalog.tableExists(tbl)) spark.catalog.dropTable(tbl) assert(!spark.catalog.tableExists(tbl)) @@ -445,7 +457,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val session = spark import session.implicits._ Seq(1).toDF("id").write.parquet(dir.getPath) - spark.catalog.createTable(t, dir.getPath).collect() + spark.catalog.createTable(t, dir.getPath) val ddl = spark.catalog.getCreateTableString(t) assert(ddl.nonEmpty && ddl.toLowerCase(java.util.Locale.ROOT).contains("create")) } @@ -470,7 +482,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val session = spark import session.implicits._ Seq(1).toDF("id").write.parquet(dir.getPath) - spark.catalog.createTable(t, dir.getPath).collect() + spark.catalog.createTable(t, dir.getPath) spark.catalog.analyzeTable(t, noScan = true) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala index ea4bc3e7ad604..2324ca05d7b7f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Catalog.scala @@ -392,13 +392,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { * @since 3.5.0 */ override def createTable(tableName: String, path: String): DataFrame = { - sparkSession.newDataFrame { builder => - builder.getCatalogBuilder.getCreateTableBuilder - .setTableName(tableName) - .setSchema(DataTypeProtoConverter.toConnectProtoType(new StructType)) - .setDescription("") - .putOptions("path", path) - } + createTable(tableName, path, "parquet") } /** @@ -484,7 +478,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { schema: StructType, description: String, options: Map[String, String]): DataFrame = { - sparkSession.newDataFrame { builder => + sparkSession.execute { builder => val createTableBuilder = builder.getCatalogBuilder.getCreateTableBuilder .setTableName(tableName) .setSource(source) @@ -494,6 +488,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog { createTableBuilder.putOptions(k, v) } } + sparkSession.table(tableName) } /**