Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,33 @@ def function_for(tool)

def extract_tool_calls(data)
if json_delta?(data)
{ nil => ToolCall.new(id: nil, name: nil, arguments: data.dig('delta', 'partial_json')) }
# Use the content block index as the hash key so the accumulator
# can route fragments to the correct tool call during parallel
# streaming. Without this, all fragments go to @latest_tool_call_id
# and parallel tool call arguments get concatenated together.
block_index = data['index']
key = block_index ? "block_idx_#{block_index}" : nil
{ key => ToolCall.new(id: nil, name: nil, arguments: data.dig('delta', 'partial_json')) }
elsif data['type'] == 'content_block_start' && data.dig('content_block', 'type') == 'tool_use'
block = data['content_block']
build_tool_use_start(block, data['index'])
else
parse_tool_calls(data['content_block'])
end
end

def build_tool_use_start(block, block_index)
input = block['input']
args = input.is_a?(Hash) && input.empty? ? +'' : (input || +'')
tool_calls = { block['id'] => ToolCall.new(id: block['id'], name: block['name'], arguments: args) }
if block_index
tool_calls["register_idx_#{block_index}"] = ToolCall.new(
id: block['id'], name: '_register_block_index', arguments: nil
)
end
tool_calls
end

def parse_tool_calls(content_blocks)
return nil if content_blocks.nil?

Expand Down
63 changes: 39 additions & 24 deletions lib/ruby_llm/stream_accumulator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,51 @@ def tool_calls_from_stream
end
end

def accumulate_tool_calls(new_tool_calls) # rubocop:disable Metrics/PerceivedComplexity
def accumulate_tool_calls(new_tool_calls)
@block_index_to_tool_call_id ||= {}
RubyLLM.logger.debug { "Accumulating tool calls: #{new_tool_calls}" } if RubyLLM.config.log_stream_debug
new_tool_calls.each_value do |tool_call|
if tool_call.id
tool_call_id = tool_call.id.empty? ? SecureRandom.uuid : tool_call.id
tool_call_arguments = tool_call.arguments
if tool_call_arguments.nil? || (tool_call_arguments.respond_to?(:empty?) && tool_call_arguments.empty?)
tool_call_arguments = +''
end
@tool_calls[tool_call.id] = ToolCall.new(
id: tool_call_id,
name: tool_call.name,
arguments: tool_call_arguments,
thought_signature: tool_call.thought_signature
)
@latest_tool_call_id = tool_call.id
new_tool_calls.each do |key, tool_call|
if register_block_index?(key, tool_call)
block_key = key.to_s.sub('register_idx_', 'block_idx_')
@block_index_to_tool_call_id[block_key] = tool_call.id
elsif tool_call.id
register_tool_call(tool_call)
else
existing = @tool_calls[@latest_tool_call_id]
if existing
fragment = tool_call.arguments
fragment = '' if fragment.nil?
existing.arguments << fragment
if tool_call.thought_signature && existing.thought_signature.nil?
existing.thought_signature = tool_call.thought_signature
end
end
append_tool_call_fragment(key, tool_call)
end
end
end

def register_block_index?(key, tool_call)
tool_call.name == '_register_block_index' && key.to_s.start_with?('register_idx_')
end

def register_tool_call(tool_call)
tool_call_id = tool_call.id.empty? ? SecureRandom.uuid : tool_call.id
tool_call_arguments = tool_call.arguments
if tool_call_arguments.nil? || (tool_call_arguments.respond_to?(:empty?) && tool_call_arguments.empty?)
tool_call_arguments = +''
end
@tool_calls[tool_call.id] = ToolCall.new(
id: tool_call_id,
name: tool_call.name,
arguments: tool_call_arguments,
thought_signature: tool_call.thought_signature
)
@latest_tool_call_id = tool_call.id
end

def append_tool_call_fragment(key, tool_call)
target_id = @block_index_to_tool_call_id[key] || @latest_tool_call_id
existing = @tool_calls[target_id]
return unless existing

existing.arguments << (tool_call.arguments || '')
return unless tool_call.thought_signature && existing.thought_signature.nil?

existing.thought_signature = tool_call.thought_signature
end

def find_tool_call(tool_call_id)
if tool_call_id.nil?
@tool_calls[@latest_tool_call]
Expand Down
44 changes: 44 additions & 0 deletions spec/ruby_llm/providers/anthropic/tools_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,48 @@
expect(described_class.parse_tool_calls([])).to be_nil
end
end

describe '.extract_tool_calls (streaming)' do
# Build a test object that mixes in the Streaming + Tools modules
# so we can call extract_tool_calls with raw Anthropic event data.
let(:provider) do
Class.new do
include RubyLLM::Providers::Anthropic::Streaming
include RubyLLM::Providers::Anthropic::Tools

public :extract_tool_calls, :json_delta?
end.new
end

it 'emits block index key for input_json_delta events' do
data = {
'type' => 'content_block_delta',
'index' => 3,
'delta' => { 'type' => 'input_json_delta', 'partial_json' => '{"sym' }
}

result = provider.extract_tool_calls(data)

expect(result).to have_key('block_idx_3')
expect(result['block_idx_3'].arguments).to eq('{"sym')
expect(result['block_idx_3'].id).to be_nil
end

it 'emits tool call and index registration for content_block_start' do
data = {
'type' => 'content_block_start',
'index' => 2,
'content_block' => { 'type' => 'tool_use', 'id' => 'toolu_abc', 'name' => 'weather', 'input' => {} }
}

result = provider.extract_tool_calls(data)

expect(result).to have_key('toolu_abc')
expect(result['toolu_abc'].name).to eq('weather')

expect(result).to have_key('register_idx_2')
expect(result['register_idx_2'].id).to eq('toolu_abc')
expect(result['register_idx_2'].name).to eq('_register_block_index')
end
end
end
57 changes: 57 additions & 0 deletions spec/ruby_llm/stream_accumulator_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,62 @@
message = accumulator.to_message(nil)
expect(message.tool_calls['call_1'].arguments).to eq({})
end

context 'with parallel streaming tool calls' do
it 'keeps arguments separate when block index is provided' do
accumulator = described_class.new

register_tool_call(accumulator, id: 'toolu_A', name: 'get_market_data', block_index: 1)
register_tool_call(accumulator, id: 'toolu_B', name: 'web_search', block_index: 2)

add_delta(accumulator, 'block_idx_1', '{"symbol":')
add_delta(accumulator, 'block_idx_2', '{"query":')
add_delta(accumulator, 'block_idx_1', '"MNQM26"}')
add_delta(accumulator, 'block_idx_2', '"market news"}')

message = accumulator.to_message(nil)

expect(message.tool_calls['toolu_A'].arguments).to eq({ 'symbol' => 'MNQM26' })
expect(message.tool_calls['toolu_B'].arguments).to eq({ 'query' => 'market news' })
end
end

it 'falls back to latest_tool_call_id when no block index is provided' do
accumulator = described_class.new

tc = RubyLLM::ToolCall.new(id: 'call_1', name: 'weather', arguments: +'')
accumulator.add(
RubyLLM::Chunk.new(role: :assistant, content: nil, tool_calls: { 'call_1' => tc })
)

delta = RubyLLM::ToolCall.new(id: nil, name: nil, arguments: '{"city":"NYC"}')
accumulator.add(
RubyLLM::Chunk.new(role: :assistant, content: nil, tool_calls: { nil => delta })
)

message = accumulator.to_message(nil)
expect(message.tool_calls['call_1'].arguments).to eq({ 'city' => 'NYC' })
end
end

private

def register_tool_call(accumulator, id:, name:, block_index:)
tc = RubyLLM::ToolCall.new(id: id, name: name, arguments: +'')
register = RubyLLM::ToolCall.new(id: id, name: '_register_block_index', arguments: nil)
accumulator.add(
RubyLLM::Chunk.new(
role: :assistant,
content: nil,
tool_calls: { id => tc, "register_idx_#{block_index}" => register }
)
)
end

def add_delta(accumulator, block_key, json_fragment)
delta = RubyLLM::ToolCall.new(id: nil, name: nil, arguments: json_fragment)
accumulator.add(
RubyLLM::Chunk.new(role: :assistant, content: nil, tool_calls: { block_key => delta })
)
end
end