Skip to content

Commit 1ae52d1

Browse files
πŸ› Handle cross-file factory functions (#120)
1 parent 937c2a8 commit 1ae52d1

File tree

9 files changed

+154
-1
lines changed

9 files changed

+154
-1
lines changed

β€Žsrc/core/analyzer.tsβ€Ž

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
collectRecognizedNames,
99
collectStringVariables,
1010
decoratorExtractor,
11+
factoryCallExtractor,
1112
getNodesByType,
1213
importExtractor,
1314
includeRouterExtractor,
@@ -46,6 +47,10 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
4647
// Get all router assignments
4748
const assignments = nodesByType.get("assignment") ?? []
4849
const { fastAPINames, apiRouterNames } = collectRecognizedNames(nodesByType)
50+
const knownConstructors = new Set([...fastAPINames, ...apiRouterNames])
51+
const factoryCalls = assignments
52+
.map((node) => factoryCallExtractor(node, knownConstructors))
53+
.filter(notNull)
4954
const routers = assignments
5055
.map((node) => routerExtractor(node, apiRouterNames, fastAPINames))
5156
.filter(notNull)
@@ -84,6 +89,7 @@ export function analyzeTree(tree: Tree, filePath: string): FileAnalysis {
8489
includeRouters,
8590
mounts,
8691
imports,
92+
factoryCalls,
8793
}
8894
}
8995

β€Žsrc/core/extractors.tsβ€Ž

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import type { Node } from "web-tree-sitter"
66
import type {
7+
FactoryCallInfo,
78
ImportedName,
89
ImportInfo,
910
IncludeRouterInfo,
@@ -582,3 +583,41 @@ export function mountExtractor(node: Node): MountInfo | null {
582583
app: appNode?.text ?? "",
583584
}
584585
}
586+
587+
export function factoryCallExtractor(
588+
node: Node,
589+
knownConstructors: Set<string>,
590+
): FactoryCallInfo | null {
591+
if (node.type !== "assignment") {
592+
return null
593+
}
594+
595+
const variableNameNode = node.childForFieldName("left")
596+
const valueNode = node.childForFieldName("right")
597+
if (!variableNameNode || valueNode?.type !== "call") {
598+
return null
599+
}
600+
601+
const functionNode = valueNode.childForFieldName("function")
602+
if (functionNode?.type !== "identifier") {
603+
return null
604+
}
605+
606+
const functionName = functionNode.text
607+
if (knownConstructors.has(functionName)) {
608+
return null
609+
}
610+
611+
// Skip function and class-local variables to avoid false positives
612+
if (
613+
hasAncestor(node, "function_definition") ||
614+
hasAncestor(node, "class_definition")
615+
) {
616+
return null
617+
}
618+
619+
return {
620+
variableName: variableNameNode.text,
621+
functionName: functionName,
622+
}
623+
}

β€Žsrc/core/internal.tsβ€Ž

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,19 @@ export interface MountInfo {
7777
app: string
7878
}
7979

80+
export interface FactoryCallInfo {
81+
variableName: string
82+
functionName: string
83+
}
84+
8085
export interface FileAnalysis {
8186
filePath: string
8287
routes: RouteInfo[]
8388
routers: RouterInfo[]
8489
includeRouters: IncludeRouterInfo[]
8590
mounts: MountInfo[]
8691
imports: ImportInfo[]
92+
factoryCalls: FactoryCallInfo[]
8793
}
8894

8995
export interface RouterNode {

β€Žsrc/core/routerResolver.tsβ€Ž

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,52 @@ async function buildRouterGraphInternal(
197197
}
198198
}
199199

200-
if (!appRouter || !analysis) {
200+
// Factory function in another module: if the entrypoint variable is assigned via
201+
// `app = create_app()` where `create_app` is imported, follow the import to the
202+
// factory file and build the router graph from there. This works because
203+
// routerExtractor and includeRouterExtractor recurse into function bodies, so
204+
// `app = FastAPI()` and `app.include_router(...)` inside `create_app` are visible
205+
// when analyzing the factory file directly.
206+
if (!appRouter && targetVariable) {
207+
const factoryCall = analysis.factoryCalls.find(
208+
(fc) => fc.variableName === targetVariable,
209+
)
210+
if (factoryCall) {
211+
const matchingImport = analysis.imports.find((imp) =>
212+
imp.names.includes(factoryCall.functionName),
213+
)
214+
if (matchingImport) {
215+
const namedImport = matchingImport.namedImports.find(
216+
(ni) => (ni.alias ?? ni.name) === factoryCall.functionName,
217+
)
218+
const originalName = namedImport?.name ?? factoryCall.functionName
219+
const factoryFileUri = await resolveNamedImport(
220+
{
221+
modulePath: matchingImport.modulePath,
222+
names: [originalName],
223+
isRelative: matchingImport.isRelative,
224+
relativeDots: matchingImport.relativeDots,
225+
},
226+
entryFileUri,
227+
projectRootUri,
228+
fs,
229+
analyzeFileFn,
230+
)
231+
if (factoryFileUri && !visited.has(factoryFileUri)) {
232+
const factoryGraph = await buildRouterGraphInternal(
233+
factoryFileUri,
234+
ctx,
235+
)
236+
if (factoryGraph) {
237+
factoryGraph.variableName = targetVariable
238+
return factoryGraph
239+
}
240+
}
241+
}
242+
}
243+
}
244+
245+
if (!appRouter) {
201246
return null
202247
}
203248

β€Žsrc/test/core/routerResolver.test.tsβ€Ž

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,39 @@ suite("routerResolver", () => {
571571
assert.strictEqual(result, null)
572572
})
573573

574+
test("follows imported factory function to resolve include_router calls", async () => {
575+
const result = await buildRouterGraph(
576+
fixtures.factoryFunc.factoryMainPy,
577+
parser,
578+
fixtures.factoryFunc.root,
579+
nodeFileSystem,
580+
"app",
581+
)
582+
583+
assert.ok(result, "Should find app via imported factory function")
584+
assert.strictEqual(result.type, "FastAPI")
585+
assert.strictEqual(result.variableName, "app")
586+
assert.strictEqual(
587+
result.children.length,
588+
1,
589+
"Should have one included router",
590+
)
591+
assert.ok(
592+
result.children[0].router.routes.length >= 2,
593+
"Should have routes from routers.py",
594+
)
595+
})
596+
597+
test("returns null without targetVariable when factory function has no local routes", async () => {
598+
const result = await buildRouterGraph(
599+
fixtures.factoryFunc.factoryMainPy,
600+
parser,
601+
fixtures.factoryFunc.root,
602+
nodeFileSystem,
603+
)
604+
assert.strictEqual(result, null)
605+
})
606+
574607
test("resolves custom APIRouter subclass as child router", async () => {
575608
const result = await buildRouterGraph(
576609
fixtures.customSubclass.mainPy,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from fastapi import FastAPI
2+
from routers import router
23

34

45
def get_fastapi_app() -> FastAPI:
56
return FastAPI()
7+
8+
9+
def create_app() -> FastAPI:
10+
app = FastAPI()
11+
app.include_router(router, prefix="/users")
12+
return app
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from app import create_app
2+
3+
app = create_app()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from fastapi import APIRouter
2+
3+
router = APIRouter()
4+
5+
6+
@router.get("/")
7+
def list_users():
8+
return []
9+
10+
11+
@router.get("/{user_id}")
12+
def get_user(user_id: int):
13+
return {"id": user_id}

β€Žsrc/test/testUtils.tsβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export const fixtures = {
5050
factoryFunc: {
5151
root: uri(join(fixturesPath, "factory-func")),
5252
mainPy: uri(join(fixturesPath, "factory-func", "main.py")),
53+
factoryMainPy: uri(join(fixturesPath, "factory-func", "factory_main.py")),
5354
},
5455
flat: {
5556
root: uri(join(fixturesPath, "flat")),

0 commit comments

Comments
Β (0)