Skip to content

Commit 759e66f

Browse files
authored
Merge pull request #6478 from Textualize/update-classes
Update classes
2 parents 634ca6e + 3a3a76c commit 759e66f

4 files changed

Lines changed: 53 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/)
66
and this project adheres to [Semantic Versioning](http://semver.org/).
77

8+
## Unreleased
9+
10+
### Added
11+
12+
- Added `DOM.update_classes` https://github.com/Textualize/textual/pull/6478
13+
814
## [8.2.3] - 2026-04-05
915

1016
### Changed

src/textual/dom.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Callable,
1717
ClassVar,
1818
Iterable,
19+
Mapping,
1920
Sequence,
2021
Type,
2122
TypeVar,
@@ -1762,6 +1763,34 @@ def set_class(self, add: bool, *class_names: str, update: bool = True) -> Self:
17621763
self.remove_class(*class_names, update=update)
17631764
return self
17641765

1766+
def update_classes(
1767+
self, classes: Mapping[str, bool], update: bool = True, animate: bool = True
1768+
) -> Self:
1769+
"""Update classes in an atomic batch.
1770+
1771+
Args:
1772+
classes: A mapping of class name on to a boolean where `True` adds
1773+
to the current classes, and `False` removes.
1774+
update: Also update styles.
1775+
animate: Enable any CSS animation?
1776+
1777+
Returns:
1778+
Self
1779+
"""
1780+
1781+
add_classes: set[str] = set()
1782+
remove_classes: set[str] = set()
1783+
adds = (remove_classes.add, add_classes.add)
1784+
for class_name, add in classes.items():
1785+
adds[add](class_name)
1786+
1787+
new_classes = (self._classes | add_classes) - remove_classes
1788+
if self._classes != new_classes:
1789+
self._classes = new_classes
1790+
if update:
1791+
self.update_node_styles(animate=animate)
1792+
return self
1793+
17651794
def set_classes(self, classes: str | Iterable[str]) -> Self:
17661795
"""Replace all classes.
17671796

src/textual/layouts/grid.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def apply_height_limits(widget: Widget, height: int) -> int:
318318
max_column = len(columns) - 1
319319
max_row = len(rows) - 1
320320

321+
stretch_height = self.stretch_height and len(children) > 1
322+
321323
for widget, (column, row, column_span, row_span) in cell_size_map.items():
322324
x = columns[column][0]
323325
if row > max_row:
@@ -336,9 +338,8 @@ def apply_height_limits(widget: Widget, height: int) -> int:
336338
greedy=greedy,
337339
)
338340

339-
if self.stretch_height and len(children) > 1:
340-
if box_height <= cell_size.height:
341-
box_height = Fraction(cell_size.height)
341+
if stretch_height and box_height <= cell_size.height:
342+
box_height = Fraction(cell_size.height)
342343

343344
region = (
344345
Region(

src/textual/screen.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,36 +1527,24 @@ async def _on_resize(self, event: events.Resize) -> None:
15271527
else self.VERTICAL_BREAKPOINTS
15281528
) or []
15291529

1530-
width, height = event.size
1531-
15321530
if horizontal_breakpoints or vertical_breakpoints:
1531+
width, height = event.size
1532+
breakpoints = {
1533+
breakpoint: False
1534+
for _, breakpoint in (horizontal_breakpoints + vertical_breakpoints)
1535+
}
15331536

1534-
remove_breakpoint_classes = {
1535-
breakpoint for _, breakpoint in horizontal_breakpoints
1536-
} | {breakpoint for _, breakpoint in vertical_breakpoints}
1537-
remove_breakpoint_classes = self._classes.intersection(
1538-
remove_breakpoint_classes
1539-
)
1540-
1541-
breakpoint_classes: set[str] = set()
1537+
for breakpoint in self._get_breakpoint_classes(
1538+
width, horizontal_breakpoints
1539+
):
1540+
breakpoints[breakpoint] = True
15421541

1543-
if horizontal_breakpoints:
1544-
breakpoint_classes |= self._get_breakpoint_classes(
1545-
width, horizontal_breakpoints
1546-
)
1547-
if vertical_breakpoints:
1548-
breakpoint_classes |= self._get_breakpoint_classes(
1549-
height, vertical_breakpoints
1550-
)
1542+
for breakpoint in self._get_breakpoint_classes(
1543+
height, vertical_breakpoints
1544+
):
1545+
breakpoints[breakpoint] = True
15511546

1552-
remove_breakpoint_classes -= breakpoint_classes
1553-
classes = self._classes.copy()
1554-
if remove_breakpoint_classes:
1555-
self._classes.difference_update(remove_breakpoint_classes)
1556-
if breakpoint_classes:
1557-
self._classes.update(breakpoint_classes)
1558-
if self._classes != classes:
1559-
self.update_node_styles(animate=False)
1547+
self.update_classes(breakpoints, animate=False)
15601548

15611549
def _get_breakpoint_classes(
15621550
self, dimension: int, breakpoints: list[tuple[int, str]]

0 commit comments

Comments
 (0)