Skip to content

Commit 61ffa65

Browse files
authored
[mypyc] Allow types to have capsule/source file deps (#20667)
An RType may need primitives/definitions from a capsule for things like runtime type checks or unboxing/boxing. Also some types might have custom incref/decref operations defined in a header that is not included by default.
1 parent c17ed50 commit 61ffa65

4 files changed

Lines changed: 189 additions & 5 deletions

File tree

mypyc/analysis/capsule_deps.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from mypyc.ir.deps import Dependency
44
from mypyc.ir.func_ir import FuncIR
5-
from mypyc.ir.ops import CallC, PrimitiveOp
5+
from mypyc.ir.ops import Assign, CallC, PrimitiveOp
6+
from mypyc.ir.rtypes import RStruct, RTuple, RType, RUnion
67

78

89
def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None:
@@ -16,14 +17,51 @@ def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None:
1617
defined in other modules.
1718
"""
1819
deps: set[Dependency] | None = None
20+
# Check function signature types for dependencies
21+
deps = find_type_dependencies(fn, deps)
22+
# Check ops for dependencies
1923
for block in fn.blocks:
2024
for op in block.ops:
21-
# TODO: Also determine implicit type object dependencies (e.g. cast targets)
25+
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
2226
if isinstance(op, CallC) and op.dependencies is not None:
2327
for dep in op.dependencies:
2428
if deps is None:
2529
deps = set()
2630
deps.add(dep)
27-
else:
28-
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
31+
deps = collect_type_deps(op.type, deps)
32+
if isinstance(op, Assign):
33+
deps = collect_type_deps(op.dest.type, deps)
34+
return deps
35+
36+
37+
def find_type_dependencies(fn: FuncIR, deps: set[Dependency] | None) -> set[Dependency] | None:
38+
"""Find dependencies from RTypes in function signatures.
39+
40+
Some RTypes (e.g., those for librt types) have associated dependencies
41+
that need to be imported when the type is used.
42+
"""
43+
# Check parameter types
44+
for arg in fn.decl.sig.args:
45+
deps = collect_type_deps(arg.type, deps)
46+
# Check return type
47+
deps = collect_type_deps(fn.decl.sig.ret_type, deps)
48+
return deps
49+
50+
51+
def collect_type_deps(typ: RType, deps: set[Dependency] | None) -> set[Dependency] | None:
52+
"""Collect dependencies from an RType, recursively checking compound types."""
53+
if typ.dependencies is not None:
54+
for dep in typ.dependencies:
55+
if deps is None:
56+
deps = set()
57+
deps.add(dep)
58+
if isinstance(typ, RUnion):
59+
for item in typ.items:
60+
deps = collect_type_deps(item, deps)
61+
elif isinstance(typ, RTuple):
62+
for item in typ.types:
63+
deps = collect_type_deps(item, deps)
64+
elif isinstance(typ, RStruct):
65+
for item in typ.types:
66+
deps = collect_type_deps(item, deps)
2967
return deps

mypyc/ir/rtypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class to enable the new behavior. In rare cases, adding a new
4141
from typing import TYPE_CHECKING, ClassVar, Final, Generic, TypeGuard, TypeVar, final
4242

4343
from mypyc.common import HAVE_IMMORTAL, IS_32_BIT_PLATFORM, PLATFORM_SIZE, JsonDict, short_name
44+
from mypyc.ir.deps import LIBRT_STRINGS, Dependency
4445
from mypyc.namegen import NameGenerator
4546

4647
if TYPE_CHECKING:
@@ -88,6 +89,7 @@ class RType:
8889
# we never raise an AttributeError and don't need the bitfield
8990
# entry.)
9091
error_overlap = False
92+
dependencies: tuple[Dependency, ...] | None = None
9193

9294
@abstractmethod
9395
def accept(self, visitor: RTypeVisitor[T]) -> T:
@@ -232,6 +234,7 @@ def __init__(
232234
size: int = PLATFORM_SIZE,
233235
error_overlap: bool = False,
234236
may_be_immortal: bool = True,
237+
dependencies: tuple[Dependency, ...] | None = None,
235238
) -> None:
236239
RPrimitive.primitive_map[name] = self
237240

@@ -244,6 +247,7 @@ def __init__(
244247
self.size = size
245248
self.error_overlap = error_overlap
246249
self._may_be_immortal = may_be_immortal and HAVE_IMMORTAL
250+
self.dependencies = dependencies
247251
if ctype == "CPyTagged":
248252
self.c_undefined = "CPY_INT_TAG"
249253
elif ctype in ("int16_t", "int32_t", "int64_t"):
@@ -517,7 +521,7 @@ def __hash__(self) -> int:
517521
range_rprimitive: Final = RPrimitive("builtins.range", is_unboxed=False, is_refcounted=True)
518522

519523
KNOWN_NATIVE_TYPES: Final = {
520-
name: RPrimitive(name, is_unboxed=False, is_refcounted=True)
524+
name: RPrimitive(name, is_unboxed=False, is_refcounted=True, dependencies=(LIBRT_STRINGS,))
521525
for name in [
522526
"librt.internal.WriteBuffer",
523527
"librt.internal.ReadBuffer",

mypyc/test-data/capsule-deps.test

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
[case testNoDeps]
2+
def f() -> None:
3+
pass
4+
[out]
5+
No deps
6+
7+
[case testStringsNoCapsuleDepWithoutExperimental]
8+
# Test case missing _experimental suffix -> no deps
9+
from librt.strings import StringWriter
10+
11+
def f() -> None:
12+
StringWriter()
13+
[out]
14+
No deps
15+
16+
[case testStringsCapsuleDep_experimental]
17+
from librt.strings import StringWriter
18+
19+
def f() -> None:
20+
StringWriter()
21+
[out]
22+
Capsule(name='librt.strings')
23+
24+
[case testStringsCapsuleDepFromParamType_experimental]
25+
from librt.strings import StringWriter
26+
27+
def f(s: StringWriter) -> None:
28+
pass
29+
[out]
30+
Capsule(name='librt.strings')
31+
32+
[case testStringsCapsuleDepFromReturnType_experimental]
33+
from librt.strings import StringWriter
34+
35+
def f() -> StringWriter:
36+
assert False
37+
[out]
38+
Capsule(name='librt.strings')
39+
40+
[case testStringsCapsuleDepFromUnion_experimental]
41+
from typing import Union
42+
43+
from librt.strings import StringWriter
44+
45+
def f(s: Union[StringWriter, str]) -> None:
46+
pass
47+
[out]
48+
Capsule(name='librt.strings')
49+
50+
[case testStringsCapsuleDepFromTuple_experimental]
51+
from librt.strings import StringWriter
52+
53+
def f(s: tuple[int, StringWriter]) -> None:
54+
pass
55+
[out]
56+
Capsule(name='librt.strings')
57+
58+
[case testExtraFileDep_experimental]
59+
from librt.strings import StringWriter
60+
61+
def f(s: StringWriter) -> int:
62+
return s[0]
63+
[out]
64+
Capsule(name='librt.strings')
65+
SourceDep(path='stringwriter_extra_ops.c')
66+
67+
[case testMultipleCapsuleDeps1_experimental]
68+
from librt.strings import StringWriter
69+
from librt.base64 import b64encode
70+
71+
def f() -> None:
72+
StringWriter()
73+
74+
def g() -> bytes:
75+
return b64encode(b'foo')
76+
[out]
77+
Capsule(name='librt.base64')
78+
Capsule(name='librt.strings')
79+
80+
[case testMultipleCapsuleDeps2_experimental]
81+
from librt.strings import StringWriter
82+
from librt.base64 import b64encode
83+
84+
def f() -> bytes:
85+
StringWriter()
86+
return b64encode(b'foo')
87+
[out]
88+
Capsule(name='librt.base64')
89+
Capsule(name='librt.strings')

mypyc/test/test_capsule_deps.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Test cases for capsule dependency analysis."""
2+
3+
from __future__ import annotations
4+
5+
import os.path
6+
7+
from mypy.errors import CompileError
8+
from mypy.test.config import test_temp_dir
9+
from mypy.test.data import DataDrivenTestCase
10+
from mypyc.analysis.capsule_deps import find_implicit_op_dependencies
11+
from mypyc.common import TOP_LEVEL_NAME
12+
from mypyc.options import CompilerOptions
13+
from mypyc.test.testutil import (
14+
ICODE_GEN_BUILTINS,
15+
MypycDataSuite,
16+
assert_test_output,
17+
build_ir_for_single_file,
18+
infer_ir_build_options_from_test_name,
19+
use_custom_builtins,
20+
)
21+
from mypyc.transform.lower import lower_ir
22+
23+
files = ["capsule-deps.test"]
24+
25+
26+
class TestCapsuleDeps(MypycDataSuite):
27+
files = files
28+
base_path = test_temp_dir
29+
30+
def run_case(self, testcase: DataDrivenTestCase) -> None:
31+
options = infer_ir_build_options_from_test_name(testcase.name)
32+
if options is None:
33+
# Skipped test case
34+
return
35+
with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase):
36+
try:
37+
ir = build_ir_for_single_file(testcase.input, options)
38+
except CompileError as e:
39+
actual = e.messages
40+
else:
41+
all_deps: set[str] = set()
42+
for fn in ir:
43+
if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"):
44+
continue
45+
compiler_options = CompilerOptions()
46+
lower_ir(fn, compiler_options)
47+
deps = find_implicit_op_dependencies(fn)
48+
if deps:
49+
for dep in deps:
50+
all_deps.add(repr(dep))
51+
actual = sorted(all_deps) if all_deps else ["No deps"]
52+
53+
assert_test_output(testcase, actual, "Invalid test output", testcase.output)

0 commit comments

Comments
 (0)