Skip to content

Commit ecc87f8

Browse files
committed
fix innocuous tracer warnings in onnx export
1 parent 660996a commit ecc87f8

6 files changed

Lines changed: 147 additions & 103 deletions

File tree

dltype/_lib/_core.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import itertools
77
import warnings
88
from copy import copy
9-
from functools import lru_cache, wraps
9+
from functools import wraps
1010
from types import EllipsisType
1111
from typing import (
1212
TYPE_CHECKING,
@@ -39,12 +39,11 @@
3939
if TYPE_CHECKING:
4040
from collections.abc import Callable
4141

42-
43-
_logger: Final = _log_utils.get_logger(__name__)
44-
4542
P = ParamSpec("P")
4643
R = TypeVar("R")
4744

45+
_logger: Final = _log_utils.get_logger(__name__)
46+
4847

4948
class DLTypeAnnotation(NamedTuple):
5049
"""A class representing a type annotation for a tensor."""
@@ -130,7 +129,6 @@ def from_hint( # noqa: PLR0911
130129
return (cls(tensor_type_hint=tensor_type, dltype_annotation=dltype_hint),)
131130

132131

133-
@lru_cache()
134132
def _resolve_types(
135133
annotations: tuple[DLTypeAnnotation | None, ...] | None,
136134
) -> tuple[_tensor_type_base.TensorTypeBase | None, ...] | None:
@@ -165,7 +163,6 @@ def _maybe_get_type_hints(
165163
return None
166164

167165

168-
@lru_cache()
169166
def _maybe_get_signature(
170167
existing: inspect.Signature | None,
171168
func: Callable[P, R],

dltype/_lib/_dltype_context.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@
88
from typing import Any, Final, NamedTuple, TypeAlias, cast
99

1010
from dltype._lib import _constants, _dtypes, _errors, _log_utils, _parser, _tensor_type_base
11+
from dltype._lib import _dependency_utilities as _deps
1112

13+
if _deps.is_torch_available():
14+
from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage]
15+
else:
16+
17+
class _NullWarning(Warning):
18+
pass
19+
20+
TracerWarning = _NullWarning
1221
_logger: Final = _log_utils.get_logger(__name__)
1322

1423
EvaluatedDimensionT: TypeAlias = dict[str, int]
@@ -123,43 +132,48 @@ def assert_context(self) -> None:
123132
"""Considering the current context, check if all tensors match their expected types."""
124133
__tracebackhide__ = not _constants.DEBUG_MODE
125134

126-
start_t = time.perf_counter_ns()
135+
with warnings.catch_warnings():
136+
warnings.simplefilter(category=TracerWarning, action="ignore")
127137

128-
try:
129-
while self._hinted_tensors:
130-
tensor_context = self._hinted_tensors.popleft()
131-
# first check if the tensor could possibly have the right shape
132-
tensor_context.dltype_annotation.check(
133-
tensor_context.tensor,
134-
tensor_name=tensor_context.tensor_arg_name,
135-
)
138+
start_t = time.perf_counter_ns()
136139

137-
if tensor_context.tensor_arg_name in self.registered_tensor_dtypes:
138-
raise _errors.DLTypeDuplicateError(
140+
try:
141+
while self._hinted_tensors:
142+
tensor_context = self._hinted_tensors.popleft()
143+
# first check if the tensor could possibly have the right shape
144+
tensor_context.dltype_annotation.check(
145+
tensor_context.tensor,
139146
tensor_name=tensor_context.tensor_arg_name,
140147
)
141148

142-
self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype
143-
expected_shape = tensor_context.get_expected_shape(
144-
tensor_context.tensor,
145-
)
146-
self._assert_tensor_shape(
147-
tensor_context.tensor_arg_name,
148-
expected_shape,
149-
tensor_context.tensor,
150-
)
149+
if tensor_context.tensor_arg_name in self.registered_tensor_dtypes:
150+
raise _errors.DLTypeDuplicateError(
151+
tensor_name=tensor_context.tensor_arg_name,
152+
)
151153

152-
finally:
153-
end_t = time.perf_counter_ns()
154-
runtime_ns = end_t - start_t
155-
_logger.debug("Context evaluation took %d ns", runtime_ns)
156-
if _maybe_warn_runtime(runtime_ns):
157-
max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6
158-
warnings.warn(
159-
f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms",
160-
UserWarning,
161-
stacklevel=2,
162-
)
154+
self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = (
155+
tensor_context.tensor.dtype
156+
)
157+
expected_shape = tensor_context.get_expected_shape(
158+
tensor_context.tensor,
159+
)
160+
self._assert_tensor_shape(
161+
tensor_context.tensor_arg_name,
162+
expected_shape,
163+
tensor_context.tensor,
164+
)
165+
166+
finally:
167+
end_t = time.perf_counter_ns()
168+
runtime_ns = end_t - start_t
169+
_logger.debug("Context evaluation took %d ns", runtime_ns)
170+
if _maybe_warn_runtime(runtime_ns):
171+
max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6
172+
warnings.warn(
173+
f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms",
174+
UserWarning,
175+
stacklevel=2,
176+
)
163177

164178
def _assert_tensor_shape(
165179
self,

dltype/_lib/_tensor_type_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def validate_tensor(
153153
return core_schema.with_info_after_validator_function(
154154
validate_tensor,
155155
schema=core_schema.is_instance_schema(source_type),
156-
field_name=handler.field_name,
157156
)
158157

159158
def check(
@@ -164,6 +163,7 @@ def check(
164163
"""Check if the tensor matches this type."""
165164
# Basic validation for multi-axis dimensions
166165
__tracebackhide__ = not _constants.DEBUG_MODE
166+
167167
if self.multiaxis_index is not None:
168168
# Min required dimensions = expected shape length + extra dimensions - 1 (the multi-axis placeholder)
169169
min_required_dims = len(self.expected_shape) - 1

0 commit comments

Comments
 (0)