Skip to content

Commit 9d1f4d3

Browse files
mihailoale-dbHyukjinKwon
authored andcommitted
[SPARK-54557][SQL] Make CSV/JSON/XmlOptions and CSV/JSON/XmlInferSchema comparable
### What changes were proposed in this pull request? In this PR I propose to make `XmlOptions` and `XmlInferSchema` comparable. ### Why are the changes needed? In order to be able to compare them while working on the single-pass implementation (dual-runs). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53268 from mihailoale-db/xmlequalsimplement. Authored-by: mihailoale-db <mihailo.aleksic@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 8dd478f commit 9d1f4d3

6 files changed

Lines changed: 90 additions & 9 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
6868

6969
private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType
7070

71+
override def equals(obj: Any): Boolean = obj match {
72+
case other: CSVInferSchema =>
73+
options == other.options
74+
case _ => false
75+
}
76+
77+
override def hashCode(): Int = options.hashCode()
78+
7179
/**
7280
* Similar to the JSON schema inference
7381
* 1. Infer type of each row

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import org.apache.spark.sql.types.StructType
3434
class CSVOptions(
3535
@transient val parameters: CaseInsensitiveMap[String],
3636
val columnPruning: Boolean,
37-
defaultTimeZoneId: String,
38-
defaultColumnNameOfCorruptRecord: String)
37+
private val defaultTimeZoneId: String,
38+
private val defaultColumnNameOfCorruptRecord: String)
3939
extends FileSourceOptions(parameters) with Logging {
4040

4141
import CSVOptions._
@@ -63,6 +63,24 @@ class CSVOptions(
6363
defaultColumnNameOfCorruptRecord)
6464
}
6565

66+
override def equals(obj: Any): Boolean = obj match {
67+
case other: CSVOptions =>
68+
(parameters == null && other.parameters == null ||
69+
parameters != null && parameters == other.parameters) &&
70+
columnPruning == other.columnPruning &&
71+
defaultTimeZoneId == other.defaultTimeZoneId &&
72+
defaultColumnNameOfCorruptRecord == other.defaultColumnNameOfCorruptRecord
73+
case _ => false
74+
}
75+
76+
override def hashCode(): Int = {
77+
var result = Option(parameters).map(_.hashCode()).getOrElse(0)
78+
result = 31 * result + (if (columnPruning) 1 else 0)
79+
result = 31 * result + defaultTimeZoneId.hashCode()
80+
result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
81+
result
82+
}
83+
6684
private def getChar(paramName: String, default: Char): Char = {
6785
val paramValue = parameters.get(paramName)
6886
paramValue match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
3636
*/
3737
class JSONOptions(
3838
@transient val parameters: CaseInsensitiveMap[String],
39-
defaultTimeZoneId: String,
40-
defaultColumnNameOfCorruptRecord: String)
39+
private val defaultTimeZoneId: String,
40+
private val defaultColumnNameOfCorruptRecord: String)
4141
extends FileSourceOptions(parameters) with Logging {
4242

4343
import JSONOptions._
@@ -156,6 +156,22 @@ class JSONOptions(
156156
protected def checkedEncoding(enc: String): String =
157157
CharsetProvider.forName(enc, caller = "JSONOptions").name()
158158

159+
override def equals(obj: Any): Boolean = obj match {
160+
case other: JSONOptions =>
161+
(parameters == null && other.parameters == null ||
162+
parameters != null && parameters == other.parameters) &&
163+
defaultTimeZoneId == other.defaultTimeZoneId &&
164+
defaultColumnNameOfCorruptRecord == other.defaultColumnNameOfCorruptRecord
165+
case _ => false
166+
}
167+
168+
override def hashCode(): Int = {
169+
var result = Option(parameters).map(_.hashCode()).getOrElse(0)
170+
result = 31 * result + defaultTimeZoneId.hashCode()
171+
result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
172+
result
173+
}
174+
159175
/**
160176
* Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE.
161177
* If the encoding is not specified (None) in read, it will be detected automatically

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.unsafe.types.UTF8String
3939
import org.apache.spark.util.ArrayImplicits._
4040
import org.apache.spark.util.Utils
4141

42-
class JsonInferSchema(options: JSONOptions) extends Serializable with Logging {
42+
class JsonInferSchema(private val options: JSONOptions) extends Serializable with Logging {
4343

4444
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
4545

@@ -61,6 +61,13 @@ class JsonInferSchema(options: JSONOptions) extends Serializable with Logging {
6161
private val isDefaultNTZ = SQLConf.get.timestampType == TimestampNTZType
6262
private val legacyMode = SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY
6363

64+
override def equals(obj: Any): Boolean = obj match {
65+
case other: JsonInferSchema => options == other.options
66+
case _ => false
67+
}
68+
69+
override def hashCode(): Int = options.hashCode()
70+
6471
private def handleJsonErrorsByParseMode(parseMode: ParseMode,
6572
columnNameOfCorruptRecord: String, e: Throwable): Option[StructType] = {
6673
parseMode match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
4545
import org.apache.spark.sql.types._
4646
import org.apache.spark.util.SparkErrorUtils
4747

48-
class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
48+
class XmlInferSchema(private val options: XmlOptions, private val caseSensitive: Boolean)
4949
extends Serializable
5050
with Logging {
5151

@@ -73,6 +73,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
7373
legacyFormat = FAST_DATE_FORMAT,
7474
isParsing = true)
7575

76+
override def equals(obj: Any): Boolean = obj match {
77+
case other: XmlInferSchema =>
78+
options == other.options &&
79+
caseSensitive == other.caseSensitive
80+
case _ => false
81+
}
82+
83+
override def hashCode(): Int = {
84+
var result = options.hashCode()
85+
result = 31 * result + (if (caseSensitive) 1 else 0)
86+
result
87+
}
88+
7689
private def handleXmlErrorsByParseMode(
7790
parser: XMLEventReader,
7891
parseMode: ParseMode,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
3232
*/
3333
class XmlOptions(
3434
val parameters: CaseInsensitiveMap[String],
35-
defaultTimeZoneId: String,
36-
defaultColumnNameOfCorruptRecord: String,
37-
rowTagRequired: Boolean)
35+
private val defaultTimeZoneId: String,
36+
private val defaultColumnNameOfCorruptRecord: String,
37+
private val rowTagRequired: Boolean)
3838
extends FileSourceOptions(parameters) with Logging {
3939

4040
import XmlOptions._
@@ -51,6 +51,25 @@ class XmlOptions(
5151
rowTagRequired)
5252
}
5353

54+
55+
override def equals(obj: Any): Boolean = obj match {
56+
case other: XmlOptions =>
57+
(parameters == null && other.parameters == null ||
58+
parameters != null && parameters == other.parameters) &&
59+
defaultTimeZoneId == other.defaultTimeZoneId &&
60+
defaultColumnNameOfCorruptRecord == other.defaultColumnNameOfCorruptRecord &&
61+
rowTagRequired == other.rowTagRequired
62+
case _ => false
63+
}
64+
65+
override def hashCode(): Int = {
66+
var result = Option(parameters).map(_.hashCode()).getOrElse(0)
67+
result = 31 * result + defaultTimeZoneId.hashCode()
68+
result = 31 * result + defaultColumnNameOfCorruptRecord.hashCode()
69+
result = 31 * result + (if (rowTagRequired) 1 else 0)
70+
result
71+
}
72+
5473
private def getBool(paramName: String, default: Boolean = false): Boolean = {
5574
val param = parameters.getOrElse(paramName, default.toString)
5675
if (param == null) {

0 commit comments

Comments
 (0)