diff --git a/lib/sprites.ex b/lib/sprites.ex index 9a11d16..24b51bc 100644 --- a/lib/sprites.ex +++ b/lib/sprites.ex @@ -42,11 +42,13 @@ defmodule Sprites do * `:base_url` - API base URL (default: "https://api.sprites.dev") * `:timeout` - HTTP timeout in milliseconds (default: 30_000) + * `:control_mode` - Enable control mode for multiplexed exec over a single WebSocket per sprite (default: false) ## Examples client = Sprites.new("my-token") client = Sprites.new("my-token", base_url: "https://custom.api.dev") + client = Sprites.new("my-token", control_mode: true) """ @spec new(String.t(), keyword()) :: client() def new(token, opts \\ []) do diff --git a/lib/sprites/client.ex b/lib/sprites/client.ex index 2e2d3a1..6038cad 100644 --- a/lib/sprites/client.ex +++ b/lib/sprites/client.ex @@ -7,13 +7,14 @@ defmodule Sprites.Client do @default_timeout 30_000 @create_timeout 120_000 - defstruct [:token, :base_url, :timeout, :req] + defstruct [:token, :base_url, :timeout, :req, control_mode: false] @type t :: %__MODULE__{ token: String.t(), base_url: String.t(), timeout: non_neg_integer(), - req: Req.Request.t() + req: Req.Request.t(), + control_mode: boolean() } @doc """ @@ -23,11 +24,13 @@ defmodule Sprites.Client do * `:base_url` - API base URL (default: "https://api.sprites.dev") * `:timeout` - HTTP timeout in milliseconds (default: 30_000) + * `:control_mode` - Enable control mode for multiplexed exec over a single WebSocket (default: false) """ @spec new(String.t(), keyword()) :: t() def new(token, opts \\ []) do base_url = Keyword.get(opts, :base_url, @default_base_url) |> normalize_url() timeout = Keyword.get(opts, :timeout, @default_timeout) + control_mode = Keyword.get(opts, :control_mode, false) req = Req.new( @@ -40,7 +43,8 @@ defmodule Sprites.Client do token: token, base_url: base_url, timeout: timeout, - req: req + req: req, + control_mode: control_mode } end diff --git a/lib/sprites/command.ex b/lib/sprites/command.ex index a1e4ed0..2127854 100644 --- a/lib/sprites/command.ex +++ b/lib/sprites/command.ex @@ -9,12 +9,17 @@ defmodule Sprites.Command do * `{:stderr, command, data}` - stderr data received * `{:exit, command, exit_code}` - command completed * `{:error, command, reason}` - error occurred + + Supports two execution modes: + + * **Direct mode** (default) — opens a new WebSocket per command to `/exec` + * **Control mode** — multiplexes over a persistent WebSocket to `/control` """ use GenServer require Logger - alias Sprites.{Sprite, Protocol, Error} + alias Sprites.{Sprite, Protocol, Error, Control, ControlConn} defstruct [:ref, :pid, :sprite, :owner, :tty_mode] @@ -122,11 +127,11 @@ defmodule Sprites.Command do @impl true def init(%{sprite: sprite, command: command, args: args, opts: opts, owner: owner, ref: ref}) do - url = Sprite.exec_url(sprite, command, args, opts) tty_mode = Keyword.get(opts, :tty, false) token = Sprite.token(sprite) + stdin = Keyword.get(opts, :stdin, false) - state = %{ + base_state = %{ owner: owner, ref: ref, tty_mode: tty_mode, @@ -134,10 +139,86 @@ defmodule Sprites.Command do stream_ref: nil, exit_code: nil, token: token, - url: url + sprite: sprite, + using_control: false, + control_conn: nil } - # Connect asynchronously but wait for connection in init + if Sprite.control_mode?(sprite) and Control.control_supported?(sprite) do + case try_control_connect(sprite, command, args, opts, stdin, base_state) do + {:ok, state} -> + {:ok, state} + + {:fallback, _reason} -> + do_direct_init(sprite, command, args, opts, token, base_state) + end + else + do_direct_init(sprite, command, args, opts, token, base_state) + end + end + + defp try_control_connect(sprite, command, args, opts, stdin, state) do + case Control.checkout(sprite) do + {:ok, conn_pid} -> + op_args = build_control_args(command, args, opts, stdin) + + case ControlConn.start_op(conn_pid, self(), "exec", op_args) do + :ok -> + {:ok, %{state | using_control: true, control_conn: conn_pid}} + + {:error, reason} -> + Control.checkin(sprite, conn_pid) + {:fallback, reason} + end + + {:error, {:control_not_supported, _}} -> + Control.mark_unsupported(sprite) + {:fallback, :control_not_supported} + + {:error, reason} -> + {:fallback, reason} + end + end + + defp build_control_args(command, args, opts, stdin) do + control_args = %{"cmd" => [command | args]} + + control_args = + case Keyword.get(opts, :dir) do + nil -> control_args + dir -> Map.put(control_args, "dir", dir) + end + + control_args = + case Keyword.get(opts, :env, []) do + [] -> + control_args + + env_list -> + env_strs = Enum.map(env_list, fn {k, v} -> "#{k}=#{v}" end) + Map.put(control_args, "env", env_strs) + end + + control_args = + if Keyword.get(opts, :tty, false) do + rows = Keyword.get(opts, :tty_rows, 24) + cols = Keyword.get(opts, :tty_cols, 80) + + control_args + |> Map.put("tty", "true") + |> Map.put("rows", to_string(rows)) + |> Map.put("cols", to_string(cols)) + else + control_args + end + + Map.put(control_args, "stdin", if(stdin, do: "true", else: "false")) + end + + defp do_direct_init(sprite, command, args, opts, token, state) do + url = Sprite.exec_url(sprite, command, args, opts) + state = Map.put(state, :url, url) + case do_connect(url, token) do {:ok, conn, stream_ref} -> {:ok, %{state | conn: conn, stream_ref: stream_ref}} @@ -216,7 +297,50 @@ defmodule Sprites.Command do end end + # --- Control mode handle_info clauses --- + @impl true + def handle_info( + {:control_data, :binary, data}, + %{using_control: true} = state + ) do + handle_binary_frame(data, state) + end + + def handle_info( + {:control_data, :text, text}, + %{using_control: true} = state + ) do + handle_text_frame(text, state) + end + + def handle_info( + {:control_op_complete, exit_code}, + %{using_control: true, owner: owner, ref: ref, sprite: sprite, control_conn: conn_pid} = + state + ) do + # In control mode, op.complete is the authoritative signal. + # If we haven't already sent an exit message (from exit stream), send now. + if state.exit_code == nil do + send(owner, {:exit, %{ref: ref}, exit_code}) + end + + Control.checkin(sprite, conn_pid) + {:stop, :normal, %{state | exit_code: exit_code, control_conn: nil}} + end + + def handle_info( + {:control_op_error, reason}, + %{using_control: true, owner: owner, ref: ref, sprite: sprite, control_conn: conn_pid} = + state + ) do + send(owner, {:error, %{ref: ref}, reason}) + Control.checkin(sprite, conn_pid) + {:stop, :normal, %{state | control_conn: nil}} + end + + # --- Direct mode handle_info clauses --- + def handle_info({:gun_ws, conn, _stream_ref, {:binary, data}}, %{conn: conn} = state) do handle_binary_frame(data, state) end @@ -231,7 +355,15 @@ defmodule Sprites.Command do def handle_info({:gun_down, conn, _protocol, reason, _killed_streams}, %{conn: conn} = state) do if state.exit_code == nil do - send(state.owner, {:error, %{ref: state.ref}, reason}) + # For normal closes (:closed, :normal), drain any pending WS frames + # from the mailbox before reporting an error. The exit frame may have + # been delivered but not yet processed. + state = drain_pending_frames(state) + + if state.exit_code == nil do + # Still no exit code after draining — report as error + send(state.owner, {:error, %{ref: state.ref}, reason}) + end end {:stop, :normal, state} @@ -251,7 +383,20 @@ defmodule Sprites.Command do {:noreply, state} end + # --- Write stdin --- + @impl true + def handle_call( + {:write_stdin, data}, + _from, + %{using_control: true, control_conn: conn_pid, tty_mode: tty_mode} = state + ) + when conn_pid != nil do + frame_data = Protocol.encode_stdin(data, tty_mode) + ControlConn.send_data(conn_pid, frame_data) + {:reply, :ok, state} + end + def handle_call( {:write_stdin, data}, _from, @@ -267,7 +412,19 @@ defmodule Sprites.Command do {:reply, {:error, :not_connected}, state} end + # --- Close stdin --- + @impl true + def handle_cast( + :close_stdin, + %{using_control: true, control_conn: conn_pid, tty_mode: false} = state + ) + when conn_pid != nil do + frame_data = Protocol.encode_stdin_eof() + ControlConn.send_data(conn_pid, frame_data) + {:noreply, state} + end + def handle_cast(:close_stdin, %{conn: conn, stream_ref: stream_ref, tty_mode: false} = state) when conn != nil do frame_data = Protocol.encode_stdin_eof() @@ -277,6 +434,18 @@ defmodule Sprites.Command do def handle_cast(:close_stdin, state), do: {:noreply, state} + # --- Resize --- + + def handle_cast( + {:resize, rows, cols}, + %{using_control: true, control_conn: conn_pid, tty_mode: true} = state + ) + when conn_pid != nil do + message = Jason.encode!(%{type: "resize", rows: rows, cols: cols}) + ControlConn.send_text(conn_pid, message) + {:noreply, state} + end + def handle_cast( {:resize, rows, cols}, %{conn: conn, stream_ref: stream_ref, tty_mode: true} = state @@ -289,7 +458,15 @@ defmodule Sprites.Command do def handle_cast({:resize, _, _}, state), do: {:noreply, state} + # --- Terminate --- + @impl true + def terminate(_reason, %{using_control: true, control_conn: conn_pid, sprite: sprite}) + when conn_pid != nil do + Control.checkin(sprite, conn_pid) + :ok + end + def terminate(_reason, %{conn: conn}) when conn != nil do :gun.close(conn) :ok @@ -315,14 +492,21 @@ defmodule Sprites.Command do {:noreply, state} {:exit, code} -> - send(owner, {:exit, %{ref: ref}, code}) - # Send close frame and stop - if state.conn do - :gun.ws_send(state.conn, state.stream_ref, :close) + if state.using_control do + # In control mode, store exit code but DON'T stop. + # Wait for op.complete to ensure proper sequencing. + send(owner, {:exit, %{ref: ref}, code}) + {:noreply, %{state | exit_code: code}} + else + send(owner, {:exit, %{ref: ref}, code}) + + if state.conn do + :gun.ws_send(state.conn, state.stream_ref, :close) + end + + {:stop, :normal, %{state | exit_code: code}} end - {:stop, :normal, %{state | exit_code: code}} - {:stdin_eof, _} -> {:noreply, state} @@ -338,13 +522,18 @@ defmodule Sprites.Command do {:noreply, state} {:ok, %{"type" => "exit", "code" => code}} -> - send(owner, {:exit, %{ref: ref}, code}) + if state.using_control do + send(owner, {:exit, %{ref: ref}, code}) + {:noreply, %{state | exit_code: code}} + else + send(owner, {:exit, %{ref: ref}, code}) - if state.conn do - :gun.ws_send(state.conn, state.stream_ref, :close) - end + if state.conn do + :gun.ws_send(state.conn, state.stream_ref, :close) + end - {:stop, :normal, %{state | exit_code: code}} + {:stop, :normal, %{state | exit_code: code}} + end _ -> {:noreply, state} @@ -384,6 +573,27 @@ defmodule Sprites.Command do end end + # Drain any pending WebSocket frames from the mailbox. + # Called on gun_down to pick up exit frames that may have arrived + # but not yet been processed (race between data frames and connection close). + defp drain_pending_frames(%{conn: conn} = state) do + receive do + {:gun_ws, ^conn, _stream_ref, {:binary, data}} -> + {:noreply, state} = handle_binary_frame(data, state) + drain_pending_frames(state) + + {:gun_ws, ^conn, _stream_ref, {:text, json}} -> + case handle_text_frame(json, state) do + {:noreply, state} -> drain_pending_frames(state) + {:stop, :normal, state} -> state + end + after + 0 -> state + end + end + + defp drain_pending_frames(state), do: state + # Read the response body from a failed HTTP response defp read_response_body(_conn, _stream_ref, :fin), do: "" diff --git a/lib/sprites/control.ex b/lib/sprites/control.ex new file mode 100644 index 0000000..9fc4199 --- /dev/null +++ b/lib/sprites/control.ex @@ -0,0 +1,182 @@ +defmodule Sprites.Control do + @moduledoc """ + Module-level pool management for control connections. + + Uses ETS tables to store per-sprite pools and control support flags. + Pools are created lazily on first checkout and cached for reuse. + """ + + alias Sprites.{ControlPool, Sprite} + + @pools_table :sprites_control_pools + @support_table :sprites_control_support + + @doc """ + Ensures ETS tables exist. Safe to call multiple times. + """ + @spec ensure_tables() :: :ok + def ensure_tables do + if :ets.whereis(@pools_table) == :undefined or + :ets.whereis(@support_table) == :undefined do + # Tables must be owned by a long-lived process so they survive + # across short-lived Command GenServers. + ensure_table_owner() + end + + :ok + end + + defp ensure_table_owner do + name = :sprites_control_table_owner + + case Process.whereis(name) do + nil -> + caller = self() + + pid = + spawn(fn -> + if :ets.whereis(@pools_table) == :undefined do + :ets.new(@pools_table, [:named_table, :public, :set]) + end + + if :ets.whereis(@support_table) == :undefined do + :ets.new(@support_table, [:named_table, :public, :set]) + end + + send(caller, {:tables_ready, self()}) + + # Stay alive forever — tables die when their owner dies + ref = make_ref() + + receive do + ^ref -> :ok + end + end) + + try do + Process.register(pid, name) + rescue + ArgumentError -> + Process.exit(pid, :normal) + end + + receive do + {:tables_ready, ^pid} -> :ok + after + 5_000 -> :ok + end + + _pid -> + :ok + end + end + + @doc """ + Checks out a control connection for the given sprite. + + Creates a pool if one doesn't exist yet. + """ + @spec checkout(Sprite.t()) :: {:ok, pid()} | {:error, term()} + def checkout(%Sprite{} = sprite) do + ensure_tables() + pool = get_or_create_pool(sprite) + ControlPool.checkout(pool) + end + + @doc """ + Returns a control connection to the pool for the given sprite. + """ + @spec checkin(Sprite.t(), pid()) :: :ok + def checkin(%Sprite{} = sprite, conn_pid) do + ensure_tables() + key = sprite_key(sprite) + + case :ets.lookup(@pools_table, key) do + [{^key, pool}] -> + ControlPool.checkin(pool, conn_pid) + + [] -> + :ok + end + end + + @doc """ + Returns whether control mode is believed to be supported for the given sprite. + + Returns `true` by default (until `mark_unsupported/1` is called). + """ + @spec control_supported?(Sprite.t()) :: boolean() + def control_supported?(%Sprite{} = sprite) do + ensure_tables() + key = sprite_key(sprite) + + case :ets.lookup(@support_table, key) do + [{^key, false}] -> false + _ -> true + end + end + + @doc """ + Marks a sprite as not supporting control mode. + + Prevents future checkout attempts from trying the control endpoint. + """ + @spec mark_unsupported(Sprite.t()) :: :ok + def mark_unsupported(%Sprite{} = sprite) do + ensure_tables() + key = sprite_key(sprite) + :ets.insert(@support_table, {key, false}) + :ok + end + + @doc """ + Closes all control connections for the given sprite and removes the pool. + """ + @spec close(Sprite.t()) :: :ok + def close(%Sprite{} = sprite) do + ensure_tables() + key = sprite_key(sprite) + + case :ets.lookup(@pools_table, key) do + [{^key, pool}] -> + ControlPool.close(pool) + :ets.delete(@pools_table, key) + + [] -> + :ok + end + + :ok + end + + # Private helpers + + defp get_or_create_pool(sprite) do + key = sprite_key(sprite) + + case :ets.lookup(@pools_table, key) do + [{^key, pool}] -> + if Process.alive?(pool) do + pool + else + create_pool(sprite, key) + end + + [] -> + create_pool(sprite, key) + end + end + + defp create_pool(sprite, key) do + url = Sprite.control_url(sprite) + token = Sprite.token(sprite) + + {:ok, pool} = ControlPool.start(url: url, token: token) + :ets.insert(@pools_table, {key, pool}) + pool + end + + defp sprite_key(%Sprite{name: name, client: client}) do + {client.base_url, name} + end +end diff --git a/lib/sprites/control_conn.ex b/lib/sprites/control_conn.ex new file mode 100644 index 0000000..9b1fe76 --- /dev/null +++ b/lib/sprites/control_conn.ex @@ -0,0 +1,336 @@ +defmodule Sprites.ControlConn do + @moduledoc """ + A persistent WebSocket connection to a sprite's control endpoint. + + Multiplexes exec operations over a single connection at `/v1/sprites/{name}/control`. + Each connection handles one operation at a time — the pool manages concurrency. + + Messages sent to the owner process during an active operation: + + * `{:control_data, :binary, data}` — binary frame (stdout/stderr/exit in protocol encoding) + * `{:control_data, :text, text}` — text frame (JSON messages like port, resize) + * `{:control_op_complete, exit_code}` — operation completed + * `{:control_op_error, message}` — operation errored + """ + + use GenServer + require Logger + + @control_prefix "control:" + + # Client API + + @doc """ + Starts a control connection to the given sprite. + + Returns `{:error, {:control_not_supported, status}}` if the server returns 404. + """ + @spec start_link(keyword()) :: GenServer.on_start() + def start_link(opts) do + GenServer.start_link(__MODULE__, opts) + end + + @doc """ + Starts a control connection (unlinked) to the given sprite. + + Returns `{:error, {:control_not_supported, status}}` if the server returns 404. + """ + @spec start(keyword()) :: GenServer.on_start() + def start(opts) do + GenServer.start(__MODULE__, opts) + end + + @doc """ + Starts an operation on this control connection. + + Sends an `op.start` control message. The owner process will receive + data frames and completion messages. + """ + @spec start_op(pid(), pid(), String.t(), map()) :: :ok | {:error, term()} + def start_op(pid, owner, op, args) do + GenServer.call(pid, {:start_op, owner, op, args}) + end + + @doc """ + Sends binary data (stdin) through the control connection. + """ + @spec send_data(pid(), binary()) :: :ok + def send_data(pid, data) do + GenServer.cast(pid, {:send_data, data}) + end + + @doc """ + Sends a text frame through the control connection. + """ + @spec send_text(pid(), String.t()) :: :ok + def send_text(pid, text) do + GenServer.cast(pid, {:send_text, text}) + end + + @doc """ + Releases the connection, clearing the current owner. + """ + @spec release(pid()) :: :ok + def release(pid) do + GenServer.cast(pid, :release) + end + + @doc """ + Closes the control connection. + """ + @spec close(pid()) :: :ok + def close(pid) do + GenServer.cast(pid, :close) + end + + # GenServer callbacks + + @impl true + def init(opts) do + url = Keyword.fetch!(opts, :url) + token = Keyword.fetch!(opts, :token) + + case do_connect(url, token) do + {:ok, conn, stream_ref} -> + {:ok, + %{ + conn: conn, + stream_ref: stream_ref, + owner: nil, + op_active: false + }} + + {:error, {:upgrade_failed, 404}} -> + {:stop, {:control_not_supported, 404}} + + {:error, %Sprites.Error.APIError{status: 404}} -> + {:stop, {:control_not_supported, 404}} + + {:error, reason} -> + {:stop, reason} + end + end + + @impl true + def handle_call({:start_op, owner, op, args}, _from, state) do + if state.op_active do + {:reply, {:error, :operation_in_progress}, state} + else + msg = + Jason.encode!(%{ + type: "op.start", + op: op, + args: args + }) + + frame = @control_prefix <> msg + :gun.ws_send(state.conn, state.stream_ref, {:text, frame}) + {:reply, :ok, %{state | owner: owner, op_active: true}} + end + end + + @impl true + def handle_cast({:send_data, data}, %{conn: conn, stream_ref: stream_ref} = state) do + :gun.ws_send(conn, stream_ref, {:binary, data}) + {:noreply, state} + end + + def handle_cast({:send_text, text}, %{conn: conn, stream_ref: stream_ref} = state) do + :gun.ws_send(conn, stream_ref, {:text, text}) + {:noreply, state} + end + + def handle_cast(:release, state) do + {:noreply, %{state | owner: nil, op_active: false}} + end + + def handle_cast(:close, %{conn: conn} = state) do + :gun.close(conn) + {:stop, :normal, %{state | conn: nil}} + end + + @impl true + def handle_info({:gun_ws, conn, _stream_ref, {:binary, data}}, %{conn: conn} = state) do + if state.owner do + send(state.owner, {:control_data, :binary, data}) + end + + {:noreply, state} + end + + def handle_info({:gun_ws, conn, _stream_ref, {:text, text}}, %{conn: conn} = state) do + if String.starts_with?(text, @control_prefix) do + payload = String.slice(text, byte_size(@control_prefix), byte_size(text)) + handle_control_message(payload, state) + else + if state.owner do + send(state.owner, {:control_data, :text, text}) + end + + {:noreply, state} + end + end + + def handle_info({:gun_ws, conn, _stream_ref, {:close, _code, _reason}}, %{conn: conn} = state) do + if state.owner && state.op_active do + send(state.owner, {:control_op_error, "connection closed"}) + end + + {:stop, :normal, state} + end + + def handle_info({:gun_down, conn, _protocol, _reason, _killed}, %{conn: conn} = state) do + if state.owner && state.op_active do + send(state.owner, {:control_op_error, "connection down"}) + end + + {:stop, :normal, state} + end + + def handle_info({:gun_error, conn, _stream_ref, reason}, %{conn: conn} = state) do + if state.owner && state.op_active do + send(state.owner, {:control_op_error, "gun error: #{inspect(reason)}"}) + end + + {:stop, :normal, state} + end + + def handle_info({:gun_error, conn, reason}, %{conn: conn} = state) do + if state.owner && state.op_active do + send(state.owner, {:control_op_error, "gun error: #{inspect(reason)}"}) + end + + {:stop, :normal, state} + end + + def handle_info(_message, state) do + {:noreply, state} + end + + @impl true + def terminate(_reason, %{conn: conn}) when conn != nil do + :gun.close(conn) + :ok + end + + def terminate(_reason, _state), do: :ok + + # Private helpers + + defp handle_control_message(payload, state) do + case Jason.decode(payload) do + {:ok, %{"type" => "op.complete", "args" => %{"exitCode" => exit_code}}} -> + if state.owner do + send(state.owner, {:control_op_complete, exit_code}) + end + + {:noreply, %{state | op_active: false}} + + {:ok, %{"type" => "op.complete"}} -> + if state.owner do + send(state.owner, {:control_op_complete, 0}) + end + + {:noreply, %{state | op_active: false}} + + {:ok, %{"type" => "op.error", "args" => %{"error" => error}}} -> + if state.owner do + send(state.owner, {:control_op_error, error}) + end + + {:noreply, %{state | op_active: false}} + + {:ok, %{"type" => "op.error"}} -> + if state.owner do + send(state.owner, {:control_op_error, "unknown error"}) + end + + {:noreply, %{state | op_active: false}} + + _ -> + {:noreply, state} + end + end + + defp do_connect(url, token) do + uri = URI.parse(url) + host = String.to_charlist(uri.host) + port = uri.port || if(uri.scheme == "wss", do: 443, else: 80) + + transport = if uri.scheme == "wss", do: :tls, else: :tcp + + gun_opts = %{ + protocols: [:http], + transport: transport, + tls_opts: [ + verify: :verify_peer, + cacerts: :public_key.cacerts_get(), + depth: 3, + customize_hostname_check: [ + match_fun: :public_key.pkix_verify_hostname_match_fun(:https) + ] + ] + } + + case :gun.open(host, port, gun_opts) do + {:ok, conn} -> + case :gun.await_up(conn, 10_000) do + {:ok, _protocol} -> + path = "#{uri.path}?#{uri.query || ""}" + headers = [{"authorization", "Bearer #{token}"}] + stream_ref = :gun.ws_upgrade(conn, path, headers) + + receive do + {:gun_upgrade, ^conn, ^stream_ref, ["websocket"], _headers} -> + {:ok, conn, stream_ref} + + {:gun_response, ^conn, ^stream_ref, is_fin, status, headers} -> + body = read_response_body(conn, stream_ref, is_fin) + :gun.close(conn) + + case Sprites.Error.parse_api_error(status, body, headers) do + {:ok, %Sprites.Error.APIError{} = api_error} -> + {:error, api_error} + + {:ok, nil} -> + {:error, {:upgrade_failed, status}} + end + + {:gun_error, ^conn, ^stream_ref, reason} -> + :gun.close(conn) + {:error, reason} + + {:gun_error, ^conn, reason} -> + :gun.close(conn) + {:error, reason} + after + 10_000 -> + :gun.close(conn) + {:error, :upgrade_timeout} + end + + {:error, reason} -> + :gun.close(conn) + {:error, reason} + end + + {:error, reason} -> + {:error, reason} + end + end + + defp read_response_body(_conn, _stream_ref, :fin), do: "" + + defp read_response_body(conn, stream_ref, :nofin) do + receive do + {:gun_data, ^conn, ^stream_ref, :fin, data} -> + data + + {:gun_data, ^conn, ^stream_ref, :nofin, data} -> + data <> read_response_body(conn, stream_ref, :nofin) + after + 5_000 -> + "" + end + end +end diff --git a/lib/sprites/control_pool.ex b/lib/sprites/control_pool.ex new file mode 100644 index 0000000..f8fbef3 --- /dev/null +++ b/lib/sprites/control_pool.ex @@ -0,0 +1,175 @@ +defmodule Sprites.ControlPool do + @moduledoc """ + Pool of `Sprites.ControlConn` processes for a single sprite. + + Manages connection lifecycle with checkout/checkin semantics and + automatic draining when the pool grows too large. + """ + + use GenServer + require Logger + + alias Sprites.ControlConn + + @max_pool_size 100 + @drain_threshold 20 + @drain_target 10 + + # Client API + + @doc """ + Starts a control pool for the given sprite. + """ + @spec start_link(keyword()) :: GenServer.on_start() + def start_link(opts) do + GenServer.start_link(__MODULE__, opts) + end + + @doc """ + Starts a control pool (unlinked) for the given sprite. + """ + @spec start(keyword()) :: GenServer.on_start() + def start(opts) do + GenServer.start(__MODULE__, opts) + end + + @doc """ + Checks out a control connection from the pool. + + Returns an idle connection or creates a new one. + Returns `{:error, {:control_not_supported, 404}}` if the server doesn't support control mode. + """ + @spec checkout(pid()) :: {:ok, pid()} | {:error, term()} + def checkout(pid) do + GenServer.call(pid, :checkout, 30_000) + end + + @doc """ + Returns a connection to the pool. + """ + @spec checkin(pid(), pid()) :: :ok + def checkin(pid, conn_pid) do + GenServer.cast(pid, {:checkin, conn_pid}) + end + + @doc """ + Closes all connections in the pool and stops the GenServer. + """ + @spec close(pid()) :: :ok + def close(pid) do + GenServer.cast(pid, :close) + end + + # GenServer callbacks + + @impl true + def init(opts) do + {:ok, + %{ + url: Keyword.fetch!(opts, :url), + token: Keyword.fetch!(opts, :token), + conns: %{} + }} + end + + @impl true + def handle_call(:checkout, _from, state) do + # Try to find an idle connection + case find_idle(state.conns) do + {:ok, conn_pid} -> + {_status, ref} = Map.get(state.conns, conn_pid) + conns = Map.put(state.conns, conn_pid, {:busy, ref}) + {:reply, {:ok, conn_pid}, %{state | conns: conns}} + + :none -> + if map_size(state.conns) >= @max_pool_size do + {:reply, {:error, :pool_full}, state} + else + case create_conn(state.url, state.token) do + {:ok, conn_pid} -> + ref = Process.monitor(conn_pid) + conns = Map.put(state.conns, conn_pid, {:busy, ref}) + {:reply, {:ok, conn_pid}, %{state | conns: conns}} + + {:error, reason} -> + {:reply, {:error, reason}, state} + end + end + end + end + + @impl true + def handle_cast({:checkin, conn_pid}, state) do + case Map.get(state.conns, conn_pid) do + nil -> + {:noreply, state} + + {_status, ref} -> + ControlConn.release(conn_pid) + conns = Map.put(state.conns, conn_pid, {:idle, ref}) + state = %{state | conns: conns} + {:noreply, maybe_drain(state)} + end + end + + def handle_cast(:close, state) do + Enum.each(state.conns, fn {conn_pid, _status} -> + ControlConn.close(conn_pid) + end) + + {:stop, :normal, %{state | conns: %{}}} + end + + @impl true + def handle_info({:DOWN, _ref, :process, conn_pid, _reason}, state) do + conns = Map.delete(state.conns, conn_pid) + {:noreply, %{state | conns: conns}} + end + + def handle_info(_message, state) do + {:noreply, state} + end + + # Private helpers + + defp find_idle(conns) do + case Enum.find(conns, fn {_pid, status} -> match?({:idle, _}, status) end) do + {pid, _} -> {:ok, pid} + nil -> :none + end + end + + defp create_conn(url, token) do + case ControlConn.start(url: url, token: token) do + {:ok, pid} -> {:ok, pid} + {:error, reason} -> {:error, reason} + end + end + + defp maybe_drain(state) do + if map_size(state.conns) > @drain_threshold do + do_drain(state) + else + state + end + end + + defp do_drain(state) do + to_close = map_size(state.conns) - @drain_target + + if to_close > 0 do + idle_pids = + state.conns + |> Enum.filter(fn {_pid, status} -> match?({:idle, _}, status) end) + |> Enum.map(fn {pid, _} -> pid end) + |> Enum.take(to_close) + + Enum.each(idle_pids, &ControlConn.close/1) + + conns = Map.drop(state.conns, idle_pids) + %{state | conns: conns} + else + state + end + end +end diff --git a/lib/sprites/sprite.ex b/lib/sprites/sprite.ex index 612c038..511c889 100644 --- a/lib/sprites/sprite.ex +++ b/lib/sprites/sprite.ex @@ -52,6 +52,26 @@ defmodule Sprites.Sprite do "#{base}#{path}?#{URI.encode_query(query_params)}" end + @doc """ + Builds the WebSocket URL for the control endpoint. + """ + @spec control_url(t()) :: String.t() + def control_url(%__MODULE__{client: client, name: name}) do + base = + client.base_url + |> String.replace(~r/^http/, "ws") + + "#{base}/v1/sprites/#{URI.encode(name)}/control" + end + + @doc """ + Returns whether control mode is enabled for this sprite's client. + """ + @spec control_mode?(t()) :: boolean() + def control_mode?(%__MODULE__{client: client}) do + client.control_mode + end + @doc """ Returns the authorization token for this sprite's client. """ diff --git a/test/control_test.exs b/test/control_test.exs new file mode 100644 index 0000000..b5373c5 --- /dev/null +++ b/test/control_test.exs @@ -0,0 +1,81 @@ +defmodule Sprites.ControlTest do + use ExUnit.Case, async: true + + alias Sprites.{Client, Sprite, Control} + + describe "Client control_mode" do + test "defaults to false" do + client = Client.new("test-token") + assert client.control_mode == false + end + + test "can be set to true" do + client = Client.new("test-token", control_mode: true) + assert client.control_mode == true + end + + test "can be explicitly set to false" do + client = Client.new("test-token", control_mode: false) + assert client.control_mode == false + end + end + + describe "Sprite.control_mode?/1" do + test "reflects client control_mode setting" do + client = Client.new("test-token", control_mode: true) + sprite = Sprite.new(client, "my-sprite") + assert Sprite.control_mode?(sprite) == true + end + + test "returns false when client has control_mode disabled" do + client = Client.new("test-token") + sprite = Sprite.new(client, "my-sprite") + assert Sprite.control_mode?(sprite) == false + end + end + + describe "Sprite.control_url/1" do + test "builds correct URL with https base" do + client = Client.new("test-token", base_url: "https://api.sprites.dev") + sprite = Sprite.new(client, "my-sprite") + assert Sprite.control_url(sprite) == "wss://api.sprites.dev/v1/sprites/my-sprite/control" + end + + test "builds correct URL with http base" do + client = Client.new("test-token", base_url: "http://localhost:8080") + sprite = Sprite.new(client, "my-sprite") + assert Sprite.control_url(sprite) == "ws://localhost:8080/v1/sprites/my-sprite/control" + end + + test "encodes sprite name in URL" do + client = Client.new("test-token", base_url: "https://api.sprites.dev") + sprite = Sprite.new(client, "my sprite") + + assert Sprite.control_url(sprite) == + "wss://api.sprites.dev/v1/sprites/my%20sprite/control" + end + end + + describe "Control.control_supported?/1" do + test "returns true by default for new sprites" do + client = Client.new("test-token") + sprite = Sprite.new(client, "supported-sprite-#{System.unique_integer([:positive])}") + assert Control.control_supported?(sprite) == true + end + + test "returns false after mark_unsupported" do + client = Client.new("test-token") + sprite = Sprite.new(client, "unsupported-sprite-#{System.unique_integer([:positive])}") + + Control.mark_unsupported(sprite) + assert Control.control_supported?(sprite) == false + end + end + + describe "Sprites.new/2" do + test "passes control_mode to client" do + client = Sprites.new("test-token", control_mode: true) + assert client.control_mode == true + end + end +end