Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ldai import LDMessage
from ldai import LDMessage, log
from ldai.models import AIConfigKind
from ldai.providers import AIProvider
from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse
Expand All @@ -18,27 +18,24 @@ class LangChainProvider(AIProvider):
This provider integrates LangChain models with LaunchDarkly's tracking capabilities.
"""

def __init__(self, llm: BaseChatModel, logger: Optional[Any] = None):
def __init__(self, llm: BaseChatModel):
"""
Initialize the LangChain provider.

:param llm: A LangChain BaseChatModel instance
:param logger: Optional logger for logging provider operations
"""
super().__init__(logger)
self._llm = llm

@staticmethod
async def create(ai_config: AIConfigKind, logger: Optional[Any] = None) -> 'LangChainProvider':
async def create(ai_config: AIConfigKind) -> 'LangChainProvider':
"""
Static factory method to create a LangChain AIProvider from an AI configuration.

:param ai_config: The LaunchDarkly AI configuration
:param logger: Optional logger for the provider
:return: Configured LangChainProvider instance
"""
llm = LangChainProvider.create_langchain_model(ai_config)
return LangChainProvider(llm, logger)
return LangChainProvider(llm)

async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
"""
Expand All @@ -56,20 +53,18 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
if isinstance(response.content, str):
content = response.content
else:
if self.logger:
self.logger.warn(
f'Multimodal response not supported, expecting a string. '
f'Content type: {type(response.content)}, Content: {response.content}'
)
log.warn(
f'Multimodal response not supported, expecting a string. '
f'Content type: {type(response.content)}, Content: {response.content}'
)
metrics = LDAIMetrics(success=False, usage=metrics.usage)

return ChatResponse(
message=LDMessage(role='assistant', content=content),
metrics=metrics,
)
except Exception as error:
if self.logger:
self.logger.warn(f'LangChain model invocation failed: {error}')
log.warn(f'LangChain model invocation failed: {error}')

return ChatResponse(
message=LDMessage(role='assistant', content=''),
Expand All @@ -94,11 +89,10 @@ async def invoke_structured_model(
response = await structured_llm.ainvoke(langchain_messages)

if not isinstance(response, dict):
if self.logger:
self.logger.warn(
f'Structured output did not return a dict. '
f'Got: {type(response)}'
)
log.warn(
f'Structured output did not return a dict. '
f'Got: {type(response)}'
)
return StructuredResponse(
data={},
raw_response='',
Expand All @@ -117,8 +111,7 @@ async def invoke_structured_model(
),
)
except Exception as error:
if self.logger:
self.logger.warn(f'LangChain structured model invocation failed: {error}')
log.warn(f'LangChain structured model invocation failed: {error}')

return StructuredResponse(
data={},
Expand Down
Loading