diff --git a/lib/hex/application.ex b/lib/hex/application.ex index b982e7f4..18412b06 100644 --- a/lib/hex/application.ex +++ b/lib/hex/application.ex @@ -50,6 +50,7 @@ defmodule Hex.Application do [ Hex.Netrc.Cache, Hex.OAuth, + Hex.Repo, Hex.State, Hex.Server, {Hex.Parallel, [:hex_fetcher]} @@ -60,6 +61,7 @@ defmodule Hex.Application do [ Hex.Netrc.Cache, Hex.OAuth, + Hex.Repo, Hex.State, Hex.Server, {Hex.Parallel, [:hex_fetcher]}, diff --git a/lib/hex/once_cache.ex b/lib/hex/once_cache.ex index a8f71011..d82e0f63 100644 --- a/lib/hex/once_cache.ex +++ b/lib/hex/once_cache.ex @@ -1,10 +1,11 @@ defmodule Hex.OnceCache do @moduledoc """ - A cache that computes a value once on first access and caches it. + A cache that computes values at most once and caches them. - Uses Agent.get_and_update/2 to ensure only one process computes the value, - even when multiple processes access it concurrently. All other processes - will wait for the Agent call to complete and receive the computed value. + Supports both single-value caching via `fetch/3` and keyed caching via + `fetch_key/4`. Computations run in the caller's process, allowing concurrent + computations for different keys. Multiple callers requesting the same key + will wait for the first caller's computation to complete. ## Example @@ -26,10 +27,10 @@ defmodule Hex.OnceCache do # => :expensive_result (no "Computing..." output) """ - use Agent + use GenServer @doc """ - Starts a new OnceCache Agent. + Starts a new OnceCache. ## Options @@ -37,7 +38,7 @@ defmodule Hex.OnceCache do """ def start_link(opts) do name = Keyword.fetch!(opts, :name) - Agent.start_link(fn -> :not_cached end, name: name) + GenServer.start_link(__MODULE__, :ok, name: name) end @doc """ @@ -52,33 +53,141 @@ defmodule Hex.OnceCache do Use `:infinity` for operations that may take a long time (e.g., user interaction). """ def fetch(name, compute_fun, opts \\ []) do + fetch_key(name, :__single__, compute_fun, opts) + end + + @doc """ + Fetches a keyed cached value or computes it if not yet cached. + + Like `fetch/3`, but supports multiple independent cached values identified by key. + The compute function is only called once per key, even with concurrent access. + Computations for different keys run concurrently in their respective caller processes. + + Should not be mixed with `fetch/3` or `put/2` on the same cache. + """ + def fetch_key(name, key, compute_fun, opts \\ []) do timeout = Keyword.get(opts, :timeout, 5000) - Agent.get_and_update( - name, - fn - :not_cached -> - value = compute_fun.() - {value, {:cached, value}} + case GenServer.call(name, {:fetch, key}, timeout) do + {:ok, value} -> + value - {:cached, cached} -> - {cached, {:cached, cached}} - end, - timeout - ) + :compute -> + try do + value = compute_fun.() + :ok = GenServer.call(name, {:computed, key, value}, timeout) + value + catch + kind, reason -> + GenServer.cast(name, {:failed, key}) + :erlang.raise(kind, reason, __STACKTRACE__) + end + end end @doc """ Stores a value in the cache without computing it. """ def put(name, value) do - Agent.update(name, fn _ -> {:cached, value} end) + GenServer.call(name, {:put, :__single__, value}) end @doc """ Clears the cache. """ def clear(name) do - Agent.update(name, fn _ -> :not_cached end) + GenServer.call(name, :clear) + end + + # GenServer callbacks + + @impl true + def init(:ok) do + {:ok, %{}} + end + + @impl true + def handle_call({:fetch, key}, {pid, _} = from, state) do + case Map.get(state, key) do + {:cached, value} -> + {:reply, {:ok, value}, state} + + {:computing, _mon_ref, _waiters} -> + {:noreply, update_waiters(state, key, from)} + + nil -> + mon_ref = Process.monitor(pid) + {:reply, :compute, Map.put(state, key, {:computing, mon_ref, []})} + end + end + + def handle_call({:computed, key, value}, _from, state) do + case Map.get(state, key) do + {:computing, mon_ref, waiters} -> + Process.demonitor(mon_ref, [:flush]) + + for waiter <- waiters do + GenServer.reply(waiter, {:ok, value}) + end + + {:reply, :ok, Map.put(state, key, {:cached, value})} + + _ -> + {:reply, :ok, Map.put(state, key, {:cached, value})} + end + end + + def handle_call({:put, key, value}, _from, state) do + {:reply, :ok, Map.put(state, key, {:cached, value})} + end + + def handle_call(:clear, _from, _state) do + {:reply, :ok, %{}} + end + + @impl true + def handle_cast({:failed, key}, state) do + case Map.get(state, key) do + {:computing, mon_ref, waiters} -> + Process.demonitor(mon_ref, [:flush]) + {:noreply, hand_off_or_remove(state, key, waiters)} + + _ -> + {:noreply, state} + end + end + + @impl true + def handle_info({:DOWN, mon_ref, :process, _pid, _reason}, state) do + case find_computing_key(state, mon_ref) do + {key, waiters} -> + {:noreply, hand_off_or_remove(state, key, waiters)} + + nil -> + {:noreply, state} + end + end + + defp update_waiters(state, key, from) do + Map.update!(state, key, fn {:computing, mon_ref, waiters} -> + {:computing, mon_ref, [from | waiters]} + end) + end + + defp hand_off_or_remove(state, key, [{pid, _} = next | rest]) do + new_mon_ref = Process.monitor(pid) + GenServer.reply(next, :compute) + Map.put(state, key, {:computing, new_mon_ref, rest}) + end + + defp hand_off_or_remove(state, key, []) do + Map.delete(state, key) + end + + defp find_computing_key(state, mon_ref) do + Enum.find_value(state, fn + {key, {:computing, ^mon_ref, waiters}} -> {key, waiters} + _ -> nil + end) end end diff --git a/lib/hex/repo.ex b/lib/hex/repo.ex index 70f9cf4b..2f4a6808 100644 --- a/lib/hex/repo.ex +++ b/lib/hex/repo.ex @@ -1,6 +1,7 @@ defmodule Hex.Repo do @moduledoc false + @exchange_cache __MODULE__.ExchangeCache @hexpm_url "https://repo.hex.pm" @hexpm_public_key """ -----BEGIN PUBLIC KEY----- @@ -14,6 +15,21 @@ defmodule Hex.Repo do -----END PUBLIC KEY----- """ + def start_link(_args) do + Hex.OnceCache.start_link(name: @exchange_cache) + end + + def child_spec(arg) do + %{ + id: __MODULE__, + start: {__MODULE__, :start_link, [arg]} + } + end + + def clear_exchange_cache do + Hex.OnceCache.clear(@exchange_cache) + end + def fetch_repo(repo) do repo = repo || "hexpm" repos = Hex.State.fetch!(:repos) @@ -335,11 +351,12 @@ defmodule Hex.Repo do {:ok, access_token} -> {:ok, access_token} - :expired -> - do_exchange_api_key(repo_config, repo_name) - - :not_found -> - do_exchange_api_key(repo_config, repo_name) + _expired_or_not_found -> + Hex.OnceCache.fetch_key( + @exchange_cache, + {repo_name, repo_config.auth_key}, + fn -> do_exchange_api_key(repo_config, repo_name) end + ) end end diff --git a/test/hex/once_cache_test.exs b/test/hex/once_cache_test.exs index b5a43333..cb117a68 100644 --- a/test/hex/once_cache_test.exs +++ b/test/hex/once_cache_test.exs @@ -169,6 +169,119 @@ defmodule Hex.OnceCacheTest do end end + describe "fetch_key/4" do + test "computes value on first call for a key", %{cache: cache} do + result = + Hex.OnceCache.fetch_key(cache, :key1, fn -> + :value1 + end) + + assert result == :value1 + end + + test "returns cached value on subsequent calls for same key", %{cache: cache} do + counter = :counters.new(1, []) + + compute_fn = fn -> + :counters.add(counter, 1, 1) + :value + end + + assert Hex.OnceCache.fetch_key(cache, :key1, compute_fn) == :value + assert :counters.get(counter, 1) == 1 + + assert Hex.OnceCache.fetch_key(cache, :key1, compute_fn) == :value + assert :counters.get(counter, 1) == 1 + end + + test "computes independently for different keys", %{cache: cache} do + assert Hex.OnceCache.fetch_key(cache, :key1, fn -> :value1 end) == :value1 + assert Hex.OnceCache.fetch_key(cache, :key2, fn -> :value2 end) == :value2 + + # Both are cached independently + assert Hex.OnceCache.fetch_key(cache, :key1, fn -> :should_not_compute end) == :value1 + assert Hex.OnceCache.fetch_key(cache, :key2, fn -> :should_not_compute end) == :value2 + end + + test "handles concurrent calls for the same key", %{cache: cache} do + counter = :counters.new(1, []) + + compute_fn = fn -> + :counters.add(counter, 1, 1) + Process.sleep(50) + :result + end + + tasks = + for _ <- 1..10 do + Task.async(fn -> + Hex.OnceCache.fetch_key(cache, :key1, compute_fn) + end) + end + + results = Task.await_many(tasks) + assert Enum.all?(results, &(&1 == :result)) + assert :counters.get(counter, 1) == 1 + end + + test "computes different keys concurrently", %{cache: cache} do + # Both keys start computing at the same time. + # If serialized, total time would be >= 200ms. + # If concurrent, total time should be ~100ms. + compute_fn = fn -> + Process.sleep(100) + :result + end + + task1 = Task.async(fn -> Hex.OnceCache.fetch_key(cache, :key1, compute_fn) end) + task2 = Task.async(fn -> Hex.OnceCache.fetch_key(cache, :key2, compute_fn) end) + + {elapsed, results} = :timer.tc(fn -> Task.await_many([task1, task2]) end) + + assert results == [:result, :result] + # Should complete in roughly 100ms, not 200ms + assert elapsed < 180_000 + end + + test "hands off to next waiter when computing process crashes", %{cache: cache} do + caller = self() + + spawn(fn -> + Hex.OnceCache.fetch_key(cache, :key1, fn -> + send(caller, :started) + Process.sleep(100) + raise "crash" + end) + end) + + assert_receive :started + + task2 = + Task.async(fn -> + Hex.OnceCache.fetch_key(cache, :key1, fn -> :recovered end) + end) + + assert Task.await(task2, 5000) == :recovered + end + + test "clear resets all keys", %{cache: cache} do + counter = :counters.new(1, []) + + Hex.OnceCache.fetch_key(cache, :key1, fn -> :value1 end) + Hex.OnceCache.fetch_key(cache, :key2, fn -> :value2 end) + + Hex.OnceCache.clear(cache) + + compute_fn = fn -> + :counters.add(counter, 1, 1) + :recomputed + end + + assert Hex.OnceCache.fetch_key(cache, :key1, compute_fn) == :recomputed + assert :counters.get(counter, 1) == 1 + end + end + describe "fetch/3 with timeout" do test "respects custom timeout for long operations", %{cache: cache} do compute_fn = fn -> @@ -180,13 +293,20 @@ defmodule Hex.OnceCacheTest do assert result == :long_operation end - test "times out if computation exceeds timeout", %{cache: cache} do - compute_fn = fn -> - Process.sleep(200) - :long_operation - end + test "waiter times out if computation exceeds timeout", %{cache: cache} do + # Start a slow computation in another process + Task.async(fn -> + Hex.OnceCache.fetch(cache, fn -> + Process.sleep(200) + :slow_result + end) + end) + + # Give the task time to start computing + Process.sleep(10) - assert catch_exit(Hex.OnceCache.fetch(cache, compute_fn, timeout: 50)) + # A waiter with a short timeout should time out + assert catch_exit(Hex.OnceCache.fetch(cache, fn -> :unused end, timeout: 50)) end test "accepts :infinity timeout", %{cache: cache} do diff --git a/test/support/case.ex b/test/support/case.ex index d3c27722..724288fb 100644 --- a/test/support/case.ex +++ b/test/support/case.ex @@ -274,6 +274,7 @@ defmodule HexTest.Case do def reset_state do Hex.State.put_all(Application.get_env(:hex, :reset_state)) Hex.OAuth.clear_tokens() + Hex.Repo.clear_exchange_cache() end def set_home_cwd() do