Skip to content

Commit 650f619

Browse files
committed
add fallback for too many args and a test, clean up printing code
1 parent a057687 commit 650f619

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
9898
case _ => false
9999
}
100100

101+
/**
102+
* Count the number of leaf fields (including nested) in a [[DataType]]. Mirrors WSCG's
103+
* `WholeStageCodegenExec.numOfNestedFields` so the [[canHandle]] threshold check uses the same
104+
* unit as `spark.sql.codegen.maxFields`.
105+
*/
106+
private def numOfNestedFields(dataType: DataType): Int = dataType match {
107+
case st: StructType => st.fields.map(f => numOfNestedFields(f.dataType)).sum
108+
case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
109+
case a: ArrayType => numOfNestedFields(a.elementType)
110+
case _ => 1
111+
}
112+
101113
/**
102114
* Plan-time predicate: can the codegen dispatcher handle this bound expression end to end?
103115
* `None` greenlights the serde to emit the codegen proto; `Some(reason)` forces a Spark
@@ -112,6 +124,19 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
112124
if (!isSupportedDataType(boundExpr.dataType)) {
113125
return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}")
114126
}
127+
// Mirror WSCG's `spark.sql.codegen.maxFields` gate. Count nested fields in the output type
128+
// and in every `BoundReference`'s input type. Wide schemas blow the generated class's typed
129+
// input field count, the typed-getter switch, and the constant pool. Refuse here so the
130+
// operator falls back to Spark cleanly rather than tripping a Janino compile failure
131+
// mid-execution (which Comet has no way to recover from).
132+
val maxFields = SQLConf.get.wholeStageMaxNumFields
133+
val totalFields = numOfNestedFields(boundExpr.dataType) +
134+
boundExpr.collect { case b: BoundReference => numOfNestedFields(b.dataType) }.sum
135+
if (totalFields > maxFields) {
136+
return Some(
137+
s"codegen dispatch: too many nested fields ($totalFields > " +
138+
s"spark.sql.codegen.maxFields=$maxFields)")
139+
}
115140
// Reject expressions that can't be safely compiled or cached:
116141
// - AggregateFunction / Generator: non-scalar bridge shape.
117142
// - CodegenFallback: opts out of `doGenCode`, which our compile path assumes works.
@@ -192,7 +217,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
192217
case t: Throwable =>
193218
logError(
194219
s"CometBatchKernelCodegen: compile failed for ${boundExpr.getClass.getSimpleName}. " +
195-
s"Generated source follows:\n${src.body}",
220+
s"Generated source follows:\n${CodeFormatter.format(src.code)}",
196221
t)
197222
throw t
198223
}

spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,33 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
163163
s"expected no dispatcher activity under disabled config, got $after")
164164
}
165165

166+
test("schema exceeding spark.sql.codegen.maxFields falls back to Spark") {
167+
// `CometBatchKernelCodegen.canHandle` mirrors WSCG's `spark.sql.codegen.maxFields` gate by
168+
// counting nested input fields plus the output field and refusing once the total exceeds the
169+
// configured cap. Comet has no mid-execution fallback, so the gate must fire at plan time
170+
// (in the serde) rather than letting an oversized kernel reach Janino. With 5 input
171+
// BoundReferences and a 1-field output we have 6 fields total; setting `maxFields=3` ensures
172+
// the gate fires here regardless of test ordering or future schema additions.
173+
spark.udf.register(
174+
"sumFiveInts",
175+
(a: Int, b: Int, c: Int, d: Int, e: Int) => a + b + c + d + e)
176+
withTable("t") {
177+
sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT) USING parquet")
178+
sql("INSERT INTO t VALUES (1, 2, 3, 4, 5), (10, 20, 30, 40, 50)")
179+
CometScalaUDFCodegen.resetStats()
180+
withSQLConf("spark.sql.codegen.maxFields" -> "3") {
181+
// Result correctness still has to match Spark; only the dispatcher path is refused.
182+
// ScalaUDF has no Comet-native path, so this runs on the JVM Spark path under fallback,
183+
// hence `checkSparkAnswer` rather than `checkSparkAnswerAndOperator`.
184+
checkSparkAnswer(sql("SELECT sumFiveInts(a, b, c, d, e) FROM t"))
185+
}
186+
val after = CometScalaUDFCodegen.stats()
187+
assert(
188+
after.compileCount == 0 && after.cacheHitCount == 0,
189+
s"expected dispatcher fallback under maxFields=3, got $after")
190+
}
191+
}
192+
166193
test("per-batch nullability produces distinct compiles for null-present vs null-absent") {
167194
// Same ScalaUDF + same Arrow vector class + different observed nullability should hit
168195
// different cache keys, because `ArrowColumnSpec.nullable` flips when the batch has no

0 commit comments

Comments
 (0)