Skip to content

Commit a3b70a8

Browse files
committed
Add each_record_batch
1 parent 77a2319 commit a3b70a8

2 files changed

Lines changed: 126 additions & 73 deletions

File tree

lib/arrow-activerecord/arrowable.rb

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,15 @@
33
module ArrowActiveRecord
44
module Arrowable
55
def to_arrow(batch_size: 10000)
6-
target_column_names = select_values
7-
target_column_names = column_names if select_values.empty?
6+
record_batches = each_record_batch(batch_size:).to_a
7+
Arrow::Table.new(record_batches.first.schema, record_batches)
8+
end
89

9-
fields = []
10-
target_column_names.each do |name|
11-
name = name.to_s
12-
target_column = columns.find do |column|
13-
column.name == name
14-
end
15-
fields << {name: name, data_type: extract_arrow_data_type(target_column)}
16-
end
17-
schema = Arrow::Schema.new(fields)
10+
def each_record_batch(batch_size: 10000, &block)
11+
return to_enum(__method__, batch_size:) unless block_given?
1812

19-
record_batches = []
13+
schema = build_arrow_schema
14+
target_column_names = schema.fields.collect(&:name)
2015
record_batch_builder = Arrow::RecordBatchBuilder.new(schema)
2116
in_batches(of: batch_size).each do |relation|
2217
records = relation.pluck(*target_column_names)
@@ -27,12 +22,26 @@ def to_arrow(batch_size: 10000)
2722
else
2823
record_batch_builder.append(records)
2924
end
30-
record_batches << record_batch_builder.flush
25+
yield(record_batch_builder.flush)
3126
end
32-
Arrow::Table.new(schema, record_batches)
3327
end
3428

3529
private
30+
def build_arrow_schema
31+
target_column_names = select_values
32+
target_column_names = column_names if select_values.empty?
33+
34+
fields = []
35+
target_column_names.each do |name|
36+
name = name.to_s
37+
target_column = columns.find do |column|
38+
column.name == name
39+
end
40+
fields << {name: name, data_type: extract_arrow_data_type(target_column)}
41+
end
42+
Arrow::Schema.new(fields)
43+
end
44+
3645
def extract_arrow_data_type(column)
3746
type = nil
3847
if column

test/arrowable_test.rb

Lines changed: 103 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,72 +11,116 @@ class Data < ActiveRecord::Base
1111
FileUtils.rm_rf(DB_PATH.dirname)
1212
end
1313

14+
setup do
15+
@date_value = Date.new(2018, 1, 10)
16+
@datetime_value = Time.iso8601("2018-01-10T18:05:01.1Z")
17+
@max_bigint_value = 2 ** 63 - 1
18+
@min_bigint_value = -(2 ** 63)
19+
ActiveRecord::Base.connection.create_table(:data) do |table|
20+
table.string :string_column
21+
table.date :date_column
22+
table.datetime :datetime_column
23+
table.boolean :boolean_column
24+
table.bigint :bigint_column
25+
end
26+
Data.create(string_column: "Hello",
27+
date_column: @date_value,
28+
datetime_column: @datetime_value,
29+
boolean_column: false,
30+
bigint_column: @max_bigint_value)
31+
Data.create(string_column: "Hello2",
32+
date_column: @date_value + 1,
33+
datetime_column: @datetime_value + 1,
34+
boolean_column: true,
35+
bigint_column: @min_bigint_value)
36+
end
37+
38+
teardown do
39+
ActiveRecord::Base.connection.drop_table(:data)
40+
end
41+
1442
sub_test_case("#to_arrow") do
15-
setup do
16-
@date_value = Date.new(2018, 1, 10)
17-
@datetime_value = Time.iso8601("2018-01-10T18:05:01.1Z")
18-
@bigint_value = 2 ** 63 - 1
19-
ActiveRecord::Base.connection.create_table(:data) do |table|
20-
table.string :string_column
21-
table.date :date_column
22-
table.datetime :datetime_column
23-
table.boolean :boolean_column
24-
table.bigint :bigint_column
25-
end
26-
Data.create(string_column: "Hello",
27-
date_column: @date_value,
28-
datetime_column: @datetime_value,
29-
boolean_column: false,
30-
bigint_column: @bigint_value)
31-
Data.create(string_column: "Hello2",
32-
date_column: @date_value + 1,
33-
datetime_column: @datetime_value + 1,
34-
boolean_column: true,
35-
bigint_column: -@bigint_value - 1)
43+
test "one column" do
44+
table = Data.all.select(:id).to_arrow(batch_size: 1)
45+
assert_equal([
46+
Arrow::RecordBatch.new(id: Arrow::Int32Array.new([1])),
47+
Arrow::RecordBatch.new(id: Arrow::Int32Array.new([2])),
48+
],
49+
table.each_record_batch.to_a)
3650
end
3751

38-
teardown do
39-
ActiveRecord::Base.connection.drop_table(:data)
52+
test "all columns" do
53+
table = Data.all.to_arrow(batch_size: 1)
54+
assert_equal([
55+
Arrow::RecordBatch.new(
56+
id: Arrow::Int32Array.new([1]),
57+
string_column: Arrow::StringArray.new(["Hello"]),
58+
date_column: Arrow::Date32Array.new([@date_value]),
59+
datetime_column: Arrow::TimestampArray.new(
60+
:nano,
61+
[@datetime_value]
62+
),
63+
boolean_column: Arrow::BooleanArray.new([false]),
64+
bigint_column:
65+
Arrow::Int64Array.new([@max_bigint_value]),
66+
),
67+
Arrow::RecordBatch.new(
68+
id: Arrow::Int32Array.new([2]),
69+
string_column: Arrow::StringArray.new(["Hello2"]),
70+
date_column: Arrow::Date32Array.new([@date_value + 1]),
71+
datetime_column: Arrow::TimestampArray.new(
72+
:nano,
73+
[@datetime_value + 1],
74+
),
75+
boolean_column: Arrow::BooleanArray.new([true]),
76+
bigint_column:
77+
Arrow::Int64Array.new([@min_bigint_value]),
78+
),
79+
],
80+
table.each_record_batch.to_a)
4081
end
82+
end
4183

42-
test "one" do
43-
arrow = Data.all.select(:id).to_arrow
44-
assert_equal(<<-RECORD_BATCH, arrow.each_record_batch.first.to_s)
45-
id: [
46-
1,
47-
2
48-
]
49-
RECORD_BATCH
84+
sub_test_case("#each_record_batch") do
85+
test "one column" do
86+
record_batches =
87+
Data.all.select(:id).each_record_batch(batch_size: 1).to_a
88+
assert_equal([
89+
Arrow::RecordBatch.new(id: Arrow::Int32Array.new([1])),
90+
Arrow::RecordBatch.new(id: Arrow::Int32Array.new([2])),
91+
],
92+
record_batches)
5093
end
5194

52-
test "all" do
53-
arrow = Data.all.to_arrow
54-
assert_equal(<<-RECORD_BATCH, arrow.each_record_batch.first.to_s)
55-
id: [
56-
1,
57-
2
58-
]
59-
string_column: [
60-
"Hello",
61-
"Hello2"
62-
]
63-
date_column: [
64-
#{@date_value},
65-
#{@date_value + 1}
66-
]
67-
datetime_column: [
68-
#{@datetime_value.strftime("%Y-%m-%d %H:%M:%S.%9N")},
69-
#{(@datetime_value + 1).strftime("%Y-%m-%d %H:%M:%S.%9N")}
70-
]
71-
boolean_column: [
72-
false,
73-
true
74-
]
75-
bigint_column: [
76-
#{@bigint_value},
77-
#{-@bigint_value - 1}
78-
]
79-
RECORD_BATCH
95+
test "all columns" do
96+
record_batches = Data.all.each_record_batch(batch_size: 1).to_a
97+
assert_equal([
98+
Arrow::RecordBatch.new(
99+
id: Arrow::Int32Array.new([1]),
100+
string_column: Arrow::StringArray.new(["Hello"]),
101+
date_column: Arrow::Date32Array.new([@date_value]),
102+
datetime_column: Arrow::TimestampArray.new(
103+
:nano,
104+
[@datetime_value]
105+
),
106+
boolean_column: Arrow::BooleanArray.new([false]),
107+
bigint_column:
108+
Arrow::Int64Array.new([@max_bigint_value]),
109+
),
110+
Arrow::RecordBatch.new(
111+
id: Arrow::Int32Array.new([2]),
112+
string_column: Arrow::StringArray.new(["Hello2"]),
113+
date_column: Arrow::Date32Array.new([@date_value + 1]),
114+
datetime_column: Arrow::TimestampArray.new(
115+
:nano,
116+
[@datetime_value + 1],
117+
),
118+
boolean_column: Arrow::BooleanArray.new([true]),
119+
bigint_column:
120+
Arrow::Int64Array.new([@min_bigint_value]),
121+
),
122+
],
123+
record_batches)
80124
end
81125
end
82126
end

0 commit comments

Comments
 (0)