Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
import warnings
from copy import copy
from functools import lru_cache, wraps
from functools import wraps
from types import EllipsisType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -39,12 +39,11 @@
if TYPE_CHECKING:
from collections.abc import Callable


_logger: Final = _log_utils.get_logger(__name__)

P = ParamSpec("P")
R = TypeVar("R")

_logger: Final = _log_utils.get_logger(__name__)


class DLTypeAnnotation(NamedTuple):
"""A class representing a type annotation for a tensor."""
Expand Down Expand Up @@ -130,7 +129,6 @@ def from_hint( # noqa: PLR0911
return (cls(tensor_type_hint=tensor_type, dltype_annotation=dltype_hint),)


@lru_cache()
def _resolve_types(
annotations: tuple[DLTypeAnnotation | None, ...] | None,
) -> tuple[_tensor_type_base.TensorTypeBase | None, ...] | None:
Expand Down Expand Up @@ -165,7 +163,6 @@ def _maybe_get_type_hints(
return None


@lru_cache()
def _maybe_get_signature(
existing: inspect.Signature | None,
func: Callable[P, R],
Expand Down
76 changes: 45 additions & 31 deletions dltype/_lib/_dltype_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
from typing import Any, Final, NamedTuple, TypeAlias, cast

from dltype._lib import _constants, _dtypes, _errors, _log_utils, _parser, _tensor_type_base
from dltype._lib import _dependency_utilities as _deps

if _deps.is_torch_available():
from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage]
else:

class _NullWarning(Warning):
pass

TracerWarning = _NullWarning
_logger: Final = _log_utils.get_logger(__name__)

EvaluatedDimensionT: TypeAlias = dict[str, int]
Expand Down Expand Up @@ -123,43 +132,48 @@ def assert_context(self) -> None:
"""Considering the current context, check if all tensors match their expected types."""
__tracebackhide__ = not _constants.DEBUG_MODE

start_t = time.perf_counter_ns()
with warnings.catch_warnings():
warnings.simplefilter(category=TracerWarning, action="ignore")

try:
while self._hinted_tensors:
tensor_context = self._hinted_tensors.popleft()
# first check if the tensor could possibly have the right shape
tensor_context.dltype_annotation.check(
tensor_context.tensor,
tensor_name=tensor_context.tensor_arg_name,
)
start_t = time.perf_counter_ns()

if tensor_context.tensor_arg_name in self.registered_tensor_dtypes:
raise _errors.DLTypeDuplicateError(
try:
while self._hinted_tensors:
tensor_context = self._hinted_tensors.popleft()
# first check if the tensor could possibly have the right shape
tensor_context.dltype_annotation.check(
tensor_context.tensor,
tensor_name=tensor_context.tensor_arg_name,
)

self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype
expected_shape = tensor_context.get_expected_shape(
tensor_context.tensor,
)
self._assert_tensor_shape(
tensor_context.tensor_arg_name,
expected_shape,
tensor_context.tensor,
)
if tensor_context.tensor_arg_name in self.registered_tensor_dtypes:
raise _errors.DLTypeDuplicateError(
tensor_name=tensor_context.tensor_arg_name,
)

finally:
end_t = time.perf_counter_ns()
runtime_ns = end_t - start_t
_logger.debug("Context evaluation took %d ns", runtime_ns)
if _maybe_warn_runtime(runtime_ns):
max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6
warnings.warn(
f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms",
UserWarning,
stacklevel=2,
)
self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = (
tensor_context.tensor.dtype
)
expected_shape = tensor_context.get_expected_shape(
tensor_context.tensor,
)
self._assert_tensor_shape(
tensor_context.tensor_arg_name,
expected_shape,
tensor_context.tensor,
)

finally:
end_t = time.perf_counter_ns()
runtime_ns = end_t - start_t
_logger.debug("Context evaluation took %d ns", runtime_ns)
if _maybe_warn_runtime(runtime_ns):
max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6
warnings.warn(
f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms",
UserWarning,
stacklevel=2,
)

def _assert_tensor_shape(
self,
Expand Down
2 changes: 1 addition & 1 deletion dltype/_lib/_tensor_type_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def validate_tensor(
return core_schema.with_info_after_validator_function(
validate_tensor,
schema=core_schema.is_instance_schema(source_type),
field_name=handler.field_name,
)

def check(
Expand All @@ -164,6 +163,7 @@ def check(
"""Check if the tensor matches this type."""
# Basic validation for multi-axis dimensions
__tracebackhide__ = not _constants.DEBUG_MODE

if self.multiaxis_index is not None:
# Min required dimensions = expected shape length + extra dimensions - 1 (the multi-axis placeholder)
min_required_dims = len(self.expected_shape) - 1
Expand Down
Loading
Loading