Skip to content

Commit d9d6c8a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: migrate legacy langgraph imports to maintain compatibility
PiperOrigin-RevId: 906317678
1 parent 68f053e commit d9d6c8a

6 files changed

Lines changed: 238 additions & 106 deletions

File tree

tests/unit/vertex_langchain/test_agent_engine_templates_langgraph.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[
212+
mock.call.invoke.invoke(
213+
input={"input": "test query", "messages": [("user", "test query")]},
214+
config=None,
215+
)
216+
]
212217
)
213218

214219
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
217222
agent._tmpl_attrs["runnable"].stream.return_value = []
218223
list(agent.stream_query(input="test stream query"))
219224
agent._tmpl_attrs["runnable"].stream.assert_called_once_with(
220-
input={"input": "test stream query"},
225+
input={
226+
"input": "test stream query",
227+
"messages": [("user", "test stream query")],
228+
},
221229
config=None,
222230
)
223231

tests/unit/vertex_langchain/test_reasoning_engine_templates_langgraph.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[
212+
mock.call.invoke.invoke(
213+
input={"input": "test query", "messages": [("user", "test query")]},
214+
config=None,
215+
)
216+
]
212217
)
213218

214219
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
217222
agent._runnable.stream.return_value = []
218223
list(agent.stream_query(input="test stream query"))
219224
agent._runnable.stream.assert_called_once_with(
220-
input={"input": "test stream query"},
225+
input={
226+
"input": "test stream query",
227+
"messages": [("user", "test stream query")],
228+
},
221229
config=None,
222230
)
223231

vertexai/agent_engines/templates/langchain.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@
4343
RunnableSerializable = Any
4444

4545
try:
46-
from langchain_google_vertexai.functions_utils import _ToolsType
47-
48-
_ToolsType = _ToolsType
46+
from langchain_google_genai.functions_utils import _ToolsType
4947
except ImportError:
50-
_ToolsType = Any
48+
try:
49+
from langchain_google_vertexai.functions_utils import _ToolsType
50+
except ImportError:
51+
_ToolsType = Any
5152

5253
try:
5354
from opentelemetry.sdk import trace
@@ -81,13 +82,15 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
8182

8283
def _default_output_parser():
8384
try:
84-
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
85+
from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser
8586
except (ModuleNotFoundError, ImportError):
86-
# Fallback to an older version if needed.
87-
from langchain.agents.output_parsers.openai_tools import (
88-
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
89-
)
90-
87+
try:
88+
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
89+
except (ModuleNotFoundError, ImportError):
90+
# Fallback to an older version if needed.
91+
from langchain.agents.output_parsers.openai_tools import (
92+
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
93+
)
9194
return ToolsAgentOutputParser()
9295

9396

@@ -98,17 +101,29 @@ def _default_model_builder(
98101
location: str,
99102
model_kwargs: Optional[Mapping[str, Any]] = None,
100103
) -> "BaseLanguageModel":
101-
import vertexai
102-
from google.cloud.aiplatform import initializer
103-
from langchain_google_vertexai import ChatVertexAI
104-
105104
model_kwargs = model_kwargs or {}
106-
current_project = initializer.global_config.project
107-
current_location = initializer.global_config.location
108-
vertexai.init(project=project, location=location)
109-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
110-
vertexai.init(project=current_project, location=current_location)
111-
return model
105+
try:
106+
from langchain_google_genai import ChatGoogleGenerativeAI
107+
108+
model = ChatGoogleGenerativeAI(
109+
model=model_name,
110+
project=project,
111+
location=location,
112+
vertexai=True,
113+
**model_kwargs,
114+
)
115+
return model
116+
except ImportError:
117+
import vertexai
118+
from google.cloud.aiplatform import initializer
119+
from langchain_google_vertexai import ChatVertexAI
120+
121+
current_project = initializer.global_config.project
122+
current_location = initializer.global_config.location
123+
vertexai.init(project=project, location=location)
124+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
125+
vertexai.init(project=current_project, location=current_location)
126+
return model
112127

113128

114129
def _default_runnable_builder(
@@ -124,8 +139,16 @@ def _default_runnable_builder(
124139
runnable_kwargs: Optional[Mapping[str, Any]] = None,
125140
) -> "RunnableSerializable":
126141
from langchain_core import tools as lc_tools
127-
from langchain.agents import AgentExecutor
128-
from langchain.tools.base import StructuredTool
142+
143+
try:
144+
from langchain_classic.agents import AgentExecutor
145+
except ImportError:
146+
from langchain.agents import AgentExecutor
147+
148+
try:
149+
from langchain_core.tools import StructuredTool
150+
except ImportError:
151+
from langchain.tools.base import StructuredTool
129152

130153
# The prompt template and runnable_kwargs needs to be customized depending
131154
# on whether the user intends for the agent to have history. The way the
@@ -261,12 +284,16 @@ def _default_prompt(
261284
from langchain_core import prompts
262285

263286
try:
264-
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
265-
except (ModuleNotFoundError, ImportError):
266-
# Fallback to an older version if needed.
267-
from langchain.agents.format_scratchpad.openai_tools import (
268-
format_to_openai_tool_messages as format_to_tool_messages,
287+
from langchain_classic.agents.format_scratchpad.tools import (
288+
format_to_tool_messages,
269289
)
290+
except (ModuleNotFoundError, ImportError):
291+
try:
292+
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
293+
except (ModuleNotFoundError, ImportError):
294+
from langchain.agents.format_scratchpad.openai_tools import (
295+
format_to_openai_tool_messages as format_to_tool_messages,
296+
)
270297

271298
system_instructions = []
272299
if system_instruction:
@@ -629,13 +656,18 @@ def query(
629656
Returns:
630657
The output of querying the Agent with the given input and config.
631658
"""
632-
from langchain.load import dump as langchain_load_dump
659+
try:
660+
from langchain_core.load import dumpd
661+
except ImportError:
662+
from langchain.load import dump as langchain_load_dump
663+
664+
dumpd = langchain_load_dump.dumpd
633665

634666
if isinstance(input, str):
635667
input = {"input": input}
636668
if not self._tmpl_attrs.get("runnable"):
637669
self.set_up()
638-
return langchain_load_dump.dumpd(
670+
return dumpd(
639671
self._tmpl_attrs.get("runnable").invoke(
640672
input=input, config=config, **kwargs
641673
)
@@ -662,7 +694,12 @@ def stream_query(
662694
Yields:
663695
The output of querying the Agent with the given input and config.
664696
"""
665-
from langchain.load import dump as langchain_load_dump
697+
try:
698+
from langchain_core.load import dumpd
699+
except ImportError:
700+
from langchain.load import dump as langchain_load_dump
701+
702+
dumpd = langchain_load_dump.dumpd
666703

667704
if isinstance(input, str):
668705
input = {"input": input}
@@ -673,4 +710,4 @@ def stream_query(
673710
config=config,
674711
**kwargs,
675712
):
676-
yield langchain_load_dump.dumpd(chunk)
713+
yield dumpd(chunk)

vertexai/agent_engines/templates/langgraph.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@
3434
BaseLanguageModel = Any
3535

3636
try:
37-
from langchain_google_vertexai.functions_utils import _ToolsType
37+
from langchain_google_genai.functions_utils import _ToolsType
3838

3939
_ToolLike = _ToolsType
4040
except ImportError:
41-
_ToolLike = Any
41+
try:
42+
from langchain_google_vertexai.functions_utils import _ToolsType
43+
44+
_ToolLike = _ToolsType
45+
except ImportError:
46+
_ToolLike = Any
4247

4348
try:
4449
from opentelemetry.sdk import trace
@@ -87,17 +92,29 @@ def _default_model_builder(
8792
Returns:
8893
BaseLanguageModel: The language model.
8994
"""
90-
import vertexai
91-
from google.cloud.aiplatform import initializer
92-
from langchain_google_vertexai import ChatVertexAI
93-
9495
model_kwargs = model_kwargs or {}
95-
current_project = initializer.global_config.project
96-
current_location = initializer.global_config.location
97-
vertexai.init(project=project, location=location)
98-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
99-
vertexai.init(project=current_project, location=current_location)
100-
return model
96+
try:
97+
from langchain_google_genai import ChatGoogleGenerativeAI
98+
99+
model = ChatGoogleGenerativeAI(
100+
model=model_name,
101+
project=project,
102+
location=location,
103+
vertexai=True,
104+
**model_kwargs,
105+
)
106+
return model
107+
except ImportError:
108+
import vertexai
109+
from google.cloud.aiplatform import initializer
110+
from langchain_google_vertexai import ChatVertexAI
111+
112+
current_project = initializer.global_config.project
113+
current_location = initializer.global_config.location
114+
vertexai.init(project=project, location=location)
115+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
116+
vertexai.init(project=current_project, location=current_location)
117+
return model
101118

102119

103120
def _default_runnable_builder(
@@ -554,13 +571,16 @@ def query(
554571
Returns:
555572
The output of querying the Agent with the given input and config.
556573
"""
557-
from langchain.load import dump as langchain_load_dump
574+
try:
575+
from langchain_core.load import dumpd
576+
except ImportError:
577+
from langchain.load.dump import dumpd
558578

559579
if isinstance(input, str):
560-
input = {"input": input}
580+
input = {"input": input, "messages": [("user", input)]}
561581
if not self._tmpl_attrs.get("runnable"):
562582
self.set_up()
563-
return langchain_load_dump.dumpd(
583+
return dumpd(
564584
self._tmpl_attrs.get("runnable").invoke(
565585
input=input, config=config, **kwargs
566586
)
@@ -587,18 +607,21 @@ def stream_query(
587607
Yields:
588608
The output of querying the Agent with the given input and config.
589609
"""
590-
from langchain.load import dump as langchain_load_dump
610+
try:
611+
from langchain_core.load import dumpd
612+
except ImportError:
613+
from langchain.load.dump import dumpd
591614

592615
if isinstance(input, str):
593-
input = {"input": input}
616+
input = {"input": input, "messages": [("user", input)]}
594617
if not self._tmpl_attrs.get("runnable"):
595618
self.set_up()
596619
for chunk in self._tmpl_attrs.get("runnable").stream(
597620
input=input,
598621
config=config,
599622
**kwargs,
600623
):
601-
yield langchain_load_dump.dumpd(chunk)
624+
yield dumpd(chunk)
602625

603626
def get_state_history(
604627
self,

0 commit comments

Comments
 (0)