Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 58 additions & 40 deletions rigging/generator/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import functools
import inspect
import typing as t
from dataclasses import dataclass, field
from functools import lru_cache

from loguru import logger
Expand Down Expand Up @@ -42,7 +42,7 @@ class Fixup(abc.ABC):
"""

@abc.abstractmethod
def can_fix(self, exception: Exception) -> bool:
def can_fix(self, exception: Exception) -> bool | t.Literal["once"]:
"""
Check if the fixup can resolve the given exception if made active.

Expand All @@ -68,10 +68,62 @@ def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
...


@dataclass
class Fixups:
available: list[Fixup] = field(default_factory=list)
active: list[Fixup] = field(default_factory=list)
FixupCompatibleFunc = t.Callable[
t.Concatenate[t.Any, t.Sequence[Message], P],
t.Awaitable[R],
]


def with_fixups(
*fixups: Fixup,
) -> t.Callable[[FixupCompatibleFunc[P, R]], FixupCompatibleFunc[P, R]]:
"""
Decorator that adds fixup retry logic with persistent state.

Args:
fixups: Sequence of fixups to try
"""
available_fixups: list[Fixup] = list(fixups)
active_fixups: list[Fixup] = []
once_fixups: list[Fixup] = []

def decorator(func: FixupCompatibleFunc[P, R]) -> FixupCompatibleFunc[P, R]:
@functools.wraps(func)
async def wrapper(
self: t.Any,
messages: t.Sequence[Message],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
nonlocal available_fixups, active_fixups

for fixup in [*active_fixups, *once_fixups]:
messages = fixup.fix(messages)

try:
result = await func(self, messages, *args, **kwargs)
available_fixups = [*available_fixups, *once_fixups]
once_fixups.clear()
except Exception as e:
for fixup in list(available_fixups):
if (can_fix := fixup.can_fix(e)) is False:
continue

if can_fix == "once":
once_fixups.append(fixup)
else:
active_fixups.append(fixup)
available_fixups.remove(fixup)

return await wrapper(self, messages, *args, **kwargs)

raise

return result

return wrapper # type: ignore[return-value]

return decorator


# TODO: We also would like to support N-style
Expand Down Expand Up @@ -305,8 +357,6 @@ class Generator(BaseModel):
_watch_callbacks: list["WatchChatCallback | WatchCompletionCallback"] = []
_wrap: t.Callable[[CallableT], CallableT] | None = None

_fixups: Fixups = Fixups()

def to_identifier(self, params: GenerateParams | None = None) -> str:
"""
Converts the generator instance back into a rigging identifier string.
Expand Down Expand Up @@ -393,38 +443,6 @@ async def supports_function_calling(self) -> bool | None:
"""
return None

def _check_fixups(self, error: Exception) -> bool:
"""
Check if any fixer can handle this error.

Args:
error: The error to be checked.

Returns:
Whether a fixer was able to handle the error.
"""
for fixup in self._fixups.available[:]:
if fixup.can_fix(error):
self._fixups.active.append(fixup)
self._fixups.available.remove(fixup)
return True
return False

async def _apply_fixups(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
"""
Apply all active fixups to the messages.

Args:
messages: The messages to be fixed.

Returns:
The fixed messages.
"""
current_messages = messages
for fixup in self._fixups.active:
current_messages = fixup.fix(current_messages)
return current_messages

async def generate_messages(
self,
messages: t.Sequence[t.Sequence[Message]],
Expand Down
60 changes: 40 additions & 20 deletions rigging/generator/litellm_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

from rigging.generator.base import (
Fixup,
Fixups,
GeneratedMessage,
GeneratedText,
GenerateParams,
Generator,
trace_messages,
trace_str,
with_fixups,
)
from rigging.message import ContentAudioInput, ContentImageUrl, ContentText, Message
from rigging.tool.api import ApiFunctionDefinition, ApiToolDefinition
Expand Down Expand Up @@ -74,11 +74,40 @@ class CacheTooSmallFixup(Fixup):
# are below a certain threshold can result in a 400
# error from APIs (Vertex/Gemini).

def can_fix(self, exception: Exception) -> bool | t.Literal["once"]:
return "once" if "Cached content is too small." in str(exception) else False

def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
return [message.cache(False) for message in messages]


class GroqAssistantContentFixup(Fixup):
# Groq can complain if we try to send fully
# structured content parts when working with
# the assistant role.
#
# Compatibility flags are a poor workaround for the
# fact that we don't have direct control over the
# conversion to the OpenAI spec.

def can_fix(self, exception: Exception) -> bool:
return "Cached content is too small." in str(exception)
return "Groq" in str(exception) and "content' : value must be a string" in str(exception)

def fix(self, messages: t.Sequence[Message]) -> t.Sequence[Message]:
updated_messages: list[Message] = []
for message in messages:
if message.role == "assistant":
message = message.clone() # noqa: PLW2901
message._compability_flags.add("content_as_str") # noqa: SLF001
updated_messages.append(message)
return updated_messages

def fix(self, items: t.Sequence[Message]) -> t.Sequence[Message]:
return [message.cache(False) for message in items]

g_fixups = [
OpenAIToolsWithImageURLsFixup(),
CacheTooSmallFixup(),
GroqAssistantContentFixup(),
]


class LiteLLMGenerator(Generator):
Expand Down Expand Up @@ -123,8 +152,6 @@ class LiteLLMGenerator(Generator):
_last_request_time: datetime.datetime | None = None
_supports_function_calling: bool | None = None

_fixups = Fixups(available=[OpenAIToolsWithImageURLsFixup(), CacheTooSmallFixup()])

@property
def semaphore(self) -> asyncio.Semaphore:
if self._semaphore is None:
Expand Down Expand Up @@ -299,6 +326,7 @@ def _parse_text_completion_response(
extra={"response_id": response.id},
)

@with_fixups(*g_fixups)
async def _generate_message(
self,
messages: t.Sequence[Message],
Expand All @@ -313,20 +341,12 @@ async def _generate_message(
if self._wrap is not None:
acompletion = self._wrap(acompletion)

# Prepare messages for specific providers
messages = await self._apply_fixups(messages)

try:
response = await acompletion(
model=self.model,
messages=[message.to_openai_spec() for message in messages],
api_key=self.api_key,
**self.params.merge_with(params).to_dict(),
)
except Exception as e:
if self._check_fixups(e):
return await self._generate_message(messages, params)
raise
response = await acompletion(
model=self.model,
messages=[message.to_openai_spec() for message in messages],
api_key=self.api_key,
**self.params.merge_with(params).to_dict(),
)

self._last_request_time = datetime.datetime.now(tz=datetime.timezone.utc)
return self._parse_model_response(response)
Expand Down
17 changes: 16 additions & 1 deletion rigging/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def save(self, path: Path | str) -> None:
"""The types of content that can be included in a message."""
ContentTypes = (ContentText, ContentImageUrl, ContentAudioInput)

CompatibilityFlag = t.Literal["content_as_str"]


class Message(BaseModel):
"""
Expand Down Expand Up @@ -364,6 +366,8 @@ class Message(BaseModel):
tool_call_id: str | None = Field(None)
"""Associated call id if this message is a response to a tool call."""

_compability_flags: set[CompatibilityFlag] = set()

def __init__(
self,
role: Role,
Expand Down Expand Up @@ -546,7 +550,7 @@ def to_openai_spec(self) -> dict[str, t.Any]:
isinstance(current, dict)
and current.get("type") == "text"
and next_.get("type") == "text"
and not current.get("text", "").endswith("\n")
and not str(current.get("text", "")).endswith("\n")
):
current["text"] += "\n"

Expand All @@ -556,6 +560,17 @@ def to_openai_spec(self) -> dict[str, t.Any]:
if isinstance(part, dict) and part.get("type") == "input_audio":
part.get("input_audio", {}).pop("transcript", None)

# If enabled, we need to convert our content to a flat
# string for API compatibility. Groq is an example of an API
# which will complain for some roles if we send a list of content parts.

if "content_as_str" in self._compability_flags:
obj["content"] = "".join(
part["text"]
for part in obj["content"]
if isinstance(part, dict) and part.get("type") == "text"
)

return obj

# TODO: In general the add/remove/sync_part methods are
Expand Down
Loading