Skip to content

Commit 813770a

Browse files
FBruzzesirich-iannonemachow
authored
fix: Address slicing for pyarrow array in data_color (#741)
* fix: Address slicing pyarrow array in data_color * rename function * collapse pyarrow into one func: --------- Co-authored-by: Richard Iannone <riannone@me.com> Co-authored-by: Michael Chow <mc_al_github@fastmail.com>
1 parent ff3709c commit 813770a

4 files changed

Lines changed: 219 additions & 6 deletions

File tree

great_tables/_data_color/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from typing_extensions import TypeAlias
66

77
from great_tables._locations import RowSelectExpr, resolve_cols_c, resolve_rows_i
8-
from great_tables._tbl_data import DataFrameLike, SelectExpr, get_column_names, is_na
8+
from great_tables._tbl_data import (
9+
DataFrameLike,
10+
SelectExpr,
11+
get_column_names,
12+
get_rows,
13+
is_na,
14+
to_list,
15+
)
916
from great_tables.loc import body
1017
from great_tables.style import fill, text
1118

@@ -227,7 +234,7 @@ def data_color(
227234
# For each column targeted, get the data values as a new list object
228235
for col in columns_resolved:
229236
# This line handles both pandas and polars dataframes
230-
column_vals = data_table[col][row_pos].to_list()
237+
column_vals = to_list(get_rows(data_table[col], indexes=row_pos))
231238

232239
# Filter out NA values from `column_vals`
233240
filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)]

great_tables/_tbl_data.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
PlSelectExpr = Selector
3333
PlExpr = pl.Expr
3434

35-
PdSeries = pd.Series
35+
PdSeries = pd.Series[Any]
3636
PlSeries = pl.Series
37-
PyArrowArray = pa.Array
38-
PyArrowChunkedArray = pa.ChunkedArray
37+
PyArrowArray = pa.Array[Any]
38+
PyArrowChunkedArray = pa.ChunkedArray[Any]
3939

4040
PdNA = pd.NA
4141
PlNull = pl.Null
@@ -769,7 +769,7 @@ def _(df: PyArrowTable, x: Any) -> bool:
769769
import pyarrow as pa
770770

771771
arr = pa.array([x])
772-
return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0]
772+
return arr.is_null(nan_is_null=True).to_pylist()[0]
773773

774774

775775
@singledispatch
@@ -942,3 +942,25 @@ def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowTable]) -> dict[st
942942
)
943943

944944
return {col: res.column(col)[0].as_py() for col in res.column_names}
945+
946+
947+
@singledispatch
948+
def get_rows(ser: SeriesLike, indexes: list[int]) -> SeriesLike:
949+
"""Returns values of the series at `indexes` position.`"""
950+
raise NotImplementedError(f"Unsupported type: {type(ser)}")
951+
952+
953+
@get_rows.register
954+
def _(ser: PdSeries, indexes: list[int]) -> PdSeries:
955+
return ser.iloc[indexes]
956+
957+
958+
@get_rows.register
959+
def _(ser: PlSeries, indexes: list[int]) -> PlSeries:
960+
return ser[indexes]
961+
962+
963+
@get_rows.register(PyArrowArray)
964+
@get_rows.register(PyArrowChunkedArray)
965+
def _(ser: Any, indexes: list[int]) -> PyArrowArray | PyArrowChunkedArray:
966+
return ser.take(indexes)

tests/data_color/__snapshots__/test_data_color.ambr

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,32 @@
123123
</tbody>
124124
'''
125125
# ---
126+
# name: test_data_color_autocolor_text_false[pyarrow]
127+
'''
128+
<tbody class="gt_table_body">
129+
<tr>
130+
<td class="gt_row gt_right">0.1111</td>
131+
<td class="gt_row gt_left">apricot</td>
132+
<td style="background-color: #ff0000;" class="gt_row gt_right">49.95</td>
133+
</tr>
134+
<tr>
135+
<td class="gt_row gt_right">2.222</td>
136+
<td class="gt_row gt_left">banana</td>
137+
<td style="background-color: #5c5200;" class="gt_row gt_right">17.95</td>
138+
</tr>
139+
<tr>
140+
<td class="gt_row gt_right">33.33</td>
141+
<td class="gt_row gt_left">coconut</td>
142+
<td style="background-color: #077c00;" class="gt_row gt_right">1.39</td>
143+
</tr>
144+
<tr>
145+
<td class="gt_row gt_right">444.4</td>
146+
<td class="gt_row gt_left">durian</td>
147+
<td style="background-color: #0000FF;" class="gt_row gt_right">65100</td>
148+
</tr>
149+
</tbody>
150+
'''
151+
# ---
126152
# name: test_data_color_colorbrewer_snap
127153
'''
128154
<tbody class="gt_table_body">
@@ -206,6 +232,32 @@
206232
</tbody>
207233
'''
208234
# ---
235+
# name: test_data_color_domain_na_color_reverse_snap[pyarrow]
236+
'''
237+
<tbody class="gt_table_body">
238+
<tr>
239+
<td class="gt_row gt_right">0.1111</td>
240+
<td class="gt_row gt_left">apricot</td>
241+
<td style="color: #000000; background-color: #ff0000;" class="gt_row gt_right">49.95</td>
242+
</tr>
243+
<tr>
244+
<td class="gt_row gt_right">2.222</td>
245+
<td class="gt_row gt_left">banana</td>
246+
<td style="color: #FFFFFF; background-color: #5c5200;" class="gt_row gt_right">17.95</td>
247+
</tr>
248+
<tr>
249+
<td class="gt_row gt_right">33.33</td>
250+
<td class="gt_row gt_left">coconut</td>
251+
<td style="color: #FFFFFF; background-color: #077c00;" class="gt_row gt_right">1.39</td>
252+
</tr>
253+
<tr>
254+
<td class="gt_row gt_right">444.4</td>
255+
<td class="gt_row gt_left">durian</td>
256+
<td style="color: #FFFFFF; background-color: #0000FF;" class="gt_row gt_right">65100</td>
257+
</tr>
258+
</tbody>
259+
'''
260+
# ---
209261
# name: test_data_color_domain_na_color_snap[pandas]
210262
'''
211263
<tbody class="gt_table_body">
@@ -258,6 +310,32 @@
258310
</tbody>
259311
'''
260312
# ---
313+
# name: test_data_color_domain_na_color_snap[pyarrow]
314+
'''
315+
<tbody class="gt_table_body">
316+
<tr>
317+
<td class="gt_row gt_right">0.1111</td>
318+
<td class="gt_row gt_left">apricot</td>
319+
<td style="color: #FFFFFF; background-color: #008000;" class="gt_row gt_right">49.95</td>
320+
</tr>
321+
<tr>
322+
<td class="gt_row gt_right">2.222</td>
323+
<td class="gt_row gt_left">banana</td>
324+
<td style="color: #FFFFFF; background-color: #a32e00;" class="gt_row gt_right">17.95</td>
325+
</tr>
326+
<tr>
327+
<td class="gt_row gt_right">33.33</td>
328+
<td class="gt_row gt_left">coconut</td>
329+
<td style="color: #000000; background-color: #f80400;" class="gt_row gt_right">1.39</td>
330+
</tr>
331+
<tr>
332+
<td class="gt_row gt_right">444.4</td>
333+
<td class="gt_row gt_left">durian</td>
334+
<td style="color: #FFFFFF; background-color: #0000FF;" class="gt_row gt_right">65100</td>
335+
</tr>
336+
</tbody>
337+
'''
338+
# ---
261339
# name: test_data_color_overlapping_domain[pandas]
262340
'''
263341
<tbody class="gt_table_body">
@@ -310,6 +388,32 @@
310388
</tbody>
311389
'''
312390
# ---
391+
# name: test_data_color_overlapping_domain[pyarrow]
392+
'''
393+
<tbody class="gt_table_body">
394+
<tr>
395+
<td class="gt_row gt_right">0.1111</td>
396+
<td class="gt_row gt_left">apricot</td>
397+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">49.95</td>
398+
</tr>
399+
<tr>
400+
<td class="gt_row gt_right">2.222</td>
401+
<td class="gt_row gt_left">banana</td>
402+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">17.95</td>
403+
</tr>
404+
<tr>
405+
<td class="gt_row gt_right">33.33</td>
406+
<td class="gt_row gt_left">coconut</td>
407+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">1.39</td>
408+
</tr>
409+
<tr>
410+
<td class="gt_row gt_right">444.4</td>
411+
<td class="gt_row gt_left">durian</td>
412+
<td style="color: #FFFFFF; background-color: #673498;" class="gt_row gt_right">65100</td>
413+
</tr>
414+
</tbody>
415+
'''
416+
# ---
313417
# name: test_data_color_palette_snap[pandas]
314418
'''
315419
<tbody class="gt_table_body">
@@ -362,6 +466,32 @@
362466
</tbody>
363467
'''
364468
# ---
469+
# name: test_data_color_palette_snap[pyarrow]
470+
'''
471+
<tbody class="gt_table_body">
472+
<tr>
473+
<td style="color: #000000; background-color: #ff0000;" class="gt_row gt_right">0.1111</td>
474+
<td class="gt_row gt_left">apricot</td>
475+
<td style="color: #000000; background-color: #ff0000;" class="gt_row gt_right">49.95</td>
476+
</tr>
477+
<tr>
478+
<td style="color: #000000; background-color: #fe0100;" class="gt_row gt_right">2.222</td>
479+
<td class="gt_row gt_left">banana</td>
480+
<td style="color: #000000; background-color: #ff0000;" class="gt_row gt_right">17.95</td>
481+
</tr>
482+
<tr>
483+
<td style="color: #000000; background-color: #ec0a00;" class="gt_row gt_right">33.33</td>
484+
<td class="gt_row gt_left">coconut</td>
485+
<td style="color: #000000; background-color: #ff0000;" class="gt_row gt_right">1.39</td>
486+
</tr>
487+
<tr>
488+
<td style="color: #FFFFFF; background-color: #008000;" class="gt_row gt_right">444.4</td>
489+
<td class="gt_row gt_left">durian</td>
490+
<td style="color: #FFFFFF; background-color: #008000;" class="gt_row gt_right">65100</td>
491+
</tr>
492+
</tbody>
493+
'''
494+
# ---
365495
# name: test_data_color_pd_cols_rows_snap
366496
'''
367497
<tbody class="gt_table_body">
@@ -495,6 +625,32 @@
495625
</tbody>
496626
'''
497627
# ---
628+
# name: test_data_color_simple_exibble_snap[pyarrow]
629+
'''
630+
<tbody class="gt_table_body">
631+
<tr>
632+
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_right">0.1111</td>
633+
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_left">apricot</td>
634+
<td style="color: #FFFFFF; background-color: #010001;" class="gt_row gt_right">49.95</td>
635+
</tr>
636+
<tr>
637+
<td style="color: #FFFFFF; background-color: #070304;" class="gt_row gt_right">2.222</td>
638+
<td style="color: #000000; background-color: #4cbd81;" class="gt_row gt_left">banana</td>
639+
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_right">17.95</td>
640+
</tr>
641+
<tr>
642+
<td style="color: #FFFFFF; background-color: #752b38;" class="gt_row gt_right">33.33</td>
643+
<td style="color: #FFFFFF; background-color: #9653ca;" class="gt_row gt_left">coconut</td>
644+
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_right">1.39</td>
645+
</tr>
646+
<tr>
647+
<td style="color: #000000; background-color: #9e9e9e;" class="gt_row gt_right">444.4</td>
648+
<td style="color: #000000; background-color: #9e9e9e;" class="gt_row gt_left">durian</td>
649+
<td style="color: #000000; background-color: #9e9e9e;" class="gt_row gt_right">65100</td>
650+
</tr>
651+
</tbody>
652+
'''
653+
# ---
498654
# name: test_data_color_subset_domain[pandas]
499655
'''
500656
<tbody class="gt_table_body">
@@ -547,6 +703,32 @@
547703
</tbody>
548704
'''
549705
# ---
706+
# name: test_data_color_subset_domain[pyarrow]
707+
'''
708+
<tbody class="gt_table_body">
709+
<tr>
710+
<td class="gt_row gt_right">0.1111</td>
711+
<td class="gt_row gt_left">apricot</td>
712+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">49.95</td>
713+
</tr>
714+
<tr>
715+
<td class="gt_row gt_right">2.222</td>
716+
<td class="gt_row gt_left">banana</td>
717+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">17.95</td>
718+
</tr>
719+
<tr>
720+
<td class="gt_row gt_right">33.33</td>
721+
<td class="gt_row gt_left">coconut</td>
722+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">1.39</td>
723+
</tr>
724+
<tr>
725+
<td class="gt_row gt_right">444.4</td>
726+
<td class="gt_row gt_left">durian</td>
727+
<td style="color: #000000; background-color: #FF0000;" class="gt_row gt_right">65100</td>
728+
</tr>
729+
</tbody>
730+
'''
731+
# ---
550732
# name: test_data_color_viridis_snap
551733
'''
552734
<tbody class="gt_table_body">

tests/data_color/test_data_color.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pandas as pd
55
import polars as pl
6+
import pyarrow as pa
67
import pytest
78

89
from great_tables import GT, style
@@ -16,6 +17,7 @@
1617
params_frames = [
1718
pytest.param(pd.DataFrame, id="pandas"),
1819
pytest.param(pl.DataFrame, id="polars"),
20+
pytest.param(pa.table, id="pyarrow"),
1921
]
2022

2123

0 commit comments

Comments
 (0)