Skip to content

Commit 23dc35d

Browse files
feat(langchain): enhance conditional imports and type hinting for LangChain integration
- Implemented conditional import handling for the `langchain` library, allowing for graceful degradation when the library is not installed. - Improved type hints using forward references for `langchain` types to enhance code clarity and maintainability. - Introduced an informative error message when the `langchain` library is missing, guiding users on how to install it. - This update ensures better compatibility and user experience when working with optional dependencies in the LangChain integration.
1 parent d521c4b commit 23dc35d

1 file changed

Lines changed: 31 additions & 13 deletions

File tree

src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22

33
# pylint: disable=unused-argument
44
import time
5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
66
from uuid import UUID
77

8-
from langchain import schema as langchain_schema
9-
from langchain.callbacks.base import BaseCallbackHandler
8+
try:
9+
from langchain import schema as langchain_schema
10+
from langchain.callbacks.base import BaseCallbackHandler
11+
HAVE_LANGCHAIN = True
12+
except ImportError:
13+
HAVE_LANGCHAIN = False
14+
15+
if TYPE_CHECKING:
16+
from langchain import schema as langchain_schema
17+
from langchain.callbacks.base import BaseCallbackHandler
1018

1119
from ..tracing import tracer, steps, traces, enums
1220
from .. import utils
@@ -18,10 +26,20 @@
1826
}
1927

2028

21-
class OpenlayerHandler(BaseCallbackHandler):
29+
if HAVE_LANGCHAIN:
30+
BaseCallbackHandlerClass = BaseCallbackHandler
31+
else:
32+
BaseCallbackHandlerClass = object
33+
34+
35+
class OpenlayerHandler(BaseCallbackHandlerClass): # type: ignore[misc]
2236
"""LangChain callback handler that logs to Openlayer."""
2337

2438
def __init__(self, **kwargs: Any) -> None:
39+
if not HAVE_LANGCHAIN:
40+
raise ImportError(
41+
"LangChain library is not installed. Please install it with: pip install langchain"
42+
)
2543
super().__init__()
2644
self.metadata: Dict[str, Any] = kwargs or {}
2745
self.steps: Dict[UUID, steps.Step] = {}
@@ -197,7 +215,7 @@ def _convert_step_objects_recursively(self, step: steps.Step) -> None:
197215
def _convert_langchain_objects(self, obj: Any) -> Any:
198216
"""Recursively convert LangChain objects to JSON-serializable format."""
199217
# Explicit check for LangChain BaseMessage and its subclasses
200-
if isinstance(obj, langchain_schema.BaseMessage):
218+
if HAVE_LANGCHAIN and isinstance(obj, langchain_schema.BaseMessage):
201219
return self._message_to_dict(obj)
202220

203221
# Handle ChatPromptValue objects which contain messages
@@ -249,7 +267,7 @@ def _convert_langchain_objects(self, obj: Any) -> Any:
249267
# For everything else, convert to string
250268
return str(obj)
251269

252-
def _message_to_dict(self, message: langchain_schema.BaseMessage) -> Dict[str, str]:
270+
def _message_to_dict(self, message: "langchain_schema.BaseMessage") -> Dict[str, str]:
253271
"""Convert a LangChain message to a JSON-serializable dictionary."""
254272
message_type = getattr(message, "type", "user")
255273

@@ -262,7 +280,7 @@ def _message_to_dict(self, message: langchain_schema.BaseMessage) -> Dict[str, s
262280
return {"role": role, "content": str(message.content)}
263281

264282
def _messages_to_prompt_format(
265-
self, messages: List[List[langchain_schema.BaseMessage]]
283+
self, messages: List[List["langchain_schema.BaseMessage"]]
266284
) -> List[Dict[str, str]]:
267285
"""Convert LangChain messages to Openlayer prompt format using
268286
unified conversion."""
@@ -302,7 +320,7 @@ def _extract_model_info(
302320
}
303321

304322
def _extract_token_info(
305-
self, response: langchain_schema.LLMResult
323+
self, response: "langchain_schema.LLMResult"
306324
) -> Dict[str, Any]:
307325
"""Extract token information generically from LLM response."""
308326
llm_output = response.llm_output or {}
@@ -340,7 +358,7 @@ def _extract_token_info(
340358
"tokens": token_usage.get("total_tokens", 0),
341359
}
342360

343-
def _extract_output(self, response: langchain_schema.LLMResult) -> str:
361+
def _extract_output(self, response: "langchain_schema.LLMResult") -> str:
344362
"""Extract output text from LLM response."""
345363
output = ""
346364
for generations in response.generations:
@@ -384,7 +402,7 @@ def on_llm_start(
384402
def on_chat_model_start(
385403
self,
386404
serialized: Dict[str, Any],
387-
messages: List[List[langchain_schema.BaseMessage]],
405+
messages: List[List["langchain_schema.BaseMessage"]],
388406
*,
389407
run_id: UUID,
390408
parent_run_id: Optional[UUID] = None,
@@ -414,7 +432,7 @@ def on_chat_model_start(
414432

415433
def on_llm_end(
416434
self,
417-
response: langchain_schema.LLMResult,
435+
response: "langchain_schema.LLMResult",
418436
*,
419437
run_id: UUID,
420438
parent_run_id: Optional[UUID] = None,
@@ -590,7 +608,7 @@ def on_text(self, text: str, **kwargs: Any) -> Any:
590608

591609
def on_agent_action(
592610
self,
593-
action: langchain_schema.AgentAction,
611+
action: "langchain_schema.AgentAction",
594612
*,
595613
run_id: UUID,
596614
parent_run_id: Optional[UUID] = None,
@@ -612,7 +630,7 @@ def on_agent_action(
612630

613631
def on_agent_finish(
614632
self,
615-
finish: langchain_schema.AgentFinish,
633+
finish: "langchain_schema.AgentFinish",
616634
*,
617635
run_id: UUID,
618636
parent_run_id: Optional[UUID] = None,

0 commit comments

Comments
 (0)