Skip to content

Commit 70fed24

Browse files
Add ForkJoinParallelCpgPassWithAccumulator (#1847)
Introduce a general ForkJoinParallelCpgPassWithAccumulator for fork/join CPG passes that need per-worker accumulators.
1 parent 3af4f20 commit 70fed24

File tree

2 files changed

+290
-64
lines changed

2 files changed

+290
-64
lines changed

codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala

Lines changed: 184 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,23 @@ import scala.annotation.nowarn
1010
import scala.concurrent.duration.DurationLong
1111
import scala.util.{Failure, Success, Try}
1212

13-
/* CpgPass
14-
*
15-
* Base class of a program which receives a CPG as input for the purpose of modifying it.
16-
* */
13+
/** A single-threaded CPG pass. This is the simplest pass to implement: override [[run]] and add desired graph
14+
* modifications to the provided [[DiffGraphBuilder]].
15+
*
16+
* Internally implemented as a [[ForkJoinParallelCpgPass]] with a single part and parallelism disabled.
17+
*
18+
* @param cpg
19+
* the code property graph to modify
20+
* @param outName
21+
* optional name for output
22+
*/
1723
abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelCpgPass[AnyRef](cpg, outName) {
1824

25+
/** The main method to implement. Add all desired graph changes (nodes, edges, properties) to the provided builder.
26+
*
27+
* @param builder
28+
* the [[DiffGraphBuilder]] that accumulates graph modifications
29+
*/
1930
def run(builder: DiffGraphBuilder): Unit
2031

2132
final override def generateParts(): Array[? <: AnyRef] = Array[AnyRef](null)
@@ -26,42 +37,126 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC
2637
override def isParallel: Boolean = false
2738
}
2839

40+
/** @deprecated Use [[CpgPass]] instead. */
2941
@deprecated abstract class SimpleCpgPass(cpg: Cpg, outName: String = "") extends CpgPass(cpg, outName)
3042

31-
/* ForkJoinParallelCpgPass is a possible replacement for CpgPass and ParallelCpgPass.
32-
*
33-
* Instead of returning an Iterator, generateParts() returns an Array. This means that the entire collection
34-
* of parts must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
35-
* e.g. when running over all METHOD nodes and deleting some of them.
36-
*
37-
* Instead of streaming writes as ParallelCpgPass do, all `runOnPart` invocations read the initial state
38-
* of the graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one go.
39-
*
40-
* In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
41-
* The effect is identical as if one were to sequentially run `runOnParts` on all output elements of `generateParts()`
42-
* in sequential order, with the same builder.
43-
*
44-
* This simplifies semantics and makes it easy to reason about possible races.
45-
*
46-
* Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption when porting from ParallelCpgPass.
47-
*
48-
* Initialization and cleanup of external resources or large datastructures can be done in the `init()` and `finish()`
49-
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
50-
* passes eagerly, and releases them only when the entire chain has run.
51-
* */
52-
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: String = "") extends CpgPassBase {
43+
/** A parallel CPG pass using the fork/join model.
44+
*
45+
* Instead of returning an Iterator, [[generateParts]] returns an Array. This means that the entire collection of parts
46+
* must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
47+
* e.g. when running over all METHOD nodes and deleting some of them.
48+
*
49+
* Instead of streaming writes as ParallelCpgPass do, all [[runOnPart]] invocations read the initial state of the
50+
* graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one
51+
* go.
52+
*
53+
* In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
54+
* The effect is identical as if one were to sequentially run [[runOnPart]] on all output elements of [[generateParts]]
55+
* in sequential order, with the same builder.
56+
*
57+
* This simplifies semantics and makes it easy to reason about possible races.
58+
*
59+
* Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption
60+
* when porting from ParallelCpgPass.
61+
*
62+
* Initialization and cleanup of external resources or large datastructures can be done in the [[init]] and [[finish]]
63+
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct passes
64+
* eagerly, and releases them only when the entire chain has run.
65+
*
66+
* This is a simplified form of [[ForkJoinParallelCpgPassWithAccumulator]] that does not use an accumulator.
67+
*
68+
* @tparam T
69+
* the type of each part produced by [[generateParts]]
70+
* @param cpg
71+
* the code property graph to modify
72+
* @param outname
73+
* optional output name
74+
*/
75+
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outname: String = "")
76+
extends ForkJoinParallelCpgPassWithAccumulator[T, Null](cpg, outname) {
77+
78+
/** Process a single part and record graph modifications in the provided builder.
79+
*
80+
* @param builder
81+
* the [[DiffGraphBuilder]] that accumulates graph modifications
82+
* @param part
83+
* the part to process, as produced by [[generateParts]]
84+
*/
85+
def runOnPart(builder: DiffGraphBuilder, part: T): Unit
86+
87+
override def createAccumulator(): Null = null
88+
override def runOnPart(builder: DiffGraphBuilder, part: T, acc: Null): Unit = runOnPart(builder, part)
89+
override def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Null): Unit = {}
90+
override def mergeAccumulator(left: Null, accumulator: Null): Unit = {}
91+
}
92+
93+
/** A parallel CPG pass with an accumulator for aggregating side results.
94+
*
95+
* This is the most general form of the fork/join pass framework. It extends [[ForkJoinParallelCpgPass]] with an
96+
* accumulator of type [[Accumulator]] that each parallel worker maintains locally. After all parts are processed,
97+
* worker accumulators are merged via [[mergeAccumulator]], and the final merged accumulator is passed to
98+
* [[onAccumulatorComplete]] where additional graph changes can be recorded.
99+
*
100+
* @tparam T
101+
* the type of each part produced by [[generateParts]]
102+
* @tparam Accumulator
103+
* the type of the accumulator used during parallel execution
104+
* @param cpg
105+
* the code property graph to modify
106+
* @param outName
107+
* optional output name
108+
*/
109+
abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, Accumulator <: AnyRef](
110+
cpg: Cpg,
111+
@nowarn outName: String = ""
112+
) extends CpgPassBase {
53113
type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder
54-
// generate Array of parts that can be processed in parallel
114+
115+
/** Generate an array of parts to be processed in parallel by [[runOnPart]]. */
55116
def generateParts(): Array[? <: AnyRef]
56-
// setup large data structures, acquire external resources
117+
118+
/** Called once before [[generateParts]]. Use to set up large data structures or acquire external resources. */
57119
def init(): Unit = {}
58-
// release large data structures and external resources
120+
121+
/** Called once after all parts have been processed (in a `finally` block). Use to release resources acquired in
122+
* [[init]].
123+
*/
59124
def finish(): Unit = {}
60-
// main function: add desired changes to builder
61-
def runOnPart(builder: DiffGraphBuilder, part: T): Unit
62-
// Override this to disable parallelism of passes. Useful for debugging.
125+
126+
/** Process a single part, recording graph changes in `builder` and side results in `accumulator`.
127+
*
128+
* @param builder
129+
* the [[DiffGraphBuilder]] that accumulates graph modifications
130+
* @param part
131+
* the part to process
132+
* @param accumulator
133+
* the thread-local accumulator for this worker
134+
*/
135+
def runOnPart(builder: DiffGraphBuilder, part: T, accumulator: Accumulator): Unit
136+
137+
/** Override and return `false` to disable parallel execution. Useful for debugging. */
63138
def isParallel: Boolean = true
64139

140+
/** Create a fresh accumulator instance. Called once per parallel worker thread. */
141+
def createAccumulator(): Accumulator
142+
143+
/** Merge the `accumulator` (right) into `left`. Called during the combine phase of fork/join. */
144+
def mergeAccumulator(left: Accumulator, accumulator: Accumulator): Unit
145+
146+
/** Called once after all parts are processed and accumulators are merged. Use to record additional graph changes
147+
* based on the fully merged accumulator.
148+
*
149+
* @param builder
150+
* the [[DiffGraphBuilder]] for any additional modifications
151+
* @param accumulator
152+
* the final merged accumulator
153+
*/
154+
def onAccumulatorComplete(builder: DiffGraphBuilder, accumulator: Accumulator): Unit
155+
156+
/** Creates a new [[DiffGraphBuilder]], runs the pass (init, generateParts, runOnPart, finish), applies all
157+
* accumulated changes to the graph, and logs timing information. Exceptions during execution are logged and
158+
* re-thrown.
159+
*/
65160
override def createAndApply(): Unit = {
66161
baseLogger.info(s"Start of pass: $name")
67162
val nanosStart = System.nanoTime()
@@ -89,41 +184,50 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
89184
}
90185
}
91186

187+
/** Runs the full pass lifecycle (init, generateParts, parallel runOnPart, accumulator merge, finish) and absorbs all
188+
* changes into `externalBuilder` without applying them to the graph. The caller is responsible for applying the
189+
* builder.
190+
*
191+
* @param externalBuilder
192+
* the builder to absorb all generated changes into
193+
* @return
194+
* the number of parts that were processed
195+
*/
92196
override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = {
93197
try {
94198
init()
199+
95200
val parts = generateParts()
96201
val nParts = parts.size
97-
nParts match {
98-
case 0 =>
99-
case 1 =>
100-
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
101-
case _ =>
102-
val stream =
103-
if (!isParallel)
104-
java.util.Arrays
105-
.stream(parts)
106-
.sequential()
107-
else
108-
java.util.Arrays
109-
.stream(parts)
110-
.parallel()
111-
val diff = stream.collect(
112-
new Supplier[DiffGraphBuilder] {
113-
override def get(): DiffGraphBuilder =
114-
Cpg.newDiffGraphBuilder
115-
},
116-
new BiConsumer[DiffGraphBuilder, AnyRef] {
117-
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
118-
runOnPart(builder, part.asInstanceOf[T])
119-
},
120-
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
121-
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
122-
leftBuilder.absorb(rightBuilder)
123-
}
124-
)
125-
externalBuilder.absorb(diff)
126-
}
202+
val stream =
203+
if (!isParallel) java.util.Arrays.stream(parts).sequential()
204+
else java.util.Arrays.stream(parts).parallel()
205+
206+
val (diff, acc) = stream.collect(
207+
new Supplier[(DiffGraphBuilder, Accumulator)] {
208+
override def get(): (DiffGraphBuilder, Accumulator) =
209+
(Cpg.newDiffGraphBuilder, createAccumulator())
210+
},
211+
new BiConsumer[(DiffGraphBuilder, Accumulator), AnyRef] {
212+
override def accept(consumedArg: (DiffGraphBuilder, Accumulator), part: AnyRef): Unit = {
213+
val (diff, acc) = consumedArg
214+
runOnPart(diff, part.asInstanceOf[T], acc)
215+
}
216+
},
217+
new BiConsumer[(DiffGraphBuilder, Accumulator), (DiffGraphBuilder, Accumulator)] {
218+
override def accept(
219+
leftConsumedArg: (DiffGraphBuilder, Accumulator),
220+
rightConsumedArg: (DiffGraphBuilder, Accumulator)
221+
): Unit = {
222+
val (leftDiff, leftAcc) = leftConsumedArg
223+
val (rightDiff, rightAcc) = leftConsumedArg
224+
leftDiff.absorb(rightDiff)
225+
mergeAccumulator(leftAcc, rightAcc)
226+
}
227+
}
228+
)
229+
onAccumulatorComplete(diff, acc)
230+
externalBuilder.absorb(diff)
127231
nParts
128232
} finally {
129233
finish()
@@ -137,6 +241,9 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
137241

138242
}
139243

244+
/** Base trait for all CPG passes. Defines the lifecycle methods that every pass must implement: [[createAndApply]] for
245+
* standalone execution, and [[runWithBuilder]] for composing passes that share a single [[DiffGraphBuilder]].
246+
*/
140247
trait CpgPassBase {
141248

142249
protected def baseLogger: Logger = LoggerFactory.getLogger(getClass)
@@ -156,8 +263,12 @@ trait CpgPassBase {
156263
*/
157264
def runWithBuilder(builder: DiffGraphBuilder): Int
158265

159-
/** Wraps runWithBuilder with logging, and swallows raised exceptions. Use with caution -- API is unstable. A return
160-
* value of -1 indicates failure, otherwise the return value of runWithBuilder is passed through.
266+
/** Wraps [[runWithBuilder]] with logging and exception handling. Use with caution — API is unstable.
267+
*
268+
* @param builder
269+
* the [[DiffGraphBuilder]] to absorb changes into
270+
* @return
271+
* the number of parts processed, or `-1` if the pass threw an exception
161272
*/
162273
def runWithBuilderLogged(builder: DiffGraphBuilder): Int = {
163274
baseLogger.info(s"Start of pass: $name")
@@ -189,6 +300,15 @@ trait CpgPassBase {
189300
@deprecated
190301
protected def store(overlay: GeneratedMessageV3, name: String, serializedCpg: SerializedCpg): Unit = {}
191302

303+
/** Executes `fun` while logging the pass start and completion time (including duration via MDC).
304+
*
305+
* @tparam A
306+
* the return type of the wrapped computation
307+
* @param fun
308+
* the computation to execute
309+
* @return
310+
* the result of `fun`
311+
*/
192312
protected def withStartEndTimesLogged[A](fun: => A): A = {
193313
baseLogger.info(s"Running pass: $name")
194314
val startTime = System.currentTimeMillis

0 commit comments

Comments
 (0)