Skip to content

Commit 629f456

Browse files
authored
[mypyc] Support generic primitives and add some generic primitives for vec (#21656)
Generic primitives have `RTypeVar` types in parameter and/or return types, and these get expanded away when the primitive is added to IR. Generic primitives let us use lowering for various `vec` operations, many of which are generic. Using higher-level operations in the IR helps with various optimizations. It's also easier to verify that the generated IR is correct when the IR is less verbose. Add generic primitives for unsafe `vec` get item op as an initial use case. We can later use these for other `vec` operations as well. Used some coding agent assist (mostly for tests).
1 parent 5374fec commit 629f456

18 files changed

Lines changed: 665 additions & 325 deletions

mypyc/ir/ops.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class to enable the new behavior. Sometimes adding a new abstract
3939
RStruct,
4040
RTuple,
4141
RType,
42+
RTypeVar,
4243
RUnion,
4344
RVec,
4445
RVoid,
@@ -703,7 +704,7 @@ def __init__(
703704
self,
704705
name: str,
705706
arg_types: list[RType],
706-
return_type: RType, # TODO: What about generic?
707+
return_type: RType,
707708
var_arg_type: RType | None,
708709
truncated_type: RType | None,
709710
c_function_name: str | None,
@@ -716,6 +717,7 @@ def __init__(
716717
is_pure: bool,
717718
experimental: bool,
718719
dependencies: list[Dependency] | None,
720+
type_params: list[RTypeVar] | None,
719721
) -> None:
720722
# Each primitive much have a distinct name, but otherwise they are arbitrary.
721723
self.name: Final = name
@@ -749,6 +751,7 @@ def __init__(
749751
# If this flag is set, the primitive has native integer types and must
750752
# be matched using more complex rules.
751753
self.is_ambiguous = any(has_fixed_width_int(t) for t in arg_types)
754+
self.type_params = None if not type_params else type_params
752755

753756
def __repr__(self) -> str:
754757
return f"<PrimitiveDescription {self.name!r}: {self.arg_types}>"
@@ -776,11 +779,23 @@ class PrimitiveOp(RegisterOp):
776779
code paths for short and long representations.
777780
"""
778781

779-
def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
782+
def __init__(
783+
self,
784+
args: list[Value],
785+
desc: PrimitiveDescription,
786+
line: int = -1,
787+
*,
788+
arg_types: list[RType] | None = None,
789+
return_type: RType | None = None,
790+
type_args: list[RType] | None = None,
791+
) -> None:
780792
self.error_kind = desc.error_kind
781793
super().__init__(line)
782794
self.args = args
783-
self.type = desc.return_type
795+
self.arg_types = arg_types if arg_types is not None else desc.arg_types
796+
self.type = return_type if return_type is not None else desc.return_type
797+
self.is_borrowed = desc.is_borrowed
798+
self.type_args = type_args
784799
self.desc = desc
785800

786801
def sources(self) -> list[Value]:

mypyc/ir/pprint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,15 @@ def visit_call_c(self, op: CallC) -> str:
231231

232232
def visit_primitive_op(self, op: PrimitiveOp) -> str:
233233
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
234+
if op.type_args:
235+
joined = ", ".join(str(arg) for arg in op.type_args)
236+
type_args = f"[{joined}]"
237+
else:
238+
type_args = ""
234239
if op.is_void:
235-
return self.format("%s %s", op.desc.name, args_str)
240+
return self.format("%s%s %s", op.desc.name, type_args, args_str)
236241
else:
237-
return self.format("%r = %s %s", op, op.desc.name, args_str)
242+
return self.format("%r = %s%s %s", op, op.desc.name, type_args, args_str)
238243

239244
def visit_truncate(self, op: Truncate) -> str:
240245
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)

mypyc/ir/rtypes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def visit_rprimitive(self, typ: RPrimitive, /) -> T:
152152
def visit_rinstance(self, typ: RInstance, /) -> T:
153153
raise NotImplementedError
154154

155+
def visit_rtypevar(self, typ: RTypeVar, /) -> T:
156+
raise RuntimeError("RTypeVar should not be encountered here")
157+
155158
@abstractmethod
156159
def visit_rvec(self, typ: RVec, /) -> T:
157160
raise NotImplementedError
@@ -747,6 +750,12 @@ def visit_rarray(self, t: RArray) -> str:
747750
def visit_rvoid(self, t: RVoid) -> str:
748751
assert False, "rvoid in tuple?"
749752

753+
def visit_rtypevar(self, typ: RTypeVar) -> str:
754+
# We need to return something to support generic RTuples, etc. Make sure
755+
# the return value is invalid C so that generic RTuples must be expanded
756+
# before they can be used in IR.
757+
return f"!RTypeVar {typ.id} invalid!"
758+
750759

751760
@final
752761
class RTuple(RType):
@@ -1013,6 +1022,50 @@ def serialize(self) -> str:
10131022
return self.name
10141023

10151024

1025+
@final
1026+
class RTypeVar(RType):
1027+
"""Type variable type used for generic primitive ops.
1028+
1029+
This allows having generic primitive operations like vec get item, which is
1030+
parametrized by the vec item type.
1031+
1032+
These types are not valid in any other context outside PrimitiveDescription,
1033+
and they will always be substituted during the construction of a PrimitiveOp.
1034+
1035+
NOTE: This is not related to mypy's TypeVarType!
1036+
"""
1037+
1038+
def __init__(self, id: int) -> None:
1039+
self.id = id
1040+
1041+
@property
1042+
def may_be_immortal(self) -> bool:
1043+
# RTypeVar must always be substituted before use, so this should never matter.
1044+
return False
1045+
1046+
def accept(self, visitor: RTypeVisitor[T]) -> T:
1047+
return visitor.visit_rtypevar(self)
1048+
1049+
def __str__(self) -> str:
1050+
return f"<RTypeVar {self.id}>"
1051+
1052+
def __repr__(self) -> str:
1053+
return f"<RTypeVar {self.id}>"
1054+
1055+
def __eq__(self, other: object) -> TypeGuard[RTypeVar]:
1056+
return isinstance(other, RTypeVar) and other.id == self.id
1057+
1058+
def __hash__(self) -> int:
1059+
return self.id ^ 12345
1060+
1061+
def serialize(self) -> JsonDict:
1062+
return {".class": "RTypeVar", "id": self.id}
1063+
1064+
@classmethod
1065+
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RTypeVar:
1066+
return RTypeVar(data["id"])
1067+
1068+
10161069
@final
10171070
class RVec(RType):
10181071
"""librt.vecs.vec[T]"""

mypyc/irbuild/ll_builder.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@
206206
new_tuple_with_length_op,
207207
sequence_tuple_op,
208208
)
209+
from mypyc.rt_expandtype import expand_rtype
209210
from mypyc.rt_subtype import is_runtime_subtype
210211
from mypyc.sametype import is_same_type
211212
from mypyc.subtype import is_subtype
@@ -2334,6 +2335,7 @@ def primitive_op(
23342335
args: list[Value],
23352336
line: int,
23362337
result_type: RType | None = None,
2338+
type_args: list[RType] | None = None,
23372339
) -> Value:
23382340
"""Add a primitive op."""
23392341
# Does this primitive map into calling a Python C API
@@ -2363,18 +2365,37 @@ def primitive_op(
23632365
# This primitive gets transformed in a lowering pass to
23642366
# lower-level IR ops using a custom transform function.
23652367

2368+
# Evaluate argument and return types for generic primitives
2369+
return_type = None
2370+
if desc.type_params is not None:
2371+
assert type_args is not None, "Generic primitive op requires explicit type arguments"
2372+
assert len(type_args) == len(desc.type_params)
2373+
arg_types = [expand_rtype(arg_type, type_args) for arg_type in desc.arg_types]
2374+
return_type = expand_rtype(desc.return_type, type_args)
2375+
else:
2376+
arg_types = desc.arg_types
2377+
23662378
coerced = []
23672379
# Coerce fixed number arguments
2368-
for i in range(min(len(args), len(desc.arg_types))):
2369-
formal_type = desc.arg_types[i]
2380+
for i in range(min(len(args), len(arg_types))):
2381+
formal_type = arg_types[i]
23702382
arg = args[i]
23712383
assert formal_type is not None # TODO
23722384
arg = self.coerce(arg, formal_type, line)
23732385
coerced.append(arg)
23742386
assert desc.ordering is None
23752387
assert desc.var_arg_type is None
23762388
assert not desc.extra_int_constants
2377-
target = self.add(PrimitiveOp(coerced, desc, line=line))
2389+
target = self.add(
2390+
PrimitiveOp(
2391+
coerced,
2392+
desc,
2393+
line=line,
2394+
arg_types=arg_types,
2395+
return_type=return_type,
2396+
type_args=type_args,
2397+
)
2398+
)
23782399
if desc.is_borrowed:
23792400
# If the result is borrowed, force the arguments to be
23802401
# kept alive afterwards, as otherwise the result might be

mypyc/irbuild/vec.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
vec_api_by_item_type,
5353
vec_item_type_tags,
5454
)
55+
from mypyc.primitives.librt_vecs_ops import vec_get_item_unsafe_borrow_op, vec_get_item_unsafe_op
5556

5657
if TYPE_CHECKING:
5758
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
@@ -316,9 +317,10 @@ def vec_get_item(
316317
) -> Value:
317318
"""Generate inlined vec __getitem__ call.
318319
319-
We inline this, since it's simple but performance-critical.
320+
We inline the length and bounds check, since they are simple but
321+
performance-critical. The actual item load is emitted as a generic primitive
322+
op that is lowered later.
320323
"""
321-
# TODO: Support more item types
322324
# TODO: Support more index types
323325
len_val = vec_len(builder, base)
324326
index = vec_check_and_adjust_index(builder, len_val, index, line)
@@ -328,7 +330,22 @@ def vec_get_item(
328330
def vec_get_item_unsafe(
329331
builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False
330332
) -> Value:
331-
"""Get vec item, assuming index is non-negative and within bounds."""
333+
"""Get vec item, assuming index is non-negative and within bounds.
334+
335+
This emits a generic primitive op that is inlined during lowering.
336+
"""
337+
assert isinstance(base.type, RVec)
338+
if can_borrow:
339+
desc = vec_get_item_unsafe_borrow_op
340+
else:
341+
desc = vec_get_item_unsafe_op
342+
return builder.primitive_op(desc, [base, index], line, type_args=[base.type.item_type])
343+
344+
345+
def vec_get_item_unsafe_lower(
346+
builder: LowLevelIRBuilder, base: Value, index: Value, line: int, *, can_borrow: bool = False
347+
) -> Value:
348+
"""Generate the low-level IR for an unsafe vec item load."""
332349
assert isinstance(base.type, RVec)
333350
index = as_platform_int(builder, index, line)
334351
vtype = base.type

mypyc/lower/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def wrapper(f: LF) -> LF:
2626

2727

2828
# Import various modules that set up global state.
29-
from mypyc.lower import int_ops, list_ops, misc_ops # noqa: F401
29+
from mypyc.lower import int_ops, list_ops, misc_ops, vec_ops # noqa: F401

mypyc/lower/vec_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
from mypyc.ir.ops import Value
4+
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
5+
from mypyc.irbuild.vec import vec_get_item_unsafe_lower
6+
from mypyc.lower.registry import lower_primitive_op
7+
8+
9+
@lower_primitive_op("vec_get_item_unsafe")
10+
def vec_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
11+
base, index = args
12+
return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=False)
13+
14+
15+
@lower_primitive_op("vec_get_item_unsafe_borrow")
16+
def vec_get_item_unsafe_borrow(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
17+
base, index = args
18+
return vec_get_item_unsafe_lower(builder, base, index, line, can_borrow=True)

mypyc/primitives/librt_vecs_ops.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from mypyc.ir.deps import LIBRT_VECS, VECS_EXTRA_OPS
22
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
33
from mypyc.ir.rtypes import (
4+
RTypeVar,
45
RVec,
56
bit_rprimitive,
67
bytes_rprimitive,
8+
int64_rprimitive,
79
object_rprimitive,
810
uint8_rprimitive,
911
)
10-
from mypyc.primitives.registry import function_op
12+
from mypyc.primitives.registry import custom_primitive_op, function_op
1113

1214
# isinstance(obj, vec)
1315
isinstance_vec = function_op(
@@ -28,3 +30,22 @@
2830
error_kind=ERR_MAGIC,
2931
dependencies=[LIBRT_VECS, VECS_EXTRA_OPS],
3032
)
33+
34+
# Get vec item, assuming the index is valid (no bounds check)
35+
vec_get_item_unsafe_op = custom_primitive_op(
36+
name="vec_get_item_unsafe",
37+
arg_types=[RVec(RTypeVar(0)), int64_rprimitive],
38+
return_type=RTypeVar(0),
39+
error_kind=ERR_NEVER,
40+
type_params=[RTypeVar(0)],
41+
)
42+
43+
# Like vec_get_item_unsafe, but the result is a borrowed reference
44+
vec_get_item_unsafe_borrow_op = custom_primitive_op(
45+
name="vec_get_item_unsafe_borrow",
46+
arg_types=[RVec(RTypeVar(0)), int64_rprimitive],
47+
is_borrowed=True,
48+
return_type=RTypeVar(0),
49+
error_kind=ERR_NEVER,
50+
type_params=[RTypeVar(0)],
51+
)

mypyc/primitives/registry.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from mypyc.ir.deps import Dependency
4343
from mypyc.ir.ops import PrimitiveDescription, StealsDescription
44-
from mypyc.ir.rtypes import RType
44+
from mypyc.ir.rtypes import RType, RTypeVar
4545

4646
# Error kind for functions that return negative integer on exception. This
4747
# is only used for primitives. We translate it away during IR building.
@@ -154,6 +154,7 @@ def method_op(
154154
is_pure=is_pure,
155155
experimental=experimental,
156156
dependencies=dependencies,
157+
type_params=None,
157158
)
158159
ops.append(desc)
159160
return desc
@@ -204,6 +205,7 @@ def function_op(
204205
is_pure=False,
205206
experimental=experimental,
206207
dependencies=dependencies,
208+
type_params=None,
207209
)
208210
ops.append(desc)
209211
return desc
@@ -253,6 +255,7 @@ def binary_op(
253255
is_pure=False,
254256
experimental=False,
255257
dependencies=dependencies,
258+
type_params=None,
256259
)
257260
ops.append(desc)
258261
return desc
@@ -313,6 +316,7 @@ def custom_primitive_op(
313316
is_pure: bool = False,
314317
experimental: bool = False,
315318
dependencies: list[Dependency] | None = None,
319+
type_params: list[RTypeVar] | None = None,
316320
) -> PrimitiveDescription:
317321
"""Define a primitive op that can't be automatically generated based on the AST.
318322
@@ -336,6 +340,7 @@ def custom_primitive_op(
336340
is_pure=is_pure,
337341
experimental=experimental,
338342
dependencies=dependencies,
343+
type_params=type_params,
339344
)
340345

341346

@@ -380,6 +385,7 @@ def unary_op(
380385
is_pure=is_pure,
381386
experimental=False,
382387
dependencies=dependencies,
388+
type_params=None,
383389
)
384390
ops.append(desc)
385391
return desc

0 commit comments

Comments
 (0)