diff --git a/lib/ruby_llm/providers/bedrock.rb b/lib/ruby_llm/providers/bedrock.rb index 86fa25e60..ec34b4df1 100644 --- a/lib/ruby_llm/providers/bedrock.rb +++ b/lib/ruby_llm/providers/bedrock.rb @@ -6,6 +6,7 @@ module Providers class Bedrock < Provider include Bedrock::Auth include Bedrock::Chat + include Bedrock::Embeddings include Bedrock::Media include Bedrock::Models include Bedrock::Streaming @@ -47,6 +48,27 @@ def parse_error(response) body['message'] || body['Message'] || body['error'] || body['__type'] || super end + def embed(text, model:, dimensions:) + texts = [text].flatten + url = embedding_url(model:) + + results = texts.map do |t| + payload = render_embedding_payload(t, model:, dimensions:) + body = JSON.generate(payload) + signed_hdrs = sign_headers('POST', url, body) + + @connection.post(url, payload) do |req| + req.headers.merge!(signed_hdrs) + end + end + + vectors = results.map { |r| r.body['embedding'] } + input_tokens = results.sum { |r| r.body['inputTextTokenCount'] || 0 } + vectors = vectors.first unless text.is_a?(Array) + + Embedding.new(vectors:, model:, input_tokens:) + end + def list_models response = signed_get(models_api_base, models_url) parse_list_models_response(response, slug, capabilities) diff --git a/lib/ruby_llm/providers/bedrock/embeddings.rb b/lib/ruby_llm/providers/bedrock/embeddings.rb new file mode 100644 index 000000000..22203e59f --- /dev/null +++ b/lib/ruby_llm/providers/bedrock/embeddings.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + class Bedrock + # Embeddings methods for AWS Bedrock InvokeModel API. + module Embeddings + module_function + + def embedding_url(model:) + "/model/#{model}/invoke" + end + + def render_embedding_payload(text, model:, dimensions:) # rubocop:disable Lint/UnusedMethodArgument + payload = { inputText: text.to_s } + payload[:dimensions] = dimensions if dimensions + payload[:normalize] = true + payload + end + + def parse_embedding_response(response, model:, text:) # rubocop:disable Lint/UnusedMethodArgument + data = response.body + vectors = data['embedding'] + input_tokens = data['inputTextTokenCount'] || 0 + + Embedding.new(vectors:, model:, input_tokens:) + end + end + end + end +end diff --git a/spec/ruby_llm/embeddings_spec.rb b/spec/ruby_llm/embeddings_spec.rb index 8cf78d442..bae5b265e 100644 --- a/spec/ruby_llm/embeddings_spec.rb +++ b/spec/ruby_llm/embeddings_spec.rb @@ -24,6 +24,7 @@ it "#{provider}/#{model} can handle a single text with custom dimensions" do skip 'Mistral does not support custom dimensions' if provider == :mistral skip 'Azure Cohere embeddings do not support custom dimensions' if provider == :azure + skip 'Bedrock Titan only supports dimensions of 256, 512, or 1024' if provider == :bedrock embedding = RubyLLM.embed(test_text, model: model, provider: provider, dimensions: test_dimensions) expect(embedding.vectors).to be_an(Array) @@ -42,6 +43,7 @@ it "#{provider}/#{model} can handle multiple texts with custom dimensions" do skip 'Mistral does not support custom dimensions' if provider == :mistral skip 'Azure Cohere embeddings do not support custom dimensions' if provider == :azure + skip 'Bedrock Titan only supports dimensions of 256, 512, or 1024' if provider == :bedrock embeddings = RubyLLM.embed(test_texts, model: model, provider: provider, dimensions: test_dimensions) expect(embeddings.vectors).to be_an(Array) diff --git a/spec/support/models_to_test.rb b/spec/support/models_to_test.rb index 29e90a007..1f92ba6fe 100644 --- a/spec/support/models_to_test.rb +++ b/spec/support/models_to_test.rb @@ -88,6 +88,7 @@ def filter_local_providers(models) EMBEDDING_MODELS = [ { provider: :azure, model: 'Cohere-embed-v3-english' }, + { provider: :bedrock, model: 'amazon.titan-embed-text-v2:0' }, { provider: :gemini, model: 'gemini-embedding-001' }, { provider: :mistral, model: 'mistral-embed' }, { provider: :openai, model: 'text-embedding-3-small' },