Skip to content

Commit f0ae1a7

Browse files
committed
Modify the locations parameter of opt_all_caps() to accept Loc objects
1 parent c42c1ad commit f0ae1a7

3 files changed

Lines changed: 120 additions & 16 deletions

File tree

great_tables/_options.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,8 @@ def opt_all_caps(
893893
894894
locations
895895
Which locations should undergo this text transformation? By default it includes all of
896-
the `"column_labels"`, the `"stub"`, and the `"row_group"` locations. However, we could
897-
just choose one or two of those.
896+
the `loc.column_labels`, the `loc.stub"`, and the `loc.row_group` locations. However, we
897+
could just choose one or two of those.
898898
899899
Returns
900900
-------
@@ -909,7 +909,7 @@ def opt_all_caps(
909909
in all row groups is transformed to all caps using the `opt_all_caps()` method.
910910
911911
```{python}
912-
from great_tables import GT, exibble, md
912+
from great_tables import GT, exibble, loc, md
913913
914914
(
915915
GT(
@@ -927,16 +927,49 @@ def opt_all_caps(
927927
.opt_all_caps()
928928
)
929929
```
930+
`opt_all_caps()` accepts a `locations` parameter that allows us to specify which components
931+
should be transformed. For example, if we only want to ensure that all text in the stub and all
932+
row groups is converted to all caps:
933+
```{python}
934+
(
935+
GT(
936+
exibble[["num", "char", "currency", "row", "group"]],
937+
rowname_col="row",
938+
groupname_col="group"
939+
)
940+
.tab_header(
941+
title=md("Data listing from **exibble**"),
942+
subtitle=md("`exibble` is a **Great Tables** dataset.")
943+
)
944+
.fmt_number(columns="num")
945+
.fmt_currency(columns="currency")
946+
.tab_source_note(source_note="This is only a subset of the dataset.")
947+
.opt_all_caps(locations=[loc.stub, loc.row_group])
948+
)
949+
```
930950
"""
951+
# Importing `great_tables._locations` at the top will cause a circular import error.
952+
# The type annotation for `locations` should be:
953+
# `Loc | list[Loc] = [LocColumnLabels, LocStub, LocRowGroups]`
954+
from great_tables._locations import Loc, LocColumnLabels, LocStub, LocRowGroups
931955

932-
# If providing a scalar string value, normalize it to be in a list
933-
if not isinstance(locations, list):
934-
locations = _utils._str_scalar_to_list(cast(str, locations))
956+
if not locations:
957+
locations = [LocColumnLabels, LocStub, LocRowGroups]
935958

936-
# Ensure that the `locations` value is a list of strings
937-
_utils._assert_str_list(locations)
938-
939-
# TODO: Ensure that all values within `locations` are valid
959+
# If providing a Loc object, normalize it to be in a list
960+
if not isinstance(locations, list):
961+
locations = [locations]
962+
963+
# Ensure that all values within `locations` are valid
964+
# A `try-except` block is needed here because the first argument of `issubclass()` must be a
965+
# class.
966+
for location in locations:
967+
try:
968+
issubclass(location, Loc)
969+
except TypeError as exc:
970+
raise AssertionError(
971+
"Only `loc.column_labels`, `loc.stub` and `loc.row_group` are allowed in the locations."
972+
) from exc
940973

941974
# if `all_caps` is False, reset options to default, or, set new options
942975
# for `locations` selected
@@ -956,23 +989,23 @@ def opt_all_caps(
956989

957990
info = [
958991
(
959-
"column_labels",
992+
LocColumnLabels,
960993
{
961994
"column_labels_font_size": "80%",
962995
"column_labels_font_weight": "bolder",
963996
"column_labels_text_transform": "uppercase",
964997
},
965998
),
966999
(
967-
"stub",
1000+
LocStub,
9681001
{
9691002
"stub_font_size": "80%",
9701003
"stub_font_weight": "bolder",
9711004
"stub_text_transform": "uppercase",
9721005
},
9731006
),
9741007
(
975-
"row_group",
1008+
LocRowGroups,
9761009
{
9771010
"row_group_font_size": "80%",
9781011
"row_group_font_weight": "bolder",

great_tables/loc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
LocBody as body,
55
LocStub as stub,
66
LocColumnLabels as column_labels,
7+
LocRowGroups as row_group,
78
)
89

9-
__all__ = ("body", "stub", "column_labels")
10+
__all__ = ("body", "stub", "column_labels", "row_group")

tests/test_options.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pandas as pd
22
import pytest
3-
from great_tables import GT, exibble, md
3+
from great_tables import GT, exibble, loc, md
44
from great_tables._scss import compile_scss
55
from great_tables._gt_data import default_fonts_list
66

@@ -328,7 +328,6 @@ def test_scss_from_opt_table_outline(gt_tbl: GT, snapshot):
328328

329329

330330
def test_opt_table_font_add_font():
331-
332331
gt_tbl = GT(exibble).opt_table_font(font="Arial", weight="bold", style="italic")
333332

334333
assert gt_tbl._options.table_font_names.value == ["Arial"] + default_fonts_list
@@ -369,3 +368,74 @@ def test_opt_table_font_raises():
369368
GT(exibble).opt_table_font(font=None, stack=None)
370369

371370
assert "Either `font=` or `stack=` must be provided." in exc_info.value.args[0]
371+
372+
373+
def test_opt_all_caps(gt_tbl: GT):
374+
tbl = gt_tbl.opt_all_caps(locations=loc.column_labels)
375+
376+
assert tbl._options.column_labels_font_size.value == "80%"
377+
assert tbl._options.column_labels_font_weight.value == "bolder"
378+
assert tbl._options.column_labels_text_transform.value == "uppercase"
379+
380+
tbl = gt_tbl.opt_all_caps(locations=[loc.column_labels, loc.stub])
381+
382+
assert tbl._options.column_labels_font_size.value == "80%"
383+
assert tbl._options.column_labels_font_weight.value == "bolder"
384+
assert tbl._options.column_labels_text_transform.value == "uppercase"
385+
386+
assert tbl._options.stub_font_size.value == "80%"
387+
assert tbl._options.stub_font_weight.value == "bolder"
388+
assert tbl._options.stub_text_transform.value == "uppercase"
389+
390+
tbl = gt_tbl.opt_all_caps(locations=[loc.column_labels, loc.stub, loc.row_group])
391+
392+
assert tbl._options.column_labels_font_size.value == "80%"
393+
assert tbl._options.column_labels_font_weight.value == "bolder"
394+
assert tbl._options.column_labels_text_transform.value == "uppercase"
395+
396+
assert tbl._options.stub_font_size.value == "80%"
397+
assert tbl._options.stub_font_weight.value == "bolder"
398+
assert tbl._options.stub_text_transform.value == "uppercase"
399+
400+
assert tbl._options.row_group_font_size.value == "80%"
401+
assert tbl._options.row_group_font_weight.value == "bolder"
402+
assert tbl._options.row_group_text_transform.value == "uppercase"
403+
404+
# Activate the following tests once the circular import issue is resolved.
405+
# tbl = gt_tbl.opt_all_caps()
406+
407+
# assert tbl._options.column_labels_font_size.value == "80%"
408+
# assert tbl._options.column_labels_font_weight.value == "bolder"
409+
# assert tbl._options.column_labels_text_transform.value == "uppercase"
410+
411+
# assert tbl._options.stub_font_size.value == "80%"
412+
# assert tbl._options.stub_font_weight.value == "bolder"
413+
# assert tbl._options.stub_text_transform.value == "uppercase"
414+
415+
# assert tbl._options.row_group_font_size.value == "80%"
416+
# assert tbl._options.row_group_font_weight.value == "bolder"
417+
# assert tbl._options.row_group_text_transform.value == "uppercase"
418+
419+
# tbl = gt_tbl.opt_all_caps(all_caps=False)
420+
421+
# assert tbl._options.column_labels_font_size.value == "100%"
422+
# assert tbl._options.column_labels_font_weight.value == "normal"
423+
# assert tbl._options.column_labels_text_transform.value == "inherit"
424+
425+
# assert tbl._options.stub_font_size.value == "100%"
426+
# assert tbl._options.stub_font_weight.value == "initial"
427+
# assert tbl._options.stub_text_transform.value == "inherit"
428+
429+
# assert tbl._options.row_group_font_size.value == "100%"
430+
# assert tbl._options.row_group_font_weight.value == "initial"
431+
# assert tbl._options.row_group_text_transform.value == "inherit"
432+
433+
434+
def test_opt_all_caps_raises(gt_tbl: GT):
435+
with pytest.raises(AssertionError) as exc_info:
436+
gt_tbl.opt_all_caps(locations="column_labels")
437+
438+
assert (
439+
"Only `loc.column_labels`, `loc.stub` and `loc.row_group` are allowed in the locations."
440+
in exc_info.value.args[0]
441+
)

0 commit comments

Comments
 (0)