Skip to content

Commit 7c64223

Browse files
♻️ Traverse AST in a single pass during file analysis (#109)
1 parent 9fbe792 commit 7c64223

File tree

4 files changed

+189
-174
lines changed

4 files changed

+189
-174
lines changed

src/core/analyzer.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
collectRecognizedNames,
99
collectStringVariables,
1010
decoratorExtractor,
11-
findNodesByType,
11+
getNodesByType,
1212
importExtractor,
1313
includeRouterExtractor,
1414
mountExtractor,
@@ -37,32 +37,32 @@ function resolveVariables(
3737

3838
/** Analyze a syntax tree and extract FastAPI-related information */
3939
export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
40-
const rootNode = tree.rootNode
40+
const nodesByType = getNodesByType(tree.rootNode)
4141

4242
// Get all decorated definitions (functions and classes with decorators)
43-
const decoratedDefs = findNodesByType(rootNode, "decorated_definition")
43+
const decoratedDefs = nodesByType.get("decorated_definition") ?? []
4444
const routes = decoratedDefs.map(decoratorExtractor).filter(notNull)
4545

4646
// Get all router assignments
47-
const assignments = findNodesByType(rootNode, "assignment")
48-
const { fastAPINames, apiRouterNames } = collectRecognizedNames(rootNode)
47+
const assignments = nodesByType.get("assignment") ?? []
48+
const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType)
4949
const routers = assignments
5050
.map((node) => routerExtractor(node, apiRouterNames, fastAPINames))
5151
.filter(notNull)
5252

5353
// Get all include_router and mount calls
54-
const callNodes = findNodesByType(rootNode, "call")
54+
const callNodes = nodesByType.get("call") ?? []
5555
const includeRouters = callNodes.map(includeRouterExtractor).filter(notNull)
5656
const mounts = callNodes.map(mountExtractor).filter(notNull)
5757

5858
// Get all import statements
59-
const importNodes = findNodesByType(rootNode, "import_statement")
60-
const importFromNodes = findNodesByType(rootNode, "import_from_statement")
59+
const importNodes = nodesByType.get("import_statement") ?? []
60+
const importFromNodes = nodesByType.get("import_from_statement") ?? []
6161
const imports = [...importNodes, ...importFromNodes]
6262
.map(importExtractor)
6363
.filter(notNull)
6464

65-
const stringVariables = collectStringVariables(rootNode)
65+
const stringVariables = collectStringVariables(nodesByType)
6666

6767
for (const route of routes) {
6868
route.path = resolveVariables(route.path, stringVariables)

src/core/extractors.ts

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ import type {
1414
} from "./internal"
1515
import { ROUTE_METHODS } from "./internal"
1616

17-
/** Recursively finds all nodes of a given type within a subtree */
18-
export function findNodesByType(node: Node, type: string): Node[] {
19-
const results: Node[] = []
20-
collectNodesByType(node, type, results)
21-
return results
22-
}
23-
2417
function stripDocstring(raw: string): string {
2518
let content: string
2619
if (
@@ -49,16 +42,24 @@ function stripDocstring(raw: string): string {
4942
return dedented.join("\n").trim()
5043
}
5144

52-
function collectNodesByType(node: Node, type: string, results: Node[]): void {
53-
if (node.type === type) {
54-
results.push(node)
55-
}
56-
for (let i = 0; i < node.childCount; i++) {
57-
const child = node.child(i)
58-
if (child) {
59-
collectNodesByType(child, type, results)
45+
export function getNodesByType(root: Node): Map<string, Node[]> {
46+
const results = new Map<string, Node[]>()
47+
48+
function collectNodesByType(node: Node, results: Map<string, Node[]>): void {
49+
if (!results.has(node.type)) {
50+
results.set(node.type, [])
51+
}
52+
results.get(node.type)!.push(node)
53+
54+
for (let i = 0; i < node.childCount; i++) {
55+
const child = node.child(i)
56+
if (child) {
57+
collectNodesByType(child, results)
58+
}
6059
}
6160
}
61+
collectNodesByType(root, results)
62+
return results
6263
}
6364

6465
/**
@@ -74,9 +75,11 @@ function collectNodesByType(node: Node, type: string, results: Node[]): void {
7475
* settings.PREFIX = "/api" -> (skipped, not a simple identifier)
7576
* def f(): BASE = "/local" -> (skipped, inside function)
7677
*/
77-
export function collectStringVariables(rootNode: Node): Map<string, string> {
78+
export function collectStringVariables(
79+
nodesByType: Map<string, Node[]>,
80+
): Map<string, string> {
7881
const variables = new Map<string, string>()
79-
const assignmentNodes = findNodesByType(rootNode, "assignment")
82+
const assignmentNodes = nodesByType.get("assignment") ?? []
8083

8184
for (const assign of assignmentNodes) {
8285
if (
@@ -172,7 +175,11 @@ export function decoratorExtractor(node: Node): RouteInfo | null {
172175
// Grammar guarantees: decorated_definition always has a first child (the decorator)
173176
const decoratorNode = node.firstNamedChild!
174177

175-
const callNode = findNodesByType(decoratorNode, "call")[0]
178+
const callNode =
179+
decoratorNode.firstNamedChild?.type === "call"
180+
? decoratorNode.firstNamedChild
181+
: null
182+
176183
const functionNode = callNode?.childForFieldName("function")
177184
const argumentsNode = callNode?.childForFieldName("arguments")
178185
const objectNode = functionNode?.childForFieldName("object")
@@ -351,7 +358,8 @@ export function importExtractor(node: Node): ImportInfo | null {
351358
if (node.type === "import_statement") {
352359
let modulePath = ""
353360
// Handle aliased imports: "import fastapi as f"
354-
for (const aliased of findNodesByType(node, "aliased_import")) {
361+
const aliasedImports = getNodesByType(node).get("aliased_import") ?? []
362+
for (const aliased of aliasedImports) {
355363
const nameNode = aliased.childForFieldName("name")
356364
const aliasNode = aliased.childForFieldName("alias")
357365
if (nameNode) {
@@ -362,7 +370,7 @@ export function importExtractor(node: Node): ImportInfo | null {
362370
}
363371
}
364372
// Non-aliased: "import fastapi" or "import fastapi.routing"
365-
const nameNodes = findNodesByType(node, "dotted_name")
373+
const nameNodes = getNodesByType(node).get("dotted_name") ?? []
366374
for (const nameNode of nameNodes) {
367375
if (!hasAncestor(nameNode, "aliased_import")) {
368376
if (!modulePath) modulePath = nameNode.text // preserve full dotted path
@@ -387,7 +395,8 @@ export function importExtractor(node: Node): ImportInfo | null {
387395
)
388396

389397
// Aliased imports (e.g., "router as users_router")
390-
for (const aliased of findNodesByType(node, "aliased_import")) {
398+
const aliasedImports = getNodesByType(node).get("aliased_import") ?? []
399+
for (const aliased of aliasedImports) {
391400
const nameNode = aliased.childForFieldName("name")
392401
const aliasNode = aliased.childForFieldName("alias")
393402
if (nameNode) {
@@ -398,7 +407,7 @@ export function importExtractor(node: Node): ImportInfo | null {
398407
}
399408

400409
// Non-aliased imports (skip first dotted_name which is the module path)
401-
const nameNodes = findNodesByType(node, "dotted_name")
410+
const nameNodes = getNodesByType(node).get("dotted_name") ?? []
402411
for (let i = 1; i < nameNodes.length; i++) {
403412
const nameNode = nameNodes[i]
404413
if (!hasAncestor(nameNode, "aliased_import")) {
@@ -423,15 +432,15 @@ export function importExtractor(node: Node): ImportInfo | null {
423432
* fastAPINames = Set { "FastAPI", "fastapi.FastAPI", "MyApp" }
424433
* apiRouterNames = Set { "APIRouter", "fastapi.APIRouter", "MyRouter", "CustomRouter" }
425434
*/
426-
export function collectRecognizedNames(rootNode: Node): {
435+
export function collectRecognizedNames(nodesByType: Map<string, Node[]>): {
427436
fastAPINames: Set<string>
428437
apiRouterNames: Set<string>
429438
} {
430439
const fastAPINames = new Set<string>(["FastAPI", "fastapi.FastAPI"])
431440
const apiRouterNames = new Set<string>(["APIRouter", "fastapi.APIRouter"])
432441

433442
// Add aliases from "from fastapi import X as Y" imports
434-
for (const node of findNodesByType(rootNode, "import_from_statement")) {
443+
for (const node of nodesByType.get("import_from_statement") ?? []) {
435444
const info = importExtractor(node)
436445
if (!info || info.modulePath !== "fastapi") continue
437446
for (const named of info.namedImports) {
@@ -442,7 +451,7 @@ export function collectRecognizedNames(rootNode: Node): {
442451
}
443452

444453
// Add module aliases from "import fastapi as f" → recognizes f.FastAPI, f.APIRouter
445-
for (const node of findNodesByType(rootNode, "import_statement")) {
454+
for (const node of nodesByType.get("import_statement") ?? []) {
446455
const info = importExtractor(node)
447456
if (!info) continue
448457
for (const named of info.namedImports) {
@@ -456,7 +465,7 @@ export function collectRecognizedNames(rootNode: Node): {
456465

457466
// Add subclasses, checking against the already-accumulated alias sets so
458467
// "class MyRouter(AR)" works when AR is an alias for APIRouter
459-
for (const cls of findNodesByType(rootNode, "class_definition")) {
468+
for (const cls of nodesByType.get("class_definition") ?? []) {
460469
const nameNode = cls.childForFieldName("name")
461470
const superclassesNode = cls.childForFieldName("superclasses")
462471
if (!nameNode || !superclassesNode) continue

0 commit comments

Comments
 (0)