|
7 | 7 | __all__: list[str] = [] |
8 | 8 |
|
9 | 9 | from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView |
10 | | -from typing import Any, TypeVar, overload |
| 10 | +from typing import Annotated, Any, TypeVar, overload |
| 11 | +from typing_extensions import Doc |
11 | 12 |
|
12 | 13 | from jax.tree_util import register_pytree_node_class |
13 | 14 |
|
@@ -108,33 +109,20 @@ def __repr__(self) -> str: |
108 | 109 | # =========================================== |
109 | 110 | # JAX PyTree |
110 | 111 |
|
111 | | - def tree_flatten(self) -> tuple[tuple[V, ...], tuple[K, ...]]: |
112 | | - """Flatten dict to the values (and keys). |
113 | | -
|
114 | | - Returns |
115 | | - ------- |
116 | | - tuple[V, ...] tuple[str, ...] |
117 | | - A pair of an iterable with the values to be flattened recursively, |
118 | | - and the keys to pass back to the unflattening recipe. |
119 | | - """ |
120 | | - return (tuple(self._data.values()), tuple(self._data.keys())) |
| 112 | + def tree_flatten( |
| 113 | + self, |
| 114 | + ) -> tuple[ |
| 115 | + Annotated[tuple[V, ...], Doc("The values.")], |
| 116 | + Annotated[tuple[K, ...], Doc("The keys as auxiliary data.")], |
| 117 | + ]: |
| 118 | + """Flatten dict to the values (and keys).""" |
| 119 | + return tuple(self._data.values()), tuple(self._data.keys()) |
121 | 120 |
|
122 | 121 | @classmethod |
123 | 122 | def tree_unflatten( |
124 | 123 | cls, |
125 | | - aux_data: tuple[K, ...], |
126 | | - children: tuple[V, ...], |
127 | | - ) -> "ImmutableMap": # type: ignore[type-arg] # TODO: upstream beartype fix for ImmutableMap[V] |
128 | | - """Unflatten. |
129 | | -
|
130 | | - Params: |
131 | | - aux_data: the opaque data that was specified during flattening of the |
132 | | - current treedef. |
133 | | - children: the unflattened children |
134 | | -
|
135 | | - Returns |
136 | | - ------- |
137 | | - a re-constructed object of the registered type, using the specified |
138 | | - children and auxiliary data. |
139 | | - """ |
| 124 | + aux_data: Annotated[tuple[K, ...], Doc("The keys.")], |
| 125 | + children: Annotated[tuple[V, ...], Doc("The values.")], |
| 126 | + ) -> "ImmutableMap[K, V]": |
| 127 | + """Unflatten into an ImmutableMap from the keys and values.""" |
140 | 128 | return cls(tuple(zip(aux_data, children, strict=True))) |
0 commit comments