|
| 1 | +# Copyright The OpenTelemetry Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""LLM async-compat wrapper for Cognee ``GenericAPIAdapter``. |
| 16 | +
|
| 17 | +Background |
| 18 | +---------- |
| 19 | +Cognee v1.2.1 constructs its LLM client via:: |
| 20 | +
|
| 21 | + self.aclient = instructor.from_litellm(litellm.acompletion, mode=...) |
| 22 | +
|
| 23 | +When ``loongsuite-instrumentation-litellm`` is loaded (which the Cognee |
| 24 | +README mandates), it patches ``litellm.acompletion`` to an instance of |
| 25 | +``opentelemetry.instrumentation.litellm._wrapper.AsyncCompletionWrapper`` |
| 26 | +— a class instance with an ``async def __call__``. Instructor 1.14.5's |
| 27 | +``instructor.utils.core.is_async`` uses ``inspect.iscoroutinefunction``, |
| 28 | +which returns ``False`` for callable class instances regardless of |
| 29 | +whether ``__call__`` is ``async def``. Instructor therefore picks the |
| 30 | +sync ``Instructor`` + ``new_create_sync`` retry path; ``retry_sync`` |
| 31 | +calls ``litellm.acompletion(...)`` *without* ``await`` and the resulting |
| 32 | +coroutine is mistaken for the API response: |
| 33 | +
|
| 34 | + instructor.core.exceptions.InstructorRetryException: |
| 35 | + 'coroutine' object has no attribute 'choices' |
| 36 | +
|
| 37 | +Fix |
| 38 | +--- |
| 39 | +After ``GenericAPIAdapter.__init__`` runs, this wrapper rebuilds |
| 40 | +``self.aclient`` as an explicit ``AsyncInstructor`` by routing |
| 41 | +``litellm.acompletion`` through a real ``async def`` (so |
| 42 | +``iscoroutinefunction`` returns ``True`` and instructor picks |
| 43 | +``new_create_async``). All retry / fallback / content-policy logic in |
| 44 | +``GenericAPIAdapter.acreate_structured_output`` is preserved unchanged. |
| 45 | +
|
| 46 | +Semantics |
| 47 | +--------- |
| 48 | +The wrap is **transparent** to telemetry: |
| 49 | +
|
| 50 | +* It does not change ``response_model`` / ``mode`` / ``max_retries`` / |
| 51 | + ``llm_args`` handling — those still flow through Cognee's original |
| 52 | + ``acreate_structured_output``. |
| 53 | +* It does not wrap ``litellm.acompletion`` itself; ``litellm.acompletion`` |
| 54 | + is still the LiteLLM-instrumented callable, so LLM spans + token usage |
| 55 | + metrics are still produced by ``loongsuite-instrumentation-litellm``. |
| 56 | +* It does not create any new span / metric — the Cognee instrumentor's |
| 57 | + span coverage (ENTRY/AGENT/TOOL/STEP/EMBEDDING) is unchanged. |
| 58 | +""" |
| 59 | + |
| 60 | +from __future__ import annotations |
| 61 | + |
| 62 | +import logging |
| 63 | +from typing import Any, Callable |
| 64 | + |
| 65 | +from wrapt import wrap_function_wrapper |
| 66 | + |
| 67 | +logger = logging.getLogger(__name__) |
| 68 | + |
| 69 | + |
| 70 | +_ADAPTER_MODULE = ( |
| 71 | + "cognee.infrastructure.llm.structured_output_framework." |
| 72 | + "litellm_instructor.llm.generic_llm_api.adapter" |
| 73 | +) |
| 74 | +_ADAPTER_CLASS = "GenericAPIAdapter" |
| 75 | + |
| 76 | + |
| 77 | +def _build_async_acompletion(original_acompletion: Callable[..., Any]) -> Callable[..., Any]: |
| 78 | + """Wrap ``litellm.acompletion`` in a real ``async def`` so instructor |
| 79 | + detects it as async via ``inspect.iscoroutinefunction``. |
| 80 | +
|
| 81 | + The wrapper resolves ``litellm.acompletion`` at call time so any |
| 82 | + re-instrumentation of the global (e.g., LiteLLMInstrumentor |
| 83 | + re-install) is picked up automatically. |
| 84 | + """ |
| 85 | + |
| 86 | + async def _async_acompletion(*args: Any, **kwargs: Any) -> Any: |
| 87 | + import litellm # late import — resolve the *current* global |
| 88 | + |
| 89 | + return await litellm.acompletion(*args, **kwargs) |
| 90 | + |
| 91 | + # Preserve functools.WRAPPER_ASSIGNMENTS for debuggability. |
| 92 | + try: |
| 93 | + _async_acompletion.__name__ = "acompletion" |
| 94 | + _async_acompletion.__doc__ = original_acompletion.__doc__ |
| 95 | + except AttributeError: |
| 96 | + pass |
| 97 | + |
| 98 | + return _async_acompletion |
| 99 | + |
| 100 | + |
| 101 | +def _rebuild_aclient_as_async(instance: Any) -> bool: |
| 102 | + """Rebuild ``instance.aclient`` as an ``AsyncInstructor``. |
| 103 | +
|
| 104 | + Returns ``True`` if rebuilt, ``False`` if already async or rebuild skipped. |
| 105 | + """ |
| 106 | + aclient = getattr(instance, "aclient", None) |
| 107 | + if aclient is None: |
| 108 | + return False |
| 109 | + |
| 110 | + cls_name = type(aclient).__name__ |
| 111 | + if cls_name in ("AsyncInstructor", "_AsyncInstructor"): |
| 112 | + return False |
| 113 | + |
| 114 | + try: |
| 115 | + import instructor # type: ignore |
| 116 | + import litellm # type: ignore |
| 117 | + except ImportError: |
| 118 | + logger.debug( |
| 119 | + "Cannot rebuild aclient: instructor or litellm not importable" |
| 120 | + ) |
| 121 | + return False |
| 122 | + |
| 123 | + mode = getattr(aclient, "mode", None) or instructor.Mode("json_mode") |
| 124 | + async_acompletion = _build_async_acompletion(litellm.acompletion) |
| 125 | + try: |
| 126 | + instance.aclient = instructor.from_litellm(async_acompletion, mode=mode) |
| 127 | + logger.debug( |
| 128 | + "Rebuilt GenericAPIAdapter.aclient as AsyncInstructor " |
| 129 | + "(was %s, mode=%s)", |
| 130 | + cls_name, |
| 131 | + mode, |
| 132 | + ) |
| 133 | + return True |
| 134 | + except Exception as e: # pragma: no cover - defensive |
| 135 | + logger.debug("Failed to rebuild aclient as AsyncInstructor: %s", e) |
| 136 | + return False |
| 137 | + |
| 138 | + |
| 139 | +def _make_init_wrapper() -> Callable[..., Any]: |
| 140 | + def _init_wrapper(wrapped, instance, args, kwargs): # type: ignore[no-untyped-def] |
| 141 | + result = wrapped(*args, **kwargs) |
| 142 | + try: |
| 143 | + _rebuild_aclient_as_async(instance) |
| 144 | + except Exception as e: # pragma: no cover - defensive |
| 145 | + logger.debug( |
| 146 | + "GenericAPIAdapter.aclient async rebuild failed: %s", e |
| 147 | + ) |
| 148 | + return result |
| 149 | + |
| 150 | + return _init_wrapper |
| 151 | + |
| 152 | + |
| 153 | +def install_llm_compat_wrapper() -> None: |
| 154 | + """Wrap ``GenericAPIAdapter.__init__`` so ``self.aclient`` is async-safe.""" |
| 155 | + try: |
| 156 | + import importlib |
| 157 | + |
| 158 | + module = importlib.import_module(_ADAPTER_MODULE) |
| 159 | + cls = getattr(module, _ADAPTER_CLASS) |
| 160 | + wrap_function_wrapper(cls, "__init__", _make_init_wrapper()) |
| 161 | + except Exception as e: # pragma: no cover - defensive |
| 162 | + logger.debug( |
| 163 | + "Failed to wrap %s.%s.__init__: %s", |
| 164 | + _ADAPTER_MODULE, |
| 165 | + _ADAPTER_CLASS, |
| 166 | + e, |
| 167 | + ) |
| 168 | + |
| 169 | + |
| 170 | +def uninstall_llm_compat_wrapper() -> None: |
| 171 | + from opentelemetry.instrumentation.utils import unwrap |
| 172 | + |
| 173 | + try: |
| 174 | + import importlib |
| 175 | + |
| 176 | + module = importlib.import_module(_ADAPTER_MODULE) |
| 177 | + cls = getattr(module, _ADAPTER_CLASS) |
| 178 | + unwrap(cls, "__init__") |
| 179 | + except Exception as e: # pragma: no cover - defensive |
| 180 | + logger.debug( |
| 181 | + "Failed to unwrap %s.%s.__init__: %s", |
| 182 | + _ADAPTER_MODULE, |
| 183 | + _ADAPTER_CLASS, |
| 184 | + e, |
| 185 | + ) |
0 commit comments