diff --git a/rigging/generator/base.py b/rigging/generator/base.py index da42c88..21a51f8 100644 --- a/rigging/generator/base.py +++ b/rigging/generator/base.py @@ -2,6 +2,7 @@ import inspect import typing as t from dataclasses import dataclass, field +from functools import lru_cache from loguru import logger from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator @@ -27,8 +28,7 @@ @t.runtime_checkable class LazyGenerator(t.Protocol): - def __call__(self) -> type["Generator"]: - ... + def __call__(self) -> type["Generator"]: ... g_providers: dict[str, type["Generator"] | LazyGenerator] = {} @@ -483,16 +483,14 @@ def chat( self, messages: t.Sequence[MessageDict], params: GenerateParams | None = None, - ) -> "ChatPipeline": - ... + ) -> "ChatPipeline": ... @t.overload def chat( self, messages: t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, - ) -> "ChatPipeline": - ... + ) -> "ChatPipeline": ... def chat( self, @@ -575,8 +573,7 @@ def chat( generator: Generator, messages: t.Sequence[MessageDict], params: GenerateParams | None = None, -) -> "ChatPipeline": - ... +) -> "ChatPipeline": ... @t.overload @@ -584,8 +581,7 @@ def chat( generator: Generator, messages: t.Sequence[Message] | MessageDict | Message | str | None = None, params: GenerateParams | None = None, -) -> "ChatPipeline": - ... +) -> "ChatPipeline": ... def chat( @@ -663,6 +659,7 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) - return identifier +@lru_cache(maxsize=128) def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator: """ Get a generator by an identifier string. Uses LiteLLM by default.