Skip to content

Commit 9e8ba9e

Browse files
authored
Return only sample flows in reachableByFlows + optimizations (#1000)
* Also expand non-partial results. * Minor optimizations * Increase readability slightly * A bit of cleanup * Replace list of path elements with vector of path elements * Use Vector instead of List for ReachableByResults * Deduplicate results at the end of `reaches` calls. * Do not store partial results for last analysis stage * Tabulate only sample path * Return a sample flow only: addresses performance issue
1 parent 2ff0775 commit 9e8ba9e

6 files changed

Lines changed: 107 additions & 102 deletions

File tree

dataflowengineoss/src/main/scala/io/shiftleft/dataflowengineoss/language/TrackingPoint.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ class TrackingPoint(val traversal: Traversal[nodes.TrackingPoint]) extends AnyVa
2626
traversal.map(_.cfgNode)
2727

2828
def ddgIn(implicit semantics: Semantics): Traversal[nodes.TrackingPoint] = {
29-
val cache = mutable.HashMap[nodes.TrackingPoint, List[PathElement]]()
30-
val result = traversal.flatMap(x => x.ddgIn(List(PathElement(x)), withInvisible = false, cache))
29+
val cache = mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]]()
30+
val result = traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache))
3131
cache.clear
3232
result
3333
}
3434

3535
def ddgInPathElem(implicit semantics: Semantics): Traversal[PathElement] = {
36-
val cache = mutable.HashMap[nodes.TrackingPoint, List[PathElement]]()
37-
val result = traversal.flatMap(x => x.ddgInPathElem(List(PathElement(x)), withInvisible = false, cache))
36+
val cache = mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]]()
37+
val result = traversal.flatMap(x => x.ddgInPathElem(Vector(PathElement(x)), withInvisible = false, cache))
3838
cache.clear
3939
result
4040
}
@@ -66,7 +66,7 @@ class TrackingPoint(val traversal: Traversal[nodes.TrackingPoint]) extends AnyVa
6666
paths.to(Traversal)
6767
}
6868

69-
private def removeConsecutiveDuplicates[T](l: List[T]): List[T] = {
69+
private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = {
7070
l.headOption.map(x => x :: l.sliding(2).collect { case Seq(a, b) if a != b => b }.toList).getOrElse(List())
7171
}
7272

dataflowengineoss/src/main/scala/io/shiftleft/dataflowengineoss/language/nodemethods/TrackingPointMethods.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@ class TrackingPointMethods[NodeType <: nodes.TrackingPoint](val node: NodeType)
3333
node.start.reachableBy(sourceTravs: _*)
3434

3535
def ddgIn(implicit semantics: Semantics): Traversal[TrackingPoint] = {
36-
val cache = mutable.HashMap[nodes.TrackingPoint, List[PathElement]]()
37-
val result = ddgIn(List(PathElement(node)), withInvisible = false, cache)
36+
val cache = mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]]()
37+
val result = ddgIn(Vector(PathElement(node)), withInvisible = false, cache)
3838
cache.clear()
3939
result
4040
}
4141

4242
def ddgInPathElem(withInvisible: Boolean,
43-
cache: mutable.HashMap[nodes.TrackingPoint, List[PathElement]] =
44-
mutable.HashMap[nodes.TrackingPoint, List[PathElement]]())(
43+
cache: mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]] =
44+
mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]]())(
4545
implicit semantics: Semantics): Traversal[PathElement] =
46-
ddgInPathElem(List(PathElement(node)), withInvisible, cache)
46+
ddgInPathElem(Vector(PathElement(node)), withInvisible, cache)
4747

4848
def ddgInPathElem(implicit semantics: Semantics): Traversal[PathElement] = {
49-
val cache = mutable.HashMap[nodes.TrackingPoint, List[PathElement]]()
50-
val result = ddgInPathElem(List(PathElement(node)), withInvisible = false, cache)
49+
val cache = mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]]()
50+
val result = ddgInPathElem(Vector(PathElement(node)), withInvisible = false, cache)
5151
cache.clear()
5252
result
5353
}
@@ -56,9 +56,9 @@ class TrackingPointMethods[NodeType <: nodes.TrackingPoint](val node: NodeType)
5656
* Traverse back in the data dependence graph by one step, taking into account semantics
5757
* @param path optional list of path elements that have been expanded already
5858
* */
59-
def ddgIn(path: List[PathElement],
59+
def ddgIn(path: Vector[PathElement],
6060
withInvisible: Boolean,
61-
cache: mutable.HashMap[nodes.TrackingPoint, List[PathElement]])(
61+
cache: mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]])(
6262
implicit semantics: Semantics): Traversal[TrackingPoint] = {
6363
ddgInPathElem(path, withInvisible, cache).map(_.node)
6464
}
@@ -68,18 +68,18 @@ class TrackingPointMethods[NodeType <: nodes.TrackingPoint](val node: NodeType)
6868
* taking into account semantics
6969
* @param path optional list of path elements that have been expanded already
7070
* */
71-
def ddgInPathElem(path: List[PathElement],
71+
def ddgInPathElem(path: Vector[PathElement],
7272
withInvisible: Boolean,
73-
cache: mutable.HashMap[nodes.TrackingPoint, List[PathElement]])(
73+
cache: mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]])(
7474
implicit semantics: Semantics): Traversal[PathElement] = {
7575
val result = ddgInPathElemInternal(path, withInvisible, cache).to(Traversal)
7676
result
7777
}
7878

79-
private def ddgInPathElemInternal(path: List[PathElement],
79+
private def ddgInPathElemInternal(path: Vector[PathElement],
8080
withInvisible: Boolean,
81-
cache: mutable.HashMap[nodes.TrackingPoint, List[PathElement]])(
82-
implicit semantics: Semantics): List[PathElement] = {
81+
cache: mutable.HashMap[nodes.TrackingPoint, Vector[PathElement]])(
82+
implicit semantics: Semantics): Vector[PathElement] = {
8383

8484
if (cache.contains(node)) {
8585
return cache(node)
@@ -91,7 +91,7 @@ class TrackingPointMethods[NodeType <: nodes.TrackingPoint](val node: NodeType)
9191
} else {
9292
(elems.filter(_.visible) ++ elems
9393
.filterNot(_.visible)
94-
.flatMap(x => x.node.ddgInPathElem(x :: path, withInvisible = false, cache))).distinct
94+
.flatMap(x => x.node.ddgInPathElem(x +: path, withInvisible = false, cache))).distinct
9595
}
9696
cache.put(node, result)
9797
result

dataflowengineoss/src/main/scala/io/shiftleft/dataflowengineoss/queryengine/Engine.scala

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import scala.util.{Failure, Success, Try}
1616
private case class ReachableByTask(sink: nodes.TrackingPoint,
1717
sources: Set[nodes.TrackingPoint],
1818
table: ResultTable,
19-
initialPath: List[PathElement] = List(),
19+
initialPath: Vector[PathElement] = Vector(),
2020
callDepth: Int = 0)
2121

2222
class Engine(context: EngineContext) {
@@ -31,9 +31,9 @@ class Engine(context: EngineContext) {
3131
* Initialize a pool of workers and return a "completion service" that
3232
* we can query (in a blocking manner) for completed tasks.
3333
* */
34-
private def initializeWorkerPool(): ExecutorCompletionService[List[ReachableByResult]] = {
34+
private def initializeWorkerPool(): ExecutorCompletionService[Vector[ReachableByResult]] = {
3535
val executorService: ExecutorService = Executors.newWorkStealingPool()
36-
new ExecutorCompletionService[List[ReachableByResult]](executorService)
36+
new ExecutorCompletionService[Vector[ReachableByResult]](executorService)
3737
}
3838

3939
/**
@@ -67,20 +67,19 @@ class Engine(context: EngineContext) {
6767
result
6868
}
6969

70-
private def newTasksFromResults(resultsOfTask: List[ReachableByResult],
71-
sources: Set[nodes.TrackingPoint]): List[ReachableByTask] = {
72-
tasksForPartialResults(resultsOfTask, sources) ++ tasksForUnresolvedOutArgs(resultsOfTask, sources)
70+
private def newTasksFromResults(resultsOfTask: Vector[ReachableByResult],
71+
sources: Set[nodes.TrackingPoint]): Vector[ReachableByTask] = {
72+
tasksForParams(resultsOfTask, sources) ++ tasksForUnresolvedOutArgs(resultsOfTask, sources)
7373
}
7474

7575
private def submitTask(task: ReachableByTask): Unit = {
7676
numberOfTasksRunning += 1
7777
completionService.submit(new ReachableByCallable(task, context))
7878
}
7979

80-
private def tasksForPartialResults(resultsOfTask: List[ReachableByResult],
81-
sources: Set[nodes.TrackingPoint]): List[ReachableByTask] = {
82-
val partialResults = resultsOfTask.filter(_.partial)
83-
val pathsFromParams = partialResults.map(x => (x.path, x.callDepth))
80+
private def tasksForParams(resultsOfTask: Vector[ReachableByResult],
81+
sources: Set[nodes.TrackingPoint]): Vector[ReachableByTask] = {
82+
val pathsFromParams = resultsOfTask.map(x => (x.path, x.callDepth))
8483
pathsFromParams.flatMap {
8584
case (path, callDepth) =>
8685
val param = path.head.node
@@ -91,12 +90,12 @@ class Engine(context: EngineContext) {
9190
ReachableByTask(arg, sources, new ResultTable, path, callDepth + 1)
9291
}
9392
}
94-
.getOrElse(List())
93+
.getOrElse(Vector())
9594
}
9695
}
9796

98-
private def tasksForUnresolvedOutArgs(resultsOfTask: List[ReachableByResult],
99-
sources: Set[nodes.TrackingPoint]): List[ReachableByTask] = {
97+
private def tasksForUnresolvedOutArgs(resultsOfTask: Vector[ReachableByResult],
98+
sources: Set[nodes.TrackingPoint]): Vector[ReachableByTask] = {
10099

101100
val outArgsAndCalls = resultsOfTask
102101
.map(x => (x.unresolvedArgs.collect { case e: nodes.Expression => e }, x.path, x.callDepth))
@@ -129,8 +128,8 @@ class Engine(context: EngineContext) {
129128

130129
object Engine {
131130

132-
def expandIn(curNode: nodes.TrackingPoint, path: List[PathElement])(
133-
implicit semantics: Semantics): List[PathElement] = {
131+
def expandIn(curNode: nodes.TrackingPoint, path: Vector[PathElement])(
132+
implicit semantics: Semantics): Vector[PathElement] = {
134133
curNode match {
135134
case argument: nodes.Expression =>
136135
val (arguments, nonArguments) = ddgInE(curNode, path).partition(_.outNode().isInstanceOf[nodes.Expression])
@@ -150,13 +149,13 @@ object Engine {
150149
PathElement(parentNode, outEdgeLabel = outLabel)
151150
}
152151

153-
private def ddgInE(dstNode: nodes.TrackingPoint, path: List[PathElement]): List[Edge] = {
152+
private def ddgInE(dstNode: nodes.TrackingPoint, path: Vector[PathElement]): Vector[Edge] = {
154153
dstNode
155154
.inE(EdgeTypes.REACHING_DEF)
156155
.asScala
157156
.filter(e => e.outNode().isInstanceOf[nodes.TrackingPoint])
158157
.filter(e => !path.map(_.node).contains(e.outNode().asInstanceOf[nodes.TrackingPoint]))
159-
.toList
158+
.toVector
160159
}
161160

162161
/**
@@ -230,19 +229,19 @@ case class EngineConfig(var maxCallDepth: Int = 4)
230229
* @param context state of the data flow engine
231230
* */
232231
private class ReachableByCallable(task: ReachableByTask, context: EngineContext)
233-
extends Callable[List[ReachableByResult]] {
232+
extends Callable[Vector[ReachableByResult]] {
234233

235234
import Engine._
236235

237236
/**
238237
* Entry point of callable.
239238
* */
240-
override def call(): List[ReachableByResult] = {
239+
override def call(): Vector[ReachableByResult] = {
241240
if (task.callDepth > context.config.maxCallDepth) {
242-
List()
241+
Vector()
243242
} else {
244243
implicit val sem: Semantics = context.semantics
245-
results(List(PathElement(task.sink)) ++ task.initialPath, task.sources, task.table)
244+
results(PathElement(task.sink) +: task.initialPath, task.sources, task.table)
246245
task.table.get(task.sink).get.map { r =>
247246
r.copy(callDepth = task.callDepth)
248247
}
@@ -261,33 +260,39 @@ private class ReachableByCallable(task: ReachableByTask, context: EngineContext)
261260
* @param path This is a path from a node to the sink. The first node
262261
* of the path is expanded by this method
263262
* */
264-
private def results[NodeType <: nodes.TrackingPoint](path: List[PathElement],
265-
sources: Set[NodeType],
266-
table: ResultTable)(implicit semantics: Semantics): Unit = {
263+
private def results[NodeType <: nodes.TrackingPoint](
264+
path: Vector[PathElement],
265+
sources: Set[NodeType],
266+
table: ResultTable)(implicit semantics: Semantics): Vector[ReachableByResult] = {
267267
val curNode = path.head.node
268268

269-
val resultsForParents: List[ReachableByResult] = {
270-
expandIn(curNode, path).flatMap { parent =>
271-
table.createFromTable(parent :: path).getOrElse {
272-
results(parent :: path, sources, table)
273-
table.get(parent.node).get
269+
val resultsForParents: Vector[ReachableByResult] = {
270+
expandIn(curNode, path).iterator.flatMap { parent =>
271+
val cachedResult = table.createFromTable(parent, path)
272+
if (cachedResult.isDefined) {
273+
cachedResult.get
274+
} else {
275+
results(parent +: path, sources, table)
274276
}
275-
}
277+
}.toVector
276278
}
277279

278280
val resultsForCurNode = {
279281
val endStates = if (sources.contains(curNode.asInstanceOf[NodeType])) {
280282
List(ReachableByResult(path))
281-
} else if (curNode.isInstanceOf[nodes.MethodParameterIn]) {
283+
} else if ((task.callDepth != context.config.maxCallDepth) && curNode.isInstanceOf[nodes.MethodParameterIn]) {
282284
List(ReachableByResult(path, partial = true))
283285
} else {
284286
List()
285287
}
286288

287289
val retsToResolve = curNode match {
288290
case call: nodes.Call =>
289-
if (methodsForCall(call).to(Traversal).internal.nonEmpty && semanticsForCall(call).isEmpty) {
290-
List(ReachableByResult(PathElement(path.head.node, resolved = false) :: path.tail, partial = true))
291+
if ((task.callDepth != context.config.maxCallDepth) && methodsForCall(call)
292+
.to(Traversal)
293+
.internal
294+
.nonEmpty && semanticsForCall(call).isEmpty) {
295+
List(ReachableByResult(PathElement(path.head.node, resolved = false) +: path.tail, partial = true))
291296
} else {
292297
List()
293298
}
@@ -296,7 +301,30 @@ private class ReachableByCallable(task: ReachableByTask, context: EngineContext)
296301
endStates ++ retsToResolve
297302
}
298303

299-
table.add(curNode, resultsForParents ++ resultsForCurNode)
304+
val res = (resultsForParents ++ resultsForCurNode)
305+
.groupBy { x =>
306+
(x.path.headOption ++ x.path.lastOption, x.partial, x.callDepth)
307+
}
308+
.map {
309+
case (_, list) =>
310+
val lenIdPathPairs = list.map(x => (x.path.length, x)).toList
311+
val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match {
312+
case Nil => Nil
313+
case h :: t => h :: t.takeWhile(y => y._1 == h._1)
314+
}).map(_._2)
315+
316+
if (withMaxLength.length == 1) {
317+
withMaxLength.head
318+
} else {
319+
withMaxLength.minBy { x =>
320+
x.path.map(_.node.id()).mkString("-")
321+
}
322+
}
323+
}
324+
.toVector
325+
326+
table.add(curNode, res)
327+
res
300328
}
301329

302330
private def semanticsForCall(call: nodes.Call)(implicit semantics: Semantics): List[FlowSemantic] = {

dataflowengineoss/src/main/scala/io/shiftleft/dataflowengineoss/queryengine/ResultTable.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ import scala.jdk.CollectionConverters._
55

66
class ResultTable {
77

8-
private val table = new java.util.concurrent.ConcurrentHashMap[nodes.StoredNode, List[ReachableByResult]].asScala
8+
private val table = new java.util.concurrent.ConcurrentHashMap[nodes.StoredNode, Vector[ReachableByResult]].asScala
99

1010
/**
1111
* Add all results in `value` to table entry at `key`, appending to existing
1212
* results.
1313
*/
14-
def add(key: nodes.StoredNode, value: List[ReachableByResult]): Unit = {
14+
def add(key: nodes.StoredNode, value: Vector[ReachableByResult]): Unit = {
1515
table.asJava.compute(key, { (_, existingValue) =>
16-
Option(existingValue).toList.flatten ++ value
16+
Option(existingValue).toVector.flatten ++ value
1717
})
1818
}
1919

@@ -22,11 +22,11 @@ class ResultTable {
2222
* table, and if so, for each result, determine the path up to `first` and prepend it to
2323
* `path`, giving us new results via table lookup.
2424
*/
25-
def createFromTable(path: List[PathElement]): Option[List[ReachableByResult]] = {
26-
val first = path.head
25+
def createFromTable(first: PathElement, remainder: Vector[PathElement]): Option[Vector[ReachableByResult]] = {
2726
table.get(first.node).map { res =>
2827
res.map { r =>
29-
val completePath = r.path.slice(0, r.path.map(_.node).indexOf(first.node)) ++ path
28+
val pathToFirstNode = r.path.slice(0, r.path.map(_.node).indexOf(first.node))
29+
val completePath = pathToFirstNode ++ (first +: remainder)
3030
r.copy(path = completePath)
3131
}
3232
}
@@ -36,7 +36,7 @@ class ResultTable {
3636
* Retrieve list of results for `node` or None if they are not
3737
* available in the table.
3838
*/
39-
def get(node: nodes.StoredNode): Option[List[ReachableByResult]] = {
39+
def get(node: nodes.StoredNode): Option[Vector[ReachableByResult]] = {
4040
table.get(node)
4141
}
4242

@@ -50,10 +50,10 @@ class ResultTable {
5050
* @param partial indicate whether this result stands on its own or requires further analysis,
5151
* e.g., by expanding output arguments backwards into method output parameters.
5252
* */
53-
case class ReachableByResult(path: List[PathElement], callDepth: Int = 0, partial: Boolean = false) {
53+
case class ReachableByResult(path: Vector[PathElement], callDepth: Int = 0, partial: Boolean = false) {
5454
def source: nodes.TrackingPoint = path.head.node
5555

56-
def unresolvedArgs: List[nodes.TrackingPoint] =
56+
def unresolvedArgs: Vector[nodes.TrackingPoint] =
5757
path.collect {
5858
case elem if !elem.resolved =>
5959
elem.node

0 commit comments

Comments
 (0)