Skip to content

Commit c9a297c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: migrate legacy langgraph imports to maintain compatibility
PiperOrigin-RevId: 906317678
1 parent 68f053e commit c9a297c

6 files changed

Lines changed: 220 additions & 110 deletions

File tree

tests/unit/vertex_langchain/test_agent_engine_templates_langgraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[mock.call.invoke.invoke(input={"input": "test query", "messages": [("user", "test query")]}, config=None)]
212212
)
213213

214214
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +217,7 @@ def test_stream_query(self, langchain_dump_mock):
217217
agent._tmpl_attrs["runnable"].stream.return_value = []
218218
list(agent.stream_query(input="test stream query"))
219219
agent._tmpl_attrs["runnable"].stream.assert_called_once_with(
220-
input={"input": "test stream query"},
220+
input={"input": "test stream query", "messages": [("user", "test stream query")]},
221221
config=None,
222222
)
223223

tests/unit/vertex_langchain/test_reasoning_engine_templates_langgraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[mock.call.invoke.invoke(input={"input": "test query", "messages": [("user", "test query")]}, config=None)]
212212
)
213213

214214
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +217,7 @@ def test_stream_query(self, langchain_dump_mock):
217217
agent._runnable.stream.return_value = []
218218
list(agent.stream_query(input="test stream query"))
219219
agent._runnable.stream.assert_called_once_with(
220-
input={"input": "test stream query"},
220+
input={"input": "test stream query", "messages": [("user", "test stream query")]},
221221
config=None,
222222
)
223223

vertexai/agent_engines/templates/langchain.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@
4343
RunnableSerializable = Any
4444

4545
try:
46-
from langchain_google_vertexai.functions_utils import _ToolsType
47-
48-
_ToolsType = _ToolsType
46+
from langchain_google_genai.functions_utils import _ToolsType
4947
except ImportError:
50-
_ToolsType = Any
48+
try:
49+
from langchain_google_vertexai.functions_utils import _ToolsType
50+
except ImportError:
51+
_ToolsType = Any
5152

5253
try:
5354
from opentelemetry.sdk import trace
@@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
8182

8283
def _default_output_parser():
8384
try:
84-
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
85+
from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser
8586
except (ModuleNotFoundError, ImportError):
86-
# Fallback to an older version if needed.
87-
from langchain.agents.output_parsers.openai_tools import (
88-
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
89-
)
90-
87+
try:
88+
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
89+
except (ModuleNotFoundError, ImportError):
90+
# Fallback to an older version if needed.
91+
from langchain.agents.output_parsers.openai_tools import (
92+
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
93+
)
9194
return ToolsAgentOutputParser()
9295

9396

@@ -98,17 +101,29 @@ def _default_model_builder(
98101
location: str,
99102
model_kwargs: Optional[Mapping[str, Any]] = None,
100103
) -> "BaseLanguageModel":
101-
import vertexai
102-
from google.cloud.aiplatform import initializer
103-
from langchain_google_vertexai import ChatVertexAI
104-
105104
model_kwargs = model_kwargs or {}
106-
current_project = initializer.global_config.project
107-
current_location = initializer.global_config.location
108-
vertexai.init(project=project, location=location)
109-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
110-
vertexai.init(project=current_project, location=current_location)
111-
return model
105+
try:
106+
from langchain_google_genai import ChatGoogleGenerativeAI
107+
108+
model = ChatGoogleGenerativeAI(
109+
model=model_name,
110+
project=project,
111+
location=location,
112+
vertexai=True,
113+
**model_kwargs,
114+
)
115+
return model
116+
except ImportError:
117+
import vertexai
118+
from google.cloud.aiplatform import initializer
119+
from langchain_google_vertexai import ChatVertexAI
120+
121+
current_project = initializer.global_config.project
122+
current_location = initializer.global_config.location
123+
vertexai.init(project=project, location=location)
124+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
125+
vertexai.init(project=current_project, location=current_location)
126+
return model
112127

113128

114129
def _default_runnable_builder(
@@ -124,8 +139,16 @@ def _default_runnable_builder(
124139
runnable_kwargs: Optional[Mapping[str, Any]] = None,
125140
) -> "RunnableSerializable":
126141
from langchain_core import tools as lc_tools
127-
from langchain.agents import AgentExecutor
128-
from langchain.tools.base import StructuredTool
142+
143+
try:
144+
from langchain_classic.agents import AgentExecutor
145+
except ImportError:
146+
from langchain.agents import AgentExecutor
147+
148+
try:
149+
from langchain_core.tools import StructuredTool
150+
except ImportError:
151+
from langchain.tools.base import StructuredTool
129152

130153
# The prompt template and runnable_kwargs needs to be customized depending
131154
# on whether the user intends for the agent to have history. The way the
@@ -261,12 +284,16 @@ def _default_prompt(
261284
from langchain_core import prompts
262285

263286
try:
264-
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
265-
except (ModuleNotFoundError, ImportError):
266-
# Fallback to an older version if needed.
267-
from langchain.agents.format_scratchpad.openai_tools import (
268-
format_to_openai_tool_messages as format_to_tool_messages,
287+
from langchain_classic.agents.format_scratchpad.tools import (
288+
format_to_tool_messages,
269289
)
290+
except (ModuleNotFoundError, ImportError):
291+
try:
292+
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
293+
except (ModuleNotFoundError, ImportError):
294+
from langchain.agents.format_scratchpad.openai_tools import (
295+
format_to_openai_tool_messages as format_to_tool_messages,
296+
)
270297

271298
system_instructions = []
272299
if system_instruction:
@@ -629,13 +656,18 @@ def query(
629656
Returns:
630657
The output of querying the Agent with the given input and config.
631658
"""
632-
from langchain.load import dump as langchain_load_dump
659+
try:
660+
from langchain_core.load import dumpd
661+
except ImportError:
662+
from langchain.load import dump as langchain_load_dump
663+
664+
dumpd = langchain_load_dump.dumpd
633665

634666
if isinstance(input, str):
635667
input = {"input": input}
636668
if not self._tmpl_attrs.get("runnable"):
637669
self.set_up()
638-
return langchain_load_dump.dumpd(
670+
return dumpd(
639671
self._tmpl_attrs.get("runnable").invoke(
640672
input=input, config=config, **kwargs
641673
)
@@ -662,7 +694,12 @@ def stream_query(
662694
Yields:
663695
The output of querying the Agent with the given input and config.
664696
"""
665-
from langchain.load import dump as langchain_load_dump
697+
try:
698+
from langchain_core.load import dumpd
699+
except ImportError:
700+
from langchain.load import dump as langchain_load_dump
701+
702+
dumpd = langchain_load_dump.dumpd
666703

667704
if isinstance(input, str):
668705
input = {"input": input}
@@ -673,4 +710,4 @@ def stream_query(
673710
config=config,
674711
**kwargs,
675712
):
676-
yield langchain_load_dump.dumpd(chunk)
713+
yield dumpd(chunk)

vertexai/agent_engines/templates/langgraph.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@
3434
BaseLanguageModel = Any
3535

3636
try:
37-
from langchain_google_vertexai.functions_utils import _ToolsType
38-
37+
from langchain_google_genai.functions_utils import _ToolsType
3938
_ToolLike = _ToolsType
4039
except ImportError:
41-
_ToolLike = Any
40+
try:
41+
from langchain_google_vertexai.functions_utils import _ToolsType
42+
_ToolLike = _ToolsType
43+
except ImportError:
44+
_ToolLike = Any
4245

4346
try:
4447
from opentelemetry.sdk import trace
@@ -53,12 +56,10 @@
5356

5457
try:
5558
from langgraph_checkpoint.checkpoint import base
56-
5759
BaseCheckpointSaver = base.BaseCheckpointSaver
5860
except ImportError:
5961
try:
6062
from langgraph.checkpoint import base
61-
6263
BaseCheckpointSaver = base.BaseCheckpointSaver
6364
except ImportError:
6465
BaseCheckpointSaver = Any
@@ -87,17 +88,29 @@ def _default_model_builder(
8788
Returns:
8889
BaseLanguageModel: The language model.
8990
"""
90-
import vertexai
91-
from google.cloud.aiplatform import initializer
92-
from langchain_google_vertexai import ChatVertexAI
93-
9491
model_kwargs = model_kwargs or {}
95-
current_project = initializer.global_config.project
96-
current_location = initializer.global_config.location
97-
vertexai.init(project=project, location=location)
98-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
99-
vertexai.init(project=current_project, location=current_location)
100-
return model
92+
try:
93+
from langchain_google_genai import ChatGoogleGenerativeAI
94+
95+
model = ChatGoogleGenerativeAI(
96+
model=model_name,
97+
project=project,
98+
location=location,
99+
vertexai=True,
100+
**model_kwargs,
101+
)
102+
return model
103+
except ImportError:
104+
import vertexai
105+
from google.cloud.aiplatform import initializer
106+
from langchain_google_vertexai import ChatVertexAI
107+
108+
current_project = initializer.global_config.project
109+
current_location = initializer.global_config.location
110+
vertexai.init(project=project, location=location)
111+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
112+
vertexai.init(project=current_project, location=current_location)
113+
return model
101114

102115

103116
def _default_runnable_builder(
@@ -554,13 +567,16 @@ def query(
554567
Returns:
555568
The output of querying the Agent with the given input and config.
556569
"""
557-
from langchain.load import dump as langchain_load_dump
570+
try:
571+
from langchain_core.load import dumpd
572+
except ImportError:
573+
from langchain.load.dump import dumpd
558574

559575
if isinstance(input, str):
560-
input = {"input": input}
576+
input = {"input": input, "messages": [("user", input)]}
561577
if not self._tmpl_attrs.get("runnable"):
562578
self.set_up()
563-
return langchain_load_dump.dumpd(
579+
return dumpd(
564580
self._tmpl_attrs.get("runnable").invoke(
565581
input=input, config=config, **kwargs
566582
)
@@ -587,18 +603,21 @@ def stream_query(
587603
Yields:
588604
The output of querying the Agent with the given input and config.
589605
"""
590-
from langchain.load import dump as langchain_load_dump
606+
try:
607+
from langchain_core.load import dumpd
608+
except ImportError:
609+
from langchain.load.dump import dumpd
591610

592611
if isinstance(input, str):
593-
input = {"input": input}
612+
input = {"input": input, "messages": [("user", input)]}
594613
if not self._tmpl_attrs.get("runnable"):
595614
self.set_up()
596615
for chunk in self._tmpl_attrs.get("runnable").stream(
597616
input=input,
598617
config=config,
599618
**kwargs,
600619
):
601-
yield langchain_load_dump.dumpd(chunk)
620+
yield dumpd(chunk)
602621

603622
def get_state_history(
604623
self,

0 commit comments

Comments
 (0)