Skip to content

Commit b2bc66d

Browse files
authored
Merge pull request #1646 from weaviate/add-stubs-autogen-tool
Add a python script that auto generates async/sync stubs from executor defs
2 parents e5f92d0 + 41ba685 commit b2bc66d

225 files changed

Lines changed: 16261 additions & 9742 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ exclude =
1414
weaviate/collections/classes/orm.py,
1515
weaviate/proto/**/*.py,
1616
build/
17+
tools/stubs.py,
1718
ignore = D100, D101, D102, D103, D104, D105, D107, E203, E266, E501, E704, E731, W503, DOC301
1819
per-file-ignores =
1920
weaviate/cluster/types.py: A005
2021
weaviate/collections/classes/types.py: A005
2122
weaviate/collections/collections/__init__.py: A005
2223
weaviate/collections/__init__.py: A005
2324
weaviate/debug/types.py: A005
25+
weaviate/collections/tenants/types.py: A005
2426
weaviate/types.py: A005
2527
weaviate/warnings.py: A005
2628
test/*: D100, D101, D102, D103, D104, D105, D107, PYD001

.pre-commit-config.yaml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
exclude: ^proto/
22
repos:
3-
- repo: https://github.com/psf/black-pre-commit-mirror
4-
rev: 24.10.0
5-
hooks:
6-
- id: black
7-
language_version: python3.12
3+
- repo: local
4+
hooks:
5+
- id: stubs-autogen
6+
name: stubs-autogen
7+
language: system
8+
entry: ./tools/stubs_regen.sh
9+
10+
- repo: https://github.com/psf/black-pre-commit-mirror
11+
rev: 24.10.0
12+
hooks:
13+
- id: black
14+
language_version: python3.12
815

916
- repo: https://github.com/pre-commit/pre-commit-hooks
1017
rev: v4.6.0

tools/__init__.py

Whitespace-only changes.

tools/stubs.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import ast
2+
import importlib
3+
import inspect
4+
import os
5+
import textwrap
6+
from collections import defaultdict
7+
from typing import Literal, cast
8+
9+
10+
class ExecutorTransformer(ast.NodeTransformer):
11+
def __init__(self, colour: Literal["async", "sync"]):
12+
self.colour = colour
13+
self.executor_names = []
14+
15+
def visit_ClassDef(self, node):
16+
self.executor_names.append(node.name)
17+
node.bases = self.__parse_generics(node)
18+
node.body = self.__parse_body(node)
19+
node.name = node.name.replace(
20+
"Executor", "" if self.colour == "sync" else self.colour.capitalize()
21+
)
22+
self.generic_visit(node)
23+
return node
24+
25+
def __is_overload(self, fn: ast.FunctionDef):
26+
return any(isinstance(d, ast.Name) and d.id == "overload" for d in fn.decorator_list)
27+
28+
def __parse_body(self, node: ast.ClassDef):
29+
funcs_by_name = defaultdict(list)
30+
for stmt in node.body:
31+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
32+
funcs_by_name[stmt.name].append(stmt)
33+
34+
new_body: list[ast.stmt] = []
35+
for stmt in node.body:
36+
if isinstance(stmt, ast.FunctionDef) and stmt.name.startswith("__"):
37+
continue # Skip all dunder methods
38+
if isinstance(stmt, ast.FunctionDef):
39+
overloads = funcs_by_name[stmt.name]
40+
if any(self.__is_overload(f) for f in overloads):
41+
if not self.__is_overload(stmt):
42+
continue # skip the impl
43+
new_body.append(stmt)
44+
return new_body
45+
46+
def __parse_generics(self, node: ast.ClassDef):
47+
new_bases: list[ast.expr] = []
48+
for base in node.bases:
49+
if not isinstance(base, ast.Subscript):
50+
continue
51+
if isinstance(base.value, ast.Name) and base.value.id == "Generic":
52+
# This is a generic class
53+
# We need to extract the type arguments
54+
if isinstance(base.slice, ast.Tuple):
55+
# This is a tuple of types
56+
# must remove `ConnectionType` if there
57+
generics = [
58+
arg.id
59+
for arg in base.slice.elts
60+
if isinstance(arg, ast.Name)
61+
if arg.id != "ConnectionType"
62+
]
63+
new_bases.append(
64+
ast.Subscript(
65+
value=base.value,
66+
slice=ast.Tuple(
67+
elts=[ast.Name(id=arg) for arg in generics], ctx=ast.Load()
68+
),
69+
ctx=ast.Load(),
70+
)
71+
)
72+
elif isinstance(base.slice, ast.Name):
73+
# This is a single type
74+
if base.slice.id == "ConnectionType":
75+
# We don't want to include ConnectionType
76+
continue
77+
new_bases.append(base)
78+
connection_type = ast.Name(id=self.__which_connection_type(), ctx=ast.Load())
79+
if len(new_bases) == 0:
80+
# no generics, we need to add the ConnectionType
81+
slice = connection_type
82+
else:
83+
elts: list[ast.expr] = []
84+
for base in new_bases:
85+
assert isinstance(base, ast.Subscript)
86+
slice = base.slice
87+
assert isinstance(slice, ast.Tuple)
88+
elts.extend(slice.elts)
89+
slice = ast.Tuple(elts=[connection_type, *elts], ctx=ast.Load())
90+
new_bases.append(
91+
ast.Subscript(
92+
value=ast.Name(id=node.name, ctx=ast.Load()),
93+
slice=slice,
94+
ctx=ast.Load(),
95+
)
96+
)
97+
return new_bases
98+
99+
def __which_connection_type(self):
100+
return "ConnectionAsync" if self.colour == "async" else "ConnectionSync"
101+
102+
def __extract_inner_return_type(self, node: ast.expr | None) -> ast.expr | None:
103+
# Looking for executor.Result[T]
104+
if (
105+
isinstance(node, ast.Subscript)
106+
and isinstance(node.value, ast.Attribute)
107+
and isinstance(node.value.value, ast.Name)
108+
and node.value.value.id == "executor"
109+
and node.value.attr == "Result"
110+
):
111+
# This is executor.Result[...]
112+
return node.slice # Return T
113+
return node # fallback, return original if not matching
114+
115+
def visit_FunctionDef(self, node):
116+
func_def = ast.AsyncFunctionDef if self.colour == "async" else ast.FunctionDef
117+
new_node = func_def(
118+
name=node.name,
119+
args=node.args,
120+
body=[ast.Expr(value=ast.Constant(value=Ellipsis))],
121+
decorator_list=node.decorator_list,
122+
returns=self.__extract_inner_return_type(node.returns),
123+
type_comment=node.type_comment,
124+
)
125+
return ast.copy_location(new_node, node)
126+
127+
128+
for subdir, dirs, files in os.walk("./weaviate"):
129+
for file in files:
130+
if file != "executor.py":
131+
continue
132+
if "connect" in subdir:
133+
# ignore weaviate/connect/executor.py file
134+
continue
135+
if "collections/collections" in subdir:
136+
# ignore weaviate/collections/collections directory
137+
continue
138+
139+
mod = os.path.join(subdir, file)
140+
mod = mod[2:] # remove the leading dot and slash
141+
mod = mod[:-3] # remove the .py
142+
mod = mod.replace("/", ".") # convert into pythonic import
143+
144+
module = importlib.import_module(mod)
145+
source = textwrap.dedent(inspect.getsource(module))
146+
147+
colours: list[Literal["sync", "async"]] = ["sync", "async"]
148+
for colour in colours:
149+
tree = ast.parse(source, mode="exec", type_comments=True)
150+
151+
transformer = ExecutorTransformer(colour)
152+
stubbed = transformer.visit(tree)
153+
154+
imports = [
155+
node for node in stubbed.body if isinstance(node, (ast.Import, ast.ImportFrom))
156+
] + [
157+
ast.ImportFrom(
158+
module="weaviate.connect.v4",
159+
names=[ast.alias(name=f"Connection{colour.capitalize()}", asname=None)],
160+
level=0,
161+
),
162+
ast.ImportFrom(
163+
module=".executor",
164+
names=[
165+
ast.alias(name=name, asname=None) for name in transformer.executor_names
166+
],
167+
level=0,
168+
),
169+
]
170+
stubbed.body = imports + [
171+
node for node in stubbed.body if isinstance(node, ast.ClassDef)
172+
]
173+
ast.fix_missing_locations(stubbed)
174+
175+
dir = cast(str, module.__package__).replace(".", "/")
176+
file = f"{dir}/{colour}.pyi" if colour == "sync" else f"{dir}/{colour}_.pyi"
177+
with open(file, "w") as f:
178+
print(f"Writing {file}")
179+
f.write(ast.unparse(stubbed))

tools/stubs_regen.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
echo "Regenerating stubs..."
4+
5+
python3 -m tools.stubs
6+
black ./weaviate
7+
8+
echo "done"

weaviate/backup/async_.pyi

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
1-
from typing import Optional, Union, List
2-
3-
from weaviate.backup.base import _BackupBase
4-
from weaviate.backup.executor import (
1+
import asyncio
2+
import time
3+
from typing import Generic, Optional, Union, List, Tuple, Dict
4+
from httpx import Response
5+
from weaviate.backup.backup import (
56
BackupStorage,
67
BackupReturn,
78
BackupStatusReturn,
9+
STORAGE_NAMES,
810
BackupConfigCreate,
11+
BackupStatus,
912
BackupConfigRestore,
10-
_BackupExecutor,
1113
)
1214
from weaviate.backup.backup_location import BackupLocationType
15+
from weaviate.connect import executor
16+
from weaviate.connect.v4 import _ExpectedStatusCodes, Connection, ConnectionAsync, ConnectionType
17+
from weaviate.exceptions import (
18+
WeaviateInvalidInputError,
19+
WeaviateUnsupportedFeatureError,
20+
BackupFailedException,
21+
EmptyResponseException,
22+
BackupCanceledError,
23+
)
24+
from weaviate.util import _capitalize_first_letter, _decode_json_response_dict
1325
from weaviate.connect.v4 import ConnectionAsync
26+
from .executor import _BackupExecutor
1427

1528
class _BackupAsync(_BackupExecutor[ConnectionAsync]):
16-
"""Backup class used to schedule and/or check the status of a backup process of Weaviate objects."""
17-
18-
async def cancel(
19-
self,
20-
backup_id: str,
21-
backend: BackupStorage,
22-
backup_location: Optional[BackupLocationType] = None,
23-
) -> bool: ...
2429
async def create(
2530
self,
2631
backup_id: str,
2732
backend: BackupStorage,
28-
include_collections: Optional[Union[List[str], str]] = None,
29-
exclude_collections: Optional[Union[List[str], str]] = None,
33+
include_collections: Union[List[str], str, None] = None,
34+
exclude_collections: Union[List[str], str, None] = None,
3035
wait_for_completion: bool = False,
3136
config: Optional[BackupConfigCreate] = None,
3237
backup_location: Optional[BackupLocationType] = None,
@@ -53,3 +58,9 @@ class _BackupAsync(_BackupExecutor[ConnectionAsync]):
5358
backend: BackupStorage,
5459
backup_location: Optional[BackupLocationType] = None,
5560
) -> BackupStatusReturn: ...
61+
async def cancel(
62+
self,
63+
backup_id: str,
64+
backend: BackupStorage,
65+
backup_location: Optional[BackupLocationType] = None,
66+
) -> bool: ...

weaviate/backup/sync.pyi

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
1-
from typing import Optional, Union, List
2-
3-
from weaviate.backup.base import _BackupBase
4-
from weaviate.backup.executor import (
1+
import asyncio
2+
import time
3+
from typing import Generic, Optional, Union, List, Tuple, Dict
4+
from httpx import Response
5+
from weaviate.backup.backup import (
56
BackupStorage,
67
BackupReturn,
78
BackupStatusReturn,
9+
STORAGE_NAMES,
810
BackupConfigCreate,
11+
BackupStatus,
912
BackupConfigRestore,
10-
_BackupExecutor,
1113
)
1214
from weaviate.backup.backup_location import BackupLocationType
15+
from weaviate.connect import executor
16+
from weaviate.connect.v4 import _ExpectedStatusCodes, Connection, ConnectionAsync, ConnectionType
17+
from weaviate.exceptions import (
18+
WeaviateInvalidInputError,
19+
WeaviateUnsupportedFeatureError,
20+
BackupFailedException,
21+
EmptyResponseException,
22+
BackupCanceledError,
23+
)
24+
from weaviate.util import _capitalize_first_letter, _decode_json_response_dict
1325
from weaviate.connect.v4 import ConnectionSync
26+
from .executor import _BackupExecutor
1427

1528
class _Backup(_BackupExecutor[ConnectionSync]):
16-
"""Backup class used to schedule and/or check the status of a backup process of Weaviate objects."""
17-
18-
def cancel(
19-
self,
20-
backup_id: str,
21-
backend: BackupStorage,
22-
backup_location: Optional[BackupLocationType] = None,
23-
) -> bool: ...
2429
def create(
2530
self,
2631
backup_id: str,
2732
backend: BackupStorage,
28-
include_collections: Optional[Union[List[str], str]] = None,
29-
exclude_collections: Optional[Union[List[str], str]] = None,
33+
include_collections: Union[List[str], str, None] = None,
34+
exclude_collections: Union[List[str], str, None] = None,
3035
wait_for_completion: bool = False,
3136
config: Optional[BackupConfigCreate] = None,
3237
backup_location: Optional[BackupLocationType] = None,
@@ -53,3 +58,9 @@ class _Backup(_BackupExecutor[ConnectionSync]):
5358
backend: BackupStorage,
5459
backup_location: Optional[BackupLocationType] = None,
5560
) -> BackupStatusReturn: ...
61+
def cancel(
62+
self,
63+
backup_id: str,
64+
backend: BackupStorage,
65+
backup_location: Optional[BackupLocationType] = None,
66+
) -> bool: ...
File renamed without changes.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .aggregate import _Hybrid, _HybridAsync
1+
from .async_ import _HybridAsync
2+
from .sync import _Hybrid
23

34
__all__ = ["_Hybrid", "_HybridAsync"]

0 commit comments

Comments
 (0)