Skip to content

Commit cabe1a9

Browse files
committed
More WIP
1 parent 62b2b20 commit cabe1a9

3 files changed

Lines changed: 53 additions & 17 deletions

File tree

spec/avram/operations/save_operation_spec.cr

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,16 +310,15 @@ describe "Avram::SaveOperation" do
310310
describe ".bulk_upsert" do
311311
context "when the records are not persisted" do
312312
it "should create the records" do
313-
record_args = (1..50).to_a.map do |i|
313+
record_args = (1..2).to_a.map do |i|
314314
{
315315
:name => "Test #{i}",
316316
:nickname => "Test Nickname #{i}",
317317
}.as(Avram::BulkUpsert::Params)
318318
end
319+
records = UserFactory.new.build_pair
319320

320-
result = SaveUser.bulk_upsert(record_args)
321-
322-
pp result
321+
SaveUser.bulk_upsert(record_args)
323322
end
324323
end
325324

src/avram/bulk_upsert.cr

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,57 @@
1-
class Avram::BulkUpsert
1+
class Avram::BulkUpsert(T)
22
alias Params = Hash(Symbol, String) | Hash(Symbol, String?) | Hash(Symbol, Nil)
33

4-
def initialize(@table : TableName,
5-
@records : Array(Params),
6-
@column_names : Array(Symbol) = [] of Symbol)
4+
def initialize(@records : Array(Params))
75
end
86

97
def statement
10-
"insert into #{@table}(#{fields}) values(#{values}) returning *"
8+
[
9+
"insert into #{table}(#{fields})",
10+
"values #{value_placeholders}",
11+
"ON CONFLICT DO UPDATE SET #{updates}",
12+
"returning #{returning}",
13+
].join(" ")
14+
end
15+
16+
private def table
17+
T.table_name
18+
end
19+
20+
private def updates
21+
conflict_updates = T.column_names.uniq.map do |column|
22+
"SET #{column}=EXCLUDED.#{column}"
23+
end
24+
25+
if T.column_names.includes?(:updated_at)
26+
conflict_updates.push("SET updated_at=NOW()").join(", ")
27+
else
28+
conflict_updates.join(", ")
29+
end
30+
end
31+
32+
private def returning
33+
"id"
1134
end
1235

1336
private def fields
14-
@column_names.join(", ")
37+
T.column_names.join(", ")
38+
end
39+
40+
def args
41+
@records.map &.values
1542
end
1643

17-
private def record_values(record)
44+
private def placeholder_values(record)
1845
values = record.values.map_with_index(1) do |_value, index|
1946
"$#{index}"
2047
end.join(", ")
2148

2249
"(#{values})"
2350
end
2451

25-
private def values
52+
private def value_placeholders
2653
@records.map do |record|
27-
record_values(record)
54+
placeholder_values(record)
2855
end.join(", ")
2956
end
3057
end

src/avram/upsert.cr

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,20 @@ module Avram::Upsert
101101
end
102102
{% end %}
103103

104-
def self.bulk_upsert(params : Array(Avram::BulkUpsert::Params))
105-
Avram::BulkUpsert
106-
.new(T.table_name, params, T.column_names)
107-
.statement
104+
def self.bulk_upsert(upserts : Array(Avram::BulkUpsert::Params))
105+
upsert_keys = upserts.map(&.keys).uniq
106+
107+
if upsert_keys.size > 1
108+
raise "All hashes passed to bulk_upsert must have the same keys."
109+
elsif upsert_keys.flatten.any? { |key| !T.column_names.includes?(key) }
110+
raise "All keys in hashes must be column names in the table."
111+
end
112+
113+
upsert = Avram::BulkUpsert(T).new(upserts)
114+
pp upsert.statement
115+
pp upsert.args
116+
117+
T.database.query(upsert.statement, args: upsert.args)
108118
end
109119
end
110120
end

0 commit comments

Comments
 (0)