@@ -263,20 +263,23 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
263263 test(" per-task cache isolates UDF state across sequential task runs in one session" ) {
264264 // Regression guard for the cache-scoping invariant on CometUdfBridge: instances live for
265265 // exactly one Spark task and are dropped on task completion, so a stateful kernel sees a
266- // fresh instance per task. Running the same `monotonically_increasing_id()`-carrying query
267- // twice in one session must produce identical results each run. Under a cache that outlived
268- // a task and got reused by the next one, the counter would continue from the previous run's
269- // final value and the second run's IDs would diverge. Under a cache that was keyed by Tokio
270- // worker thread rather than task attempt ID, worker reuse across tasks would cause the same
271- // leak whenever the second task happened to be polled by the same worker.
266+ // fresh instance per task. The query has to actually route through the dispatcher for this
267+ // to test anything, so wrap `monotonically_increasing_id()` in a ScalaUDF identity. Running
268+ // it twice in one session must produce results matching Spark each time. Under a cache that
269+ // outlived a task and got reused by the next one, the counter would continue from the
270+ // previous run's final value and the second run's IDs would diverge from Spark. Under a
271+ // cache that was keyed by Tokio worker thread rather than task attempt ID, worker reuse
272+ // across tasks would cause the same leak whenever the second task happened to be polled by
273+ // the same worker. Two `checkSparkAnswerAndOperator` calls are stronger than asserting
274+ // first == second: equality alone could pass if both runs are wrong-but-consistent (e.g.
275+ // `init(partitionIndex)` never fires); matching Spark on both runs rules that out and
276+ // implies cross-run equality because Spark is deterministic on the same query.
277+ spark.udf.register(" idPassthrough" , (id : Long ) => id)
272278 val rows = (0 until 2048 ).map(i => s " row_ $i" )
273279 withSubjects(rows : _* ) {
274- val q = " SELECT s, monotonically_increasing_id() AS mid FROM t"
275- val first = sql(q).collect().map(r => (r.getString(0 ), r.getLong(1 ))).toSeq
276- val second = sql(q).collect().map(r => (r.getString(0 ), r.getLong(1 ))).toSeq
277- assert(
278- first == second,
279- s " per-task cache leaked state across runs: first= ${first.take(5 )} second= ${second.take(5 )}" )
280+ val q = " SELECT s, idPassthrough(monotonically_increasing_id()) AS mid FROM t"
281+ checkSparkAnswerAndOperator(sql(q))
282+ checkSparkAnswerAndOperator(sql(q))
280283 }
281284 }
282285
0 commit comments