Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3778.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`Group.tree()` no longer requires the `rich` dependency. Tree rendering now uses built-in ANSI bold for terminals and HTML bold for Jupyter. New parameters: `plain=True` for unstyled output, and `max_nodes` (default 500) to truncate large hierarchies with early bailout.
2 changes: 0 additions & 2 deletions docs/user-guide/groups.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,3 @@ Groups also have the [`zarr.Group.tree`][] method, e.g.:
print(root.tree())
```

!!! note
[`zarr.Group.tree`][] requires the optional [rich](https://rich.readthedocs.io/en/stable/) dependency. It can be installed with the `[tree]` extra.
2 changes: 1 addition & 1 deletion docs/user-guide/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ These can be installed using `pip install "zarr[<extra>]"`, e.g. `pip install "z
- `gpu`: support for GPUs
- `remote`: support for reading/writing to remote data stores

Additional optional dependencies include `rich`, `universal_pathlib`. These must be installed separately.
Additional optional dependencies include `universal_pathlib`. These must be installed separately.

## conda

Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ gpu = [
"cupy-cuda12x",
]
cli = ["typer"]
optional = ["rich", "universal-pathlib"]
optional = ["universal-pathlib"]

[project.scripts]
zarr = "zarr._cli.cli:app"
Expand Down Expand Up @@ -122,7 +122,6 @@ docs = [
"towncrier",
# Optional dependencies to run examples
"numcodecs[msgpack]",
"rich",
"s3fs>=2023.10.0",
"astroid<4",
"pytest",
Expand All @@ -131,7 +130,6 @@ dev = [
{include-group = "test"},
{include-group = "remote-tests"},
{include-group = "docs"},
"rich",
"universal-pathlib",
"mypy",
]
Expand Down
1 change: 0 additions & 1 deletion src/zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def print_packages(packages: list[str]) -> None:
"s3fs",
"gcsfs",
"universal-pathlib",
"rich",
"obstore",
]

Expand Down
130 changes: 99 additions & 31 deletions src/zarr/core/_tree.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import io
import os
import sys
from collections import deque
from collections.abc import Sequence
from html import escape as html_escape
from typing import Any

from zarr.core.group import AsyncGroup

try:
import rich
import rich.console
import rich.tree
except ImportError as e:
raise ImportError("'rich' is required for Group.tree") from e


class TreeRepr:
"""
Expand All @@ -21,45 +15,119 @@ class TreeRepr:
of Zarr's public API.
"""

def __init__(self, tree: rich.tree.Tree) -> None:
self._tree = tree
def __init__(self, text: str, html: str, truncated: str = "") -> None:
self._text = text
self._html = html
self._truncated = truncated

def __repr__(self) -> str:
color_system = os.environ.get("OVERRIDE_COLOR_SYSTEM", rich.get_console().color_system)
console = rich.console.Console(file=io.StringIO(), color_system=color_system)
console.print(self._tree)
return str(console.file.getvalue())
if self._truncated:
return self._truncated + self._text
return self._text

def _repr_mimebundle_(
self,
include: Sequence[str],
exclude: Sequence[str],
include: Sequence[str] | None = None,
exclude: Sequence[str] | None = None,
**kwargs: Any,
) -> dict[str, str]:
text = self._truncated + self._text if self._truncated else self._text
# For jupyter support.
# Unsure why mypy infers the return type to by Any
return self._tree._repr_mimebundle_(include=include, exclude=exclude, **kwargs) # type: ignore[no-any-return]
html_body = self._truncated + self._html if self._truncated else self._html
html = (
'<pre style="white-space:pre;overflow-x:auto;line-height:normal;'
"font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">"
f"{html_body}</pre>\n"
)
return {"text/plain": text, "text/html": html}


async def group_tree_async(
group: AsyncGroup,
max_depth: int | None = None,
max_nodes: int = 500,
plain: bool = False,
) -> TreeRepr:
members: list[tuple[str, Any]] = []
truncated = False
async for item in group.members(max_depth=max_depth):
if len(members) == max_nodes:
truncated = True
break
members.append(item)
members.sort(key=lambda key_node: key_node[0])

async def group_tree_async(group: AsyncGroup, max_depth: int | None = None) -> TreeRepr:
tree = rich.tree.Tree(label=f"[bold]{group.name}[/bold]")
nodes = {"": tree}
members = sorted([x async for x in group.members(max_depth=max_depth)])
# Set up styling tokens: ANSI bold for terminals, HTML <b> for Jupyter,
# or empty strings when plain=True (useful for LLMs, logging, files).
if plain:
ansi_open = ansi_close = html_open = html_close = ""
else:
# Avoid emitting ANSI escape codes when output is piped or in CI.
use_ansi = sys.stdout.isatty()
ansi_open = "\x1b[1m" if use_ansi else ""
ansi_close = "\x1b[0m" if use_ansi else ""
html_open = "<b>"
html_close = "</b>"

# Group members by parent key so we can render the tree level by level.
nodes: dict[str, list[tuple[str, Any]]] = {}
for key, node in members:
if key.count("/") == 0:
parent_key = ""
else:
parent_key = key.rsplit("/", 1)[0]
parent = nodes[parent_key]
nodes.setdefault(parent_key, []).append((key, node))

# We want what the spec calls the node "name", the part excluding all leading
# /'s and path segments. But node.name includes all that, so we build it here.
# Render the tree iteratively (not recursively) to avoid hitting
# Python's recursion limit on deeply nested hierarchies.
# Each stack frame is (prefix_string, remaining_children_at_this_level).
text_lines = [f"{ansi_open}{group.name}{ansi_close}"]
html_lines = [f"{html_open}{html_escape(group.name)}{html_close}"]
stack = [("", deque(nodes.get("", [])))]
while stack:
prefix, remaining = stack[-1]
if not remaining:
stack.pop()
continue
key, node = remaining.popleft()
name = key.rsplit("/")[-1]
escaped_name = html_escape(name)
# if we popped the last item then remaining will
# now be empty - that's how we got past the if not remaining
# above, but this can still be true.
is_last = not remaining
connector = "└── " if is_last else "├── "
if isinstance(node, AsyncGroup):
label = f"[bold]{name}[/bold]"
text_lines.append(f"{prefix}{connector}{ansi_open}{name}{ansi_close}")
html_lines.append(f"{prefix}{connector}{html_open}{escaped_name}{html_close}")
else:
label = f"[bold]{name}[/bold] {node.shape} {node.dtype}"
nodes[key] = parent.add(label)

return TreeRepr(tree)
text_lines.append(
f"{prefix}{connector}{ansi_open}{name}{ansi_close} {node.shape} {node.dtype}"
)
html_lines.append(
f"{prefix}{connector}{html_open}{escaped_name}{html_close}"
f" {html_escape(str(node.shape))} {html_escape(str(node.dtype))}"
)
# Descend into children with an accumulated prefix:
# Example showing how prefix accumulates:
# /
# ├── a prefix = ""
# │ ├── b prefix = "" + "│ "
# │ │ └── x prefix = "" + "│ " + "│ "
# │ └── c prefix = "" + "│ "
# └── d prefix = ""
# └── e prefix = "" + " "
if children := nodes.get(key, []):
if is_last:
child_prefix = prefix + " "
else:
child_prefix = prefix + "│ "
stack.append((child_prefix, deque(children)))
text = "\n".join(text_lines) + "\n"
html = "\n".join(html_lines) + "\n"
note = (
f"Truncated at max_nodes={max_nodes}, some nodes and their children may be missing\n"
if truncated
else ""
)
return TreeRepr(text, html, truncated=note)
38 changes: 30 additions & 8 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,19 +1588,29 @@ async def array_values(
async for _, array in self.arrays():
yield array

async def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
async def tree(
self,
expand: bool | None = None,
level: int | None = None,
max_nodes: int = 500,
plain: bool = False,
Comment thread
maxrjones marked this conversation as resolved.
) -> Any:
"""
Return a tree-like representation of a hierarchy.

This requires the optional ``rich`` dependency.

Parameters
----------
expand : bool, optional
This keyword is not yet supported. A NotImplementedError is raised if
it's used.
level : int, optional
The maximum depth below this Group to display in the tree.
max_nodes : int
Maximum number of nodes to display before truncating. Default is 500.
plain : bool, optional
If True, return a plain-text tree without ANSI styling. This is
useful when the output will be consumed by an LLM or written to a
file. Default is False.

Returns
-------
Expand All @@ -1611,7 +1621,7 @@ async def tree(self, expand: bool | None = None, level: int | None = None) -> An

if expand is not None:
raise NotImplementedError("'expand' is not yet implemented.")
return await group_tree_async(self, max_depth=level)
return await group_tree_async(self, max_depth=level, max_nodes=max_nodes, plain=plain)

async def empty(self, *, name: str, shape: tuple[int, ...], **kwargs: Any) -> AnyAsyncArray:
"""Create an empty array with the specified shape in this Group. The contents will
Expand Down Expand Up @@ -2371,26 +2381,38 @@ def array_values(self) -> Generator[AnyArray, None]:
for _, array in self.arrays():
yield array

def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
def tree(
self,
expand: bool | None = None,
level: int | None = None,
max_nodes: int = 500,
Comment thread
ianhi marked this conversation as resolved.
plain: bool = False,
) -> Any:
"""
Return a tree-like representation of a hierarchy.

This requires the optional ``rich`` dependency.

Parameters
----------
expand : bool, optional
This keyword is not yet supported. A NotImplementedError is raised if
it's used.
level : int, optional
The maximum depth below this Group to display in the tree.
max_nodes : int
Maximum number of nodes to display before truncating. Default is 500.
plain : bool, optional
If True, return a plain-text tree without ANSI styling. This is
useful when the output will be consumed by an LLM or written to a
file. Default is False.

Returns
-------
TreeRepr
A pretty-printable object displaying the hierarchy.
"""
return self._sync(self._async_group.tree(expand=expand, level=level))
return self._sync(
self._async_group.tree(expand=expand, level=level, max_nodes=max_nodes, plain=plain)
)

def create_group(self, name: str, **kwargs: Any) -> Group:
"""Create a sub-group.
Expand Down
1 change: 0 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ def test_load_local(tmp_path: Path, path: str | None, load_read_only: bool) -> N


def test_tree() -> None:
pytest.importorskip("rich")
g1 = zarr.group()
g1.create_group("foo")
g3 = g1.create_group("bar")
Expand Down
Loading
Loading