From f6abd810326ea57c17888b6cf4533ea5fc935242 Mon Sep 17 00:00:00 2001 From: allsmog Date: Mon, 30 Mar 2026 10:01:51 -0700 Subject: [PATCH] [gosrc2cpg] Fix literal types, pointer-to-pointer, type assertions, and interface resolution - Fix float literals to float64, imaginary to complex128, char to int32 per Go spec - Fix pointer-to-pointer (**T) via recursive type resolution - Fix TypeAssertExpr to produce Operators.cast call with correct result type - Extract interface method sets and add interface method lookup fallback --- .../astcreation/AstCreatorHelper.scala | 4 +- .../astcreation/AstForExpressionCreator.scala | 19 ++++++- .../AstForMethodCallExpressionCreator.scala | 8 +++ .../astcreation/AstForPrimitivesCreator.scala | 7 ++- .../astcreation/CommonCacheBuilder.scala | 30 ++++++++++-- .../gosrc2cpg/datastructures/GoGlobal.scala | 9 ++++ .../io/joern/gosrc2cpg/parser/ParserAst.scala | 1 + .../passes/ast/E2EGoValidationTests.scala | 49 +++++++++++++++++++ .../ast/InterfaceTypeResolutionTests.scala | 41 ++++++++++++++++ .../go2cpg/passes/ast/TypeFullNameTests.scala | 16 +++--- 10 files changed, 166 insertions(+), 18 deletions(-) create mode 100644 joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/E2EGoValidationTests.scala create mode 100644 joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/InterfaceTypeResolutionTests.scala diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreatorHelper.scala index 246a894a27f5..a8a35f78af53 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstCreatorHelper.scala @@ -243,8 +243,8 @@ trait AstCreatorHelper { this: AstCreator => ): (String, String, String) = { nodeInfo.node match { case StarExpr => - // TODO: Need to handle pointer to pointer use case. - val (fullName, typeNameForcode) = internalArrayTypeHandler(createParserNodeInfo(nodeInfo.json(ParserKeys.X))) + val innerNode = createParserNodeInfo(nodeInfo.json(ParserKeys.X)) + val (fullName, typeNameForcode, _) = internalStarExpHandler(innerNode) (s"*$fullName", s"*$typeNameForcode", EvaluationStrategies.BY_SHARING) case _ => val (fullName, typeNameForcode) = internalArrayTypeHandler(nodeInfo, genericTypeMethodMap) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForExpressionCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForExpressionCreator.scala index d5210a9b0201..96b101c9b6f7 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForExpressionCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForExpressionCreator.scala @@ -9,6 +9,7 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators, Prope import ujson.Value import scala.collection.immutable.Seq +import scala.util.{Success, Try} trait AstForExpressionCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator => def astsForExpression(expr: ParserNodeInfo): Seq[Ast] = { expr.node match { @@ -16,7 +17,7 @@ trait AstForExpressionCreator(implicit withSchemaValidation: ValidationMode) { t case StarExpr => astForStarExpr(expr) case UnaryExpr => astForUnaryExpr(expr) case ParenExpr => astsForExpression(createParserNodeInfo(expr.json(ParserKeys.X))) - case TypeAssertExpr => astForNode(expr.json(ParserKeys.X)) + case TypeAssertExpr => Seq(astForTypeAssertExpr(expr)) case CallExpr => astForCallExpression(expr) case SelectorExpr => astForFieldAccess(expr) case KeyValueExpr => astForNode(createParserNodeInfo(expr.json(ParserKeys.Value))) @@ -25,6 +26,22 @@ trait AstForExpressionCreator(implicit withSchemaValidation: ValidationMode) { t } } + private def astForTypeAssertExpr(expr: ParserNodeInfo): Ast = { + Try(createParserNodeInfo(expr.json(ParserKeys.Type))) match { + case Success(typeNode) => + val operandAst = astForNode(expr.json(ParserKeys.X)) + val (typeFullName, _, _, _) = processTypeInfo(typeNode) + val castCall = callNode( + expr, expr.code, Operators.cast, Operators.cast, + DispatchTypes.STATIC_DISPATCH, None, Some(typeFullName) + ) + callAst(castCall, operandAst) + case _ => + // For type switch expressions x.(type), Type may be absent + astForNode(expr.json(ParserKeys.X)).headOption.getOrElse(Ast()) + } + } + private def astForBinaryExpr(binaryExpr: ParserNodeInfo): Seq[Ast] = { val arguments = astForNode(binaryExpr.json(ParserKeys.X)) ++: astForNode(binaryExpr.json(ParserKeys.Y)) // Randomly taking first element of the LHS. diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForMethodCallExpressionCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForMethodCallExpressionCreator.scala index 60cd993e1bd1..3f06cad11fe0 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForMethodCallExpressionCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForMethodCallExpressionCreator.scala @@ -167,6 +167,14 @@ trait AstForMethodCallExpressionCreator(implicit withSchemaValidation: Validatio val callMethodFullName = s"$receiverTypeFullName.$methodName" val MethodCacheMetaData(returnTypeFullNameCache, signatureCache) = goGlobal .getMethodMetadata(receiverTypeFullName, methodName) + .orElse { + // Fallback: check if receiverTypeFullName is an interface with this method + goGlobal.getInterfaceMethods(receiverTypeFullName).flatMap { methods => + if (methods.contains(methodName)) { + goGlobal.getMethodMetadata(receiverTypeFullName, methodName) + } else None + } + } .getOrElse( MethodCacheMetaData( s"$receiverTypeFullName.$methodName.${Defines.ReturnType}.${XDefines.Unknown}", diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPrimitivesCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPrimitivesCreator.scala index d47a36e65cc2..e81d11755eb9 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForPrimitivesCreator.scala @@ -107,12 +107,11 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t } protected def getTypeOfToken(basicLit: ParserNodeInfo): String = { - // TODO need to add more primitive types Try(basicLit.json(ParserKeys.Kind).str match { case "INT" => "int" - case "FLOAT" => "float32" - case "IMAG" => "imag" - case "CHAR" => "char" + case "FLOAT" => "float64" + case "IMAG" => "complex128" + case "CHAR" => "int32" case "STRING" => "string" case _ => Defines.anyTypeName }).toOption.getOrElse(Defines.anyTypeName) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala index aa2abec8212a..3a095ff5e220 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala @@ -6,7 +6,7 @@ import io.joern.gosrc2cpg.parser.{ParserKeys, ParserNodeInfo} import io.joern.x2cpg.{Ast, ValidationMode} import ujson.Value -import scala.util.Try +import scala.util.{Success, Try} trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator => @@ -57,8 +57,7 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this: val fullName = fullyQualifiedPackage + Defines.dot + name val typeNode = createParserNodeInfo(typeSepc.json(ParserKeys.Type)) val ast = typeNode.node match { - // As of don't see any use case where InterfaceType needs to be handled. - case InterfaceType => Seq.empty + case InterfaceType => processInterfaceType(typeNode, fullName) // astForStructType() function will record the member types case StructType => astForStructType(typeNode, fullName) // Process lambda function types to record lambda function signature mapped to TypeFullName @@ -99,6 +98,31 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this: MethodMetadata() } + private def processInterfaceType(typeNode: ParserNodeInfo, fullName: String): Seq[Ast] = { + val methodFields = Try(typeNode.json(ParserKeys.Methods)(ParserKeys.List)) + .orElse(Try(typeNode.json(ParserKeys.Fields)(ParserKeys.List))) + methodFields.toOption.foreach { fields => + fields.arr.foreach { field => + Try(field(ParserKeys.Names).arr.head(ParserKeys.Name).str).toOption.foreach { name => + val methodTypeNode = createParserNodeInfo(field(ParserKeys.Type)) + val returnTypes = getReturnType(methodTypeNode.json, Map.empty) + val returnTypeStr = returnTypes match { + case Seq() => Defines.voidTypeName + case Seq(one) => one._1 + case multiple => s"(${multiple.map(_._1).mkString(", ")})" + } + val params = Try(methodTypeNode.json(ParserKeys.Params)(ParserKeys.List)).getOrElse(ujson.Arr()) + val sig = parameterSignature(params, Map.empty) + val methodFullName = s"$fullName.$name" + val signature = s"$methodFullName($sig)${if (returnTypeStr == Defines.voidTypeName) "" else returnTypeStr}" + goGlobal.recordMethodMetadata(fullName, name, MethodCacheMetaData(returnTypeStr, signature)) + goGlobal.recordInterfaceMethods(fullName, name) + } + } + } + Seq.empty + } + protected def processImports(importDecl: Value): (String, String) = { val importedEntity = importDecl(ParserKeys.Path).obj(ParserKeys.Value).str.replaceAll("\"", "") val importedAsOption = diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/datastructures/GoGlobal.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/datastructures/GoGlobal.scala index 6dd08a5c4293..261d0f10ed28 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/datastructures/GoGlobal.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/datastructures/GoGlobal.scala @@ -122,6 +122,15 @@ class GoGlobal { } } + val interfaceMethodsMap: java.util.concurrent.ConcurrentHashMap[String, java.util.Set[String]] = new java.util.concurrent.ConcurrentHashMap() + + def recordInterfaceMethods(interfaceFullName: String, methodName: String): Unit = { + interfaceMethodsMap.computeIfAbsent(interfaceFullName, _ => java.util.concurrent.ConcurrentHashMap.newKeySet[String]()).add(methodName) + } + + def getInterfaceMethods(interfaceFullName: String): Option[java.util.Set[String]] = + Option(interfaceMethodsMap.get(interfaceFullName)) + def splitNamespaceFromMember(fullName: String): (String, String) = { if (fullName.contains('.')) { val lastDotIndex = fullName.lastIndexOf('.') diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala index 6aaa955028ac..72ed66dc2655 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala @@ -118,4 +118,5 @@ object ParserKeys { val Args = "Args" val Recv = "Recv" val Index = "Index" + val Methods = "Methods" } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/E2EGoValidationTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/E2EGoValidationTests.scala new file mode 100644 index 000000000000..e5b253116db0 --- /dev/null +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/E2EGoValidationTests.scala @@ -0,0 +1,49 @@ +package io.joern.go2cpg.passes.ast + +import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite +import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.semanticcpg.language.* + +class E2EGoValidationTests extends GoCodeToCpgSuite { + + "Type assertion expressions" should { + val cpg = code(""" + |package main + |func foo(x interface{}) { + | y := x.(int) + | _ = y + |} + |""".stripMargin) + + "produce a cast operator call node" in { + val List(castCall) = cpg.call(Operators.cast).l + castCall.typeFullName shouldBe "int" + } + } + + "Literal type inference per Go spec" should { + val cpg = code(""" + |package main + |func main() { + | a := 42 + | b := 3.14 + | c := "hello" + |} + |""".stripMargin) + + "infer int for integer literals" in { + val List(lit) = cpg.literal.code("42").l + lit.typeFullName shouldBe "int" + } + + "infer float64 for float literals" in { + val List(lit) = cpg.literal.code("3.14").l + lit.typeFullName shouldBe "float64" + } + + "infer string for string literals" in { + val List(lit) = cpg.literal.code("\"hello\"").l + lit.typeFullName shouldBe "string" + } + } +} diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/InterfaceTypeResolutionTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/InterfaceTypeResolutionTests.scala new file mode 100644 index 000000000000..8fd3c14dd558 --- /dev/null +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/InterfaceTypeResolutionTests.scala @@ -0,0 +1,41 @@ +package io.joern.go2cpg.passes.ast + +import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite +import io.shiftleft.semanticcpg.language.* + +import java.io.File + +class InterfaceTypeResolutionTests extends GoCodeToCpgSuite { + + "Interface method metadata should be recorded" should { + val cpg = code( + """ + |module joern.io/sample + |go 1.18 + |""".stripMargin, + "go.mod" + ).moreCode( + """ + |package lib + |type Speaker interface { + | Speak() string + |} + |""".stripMargin, + Seq("lib", "iface.go").mkString(File.separator) + ).moreCode( + """ + |package main + |import "joern.io/sample/lib" + |func greet(s lib.Speaker) string { + | return s.Speak() + |} + |""".stripMargin, + "main.go" + ) + + "resolve interface method call return type" in { + val List(speakCall) = cpg.call("Speak").l + speakCall.typeFullName shouldBe "string" + } + } +} diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeFullNameTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeFullNameTests.scala index 5d803f4310e5..a1d018c5ab60 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeFullNameTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/TypeFullNameTests.scala @@ -52,7 +52,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite { "check for LITERAL nodes types" in { val List(a, b, c, d, e, _, _) = cpg.literal.l a.typeFullName shouldBe "int" - b.typeFullName shouldBe "float32" + b.typeFullName shouldBe "float64" c.typeFullName shouldBe "string" d.typeFullName shouldBe "bool" e.typeFullName shouldBe "bool" @@ -66,7 +66,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite { "Check for local nodes" in { val List(a, b, c, d, e, f) = cpg.local.l a.typeFullName shouldBe "int" - b.typeFullName shouldBe "float32" + b.typeFullName shouldBe "float64" c.typeFullName shouldBe "string" d.typeFullName shouldBe "bool" e.typeFullName shouldBe "bool" @@ -76,7 +76,7 @@ class TypeFullNameTests extends GoCodeToCpgSuite { "check for identifier nodes" in { val List(a, b, c, d, e, f) = cpg.identifier.l a.typeFullName shouldBe "int" - b.typeFullName shouldBe "float32" + b.typeFullName shouldBe "float64" c.typeFullName shouldBe "string" d.typeFullName shouldBe "bool" e.typeFullName shouldBe "bool" @@ -527,19 +527,19 @@ class TypeFullNameTests extends GoCodeToCpgSuite { "Check for local nodes" in { val List(typefullname) = cpg.local.typeFullName.dedup.l - typefullname shouldBe "float32" + typefullname shouldBe "float64" } "check for identifier nodes" in { val List(typefullname) = cpg.identifier.typeFullName.dedup.l - typefullname shouldBe "float32" + typefullname shouldBe "float64" } "Operator += call node type check" in { val List(a, b, c) = cpg.call(Operators.assignmentPlus).l - a.typeFullName shouldBe "float32" - b.typeFullName shouldBe "float32" - c.typeFullName shouldBe "float32" + a.typeFullName shouldBe "float64" + b.typeFullName shouldBe "float64" + c.typeFullName shouldBe "float64" } }