Skip to content

Commit a2428f8

Browse files
authored
make queryInit explicitly return a Query to avoid casts (#1099)
note: in order to reference `Query` from the macro i needed to pull it out to another common dependency - you probably want to move it somewhere else
1 parent 8f7d5ba commit a2428f8

4 files changed

Lines changed: 27 additions & 24 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package io.shiftleft.console
2+
3+
import io.shiftleft.codepropertygraph.generated.nodes
4+
import io.shiftleft.codepropertygraph.Cpg
5+
import overflowdb.traversal.Traversal
6+
7+
case class Query(name: String,
8+
author: String,
9+
title: String,
10+
description: String,
11+
score: Double,
12+
// intended to be filled by com.lihaoyi.sourcecode.Line
13+
docStartLine: Int = 0,
14+
traversal: Cpg => Traversal[nodes.StoredNode],
15+
// intended to be filled by com.lihaoyi.sourcecode.Line
16+
docEndLine: Int = 0,
17+
// intended to be filled by com.lihaoyi.sourcecode.FileName
18+
docFileName: String = "",
19+
traversalAsString: String = "",
20+
tags: List[String] = List())

console/src/main/scala/io/shiftleft/console/QueryDatabase.scala

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
package io.shiftleft.console
22

3-
import io.shiftleft.codepropertygraph.Cpg
4-
import io.shiftleft.codepropertygraph.generated.nodes
53
import org.reflections8.Reflections
64
import org.reflections8.util.{ClasspathHelper, ConfigurationBuilder}
75
import org.slf4j.{Logger, LoggerFactory}
8-
import overflowdb.traversal.Traversal
9-
106
import scala.annotation.{StaticAnnotation, unused}
117
import scala.jdk.CollectionConverters._
128
import scala.reflect.runtime.universe._
@@ -15,21 +11,6 @@ import scala.reflect.runtime.{universe => ru}
1511
trait QueryBundle
1612
class q() extends StaticAnnotation
1713

18-
case class Query(name: String,
19-
author: String,
20-
title: String,
21-
description: String,
22-
score: Double,
23-
// intended to be filled by com.lihaoyi.sourcecode.Line
24-
docStartLine: Int = 0,
25-
traversal: Cpg => Traversal[nodes.StoredNode],
26-
// intended to be filled by com.lihaoyi.sourcecode.Line
27-
docEndLine: Int = 0,
28-
// intended to be filled by com.lihaoyi.sourcecode.FileName
29-
docFileName: String = "",
30-
traversalAsString: String = "",
31-
tags: List[String] = List())
32-
3314
class QueryDatabase(defaultArgumentProvider: DefaultArgumentProvider = new DefaultArgumentProvider,
3415
namespace: String = "io.joern.scanners") {
3516

console/src/test/scala/io/shiftleft/console/QueryDatabaseTests.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import io.shiftleft.semanticcpg.language._
44
import io.shiftleft.codepropertygraph.Cpg
55
import org.scalatest.matchers.should
66
import org.scalatest.wordspec.AnyWordSpec
7-
import overflowdb.traversal.Traversal
87
import io.shiftleft.macros.QueryMacros
98

109
object TestBundle extends QueryBundle {
@@ -43,7 +42,9 @@ class QueryDatabaseTests extends AnyWordSpec with should.Matchers {
4342
"an-author",
4443
"a-title",
4544
"a-description",
46-
2.0, { cpg: Cpg => cpg.method }).asInstanceOf[Query]
45+
2.0,
46+
{ cpg: Cpg => cpg.method }
47+
)
4748
query.title shouldBe "a-title"
4849
query.traversalAsString shouldBe "cpg: Cpg => cpg.method"
4950
}

macros/src/main/scala/io/shiftleft/macros/Macros.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.shiftleft.macros
22

33
import io.shiftleft.codepropertygraph.Cpg
44
import io.shiftleft.codepropertygraph.generated.nodes
5+
import io.shiftleft.console.Query
56
import overflowdb.traversal.Traversal
67
import scala.language.experimental.macros
78
import scala.reflect.macros.whitebox
@@ -13,21 +14,21 @@ object QueryMacros {
1314
title: String,
1415
description: String,
1516
score: Double,
16-
traversal: Cpg => Traversal[nodes.StoredNode]): Any = macro queryInitImpl
17+
traversal: Cpg => Traversal[nodes.StoredNode]): Query = macro queryInitImpl
1718

1819
def queryInitImpl(c: whitebox.Context)(name: c.Tree,
1920
author: c.Tree,
2021
title: c.Tree,
2122
description: c.Tree,
2223
score: c.Tree,
23-
traversal: c.Tree): c.Expr[Any] = {
24+
traversal: c.Tree): c.Expr[Query] = {
2425
import c.universe._
2526
val fileContent = new String(traversal.pos.source.content)
2627
val start = traversal.pos.start
2728
val end = traversal.pos.end
2829
val traversalAsString: String = fileContent.slice(start, end)
2930

30-
c.Expr[Any](
31+
c.Expr(
3132
q"""
3233
Query(
3334
name = $name,

0 commit comments

Comments
 (0)