diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 502d96a205..273eff22f9 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -355,6 +355,7 @@ jobs: org.apache.spark.sql.comet.CometTaskMetricsSuite org.apache.spark.sql.comet.CometDppFallbackRepro3949Suite org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite + org.apache.spark.sql.comet.PlanDataInjectorSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.spark.sql.comet.util.UtilsSuite org.apache.comet.objectstore.NativeConfigSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index f2a3e84fa1..075a228c64 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -171,6 +171,7 @@ jobs: org.apache.spark.sql.comet.CometTaskMetricsSuite org.apache.spark.sql.comet.CometDppFallbackRepro3949Suite org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite + org.apache.spark.sql.comet.PlanDataInjectorSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.spark.sql.comet.util.UtilsSuite org.apache.comet.objectstore.NativeConfigSuite diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index ebb22d2361..af9e1df8a3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -68,6 +68,12 @@ import org.apache.comet.serde.operator.CometSink */ private[comet] trait PlanDataInjector { + /** + * Which `OpStructCase` this injector handles. Used by `injectPlanData` for an O(1) pre-filter + * so we don't run every injector's `canInject` against every operator in the tree. + */ + def opStructCase: Operator.OpStructCase + /** Check if this injector can handle the given operator. */ def canInject(op: Operator): Boolean @@ -90,6 +96,13 @@ private[comet] object PlanDataInjector { // Future: DeltaPlanDataInjector, HudiPlanDataInjector, etc. ) + // O(1) lookup by op kind: most operators in any tree don't match any injector, so the per-op + // `for (injector <- injectors if injector.canInject(op))` walk was paying N*M canInject calls + // (N operators, M injectors) just to find no match. Keying by OpStructCase lets us skip the + // iteration entirely for non-scan operators. + private val injectorsByKind: Map[Operator.OpStructCase, PlanDataInjector] = + injectors.map(i => i.opStructCase -> i).toMap + /** * Injects planning data into an Operator tree by finding nodes that need injection and applying * the appropriate injector. @@ -103,21 +116,24 @@ private[comet] object PlanDataInjector { partitionByKey: Map[String, Array[Byte]]): Operator = { val builder = op.toBuilder - // Try each injector to see if it can handle this operator - for (injector <- injectors if injector.canInject(op)) { - injector.getKey(op) match { - case Some(key) => - (commonByKey.get(key), partitionByKey.get(key)) match { - case (Some(commonBytes), Some(partitionBytes)) => - val injectedOp = injector.inject(op, commonBytes, partitionBytes) - // Copy the injected operator's fields to our builder - builder.clear() - builder.mergeFrom(injectedOp) - case _ => - throw new CometRuntimeException(s"Missing planning data for key: $key") - } - case None => - } + // O(1) by op kind, then a canInject confirm (which may inspect detail fields like `hasCommon` + // / `!hasFilePartition`). Most operators in any tree are non-scan and skip the lookup body. + injectorsByKind.get(op.getOpStructCase) match { + case Some(injector) if injector.canInject(op) => + injector.getKey(op) match { + case Some(key) => + (commonByKey.get(key), partitionByKey.get(key)) match { + case (Some(commonBytes), Some(partitionBytes)) => + val injectedOp = injector.inject(op, commonBytes, partitionBytes) + // Copy the injected operator's fields to our builder + builder.clear() + builder.mergeFrom(injectedOp) + case _ => + throw new CometRuntimeException(s"Missing planning data for key: $key") + } + case None => + } + case _ => } // Recursively process children @@ -161,6 +177,8 @@ private[comet] object IcebergPlanDataInjector extends PlanDataInjector { } }) + override val opStructCase: Operator.OpStructCase = Operator.OpStructCase.ICEBERG_SCAN + override def canInject(op: Operator): Boolean = op.hasIcebergScan && op.getIcebergScan.getFileScanTasksCount == 0 && @@ -200,6 +218,8 @@ private[comet] object IcebergPlanDataInjector extends PlanDataInjector { */ private[comet] object NativeScanPlanDataInjector extends PlanDataInjector { + override val opStructCase: Operator.OpStructCase = Operator.OpStructCase.NATIVE_SCAN + override def canInject(op: Operator): Boolean = op.hasNativeScan && op.getNativeScan.hasCommon && diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/PlanDataInjectorSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/PlanDataInjectorSuite.scala new file mode 100644 index 0000000000..601ce9a7e7 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/PlanDataInjectorSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.comet.serde.OperatorOuterClass.Operator + +class PlanDataInjectorSuite extends AnyFunSuite { + + test("injectPlanData leaves a non-scan operator tree unchanged") { + // An operator with no injectable scan (here, an empty op_struct, but the same holds for + // Filter/Projection/etc.) must pass through untouched. This exercises the O(1) + // injectorsByKind miss path (`case _ =>`) that replaced the per-injector canInject walk. + val child = Operator.newBuilder().setPlanId(2).build() + val root = Operator.newBuilder().setPlanId(1).addChildren(child).build() + + val result = PlanDataInjector.injectPlanData(root, Map.empty, Map.empty) + + assert(result == root, "non-scan operator tree should be returned unchanged") + } + + test("each registered injector is reachable by its opStructCase") { + // The O(1) lookup keys injectors by opStructCase, so two injectors sharing a kind would + // silently shadow one another in the map. Guard that every registered injector resolves back + // to itself via its declared opStructCase (i.e. the kinds are distinct and the map is complete). + val injectors = Seq(IcebergPlanDataInjector, NativeScanPlanDataInjector) + val byKind = injectors.map(i => i.opStructCase -> i).toMap + assert(byKind.size == injectors.size, "injectors must have distinct opStructCase keys") + injectors.foreach { i => + assert(byKind(i.opStructCase) eq i) + } + assert(IcebergPlanDataInjector.opStructCase == Operator.OpStructCase.ICEBERG_SCAN) + assert(NativeScanPlanDataInjector.opStructCase == Operator.OpStructCase.NATIVE_SCAN) + } +}