Skip to content

Commit f5bced2

Browse files
committed
Add a python script that auto generates async/sync stubs from executor defs
1 parent dcc9f3f commit f5bced2

221 files changed

Lines changed: 16186 additions & 9739 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ exclude =
1414
weaviate/collections/classes/orm.py,
1515
weaviate/proto/**/*.py,
1616
build/
17+
tools/stubs.py,
1718
ignore = D100, D101, D102, D103, D104, D105, D107, E203, E266, E501, E704, E731, W503, DOC301
1819
per-file-ignores =
1920
weaviate/cluster/types.py: A005

.pre-commit-config.yaml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
exclude: ^proto/
22
repos:
3-
- repo: https://github.com/psf/black-pre-commit-mirror
4-
rev: 24.10.0
5-
hooks:
6-
- id: black
7-
language_version: python3.12
3+
- repo: local
4+
hooks:
5+
- id: stubs-autogen
6+
name: stubs-autogen
7+
language: system
8+
entry: ./tools/stubs_regen.sh
9+
10+
- repo: https://github.com/psf/black-pre-commit-mirror
11+
rev: 24.10.0
12+
hooks:
13+
- id: black
14+
language_version: python3.12
815

916
- repo: https://github.com/pre-commit/pre-commit-hooks
1017
rev: v4.6.0

tools/__init__.py

Whitespace-only changes.

tools/stubs.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import ast
2+
import inspect
3+
import textwrap
4+
from collections import defaultdict
5+
from typing import Literal, cast
6+
7+
8+
class ExecutorTransformer(ast.NodeTransformer):
9+
def __init__(self, colour: Literal["async", "sync"]):
10+
self.colour = colour
11+
self.executor_names = []
12+
13+
def visit_ClassDef(self, node):
14+
self.executor_names.append(node.name)
15+
node.bases = self.__parse_generics(node)
16+
node.body = self.__parse_body(node)
17+
node.name = node.name.replace(
18+
"Executor", "" if self.colour == "sync" else self.colour.capitalize()
19+
)
20+
self.generic_visit(node)
21+
return node
22+
23+
def __is_overload(self, fn: ast.FunctionDef):
24+
return any(isinstance(d, ast.Name) and d.id == "overload" for d in fn.decorator_list)
25+
26+
def __parse_body(self, node: ast.ClassDef):
27+
funcs_by_name = defaultdict(list)
28+
for stmt in node.body:
29+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
30+
funcs_by_name[stmt.name].append(stmt)
31+
32+
new_body: list[ast.stmt] = []
33+
for stmt in node.body:
34+
if isinstance(stmt, ast.FunctionDef) and stmt.name.startswith("__"):
35+
continue # Skip all dunder methods
36+
if isinstance(stmt, ast.FunctionDef):
37+
overloads = funcs_by_name[stmt.name]
38+
if any(self.__is_overload(f) for f in overloads):
39+
if not self.__is_overload(stmt):
40+
continue # skip the impl
41+
new_body.append(stmt)
42+
return new_body
43+
44+
def __parse_generics(self, node: ast.ClassDef):
45+
new_bases: list[ast.expr] = []
46+
for base in node.bases:
47+
if not isinstance(base, ast.Subscript):
48+
continue
49+
if isinstance(base.value, ast.Name) and base.value.id == "Generic":
50+
# This is a generic class
51+
# We need to extract the type arguments
52+
if isinstance(base.slice, ast.Tuple):
53+
# This is a tuple of types
54+
# must remove `ConnectionType` if there
55+
generics = [
56+
arg.id
57+
for arg in base.slice.elts
58+
if isinstance(arg, ast.Name)
59+
if arg.id != "ConnectionType"
60+
]
61+
new_bases.append(
62+
ast.Subscript(
63+
value=base.value,
64+
slice=ast.Tuple(
65+
elts=[ast.Name(id=arg) for arg in generics], ctx=ast.Load()
66+
),
67+
ctx=ast.Load(),
68+
)
69+
)
70+
elif isinstance(base.slice, ast.Name):
71+
# This is a single type
72+
if base.slice.id == "ConnectionType":
73+
# We don't want to include ConnectionType
74+
continue
75+
new_bases.append(base)
76+
# if isinstance(base.value, ast.Name) and base.value.id == "_BaseExecutor":
77+
# # This is class from collections/queries
78+
# return []
79+
connection_type = ast.Name(id=self.__which_connection_type(), ctx=ast.Load())
80+
if len(new_bases) == 0:
81+
# no generics, we need to add the ConnectionType
82+
slice = connection_type
83+
else:
84+
elts: list[ast.expr] = []
85+
for base in new_bases:
86+
assert isinstance(base, ast.Subscript)
87+
slice = base.slice
88+
assert isinstance(slice, ast.Tuple)
89+
elts.extend(slice.elts)
90+
slice = ast.Tuple(elts=[connection_type, *elts], ctx=ast.Load())
91+
new_bases.append(
92+
ast.Subscript(
93+
value=ast.Name(id=node.name, ctx=ast.Load()),
94+
slice=slice,
95+
ctx=ast.Load(),
96+
)
97+
)
98+
return new_bases
99+
100+
def __which_connection_type(self):
101+
return "ConnectionAsync" if self.colour == "async" else "ConnectionSync"
102+
103+
def __extract_inner_return_type(self, node: ast.expr | None) -> ast.expr | None:
104+
# Looking for executor.Result[T]
105+
if (
106+
isinstance(node, ast.Subscript)
107+
and isinstance(node.value, ast.Attribute)
108+
and isinstance(node.value.value, ast.Name)
109+
and node.value.value.id == "executor"
110+
and node.value.attr == "Result"
111+
):
112+
# This is executor.Result[...]
113+
return node.slice # Return T
114+
return node # fallback, return original if not matching
115+
116+
def visit_FunctionDef(self, node):
117+
func_def = ast.AsyncFunctionDef if self.colour == "async" else ast.FunctionDef
118+
new_node = func_def(
119+
name=node.name,
120+
args=node.args,
121+
body=[ast.Expr(value=ast.Constant(value=Ellipsis))],
122+
decorator_list=node.decorator_list,
123+
returns=self.__extract_inner_return_type(node.returns),
124+
type_comment=node.type_comment,
125+
)
126+
return ast.copy_location(new_node, node)
127+
128+
129+
from weaviate.collections.aggregations.hybrid import executor as agg_hybrid
130+
from weaviate.collections.aggregations.near_image import executor as agg_near_image
131+
from weaviate.collections.aggregations.near_object import executor as agg_near_object
132+
from weaviate.collections.aggregations.near_text import executor as agg_near_text
133+
from weaviate.collections.aggregations.near_vector import executor as agg_near_vector
134+
from weaviate.collections.aggregations.over_all import executor as agg_over_all
135+
from weaviate.collections.backups import executor as backups
136+
from weaviate.collections.cluster import executor as cluster
137+
from weaviate.collections.config import executor as config
138+
from weaviate.collections.data import executor as data
139+
from weaviate.collections.queries.bm25.generate import executor as generate_bm25
140+
from weaviate.collections.queries.bm25.query import executor as query_bm25
141+
from weaviate.collections.queries.fetch_object_by_id import executor as fetch_object_by_id
142+
from weaviate.collections.queries.fetch_objects.generate import executor as generate_fetch_objects
143+
from weaviate.collections.queries.fetch_objects.query import executor as query_fetch_objects
144+
from weaviate.collections.queries.fetch_objects_by_ids.generate import (
145+
executor as generate_fetch_objects_by_ids,
146+
)
147+
from weaviate.collections.queries.fetch_objects_by_ids.query import (
148+
executor as query_fetch_objects_by_ids,
149+
)
150+
from weaviate.collections.queries.hybrid.generate import executor as generate_hybrid
151+
from weaviate.collections.queries.hybrid.query import executor as query_hybrid
152+
from weaviate.collections.queries.near_image.generate import executor as generate_near_image
153+
from weaviate.collections.queries.near_image.query import executor as query_near_image
154+
from weaviate.collections.queries.near_media.generate import executor as generate_near_media
155+
from weaviate.collections.queries.near_media.query import executor as query_near_media
156+
from weaviate.collections.queries.near_object.generate import executor as generate_near_object
157+
from weaviate.collections.queries.near_object.query import executor as query_near_object
158+
from weaviate.collections.queries.near_text.generate import executor as generate_near_text
159+
from weaviate.collections.queries.near_text.query import executor as query_near_text
160+
from weaviate.collections.queries.near_vector.generate import executor as generate_near_vector
161+
from weaviate.collections.queries.near_vector.query import executor as query_near_vector
162+
from weaviate.debug import executor as debug
163+
from weaviate.rbac import executor as rbac
164+
from weaviate.collections.tenants import executor as tenants
165+
from weaviate.users import executor as users
166+
167+
for module in [
168+
agg_hybrid,
169+
agg_near_image,
170+
agg_near_object,
171+
agg_near_text,
172+
agg_near_vector,
173+
agg_over_all,
174+
backups,
175+
cluster,
176+
config,
177+
data,
178+
debug,
179+
generate_bm25,
180+
generate_fetch_objects,
181+
generate_fetch_objects_by_ids,
182+
generate_hybrid,
183+
generate_near_image,
184+
generate_near_media,
185+
generate_near_object,
186+
generate_near_text,
187+
generate_near_vector,
188+
fetch_object_by_id,
189+
query_bm25,
190+
query_fetch_objects,
191+
query_fetch_objects_by_ids,
192+
query_hybrid,
193+
query_near_image,
194+
query_near_media,
195+
query_near_object,
196+
query_near_text,
197+
query_near_vector,
198+
rbac,
199+
tenants,
200+
users,
201+
]:
202+
source = textwrap.dedent(inspect.getsource(module))
203+
204+
colours: list[Literal["sync", "async"]] = ["sync", "async"]
205+
for colour in colours:
206+
tree = ast.parse(source, mode="exec", type_comments=True)
207+
208+
transformer = ExecutorTransformer(colour)
209+
stubbed = transformer.visit(tree)
210+
211+
imports = [
212+
node for node in stubbed.body if isinstance(node, (ast.Import, ast.ImportFrom))
213+
] + [
214+
ast.ImportFrom(
215+
module="weaviate.connect.v4",
216+
names=[ast.alias(name=f"Connection{colour.capitalize()}", asname=None)],
217+
level=0,
218+
),
219+
ast.ImportFrom(
220+
module=".executor",
221+
names=[ast.alias(name=name, asname=None) for name in transformer.executor_names],
222+
level=0,
223+
),
224+
]
225+
stubbed.body = imports + [node for node in stubbed.body if isinstance(node, ast.ClassDef)]
226+
ast.fix_missing_locations(stubbed)
227+
228+
dir = cast(str, module.__package__).replace(".", "/")
229+
file = f"{dir}/{colour}.pyi" if colour == "sync" else f"{dir}/{colour}_.pyi"
230+
with open(file, "w") as f:
231+
print(f"Writing {file}")
232+
f.write(ast.unparse(stubbed))

tools/stubs_regen.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
echo "Regenerating stubs..."
4+
5+
python3 -m tools.stubs
6+
black ./weaviate
7+
8+
echo "done"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from .aggregate import _Hybrid, _HybridAsync
1+
from .async_ import _HybridAsync
2+
from .sync import _Hybrid
23

34
__all__ = ["_Hybrid", "_HybridAsync"]

weaviate/collections/aggregations/hybrid/aggregate.pyi

Lines changed: 0 additions & 112 deletions
This file was deleted.

0 commit comments

Comments
 (0)