Skip to content

Commit 5ee462c

Browse files
authored
feat: doc annotations (#4)
* feat: doc annotations * build: typing_extensions Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent a563e96 commit 5ee462c

2 files changed

Lines changed: 18 additions & 26 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ classifiers = [
3232
dynamic = ["version"]
3333
dependencies = [
3434
"jax",
35+
"typing_extensions >= 4.8",
3536
]
3637

3738
[project.optional-dependencies]
@@ -140,6 +141,9 @@ ignore = [
140141
"ISC001", # Conflicts with formatter
141142
]
142143

144+
[tool.ruff.lint.isort]
145+
extra-standard-library = ["typing_extensions"]
146+
143147
[tool.ruff.lint.per-file-ignores]
144148
"tests/**" = ["T20"]
145149
"noxfile.py" = ["T20"]

src/immutable_map_jax/_core.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
__all__: list[str] = []
88

99
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView
10-
from typing import Any, TypeVar, overload
10+
from typing import Annotated, Any, TypeVar, overload
11+
from typing_extensions import Doc
1112

1213
from jax.tree_util import register_pytree_node_class
1314

@@ -108,33 +109,20 @@ def __repr__(self) -> str:
108109
# ===========================================
109110
# JAX PyTree
110111

111-
def tree_flatten(self) -> tuple[tuple[V, ...], tuple[K, ...]]:
112-
"""Flatten dict to the values (and keys).
113-
114-
Returns
115-
-------
116-
tuple[V, ...] tuple[str, ...]
117-
A pair of an iterable with the values to be flattened recursively,
118-
and the keys to pass back to the unflattening recipe.
119-
"""
120-
return (tuple(self._data.values()), tuple(self._data.keys()))
112+
def tree_flatten(
113+
self,
114+
) -> tuple[
115+
Annotated[tuple[V, ...], Doc("The values.")],
116+
Annotated[tuple[K, ...], Doc("The keys as auxiliary data.")],
117+
]:
118+
"""Flatten dict to the values (and keys)."""
119+
return tuple(self._data.values()), tuple(self._data.keys())
121120

122121
@classmethod
123122
def tree_unflatten(
124123
cls,
125-
aux_data: tuple[K, ...],
126-
children: tuple[V, ...],
127-
) -> "ImmutableMap": # type: ignore[type-arg] # TODO: upstream beartype fix for ImmutableMap[V]
128-
"""Unflatten.
129-
130-
Params:
131-
aux_data: the opaque data that was specified during flattening of the
132-
current treedef.
133-
children: the unflattened children
134-
135-
Returns
136-
-------
137-
a re-constructed object of the registered type, using the specified
138-
children and auxiliary data.
139-
"""
124+
aux_data: Annotated[tuple[K, ...], Doc("The keys.")],
125+
children: Annotated[tuple[V, ...], Doc("The values.")],
126+
) -> "ImmutableMap[K, V]":
127+
"""Unflatten into an ImmutableMap from the keys and values."""
140128
return cls(tuple(zip(aux_data, children, strict=True)))

0 commit comments

Comments
 (0)