Skip to content

Commit 607be64

Browse files
authored
apacheGH-46939: [C++] Add support for shared memory comparison in arrow::RecordBatch (apache#47149)
### Rationale for this change Create a fast path for comparing `arrow::RecordBatch `instances that share the same memory. ### What changes are included in this PR? Enable fast comparison for `arrow::RecordBatch `objects backed by the same memory. ### Are these changes tested? Yes, I ran the relevant unit tests. ### Are there any user-facing changes? No. * GitHub Issue: apache#46939 Authored-by: Arash Andishgar <arashandishgar1@gmail.com> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
1 parent 39f3722 commit 607be64

2 files changed

Lines changed: 144 additions & 13 deletions

File tree

cpp/src/arrow/record_batch.cc

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -310,19 +310,58 @@ const std::string& RecordBatch::column_name(int i) const {
310310
return schema_->field(i)->name();
311311
}
312312

313-
bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata,
314-
const EqualOptions& opts) const {
315-
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
316-
return false;
313+
namespace {
314+
315+
bool ContainFloatType(const std::shared_ptr<DataType>& type) {
316+
if (is_floating(type->id())) {
317+
return true;
317318
}
318319

319-
if (!schema_->Equals(*other.schema(), check_metadata)) {
320-
return false;
320+
for (const auto& field : type->fields()) {
321+
if (ContainFloatType(field->type())) {
322+
return true;
323+
}
321324
}
322325

323-
if (device_type() != other.device_type()) {
326+
return false;
327+
}
328+
329+
bool ContainFloatType(const Schema& schema) {
330+
for (auto& field : schema.fields()) {
331+
if (ContainFloatType(field->type())) {
332+
return true;
333+
}
334+
}
335+
return false;
336+
}
337+
338+
bool CanIgnoreNaNInEquality(const RecordBatch& batch, const EqualOptions& opts) {
339+
if (opts.nans_equal()) {
340+
return true;
341+
} else if (!ContainFloatType(*batch.schema())) {
342+
return true;
343+
} else {
324344
return false;
325345
}
346+
}
347+
348+
} // namespace
349+
350+
bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata,
351+
const EqualOptions& opts) const {
352+
if (this == &other) {
353+
if (CanIgnoreNaNInEquality(*this, opts)) {
354+
return true;
355+
}
356+
} else {
357+
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
358+
return false;
359+
} else if (!schema_->Equals(*other.schema(), check_metadata)) {
360+
return false;
361+
} else if (device_type() != other.device_type()) {
362+
return false;
363+
}
364+
}
326365

327366
for (int i = 0; i < num_columns(); ++i) {
328367
if (!column(i)->Equals(other.column(i), opts)) {
@@ -334,12 +373,16 @@ bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata,
334373
}
335374

336375
bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opts) const {
337-
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
338-
return false;
339-
}
340-
341-
if (device_type() != other.device_type()) {
342-
return false;
376+
if (this == &other) {
377+
if (CanIgnoreNaNInEquality(*this, opts)) {
378+
return true;
379+
}
380+
} else {
381+
if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
382+
return false;
383+
} else if (device_type() != other.device_type()) {
384+
return false;
385+
}
343386
}
344387

345388
for (int i = 0; i < num_columns(); ++i) {

cpp/src/arrow/record_batch_test.cc

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,94 @@ TEST_F(TestRecordBatch, ApproxEqualOptions) {
142142
EXPECT_TRUE(b1->ApproxEquals(*b2, options));
143143
}
144144

145+
class TestRecordBatchEqualsSameAddress : public TestRecordBatch {};
146+
147+
TEST_F(TestRecordBatchEqualsSameAddress, NonFloatType) {
148+
auto f0 = field("f0", int32());
149+
auto f1 = field("f1", int64());
150+
151+
auto schema = ::arrow::schema({f0, f1});
152+
153+
auto a0 = ArrayFromJSON(f0->type(), "[0, 1, 2]");
154+
auto a1 = ArrayFromJSON(f1->type(), "[0, 1, 2]");
155+
156+
auto b0 = RecordBatch::Make(schema, 3, {a0, a1});
157+
auto b1 = b0;
158+
159+
auto options = EqualOptions::Defaults();
160+
161+
ASSERT_TRUE(b0->Equals(*b1, true, options));
162+
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
163+
164+
ASSERT_TRUE(b0->ApproxEquals(*b1, options));
165+
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
166+
}
167+
168+
TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithoutFloatType) {
169+
auto f0 = field("f0", int32());
170+
auto f1 = field("f1", struct_({{"f2", int64()}, {"f3", int8()}}));
171+
172+
auto schema = ::arrow::schema({f0, f1});
173+
174+
auto a0 = ArrayFromJSON(f0->type(), "[0, 1, 2]");
175+
auto a1 = ArrayFromJSON(
176+
f1->type(), R"([{"f2": 1, "f3": 4}, {"f2": 2, "f3": 5}, {"f2":3, "f3": 6}])");
177+
178+
auto b0 = RecordBatch::Make(schema, 3, {a0, a1});
179+
auto b1 = b0;
180+
181+
auto options = EqualOptions::Defaults();
182+
183+
ASSERT_TRUE(b0->Equals(*b1, true, options));
184+
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
185+
186+
ASSERT_TRUE(b0->ApproxEquals(*b1, options));
187+
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
188+
}
189+
190+
TEST_F(TestRecordBatchEqualsSameAddress, FloatType) {
191+
auto f0 = field("f0", int32());
192+
auto f1 = field("f1", float64());
193+
194+
auto schema = ::arrow::schema({f0, f1});
195+
196+
auto a0 = ArrayFromJSON(f0->type(), "[0, 1, 2]");
197+
auto a1 = ArrayFromJSON(f1->type(), "[0.0, 1.0, 2.0, NaN]");
198+
199+
auto b0 = RecordBatch::Make(schema, 3, {a0, a1});
200+
auto b1 = b0;
201+
202+
auto options = EqualOptions::Defaults();
203+
204+
ASSERT_FALSE(b0->Equals(*b1, true, options));
205+
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
206+
207+
ASSERT_FALSE(b0->ApproxEquals(*b1, options));
208+
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
209+
}
210+
211+
TEST_F(TestRecordBatchEqualsSameAddress, NestedTypesWithFloatType) {
212+
auto f0 = field("f0", int32());
213+
auto f1 = field("f1", struct_({{"f2", int64()}, {"f3", float32()}}));
214+
215+
auto schema = ::arrow::schema({f0, f1});
216+
217+
auto a0 = ArrayFromJSON(f0->type(), "[0, 1, 2]");
218+
auto a1 = ArrayFromJSON(
219+
f1->type(), R"([{"f2": 1, "f3": 4.0}, {"f2": 2, "f3": 4.0}, {"f2":3, "f3": NaN}])");
220+
221+
auto b0 = RecordBatch::Make(schema, 3, {a0, a1});
222+
auto b1 = b0;
223+
224+
auto options = EqualOptions::Defaults();
225+
226+
ASSERT_FALSE(b0->Equals(*b1, true, options));
227+
ASSERT_TRUE(b0->Equals(*b1, true, options.nans_equal(true)));
228+
229+
ASSERT_FALSE(b0->ApproxEquals(*b1, options));
230+
ASSERT_TRUE(b0->ApproxEquals(*b1, options.nans_equal(true)));
231+
}
232+
145233
TEST_F(TestRecordBatch, Validate) {
146234
const int length = 10;
147235

0 commit comments

Comments
 (0)