Skip to content

Commit 2a27d48

Browse files
authored
Fix pyright type resolution for Pipeline parameter (#802)
Direct import of Pipeline from redis.asyncio.client allows pyright to correctly resolve the type instead of showing Unknown.
1 parent 9d03964 commit 2a27d48

2 files changed

Lines changed: 37 additions & 16 deletions

File tree

aredis_om/model/model.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,21 @@
88
from copy import copy
99
from enum import Enum
1010
from functools import reduce
11-
from typing import (Any, Callable, Dict, List, Literal, Mapping, Optional,
12-
Sequence, Set, Tuple, Type, TypeVar, Union)
11+
from typing import (
12+
Any,
13+
Callable,
14+
Dict,
15+
List,
16+
Literal,
17+
Mapping,
18+
Optional,
19+
Sequence,
20+
Set,
21+
Tuple,
22+
Type,
23+
TypeVar,
24+
Union,
25+
)
1326
from typing import get_args as typing_get_args
1427
from typing import no_type_check
1528

@@ -43,6 +56,7 @@
4356
_FromFieldInfoInputs = dict
4457
Undefined = ...
4558
UndefinedType = type(...)
59+
from redis.asyncio.client import Pipeline
4660
from redis.commands.json.path import Path
4761
from redis.exceptions import ResponseError
4862
from typing_extensions import Protocol, Unpack, get_args, get_origin
@@ -2719,9 +2733,7 @@ async def _delete(cls, db, *pks):
27192733
return await db.delete(*pks)
27202734

27212735
@classmethod
2722-
async def delete(
2723-
cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None
2724-
) -> int:
2736+
async def delete(cls, pk: Any, pipeline: Optional[Pipeline] = None) -> int:
27252737
"""Delete data at this key."""
27262738
db = cls._get_db(pipeline)
27272739

@@ -2737,7 +2749,7 @@ async def update(self, **field_values):
27372749

27382750
async def save(
27392751
self: "Model",
2740-
pipeline: Optional[redis.client.Pipeline] = None,
2752+
pipeline: Optional[Pipeline] = None,
27412753
nx: bool = False,
27422754
xx: bool = False,
27432755
) -> Optional["Model"]:
@@ -2757,9 +2769,7 @@ async def save(
27572769
"""
27582770
raise NotImplementedError
27592771

2760-
async def expire(
2761-
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
2762-
):
2772+
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
27632773
db = self._get_db(pipeline)
27642774

27652775
# TODO: Wrap any Redis response errors in a custom exception?
@@ -2905,7 +2915,7 @@ def get_annotations(cls):
29052915
async def add(
29062916
cls: Type["Model"],
29072917
models: Sequence["Model"],
2908-
pipeline: Optional[redis.client.Pipeline] = None,
2918+
pipeline: Optional[Pipeline] = None,
29092919
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
29102920
) -> Sequence["Model"]:
29112921
db = cls._get_db(pipeline, bulk=True)
@@ -2923,9 +2933,7 @@ async def add(
29232933
return models
29242934

29252935
@classmethod
2926-
def _get_db(
2927-
self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False
2928-
):
2936+
def _get_db(self, pipeline: Optional[Pipeline] = None, bulk: bool = False):
29292937
if pipeline is not None:
29302938
return pipeline
29312939
elif bulk:
@@ -2937,7 +2945,7 @@ def _get_db(
29372945
async def delete_many(
29382946
cls,
29392947
models: Sequence["RedisModel"],
2940-
pipeline: Optional[redis.client.Pipeline] = None,
2948+
pipeline: Optional[Pipeline] = None,
29412949
) -> int:
29422950
db = cls._get_db(pipeline)
29432951

@@ -3069,7 +3077,7 @@ def _get_field_expirations(
30693077

30703078
async def save(
30713079
self: "Model",
3072-
pipeline: Optional[redis.client.Pipeline] = None,
3080+
pipeline: Optional[Pipeline] = None,
30733081
nx: bool = False,
30743082
xx: bool = False,
30753083
field_expirations: Optional[Dict[str, int]] = None,
@@ -3479,7 +3487,7 @@ def __init__(self, *args, **kwargs):
34793487

34803488
async def save(
34813489
self: "Model",
3482-
pipeline: Optional[redis.client.Pipeline] = None,
3490+
pipeline: Optional[Pipeline] = None,
34833491
nx: bool = False,
34843492
xx: bool = False,
34853493
) -> Optional["Model"]:

make_sync.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,20 @@ def remove_run_async_call(match):
111111
with open(file_path, 'w') as f:
112112
f.write(content)
113113

114+
# Post-process model.py to fix async imports for sync version
115+
model_file = Path(__file__).absolute().parent / "redis_om/model/model.py"
116+
if model_file.exists():
117+
with open(model_file, 'r') as f:
118+
content = f.read()
114119

120+
# Fix Pipeline import: redis.asyncio.client -> redis.client
121+
content = content.replace(
122+
'from redis.asyncio.client import Pipeline',
123+
'from redis.client import Pipeline'
124+
)
125+
126+
with open(model_file, 'w') as f:
127+
f.write(content)
115128

116129

117130
if __name__ == "__main__":

0 commit comments

Comments
 (0)