From 72b0cd4b890157bc1e864fa688a861d880219e23 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Fri, 30 Jan 2026 13:16:28 -0800 Subject: [PATCH] Fix pyright type resolution for Pipeline parameter Direct import of Pipeline from redis.asyncio.client allows pyright to correctly resolve the type instead of showing Unknown. --- aredis_om/model/model.py | 40 ++++++++++++++++++++++++---------------- make_sync.py | 13 +++++++++++++ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 07180b8b..cce9d42a 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -8,8 +8,21 @@ from copy import copy from enum import Enum from functools import reduce -from typing import (Any, Callable, Dict, List, Literal, Mapping, Optional, - Sequence, Set, Tuple, Type, TypeVar, Union) +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from typing import get_args as typing_get_args from typing import no_type_check @@ -43,6 +56,7 @@ _FromFieldInfoInputs = dict Undefined = ... UndefinedType = type(...) +from redis.asyncio.client import Pipeline from redis.commands.json.path import Path from redis.exceptions import ResponseError from typing_extensions import Protocol, Unpack, get_args, get_origin @@ -2719,9 +2733,7 @@ async def _delete(cls, db, *pks): return await db.delete(*pks) @classmethod - async def delete( - cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None - ) -> int: + async def delete(cls, pk: Any, pipeline: Optional[Pipeline] = None) -> int: """Delete data at this key.""" db = cls._get_db(pipeline) @@ -2737,7 +2749,7 @@ async def update(self, **field_values): async def save( self: "Model", - pipeline: Optional[redis.client.Pipeline] = None, + pipeline: Optional[Pipeline] = None, nx: bool = False, xx: bool = False, ) -> Optional["Model"]: @@ -2757,9 +2769,7 @@ async def save( """ raise NotImplementedError - async def expire( - self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None - ): + async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None): db = self._get_db(pipeline) # TODO: Wrap any Redis response errors in a custom exception? @@ -2905,7 +2915,7 @@ def get_annotations(cls): async def add( cls: Type["Model"], models: Sequence["Model"], - pipeline: Optional[redis.client.Pipeline] = None, + pipeline: Optional[Pipeline] = None, pipeline_verifier: Callable[..., Any] = verify_pipeline_response, ) -> Sequence["Model"]: db = cls._get_db(pipeline, bulk=True) @@ -2923,9 +2933,7 @@ async def add( return models @classmethod - def _get_db( - self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False - ): + def _get_db(self, pipeline: Optional[Pipeline] = None, bulk: bool = False): if pipeline is not None: return pipeline elif bulk: @@ -2937,7 +2945,7 @@ def _get_db( async def delete_many( cls, models: Sequence["RedisModel"], - pipeline: Optional[redis.client.Pipeline] = None, + pipeline: Optional[Pipeline] = None, ) -> int: db = cls._get_db(pipeline) @@ -3069,7 +3077,7 @@ def _get_field_expirations( async def save( self: "Model", - pipeline: Optional[redis.client.Pipeline] = None, + pipeline: Optional[Pipeline] = None, nx: bool = False, xx: bool = False, field_expirations: Optional[Dict[str, int]] = None, @@ -3479,7 +3487,7 @@ def __init__(self, *args, **kwargs): async def save( self: "Model", - pipeline: Optional[redis.client.Pipeline] = None, + pipeline: Optional[Pipeline] = None, nx: bool = False, xx: bool = False, ) -> Optional["Model"]: diff --git a/make_sync.py b/make_sync.py index b43733b6..779e37ee 100644 --- a/make_sync.py +++ b/make_sync.py @@ -111,7 +111,20 @@ def remove_run_async_call(match): with open(file_path, 'w') as f: f.write(content) + # Post-process model.py to fix async imports for sync version + model_file = Path(__file__).absolute().parent / "redis_om/model/model.py" + if model_file.exists(): + with open(model_file, 'r') as f: + content = f.read() + # Fix Pipeline import: redis.asyncio.client -> redis.client + content = content.replace( + 'from redis.asyncio.client import Pipeline', + 'from redis.client import Pipeline' + ) + + with open(model_file, 'w') as f: + f.write(content) if __name__ == "__main__":