diff --git a/redis/commands/core.py b/redis/commands/core.py index 525b31c99d..3b82699129 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -56,7 +56,9 @@ if TYPE_CHECKING: import redis.asyncio.client + import redis.asyncio.cluster import redis.client + import redis.cluster class ACLCommands(CommandsProtocol): @@ -5914,7 +5916,11 @@ class Script: An executable Lua script object returned by ``register_script`` """ - def __init__(self, registered_client: "redis.client.Redis", script: ScriptTextT): + def __init__( + self, + registered_client: Union["redis.client.Redis", "redis.cluster.RedisCluster"], + script: ScriptTextT, + ): self.registered_client = registered_client self.script = script # Precalculate and store the SHA1 hex digest of the script. @@ -5930,7 +5936,7 @@ def __call__( self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.client.Redis", None] = None, + client: Union["redis.client.Redis", "redis.cluster.RedisCluster", None] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -5979,7 +5985,9 @@ class AsyncScript: def __init__( self, - registered_client: "redis.asyncio.client.Redis", + registered_client: Union[ + "redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster" + ], script: ScriptTextT, ): self.registered_client = registered_client @@ -6001,7 +6009,9 @@ async def __call__( self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.asyncio.client.Redis", None] = None, + client: Union[ + "redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster", None + ] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -6234,7 +6244,10 @@ def script_load(self, script: ScriptTextT) -> ResponseT: """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self: "redis.client.Redis", script: ScriptTextT) -> Script: + def register_script( + self: Union["redis.client.Redis", "redis.cluster.RedisCluster"], + script: ScriptTextT, + ) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -6249,7 +6262,7 @@ async def script_debug(self, *args) -> None: return super().script_debug() def register_script( - self: "redis.asyncio.client.Redis", + self: Union["redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster"], script: ScriptTextT, ) -> AsyncScript: """ diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index b8e100c04a..76592e7a93 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,6 +1,7 @@ import pytest import pytest_asyncio from redis import exceptions +from redis.commands.core import AsyncScript from tests.conftest import skip_if_server_version_lt multiply_script = """ @@ -150,3 +151,27 @@ async def test_eval_msgpack_pipeline_error_in_lua(self, r): with pytest.raises(exceptions.ResponseError) as excinfo: await pipe.execute() assert excinfo.type == exceptions.ResponseError + + +@pytest.mark.onlycluster +class TestAsyncScriptWithCluster: + """Tests for AsyncScript with RedisCluster support.""" + + @pytest_asyncio.fixture + async def r(self, create_redis): + redis = await create_redis() + yield redis + await redis.script_flush() + + @pytest.mark.asyncio() + async def test_register_script_with_cluster_client(self, r): + """Test that register_script works with async RedisCluster client. + + This verifies the type hints fix for register_script to support RedisCluster. + """ + await r.set("a", 2) + multiply = r.register_script(multiply_script) + assert isinstance(multiply, AsyncScript) + assert multiply.registered_client is r + # Verify the script actually works + assert await multiply(keys=["a"], args=[3]) == 6 diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 899dc69482..9587285c0c 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -96,6 +96,19 @@ def test_eval_same_slot(self, r): result = r.eval(script, 2, "A{foo}", "B{foo}") assert result == 8 + @pytest.mark.onlycluster + def test_register_script_with_cluster_client(self, r): + """Test that register_script works with RedisCluster client. + + This verifies the type hints fix for register_script to support RedisCluster. + """ + r.set("a", 2) + multiply = r.register_script(multiply_script) + assert isinstance(multiply, Script) + assert multiply.registered_client is r + # Verify the script actually works + assert multiply(keys=["a"], args=[3]) == 6 + @pytest.mark.onlycluster def test_eval_crossslot(self, r): """