Tensor duck#40
Conversation
…se the protocol instead of torch.Tensor directly.
| if is_torchtyping_annotation: | ||
| base_cls, *all_metadata = get_args(expected_type) | ||
| if not issubclass(base_cls, torch.Tensor): | ||
| if not isinstance(base_cls(), TensorLike): |
There was a problem hiding this comment.
I'm not sure about this last change. As mentioned, the protocol class only supports isintance() because it has properties. This means I had to require default construction.
But, I think this test may be unnecessary - after all the other tests I think we know this is a TensorLike element?
I think it might be better to just get rid of this test. @patrick-kidger
There was a problem hiding this comment.
In fact, it does seem I have a strong motivation to remove this. The case where I want to apply it is to check shape signatures on an abstract base class so default construction may not be an option.
There was a problem hiding this comment.
I have updated the PR accordingly.
|
Thanks for the PR! Unfortunately, this isn't quite the direction I had in mind. Following on from the discussion in #39, perhaps it's worth making clear that I don't intend to make TorchTyping depend on JAX. Rather, that the plan is to simply copy over the non-JAX parts of the code. (Which is most of it.) The idea would be to end up with annotations that look like At a technical level this should be essentially simple. The main hurdle - and the reason I've been putting off doing this is - is writing up documentation that makes this transition clear. |
|
Thanks for the clarification! I can totally see why you want to pull over the jaxtyping code and have a single code base. I understand that this PR is perhaps not what you were looking for, but I think it could actually represent a very important step in generalizing what you have and maybe even merging the two code bases. Let's take an example snippet from jaxtyping where the dtype is extracted (array_types.py: 129-) I think you would agree that it's a bit awkward and somewhat hard to extend since the supported classes have to be coded in advance. Then, to type-check a concrete class like numpy.array or torch.Tensor we just use the adapter pattern to map the specialized methods to the interface. (As an example, a simple name remapper: Adapter Method – Python Design Patterns). This would make it easy for folks like me to extend your library to array-type objects such as LinearOperator by just writing an adapter to the interface specified by the library. In addition, I think it could also let you merge these two libraries and make your life easier. You wrote that: |
|
Hmm. I suppose the practical implementation of such an adaptor would be via a registry: import functools as ft
@ft.singledispatch
def get_dtype(obj):
# Note that this default implementation does not explicitly
# depend on any of PyTorch/etc; thus the singledispatch
# hook is made available just for the sake of user-defined
# custom types.
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
else:
# PyTorch
repr_dtype = repr(obj.dtype).split(".")
if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
dtype = repr_dtype[1]
else:
raise RuntimeError(
"Unrecognised array/tensor type to extract dtype from"
)
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
...
dtype = get_dtype(obj)
...and then in your user code, you could add a custom overload for your type. I'd be willing to accept a PR for this over in jaxtyping. |
As promised, here is the PR to upgrade the library to define a 'torch-like' protocol and use that for the base type rather than using torch.Tensor directly. This lets users perform dimension checking on classes that support a Tensor interface but do not directly inherit from torch.Tensor. I think the change is fairly clear-cut, I have added a test case to demonstrate and verify that dimensions are actually checked.
The only question I have is about the change to line 304 in typechecker.py (the last change below).
Is this test really necessary?
I had to change it to use default construction because protocols don't support isinstance if they have properties.