Skip to content

Commit a8b80de

Browse files
committed
docs: fix 36 docstring quality gate failures across 17 files
- Fix missing_param_type, missing_return_type, param_type_mismatch, return_type_mismatch, no_args, no_returns, and missing docstring issues - Add TYPE_CHECKING imports for HuggingFace types in util.py with type: ignore[union-attr] for pre-existing None-safety gaps - Add Granite3ChatCompletion import to granite32/33 input.py for correct sanitize() parent signature match - Convert reST-style docstrings to Google style in intrinsics/input.py - Document AST single-quote normalization for Literal types in CONTRIBUTING.md
1 parent 4bfcd0a commit a8b80de

17 files changed

Lines changed: 110 additions & 45 deletions

File tree

cli/alora/intrinsic_uploader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def upload_intrinsic(
4040
base_model (str): Base model ID or path (e.g.
4141
``"ibm-granite/granite-3.3-2b-instruct"``). Must contain at most
4242
one ``"/"`` separator.
43-
type (Literal["lora", "alora"]): Adapter type, used as the leaf
43+
type (Literal['lora', 'alora']): Adapter type, used as the leaf
4444
directory name in the repository layout.
4545
io_yaml (str): Path to the ``io.yaml`` configuration file for
4646
intrinsic input/output processing.

cli/alora/train.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import typer
1717
from datasets import Dataset
1818
from peft import LoraConfig, get_peft_model
19-
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
19+
from transformers import (
20+
AutoModelForCausalLM,
21+
AutoTokenizer,
22+
TrainerCallback,
23+
TrainerControl,
24+
TrainerState,
25+
TrainingArguments,
26+
)
2027
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
2128

2229
# Handle MPS with old PyTorch versions on macOS only
@@ -39,7 +46,9 @@
3946
)
4047

4148

42-
def load_dataset_from_json(json_path, tokenizer, invocation_prompt):
49+
def load_dataset_from_json(
50+
json_path: str, tokenizer: AutoTokenizer, invocation_prompt: str
51+
) -> Dataset:
4352
"""Load a JSONL dataset and format it for SFT training.
4453
4554
Reads ``item``/``label`` pairs from a JSONL file and builds a HuggingFace
@@ -73,7 +82,7 @@ def load_dataset_from_json(json_path, tokenizer, invocation_prompt):
7382
return Dataset.from_dict({"input": inputs, "target": targets})
7483

7584

76-
def formatting_prompts_func(example):
85+
def formatting_prompts_func(example: dict) -> list[str]:
7786
"""Concatenate input and target columns for SFT prompt formatting.
7887
7988
Args:
@@ -101,7 +110,13 @@ class SaveBestModelCallback(TrainerCallback):
101110
def __init__(self):
102111
self.best_eval_loss = float("inf")
103112

104-
def on_evaluate(self, args, state, control, **kwargs):
113+
def on_evaluate(
114+
self,
115+
args: TrainingArguments,
116+
state: TrainerState,
117+
control: TrainerControl,
118+
**kwargs,
119+
):
105120
"""Save the adapter weights if the current evaluation loss is a new best.
106121
107122
Called automatically by the HuggingFace Trainer after each evaluation

cli/decompose/decompose.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ class DecompVersion(StrEnum):
4949
def reorder_subtasks(
5050
subtasks: list[DecompSubtasksResult],
5151
) -> list[DecompSubtasksResult]:
52+
"""Topologically sort subtasks by their ``depends_on`` relationships.
53+
54+
Args:
55+
subtasks: List of subtask dicts, each with a ``"tag"`` and optional
56+
``"depends_on"`` field.
57+
58+
Returns:
59+
list[DecompSubtasksResult]: The subtasks reordered so that dependencies
60+
come before dependents, with numbering prefixes updated.
61+
62+
Raises:
63+
ValueError: If a circular dependency is detected.
64+
"""
5265
subtask_map = {subtask["tag"].lower(): subtask for subtask in subtasks}
5366

5467
graph = {}
@@ -78,6 +91,19 @@ def reorder_subtasks(
7891
def verify_user_variables(
7992
decomp_data: DecompPipelineResult, input_var: list[str] | None
8093
) -> DecompPipelineResult:
94+
"""Validate that all required input variables and dependencies exist.
95+
96+
Args:
97+
decomp_data: The decomposition pipeline result containing subtasks.
98+
input_var: User-provided input variable names, or ``None`` for none.
99+
100+
Returns:
101+
DecompPipelineResult: The (possibly reordered) decomposition data.
102+
103+
Raises:
104+
ValueError: If a subtask requires an input variable that was not
105+
provided, or depends on a subtask tag that does not exist.
106+
"""
81107
if input_var is None:
82108
input_var = []
83109

cli/eval/runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
self.score = score
5050
self.validation_reason = validation_reason
5151

52-
def to_dict(self):
52+
def to_dict(self) -> dict:
5353
"""Serialise the input evaluation result to a plain dictionary.
5454
5555
Returns:
@@ -84,7 +84,7 @@ def __init__(self, test_eval: TestBasedEval, input_results: list[InputEvalResult
8484
self.test_eval = test_eval
8585
self.input_results = input_results
8686

87-
def to_dict(self):
87+
def to_dict(self) -> dict:
8888
"""Serialise the test evaluation result to a plain dictionary.
8989
9090
Returns:
@@ -366,7 +366,7 @@ def execute_test_eval(
366366
return test_result
367367

368368

369-
def parse_judge_output(judge_output: str):
369+
def parse_judge_output(judge_output: str) -> tuple[int | None, str]:
370370
"""Parse score and justification from a judge model's output string.
371371
372372
Args:

docs/docs/guide/CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ in the table below, follow the fix instructions, and re-push.
391391
| `no_raises` | Function source contains `raise` but the docstring has no `Raises:` section | Add a `Raises:` section listing each exception type and the condition that triggers it |
392392
| `missing_param_type` | `Args:` section exists but one or more parameters have no Python type annotation — the type column is absent from the generated API docs | Add a type annotation to each listed parameter in the function signature (e.g. `def f(x: int)`). Only fires when `no_args` is already satisfied; `*args`/`**kwargs` are excluded. |
393393
| `missing_return_type` | `Returns:` section is documented but the function has no return type annotation — the return type is absent from the generated API docs | Add a return annotation to the function signature (e.g. `-> str`). Only fires when `no_returns` is already satisfied. |
394-
| `param_type_mismatch` | A parameter's `Args:` entry states an explicit type (e.g. `x (int): …`) that does not match the Python annotation in the function signature | Align the docstring type with the annotation, or vice versa. The check normalises common equivalents (`Optional[X]``X \| None`, `List``list`, union ordering) before comparing, so only genuine disagreements are flagged. Only fires when both the docstring and the signature have an explicit type. |
394+
| `param_type_mismatch` | A parameter's `Args:` entry states an explicit type (e.g. `x (int): …`) that does not match the Python annotation in the function signature | Align the docstring type with the annotation, or vice versa. The check normalises common equivalents (`Optional[X]``X \| None`, `List``list`, union ordering) before comparing, so only genuine disagreements are flagged. Only fires when both the docstring and the signature have an explicit type. **Note:** Python's AST normalises string literals to single quotes, so `Literal["a", "b"]` in source is read as `Literal['a', 'b']` — use single quotes in docstrings to match. |
395395
| `return_type_mismatch` | The `Returns:` section has a type prefix (e.g. `Returns: \n str: …`) that does not match the Python return annotation | Align the docstring return type with the annotation, or vice versa. Same normalisation rules as `param_type_mismatch`. Only fires when both sides have an explicit type. |
396396

397397
#### Class docstrings (Option C)

mellea/backends/tools.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def as_json_tool(self) -> dict[str, Any]:
7070
return self._as_json_tool.copy()
7171

7272
@classmethod
73-
def from_langchain(cls, tool: Any):
73+
def from_langchain(cls, tool: Any) -> "MelleaTool":
7474
"""Create a MelleaTool from a LangChain tool object.
7575
7676
Args:
@@ -117,7 +117,7 @@ def parameter_remapper(*args, **kwargs):
117117
) from e
118118

119119
@classmethod
120-
def from_smolagents(cls, tool: Any):
120+
def from_smolagents(cls, tool: Any) -> "MelleaTool":
121121
"""Create a Tool from a HuggingFace smolagents tool object.
122122
123123
Args:
@@ -172,7 +172,7 @@ def tool_call(*args, **kwargs):
172172
) from e
173173

174174
@classmethod
175-
def from_callable(cls, func: Callable, name: str | None = None):
175+
def from_callable(cls, func: Callable, name: str | None = None) -> "MelleaTool":
176176
"""Create a MelleaTool from a plain Python callable.
177177
178178
Introspects the callable's signature and docstring to build an
@@ -379,7 +379,7 @@ def json_extraction(text: str) -> Generator[dict, None, None]:
379379
index = text.find("{", index)
380380

381381

382-
def find_func(d) -> tuple[str | None, Mapping | None]:
382+
def find_func(d: object) -> tuple[str | None, Mapping | None]:
383383
"""Find the first function in a json-like dictionary.
384384
385385
Most llms output tool requests in the form ``...{"name": string, "arguments": {}}...``

mellea/core/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,5 @@ async def sample(
133133
tool_calls: True if tool calls should be used during this sampling strategy.
134134
135135
Returns:
136-
SamplingResult: A result object indicating the success or failure of the sampling process.
136+
SamplingResult[S]: A result object indicating the success or failure of the sampling process.
137137
"""

mellea/core/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ class JsonFormatter(logging.Formatter):
6767
process ID, thread ID, and (if present) exception information.
6868
"""
6969

70-
def format(self, record): # type: ignore
70+
def format(self, record: logging.LogRecord) -> dict: # type: ignore[override]
7171
"""Formats a log record as a JSON-serialisable dictionary.
7272
7373
Includes timestamp, level, message, module, function name, line number,
7474
process ID, thread ID, and exception info if present.
7575
7676
Args:
7777
record (logging.LogRecord): The log record to format.
78+
79+
Returns:
80+
dict: A dictionary containing timestamp, level, message, module, function,
81+
line number, process/thread IDs, and optional exception info.
7882
"""
7983
log_record = {
8084
"timestamp": self.formatTime(record, self.datefmt),

mellea/formatters/granite/base/util.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,23 @@
33
"""Common utility functions for the library and tests."""
44

55
# Standard
6+
from __future__ import annotations
7+
68
import contextlib
79
import itertools
810
import json
911
import logging
1012
import os
1113
import re
1214
import uuid
15+
from typing import TYPE_CHECKING
1316

1417
# Third Party
1518
import pydantic
1619

20+
if TYPE_CHECKING:
21+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
22+
1723
# First Party
1824
from .types import ChatCompletionResponse, ChatCompletionResponseChoice
1925

@@ -98,7 +104,7 @@ def random_uuid() -> str:
98104
return str(uuid.uuid4())
99105

100106

101-
def load_transformers_lora(local_or_remote_path):
107+
def load_transformers_lora(local_or_remote_path: str) -> tuple:
102108
"""Load transformers LoRA model.
103109
104110
AutoModelForCausalLM.from_pretrained() is supposed to auto-load base models if you
@@ -136,7 +142,10 @@ def load_transformers_lora(local_or_remote_path):
136142

137143

138144
def chat_completion_request_to_transformers_inputs(
139-
request, tokenizer=None, model=None, constrained_decoding_prefix=None
145+
request: dict,
146+
tokenizer: PreTrainedTokenizerBase | None = None,
147+
model: PreTrainedModel | None = None,
148+
constrained_decoding_prefix: str | None = None,
140149
) -> tuple[dict, dict]:
141150
"""Translate an OpenAI-style chat completion request.
142151
@@ -191,7 +200,7 @@ def chat_completion_request_to_transformers_inputs(
191200
):
192201
tokenizer_input["documents"] = request["extra_body"]["documents"]
193202

194-
input_tokens = tokenizer.apply_chat_template(**tokenizer_input, return_tensors="pt")
203+
input_tokens = tokenizer.apply_chat_template(**tokenizer_input, return_tensors="pt") # type: ignore[union-attr]
195204

196205
# Transformers 5 switched the return type of apply_chat_template() from Tensor to
197206
# BatchEncoding. Adjust our behavior depending on which direction the currently
@@ -208,17 +217,17 @@ def chat_completion_request_to_transformers_inputs(
208217

209218
# generate() will fail with many different creative error messages if tokens aren't
210219
# on the right device.
211-
input_tokens = input_tokens.to(model.device)
220+
input_tokens = input_tokens.to(model.device) # type: ignore[union-attr]
212221
generate_input["input_tokens"] = input_tokens
213222

214223
# The generate() method sometimes needs to know what is the integer ID
215224
# of the padding token, and for some reason this critical piece of information
216225
# isn't included in the serialized model. We get it from the tokenizer.
217226
# And of course some tokenizers don't set this parameter, in which case
218227
# we use the end of string token and hope for the best.
219-
pad_token_id = tokenizer.pad_token_id
228+
pad_token_id = tokenizer.pad_token_id # type: ignore[union-attr]
220229
if pad_token_id is None:
221-
pad_token_id = tokenizer.eos_token_id
230+
pad_token_id = tokenizer.eos_token_id # type: ignore[union-attr]
222231
if pad_token_id is None:
223232
# Raise an error here because the some branches of the generate
224233
# method won't complain about an invalid value of this parameter,
@@ -229,7 +238,7 @@ def chat_completion_request_to_transformers_inputs(
229238

230239
# Make sure you specify this parameter explicitly, or you will have
231240
# a bad time.
232-
generate_input["eos_token_id"] = tokenizer.eos_token_id
241+
generate_input["eos_token_id"] = tokenizer.eos_token_id # type: ignore[union-attr]
233242

234243
other_input = {}
235244

@@ -316,7 +325,10 @@ def chat_completion_request_to_transformers_inputs(
316325

317326

318327
def generate_with_transformers(
319-
tokenizer, model, generate_input: dict, other_input: dict
328+
tokenizer: PreTrainedTokenizerBase,
329+
model: PreTrainedModel,
330+
generate_input: dict,
331+
other_input: dict,
320332
) -> ChatCompletionResponse:
321333
"""Call Transformers generate and get usable results.
322334

mellea/formatters/granite/granite3/granite32/input.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NO_TOOLS_NO_DOCS_NO_THINKING_SYSTEM_MESSAGE_PART,
1414
)
1515
from ...granite3.input import Granite3InputProcessor
16+
from ...granite3.types import Granite3ChatCompletion
1617

1718
# Local
1819
from .constants import (
@@ -221,12 +222,14 @@ def _remove_special_tokens(cls, text: str) -> str:
221222
return new_text
222223

223224
@classmethod
224-
def sanitize(cls, chat_completion, parts="all"):
225+
def sanitize(
226+
cls, chat_completion: Granite3ChatCompletion, parts: list[str] | str = "all"
227+
) -> Granite3ChatCompletion:
225228
"""Sanitize the chat completion by removing Granite 3.2 special tokens.
226229
227230
Args:
228231
chat_completion: The chat completion request to sanitize.
229-
parts (str): Which parts of the chat completion to sanitize;
232+
parts (list[str] | str): Which parts of the chat completion to sanitize;
230233
defaults to ``"all"``.
231234
232235
Returns:

0 commit comments

Comments
 (0)