Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ unreleased: Version 0.18.0
- Replace string literal type annotations with postponed evaluation using
``from __future__ import annotations`` PR #191

Breaking changes:

- ``BasicBlock``, ``Bytecode``, and ``ConcreteBytecode`` now validate inserted
instructions at insertion time (``append``, ``extend``, ``insert``,
``__setitem__``) rather than during iteration. Code that relied on catching
``ValueError`` from ``list(block)`` or ``for instr in block:`` must wrap the
insertion call instead. PR #199

Bugfixes:

- Fix handling of END_ASYNC_FOR which is a backward jump PR #179
Expand Down
30 changes: 25 additions & 5 deletions src/bytecode/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ def legalize(self) -> None:
def _check_instr(self, instr):
raise NotImplementedError()

def append(self, instr: U) -> None: # type: ignore[override]
self._check_instr(instr)
super().append(instr)

def insert(self, index: SupportsIndex, instr: U) -> None: # type: ignore[override]
self._check_instr(instr)
super().insert(index, instr)

def extend(self, instrs) -> None: # type: ignore[override]
instrs = list(instrs)
for instr in instrs:
self._check_instr(instr)
super().extend(instrs)

def __setitem__(self, index, value):
if isinstance(index, slice):
values = list(value)
for v in values:
self._check_instr(v)
super().__setitem__(index, values)
else:
self._check_instr(value)
super().__setitem__(index, value)


V = TypeVar("V")

Expand Down Expand Up @@ -236,15 +260,11 @@ def __init__(
) -> None:
BaseBytecode.__init__(self)
self.argnames: List[str] = []
for instr in instructions:
self._check_instr(instr)
self.extend(instructions)

def __iter__(self) -> Iterator[Union[Instr, Label, TryBegin, TryEnd, SetLineno]]:
instructions = super().__iter__()
seen_try_begin = False
for instr in instructions:
self._check_instr(instr)
for instr in super().__iter__():
if isinstance(instr, TryBegin):
if seen_try_begin:
raise RuntimeError("TryBegin pseudo instructions cannot be nested.")
Expand Down
45 changes: 24 additions & 21 deletions src/bytecode/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def _check_instr(instr: Any) -> None:

def append(self, instr: Union[Instr, SetLineno, TryBegin, TryEnd]) -> None:
self._check_instr(instr)
if isinstance(instr, Instr):
last = self.get_last_non_artificial_instruction()
if last is not None and last.has_jump():
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)
super().append(instr)

def insert(
Expand All @@ -68,6 +74,17 @@ def extend(
instrs = list(instrs)
for instr in instrs:
self._check_instr(instr)
existing_last = self.get_last_non_artificial_instruction()
last_new_instr: Optional[Instr] = None
for instr in instrs:
if isinstance(instr, Instr):
if (existing_last is not None and existing_last.has_jump()) or (
last_new_instr is not None and last_new_instr.has_jump()
):
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)
last_new_instr = instr
super().extend(instrs)

def __setitem__(self, index, value):
Expand All @@ -81,32 +98,19 @@ def __setitem__(self, index, value):
super().__setitem__(index, value)

def __iter__(self) -> Iterator[Union[Instr, SetLineno, TryBegin, TryEnd]]:
index = 0
while index < len(self):
instr = self[index]
index += 1

for instr in super().__iter__():
if isinstance(instr, Instr) and instr.has_jump():
if index < len(self) and any(
isinstance(self[i], Instr) for i in range(index, len(self))
):
raise ValueError(
"Only the last instruction of a basic block can be a jump"
)

if not isinstance(instr.arg, BasicBlock):
raise ValueError(
"Jump target must a BasicBlock, got %s",
type(instr.arg).__name__,
"Jump target must a BasicBlock, got %s"
% type(instr.arg).__name__
)

if isinstance(instr, TryBegin):
elif isinstance(instr, TryBegin):
if not isinstance(instr.target, BasicBlock):
raise ValueError(
"TryBegin target must a BasicBlock, got %s",
type(instr.target).__name__,
"TryBegin target must a BasicBlock, got %s"
% type(instr.target).__name__
)

yield instr

@overload
Expand All @@ -126,10 +130,9 @@ def __getitem__(self, index):
return value

def get_last_non_artificial_instruction(self) -> Optional[Instr]:
for instr in reversed(self):
for instr in super().__reversed__():
if isinstance(instr, Instr):
return instr

return None

def copy(self: T) -> T:
Expand Down
48 changes: 25 additions & 23 deletions src/bytecode/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,8 @@ def __init__(
self.names = list(names)
self.varnames = list(varnames)
self.exception_table = exception_table or []
for instr in instructions:
self._check_instr(instr)
self.extend(instructions)

def __iter__(self) -> Iterator[Union[ConcreteInstr, SetLineno]]:
instructions = super().__iter__()
for instr in instructions:
self._check_instr(instr)
yield instr

def _check_instr(self, instr: Any) -> None:
if not isinstance(instr, (ConcreteInstr, SetLineno)):
raise ValueError(
Expand Down Expand Up @@ -1036,7 +1028,9 @@ def __init__(self, code: _bytecode.Bytecode) -> None:
self.consts_indices: dict[bytes | Tuple[type, int], int] = {}
self.consts_list: list[Any] = []
self.names: list[str] = []
self.names_map: dict[str, int] = {}
self.varnames: list[str] = []
self.varnames_map: dict[str, int] = {}

def add_const(self, value: Any) -> int:
key = const_key(value)
Expand All @@ -1047,13 +1041,20 @@ def add_const(self, value: Any) -> int:
self.consts_list.append(value)
return index

@staticmethod
def add(names: list[str], name: str) -> int:
try:
index = names.index(name)
except ValueError:
index = len(names)
names.append(name)
def add_name(self, name: str) -> int:
index = self.names_map.get(name)
if index is None:
index = len(self.names)
self.names_map[name] = index
self.names.append(name)
return index

def add_varname(self, name: str) -> int:
index = self.varnames_map.get(name)
if index is None:
index = len(self.varnames)
self.varnames_map[name] = index
self.varnames.append(name)
return index

def concrete_instructions(self) -> None:
Expand All @@ -1074,7 +1075,7 @@ def concrete_instructions(self) -> None:
assert isinstance(binstr.arg, tuple)
for parg in binstr.arg:
assert isinstance(parg, str)
self.add(self.varnames, parg)
self.add_varname(parg)

# We use None as a sentinel to ensure caches for the last instruction are
# properly generated.
Expand Down Expand Up @@ -1158,8 +1159,8 @@ def concrete_instructions(self) -> None:
elif opcode in HAS_LOCAL:
if opcode in DUAL_ARG_OPCODES:
_arg2 = cast(Tuple[str, str], arg)
arg1_index = self.add(self.varnames, _arg2[0])
arg2_index = self.add(self.varnames, _arg2[1])
arg1_index = self.add_varname(_arg2[0])
arg2_index = self.add_varname(_arg2[1])
if arg1_index > 16 or arg2_index > 16:
n1, n2 = DUAL_ARG_OPCODES_SINGLE_OPS[opcode]
c_instr = ConcreteInstr(n1, arg1_index, location=location)
Expand All @@ -1176,7 +1177,7 @@ def concrete_instructions(self) -> None:
c_arg = self.bytecode.freevars.index(arg.name)
else:
assert isinstance(arg, str)
c_arg = self.add(self.varnames, arg)
c_arg = self.add_varname(arg)
elif opcode in HAS_NAME:
if opcode in BITFLAG_OPCODES:
assert (
Expand All @@ -1185,19 +1186,19 @@ def concrete_instructions(self) -> None:
and isinstance(arg[0], bool)
), arg
if isinstance(arg[1], str):
index = self.add(self.names, arg[1])
index = self.add_name(arg[1])
elif isinstance(arg, FormatValue):
index = int(arg)
else:
assert False, arg # noqa
c_arg = int(arg[0]) + (index << 1)
elif opcode in BITFLAG2_OPCODES:
_arg3 = cast(tuple[bool, bool, str], arg)
index = self.add(self.names, _arg3[2])
index = self.add_name(_arg3[2])
c_arg = int(_arg3[0]) + 2 * int(_arg3[1]) + (index << 2)
else:
assert isinstance(arg, str), f"Got {arg}, expected a str"
c_arg = self.add(self.names, arg)
c_arg = self.add_name(arg)
elif opcode in HAS_FREE:
if isinstance(arg, CellVar):
cell_instrs.append(len(self.instructions))
Expand Down Expand Up @@ -1342,7 +1343,8 @@ def to_concrete_bytecode(
if first_const is not UNSET:
self.add_const(first_const)

self.varnames.extend(self.bytecode.argnames)
for name in self.bytecode.argnames:
self.add_varname(name)

self.concrete_instructions()
for _ in range(0, compute_jumps_passes):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def test_constructor(self):

def test_invalid_types(self):
code = Bytecode()
code.append(123)
with self.assertRaises(ValueError):
list(code)
code.append(123)
with self.assertRaises(ValueError):
code.legalize()
code.extend([123])
with self.assertRaises(ValueError):
code.insert(0, 123)
with self.assertRaises(ValueError):
Bytecode([123])

Expand Down
24 changes: 16 additions & 8 deletions tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,24 @@ def test_iter_invalid_types(self):
# Only one jump allowed and only at the end
block = BasicBlock()
block2 = BasicBlock()
block.extend(
[
Instr("JUMP_FORWARD", block2),
Instr("NOP"),
]
)
# caught at extend time (within batch)
with self.assertRaises(ValueError):
list(block)
block.extend(
[
Instr("JUMP_FORWARD", block2),
Instr("NOP"),
]
)
# caught at append time (cross-boundary)
block = BasicBlock()
block.append(Instr("JUMP_FORWARD", block2))
with self.assertRaises(ValueError):
block.legalize(1)
block.append(Instr("NOP"))
# caught at extend time (cross-boundary)
block = BasicBlock()
block.append(Instr("JUMP_FORWARD", block2))
with self.assertRaises(ValueError):
block.extend([Instr("NOP")])

# jump target must be a BasicBlock
block = BasicBlock()
Expand Down
7 changes: 4 additions & 3 deletions tests/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,12 @@ def test_attr(self):

def test_invalid_types(self):
code = ConcreteBytecode()
code.append(Label())
with self.assertRaises(ValueError):
list(code)
code.append(Label())
with self.assertRaises(ValueError):
code.legalize()
code.extend([Label()])
with self.assertRaises(ValueError):
code.insert(0, Label())
with self.assertRaises(ValueError):
ConcreteBytecode([Label()])

Expand Down