Skip to content

Commit 35d6149

Browse files
fix: adds missing cache parameter handling
1 parent 57b887d commit 35d6149

1 file changed

Lines changed: 46 additions & 4 deletions

File tree

src/askui/tools/caching_tools.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import logging
22
import time
33
from pathlib import Path
4+
from typing import Any
45

56
from pydantic import validate_call
67
from typing_extensions import override
78

89
from ..models.shared.settings import CacheExecutionSettings
910
from ..models.shared.tools import Tool, ToolCollection
1011
from ..utils.caching.cache_manager import CacheManager
12+
from ..utils.caching.cache_parameter_handler import CacheParameterHandler
1113

1214
logger = logging.getLogger()
1315

@@ -72,7 +74,12 @@ def __init__(self, settings: CacheExecutionSettings | None = None) -> None:
7274
"trajectory files are available\n"
7375
"2. Select the appropriate trajectory file path from the "
7476
"returned list\n"
75-
"3. Pass the full file path to this tool\n\n"
77+
"3. If the trajectory contains parameters (e.g., {{target_url}}), "
78+
"provide values for them in the parameter_values parameter\n"
79+
"4. Pass the full file path to this tool\n\n"
80+
"Cache parameters allow dynamic values to be injected during "
81+
"execution. For example, if a trajectory types '{{target_url}}', "
82+
"you must provide parameter_values={'target_url': 'https://...'}.\n\n"
7683
"The trajectory will be executed step-by-step, and you should "
7784
"verify the results afterward. Note: Trajectories may fail if "
7885
"the UI state has changed since they were recorded."
@@ -88,6 +95,16 @@ def __init__(self, settings: CacheExecutionSettings | None = None) -> None:
8895
"available files)"
8996
),
9097
},
98+
"parameter_values": {
99+
"type": "object",
100+
"description": (
101+
"Optional dictionary mapping parameter names to "
102+
"their values. Required if the trajectory contains "
103+
"parameters like {{variable}}. Example: "
104+
"{'target_url': 'https://example.com'}"
105+
),
106+
"additionalProperties": {"type": "string"},
107+
},
91108
},
92109
"required": ["trajectory_file"],
93110
},
@@ -102,7 +119,11 @@ def set_toolbox(self, toolbox: ToolCollection) -> None:
102119

103120
@override
104121
@validate_call
105-
def __call__(self, trajectory_file: str) -> str:
122+
def __call__(
123+
self,
124+
trajectory_file: str,
125+
parameter_values: dict[str, Any] | None = None,
126+
) -> str:
106127
if not hasattr(self, "_toolbox"):
107128
error_msg = "Toolbox not set. Call set_toolbox() first."
108129
logger.error(error_msg)
@@ -116,9 +137,24 @@ def __call__(self, trajectory_file: str) -> str:
116137
logger.error(error_msg)
117138
raise FileNotFoundError(error_msg)
118139

119-
# Load and execute trajectory
140+
# Load trajectory
120141
cache_file = CacheManager.read_cache_file(Path(trajectory_file))
121142
trajectory = cache_file.trajectory
143+
parameter_values = parameter_values or {}
144+
145+
# Validate parameters
146+
is_valid, missing_params = CacheParameterHandler.validate_parameters(
147+
trajectory, parameter_values
148+
)
149+
if not is_valid:
150+
error_msg = (
151+
f"Missing required parameter values: {missing_params}. "
152+
f"The cache file expects these parameters. "
153+
f"Available parameters in cache: {cache_file.cache_parameters}"
154+
)
155+
logger.error(error_msg)
156+
return error_msg
157+
122158
info_msg = f"Executing cached trajectory from {trajectory_file}"
123159
logger.info(info_msg)
124160
for step in trajectory:
@@ -129,8 +165,14 @@ def __call__(self, trajectory_file: str) -> str:
129165
or step.name.startswith("execute_cached_executions_tool")
130166
):
131167
continue
168+
169+
# Substitute parameters in the step before execution
170+
substituted_step = CacheParameterHandler.substitute_parameters(
171+
step, parameter_values
172+
)
173+
132174
try:
133-
results = self._toolbox.run([step])
175+
results = self._toolbox.run([substituted_step])
134176
# Check for tool execution errors
135177
if results and hasattr(results[0], "is_error") and results[0].is_error:
136178
error_content = getattr(results[0], "content", "Unknown error")

0 commit comments

Comments
 (0)