Skip to content

Commit cea6d3d

Browse files
committed
fixes
Signed-off-by: Ezequiel Lanza <ezequiel.lanza@gmail.com>
1 parent 91aaf4b commit cea6d3d

1 file changed

Lines changed: 26 additions & 21 deletions

File tree

python/beeai_framework/tools/scratchpad/scratchpad.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ def __init__(self) -> None:
8080
"""Initialize scratchpad tool."""
8181
super().__init__()
8282
self.middlewares = []
83-
# Store the session_id once it's determined from context
84-
# This ensures the same session is used across all calls
85-
self._cached_session_id: str | None = None
8683

8784
@classmethod
8885
def _ensure_session(cls, session_id: str) -> None:
@@ -91,28 +88,32 @@ def _ensure_session(cls, session_id: str) -> None:
9188
cls._scratchpads[session_id] = []
9289

9390
def _get_session_id(self) -> str:
94-
"""Extract session ID from context.
91+
"""Extract a stable session ID from RunContext.
9592
96-
Caches the session ID on first call to ensure the same session
97-
is used across all tool calls for this tool instance.
93+
Order of preference:
94+
1) RunContext.context["session_id"] (set by the API layer)
95+
2) RunContext.group_id (stable for a run group)
96+
3) RunContext.run_id (per-request fallback)
97+
"""
98+
context = RunContext.get()
99+
if not context:
100+
raise ToolInputValidationError("RunContext missing; cannot determine session.")
98101

99-
Returns:
100-
Session ID string for data isolation.
102+
context_data = getattr(context, "context", None)
103+
if isinstance(context_data, dict):
104+
session_id = context_data.get("session_id")
105+
if session_id:
106+
return str(session_id)
101107

102-
Raises:
103-
ToolInputValidationError: If no valid session ID can be extracted from context.
104-
"""
105-
# Return cached session ID if we already determined it
106-
if self._cached_session_id:
107-
return self._cached_session_id
108+
group_id = getattr(context, "group_id", None)
109+
if group_id:
110+
return str(group_id)
108111

109-
# Get run_id from RunContext as session identifier
110-
session_id = RunContext.get().run_id
112+
run_id = getattr(context, "run_id", None)
113+
if run_id:
114+
return str(run_id)
111115

112-
# Cache the session ID for future calls
113-
self._cached_session_id = session_id
114-
logger.info(f"Scratchpad session initialized: {session_id}")
115-
return session_id
116+
raise ToolInputValidationError("No valid session id found in RunContext.")
116117

117118
@property
118119
def name(self) -> str:
@@ -135,10 +136,14 @@ def input_schema(self) -> type[ScratchpadInput]:
135136
"""Input schema for the tool."""
136137
return ScratchpadInput
137138

139+
def _create_emitter(self) -> Emitter:
140+
"""Create emitter for the tool."""
141+
return Emitter.root()
142+
138143
@property
139144
def emitter(self) -> Emitter:
140145
"""Emitter for the tool."""
141-
return Emitter.root.child(
146+
return Emitter.root().child(
142147
namespace=["tool", "scratchpad"],
143148
creator=self,
144149
)

0 commit comments

Comments
 (0)