Skip to content

Commit 65e1375

Browse files
ianhimaxrjones
andauthored
Remove rich as required for tree() (#3778)
* remove rich dependency for tree * lint + pr number * keyword only * Apply suggestion from @maxrjones Co-authored-by: Max Jones <14077947+maxrjones@users.noreply.github.com> --------- Co-authored-by: Max Jones <14077947+maxrjones@users.noreply.github.com>
1 parent 03355b8 commit 65e1375

File tree

9 files changed

+188
-58
lines changed

9 files changed

+188
-58
lines changed

changes/3778.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`Group.tree()` no longer requires the `rich` dependency. Tree rendering now uses built-in ANSI bold for terminals and HTML bold for Jupyter. New parameters: `plain=True` for unstyled output, and `max_nodes` (default 500) to truncate large hierarchies with early bailout.

docs/user-guide/groups.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,3 @@ Groups also have the [`zarr.Group.tree`][] method, e.g.:
133133
print(root.tree())
134134
```
135135

136-
!!! note
137-
[`zarr.Group.tree`][] requires the optional [rich](https://rich.readthedocs.io/en/stable/) dependency. It can be installed with the `[tree]` extra.

docs/user-guide/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ These can be installed using `pip install "zarr[<extra>]"`, e.g. `pip install "z
2626
- `gpu`: support for GPUs
2727
- `remote`: support for reading/writing to remote data stores
2828

29-
Additional optional dependencies include `rich`, `universal_pathlib`. These must be installed separately.
29+
Additional optional dependencies include `universal_pathlib`. These must be installed separately.
3030

3131
## conda
3232

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ gpu = [
7070
"cupy-cuda12x",
7171
]
7272
cli = ["typer"]
73-
optional = ["rich", "universal-pathlib"]
73+
optional = ["universal-pathlib"]
7474

7575
[project.scripts]
7676
zarr = "zarr._cli.cli:app"
@@ -122,7 +122,6 @@ docs = [
122122
"towncrier",
123123
# Optional dependencies to run examples
124124
"numcodecs[msgpack]",
125-
"rich",
126125
"s3fs>=2023.10.0",
127126
"astroid<4",
128127
"pytest",
@@ -131,7 +130,6 @@ dev = [
131130
{include-group = "test"},
132131
{include-group = "remote-tests"},
133132
{include-group = "docs"},
134-
"rich",
135133
"universal-pathlib",
136134
"mypy",
137135
]

src/zarr/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def print_packages(packages: list[str]) -> None:
7878
"s3fs",
7979
"gcsfs",
8080
"universal-pathlib",
81-
"rich",
8281
"obstore",
8382
]
8483

src/zarr/core/_tree.py

Lines changed: 100 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
import io
2-
import os
1+
import sys
2+
from collections import deque
33
from collections.abc import Sequence
4+
from html import escape as html_escape
45
from typing import Any
56

67
from zarr.core.group import AsyncGroup
78

8-
try:
9-
import rich
10-
import rich.console
11-
import rich.tree
12-
except ImportError as e:
13-
raise ImportError("'rich' is required for Group.tree") from e
14-
159

1610
class TreeRepr:
1711
"""
@@ -21,45 +15,120 @@ class TreeRepr:
2115
of Zarr's public API.
2216
"""
2317

24-
def __init__(self, tree: rich.tree.Tree) -> None:
25-
self._tree = tree
18+
def __init__(self, text: str, html: str, truncated: str = "") -> None:
19+
self._text = text
20+
self._html = html
21+
self._truncated = truncated
2622

2723
def __repr__(self) -> str:
28-
color_system = os.environ.get("OVERRIDE_COLOR_SYSTEM", rich.get_console().color_system)
29-
console = rich.console.Console(file=io.StringIO(), color_system=color_system)
30-
console.print(self._tree)
31-
return str(console.file.getvalue())
24+
if self._truncated:
25+
return self._truncated + self._text
26+
return self._text
3227

3328
def _repr_mimebundle_(
3429
self,
35-
include: Sequence[str],
36-
exclude: Sequence[str],
30+
include: Sequence[str] | None = None,
31+
exclude: Sequence[str] | None = None,
3732
**kwargs: Any,
3833
) -> dict[str, str]:
34+
text = self._truncated + self._text if self._truncated else self._text
3935
# For jupyter support.
40-
# Unsure why mypy infers the return type to by Any
41-
return self._tree._repr_mimebundle_(include=include, exclude=exclude, **kwargs) # type: ignore[no-any-return]
36+
html_body = self._truncated + self._html if self._truncated else self._html
37+
html = (
38+
'<pre style="white-space:pre;overflow-x:auto;line-height:normal;'
39+
"font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">"
40+
f"{html_body}</pre>\n"
41+
)
42+
return {"text/plain": text, "text/html": html}
43+
4244

45+
async def group_tree_async(
46+
group: AsyncGroup,
47+
max_depth: int | None = None,
48+
*,
49+
max_nodes: int = 500,
50+
plain: bool = False,
51+
) -> TreeRepr:
52+
members: list[tuple[str, Any]] = []
53+
truncated = False
54+
async for item in group.members(max_depth=max_depth):
55+
if len(members) == max_nodes:
56+
truncated = True
57+
break
58+
members.append(item)
59+
members.sort(key=lambda key_node: key_node[0])
4360

44-
async def group_tree_async(group: AsyncGroup, max_depth: int | None = None) -> TreeRepr:
45-
tree = rich.tree.Tree(label=f"[bold]{group.name}[/bold]")
46-
nodes = {"": tree}
47-
members = sorted([x async for x in group.members(max_depth=max_depth)])
61+
# Set up styling tokens: ANSI bold for terminals, HTML <b> for Jupyter,
62+
# or empty strings when plain=True (useful for LLMs, logging, files).
63+
if plain:
64+
ansi_open = ansi_close = html_open = html_close = ""
65+
else:
66+
# Avoid emitting ANSI escape codes when output is piped or in CI.
67+
use_ansi = sys.stdout.isatty()
68+
ansi_open = "\x1b[1m" if use_ansi else ""
69+
ansi_close = "\x1b[0m" if use_ansi else ""
70+
html_open = "<b>"
71+
html_close = "</b>"
4872

73+
# Group members by parent key so we can render the tree level by level.
74+
nodes: dict[str, list[tuple[str, Any]]] = {}
4975
for key, node in members:
5076
if key.count("/") == 0:
5177
parent_key = ""
5278
else:
5379
parent_key = key.rsplit("/", 1)[0]
54-
parent = nodes[parent_key]
80+
nodes.setdefault(parent_key, []).append((key, node))
5581

56-
# We want what the spec calls the node "name", the part excluding all leading
57-
# /'s and path segments. But node.name includes all that, so we build it here.
82+
# Render the tree iteratively (not recursively) to avoid hitting
83+
# Python's recursion limit on deeply nested hierarchies.
84+
# Each stack frame is (prefix_string, remaining_children_at_this_level).
85+
text_lines = [f"{ansi_open}{group.name}{ansi_close}"]
86+
html_lines = [f"{html_open}{html_escape(group.name)}{html_close}"]
87+
stack = [("", deque(nodes.get("", [])))]
88+
while stack:
89+
prefix, remaining = stack[-1]
90+
if not remaining:
91+
stack.pop()
92+
continue
93+
key, node = remaining.popleft()
5894
name = key.rsplit("/")[-1]
95+
escaped_name = html_escape(name)
96+
# if we popped the last item then remaining will
97+
# now be empty - that's how we got past the if not remaining
98+
# above, but this can still be true.
99+
is_last = not remaining
100+
connector = "└── " if is_last else "├── "
59101
if isinstance(node, AsyncGroup):
60-
label = f"[bold]{name}[/bold]"
102+
text_lines.append(f"{prefix}{connector}{ansi_open}{name}{ansi_close}")
103+
html_lines.append(f"{prefix}{connector}{html_open}{escaped_name}{html_close}")
61104
else:
62-
label = f"[bold]{name}[/bold] {node.shape} {node.dtype}"
63-
nodes[key] = parent.add(label)
64-
65-
return TreeRepr(tree)
105+
text_lines.append(
106+
f"{prefix}{connector}{ansi_open}{name}{ansi_close} {node.shape} {node.dtype}"
107+
)
108+
html_lines.append(
109+
f"{prefix}{connector}{html_open}{escaped_name}{html_close}"
110+
f" {html_escape(str(node.shape))} {html_escape(str(node.dtype))}"
111+
)
112+
# Descend into children with an accumulated prefix:
113+
# Example showing how prefix accumulates:
114+
# /
115+
# ├── a prefix = ""
116+
# │ ├── b prefix = "" + "│ "
117+
# │ │ └── x prefix = "" + "│ " + "│ "
118+
# │ └── c prefix = "" + "│ "
119+
# └── d prefix = ""
120+
# └── e prefix = "" + " "
121+
if children := nodes.get(key, []):
122+
if is_last:
123+
child_prefix = prefix + " "
124+
else:
125+
child_prefix = prefix + "│ "
126+
stack.append((child_prefix, deque(children)))
127+
text = "\n".join(text_lines) + "\n"
128+
html = "\n".join(html_lines) + "\n"
129+
note = (
130+
f"Truncated at max_nodes={max_nodes}, some nodes and their children may be missing\n"
131+
if truncated
132+
else ""
133+
)
134+
return TreeRepr(text, html, truncated=note)

src/zarr/core/group.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,19 +1588,30 @@ async def array_values(
15881588
async for _, array in self.arrays():
15891589
yield array
15901590

1591-
async def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
1591+
async def tree(
1592+
self,
1593+
expand: bool | None = None,
1594+
level: int | None = None,
1595+
*,
1596+
max_nodes: int = 500,
1597+
plain: bool = False,
1598+
) -> Any:
15921599
"""
15931600
Return a tree-like representation of a hierarchy.
15941601
1595-
This requires the optional ``rich`` dependency.
1596-
15971602
Parameters
15981603
----------
15991604
expand : bool, optional
16001605
This keyword is not yet supported. A NotImplementedError is raised if
16011606
it's used.
16021607
level : int, optional
16031608
The maximum depth below this Group to display in the tree.
1609+
max_nodes : int
1610+
Maximum number of nodes to display before truncating. Default is 500.
1611+
plain : bool, optional
1612+
If True, return a plain-text tree without ANSI styling. This is
1613+
useful when the output will be consumed by an LLM or written to a
1614+
file. Default is False.
16041615
16051616
Returns
16061617
-------
@@ -1611,7 +1622,7 @@ async def tree(self, expand: bool | None = None, level: int | None = None) -> An
16111622

16121623
if expand is not None:
16131624
raise NotImplementedError("'expand' is not yet implemented.")
1614-
return await group_tree_async(self, max_depth=level)
1625+
return await group_tree_async(self, max_depth=level, max_nodes=max_nodes, plain=plain)
16151626

16161627
async def empty(self, *, name: str, shape: tuple[int, ...], **kwargs: Any) -> AnyAsyncArray:
16171628
"""Create an empty array with the specified shape in this Group. The contents will
@@ -2371,26 +2382,39 @@ def array_values(self) -> Generator[AnyArray, None]:
23712382
for _, array in self.arrays():
23722383
yield array
23732384

2374-
def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
2385+
def tree(
2386+
self,
2387+
expand: bool | None = None,
2388+
level: int | None = None,
2389+
*,
2390+
max_nodes: int = 500,
2391+
plain: bool = False,
2392+
) -> Any:
23752393
"""
23762394
Return a tree-like representation of a hierarchy.
23772395
2378-
This requires the optional ``rich`` dependency.
2379-
23802396
Parameters
23812397
----------
23822398
expand : bool, optional
23832399
This keyword is not yet supported. A NotImplementedError is raised if
23842400
it's used.
23852401
level : int, optional
23862402
The maximum depth below this Group to display in the tree.
2403+
max_nodes : int
2404+
Maximum number of nodes to display before truncating. Default is 500.
2405+
plain : bool, optional
2406+
If True, return a plain-text tree without ANSI styling. This is
2407+
useful when the output will be consumed by an LLM or written to a
2408+
file. Default is False.
23872409
23882410
Returns
23892411
-------
23902412
TreeRepr
23912413
A pretty-printable object displaying the hierarchy.
23922414
"""
2393-
return self._sync(self._async_group.tree(expand=expand, level=level))
2415+
return self._sync(
2416+
self._async_group.tree(expand=expand, level=level, max_nodes=max_nodes, plain=plain)
2417+
)
23942418

23952419
def create_group(self, name: str, **kwargs: Any) -> Group:
23962420
"""Create a sub-group.

tests/test_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,6 @@ def test_load_local(tmp_path: Path, path: str | None, load_read_only: bool) -> N
599599

600600

601601
def test_tree() -> None:
602-
pytest.importorskip("rich")
603602
g1 = zarr.group()
604603
g1.create_group("foo")
605604
g3 = g1.create_group("bar")

0 commit comments

Comments
 (0)