Skip to content
22 changes: 22 additions & 0 deletions lib/ruby_llm/providers/bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module Providers
class Bedrock < Provider
include Bedrock::Auth
include Bedrock::Chat
include Bedrock::Embeddings
include Bedrock::Media
include Bedrock::Models
include Bedrock::Streaming
Expand Down Expand Up @@ -47,6 +48,27 @@ def parse_error(response)
body['message'] || body['Message'] || body['error'] || body['__type'] || super
end

def embed(text, model:, dimensions:)
texts = [text].flatten
url = embedding_url(model:)

results = texts.map do |t|
payload = render_embedding_payload(t, model:, dimensions:)
body = JSON.generate(payload)
signed_hdrs = sign_headers('POST', url, body)

@connection.post(url, payload) do |req|
req.headers.merge!(signed_hdrs)
end
end

vectors = results.map { |r| r.body['embedding'] }
input_tokens = results.sum { |r| r.body['inputTextTokenCount'] || 0 }
vectors = vectors.first unless text.is_a?(Array)

Embedding.new(vectors:, model:, input_tokens:)
end

def list_models
response = signed_get(models_api_base, models_url)
parse_list_models_response(response, slug, capabilities)
Expand Down
31 changes: 31 additions & 0 deletions lib/ruby_llm/providers/bedrock/embeddings.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# frozen_string_literal: true

module RubyLLM
module Providers
class Bedrock
# Embeddings methods for AWS Bedrock InvokeModel API.
module Embeddings
module_function

def embedding_url(model:)
"/model/#{model}/invoke"
end

def render_embedding_payload(text, model:, dimensions:) # rubocop:disable Lint/UnusedMethodArgument
payload = { inputText: text.to_s }
payload[:dimensions] = dimensions if dimensions
payload[:normalize] = true
payload
end

def parse_embedding_response(response, model:, text:) # rubocop:disable Lint/UnusedMethodArgument
data = response.body
vectors = data['embedding']
input_tokens = data['inputTextTokenCount'] || 0

Embedding.new(vectors:, model:, input_tokens:)
end
end
end
end
end
2 changes: 2 additions & 0 deletions spec/ruby_llm/embeddings_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
it "#{provider}/#{model} can handle a single text with custom dimensions" do
skip 'Mistral does not support custom dimensions' if provider == :mistral
skip 'Azure Cohere embeddings do not support custom dimensions' if provider == :azure
skip 'Bedrock Titan only supports dimensions of 256, 512, or 1024' if provider == :bedrock

embedding = RubyLLM.embed(test_text, model: model, provider: provider, dimensions: test_dimensions)
expect(embedding.vectors).to be_an(Array)
Expand All @@ -42,6 +43,7 @@
it "#{provider}/#{model} can handle multiple texts with custom dimensions" do
skip 'Mistral does not support custom dimensions' if provider == :mistral
skip 'Azure Cohere embeddings do not support custom dimensions' if provider == :azure
skip 'Bedrock Titan only supports dimensions of 256, 512, or 1024' if provider == :bedrock

embeddings = RubyLLM.embed(test_texts, model: model, provider: provider, dimensions: test_dimensions)
expect(embeddings.vectors).to be_an(Array)
Expand Down
1 change: 1 addition & 0 deletions spec/support/models_to_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def filter_local_providers(models)

EMBEDDING_MODELS = [
{ provider: :azure, model: 'Cohere-embed-v3-english' },
{ provider: :bedrock, model: 'amazon.titan-embed-text-v2:0' },
{ provider: :gemini, model: 'gemini-embedding-001' },
{ provider: :mistral, model: 'mistral-embed' },
{ provider: :openai, model: 'text-embedding-3-small' },
Expand Down