Skip to content

Commit c50ca25

Browse files
Ensure finish() is called exactly once in ForkJoinParallelCpgPass lifecycle (#1841)
1 parent 0a2cf00 commit c50ca25

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,11 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
8181
baseLogger.error(s"Pass ${name} failed", exc)
8282
throw exc
8383
} finally {
84-
try {
85-
finish()
86-
} finally {
87-
// the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish()
88-
// in the reported timings, and we must have our final log message if finish() throws
89-
val nanosStop = System.nanoTime()
90-
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
91-
baseLogger.info(
92-
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts."
93-
)
94-
}
84+
val nanosStop = System.nanoTime()
85+
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
86+
baseLogger.info(
87+
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts."
88+
)
9589
}
9690
}
9791

codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
package io.shiftleft.passes
22

3-
import better.files.File
43
import flatgraph.SchemaViolationException
54
import io.shiftleft.codepropertygraph.generated.Cpg
6-
import io.shiftleft.codepropertygraph.generated.nodes.NewFile
75
import io.shiftleft.codepropertygraph.generated.language.*
6+
import io.shiftleft.codepropertygraph.generated.nodes.NewFile
87
import org.scalatest.matchers.should.Matchers
98
import org.scalatest.wordspec.AnyWordSpec
109

11-
import java.nio.file.Files
10+
import scala.collection.mutable.ArrayBuffer
1211

1312
class CpgPassNewTests extends AnyWordSpec with Matchers {
1413

@@ -52,6 +51,43 @@ class CpgPassNewTests extends AnyWordSpec with Matchers {
5251
pass.createAndApply()
5352
}
5453
}
54+
55+
"call init and finish once around run" in {
56+
val cpg = Cpg.empty
57+
val events = ArrayBuffer.empty[String]
58+
val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "lifecycle-pass") {
59+
override def init(): Unit = events += "init"
60+
override def generateParts(): Array[String] = Array("p1")
61+
override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = events += "run"
62+
override def finish(): Unit = events += "finish"
63+
}
64+
65+
pass.createAndApply()
66+
67+
// all events should be in the expected order and should only occur once
68+
events.toSeq shouldBe Seq("init", "run", "finish")
69+
}
70+
71+
"call finish once when run fails" in {
72+
val cpg = Cpg.empty
73+
val events = ArrayBuffer.empty[String]
74+
val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "failing-lifecycle-pass") {
75+
override def init(): Unit = events += "init"
76+
override def generateParts(): Array[String] = Array("p1")
77+
override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = {
78+
events += "run"
79+
throw new RuntimeException("run failed")
80+
}
81+
override def finish(): Unit = events += "finish"
82+
}
83+
84+
intercept[RuntimeException] {
85+
pass.createAndApply()
86+
}
87+
88+
// all events should be in the expected order and should only occur once even if run fails
89+
events.toSeq shouldBe Seq("init", "run", "finish")
90+
}
5591
}
5692

5793
}

0 commit comments

Comments
 (0)