|
2 | 2 | import inspect |
3 | 3 | import typing as t |
4 | 4 | from dataclasses import dataclass, field |
| 5 | +from functools import lru_cache |
5 | 6 |
|
6 | 7 | from loguru import logger |
7 | 8 | from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator |
|
27 | 28 |
|
28 | 29 | @t.runtime_checkable |
29 | 30 | class LazyGenerator(t.Protocol): |
30 | | - def __call__(self) -> type["Generator"]: |
31 | | - ... |
| 31 | + def __call__(self) -> type["Generator"]: ... |
32 | 32 |
|
33 | 33 |
|
34 | 34 | g_providers: dict[str, type["Generator"] | LazyGenerator] = {} |
@@ -483,16 +483,14 @@ def chat( |
483 | 483 | self, |
484 | 484 | messages: t.Sequence[MessageDict], |
485 | 485 | params: GenerateParams | None = None, |
486 | | - ) -> "ChatPipeline": |
487 | | - ... |
| 486 | + ) -> "ChatPipeline": ... |
488 | 487 |
|
489 | 488 | @t.overload |
490 | 489 | def chat( |
491 | 490 | self, |
492 | 491 | messages: t.Sequence[Message] | MessageDict | Message | str | None = None, |
493 | 492 | params: GenerateParams | None = None, |
494 | | - ) -> "ChatPipeline": |
495 | | - ... |
| 493 | + ) -> "ChatPipeline": ... |
496 | 494 |
|
497 | 495 | def chat( |
498 | 496 | self, |
@@ -575,17 +573,15 @@ def chat( |
575 | 573 | generator: Generator, |
576 | 574 | messages: t.Sequence[MessageDict], |
577 | 575 | params: GenerateParams | None = None, |
578 | | -) -> "ChatPipeline": |
579 | | - ... |
| 576 | +) -> "ChatPipeline": ... |
580 | 577 |
|
581 | 578 |
|
582 | 579 | @t.overload |
583 | 580 | def chat( |
584 | 581 | generator: Generator, |
585 | 582 | messages: t.Sequence[Message] | MessageDict | Message | str | None = None, |
586 | 583 | params: GenerateParams | None = None, |
587 | | -) -> "ChatPipeline": |
588 | | - ... |
| 584 | +) -> "ChatPipeline": ... |
589 | 585 |
|
590 | 586 |
|
591 | 587 | def chat( |
@@ -663,6 +659,7 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) - |
663 | 659 | return identifier |
664 | 660 |
|
665 | 661 |
|
| 662 | +@lru_cache(maxsize=128) |
666 | 663 | def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator: |
667 | 664 | """ |
668 | 665 | Get a generator by an identifier string. Uses LiteLLM by default. |
|
0 commit comments