|
1 | 1 | package io.shiftleft.passes |
2 | 2 |
|
3 | | -import better.files.File |
4 | 3 | import flatgraph.SchemaViolationException |
5 | 4 | import io.shiftleft.codepropertygraph.generated.Cpg |
6 | | -import io.shiftleft.codepropertygraph.generated.nodes.NewFile |
7 | 5 | import io.shiftleft.codepropertygraph.generated.language.* |
| 6 | +import io.shiftleft.codepropertygraph.generated.nodes.NewFile |
8 | 7 | import org.scalatest.matchers.should.Matchers |
9 | 8 | import org.scalatest.wordspec.AnyWordSpec |
10 | 9 |
|
11 | | -import java.nio.file.Files |
| 10 | +import scala.collection.mutable.ArrayBuffer |
12 | 11 |
|
13 | 12 | class CpgPassNewTests extends AnyWordSpec with Matchers { |
14 | 13 |
|
@@ -52,6 +51,43 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { |
52 | 51 | pass.createAndApply() |
53 | 52 | } |
54 | 53 | } |
| 54 | + |
| 55 | + "call init and finish once around run" in { |
| 56 | + val cpg = Cpg.empty |
| 57 | + val events = ArrayBuffer.empty[String] |
| 58 | + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "lifecycle-pass") { |
| 59 | + override def init(): Unit = events += "init" |
| 60 | + override def generateParts(): Array[String] = Array("p1") |
| 61 | + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = events += "run" |
| 62 | + override def finish(): Unit = events += "finish" |
| 63 | + } |
| 64 | + |
| 65 | + pass.createAndApply() |
| 66 | + |
| 67 | + // all events should be in the expected order and should only occur once |
| 68 | + events.toSeq shouldBe Seq("init", "run", "finish") |
| 69 | + } |
| 70 | + |
| 71 | + "call finish once when run fails" in { |
| 72 | + val cpg = Cpg.empty |
| 73 | + val events = ArrayBuffer.empty[String] |
| 74 | + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "failing-lifecycle-pass") { |
| 75 | + override def init(): Unit = events += "init" |
| 76 | + override def generateParts(): Array[String] = Array("p1") |
| 77 | + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = { |
| 78 | + events += "run" |
| 79 | + throw new RuntimeException("run failed") |
| 80 | + } |
| 81 | + override def finish(): Unit = events += "finish" |
| 82 | + } |
| 83 | + |
| 84 | + intercept[RuntimeException] { |
| 85 | + pass.createAndApply() |
| 86 | + } |
| 87 | + |
| 88 | + // all events should be in the expected order and should only occur once even if run fails |
| 89 | + events.toSeq shouldBe Seq("init", "run", "finish") |
| 90 | + } |
55 | 91 | } |
56 | 92 |
|
57 | 93 | } |
0 commit comments