@@ -10,12 +10,23 @@ import scala.annotation.nowarn
1010import scala .concurrent .duration .DurationLong
1111import scala .util .{Failure , Success , Try }
1212
13- /* CpgPass
14- *
15- * Base class of a program which receives a CPG as input for the purpose of modifying it.
16- * */
13+ /** A single-threaded CPG pass. This is the simplest pass to implement: override [[run ]] and add desired graph
14+ * modifications to the provided [[DiffGraphBuilder ]].
15+ *
16+ * Internally implemented as a [[ForkJoinParallelCpgPass ]] with a single part and parallelism disabled.
17+ *
18+ * @param cpg
19+ * the code property graph to modify
20+ * @param outName
21+ * optional name for output
22+ */
1723abstract class CpgPass (cpg : Cpg , outName : String = " " ) extends ForkJoinParallelCpgPass [AnyRef ](cpg, outName) {
1824
25+ /** The main method to implement. Add all desired graph changes (nodes, edges, properties) to the provided builder.
26+ *
27+ * @param builder
28+ * the [[DiffGraphBuilder ]] that accumulates graph modifications
29+ */
1930 def run (builder : DiffGraphBuilder ): Unit
2031
2132 final override def generateParts (): Array [? <: AnyRef ] = Array [AnyRef ](null )
@@ -26,42 +37,126 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC
2637 override def isParallel : Boolean = false
2738}
2839
40+ /** @deprecated Use [[CpgPass ]] instead. */
2941@ deprecated abstract class SimpleCpgPass (cpg : Cpg , outName : String = " " ) extends CpgPass (cpg, outName)
3042
31- /* ForkJoinParallelCpgPass is a possible replacement for CpgPass and ParallelCpgPass.
32- *
33- * Instead of returning an Iterator, generateParts() returns an Array. This means that the entire collection
34- * of parts must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
35- * e.g. when running over all METHOD nodes and deleting some of them.
36- *
37- * Instead of streaming writes as ParallelCpgPass do, all `runOnPart` invocations read the initial state
38- * of the graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one go.
39- *
40- * In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
41- * The effect is identical as if one were to sequentially run `runOnParts` on all output elements of `generateParts()`
42- * in sequential order, with the same builder.
43- *
44- * This simplifies semantics and makes it easy to reason about possible races.
45- *
46- * Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption when porting from ParallelCpgPass.
47- *
48- * Initialization and cleanup of external resources or large datastructures can be done in the `init()` and `finish()`
49- * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
50- * passes eagerly, and releases them only when the entire chain has run.
51- * */
52- abstract class ForkJoinParallelCpgPass [T <: AnyRef ](cpg : Cpg , @ nowarn outName : String = " " ) extends CpgPassBase {
43+ /** A parallel CPG pass using the fork/join model.
44+ *
45+ * Instead of returning an Iterator, [[generateParts ]] returns an Array. This means that the entire collection of parts
46+ * must live on the heap at the same time; on the other hand, there are no possible issues with iterator invalidation,
47+ * e.g. when running over all METHOD nodes and deleting some of them.
48+ *
49+ * Instead of streaming writes as ParallelCpgPass do, all [[runOnPart ]] invocations read the initial state of the
50+ * graph. Then all changes (accumulated in the DiffGraphBuilders) are merged into a single change, and applied in one
51+ * go.
52+ *
53+ * In other words, the parallelism follows the fork/join parallel map-reduce (java: collect, scala: aggregate) model.
54+ * The effect is identical as if one were to sequentially run [[runOnPart ]] on all output elements of [[generateParts ]]
55+ * in sequential order, with the same builder.
56+ *
57+ * This simplifies semantics and makes it easy to reason about possible races.
58+ *
59+ * Note that ForkJoinParallelCpgPass never writes intermediate results, so one must consider peak memory consumption
60+ * when porting from ParallelCpgPass.
61+ *
62+ * Initialization and cleanup of external resources or large datastructures can be done in the [[init ]] and [[finish ]]
63+ * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct passes
64+ * eagerly, and releases them only when the entire chain has run.
65+ *
66+ * This is a simplified form of [[ForkJoinParallelCpgPassWithAccumulator ]] that does not use an accumulator.
67+ *
68+ * @tparam T
69+ * the type of each part produced by [[generateParts ]]
70+ * @param cpg
71+ * the code property graph to modify
72+ * @param outname
73+ * optional output name
74+ */
75+ abstract class ForkJoinParallelCpgPass [T <: AnyRef ](cpg : Cpg , @ nowarn outname : String = " " )
76+ extends ForkJoinParallelCpgPassWithAccumulator [T , Null ](cpg, outname) {
77+
78+ /** Process a single part and record graph modifications in the provided builder.
79+ *
80+ * @param builder
81+ * the [[DiffGraphBuilder ]] that accumulates graph modifications
82+ * @param part
83+ * the part to process, as produced by [[generateParts ]]
84+ */
85+ def runOnPart (builder : DiffGraphBuilder , part : T ): Unit
86+
87+ override def createAccumulator (): Null = null
88+ override def runOnPart (builder : DiffGraphBuilder , part : T , acc : Null ): Unit = runOnPart(builder, part)
89+ override def onAccumulatorComplete (builder : DiffGraphBuilder , accumulator : Null ): Unit = {}
90+ override def mergeAccumulator (left : Null , accumulator : Null ): Unit = {}
91+ }
92+
93+ /** A parallel CPG pass with an accumulator for aggregating side results.
94+ *
95+ * This is the most general form of the fork/join pass framework. It extends [[ForkJoinParallelCpgPass ]] with an
96+ * accumulator of type [[Accumulator ]] that each parallel worker maintains locally. After all parts are processed,
97+ * worker accumulators are merged via [[mergeAccumulator ]], and the final merged accumulator is passed to
98+ * [[onAccumulatorComplete ]] where additional graph changes can be recorded.
99+ *
100+ * @tparam T
101+ * the type of each part produced by [[generateParts ]]
102+ * @tparam Accumulator
103+ * the type of the accumulator used during parallel execution
104+ * @param cpg
105+ * the code property graph to modify
106+ * @param outName
107+ * optional output name
108+ */
109+ abstract class ForkJoinParallelCpgPassWithAccumulator [T <: AnyRef , Accumulator <: AnyRef ](
110+ cpg : Cpg ,
111+ @ nowarn outName : String = " "
112+ ) extends CpgPassBase {
53113 type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder
54- // generate Array of parts that can be processed in parallel
114+
115+ /** Generate an array of parts to be processed in parallel by [[runOnPart ]]. */
55116 def generateParts (): Array [? <: AnyRef ]
56- // setup large data structures, acquire external resources
117+
118+ /** Called once before [[generateParts ]]. Use to set up large data structures or acquire external resources. */
57119 def init (): Unit = {}
58- // release large data structures and external resources
120+
121+ /** Called once after all parts have been processed (in a `finally` block). Use to release resources acquired in
122+ * [[init ]].
123+ */
59124 def finish (): Unit = {}
60- // main function: add desired changes to builder
61- def runOnPart (builder : DiffGraphBuilder , part : T ): Unit
62- // Override this to disable parallelism of passes. Useful for debugging.
125+
126+ /** Process a single part, recording graph changes in `builder` and side results in `accumulator`.
127+ *
128+ * @param builder
129+ * the [[DiffGraphBuilder ]] that accumulates graph modifications
130+ * @param part
131+ * the part to process
132+ * @param accumulator
133+ * the thread-local accumulator for this worker
134+ */
135+ def runOnPart (builder : DiffGraphBuilder , part : T , accumulator : Accumulator ): Unit
136+
137+ /** Override and return `false` to disable parallel execution. Useful for debugging. */
63138 def isParallel : Boolean = true
64139
140+ /** Create a fresh accumulator instance. Called once per parallel worker thread. */
141+ def createAccumulator (): Accumulator
142+
143+ /** Merge the `accumulator` (right) into `left`. Called during the combine phase of fork/join. */
144+ def mergeAccumulator (left : Accumulator , accumulator : Accumulator ): Unit
145+
146+ /** Called once after all parts are processed and accumulators are merged. Use to record additional graph changes
147+ * based on the fully merged accumulator.
148+ *
149+ * @param builder
150+ * the [[DiffGraphBuilder ]] for any additional modifications
151+ * @param accumulator
152+ * the final merged accumulator
153+ */
154+ def onAccumulatorComplete (builder : DiffGraphBuilder , accumulator : Accumulator ): Unit
155+
156+ /** Creates a new [[DiffGraphBuilder ]], runs the pass (init, generateParts, runOnPart, finish), applies all
157+ * accumulated changes to the graph, and logs timing information. Exceptions during execution are logged and
158+ * re-thrown.
159+ */
65160 override def createAndApply (): Unit = {
66161 baseLogger.info(s " Start of pass: $name" )
67162 val nanosStart = System .nanoTime()
@@ -89,41 +184,50 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
89184 }
90185 }
91186
187+ /** Runs the full pass lifecycle (init, generateParts, parallel runOnPart, accumulator merge, finish) and absorbs all
188+ * changes into `externalBuilder` without applying them to the graph. The caller is responsible for applying the
189+ * builder.
190+ *
191+ * @param externalBuilder
192+ * the builder to absorb all generated changes into
193+ * @return
194+ * the number of parts that were processed
195+ */
92196 override def runWithBuilder (externalBuilder : DiffGraphBuilder ): Int = {
93197 try {
94198 init()
199+
95200 val parts = generateParts()
96201 val nParts = parts.size
97- nParts match {
98- case 0 =>
99- case 1 =>
100- runOnPart(externalBuilder, parts(0 ).asInstanceOf [T ])
101- case _ =>
102- val stream =
103- if (! isParallel)
104- java.util.Arrays
105- .stream(parts)
106- .sequential()
107- else
108- java.util.Arrays
109- .stream(parts)
110- .parallel()
111- val diff = stream.collect(
112- new Supplier [DiffGraphBuilder ] {
113- override def get (): DiffGraphBuilder =
114- Cpg .newDiffGraphBuilder
115- },
116- new BiConsumer [DiffGraphBuilder , AnyRef ] {
117- override def accept (builder : DiffGraphBuilder , part : AnyRef ): Unit =
118- runOnPart(builder, part.asInstanceOf [T ])
119- },
120- new BiConsumer [DiffGraphBuilder , DiffGraphBuilder ] {
121- override def accept (leftBuilder : DiffGraphBuilder , rightBuilder : DiffGraphBuilder ): Unit =
122- leftBuilder.absorb(rightBuilder)
123- }
124- )
125- externalBuilder.absorb(diff)
126- }
202+ val stream =
203+ if (! isParallel) java.util.Arrays .stream(parts).sequential()
204+ else java.util.Arrays .stream(parts).parallel()
205+
206+ val (diff, acc) = stream.collect(
207+ new Supplier [(DiffGraphBuilder , Accumulator )] {
208+ override def get (): (DiffGraphBuilder , Accumulator ) =
209+ (Cpg .newDiffGraphBuilder, createAccumulator())
210+ },
211+ new BiConsumer [(DiffGraphBuilder , Accumulator ), AnyRef ] {
212+ override def accept (consumedArg : (DiffGraphBuilder , Accumulator ), part : AnyRef ): Unit = {
213+ val (diff, acc) = consumedArg
214+ runOnPart(diff, part.asInstanceOf [T ], acc)
215+ }
216+ },
217+ new BiConsumer [(DiffGraphBuilder , Accumulator ), (DiffGraphBuilder , Accumulator )] {
218+ override def accept (
219+ leftConsumedArg : (DiffGraphBuilder , Accumulator ),
220+ rightConsumedArg : (DiffGraphBuilder , Accumulator )
221+ ): Unit = {
222+ val (leftDiff, leftAcc) = leftConsumedArg
223+ val (rightDiff, rightAcc) = leftConsumedArg
224+ leftDiff.absorb(rightDiff)
225+ mergeAccumulator(leftAcc, rightAcc)
226+ }
227+ }
228+ )
229+ onAccumulatorComplete(diff, acc)
230+ externalBuilder.absorb(diff)
127231 nParts
128232 } finally {
129233 finish()
@@ -137,6 +241,9 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
137241
138242}
139243
244+ /** Base trait for all CPG passes. Defines the lifecycle methods that every pass must implement: [[createAndApply ]] for
245+ * standalone execution, and [[runWithBuilder ]] for composing passes that share a single [[DiffGraphBuilder ]].
246+ */
140247trait CpgPassBase {
141248
142249 protected def baseLogger : Logger = LoggerFactory .getLogger(getClass)
@@ -156,8 +263,12 @@ trait CpgPassBase {
156263 */
157264 def runWithBuilder (builder : DiffGraphBuilder ): Int
158265
159- /** Wraps runWithBuilder with logging, and swallows raised exceptions. Use with caution -- API is unstable. A return
160- * value of -1 indicates failure, otherwise the return value of runWithBuilder is passed through.
266+ /** Wraps [[runWithBuilder ]] with logging and exception handling. Use with caution — API is unstable.
267+ *
268+ * @param builder
269+ * the [[DiffGraphBuilder ]] to absorb changes into
270+ * @return
271+ * the number of parts processed, or `-1` if the pass threw an exception
161272 */
162273 def runWithBuilderLogged (builder : DiffGraphBuilder ): Int = {
163274 baseLogger.info(s " Start of pass: $name" )
@@ -189,6 +300,15 @@ trait CpgPassBase {
189300 @ deprecated
190301 protected def store (overlay : GeneratedMessageV3 , name : String , serializedCpg : SerializedCpg ): Unit = {}
191302
303+ /** Executes `fun` while logging the pass start and completion time (including duration via MDC).
304+ *
305+ * @tparam A
306+ * the return type of the wrapped computation
307+ * @param fun
308+ * the computation to execute
309+ * @return
310+ * the result of `fun`
311+ */
192312 protected def withStartEndTimesLogged [A ](fun : => A ): A = {
193313 baseLogger.info(s " Running pass: $name" )
194314 val startTime = System .currentTimeMillis
0 commit comments