Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ trait AstCreatorHelper { this: AstCreator =>
): (String, String, String) = {
nodeInfo.node match {
case StarExpr =>
// TODO: Need to handle pointer to pointer use case.
val (fullName, typeNameForcode) = internalArrayTypeHandler(createParserNodeInfo(nodeInfo.json(ParserKeys.X)))
val innerNode = createParserNodeInfo(nodeInfo.json(ParserKeys.X))
val (fullName, typeNameForcode, _) = internalStarExpHandler(innerNode)
(s"*$fullName", s"*$typeNameForcode", EvaluationStrategies.BY_SHARING)
case _ =>
val (fullName, typeNameForcode) = internalArrayTypeHandler(nodeInfo, genericTypeMethodMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, Prope
import ujson.Value

import scala.collection.immutable.Seq
import scala.util.{Success, Try}
trait AstForExpressionCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
def astsForExpression(expr: ParserNodeInfo): Seq[Ast] = {
expr.node match {
case BinaryExpr => astForBinaryExpr(expr)
case StarExpr => astForStarExpr(expr)
case UnaryExpr => astForUnaryExpr(expr)
case ParenExpr => astsForExpression(createParserNodeInfo(expr.json(ParserKeys.X)))
case TypeAssertExpr => astForNode(expr.json(ParserKeys.X))
case TypeAssertExpr => Seq(astForTypeAssertExpr(expr))
case CallExpr => astForCallExpression(expr)
case SelectorExpr => astForFieldAccess(expr)
case KeyValueExpr => astForNode(createParserNodeInfo(expr.json(ParserKeys.Value)))
Expand All @@ -25,6 +26,22 @@ trait AstForExpressionCreator(implicit withSchemaValidation: ValidationMode) { t
}
}

private def astForTypeAssertExpr(expr: ParserNodeInfo): Ast = {
Try(createParserNodeInfo(expr.json(ParserKeys.Type))) match {
case Success(typeNode) =>
val operandAst = astForNode(expr.json(ParserKeys.X))
val (typeFullName, _, _, _) = processTypeInfo(typeNode)
val castCall = callNode(
expr, expr.code, Operators.cast, Operators.cast,
DispatchTypes.STATIC_DISPATCH, None, Some(typeFullName)
)
callAst(castCall, operandAst)
case _ =>
// For type switch expressions x.(type), Type may be absent
astForNode(expr.json(ParserKeys.X)).headOption.getOrElse(Ast())
}
}

private def astForBinaryExpr(binaryExpr: ParserNodeInfo): Seq[Ast] = {
val arguments = astForNode(binaryExpr.json(ParserKeys.X)) ++: astForNode(binaryExpr.json(ParserKeys.Y))
// Randomly taking first element of the LHS.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ trait AstForMethodCallExpressionCreator(implicit withSchemaValidation: Validatio
val callMethodFullName = s"$receiverTypeFullName.$methodName"
val MethodCacheMetaData(returnTypeFullNameCache, signatureCache) = goGlobal
.getMethodMetadata(receiverTypeFullName, methodName)
.orElse {
// Fallback: check if receiverTypeFullName is an interface with this method
goGlobal.getInterfaceMethods(receiverTypeFullName).flatMap { methods =>
if (methods.contains(methodName)) {
goGlobal.getMethodMetadata(receiverTypeFullName, methodName)
} else None
}
}
.getOrElse(
MethodCacheMetaData(
s"$receiverTypeFullName.$methodName.${Defines.ReturnType}.${XDefines.Unknown}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,11 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t
}

protected def getTypeOfToken(basicLit: ParserNodeInfo): String = {
// TODO need to add more primitive types
Try(basicLit.json(ParserKeys.Kind).str match {
case "INT" => "int"
case "FLOAT" => "float32"
case "IMAG" => "imag"
case "CHAR" => "char"
case "FLOAT" => "float64"
case "IMAG" => "complex128"
case "CHAR" => "int32"
case "STRING" => "string"
case _ => Defines.anyTypeName
}).toOption.getOrElse(Defines.anyTypeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo}
import io.joern.x2cpg.{Ast, ValidationMode}
import ujson.Value

import scala.util.Try
import scala.util.{Success, Try}

trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down Expand Up @@ -57,8 +57,7 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this:
val fullName = fullyQualifiedPackage + Defines.dot + name
val typeNode = createParserNodeInfo(typeSepc.json(ParserKeys.Type))
val ast = typeNode.node match {
// As of don't see any use case where InterfaceType needs to be handled.
case InterfaceType => Seq.empty
case InterfaceType => processInterfaceType(typeNode, fullName)
// astForStructType() function will record the member types
case StructType => astForStructType(typeNode, fullName)
// Process lambda function types to record lambda function signature mapped to TypeFullName
Expand Down Expand Up @@ -99,6 +98,31 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this:
MethodMetadata()
}

private def processInterfaceType(typeNode: ParserNodeInfo, fullName: String): Seq[Ast] = {
val methodFields = Try(typeNode.json(ParserKeys.Methods)(ParserKeys.List))
.orElse(Try(typeNode.json(ParserKeys.Fields)(ParserKeys.List)))
methodFields.toOption.foreach { fields =>
fields.arr.foreach { field =>
Try(field(ParserKeys.Names).arr.head(ParserKeys.Name).str).toOption.foreach { name =>
val methodTypeNode = createParserNodeInfo(field(ParserKeys.Type))
val returnTypes = getReturnType(methodTypeNode.json, Map.empty)
val returnTypeStr = returnTypes match {
case Seq() => Defines.voidTypeName
case Seq(one) => one._1
case multiple => s"(${multiple.map(_._1).mkString(", ")})"
}
val params = Try(methodTypeNode.json(ParserKeys.Params)(ParserKeys.List)).getOrElse(ujson.Arr())
val sig = parameterSignature(params, Map.empty)
val methodFullName = s"$fullName.$name"
val signature = s"$methodFullName($sig)${if (returnTypeStr == Defines.voidTypeName) "" else returnTypeStr}"
goGlobal.recordMethodMetadata(fullName, name, MethodCacheMetaData(returnTypeStr, signature))
goGlobal.recordInterfaceMethods(fullName, name)
}
}
}
Seq.empty
}

protected def processImports(importDecl: Value): (String, String) = {
val importedEntity = importDecl(ParserKeys.Path).obj(ParserKeys.Value).str.replaceAll("\"", "")
val importedAsOption =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ class GoGlobal {
}
}

val interfaceMethodsMap: java.util.concurrent.ConcurrentHashMap[String, java.util.Set[String]] = new java.util.concurrent.ConcurrentHashMap()

def recordInterfaceMethods(interfaceFullName: String, methodName: String): Unit = {
interfaceMethodsMap.computeIfAbsent(interfaceFullName, _ => java.util.concurrent.ConcurrentHashMap.newKeySet[String]()).add(methodName)
}

def getInterfaceMethods(interfaceFullName: String): Option[java.util.Set[String]] =
Option(interfaceMethodsMap.get(interfaceFullName))

def splitNamespaceFromMember(fullName: String): (String, String) = {
if (fullName.contains('.')) {
val lastDotIndex = fullName.lastIndexOf('.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,5 @@ object ParserKeys {
val Args = "Args"
val Recv = "Recv"
val Index = "Index"
val Methods = "Methods"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.joern.go2cpg.passes.ast

import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.semanticcpg.language.*

class E2EGoValidationTests extends GoCodeToCpgSuite {

"Type assertion expressions" should {
val cpg = code("""
|package main
|func foo(x interface{}) {
| y := x.(int)
| _ = y
|}
|""".stripMargin)

"produce a cast operator call node" in {
val List(castCall) = cpg.call(Operators.cast).l
castCall.typeFullName shouldBe "int"
}
}

"Literal type inference per Go spec" should {
val cpg = code("""
|package main
|func main() {
| a := 42
| b := 3.14
| c := "hello"
|}
|""".stripMargin)

"infer int for integer literals" in {
val List(lit) = cpg.literal.code("42").l
lit.typeFullName shouldBe "int"
}

"infer float64 for float literals" in {
val List(lit) = cpg.literal.code("3.14").l
lit.typeFullName shouldBe "float64"
}

"infer string for string literals" in {
val List(lit) = cpg.literal.code("\"hello\"").l
lit.typeFullName shouldBe "string"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.joern.go2cpg.passes.ast

import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.semanticcpg.language.*

import java.io.File

class InterfaceTypeResolutionTests extends GoCodeToCpgSuite {

"Interface method metadata should be recorded" should {
val cpg = code(
"""
|module joern.io/sample
|go 1.18
|""".stripMargin,
"go.mod"
).moreCode(
"""
|package lib
|type Speaker interface {
| Speak() string
|}
|""".stripMargin,
Seq("lib", "iface.go").mkString(File.separator)
).moreCode(
"""
|package main
|import "joern.io/sample/lib"
|func greet(s lib.Speaker) string {
| return s.Speak()
|}
|""".stripMargin,
"main.go"
)

"resolve interface method call return type" in {
val List(speakCall) = cpg.call("Speak").l
speakCall.typeFullName shouldBe "string"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite {
"check for LITERAL nodes types" in {
val List(a, b, c, d, e, _, _) = cpg.literal.l
a.typeFullName shouldBe "int"
b.typeFullName shouldBe "float32"
b.typeFullName shouldBe "float64"
c.typeFullName shouldBe "string"
d.typeFullName shouldBe "bool"
e.typeFullName shouldBe "bool"
Expand All @@ -66,7 +66,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite {
"Check for local nodes" in {
val List(a, b, c, d, e, f) = cpg.local.l
a.typeFullName shouldBe "int"
b.typeFullName shouldBe "float32"
b.typeFullName shouldBe "float64"
c.typeFullName shouldBe "string"
d.typeFullName shouldBe "bool"
e.typeFullName shouldBe "bool"
Expand All @@ -76,7 +76,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite {
"check for identifier nodes" in {
val List(a, b, c, d, e, f) = cpg.identifier.l
a.typeFullName shouldBe "int"
b.typeFullName shouldBe "float32"
b.typeFullName shouldBe "float64"
c.typeFullName shouldBe "string"
d.typeFullName shouldBe "bool"
e.typeFullName shouldBe "bool"
Expand Down Expand Up @@ -527,19 +527,19 @@ class TypeFullNameTests extends GoCodeToCpgSuite {

"Check for local nodes" in {
val List(typefullname) = cpg.local.typeFullName.dedup.l
typefullname shouldBe "float32"
typefullname shouldBe "float64"
}

"check for identifier nodes" in {
val List(typefullname) = cpg.identifier.typeFullName.dedup.l
typefullname shouldBe "float32"
typefullname shouldBe "float64"
}

"Operator += call node type check" in {
val List(a, b, c) = cpg.call(Operators.assignmentPlus).l
a.typeFullName shouldBe "float32"
b.typeFullName shouldBe "float32"
c.typeFullName shouldBe "float32"
a.typeFullName shouldBe "float64"
b.typeFullName shouldBe "float64"
c.typeFullName shouldBe "float64"
}
}

Expand Down