Skip to content

Commit ce8542e

Browse files
authored
Merge pull request #94 from dreadnode/users/raja/minor-perf-improvement-get-generator
Add LRU cache for `get_generator` Functionality
2 parents 2998951 + 8384c3a commit ce8542e

1 file changed

Lines changed: 7 additions & 10 deletions

File tree

rigging/generator/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import typing as t
44
from dataclasses import dataclass, field
5+
from functools import lru_cache
56

67
from loguru import logger
78
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator
@@ -27,8 +28,7 @@
2728

2829
@t.runtime_checkable
2930
class LazyGenerator(t.Protocol):
30-
def __call__(self) -> type["Generator"]:
31-
...
31+
def __call__(self) -> type["Generator"]: ...
3232

3333

3434
g_providers: dict[str, type["Generator"] | LazyGenerator] = {}
@@ -483,16 +483,14 @@ def chat(
483483
self,
484484
messages: t.Sequence[MessageDict],
485485
params: GenerateParams | None = None,
486-
) -> "ChatPipeline":
487-
...
486+
) -> "ChatPipeline": ...
488487

489488
@t.overload
490489
def chat(
491490
self,
492491
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
493492
params: GenerateParams | None = None,
494-
) -> "ChatPipeline":
495-
...
493+
) -> "ChatPipeline": ...
496494

497495
def chat(
498496
self,
@@ -575,17 +573,15 @@ def chat(
575573
generator: Generator,
576574
messages: t.Sequence[MessageDict],
577575
params: GenerateParams | None = None,
578-
) -> "ChatPipeline":
579-
...
576+
) -> "ChatPipeline": ...
580577

581578

582579
@t.overload
583580
def chat(
584581
generator: Generator,
585582
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
586583
params: GenerateParams | None = None,
587-
) -> "ChatPipeline":
588-
...
584+
) -> "ChatPipeline": ...
589585

590586

591587
def chat(
@@ -663,6 +659,7 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
663659
return identifier
664660

665661

662+
@lru_cache(maxsize=128)
666663
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
667664
"""
668665
Get a generator by an identifier string. Uses LiteLLM by default.

0 commit comments

Comments
 (0)