Skip to content

Commit 339e76d

Browse files
giles17CopilotCopilot
authored
Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks (microsoft#5013)
* Fix GitHubCopilotAgent not calling context provider hooks (microsoft#3984) GitHubCopilotAgent accepted context_providers in its constructor but never called before_run()/after_run() on them in _run_impl() or _stream_updates(), silently ignoring all context providers. Add _run_before_providers() helper to create SessionContext and invoke before_run on each provider. Both _run_impl() and _stream_updates() now run the full provider lifecycle: before_run before sending the prompt (with provider instructions prepended) and after_run after receiving the response. This follows the same pattern used by A2AAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks Fixes microsoft#3984 * fix(microsoft#3984): address review feedback for context provider integration - Build prompt from session_context.get_messages(include_input=True) so provider-injected context_messages are included in both non-streaming and streaming paths (review comments #1, #2) - Preserve timeout in opts (use get instead of pop) so providers can observe it via context.options (review comment #3) - Eliminate streaming double-buffer: move after_run invocation to a ResponseStream result_hook (matching Agent class pattern) instead of maintaining a separate updates list in the generator (review comment #4) - Improve _run_before_providers docstring Add tests for: - Context messages included in prompt (non-streaming + streaming) - Error path: after_run NOT called when send_and_wait/streaming raises - Multiple providers: forward before_run, reverse after_run ordering - BaseHistoryProvider with load_messages=False is skipped - Streaming after_run response contains aggregated updates - Streaming with no updates still sets empty response - Timeout preserved in session context options for providers Note: _run_before_providers remains on GitHubCopilotAgent for now. A follow-up PR should extract it to BaseAgent so subclasses can reuse it without duplicating the provider iteration logic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for microsoft#3984: Python: [Bug]: GitHubCopilotAgent Memory Example * refactor(microsoft#3984): promote _run_before_providers to BaseAgent Move _run_before_providers from GitHubCopilotAgent into BaseAgent, mirroring the existing _run_after_providers helper. Agent's _prepare_session_and_messages now delegates to the shared base method, eliminating the near-duplicate provider iteration logic that could drift as the provider contract evolves. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for microsoft#3984: Python: [Bug]: GitHubCopilotAgent Memory Example * revert: keep _run_before_providers in GitHubCopilotAgent only Undo the promotion of _run_before_providers to BaseAgent. The method stays in GitHubCopilotAgent where it is needed, and _agents.py retains its original inline provider iteration in RawAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: replace deprecated BaseContextProvider/BaseHistoryProvider with ContextProvider/HistoryProvider Update imports and usages in GitHubCopilotAgent and its tests to use the new non-deprecated class names from the core package. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address review feedback - reorder providers before session, wrap streaming after_run in try/except, assert after_run on skipped HistoryProvider - Move _run_before_providers before _get_or_create_session so provider contributions can affect session configuration - Wrap _run_after_providers in try/except in streaming _after_run_hook to prevent provider errors from replacing successful responses - Add after_run assertion to test_history_provider_skip_when_load_messages_false Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 62595b2 commit 339e76d

2 files changed

Lines changed: 624 additions & 7 deletions

File tree

python/packages/github_copilot/agent_framework_github_copilot/_agent.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
BaseAgent,
1818
Content,
1919
ContextProvider,
20+
HistoryProvider,
2021
Message,
2122
ResponseStream,
23+
SessionContext,
2224
normalize_messages,
2325
)
2426
from agent_framework._settings import load_settings
@@ -352,13 +354,25 @@ def run(
352354
AgentException: If the request fails.
353355
"""
354356
if stream:
357+
ctx_holder: dict[str, Any] = {}
358+
359+
async def _after_run_hook(response: AgentResponse) -> None:
360+
session_context = ctx_holder.get("session_context")
361+
sess = ctx_holder.get("session")
362+
if session_context is not None and sess is not None:
363+
session_context._response = response
364+
try:
365+
await self._run_after_providers(session=sess, context=session_context)
366+
except Exception:
367+
logger.exception("Error running after_run providers in streaming result hook")
355368

356369
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
357370
return AgentResponse.from_updates(updates)
358371

359372
return ResponseStream(
360-
self._stream_updates(messages=messages, session=session, options=options),
373+
self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder),
361374
finalizer=_finalize,
375+
result_hooks=[_after_run_hook],
362376
)
363377
return self._run_impl(messages=messages, session=session, options=options)
364378

@@ -377,11 +391,22 @@ async def _run_impl(
377391
session = self.create_session()
378392

379393
opts: dict[str, Any] = dict(options) if options else {}
380-
timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
394+
timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
381395

382-
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
383396
input_messages = normalize_messages(messages)
384-
prompt = "\n".join([message.text for message in input_messages])
397+
398+
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
399+
400+
# NOTE: session is created after providers run so that future provider-contributed
401+
# tools/config could be folded into runtime_options before session creation.
402+
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
403+
404+
# Build the prompt from the full set of messages in the session context,
405+
# so that any context/history provider-injected messages are included.
406+
context_messages = session_context.get_messages(include_input=True)
407+
prompt = "\n".join([message.text for message in context_messages])
408+
if session_context.instructions:
409+
prompt = "\n".join(session_context.instructions) + "\n" + prompt
385410
message_options = cast(MessageOptions, {"prompt": prompt})
386411

387412
try:
@@ -408,14 +433,18 @@ async def _run_impl(
408433
)
409434
response_id = message_id
410435

411-
return AgentResponse(messages=response_messages, response_id=response_id)
436+
response = AgentResponse(messages=response_messages, response_id=response_id)
437+
session_context._response = response # type: ignore[assignment]
438+
await self._run_after_providers(session=session, context=session_context)
439+
return response
412440

413441
async def _stream_updates(
414442
self,
415443
messages: AgentRunInputs | None = None,
416444
*,
417445
session: AgentSession | None = None,
418446
options: OptionsT | None = None,
447+
_ctx_holder: dict[str, Any] | None = None,
419448
) -> AsyncIterable[AgentResponseUpdate]:
420449
"""Internal method to stream updates from GitHub Copilot.
421450
@@ -425,6 +454,9 @@ async def _stream_updates(
425454
Keyword Args:
426455
session: The conversation session associated with the message(s).
427456
options: Runtime options (model, timeout, etc.).
457+
_ctx_holder: Internal dict populated with session_context and session
458+
so that the caller (via a ResponseStream result_hook) can run
459+
after_run providers without duplicating the updates buffer.
428460
429461
Yields:
430462
AgentResponseUpdate items.
@@ -440,9 +472,23 @@ async def _stream_updates(
440472

441473
opts: dict[str, Any] = dict(options) if options else {}
442474

443-
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
444475
input_messages = normalize_messages(messages)
445-
prompt = "\n".join([message.text for message in input_messages])
476+
477+
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
478+
479+
# NOTE: session is created after providers run so that future provider-contributed
480+
# tools/config could be folded into runtime_options before session creation.
481+
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
482+
483+
if _ctx_holder is not None:
484+
_ctx_holder["session_context"] = session_context
485+
_ctx_holder["session"] = session
486+
487+
# Build the prompt from the full session context so provider-injected messages are included.
488+
context_messages = session_context.get_messages(include_input=True)
489+
prompt = "\n".join([message.text for message in context_messages])
490+
if session_context.instructions:
491+
prompt = "\n".join(session_context.instructions) + "\n" + prompt
446492
message_options = cast(MessageOptions, {"prompt": prompt})
447493

448494
queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue()
@@ -513,6 +559,46 @@ def event_handler(event: SessionEvent) -> None:
513559
finally:
514560
unsubscribe()
515561

562+
async def _run_before_providers(
563+
self,
564+
*,
565+
session: AgentSession,
566+
input_messages: list[Message],
567+
options: dict[str, Any],
568+
) -> SessionContext:
569+
"""Run before_run on all context providers and return the session context.
570+
571+
Creates a SessionContext and invokes ``before_run`` on each provider in
572+
forward order. ``HistoryProvider`` instances with
573+
``load_messages=False`` are skipped.
574+
575+
Keyword Args:
576+
session: The conversation session.
577+
input_messages: The normalized input messages.
578+
options: Runtime options dict.
579+
580+
Returns:
581+
The SessionContext with provider context populated.
582+
"""
583+
session_context = SessionContext(
584+
session_id=session.session_id,
585+
service_session_id=session.service_session_id,
586+
input_messages=input_messages,
587+
options=options,
588+
)
589+
590+
for provider in self.context_providers:
591+
if isinstance(provider, HistoryProvider) and not provider.load_messages:
592+
continue
593+
await provider.before_run(
594+
agent=self, # type: ignore[arg-type]
595+
session=session,
596+
context=session_context,
597+
state=session.state.setdefault(provider.source_id, {}),
598+
)
599+
600+
return session_context
601+
516602
@staticmethod
517603
def _prepare_system_message(
518604
instructions: str | None,

0 commit comments

Comments
 (0)