Skip to content

Commit 3311444

Browse files
authored
Merge pull request #3440 from trailofbits/perf/cache-test-types
perf: cache descendants, mimetypes, and extensions in MagicTest
2 parents 537c10e + 0b00346 commit 3311444

3 files changed

Lines changed: 31 additions & 21 deletions

File tree

polyfile/debugger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def should_break(
147147
parent_match: Optional[TestResult],
148148
result: Optional[TestResult]
149149
) -> bool:
150-
return self.pattern.is_contained_in(test.mimetypes())
150+
return self.pattern.is_contained_in(test.mimetypes)
151151

152152
@classmethod
153153
def parse(cls: Type[B], command: str) -> Optional[B]:
@@ -183,7 +183,7 @@ def should_break(
183183
parent_match: Optional[TestResult],
184184
result: Optional[TestResult]
185185
) -> bool:
186-
return self.ext in test.all_extensions()
186+
return self.ext in test.all_extensions
187187

188188
@classmethod
189189
def parse(cls: Type[B], command: str) -> Optional[B]:

polyfile/magic.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from abc import ABC, abstractmethod
1212
from collections import defaultdict
1313
import csv
14+
import functools
1415
from datetime import datetime
1516
from enum import Enum, IntFlag
1617
from importlib import resources
@@ -802,24 +803,31 @@ def ancestors(self) -> Iterator["MagicTest"]:
802803
stack.append(test.parent)
803804
history.add(test.parent)
804805

805-
def descendants(self) -> Iterator["MagicTest"]:
806-
"""
807-
Yields all descendants of this test.
808-
UseTests will also include all referenced NamedTests and their descendants.
809-
810-
"""
806+
def _compute_descendants(self) -> Tuple["MagicTest", ...]:
807+
"""Compute all descendants of this test (internal, called once)."""
808+
result: List[MagicTest] = []
811809
stack: List[MagicTest] = [self]
812810
history: Set[MagicTest] = set(stack)
813811
while stack:
814812
test = stack.pop()
815813
if test is not self:
816-
yield test
814+
result.append(test)
817815
new_tests = [child for child in test.children if child not in history]
818816
stack.extend(reversed(new_tests))
819817
history |= set(new_tests)
820818
if isinstance(test, UseTest):
821819
stack.append(test.referenced_test)
822820
history.add(test.referenced_test)
821+
return tuple(result)
822+
823+
@functools.cached_property
824+
def descendants(self) -> Tuple["MagicTest", ...]:
825+
"""
826+
Returns all descendants of this test (cached).
827+
UseTests will also include all referenced NamedTests and their descendants.
828+
829+
"""
830+
return self._compute_descendants()
823831

824832
def referenced_tests(self) -> Set["NamedTest"]:
825833
result: Set[NamedTest] = set()
@@ -853,29 +861,31 @@ def _mimetypes(self) -> Iterator[str]:
853861
if self.mime is not None:
854862
yielded |= set(self.mime.possibilities())
855863
yield from yielded
856-
for d in self.descendants():
864+
for d in self.descendants:
857865
if d.mime is not None:
858866
possibilities = set(d.mime.possibilities())
859867
new_mimes = possibilities - yielded
860868
yield from new_mimes
861869
yielded |= new_mimes
862870

863-
def mimetypes(self) -> LazyIterableSet[str]:
864-
"""Returns the set of all possible MIME types that this test or any of its descendants could match against"""
865-
return LazyIterableSet(self._mimetypes())
871+
@functools.cached_property
872+
def mimetypes(self) -> Tuple[str, ...]:
873+
"""Returns all possible MIME types that this test or any of its descendants could match against"""
874+
return tuple(self._mimetypes())
866875

867876
def _all_extensions(self) -> Iterator[str]:
868877
"""Yields all possible extensions that this test or any of its descendants could match against"""
869878
yield from self.extensions
870879
yielded = set(self.extensions)
871-
for d in self.descendants():
880+
for d in self.descendants:
872881
new_extensions = d.extensions - yielded
873882
yield from new_extensions
874883
yielded |= new_extensions
875884

876-
def all_extensions(self) -> LazyIterableSet[str]:
877-
"""Returns the set of all possible extensions that this test or any of its descendants could match against"""
878-
return LazyIterableSet(self._all_extensions())
885+
@functools.cached_property
886+
def all_extensions(self) -> Tuple[str, ...]:
887+
"""Returns all possible extensions that this test or any of its descendants could match against"""
888+
return tuple(self._all_extensions())
879889

880890
@abstractmethod
881891
def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult:
@@ -2695,9 +2705,9 @@ def _reassign_test_types(self):
26952705
self._non_text_tests.add(test)
26962706
if test.can_be_indirect:
26972707
self._tests_that_can_be_indirect.add(test)
2698-
for mime in test.mimetypes():
2708+
for mime in test.mimetypes:
26992709
self._tests_by_mime[mime].add(test)
2700-
for ext in test.all_extensions():
2710+
for ext in test.all_extensions:
27012711
self._tests_by_ext[ext].add(test)
27022712

27032713
def only_match(
@@ -2716,7 +2726,7 @@ def only_match(
27162726
return self
27172727
tests: Set[MagicTest] = {
27182728
indirect_test for indirect_test in self.tests_that_can_be_indirect
2719-
if not any(True for _ in indirect_test.mimetypes())
2729+
if not any(True for _ in indirect_test.mimetypes)
27202730
}
27212731
if mimetypes is not None:
27222732
for mime in mimetypes:

polyfile/polyfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def handle_mimetype(
230230
length = len(data) - offset
231231
extension: Optional[str] = None
232232
try:
233-
extension = next(iter(match_obj.test.all_extensions()))
233+
extension = next(iter(match_obj.test.all_extensions))
234234
except StopIteration:
235235
pass
236236
m = Match(

0 commit comments

Comments
 (0)