Skip to content

Commit 75f75cf

Browse files
committed
fix bedrock models not running properly in langchain
1 parent 1d78f8a commit 75f75cf

2 files changed

Lines changed: 22 additions & 8 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_provider.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,18 @@ def map_provider(ld_provider_name: str) -> str:
134134
"""
135135
Map LaunchDarkly provider names to LangChain provider names.
136136
137-
This method enables seamless integration between LaunchDarkly's standardized
138-
provider naming and LangChain's naming conventions.
139-
140137
:param ld_provider_name: LaunchDarkly provider name
141138
:return: LangChain-compatible provider name
142139
"""
143140
lowercased_name = ld_provider_name.lower()
141+
# Bedrock is the only provider that uses "provider:model_family" (e.g. Bedrock:Anthropic).
142+
if lowercased_name.startswith('bedrock:'):
143+
return 'bedrock_converse'
144144

145145
mapping: Dict[str, str] = {
146146
'gemini': 'google-genai',
147+
'bedrock': 'bedrock_converse',
147148
}
148-
149149
return mapping.get(lowercased_name, lowercased_name)
150150

151151
@staticmethod
@@ -232,10 +232,15 @@ def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel:
232232

233233
model_name = model_dict.get('name', '')
234234
provider = provider_dict.get('name', '')
235-
parameters = model_dict.get('parameters') or {}
235+
parameters = dict(model_dict.get('parameters') or {})
236+
mapped_provider = LangChainProvider.map_provider(provider)
236237

238+
# Bedrock requires the foundation provider (e.g. Bedrock:Anthropic) passed in
239+
# parameters separately from model_provider, which is used for LangChain routing.
240+
if mapped_provider == 'bedrock_converse' and 'provider' not in parameters:
241+
parameters['provider'] = provider
237242
return init_chat_model(
238243
model_name,
239-
model_provider=LangChainProvider.map_provider(provider),
244+
model_provider=mapped_provider,
240245
**parameters,
241246
)

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ def test_maps_gemini_to_google_genai(self):
130130
assert LangChainProvider.map_provider('Gemini') == 'google-genai'
131131
assert LangChainProvider.map_provider('GEMINI') == 'google-genai'
132132

133+
def test_maps_bedrock_and_model_families_to_bedrock_converse(self):
134+
"""Should map bedrock and bedrock:model_family to bedrock_converse."""
135+
assert LangChainProvider.map_provider('bedrock') == 'bedrock_converse'
136+
assert LangChainProvider.map_provider('Bedrock:Anthropic') == 'bedrock_converse'
137+
assert LangChainProvider.map_provider('bedrock:anthropic') == 'bedrock_converse'
138+
assert LangChainProvider.map_provider('bedrock:amazon') == 'bedrock_converse'
139+
assert LangChainProvider.map_provider('bedrock:cohere') == 'bedrock_converse'
140+
133141
def test_returns_provider_name_unchanged_for_unmapped_providers(self):
134142
"""Should return provider name unchanged for unmapped providers."""
135143
assert LangChainProvider.map_provider('openai') == 'openai'
@@ -197,7 +205,8 @@ def mock_llm(self):
197205
@pytest.mark.asyncio
198206
async def test_returns_success_true_for_successful_invocation(self, mock_llm):
199207
"""Should return success=True for successful invocation."""
200-
mock_response = {'result': 'structured data'}
208+
parsed_data = {'result': 'structured data'}
209+
mock_response = {'parsed': parsed_data, 'raw': None}
201210
mock_structured_llm = MagicMock()
202211
mock_structured_llm.ainvoke = AsyncMock(return_value=mock_response)
203212
mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm)
@@ -208,7 +217,7 @@ async def test_returns_success_true_for_successful_invocation(self, mock_llm):
208217
result = await provider.invoke_structured_model(messages, response_structure)
209218

210219
assert result.metrics.success is True
211-
assert result.data == mock_response
220+
assert result.data == parsed_data
212221

213222
@pytest.mark.asyncio
214223
async def test_returns_success_false_when_structured_model_invocation_throws_error(self, mock_llm):

0 commit comments

Comments
 (0)