Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ unreleased: Version 0.18.0
- Replace string literal type annotations with postponed evaluation using
``from __future__ import annotations`` PR #191

Breaking changes:

- ``BasicBlock``, ``Bytecode``, and ``ConcreteBytecode`` now validate inserted
instructions at insertion time (``append``, ``extend``, ``insert``,
``__setitem__``) rather than during iteration. Code that relied on catching
``ValueError`` from ``list(block)`` or ``for instr in block:`` must wrap the
insertion call instead. PR #199

Bugfixes:

- Fix handling of END_ASYNC_FOR which is a backward jump PR #179
Expand Down
30 changes: 25 additions & 5 deletions src/bytecode/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ def legalize(self) -> None:
def _check_instr(self, instr):
raise NotImplementedError()

def append(self, instr: U) -> None: # type: ignore[override]
self._check_instr(instr)
super().append(instr)

def insert(self, index: SupportsIndex, instr: U) -> None: # type: ignore[override]
self._check_instr(instr)
super().insert(index, instr)

def extend(self, instrs) -> None: # type: ignore[override]
instrs = list(instrs)
for instr in instrs:
self._check_instr(instr)
super().extend(instrs)

def __setitem__(self, index, value):
if isinstance(index, slice):
values = list(value)
for v in values:
self._check_instr(v)
super().__setitem__(index, values)
else:
self._check_instr(value)
super().__setitem__(index, value)


V = TypeVar("V")

Expand Down Expand Up @@ -236,15 +260,11 @@ def __init__(
) -> None:
BaseBytecode.__init__(self)
self.argnames: List[str] = []
for instr in instructions:
self._check_instr(instr)
self.extend(instructions)

def __iter__(self) -> Iterator[Union[Instr, Label, TryBegin, TryEnd, SetLineno]]:
instructions = super().__iter__()
seen_try_begin = False
for instr in instructions:
self._check_instr(instr)
for instr in super().__iter__():
if isinstance(instr, TryBegin):
if seen_try_begin:
raise RuntimeError("TryBegin pseudo instructions cannot be nested.")
Expand Down
45 changes: 24 additions & 21 deletions src/bytecode/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def _check_instr(instr: Any) -> None:

def append(self, instr: Union[Instr, SetLineno, TryBegin, TryEnd]) -> None:
self._check_instr(instr)
if isinstance(instr, Instr):
last = self.get_last_non_artificial_instruction()
if last is not None and last.has_jump():
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)
super().append(instr)

def insert(
Expand All @@ -68,6 +74,17 @@ def extend(
instrs = list(instrs)
for instr in instrs:
self._check_instr(instr)
existing_last = self.get_last_non_artificial_instruction()
last_new_instr: Optional[Instr] = None
for instr in instrs:
if isinstance(instr, Instr):
if (existing_last is not None and existing_last.has_jump()) or (
last_new_instr is not None and last_new_instr.has_jump()
):
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)
last_new_instr = instr
super().extend(instrs)

def __setitem__(self, index, value):
Expand All @@ -81,32 +98,19 @@ def __setitem__(self, index, value):
super().__setitem__(index, value)

def __iter__(self) -> Iterator[Union[Instr, SetLineno, TryBegin, TryEnd]]:
index = 0
while index < len(self):
instr = self[index]
index += 1

for instr in super().__iter__():
if isinstance(instr, Instr) and instr.has_jump():
if index < len(self) and any(
isinstance(self[i], Instr) for i in range(index, len(self))
):
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)

if not isinstance(instr.arg, BasicBlock):
raise ValueError(
"Jump target must a BasicBlock, got %s",
type(instr.arg).__name__,
"Jump target must a BasicBlock, got %s"
% type(instr.arg).__name__
)

if isinstance(instr, TryBegin):
elif isinstance(instr, TryBegin):
if not isinstance(instr.target, BasicBlock):
raise ValueError(
"TryBegin target must a BasicBlock, got %s",
type(instr.target).__name__,
"TryBegin target must a BasicBlock, got %s"
% type(instr.target).__name__
)

yield instr

@overload
Expand All @@ -126,10 +130,9 @@ def __getitem__(self, index):
return value

def get_last_non_artificial_instruction(self) -> Optional[Instr]:
for instr in reversed(self):
for instr in super().__reversed__():
if isinstance(instr, Instr):
return instr

return None

def copy(self: T) -> T:
Expand Down
48 changes: 25 additions & 23 deletions src/bytecode/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,8 @@ def __init__(
self.names = list(names)
self.varnames = list(varnames)
self.exception_table = exception_table or []
for instr in instructions:
self._check_instr(instr)
self.extend(instructions)

def __iter__(self) -> Iterator[Union[ConcreteInstr, SetLineno]]:
instructions = super().__iter__()
for instr in instructions:
self._check_instr(instr)
yield instr

def _check_instr(self, instr: Any) -> None:
if not isinstance(instr, (ConcreteInstr, SetLineno)):
raise ValueError(
Expand Down Expand Up @@ -1036,7 +1028,9 @@ def __init__(self, code: _bytecode.Bytecode) -> None:
self.consts_indices: dict[bytes | Tuple[type, int], int] = {}
self.consts_list: list[Any] = []
self.names: list[str] = []
self.names_map: dict[str, int] = {}
self.varnames: list[str] = []
self.varnames_map: dict[str, int] = {}

def add_const(self, value: Any) -> int:
key = const_key(value)
Expand All @@ -1047,13 +1041,20 @@ def add_const(self, value: Any) -> int:
self.consts_list.append(value)
return index

@staticmethod
def add(names: list[str], name: str) -> int:
try:
index = names.index(name)
except ValueError:
index = len(names)
names.append(name)
def add_name(self, name: str) -> int:
index = self.names_map.get(name)
if index is None:
index = len(self.names)
self.names_map[name] = index
self.names.append(name)
return index

def add_varname(self, name: str) -> int:
index = self.varnames_map.get(name)
if index is None:
index = len(self.varnames)
self.varnames_map[name] = index
self.varnames.append(name)
return index

def concrete_instructions(self) -> None:
Expand All @@ -1074,7 +1075,7 @@ def concrete_instructions(self) -> None:
assert isinstance(binstr.arg, tuple)
for parg in binstr.arg:
assert isinstance(parg, str)
self.add(self.varnames, parg)
self.add_varname(parg)

# We use None as a sentinel to ensure caches for the last instruction are
# properly generated.
Expand Down Expand Up @@ -1158,8 +1159,8 @@ def concrete_instructions(self) -> None:
elif opcode in HAS_LOCAL:
if opcode in DUAL_ARG_OPCODES:
_arg2 = cast(Tuple[str, str], arg)
arg1_index = self.add(self.varnames, _arg2[0])
arg2_index = self.add(self.varnames, _arg2[1])
arg1_index = self.add_varname(_arg2[0])
arg2_index = self.add_varname(_arg2[1])
if arg1_index > 16 or arg2_index > 16:
n1, n2 = DUAL_ARG_OPCODES_SINGLE_OPS[opcode]
c_instr = ConcreteInstr(n1, arg1_index, location=location)
Expand All @@ -1176,7 +1177,7 @@ def concrete_instructions(self) -> None:
c_arg = self.bytecode.freevars.index(arg.name)
else:
assert isinstance(arg, str)
c_arg = self.add(self.varnames, arg)
c_arg = self.add_varname(arg)
elif opcode in HAS_NAME:
if opcode in BITFLAG_OPCODES:
assert (
Expand All @@ -1185,19 +1186,19 @@ def concrete_instructions(self) -> None:
and isinstance(arg[0], bool)
), arg
if isinstance(arg[1], str):
index = self.add(self.names, arg[1])
index = self.add_name(arg[1])
elif isinstance(arg, FormatValue):
index = int(arg)
else:
assert False, arg # noqa
c_arg = int(arg[0]) + (index << 1)
elif opcode in BITFLAG2_OPCODES:
_arg3 = cast(tuple[bool, bool, str], arg)
index = self.add(self.names, _arg3[2])
index = self.add_name(_arg3[2])
c_arg = int(_arg3[0]) + 2 * int(_arg3[1]) + (index << 2)
else:
assert isinstance(arg, str), f"Got {arg}, expected a str"
c_arg = self.add(self.names, arg)
c_arg = self.add_name(arg)
elif opcode in HAS_FREE:
if isinstance(arg, CellVar):
cell_instrs.append(len(self.instructions))
Expand Down Expand Up @@ -1342,7 +1343,8 @@ def to_concrete_bytecode(
if first_const is not UNSET:
self.add_const(first_const)

self.varnames.extend(self.bytecode.argnames)
for name in self.bytecode.argnames:
self.add_varname(name)

self.concrete_instructions()
for _ in range(0, compute_jumps_passes):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def test_constructor(self):

def test_invalid_types(self):
code = Bytecode()
code.append(123)
with self.assertRaises(ValueError):
list(code)
code.append(123)
with self.assertRaises(ValueError):
code.legalize()
code.extend([123])
with self.assertRaises(ValueError):
code.insert(0, 123)
with self.assertRaises(ValueError):
Bytecode([123])

Expand Down
33 changes: 24 additions & 9 deletions tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Instr,
Label,
SetLineno,
TryBegin,
dump_bytecode,
)
from bytecode.utils import PY312, PY313, PY314
Expand Down Expand Up @@ -54,7 +55,7 @@ def disassemble(


class BlockTests(unittest.TestCase):
def test_iter_invalid_types(self):
def test_inserting_invalid_types(self):
# Labels are not allowed in basic blocks — caught at insertion time
block = BasicBlock()
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -85,16 +86,24 @@ def test_iter_invalid_types(self):
# Only one jump allowed and only at the end
block = BasicBlock()
block2 = BasicBlock()
block.extend(
[
Instr("JUMP_FORWARD", block2),
Instr("NOP"),
]
)
# caught at extend time (within batch)
with self.assertRaises(ValueError):
list(block)
block.extend(
[
Instr("JUMP_FORWARD", block2),
Instr("NOP"),
]
)
# caught at append time (cross-boundary)
block = BasicBlock()
block.append(Instr("JUMP_FORWARD", block2))
with self.assertRaises(ValueError):
block.legalize(1)
block.append(Instr("NOP"))
# caught at extend time (cross-boundary)
block = BasicBlock()
block.append(Instr("JUMP_FORWARD", block2))
with self.assertRaises(ValueError):
block.extend([Instr("NOP")])

# jump target must be a BasicBlock
block = BasicBlock()
Expand All @@ -105,6 +114,12 @@ def test_iter_invalid_types(self):
with self.assertRaises(ValueError):
block.legalize(1)

# TryBegin target must be a BasicBlock
block = BasicBlock()
block.extend([TryBegin(label, push_lasti=False)])
with self.assertRaises(ValueError):
list(block)

def test_slice(self):
block = BasicBlock([Instr("NOP")])
next_block = BasicBlock()
Expand Down
7 changes: 4 additions & 3 deletions tests/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,12 @@ def test_attr(self):

def test_invalid_types(self):
code = ConcreteBytecode()
code.append(Label())
with self.assertRaises(ValueError):
list(code)
code.append(Label())
with self.assertRaises(ValueError):
code.legalize()
code.extend([Label()])
with self.assertRaises(ValueError):
code.insert(0, Label())
with self.assertRaises(ValueError):
ConcreteBytecode([Label()])

Expand Down