Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 202 additions & 5 deletions utils/check_forward_call_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
pipelines) match the method's docstring exactly:

* every signature argument has an entry in the ``Args:`` /
``Arguments:`` / ``Parameters:`` section, and
``Arguments:`` / ``Parameters:`` section,
* every documented argument still exists in the signature
(stale entries from removed/renamed args are flagged).
(stale entries from removed/renamed args are flagged), and
* when the method has a non-``None`` return annotation, the docstring has
a ``Returns:`` / ``Return:`` / ``Yields:`` section.

A "main" class is detected via its base classes — models inherit from
``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods
Expand All @@ -33,6 +35,11 @@
Optionally restrict to specific files:

python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py

Auto-fix stale (documented-but-removed) entries — missing entries are never
auto-added (no placeholders), only stale ones are removed:

python utils/check_forward_call_docstrings.py --fix
"""

from __future__ import annotations
Expand Down Expand Up @@ -93,6 +100,17 @@ def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef |
return None


def _docstring_node(func: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.Expr | None:
if (
func.body
and isinstance(func.body[0], ast.Expr)
and isinstance(func.body[0].value, ast.Constant)
and isinstance(func.body[0].value.value, str)
):
return func.body[0]
return None


def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]:
args = func.args
collected: list[str] = []
Expand All @@ -103,6 +121,30 @@ def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[s
return collected


def _has_meaningful_return(func: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""True iff the method has a return annotation other than ``None`` or ``NoReturn``."""
ret = func.returns
if ret is None: # no annotation at all
return False
if isinstance(ret, ast.Constant) and ret.value is None: # `-> None`
return False
# `-> NoReturn` or `-> typing.NoReturn`
if isinstance(ret, ast.Name) and ret.id == "NoReturn":
return False
if isinstance(ret, ast.Attribute) and ret.attr == "NoReturn":
return False
return True


def _has_returns_section(docstring: str | None) -> bool:
if not docstring:
return False
for line in docstring.splitlines():
if line.strip() in {"Returns:", "Return:", "Yields:", "Yield:"}:
return True
return False


def _extract_documented_args(docstring: str | None) -> set[str]:
"""Extract argument names listed in an Args/Arguments/Parameters section.

Expand Down Expand Up @@ -180,10 +222,9 @@ def check_file(path: Path, kind: str) -> list[str]:
if method is None:
continue
sig_args = _signature_arg_names(method)
if not sig_args:
continue
sig_set = set(sig_args)
documented = _extract_documented_args(ast.get_docstring(method))
docstring_text = ast.get_docstring(method)
documented = _extract_documented_args(docstring_text)
missing = [a for a in sig_args if a not in documented]
stale = sorted(documented - sig_set)
if missing:
Expand All @@ -196,9 +237,137 @@ def check_file(path: Path, kind: str) -> list[str]:
f"{rel}:{method.lineno}: {node.name}.{method_name} documents "
f"argument(s) not in the signature: {', '.join(stale)}"
)
if _has_meaningful_return(method) and not _has_returns_section(docstring_text):
return_repr = ast.unparse(method.returns)
ds = _docstring_node(method)
if ds is None:
where = " (method has no docstring)"
else:
where = f' (add it just above the closing """ on line {ds.end_lineno})'
errors.append(
f"{rel}:{method.lineno}: {node.name}.{method_name} returns "
f"`{return_repr}` but the docstring has no Returns: section{where}"
)
return errors


def fix_file(path: Path, kind: str) -> list[str]:
"""Remove stale arg entries (documented but not in signature) in-place.

Missing-in-signature → docstring entries are NOT added (no placeholders).
Returns a list of ``"ClassName.method: removed name1, name2"`` strings
describing what was removed.
"""
method_name = "forward" if kind == "model" else "__call__"
base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE

source = path.read_text(encoding="utf-8")
try:
tree = ast.parse(source)
except (SyntaxError, UnicodeDecodeError):
return []

lines = source.splitlines(keepends=True)
# (start_idx, end_idx_exclusive) ranges of lines to drop.
deletions: list[tuple[int, int]] = []
summaries: list[str] = []

for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
if base_class not in _base_class_names(node):
continue
method = _find_method(node, method_name)
if method is None:
continue
# Method must start with a string docstring expression.
if not (
method.body
and isinstance(method.body[0], ast.Expr)
and isinstance(method.body[0].value, ast.Constant)
and isinstance(method.body[0].value.value, str)
):
continue

sig_set = set(_signature_arg_names(method))
documented = _extract_documented_args(ast.get_docstring(method))
stale = documented - sig_set
if not stale:
continue

docstring_expr = method.body[0]
doc_start = docstring_expr.lineno - 1 # 0-indexed
doc_end = docstring_expr.end_lineno - 1 # 0-indexed, inclusive

# Locate the Args/Arguments/Parameters header in raw source.
args_idx: int | None = None
header_indent = 0
for i in range(doc_start, doc_end + 1):
stripped = lines[i].strip()
if stripped in {"Args:", "Arguments:", "Parameters:"}:
args_idx = i
header_indent = len(lines[i]) - len(lines[i].lstrip())
break
if args_idx is None:
continue

# First non-empty line after the header sets the per-entry indent.
entry_indent: int | None = None
for i in range(args_idx + 1, doc_end + 1):
stripped = lines[i].strip()
if not stripped:
continue
entry_indent = len(lines[i]) - len(lines[i].lstrip())
break
if entry_indent is None or entry_indent <= header_indent:
continue

# Walk entries; each entry spans from its header line up to (but not
# including) the next entry header / section header / end of docstring.
current_name: str | None = None
current_start: int = -1
end_of_args: int | None = None

for i in range(args_idx + 1, doc_end + 1):
line = lines[i]
stripped = line.strip()
if not stripped:
continue
indent = len(line) - len(line.lstrip())

if indent <= header_indent and stripped in SECTION_HEADERS:
end_of_args = i
break

if indent == entry_indent:
m = _ARG_HEADER_RE.match(stripped)
if m:
if current_name in stale:
deletions.append((current_start, i))
current_name = m.group(1)
current_start = i

if current_name in stale:
end = end_of_args if end_of_args is not None else doc_end
# Trailing blank lines belong to inter-section spacing (or the
# blank line before the closing """), not to this entry.
while end > current_start + 1 and not lines[end - 1].strip():
end -= 1
deletions.append((current_start, end))

summaries.append(f"{node.name}.{method_name}: removed {', '.join(sorted(stale))}")

if not deletions:
return []

deletions.sort()
new_lines = list(lines)
for start, end in reversed(deletions):
del new_lines[start:end]
path.write_text("".join(new_lines), encoding="utf-8")
return summaries


def _kind_for_path(path: Path) -> str | None:
parts = path.resolve().parts
if "pipelines" in parts:
Expand All @@ -224,6 +393,15 @@ def main() -> int:
"(in sorted order) from each of models/ and pipelines/."
),
)
parser.add_argument(
"--fix",
action="store_true",
help=(
"Remove stale (documented-but-not-in-signature) argument entries from "
"docstrings in-place. Missing-in-docstring entries are NOT auto-added "
"(no placeholders) and will still be reported."
),
)
args = parser.parse_args()

targets: list[tuple[Path, str]] = []
Expand Down Expand Up @@ -253,6 +431,17 @@ def main() -> int:
for p in pipeline_files:
targets.append((p, "pipeline"))

if args.fix:
fix_summaries: list[str] = []
for path, kind in targets:
for summary in fix_file(path, kind):
fix_summaries.append(f"{path.relative_to(REPO_ROOT)}: {summary}")
if fix_summaries:
print("Removed stale docstring entries:")
print("\n".join(f" {s}" for s in fix_summaries))
else:
print("No stale docstring entries to remove.")

all_errors: list[str] = []
for path, kind in targets:
all_errors.extend(check_file(path, kind))
Expand All @@ -263,6 +452,14 @@ def main() -> int:
f"\nFound {len(all_errors)} docstring/signature mismatch(es).",
file=sys.stderr,
)
if not args.fix and any("documents argument(s) not in the signature" in e for e in all_errors):
print(
"Hint: run `python utils/check_forward_call_docstrings.py --fix` "
"to remove the stale argument entries flagged above. "
"(Missing-in-docstring entries must be added manually — the tool "
"never inserts placeholders.)",
file=sys.stderr,
)
return 1

print("All forward/__call__ arguments are documented.")
Expand Down
Loading