Skip to content

Commit b4fff18

Browse files
committed
Linting fixes
1 parent 82879c7 commit b4fff18

21 files changed

Lines changed: 115 additions & 141 deletions

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ strict = true
104104
target-version = "py310"
105105
line-length = 100
106106
extend-exclude = [
107-
"*.ipynb", # jupyter notebooks
107+
"*.ipynb", # jupyter notebooks
108+
"examples/*", # example files
109+
".github/*", # github files
108110
]
109111

110112
[tool.ruff.lint]

rigging/__init__.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,52 +42,52 @@
4242
__version__ = VERSION
4343

4444
__all__ = [
45-
"get_generator",
46-
"Message",
47-
"MessageDict",
48-
"Messages",
49-
"ContentText",
50-
"ContentImageUrl",
51-
"ContentAudioInput",
52-
"Tool",
53-
"Model",
54-
"attr",
55-
"element",
56-
"wrapped",
5745
"Chat",
5846
"ChatPipeline",
59-
"Generator",
60-
"GenerateParams",
61-
"GeneratedMessage",
62-
"GeneratedText",
63-
"chat",
64-
"complete",
6547
"Completion",
6648
"CompletionPipeline",
67-
"register_generator",
68-
"prompt",
69-
"Prompt",
49+
"ContentAudioInput",
50+
"ContentImageUrl",
51+
"ContentText",
7052
"Ctx",
71-
"data",
72-
"watchers",
73-
"model",
74-
"error",
75-
"parsing",
76-
"tool",
77-
"tool_method",
78-
"logging",
79-
"await_",
80-
"interact",
81-
"ThenChatCallback",
53+
"GenerateParams",
54+
"GeneratedMessage",
55+
"GeneratedText",
56+
"Generator",
8257
"MapChatCallback",
83-
"ThenCompletionCallback",
8458
"MapCompletionCallback",
59+
"Message",
60+
"MessageDict",
61+
"Messages",
62+
"Model",
8563
"PipelineStep",
86-
"PipelineStepGenerator",
8764
"PipelineStepContextManager",
65+
"PipelineStepGenerator",
66+
"Prompt",
67+
"ThenChatCallback",
68+
"ThenCompletionCallback",
69+
"Tool",
70+
"attr",
71+
"await_",
72+
"chat",
73+
"complete",
74+
"data",
75+
"element",
76+
"error",
8877
"generator",
78+
"get_generator",
79+
"interact",
80+
"logging",
8981
"mcp",
82+
"model",
83+
"parsing",
84+
"prompt",
85+
"register_generator",
9086
"robopages",
87+
"tool",
88+
"tool_method",
89+
"watchers",
90+
"wrapped",
9191
]
9292

9393
from loguru import logger

rigging/chat.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def message_dicts(self) -> list[MessageDict]:
217217
"""
218218
return [
219219
t.cast(
220-
MessageDict,
220+
"MessageDict",
221221
m.model_dump(include={"role", "content_parts"}, exclude_none=True),
222222
)
223223
for m in self.all
@@ -409,7 +409,7 @@ def inject_tool_prompt(
409409

410410
tool_system_prompt = tool_description_prompt_part(
411411
definitions,
412-
t.cast(t.Literal["xml", "json-in-xml"], mode),
412+
t.cast("t.Literal['xml', 'json-in-xml']", mode),
413413
)
414414
return self.inject_system_content(tool_system_prompt)
415415

@@ -524,8 +524,7 @@ def __call__(
524524
self,
525525
chat: Chat,
526526
/,
527-
) -> t.Awaitable[Chat | None]:
528-
...
527+
) -> t.Awaitable[Chat | None]: ...
529528

530529

531530
@runtime_checkable
@@ -534,8 +533,7 @@ def __call__(
534533
self,
535534
chat: Chat,
536535
/,
537-
) -> "PipelineStepGenerator | PipelineStepContextManager | t.Awaitable[PipelineStepGenerator | PipelineStepContextManager | None]":
538-
...
536+
) -> "PipelineStepGenerator | PipelineStepContextManager | t.Awaitable[PipelineStepGenerator | PipelineStepContextManager | None]": ...
539537

540538

541539
ThenChatCallback = _ThenChatCallback | _ThenChatStepCallback
@@ -550,8 +548,7 @@ def __call__(
550548
self,
551549
chats: list[Chat],
552550
/,
553-
) -> t.Awaitable[list[Chat]]:
554-
...
551+
) -> t.Awaitable[list[Chat]]: ...
555552

556553

557554
@runtime_checkable
@@ -560,8 +557,7 @@ def __call__(
560557
self,
561558
chats: list[Chat],
562559
/,
563-
) -> "PipelineStepGenerator | PipelineStepContextManager | t.Awaitable[PipelineStepGenerator | PipelineStepContextManager]":
564-
...
560+
) -> "PipelineStepGenerator | PipelineStepContextManager | t.Awaitable[PipelineStepGenerator | PipelineStepContextManager]": ...
565561

566562

567563
MapChatCallback = _MapChatCallback | _MapChatStepCallback
@@ -658,9 +654,10 @@ def depth(self) -> int:
658654
This is useful for setting constraints on recursion depth.
659655
"""
660656
depth = 0
661-
while self.parent is not None:
657+
current = self
658+
while current.parent is not None:
662659
depth += 1
663-
self = self.parent
660+
current = current.parent
664661
return depth
665662

666663

@@ -852,8 +849,7 @@ def add(
852849
and self.chat.all[-1].role == message_list[0].role
853850
and (
854851
merge_strategy == "all"
855-
or merge_strategy == "only-user-role"
856-
and self.chat.all[-1].role == "user"
852+
or (merge_strategy == "only-user-role" and self.chat.all[-1].role == "user")
857853
)
858854
):
859855
self.chat.all[-1].content_parts += message_list[0].content_parts
@@ -1263,7 +1259,7 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
12631259

12641260
# Parse the actual tool calls
12651261

1266-
tool_calls: (list[ApiToolCall] | list[XmlToolCall] | list[JsonInXmlToolCall] | None) = None
1262+
tool_calls: list[ApiToolCall] | list[XmlToolCall] | list[JsonInXmlToolCall] | None = None
12671263
if self.tool_mode == "api":
12681264
tool_calls = chat.last.tool_calls
12691265
if self.tool_mode == "xml":
@@ -1435,7 +1431,7 @@ async def complete() -> None:
14351431
)
14361432

14371433
generator = t.cast(
1438-
PipelineStepGenerator,
1434+
"PipelineStepGenerator",
14391435
await exit_stack.enter_async_context(aclosing(result)),
14401436
)
14411437
async for step in generator:
@@ -1675,7 +1671,7 @@ async def _step( # noqa: PLR0915, PLR0912
16751671

16761672
if inspect.isasyncgen(chats_or_generator):
16771673
generator = t.cast(
1678-
PipelineStepGenerator,
1674+
"PipelineStepGenerator",
16791675
await exit_stack.enter_async_context(
16801676
aclosing(chats_or_generator),
16811677
),

rigging/completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __init__(
263263
264264
ExhuastedMaxRounds is implicitly included.
265265
"""
266-
self.on_failed: "FailMode" = "raise"
266+
self.on_failed: FailMode = "raise"
267267
"""How to handle failures in the pipeline unless overriden in calls."""
268268

269269
# (callback, all_text, max_rounds)
@@ -747,7 +747,7 @@ async def _run( # noqa: PLR0912
747747
raise inbound # noqa: TRY301
748748

749749
inbounds = [inbound for inbound in _inbounds if isinstance(inbound, GeneratedText)]
750-
except Exception as e: # noqa: BLE001
750+
except Exception as e:
751751
if on_failed == "raise" or not any(
752752
isinstance(e, t) for t in self.errors_to_fail_on
753753
):
@@ -768,14 +768,14 @@ async def _run( # noqa: PLR0912
768768
state.processor.send(inbound.text)
769769
continue
770770
except StopIteration as stop:
771-
output = t.cast(str, stop.value)
771+
output = t.cast("str", stop.value)
772772
except CompletionExhaustedMaxRoundsError as exhausted:
773773
if on_failed == "raise":
774774
raise
775775
output = exhausted.completion
776776
failed = True
777777
error = exhausted
778-
except Exception as e: # noqa: BLE001
778+
except Exception as e:
779779
if on_failed == "raise" or not any(
780780
isinstance(e, t) for t in self.errors_to_fail_on
781781
):

rigging/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def elastic_data_to_chats(
297297
while all(hasattr(data, attr) for attr in ("keys", "__getitem__")) and "hits" in data:
298298
data = data["hits"]
299299

300-
objects = t.cast(t.Sequence[t.Mapping[str, t.Any]], data)
300+
objects = t.cast("t.Sequence[t.Mapping[str, t.Any]]", data)
301301
if not isinstance(objects, t.Sequence):
302302
raise TypeError(
303303
f"Expected to find a sequence of objects (optionally under hits), found: {type(data)}",
@@ -341,7 +341,7 @@ async def elastic_to_chats(
341341
A pandas DataFrame containing the chat data.
342342
"""
343343
data = await client.search(index=index, query=query, size=max_results, **kwargs)
344-
return elastic_data_to_chats(t.cast(dict[str, t.Any], data))
344+
return elastic_data_to_chats(t.cast("dict[str, t.Any]", data))
345345

346346

347347
async def s3_bucket_exists(client: S3Client, bucket: str) -> bool:

rigging/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _raise_as(func: t.Callable[P, R]) -> t.Callable[P, R]:
122122
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
123123
try:
124124
return func(*args, **kwargs)
125-
except Exception as e: # noqa: BLE001
125+
except Exception as e:
126126
error = error_type(message)
127127
raise error from e
128128

rigging/generator/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,19 @@ def get_transformers_lazy() -> type[Generator]:
5454
register_generator("transformers", get_transformers_lazy)
5555

5656
__all__ = [
57-
"get_generator",
58-
"Generator",
5957
"GenerateParams",
6058
"GeneratedMessage",
6159
"GeneratedText",
60+
"Generator",
61+
"HTTPGenerator",
62+
"LiteLLMGenerator",
6263
"StopReason",
6364
"Usage",
6465
"chat",
6566
"complete",
6667
"get_generator",
67-
"register_generator",
68+
"get_generator",
6869
"get_identifier",
69-
"LiteLLMGenerator",
70-
"HTTPGenerator",
70+
"register_generator",
7171
# TODO: We can't add VLLM and Transformers here because they are lazy loaded
7272
]

rigging/generator/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def get_identifier(generator: Generator, params: GenerateParams | None = None) -
663663
return identifier
664664

665665

666-
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator: # noqa: PLR0912
666+
def get_generator(identifier: str, *, params: GenerateParams | None = None) -> Generator:
667667
"""
668668
Get a generator by an identifier string. Uses LiteLLM by default.
669669
@@ -708,24 +708,24 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
708708
if "!" in identifier:
709709
try:
710710
provider, model = identifier.split("!")
711-
except Exception as e: # noqa: BLE001
711+
except Exception as e:
712712
raise InvalidModelSpecifiedError(identifier) from e
713713

714714
if provider not in g_providers:
715715
raise InvalidModelSpecifiedError(identifier)
716716

717717
if not isinstance(g_providers[provider], type):
718-
lazy_generator = t.cast(LazyGenerator, g_providers[provider])
718+
lazy_generator = t.cast("LazyGenerator", g_providers[provider])
719719
g_providers[provider] = lazy_generator()
720720

721-
generator_cls = t.cast(type[Generator], g_providers[provider])
721+
generator_cls = t.cast("type[Generator]", g_providers[provider])
722722

723723
kwargs = {}
724724
if "," in model:
725725
try:
726726
model, kwargs_str = model.split(",", 1)
727727
kwargs = dict(arg.split("=", 1) for arg in kwargs_str.split(","))
728-
except Exception as e: # noqa: BLE001
728+
except Exception as e:
729729
raise InvalidModelSpecifiedError(identifier) from e
730730

731731
# See if any of the kwargs would apply to the cls constructor directly
@@ -753,7 +753,7 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
753753

754754
try:
755755
merged_params = GenerateParams(**kwargs).merge_with(params)
756-
except Exception as e: # noqa: BLE001
756+
except Exception as e:
757757
raise InvalidModelSpecifiedError(identifier) from e
758758

759759
return generator_cls(model=model, params=merged_params, **init_kwargs)

rigging/generator/http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def _to_str(v: str | dict[str, t.Any]) -> str:
3333

3434

3535
def _to_dict(v: str | dict[str, t.Any]) -> dict[str, t.Any]:
36-
return t.cast(dict[str, t.Any], json.loads(v)) if isinstance(v, str) else v
36+
return t.cast("dict[str, t.Any]", json.loads(v)) if isinstance(v, str) else v
3737

3838

3939
def _to_dict_or_str(v: str) -> dict[str, t.Any] | str:
4040
try:
41-
return t.cast(dict[str, t.Any], json.loads(v))
41+
return t.cast("dict[str, t.Any]", json.loads(v))
4242
except json.JSONDecodeError:
4343
return v
4444

rigging/interact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def interact(
8787
else:
8888
pipeline.add(user_input)
8989

90-
print("")
90+
print()
9191

9292
animation_task = asyncio.create_task(_animate())
9393
chat = await pipeline.run()

0 commit comments

Comments
 (0)