11package io .shiftleft .dataflowengineoss .passes .reachingdef
22
33import io .shiftleft .codepropertygraph .generated .{EdgeTypes , nodes }
4- import io .shiftleft .semanticcpg .accesspath .{AccessPath , MatchResult , TrackedBase }
5- import io .shiftleft .semanticcpg .language ._
6- import io .shiftleft .semanticcpg .language .nodemethods .TrackingPointMethodsBase .ImplicitsAPI
74import org .slf4j .{Logger , LoggerFactory }
85import overflowdb .traversal ._
9-
10- import scala .jdk .CollectionConverters ._
11-
6+ import io .shiftleft .semanticcpg .language ._
7+ import io .shiftleft .semanticcpg .utils .MemberAccess .isGenericMemberAccessName
8+
9+ /**
10+ * The variables defined/used in the reaching def problem can
11+ * all be represented via nodes in the graph, however, that's
12+ * pretty confusing because it is then unclear that variables
13+ * and nodes are actually two separate domains. To make the
14+ * definition domain visible, we wrap nodes in `Definition`
15+ * classes. From a computational standpoint, this is not necessary,
16+ * but it greatly improves readability.
17+ * */
1218object Definition {
13-
1419 def fromNode (node : nodes.StoredNode ): Definition = {
1520 new Definition (node)
1621 }
17-
1822}
1923
20- class Definition private ( val node : nodes.StoredNode ) extends AnyVal {}
24+ case class Definition ( node : nodes.StoredNode ) {}
2125
2226object ReachingDefProblem {
2327
@@ -33,6 +37,9 @@ object ReachingDefProblem {
3337
3438}
3539
40+ /**
41+ * The control flow graph as viewed by the data flow solver.
42+ * */
3643class ReachingDefFlowGraph (method : nodes.Method ) extends FlowGraph {
3744
3845 private val logger : Logger = LoggerFactory .getLogger(this .getClass)
@@ -49,7 +56,7 @@ class ReachingDefFlowGraph(method: nodes.Method) extends FlowGraph {
4956 * */
5057 private def initSucc (ns : List [nodes.StoredNode ]): Map [nodes.StoredNode , List [nodes.StoredNode ]] = {
5158 ns.map {
52- case n @ (ret : nodes.Return ) => n -> List (ret.method.methodReturn )
59+ case n @ (_ : nodes.Return ) => n -> List (exitNode )
5360 case n @ (cfgNode : nodes.CfgNode ) =>
5461 n ->
5562 // `.cfgNext` would be wrong here because it filters `METHOD_RETURN`
@@ -79,6 +86,10 @@ class ReachingDefFlowGraph(method: nodes.Method) extends FlowGraph {
7986
8087}
8188
89+ /**
90+ * For each node of the graph, this transfer function defines how it affects
91+ * the propagation of definitions.
92+ * */
8293class ReachingDefTransferFunction (method : nodes.Method ) extends TransferFunction [Set [Definition ]] {
8394
8495 val gen : Map [nodes.StoredNode , Set [Definition ]] = initGen(method).withDefaultValue(Set .empty[Definition ])
@@ -100,90 +111,96 @@ class ReachingDefTransferFunction(method: nodes.Method) extends TransferFunction
100111 * */
101112 def initGen (method : nodes.Method ): Map [nodes.StoredNode , Set [Definition ]] = {
102113
103- def defsMadeByCall (call : nodes.Call ): Set [Definition ] = {
104- (Set (call) ++ call.start.argument
105- .filterNot(_.isInstanceOf [nodes.Literal ])
106- .filterNot(_.isInstanceOf [nodes.FieldIdentifier ]))
107- .map(x => Definition .fromNode(x.asInstanceOf [nodes.StoredNode ]))
108- }
109-
110114 val defsForParams = method.parameter.l.map { param =>
111115 param -> Set (Definition .fromNode(param.asInstanceOf [nodes.StoredNode ]))
112116 }
113117
114- val defsForCalls = method.call.l.map { call =>
115- call -> defsMadeByCall(call)
116- }
118+ // We filter out field accesses to ensure that they propagate
119+ // taint unharmed.
120+
121+ val defsForCalls = method.call
122+ .filterNot(x => isGenericMemberAccessName(x.name))
123+ .l
124+ .map { call =>
125+ call -> {
126+ val retVal = Set (call)
127+ val args = call.argument.filter(hasValidGenType)
128+ (retVal ++ args)
129+ .map(x => Definition .fromNode(x.asInstanceOf [nodes.StoredNode ]))
130+ }
131+ }
117132 (defsForParams ++ defsForCalls).toMap
118133 }
119134
135+ /**
136+ * Restricts the types of nodes that represent definitions.
137+ * */
138+ private def hasValidGenType (node : nodes.Expression ): Boolean = {
139+ node match {
140+ case _ : nodes.Call => true
141+ case _ : nodes.Identifier => true
142+ case _ => false
143+ }
144+ }
145+
120146 /**
121147 * Initialize the map `kill`, a map that contains killed
122148 * definitions for each flow graph node.
149+ *
150+ * All operations in our graph are represented by calls and non-operations
151+ * such as identifiers or field-identifiers have empty gen and kill sets,
152+ * meaning that they just pass on definitions unaltered.
123153 * */
124- def initKill (method : nodes.Method ,
125- gen : Map [nodes.StoredNode , Set [Definition ]]): Map [nodes.StoredNode , Set [Definition ]] = {
154+ private def initKill (method : nodes.Method ,
155+ gen : Map [nodes.StoredNode , Set [Definition ]]): Map [nodes.StoredNode , Set [Definition ]] = {
126156
127- val baseToCalls : Map [TrackedBase , List [(nodes.Call , AccessPath )]] = method.call.l
157+ // We filter out field accesses to ensure that they propagate
158+ // taint unharmed.
159+
160+ method.call
161+ .filterNot(x => isGenericMemberAccessName(x.name))
128162 .map { call =>
129- val (base, path) = call.trackedBaseAndAccessPath
130- (base, (call, path))
131- }
132- .groupBy(_._1)
133- .map { case (k, v) => (k, v.map(_._2)) }
134-
135- def allOtherInstancesOf (node : nodes.StoredNode ): Set [nodes.StoredNode ] = {
136- node match {
137- case call : nodes.Call =>
138- val (base, accessPath) = call.trackedBaseAndAccessPath
139- baseToCalls
140- .getOrElse(base, Nil )
141- .collect {
142- case (otherCall, otherPath) if node.id != otherCall.id && {
143- val m = otherPath.matchAndDiff(accessPath.elements)
144- m._1 == MatchResult .EXACT_MATCH && m._2.elements.length == 0
145- } =>
146- otherCall
147- }
148- .toSet
149- case _ =>
150- declaration(node).toList
151- .flatMap(instances)
152- .filter(_.id != node.id)
153- .toSet
163+ call -> killsForGens(gen(call))
154164 }
155- }
156-
157- // We are also adding nodes here that may not even be definitions, but that's
158- // fine since `kill` is only subtracted
159- method.call.map { call =>
160- val killedDefs = gen(call)
161- .map { d =>
162- allOtherInstancesOf(d.node)
163- .filter(d => call.id != d.id)
164- .map(x => Definition .fromNode(x))
165- }
166- .fold(Set ())((v1, v2) => v1.union(v2))
167- call -> killedDefs
168- }.toMap
165+ .toMap
169166 }
170167
171- private def instances (decl : nodes.StoredNode ): List [nodes.StoredNode ] = {
172- decl._refIn().asScala.toList ++ {
173- if (decl.isInstanceOf [nodes.MethodParameterIn ]) {
174- List (decl)
175- } else {
176- List ()
177- }
168+ /**
169+ * The only way in which a call can kill another definition is by
170+ * generating a new definition for the same variable. Given the
171+ * set of generated definitions `gens`, we calculate definitions
172+ * of the same variable for each, that is, we calculate kill(call)
173+ * based on gen(call).
174+ * */
175+ private def killsForGens (genOfCall : Set [Definition ]): Set [Definition ] = {
176+ genOfCall.flatMap { definition =>
177+ definitionsOfSameVariable(definition)
178178 }
179179 }
180180
181- private def declaration (node : nodes.StoredNode ): Option [nodes.StoredNode ] = {
182- node match {
183- case param : nodes.MethodParameterIn => Some (param)
184- case _ : nodes.Identifier => node._refOut().nextOption
185- case _ => None
181+ private def definitionsOfSameVariable (definition : Definition ): Set [Definition ] = {
182+ val definedNodes = definition.node match {
183+ case param : nodes.MethodParameterIn =>
184+ method.cfgNode
185+ .filter(x => x.id != param.id)
186+ .isIdentifier
187+ .nameExact(param.name)
188+ .toSet
189+ case identifier : nodes.Identifier =>
190+ method.cfgNode
191+ .filter(x => x.id != identifier.id)
192+ .isIdentifier
193+ .nameExact(identifier.name)
194+ .toSet
195+ case call : nodes.Call =>
196+ method.cfgNode
197+ .filter(x => x.id != call.id)
198+ .isCall
199+ .codeExact(call.code)
200+ .toSet
201+ case _ => Set ()
186202 }
203+ definedNodes.map(x => Definition .fromNode(x))
187204 }
188205
189206}
0 commit comments