diff --git a/lib/graphql/dataloader/active_record_association_source.rb b/lib/graphql/dataloader/active_record_association_source.rb index 9f87793e6fc..6a3745b46e8 100644 --- a/lib/graphql/dataloader/active_record_association_source.rb +++ b/lib/graphql/dataloader/active_record_association_source.rb @@ -12,6 +12,14 @@ def initialize(association, scope = nil) @scope = scope end + def self.batch_key_for(association, scope = nil) + if scope + [association, scope.to_sql] + else + [association] + end + end + def load(record) if (assoc = record.association(@association)).loaded? assoc.target @@ -41,17 +49,17 @@ def fetch(records) ::ActiveRecord::Associations::Preloader.new(records: records, associations: @association, available_records: available_records, scope: @scope).call loaded_associated_records = records.map { |r| r.public_send(@association) } - records_by_model = {} - loaded_associated_records.each do |record| - if record - updates = records_by_model[record.class] ||= {} - updates[record.id] = record - end - end if @scope.nil? # Don't cache records loaded via scope because they might have reduced `SELECT`s # Could check .select_values here? + records_by_model = {} + loaded_associated_records.flatten.each do |record| + if record + updates = records_by_model[record.class] ||= {} + updates[record.id] = record + end + end records_by_model.each do |model_class, updates| dataloader.with(RECORD_SOURCE_CLASS, model_class).merge(updates) end diff --git a/spec/graphql/dataloader/active_record_association_source_spec.rb b/spec/graphql/dataloader/active_record_association_source_spec.rb index 1ee764702a6..cdf761183ba 100644 --- a/spec/graphql/dataloader/active_record_association_source_spec.rb +++ b/spec/graphql/dataloader/active_record_association_source_spec.rb @@ -106,6 +106,56 @@ assert_equal ::Band.find(1), vulfpeck end + it_dataloads "works with collection associations" do |d| + wilco = ::Band.find(4) + chon = ::Band.find(3) + albums_by_band = nil + log = with_active_record_log(colorize: false) do + albums_by_band = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :albums).load_all([wilco, chon]) + end + + assert_equal [[6], [4, 5]], albums_by_band.map { |al| al.map(&:id) } + assert_includes log, 'SELECT "albums".* FROM "albums" WHERE "albums"."band_id" IN (?, ?) [["band_id", 4], ["band_id", 3]]' + + albums = nil + log = with_active_record_log(colorize: false) do + albums = d.with(GraphQL::Dataloader::ActiveRecordSource, Album).load_all([3,4,5,6]) + end + + assert_equal [3,4,5,6], albums.map(&:id) + assert_includes log, 'WHERE "albums"."id" = ? [["id", 3]]' + end + + it_dataloads "works with collection associations with scope" do |d| + wilco = ::Band.find(4) + chon = ::Band.find(3) + albums_by_band = nil + one_month_ago = nil + log = with_active_record_log(colorize: false) do + one_month_ago = 1.month.ago.end_of_day + albums_by_band_1 = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :albums, Album.where("created_at >= ?", one_month_ago)).request(wilco) + albums_by_band_2 = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :albums, Album.where("created_at >= ?", one_month_ago)).request(chon) + albums_by_band = [albums_by_band_1.load, albums_by_band_2.load] + end + + assert_equal [[6], [4, 5]], albums_by_band.map { |al| al.map(&:id) } + expected_log = if Rails::VERSION::STRING > "8" + 'SELECT "albums".* FROM "albums" WHERE (created_at >= ?) AND "albums"."band_id" IN (?, ?)' + else + 'SELECT "albums".* FROM "albums" WHERE (created_at >= ' + one_month_ago.utc.strftime("'%Y-%m-%d %H:%M:%S.%6N'") + ') AND "albums"."band_id" IN (?, ?)' + end + + assert_includes log, expected_log + + albums = nil + log = with_active_record_log(colorize: false) do + albums = d.with(GraphQL::Dataloader::ActiveRecordSource, Album).load_all([3,4,5,6]) + end + + assert_equal [3,4,5,6], albums.map(&:id) + assert_includes log, 'WHERE "albums"."id" IN (?, ?, ?, ?) [["id", 3], ["id", 4], ["id", 5], ["id", 6]]' + end + if Rails::VERSION::STRING > "7.1" # not supported in <7.1 it_dataloads "loads with composite primary keys and warms the cache" do |d| my_first_car = ::Album.find(2) diff --git a/spec/support/active_record_setup.rb b/spec/support/active_record_setup.rb index 1ddd752f843..db8ff54d673 100644 --- a/spec/support/active_record_setup.rb +++ b/spec/support/active_record_setup.rb @@ -70,6 +70,7 @@ t.integer :band_id t.string :band_name t.integer :band_genre + t.timestamps end create_table :books do |t|