Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions polyfile/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def should_break(
parent_match: Optional[TestResult],
result: Optional[TestResult]
) -> bool:
return self.pattern.is_contained_in(test.mimetypes())
return self.pattern.is_contained_in(test.mimetypes)

@classmethod
def parse(cls: Type[B], command: str) -> Optional[B]:
Expand Down Expand Up @@ -183,7 +183,7 @@ def should_break(
parent_match: Optional[TestResult],
result: Optional[TestResult]
) -> bool:
return self.ext in test.all_extensions()
return self.ext in test.all_extensions

@classmethod
def parse(cls: Type[B], command: str) -> Optional[B]:
Expand Down
46 changes: 28 additions & 18 deletions polyfile/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
import csv
import functools
from datetime import datetime
from enum import Enum, IntFlag
from importlib import resources
Expand Down Expand Up @@ -802,24 +803,31 @@ def ancestors(self) -> Iterator["MagicTest"]:
stack.append(test.parent)
history.add(test.parent)

def descendants(self) -> Iterator["MagicTest"]:
"""
Yields all descendants of this test.
UseTests will also include all referenced NamedTests and their descendants.

"""
def _compute_descendants(self) -> Tuple["MagicTest", ...]:
"""Compute all descendants of this test (internal, called once)."""
result: List[MagicTest] = []
stack: List[MagicTest] = [self]
history: Set[MagicTest] = set(stack)
while stack:
test = stack.pop()
if test is not self:
yield test
result.append(test)
new_tests = [child for child in test.children if child not in history]
stack.extend(reversed(new_tests))
history |= set(new_tests)
if isinstance(test, UseTest):
stack.append(test.referenced_test)
history.add(test.referenced_test)
return tuple(result)

@functools.cached_property
def descendants(self) -> Tuple["MagicTest", ...]:
"""
Returns all descendants of this test (cached).
UseTests will also include all referenced NamedTests and their descendants.

"""
return self._compute_descendants()

def referenced_tests(self) -> Set["NamedTest"]:
result: Set[NamedTest] = set()
Expand Down Expand Up @@ -853,29 +861,31 @@ def _mimetypes(self) -> Iterator[str]:
if self.mime is not None:
yielded |= set(self.mime.possibilities())
yield from yielded
for d in self.descendants():
for d in self.descendants:
if d.mime is not None:
possibilities = set(d.mime.possibilities())
new_mimes = possibilities - yielded
yield from new_mimes
yielded |= new_mimes

def mimetypes(self) -> LazyIterableSet[str]:
"""Returns the set of all possible MIME types that this test or any of its descendants could match against"""
return LazyIterableSet(self._mimetypes())
@functools.cached_property
def mimetypes(self) -> Tuple[str, ...]:
"""Returns all possible MIME types that this test or any of its descendants could match against"""
return tuple(self._mimetypes())

def _all_extensions(self) -> Iterator[str]:
"""Yields all possible extensions that this test or any of its descendants could match against"""
yield from self.extensions
yielded = set(self.extensions)
for d in self.descendants():
for d in self.descendants:
new_extensions = d.extensions - yielded
yield from new_extensions
yielded |= new_extensions

def all_extensions(self) -> LazyIterableSet[str]:
"""Returns the set of all possible extensions that this test or any of its descendants could match against"""
return LazyIterableSet(self._all_extensions())
@functools.cached_property
def all_extensions(self) -> Tuple[str, ...]:
"""Returns all possible extensions that this test or any of its descendants could match against"""
return tuple(self._all_extensions())

@abstractmethod
def test(self, data: bytes, absolute_offset: int, parent_match: Optional[TestResult]) -> TestResult:
Expand Down Expand Up @@ -2695,9 +2705,9 @@ def _reassign_test_types(self):
self._non_text_tests.add(test)
if test.can_be_indirect:
self._tests_that_can_be_indirect.add(test)
for mime in test.mimetypes():
for mime in test.mimetypes:
self._tests_by_mime[mime].add(test)
for ext in test.all_extensions():
for ext in test.all_extensions:
self._tests_by_ext[ext].add(test)

def only_match(
Expand All @@ -2716,7 +2726,7 @@ def only_match(
return self
tests: Set[MagicTest] = {
indirect_test for indirect_test in self.tests_that_can_be_indirect
if not any(True for _ in indirect_test.mimetypes())
if not any(True for _ in indirect_test.mimetypes)
}
if mimetypes is not None:
for mime in mimetypes:
Expand Down
2 changes: 1 addition & 1 deletion polyfile/polyfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def handle_mimetype(
length = len(data) - offset
extension: Optional[str] = None
try:
extension = next(iter(match_obj.test.all_extensions()))
extension = next(iter(match_obj.test.all_extensions))
except StopIteration:
pass
m = Match(
Expand Down