|
| 1 | +"""Round-trip Gemini 3 thought signatures through llama-stack's Vertex AI path. |
| 2 | +
|
| 3 | +Gemini 3.x models (for example ``gemini-3-flash`` and ``gemini-3.5-flash``) |
| 4 | +attach a ``thought_signature`` to the first ``functionCall`` part of a |
| 5 | +tool-calling turn. The signature MUST be replayed verbatim on the following |
| 6 | +turn or Gemini rejects the request with HTTP 400. |
| 7 | +
|
| 8 | +llama-stack converts Gemini responses into the OpenAI chat-completion shape |
| 9 | +before they re-enter its own history, and that shape has no field for a |
| 10 | +thought signature, so the signature is dropped and every multi-turn tool call |
| 11 | +against a Gemini 3 model fails. This module monkeypatches llama-stack's |
| 12 | +``vertexai`` converter so the signature survives the round trip. |
| 13 | +
|
| 14 | +Strategy: both patched functions are thin wrappers around the upstream |
| 15 | +originals. We copy none of llama-stack's conversion logic; we only smuggle the |
| 16 | +signature in and out through the opaque tool-call ``id`` (which llama-stack |
| 17 | +round-trips untouched and only ever compares for equality). |
| 18 | +
|
| 19 | +- On the way out (Gemini -> OpenAI): ``_extract_candidate_parts`` produces a |
| 20 | + random tool-call id per ``functionCall`` part. We call the original, then |
| 21 | + re-walk the candidate's parts in the same deterministic order, pair each |
| 22 | + ``functionCall`` part with the tool call the original emitted, and rewrite |
| 23 | + that tool call's id to embed the base64-encoded signature. |
| 24 | +
|
| 25 | +- On the way back (OpenAI -> Gemini): ``_convert_assistant_message`` builds |
| 26 | + the Gemini ``parts``. We call the original, then re-pair each |
| 27 | + ``function_call`` part with its source tool call (same order) and attach the |
| 28 | + decoded signature. |
| 29 | +
|
| 30 | +This file shadows behaviour tied to a specific llama-stack release. Remove it |
| 31 | +once the upstream Vertex AI provider carries thought signatures natively. |
| 32 | +""" |
| 33 | + |
| 34 | +# This module is a monkeypatch shim for llama-stack's Vertex AI converter. It |
| 35 | +# necessarily reaches into the provider's protected (underscore-prefixed) |
| 36 | +# converter functions, so protected-access is disabled file-wide here rather |
| 37 | +# than annotated on every line. |
| 38 | +# pylint: disable=protected-access |
| 39 | +import base64 |
| 40 | +from typing import Any |
| 41 | + |
| 42 | +from log import get_logger |
| 43 | + |
| 44 | +logger = get_logger(__name__) |
| 45 | + |
| 46 | +# Sentinel separating the real tool-call id from a smuggled Gemini |
| 47 | +# thought_signature. Chosen to be vanishingly unlikely in a normal id. |
| 48 | +_THOUGHT_SIG_SEP = "::gts::" |
| 49 | + |
| 50 | +# Set once the patch has been applied so repeated startup calls are no-ops. |
| 51 | +_PATCH_APPLIED = False |
| 52 | + |
| 53 | + |
| 54 | +def _encode_thought_signature_into_id(call_id: str, signature: Any) -> str: |
| 55 | + """Append a base64-encoded Gemini thought_signature to a tool-call id. |
| 56 | +
|
| 57 | + The signature is bytes; the id must stay a plain string that round-trips |
| 58 | + through llama-stack history. Returns ``call_id`` unchanged when there is no |
| 59 | + signature to carry or it cannot be encoded. |
| 60 | + """ |
| 61 | + if not signature: |
| 62 | + return call_id |
| 63 | + try: |
| 64 | + raw = ( |
| 65 | + signature.encode("utf-8") |
| 66 | + if isinstance(signature, str) |
| 67 | + else bytes(signature) |
| 68 | + ) |
| 69 | + encoded = base64.b64encode(raw).decode("ascii") |
| 70 | + except (TypeError, ValueError): |
| 71 | + return call_id |
| 72 | + return f"{call_id}{_THOUGHT_SIG_SEP}{encoded}" |
| 73 | + |
| 74 | + |
| 75 | +def _decode_thought_signature_from_id(call_id: str) -> bytes | None: |
| 76 | + """Recover the thought_signature bytes smuggled into a tool-call id.""" |
| 77 | + if not call_id or _THOUGHT_SIG_SEP not in call_id: |
| 78 | + return None |
| 79 | + _, _, encoded = call_id.partition(_THOUGHT_SIG_SEP) |
| 80 | + try: |
| 81 | + return base64.b64decode(encoded) |
| 82 | + except (ValueError, TypeError): |
| 83 | + return None |
| 84 | + |
| 85 | + |
| 86 | +def _tag_wrapper(wrapper: Any, original: Any) -> None: |
| 87 | + """Mark ``wrapper`` as an LCS converter patch wrapping ``original``. |
| 88 | +
|
| 89 | + The markers make the patched state self-describing: ``apply_patch`` reads |
| 90 | + ``__wrapped_by_lcs__`` to avoid double-wrapping, and callers can follow |
| 91 | + ``__lcs_original__`` back to the pristine converter. |
| 92 | + """ |
| 93 | + wrapper.__wrapped_by_lcs__ = True |
| 94 | + wrapper.__lcs_original__ = original |
| 95 | + |
| 96 | + |
| 97 | +def _is_lcs_wrapped(func: Any) -> bool: |
| 98 | + """Return whether ``func`` is an LCS converter patch wrapper.""" |
| 99 | + return bool(getattr(func, "__wrapped_by_lcs__", False)) |
| 100 | + |
| 101 | + |
| 102 | +def _unwrap_to_original(func: Any) -> Any: |
| 103 | + """Follow ``__lcs_original__`` links until the pristine converter is found.""" |
| 104 | + while _is_lcs_wrapped(func): |
| 105 | + func = func.__lcs_original__ |
| 106 | + return func |
| 107 | + |
| 108 | + |
| 109 | +def _iter_function_call_parts(candidate: Any) -> list[Any]: |
| 110 | + """Return the ``functionCall`` parts of a Gemini candidate, in order. |
| 111 | +
|
| 112 | + Mirrors the iteration order llama-stack's ``_extract_candidate_parts`` uses |
| 113 | + so the parts line up one-to-one with the tool calls it produces. |
| 114 | + """ |
| 115 | + content_obj = getattr(candidate, "content", None) |
| 116 | + parts = getattr(content_obj, "parts", None) or [] |
| 117 | + fc_parts: list[Any] = [] |
| 118 | + for part in parts: |
| 119 | + # Thinking parts and text parts are skipped before the function-call |
| 120 | + # branch upstream; replicate that ordering precisely. |
| 121 | + if getattr(part, "thought", None): |
| 122 | + continue |
| 123 | + if getattr(part, "text", None) is not None: |
| 124 | + continue |
| 125 | + if getattr(part, "function_call", None) is not None: |
| 126 | + fc_parts.append(part) |
| 127 | + return fc_parts |
| 128 | + |
| 129 | + |
| 130 | +def apply_patch() -> bool: |
| 131 | + """Monkeypatch the Vertex AI converter to carry Gemini thought signatures. |
| 132 | +
|
| 133 | + Idempotent. Returns ``True`` if the patch is in effect after the call, |
| 134 | + ``False`` if the converter module could not be imported (for example when |
| 135 | + the Vertex AI provider is not installed), in which case nothing is changed. |
| 136 | + """ |
| 137 | + global _PATCH_APPLIED # pylint: disable=global-statement |
| 138 | + if _PATCH_APPLIED: |
| 139 | + return True |
| 140 | + |
| 141 | + try: |
| 142 | + # Imported lazily: the provider is optional, so a top-level import |
| 143 | + # would break environments where the Vertex AI provider is absent. |
| 144 | + from llama_stack.providers.remote.inference.vertexai import ( # pylint: disable=import-outside-toplevel |
| 145 | + converters, |
| 146 | + ) |
| 147 | + except ImportError: |
| 148 | + logger.info( |
| 149 | + "Vertex AI converter not importable; skipping Gemini thought-signature patch" |
| 150 | + ) |
| 151 | + return False |
| 152 | + |
| 153 | + # Guard against re-wrapping an already-wrapped converter. The module-level |
| 154 | + # flag is the fast path, but a second importer (or a test that clears the |
| 155 | + # flag) must not double-wrap: that would encode the signature twice. The |
| 156 | + # marker attribute makes the wrapped state detectable on the function |
| 157 | + # itself, independent of the flag. |
| 158 | + if _is_lcs_wrapped(converters._extract_candidate_parts): |
| 159 | + _PATCH_APPLIED = True |
| 160 | + return True |
| 161 | + |
| 162 | + original_extract = converters._extract_candidate_parts |
| 163 | + original_convert_assistant = converters._convert_assistant_message |
| 164 | + |
| 165 | + def patched_extract_candidate_parts(candidate: Any) -> Any: |
| 166 | + text_parts, thinking_parts, tool_calls = original_extract(candidate) |
| 167 | + if not tool_calls: |
| 168 | + return text_parts, thinking_parts, tool_calls |
| 169 | + fc_parts = _iter_function_call_parts(candidate) |
| 170 | + # The original emits exactly one tool call per function-call part, in |
| 171 | + # the same order. Pair them and embed any signature into the id. |
| 172 | + for tool_call, part in zip(tool_calls, fc_parts): |
| 173 | + signature = getattr(part, "thought_signature", None) |
| 174 | + if not signature: |
| 175 | + continue |
| 176 | + tool_call.id = _encode_thought_signature_into_id( |
| 177 | + tool_call.id or "", signature |
| 178 | + ) |
| 179 | + return text_parts, thinking_parts, tool_calls |
| 180 | + |
| 181 | + def patched_convert_assistant_message(msg: dict[str, Any]) -> dict[str, Any] | None: |
| 182 | + result = original_convert_assistant(msg) |
| 183 | + if result is None: |
| 184 | + return None |
| 185 | + tool_calls = msg.get("tool_calls") or [] |
| 186 | + if not tool_calls: |
| 187 | + return result |
| 188 | + # Re-pair each Gemini function_call part with its source tool call, in |
| 189 | + # order, and attach the decoded signature. The original appends one |
| 190 | + # function_call part per tool call after any leading text part, so we |
| 191 | + # walk the function_call parts and the tool calls together. |
| 192 | + fc_parts = [p for p in result.get("parts", []) if "function_call" in p] |
| 193 | + for part, tool_call in zip(fc_parts, tool_calls): |
| 194 | + call_id = converters._to_dict(tool_call).get("id", "") |
| 195 | + signature = _decode_thought_signature_from_id(call_id) |
| 196 | + if signature is not None: |
| 197 | + part["thought_signature"] = signature |
| 198 | + return result |
| 199 | + |
| 200 | + # Tag each wrapper so it is self-describing: ``__wrapped_by_lcs__`` lets a |
| 201 | + # second apply_patch detect the patched state, and ``__lcs_original__`` |
| 202 | + # lets callers (notably tests) unwind back to the pristine original. |
| 203 | + _tag_wrapper(patched_extract_candidate_parts, original_extract) |
| 204 | + _tag_wrapper(patched_convert_assistant_message, original_convert_assistant) |
| 205 | + |
| 206 | + converters._extract_candidate_parts = patched_extract_candidate_parts |
| 207 | + converters._convert_assistant_message = patched_convert_assistant_message |
| 208 | + _PATCH_APPLIED = True |
| 209 | + logger.info("Applied Gemini 3 thought-signature patch to Vertex AI converter") |
| 210 | + return True |
0 commit comments