Skip to content

Commit c74f01d

Browse files
committed
Perform runtime type checks
1 parent 84b6a70 commit c74f01d

2 files changed

Lines changed: 31 additions & 2 deletions

File tree

mypyc/codegen/emit.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,13 +705,25 @@ def emit_cast(
705705
self.emit_lines(f" {dest} = {src};", "else {")
706706
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
707707
self.emit_line("}")
708-
elif is_object_rprimitive(typ) or is_native_rprimitive(typ):
708+
elif is_object_rprimitive(typ):
709709
if declare_dest:
710710
self.emit_line(f"PyObject *{dest};")
711711
self.emit_arg_check(src, dest, typ, "", optional)
712712
self.emit_line(f"{dest} = {src};")
713713
if optional:
714714
self.emit_line("}")
715+
elif is_native_rprimitive(typ):
716+
# Native primitive types have type check functions of form "CPy<Name>_Check(...)".
717+
if declare_dest:
718+
self.emit_line(f"PyObject *{dest};")
719+
short_name = typ.name.rsplit(".", 1)[-1]
720+
check = f"(CPy{short_name}_Check({src}))"
721+
if likely:
722+
check = f"(likely{check})"
723+
self.emit_arg_check(src, dest, typ, check, optional)
724+
self.emit_lines(f" {dest} = {src};", "else {")
725+
self.emit_cast_error_handler(error, src, dest, typ, raise_exception)
726+
self.emit_line("}")
715727
elif isinstance(typ, RUnion):
716728
self.emit_union_cast(
717729
src, dest, typ, declare_dest, error, optional, src_type, raise_exception

mypyc/test-data/run-classes.test

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2713,7 +2713,7 @@ Player.MIN = <Player.MIN: 1>
27132713
[case testBufferRoundTrip_librt_internal]
27142714
from __future__ import annotations
27152715

2716-
from typing import Final
2716+
from typing import Final, Any
27172717
from mypy_extensions import u8
27182718
from librt.internal import (
27192719
ReadBuffer, WriteBuffer, write_bool, read_bool, write_str, read_str, write_float, read_float,
@@ -2746,6 +2746,23 @@ def test_buffer_grow() -> None:
27462746
with assertRaises(ValueError):
27472747
read_int(r)
27482748

2749+
def test_buffer_primitive_types() -> None:
2750+
a1: Any = WriteBuffer()
2751+
w: WriteBuffer = a1
2752+
write_str(w, "foo")
2753+
data = w.getvalue()
2754+
assert read_str(ReadBuffer(data)) == "foo"
2755+
a2: Any = ReadBuffer(b"foo")
2756+
with assertRaises(TypeError):
2757+
w2: WriteBuffer = a2
2758+
2759+
a3: Any = ReadBuffer(data)
2760+
r: ReadBuffer = a3
2761+
assert read_str(r) == "foo"
2762+
a4: Any = WriteBuffer()
2763+
with assertRaises(TypeError):
2764+
r2: ReadBuffer = a4
2765+
27492766
def test_buffer_roundtrip() -> None:
27502767
b: WriteBuffer | ReadBuffer
27512768
b = WriteBuffer()

0 commit comments

Comments
 (0)