Skip to content

Commit 49423f8

Browse files
authored
GH-33459: [C++][Python] Support step >= 1 in list_slice kernel (#48769)
### Rationale for this change Closes ARROW-18281, which has been open since 2022. The `list_slice` kernel currently rejects `start == stop`, but should return empty lists instead (following Python slicing semantics). The implementation already handles this case correctly. When ARROW-18282 added step support, `bit_util::CeilDiv(stop - start, step)` naturally returns 0 for `start == stop`, producing empty lists. The only issue was the validation check (`start >= stop`) that prevented this from working. ### What changes are included in this PR? - Changed validation from `start >= stop` to `start > stop` - Updated error message - Added test cases ### Are these changes tested? Yes, tests were added. ### Are there any user-facing changes? Yes. ```python import pyarrow.compute as pc pc.list_slice([[1,2,3]], 0, 0) ``` Before: ``` pyarrow.lib.ArrowInvalid: `start`(0) should be greater than 0 and smaller than `stop`(0) ``` After: ``` <pyarrow.lib.ListArray object at 0x1a01b8b20> [ [] ] ``` * GitHub Issue: #33459 Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: AlenkaF <frim.alenka@gmail.com>
1 parent 6a2d09b commit 49423f8

3 files changed

Lines changed: 29 additions & 26 deletions

File tree

cpp/src/arrow/compute/kernels/scalar_nested.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ Result<TypeHolder> ListSliceOutputType(const ListSliceOptions& opts,
162162
"`stop` being set.");
163163
}
164164
if (opts.step < 1) {
165-
return Status::Invalid("`step` must be >= 1, got: ", opts.step);
165+
return Status::Invalid("`step` must be greater than or equal to 1, got: ",
166+
opts.step);
166167
}
167168
const auto length = ListSliceLength(opts.start, opts.step, *stop);
168169
return fixed_size_list(value_type, static_cast<int32_t>(length));
@@ -183,14 +184,15 @@ struct ListSlice {
183184
const auto* list_type = checked_cast<const BaseListType*>(list_array.type);
184185

185186
// Pre-conditions
186-
if (opts.start < 0 || (opts.stop.has_value() && opts.start >= opts.stop.value())) {
187-
// TODO(ARROW-18281): support start == stop which should give empty lists
188-
return Status::Invalid("`start`(", opts.start,
189-
") should be greater than 0 and smaller than `stop`(",
190-
ToString(opts.stop), ")");
187+
if (opts.start < 0 || (opts.stop.has_value() && opts.start > opts.stop.value())) {
188+
return Status::Invalid(
189+
"`start`(", opts.start,
190+
") should be greater than or equal to 0 and not greater than `stop`(",
191+
ToString(opts.stop), ")");
191192
}
192193
if (opts.step < 1) {
193-
return Status::Invalid("`step` must be >= 1, got: ", opts.step);
194+
return Status::Invalid("`step` must be greater than or equal to 1, got: ",
195+
opts.step);
194196
}
195197

196198
auto* pool = ctx->memory_pool();

cpp/src/arrow/compute/kernels/scalar_nested_test.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ TEST(TestScalarNested, ListSliceVariableOutput) {
176176
auto input = ArrayFromJSON(fixed_size_list(int32(), 1), "[[1]]");
177177
auto expected = ArrayFromJSON(list(int32()), "[[1]]");
178178
CheckScalarUnary("list_slice", input, expected, &args);
179+
180+
args.start = 0;
181+
args.stop = 0;
182+
auto input_empty = ArrayFromJSON(list(int32()), "[[1, 2, 3], [4, 5], null]");
183+
auto expected_empty = ArrayFromJSON(list(int32()), "[[], [], null]");
184+
CheckScalarUnary("list_slice", input_empty, expected_empty, &args);
179185
}
180186

181187
TEST(TestScalarNested, ListSliceFixedOutput) {
@@ -315,22 +321,17 @@ TEST(TestScalarNested, ListSliceBadParameters) {
315321
EXPECT_RAISES_WITH_MESSAGE_THAT(
316322
Invalid,
317323
::testing::HasSubstr(
318-
"`start`(-1) should be greater than 0 and smaller than `stop`(1)"),
324+
"`start`(-1) should be greater than or equal to 0 and not greater than "
325+
"`stop`(1)"),
319326
CallFunction("list_slice", {input}, &args));
320327
// start greater than stop
321328
args.start = 1;
322329
args.stop = 0;
323330
EXPECT_RAISES_WITH_MESSAGE_THAT(
324331
Invalid,
325332
::testing::HasSubstr(
326-
"`start`(1) should be greater than 0 and smaller than `stop`(0)"),
327-
CallFunction("list_slice", {input}, &args));
328-
// start same as stop
329-
args.stop = args.start;
330-
EXPECT_RAISES_WITH_MESSAGE_THAT(
331-
Invalid,
332-
::testing::HasSubstr(
333-
"`start`(1) should be greater than 0 and smaller than `stop`(1)"),
333+
"`start`(1) should be greater than or equal to 0 and not greater than "
334+
"`stop`(0)"),
334335
CallFunction("list_slice", {input}, &args));
335336
// stop not set and FixedSizeList requested with variable sized input
336337
args.stop = std::nullopt;
@@ -343,9 +344,9 @@ TEST(TestScalarNested, ListSliceBadParameters) {
343344
args.start = 0;
344345
args.stop = 2;
345346
args.step = 0;
346-
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
347-
::testing::HasSubstr("`step` must be >= 1, got: 0"),
348-
CallFunction("list_slice", {input}, &args));
347+
EXPECT_RAISES_WITH_MESSAGE_THAT(
348+
Invalid, ::testing::HasSubstr("`step` must be greater than or equal to 1, got: 0"),
349+
CallFunction("list_slice", {input}, &args));
349350
}
350351

351352
TEST(TestScalarNested, StructField) {

python/pyarrow/tests/test_compute.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3930,7 +3930,8 @@ def test_list_slice_output_fixed(start, stop, step, expected, value_type,
39303930
(0, 1,),
39313931
(0, 2,),
39323932
(1, 2,),
3933-
(2, 4,)
3933+
(2, 4,),
3934+
(0, 0,)
39343935
))
39353936
@pytest.mark.parametrize("step", (1, 2))
39363937
@pytest.mark.parametrize("value_type", (pa.string, pa.int16, pa.float64))
@@ -3978,18 +3979,17 @@ def test_list_slice_field_names_retained(return_fixed_size, type):
39783979

39793980
def test_list_slice_bad_parameters():
39803981
arr = pa.array([[1]], pa.list_(pa.int8(), 1))
3981-
msg = r"`start`(.*) should be greater than 0 and smaller than `stop`(.*)"
3982+
msg = (
3983+
r"`start`(.*) should be greater than or equal to 0 "
3984+
r"and not greater than `stop`(.*)"
3985+
)
39823986
with pytest.raises(pa.ArrowInvalid, match=msg):
39833987
pc.list_slice(arr, -1, 1) # negative start?
39843988
with pytest.raises(pa.ArrowInvalid, match=msg):
39853989
pc.list_slice(arr, 2, 1) # start > stop?
39863990

3987-
# TODO(ARROW-18281): start==stop -> empty lists
3988-
with pytest.raises(pa.ArrowInvalid, match=msg):
3989-
pc.list_slice(arr, 0, 0) # start == stop?
3990-
39913991
# Step not >= 1
3992-
msg = "`step` must be >= 1, got: "
3992+
msg = "`step` must be greater than or equal to 1, got: "
39933993
with pytest.raises(pa.ArrowInvalid, match=msg + "0"):
39943994
pc.list_slice(arr, 0, 1, step=0)
39953995
with pytest.raises(pa.ArrowInvalid, match=msg + "-1"):

0 commit comments

Comments
 (0)