Skip to content

Commit 81a5935

Browse files
committed
test: Add unit tests for runtime and Temporal integration
1 parent 59d946d commit 81a5935

2 files changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for Temporal integration helpers."""
16+
17+
import unittest
18+
from unittest.mock import MagicMock, patch, AsyncMock
19+
import sys
20+
import asyncio
21+
from typing import Any
22+
23+
from google.genai import types
24+
25+
# Configure Mocks globally
26+
# We create fresh mocks here.
27+
mock_workflow = MagicMock()
28+
mock_activity = MagicMock()
29+
mock_worker = MagicMock()
30+
mock_client = MagicMock()
31+
32+
# Important: execute_activity must be awaitable
33+
mock_workflow.execute_activity = AsyncMock(return_value="mock_result")
34+
35+
# Mock the parent package
36+
mock_temporalio = MagicMock()
37+
mock_temporalio.workflow = mock_workflow
38+
mock_temporalio.activity = mock_activity
39+
mock_temporalio.worker = mock_worker
40+
mock_temporalio.client = mock_client
41+
42+
# Mock sys.modules
43+
with patch.dict(sys.modules, {
44+
"temporalio": mock_temporalio,
45+
"temporalio.workflow": mock_workflow,
46+
"temporalio.activity": mock_activity,
47+
"temporalio.worker": mock_worker,
48+
"temporalio.client": mock_client,
49+
}):
50+
from google.adk.integrations import temporal
51+
from google.adk.models import LlmRequest, LlmResponse
52+
53+
54+
class TestTemporalIntegration(unittest.TestCase):
55+
56+
def test_activity_as_tool_wrapper(self):
57+
# Reset mocks
58+
mock_workflow.reset_mock()
59+
mock_workflow.execute_activity = AsyncMock(return_value="mock_result")
60+
61+
# Verify mock setup
62+
# If this fails, then 'temporal.workflow' is NOT our 'mock_workflow'
63+
assert temporal.workflow.execute_activity is mock_workflow.execute_activity
64+
65+
# Define a fake activity
66+
async def fake_activity(arg: str) -> str:
67+
"""My Docstring."""
68+
return f"Hello {arg}"
69+
70+
fake_activity.name = "fake_activity_name"
71+
72+
# Create tool
73+
tool = temporal.activity_as_tool(
74+
fake_activity,
75+
start_to_close_timeout=100
76+
)
77+
78+
# Check metadata
79+
self.assertEqual(tool.__name__, "fake_activity_name")
80+
self.assertEqual(tool.__doc__, "My Docstring.")
81+
82+
# Run tool (wrapper)
83+
loop = asyncio.new_event_loop()
84+
asyncio.set_event_loop(loop)
85+
86+
try:
87+
result = loop.run_until_complete(tool("World"))
88+
finally:
89+
loop.close()
90+
91+
# Verify call
92+
mock_workflow.execute_activity.assert_called_once()
93+
args, kwargs = mock_workflow.execute_activity.call_args
94+
self.assertEqual(kwargs['args'], ['World'])
95+
self.assertEqual(kwargs['start_to_close_timeout'], 100)
96+
97+
def test_temporal_model_generate_content(self):
98+
# Reset mocks
99+
mock_workflow.reset_mock()
100+
101+
# Prepare valid LlmResponse with content
102+
response_content = types.Content(parts=[types.Part(text="test_resp")])
103+
llm_response = LlmResponse(content=response_content)
104+
105+
# generate_content_async expects execute_activity to return response list (iterator)
106+
mock_workflow.execute_activity = AsyncMock(return_value=[llm_response])
107+
108+
# Mock an activity def
109+
mock_activity_def = MagicMock()
110+
111+
# Create model
112+
model = temporal.TemporalModel(
113+
model_name="test-model",
114+
activity_def=mock_activity_def,
115+
schedule_to_close_timeout=50
116+
)
117+
118+
# Create request
119+
req = LlmRequest(model="test-model", prompt="hi")
120+
121+
# Run generate_content_async (it is an async generator)
122+
async def run_gen():
123+
results = []
124+
async for r in model.generate_content_async(req):
125+
results.append(r)
126+
return results
127+
128+
loop = asyncio.new_event_loop()
129+
asyncio.set_event_loop(loop)
130+
131+
try:
132+
results = loop.run_until_complete(run_gen())
133+
finally:
134+
loop.close()
135+
136+
# Verify execute_activity called
137+
mock_workflow.execute_activity.assert_called_once()
138+
args, kwargs = mock_workflow.execute_activity.call_args
139+
self.assertEqual(kwargs['args'], [req])
140+
self.assertEqual(kwargs['schedule_to_close_timeout'], 50)
141+
self.assertEqual(len(results), 1)
142+
self.assertEqual(results[0].content.parts[0].text, "test_resp")
143+

tests/unittests/test_runtime.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for the runtime module."""
16+
17+
import time
18+
import uuid
19+
import unittest
20+
from unittest.mock import MagicMock, patch
21+
22+
from google.adk import runtime
23+
24+
25+
class TestRuntime(unittest.TestCase):
26+
27+
def tearDown(self):
28+
# Reset providers to default after each test
29+
runtime.set_time_provider(time.time)
30+
runtime.set_id_provider(lambda: str(uuid.uuid4()))
31+
32+
def test_default_time_provider(self):
33+
# Verify it returns a float that is close to now
34+
now = time.time()
35+
rt_time = runtime.get_time()
36+
self.assertIsInstance(rt_time, float)
37+
self.assertAlmostEqual(rt_time, now, delta=1.0)
38+
39+
def test_default_id_provider(self):
40+
# Verify it returns a string uuid
41+
uid = runtime.new_uuid()
42+
self.assertIsInstance(uid, str)
43+
# Should be parseable as uuid
44+
uuid.UUID(uid)
45+
46+
def test_custom_time_provider(self):
47+
# Test override
48+
mock_time = 123456789.0
49+
runtime.set_time_provider(lambda: mock_time)
50+
self.assertEqual(runtime.get_time(), mock_time)
51+
52+
def test_custom_id_provider(self):
53+
# Test override
54+
mock_id = "test-id-123"
55+
runtime.set_id_provider(lambda: mock_id)
56+
self.assertEqual(runtime.new_uuid(), mock_id)

0 commit comments

Comments
 (0)