Skip to content

Commit 1b05842

Browse files
shukladivyanshcopybara-github
authored andcommitted
feat: add configurable resource limits for subprocesses in BashTool
PiperOrigin-RevId: 894338571
1 parent 37973da commit 1b05842

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

src/google/adk/tools/bash_tool.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
import pathlib
24+
import resource
2425
import shlex
2526
import signal
2627
from typing import Any
@@ -32,16 +33,25 @@
3233
from .base_tool import BaseTool
3334
from .tool_context import ToolContext
3435

36+
logger = logging.getLogger("google_adk." + __name__)
37+
3538

3639
@dataclasses.dataclass(frozen=True)
3740
class BashToolPolicy:
38-
"""Configuration for allowed bash commands based on prefix matching.
41+
"""Configuration for allowed bash commands and resource limits.
3942
4043
Set allowed_command_prefixes to ("*",) to allow all commands (default),
4144
or explicitly list allowed prefixes.
45+
46+
Values for max_memory_bytes, max_file_size_bytes, and max_child_processes
47+
will be enforced upon the spawned subprocess.
4248
"""
4349

4450
allowed_command_prefixes: tuple[str, ...] = ("*",)
51+
timeout_seconds: Optional[int] = 30
52+
max_memory_bytes: Optional[int] = None
53+
max_file_size_bytes: Optional[int] = None
54+
max_child_processes: Optional[int] = None
4555

4656

4757
def _validate_command(command: str, policy: BashToolPolicy) -> Optional[str]:
@@ -61,6 +71,29 @@ def _validate_command(command: str, policy: BashToolPolicy) -> Optional[str]:
6171
return f"Command blocked. Permitted prefixes are: {allowed}"
6272

6373

74+
def _set_resource_limits(policy: BashToolPolicy) -> None:
75+
"""Sets resource limits for the subprocess based on the provided policy."""
76+
try:
77+
resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
78+
if policy.max_memory_bytes:
79+
resource.setrlimit(
80+
resource.RLIMIT_AS,
81+
(policy.max_memory_bytes, policy.max_memory_bytes),
82+
)
83+
if policy.max_file_size_bytes:
84+
resource.setrlimit(
85+
resource.RLIMIT_FSIZE,
86+
(policy.max_file_size_bytes, policy.max_file_size_bytes),
87+
)
88+
if policy.max_child_processes:
89+
resource.setrlimit(
90+
resource.RLIMIT_NPROC,
91+
(policy.max_child_processes, policy.max_child_processes),
92+
)
93+
except (ValueError, OSError) as e:
94+
logger.warning("Failed to set resource limits: %s", e)
95+
96+
6497
@features.experimental(features.FeatureName.SKILL_TOOLSET)
6598
class ExecuteBashTool(BaseTool):
6699
"""Tool to execute a validated bash command within a workspace directory."""
@@ -144,20 +177,25 @@ async def run_async(
144177
stdout=asyncio.subprocess.PIPE,
145178
stderr=asyncio.subprocess.PIPE,
146179
start_new_session=True,
180+
preexec_fn=lambda: _set_resource_limits(self._policy),
147181
)
148182

149183
try:
150184
stdout, stderr = await asyncio.wait_for(
151-
process.communicate(), timeout=30
185+
process.communicate(), timeout=self._policy.timeout_seconds
152186
)
153187
except asyncio.TimeoutError:
154188
try:
155-
os.killpg(process.pid, signal.SIGKILL)
189+
if process.pid:
190+
os.killpg(process.pid, signal.SIGKILL)
156191
except ProcessLookupError:
157192
pass
158193
stdout, stderr = await process.communicate()
159194
return {
160-
"error": "Command timed out after 30 seconds.",
195+
"error": (
196+
f"Command timed out after {self._policy.timeout_seconds}"
197+
" seconds."
198+
),
161199
"stdout": (
162200
stdout.decode(errors="replace")
163201
if stdout
@@ -176,7 +214,6 @@ async def run_async(
176214
os.killpg(process.pid, signal.SIGKILL)
177215
except ProcessLookupError:
178216
pass
179-
180217
return {
181218
"stdout": (
182219
stdout.decode(errors="replace")
@@ -191,7 +228,6 @@ async def run_async(
191228
"returncode": process.returncode,
192229
}
193230
except Exception as e: # pylint: disable=broad-except
194-
logger = logging.getLogger("google_adk." + __name__)
195231
logger.exception("ExecuteBashTool execution failed")
196232

197233
stdout_res = (

tests/unittests/tools/test_bash_tool.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import resource
1617
import signal
1718
from unittest import mock
1819

@@ -241,3 +242,34 @@ async def test_no_command(self, workspace, tool_context_confirmed):
241242
result = await tool.run_async(args={}, tool_context=tool_context_confirmed)
242243
assert "error" in result
243244
assert "required" in result["error"].lower()
245+
246+
@pytest.mark.asyncio
247+
async def test_resource_limits_set(self, workspace, tool_context_confirmed):
248+
policy = bash_tool.BashToolPolicy(
249+
max_memory_bytes=100 * 1024 * 1024,
250+
max_file_size_bytes=50 * 1024 * 1024,
251+
max_child_processes=10,
252+
)
253+
tool = bash_tool.ExecuteBashTool(workspace=workspace, policy=policy)
254+
mock_process = mock.AsyncMock()
255+
mock_process.communicate.return_value = (b"", b"")
256+
mock_exec = mock.AsyncMock(return_value=mock_process)
257+
258+
with mock.patch("asyncio.create_subprocess_exec", mock_exec):
259+
await tool.run_async(
260+
args={"command": "ls"},
261+
tool_context=tool_context_confirmed,
262+
)
263+
assert "preexec_fn" in mock_exec.call_args.kwargs
264+
preexec_fn = mock_exec.call_args.kwargs["preexec_fn"]
265+
266+
mock_setrlimit = mock.create_autospec(resource.setrlimit, instance=True)
267+
with mock.patch("resource.setrlimit", mock_setrlimit):
268+
preexec_fn()
269+
mock_setrlimit.assert_any_call(resource.RLIMIT_CORE, (0, 0))
270+
mock_setrlimit.assert_any_call(
271+
resource.RLIMIT_AS, (100 * 1024 * 1024, 100 * 1024 * 1024)
272+
)
273+
mock_setrlimit.assert_any_call(
274+
resource.RLIMIT_FSIZE, (50 * 1024 * 1024, 50 * 1024 * 1024)
275+
)

0 commit comments

Comments
 (0)