Skip to content

Commit fdca7c2

Browse files
committed
Solidify API for Bulk Upsert calls
1 parent cabe1a9 commit fdca7c2

4 files changed

Lines changed: 71 additions & 47 deletions

File tree

spec/avram/operations/save_operation_spec.cr

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,12 @@ describe "Avram::SaveOperation" do
312312
it "should create the records" do
313313
record_args = (1..2).to_a.map do |i|
314314
{
315-
:name => "Test #{i}",
316-
:nickname => "Test Nickname #{i}",
317-
}.as(Avram::BulkUpsert::Params)
315+
name: "Test #{i}",
316+
nickname: "Test Nickname #{i}",
317+
}
318318
end
319-
records = UserFactory.new.build_pair
320319

321-
SaveUser.bulk_upsert(record_args)
320+
UpsertUserOperation.bulk_upsert(record_args)
322321
end
323322
end
324323

src/avram/bulk_upsert.cr

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,70 @@
11
class Avram::BulkUpsert(T)
2-
alias Params = Hash(Symbol, String) | Hash(Symbol, String?) | Hash(Symbol, Nil)
3-
4-
def initialize(@records : Array(Params))
2+
def initialize(@records : Array(T), @column_names : Array(Symbol))
3+
@records = set_timestamps(records)
54
end
65

76
def statement
87
[
9-
"insert into #{table}(#{fields})",
10-
"values #{value_placeholders}",
11-
"ON CONFLICT DO UPDATE SET #{updates}",
12-
"returning #{returning}",
8+
"INSERT INTO #{table}(#{fields})",
9+
"VALUES #{value_placeholders}",
10+
"ON CONFLICT (#{conflicts}) DO UPDATE SET #{updates}",
11+
"RETURNING #{returning}",
1312
].join(" ")
1413
end
1514

15+
private def conflicts
16+
@column_names.join(", ")
17+
end
18+
19+
private def set_timestamps(collection)
20+
collection.map do |record|
21+
record.created_at.value ||= Time.utc if record.responds_to?(:created_at)
22+
record.updated_at.value = Time.utc if record.responds_to?(:updated_at)
23+
record
24+
end
25+
end
26+
1627
private def table
17-
T.table_name
28+
@records.first.table_name
1829
end
1930

2031
private def updates
21-
conflict_updates = T.column_names.uniq.map do |column|
22-
"SET #{column}=EXCLUDED.#{column}"
23-
end
32+
update_keys = @records.first.insert_values.keys
2433

25-
if T.column_names.includes?(:updated_at)
26-
conflict_updates.push("SET updated_at=NOW()").join(", ")
27-
else
28-
conflict_updates.join(", ")
29-
end
34+
(update_keys - [:created_at]).map do |column|
35+
"#{column}=EXCLUDED.#{column}"
36+
end.join(", ")
3037
end
3138

3239
private def returning
33-
"id"
40+
T.column_names.join(", ")
3441
end
3542

3643
private def fields
37-
T.column_names.join(", ")
44+
@records.first.insert_values.keys.map do |key|
45+
<<-TEXT
46+
"#{key}"
47+
TEXT
48+
end.join(", ")
3849
end
3950

4051
def args
41-
@records.map &.values
52+
@records.flat_map do |record|
53+
record.insert_values.values
54+
end
4255
end
4356

44-
private def placeholder_values(record)
45-
values = record.values.map_with_index(1) do |_value, index|
57+
private def value_placeholders(record)
58+
record.insert_values.map_with_index(1) do |_value, index|
4659
"$#{index}"
4760
end.join(", ")
48-
49-
"(#{values})"
5061
end
5162

5263
private def value_placeholders
64+
i = 0
65+
5366
@records.map do |record|
54-
placeholder_values(record)
67+
"($#{i += 1}, $#{i += 1}, $#{i += 1}, $#{i += 1})"
5568
end.join(", ")
5669
end
5770
end

src/avram/save_operation.cr

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ abstract class Avram::SaveOperation(T)
367367
{{ T.constant(:PRIMARY_KEY_NAME).id }}.value.nil?
368368
end
369369

370+
def insert_values
371+
attributes_to_hash(column_attributes).compact
372+
end
373+
370374
private def insert_or_update
371375
if persisted?
372376
update record_id
@@ -379,6 +383,10 @@ abstract class Avram::SaveOperation(T)
379383
@record.try &.id
380384
end
381385

386+
def self.column_names
387+
T.column_names
388+
end
389+
382390
def before_save; end
383391

384392
def after_save(_record : T); end
@@ -408,7 +416,6 @@ abstract class Avram::SaveOperation(T)
408416
end
409417

410418
private def insert_sql
411-
insert_values = attributes_to_hash(column_attributes).compact
412419
Avram::Insert.new(table_name, insert_values, T.column_names)
413420
end
414421

src/avram/upsert.cr

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,27 @@ module Avram::Upsert
8181
yield operation, operation.record
8282
end
8383

84+
def self.bulk_upsert(upserts)
85+
operations = upserts.map do |upsert_args|
86+
new(**upsert_args)
87+
end
88+
89+
upsert = Avram::BulkUpsert(self).new(
90+
operations,
91+
{{ attribute_names }}.to_a
92+
)
93+
94+
pp upsert.statement
95+
pp upsert.args
96+
records = [] of T
97+
98+
new.database.query upsert.statement, args: upsert.args do |rs|
99+
records << T.from_rs(rs).first
100+
end
101+
102+
pp records
103+
end
104+
84105
def self.find_existing_unique_record(operation) : T?
85106
T::BaseQuery.new
86107
{% for attribute in attribute_names %}
@@ -92,29 +113,13 @@ module Avram::Upsert
92113

93114
# :nodoc:
94115
macro included
95-
{% for method in ["upsert", "upsert!"] %}
116+
{% for method in ["upsert", "upsert!", "bulk_upsert"] %}
96117
# Performs a create or update depending on if there is a conflicting row in the database.
97118
#
98119
# See `Avram::Upsert.upsert_lookup_columns` for full documentation and examples.
99120
def self.{{ method.id }}(*args, **named_args)
100121
\{% raise "Please use the 'upsert_lookup_columns' macro in #{@type} before using '{{ method.id }}'" %}
101122
end
102123
{% end %}
103-
104-
def self.bulk_upsert(upserts : Array(Avram::BulkUpsert::Params))
105-
upsert_keys = upserts.map(&.keys).uniq
106-
107-
if upsert_keys.size > 1
108-
raise "All hashes passed to bulk_upsert must have the same keys."
109-
elsif upsert_keys.flatten.any? { |key| !T.column_names.includes?(key) }
110-
raise "All keys in hashes must be column names in the table."
111-
end
112-
113-
upsert = Avram::BulkUpsert(T).new(upserts)
114-
pp upsert.statement
115-
pp upsert.args
116-
117-
T.database.query(upsert.statement, args: upsert.args)
118-
end
119124
end
120125
end

0 commit comments

Comments
 (0)