diff --git a/pyproject.toml b/pyproject.toml index 90aa6e8..5ce2fd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,8 @@ license = "MIT" authors = [{ name = "Flowdacity" }] requires-python = ">=3.12" keywords = [ + "flowdacity", + "tailback", "queue", "job queue", "task queue", @@ -20,7 +22,7 @@ keywords = [ "leaky bucket", ] classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", @@ -31,6 +33,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: System :: Distributed Computing", "Topic :: Utilities", + "Typing :: Typed", ] dependencies = ["msgpack>=1.1.2", "redis[hiredis]>=7.1.0"] diff --git a/src/tailback/base.py b/src/tailback/base.py index 92a1aaa..aa9d868 100644 --- a/src/tailback/base.py +++ b/src/tailback/base.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. +from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass +from typing import Any, cast from tailback.config import TailbackConfig from tailback.exceptions import BadArgumentException from tailback.keys import RedisKeys from tailback.responses import ( + RedisValue, decode_redis_value, format_dequeue_response, format_metrics_counts, @@ -24,6 +27,9 @@ validate_metrics_arguments, ) +RedisCall = tuple[list[str], list[Any]] +StatusResponse = dict[str, str] + @dataclass(frozen=True) class ClearQueuePlan: @@ -35,40 +41,40 @@ class ClearQueuePlan: queue_type: str queue_id: str - def payload_member(self, job_id): + def payload_member(self, job_id: str) -> str: return "%s:%s:%s" % (self.queue_type, self.queue_id, job_id) class BaseTailback(object): """Shared non-I/O behavior for async and sync Tailback clients.""" - def __init__(self, config): - self._r = None - self._scripts = None - self.config = TailbackConfig.from_mapping(config) - self._keys = RedisKeys(self.config.queue.key_prefix) + def __init__(self, config: Mapping[str, Any]) -> None: + self._r: Any = None + self._scripts: Any = None + self.config: TailbackConfig = TailbackConfig.from_mapping(config) + self._keys: RedisKeys = RedisKeys(self.config.queue.key_prefix) - self._key_prefix = self.config.queue.key_prefix - self._job_expire_interval = int(self.config.queue.job_expire_interval) - self._default_job_requeue_limit = int( + self._key_prefix: str = self.config.queue.key_prefix + self._job_expire_interval: int = int(self.config.queue.job_expire_interval) + self._default_job_requeue_limit: int = int( self.config.queue.default_job_requeue_limit ) - def redis_client(self): + def redis_client(self) -> Any | None: return self._r - def _current_timestamp(self): + def _current_timestamp(self) -> str: return str(generate_epoch()) def _build_enqueue_call( self, - payload, - interval, - job_id, - queue_id, - queue_type, - requeue_limit, - ): + payload: Any, + interval: int, + job_id: str, + queue_id: str, + queue_type: str, + requeue_limit: int | None, + ) -> RedisCall: enqueue_args = validate_enqueue_arguments( payload, interval, @@ -89,18 +95,28 @@ def _build_enqueue_call( ] return keys, args - def _build_dequeue_call(self, queue_type): + def _build_dequeue_call(self, queue_type: str) -> RedisCall: validate_dequeue_arguments(queue_type) return [self._key_prefix, queue_type], [ self._current_timestamp(), self._job_expire_interval, ] - def _build_finish_call(self, job_id, queue_id, queue_type): + def _build_finish_call( + self, + job_id: str, + queue_id: str, + queue_type: str, + ) -> RedisCall: validate_finish_arguments(job_id, queue_id, queue_type) return [self._key_prefix, queue_type], [queue_id, job_id] - def _build_interval_call(self, interval, queue_id, queue_type): + def _build_interval_call( + self, + interval: int, + queue_id: str, + queue_type: str, + ) -> RedisCall: validate_interval_arguments(interval, queue_id, queue_type) keys = [ self._keys.interval_hash, @@ -108,35 +124,45 @@ def _build_interval_call(self, interval, queue_id, queue_type): ] return keys, [interval] - def _build_requeue_call(self, queue_type, timestamp): + def _build_requeue_call(self, queue_type: RedisValue, timestamp: str) -> RedisCall: queue_type = decode_redis_value(queue_type) return [self._key_prefix, queue_type], [timestamp] - def _build_global_metrics_call(self): + def _build_global_metrics_call(self) -> RedisCall: return [self._key_prefix], [self._current_timestamp()] - def _build_queue_metrics_call(self, queue_type, queue_id): + def _build_queue_metrics_call(self, queue_type: str, queue_id: str) -> RedisCall: return [self._keys.job_queue(queue_type, queue_id)], [self._current_timestamp()] - def _validate_metrics_call(self, queue_type, queue_id): + def _validate_metrics_call( + self, + queue_type: str | None, + queue_id: str | None, + ) -> None: validate_metrics_arguments(queue_type, queue_id) if not queue_type and queue_id: raise BadArgumentException( "`queue_id` should be accompanied by `queue_type`." ) - def _queue_type_metrics_keys(self, queue_type): + def _queue_type_metrics_keys(self, queue_type: str) -> tuple[str, str]: return ( self._keys.ready_queue_set(queue_type), self._keys.active_queue_set(queue_type), ) - def _queue_length_key(self, queue_type, queue_id): + def _queue_length_key(self, queue_type: str, queue_id: str) -> str: validate_get_queue_length_arguments(queue_type, queue_id) return self._keys.job_queue(queue_type, queue_id) - def _clear_queue_plan(self, queue_type, queue_id): + def _clear_queue_plan( + self, + queue_type: str | None, + queue_id: str | None, + ) -> ClearQueuePlan: validate_clear_queue_arguments(queue_type, queue_id) + queue_type = cast(str, queue_type) + queue_id = cast(str, queue_id) return ClearQueuePlan( primary_set=self._keys.ready_queue_set(queue_type), job_queue=self._keys.job_queue(queue_type, queue_id), @@ -147,26 +173,26 @@ def _clear_queue_plan(self, queue_type, queue_id): queue_id=queue_id, ) - def _finish_response(self, finish_response): + def _finish_response(self, finish_response: int) -> StatusResponse: if finish_response == 0: return {"status": "failure"} return {"status": "success"} - def _interval_response(self, interval_response): + def _interval_response(self, interval_response: int) -> StatusResponse: if interval_response == 0: return {"status": "failure"} return {"status": "success"} - def _dequeue_response(self, dequeue_response): + def _dequeue_response(self, dequeue_response: Sequence[Any]) -> dict[str, Any]: return format_dequeue_response(dequeue_response) def _global_metrics_response( self, - active_queue_types, - ready_queue_types, - enqueue_details, - dequeue_details, - ): + active_queue_types: Iterable[RedisValue], + ready_queue_types: Iterable[RedisValue], + enqueue_details: Sequence[Any], + dequeue_details: Sequence[Any], + ) -> dict[str, Any]: enqueue_counts, dequeue_counts = format_metrics_counts( enqueue_details, dequeue_details, @@ -178,7 +204,11 @@ def _global_metrics_response( "dequeue_counts": dequeue_counts, } - def _queue_type_metrics_response(self, ready_queues, active_queues): + def _queue_type_metrics_response( + self, + ready_queues: Iterable[RedisValue], + active_queues: Iterable[RedisValue], + ) -> dict[str, Any]: return { "status": "success", "queue_ids": format_queue_ids(ready_queues, active_queues), @@ -186,10 +216,10 @@ def _queue_type_metrics_response(self, ready_queues, active_queues): def _queue_metrics_response( self, - queue_length, - enqueue_details, - dequeue_details, - ): + queue_length: int | str | bytes, + enqueue_details: Sequence[Any], + dequeue_details: Sequence[Any], + ) -> dict[str, Any]: enqueue_counts, dequeue_counts = format_metrics_counts( enqueue_details, dequeue_details, @@ -201,23 +231,23 @@ def _queue_metrics_response( "dequeue_counts": dequeue_counts, } - def _decode_redis_value(self, value): + def _decode_redis_value(self, value: RedisValue) -> str: return decode_redis_value(value) - def _decode_requeue_job(self, job): + def _decode_requeue_job(self, job: RedisValue) -> tuple[str, str]: queue_id, job_id = decode_redis_value(job).split(":") return queue_id, job_id - def _clear_queue_empty_response(self): + def _clear_queue_empty_response(self) -> StatusResponse: return {"status": "Failure", "message": "No queued calls found"} - def _clear_queue_removed_response(self): + def _clear_queue_removed_response(self) -> StatusResponse: return { "status": "Success", "message": "Successfully removed all queued calls", } - def _clear_queue_purged_response(self): + def _clear_queue_purged_response(self) -> StatusResponse: return { "status": "Success", "message": "Successfully removed all queued calls and purged related resources", diff --git a/src/tailback/config.py b/src/tailback/config.py index 0055589..ff1efd4 100644 --- a/src/tailback/config.py +++ b/src/tailback/config.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from dataclasses import dataclass +from typing import Any, Self from tailback.exceptions import TailbackException from tailback.utils import is_valid_interval, is_valid_requeue_limit @@ -22,7 +23,7 @@ class RedisConfig: password: str | None = None @classmethod - def from_mapping(cls, config): + def from_mapping(cls, config: Mapping[str, Any]) -> Self: cls._validate_required(config) cls._validate_connection(config) cls._validate_optional(config) @@ -38,7 +39,7 @@ def from_mapping(cls, config): ) @classmethod - def _validate_required(cls, config): + def _validate_required(cls, config: Mapping[str, Any]) -> None: conn_type = cls._require_value(config, "conn_type") if conn_type not in REDIS_CONN_TYPES: raise TailbackException( @@ -50,7 +51,7 @@ def _validate_required(cls, config): raise TailbackException("Invalid config: redis.db must be an integer") @classmethod - def _validate_connection(cls, config): + def _validate_connection(cls, config: Mapping[str, Any]) -> None: cls._validate_clustered(config) if config["conn_type"] == "unix_sock": @@ -60,12 +61,12 @@ def _validate_connection(cls, config): cls._validate_tcp_socket(config) @classmethod - def _validate_clustered(cls, config): + def _validate_clustered(cls, config: Mapping[str, Any]) -> None: if "clustered" in config and not isinstance(config["clustered"], bool): raise TailbackException("Invalid config: redis.clustered must be a boolean") @classmethod - def _validate_unix_socket(cls, config): + def _validate_unix_socket(cls, config: Mapping[str, Any]) -> None: unix_socket_path = cls._require_value(config, "unix_socket_path") if not cls._is_non_empty_string(unix_socket_path): raise TailbackException( @@ -73,10 +74,12 @@ def _validate_unix_socket(cls, config): ) @classmethod - def _validate_tcp_socket(cls, config): + def _validate_tcp_socket(cls, config: Mapping[str, Any]) -> None: host = cls._require_value(config, "host") if not cls._is_non_empty_string(host): - raise TailbackException("Invalid config: redis.host must be a non-empty string") + raise TailbackException( + "Invalid config: redis.host must be a non-empty string" + ) port = cls._require_value(config, "port") if not cls._is_int_not_bool(port): @@ -88,24 +91,26 @@ def _validate_tcp_socket(cls, config): ) @classmethod - def _validate_optional(cls, config): + def _validate_optional(cls, config: Mapping[str, Any]) -> None: if "password" in config and config["password"] is not None: if not isinstance(config["password"], str): - raise TailbackException("Invalid config: redis.password must be a string") + raise TailbackException( + "Invalid config: redis.password must be a string" + ) @staticmethod - def _require_value(config, option_name): + def _require_value(config: Mapping[str, Any], option_name: str) -> Any: if option_name not in config: raise TailbackException("Missing config: redis.%s" % option_name) return config[option_name] @staticmethod - def _is_non_empty_string(value): + def _is_non_empty_string(value: object) -> bool: return isinstance(value, str) and bool(value) @staticmethod - def _is_int_not_bool(value): + def _is_int_not_bool(value: object) -> bool: return isinstance(value, int) and not isinstance(value, bool) @@ -117,7 +122,7 @@ class QueueConfig: default_job_requeue_limit: int @classmethod - def from_mapping(cls, config): + def from_mapping(cls, config: Mapping[str, Any]) -> Self: cls._validate_required(config) return cls( @@ -128,7 +133,7 @@ def from_mapping(cls, config): ) @classmethod - def _validate_required(cls, config): + def _validate_required(cls, config: Mapping[str, Any]) -> None: key_prefix = cls._require_value(config, "key_prefix") if not cls._is_non_empty_string(key_prefix): raise TailbackException( @@ -151,14 +156,14 @@ def _validate_required(cls, config): ) @staticmethod - def _require_value(config, option_name): + def _require_value(config: Mapping[str, Any], option_name: str) -> Any: if option_name not in config: raise TailbackException("Missing config: queue.%s" % option_name) return config[option_name] @staticmethod - def _is_non_empty_string(value): + def _is_non_empty_string(value: object) -> bool: return isinstance(value, str) and bool(value) @@ -168,7 +173,7 @@ class TailbackConfig: queue: QueueConfig @classmethod - def from_mapping(cls, config): + def from_mapping(cls, config: Mapping[str, Any]) -> Self: normalized = cls._normalize_sections(config) cls._require_sections(normalized) @@ -178,11 +183,13 @@ def from_mapping(cls, config): ) @staticmethod - def _normalize_sections(config): + def _normalize_sections(config: Mapping[str, Any]) -> dict[str, dict[str, Any]]: if not isinstance(config, Mapping): - raise TailbackException("Config must be a mapping with redis and queue sections") + raise TailbackException( + "Config must be a mapping with redis and queue sections" + ) - normalized = {} + normalized: dict[str, dict[str, Any]] = {} for section_name, section_values in config.items(): if not isinstance(section_values, Mapping): raise TailbackException( @@ -196,6 +203,6 @@ def _normalize_sections(config): return normalized @staticmethod - def _require_sections(config): + def _require_sections(config: Mapping[str, Any]) -> None: if "redis" not in config or "queue" not in config: raise TailbackException("Config missing required sections: redis, queue") diff --git a/src/tailback/keys.py b/src/tailback/keys.py index ddf3279..ead21df 100644 --- a/src/tailback/keys.py +++ b/src/tailback/keys.py @@ -9,36 +9,36 @@ class RedisKeys: key_prefix: str @property - def active_queue_types(self): + def active_queue_types(self) -> str: return "%s:active:queue_type" % self.key_prefix @property - def ready_queue_types(self): + def ready_queue_types(self) -> str: return "%s:ready:queue_type" % self.key_prefix @property - def interval_hash(self): + def interval_hash(self) -> str: return "%s:interval" % self.key_prefix @property - def payload_hash(self): + def payload_hash(self) -> str: return "%s:payload" % self.key_prefix @property - def deep_status(self): + def deep_status(self) -> str: return "fq:deep_status:%s" % self.key_prefix - def ready_queue_set(self, queue_type): + def ready_queue_set(self, queue_type: str) -> str: return "%s:%s" % (self.key_prefix, queue_type) - def active_queue_set(self, queue_type): + def active_queue_set(self, queue_type: str) -> str: return "%s:%s:active" % (self.key_prefix, queue_type) - def job_queue(self, queue_type, queue_id): + def job_queue(self, queue_type: str, queue_id: str) -> str: return "%s:%s:%s" % (self.key_prefix, queue_type, queue_id) - def interval_member(self, queue_type, queue_id): + def interval_member(self, queue_type: str, queue_id: str) -> str: return "%s:%s" % (queue_type, queue_id) - def payload_member(self, queue_type, queue_id, job_id): + def payload_member(self, queue_type: str, queue_id: str, job_id: str) -> str: return "%s:%s:%s" % (queue_type, queue_id, job_id) diff --git a/src/tailback/lua.py b/src/tailback/lua.py index cd9affa..d9eda4b 100644 --- a/src/tailback/lua.py +++ b/src/tailback/lua.py @@ -3,20 +3,20 @@ from dataclasses import dataclass, fields from pathlib import Path -from typing import Any +from typing import Any, Callable, Self @dataclass(frozen=True) class LuaScripts: - enqueue: Any - dequeue: Any - finish: Any - interval: Any - requeue: Any - metrics: Any + enqueue: Callable[..., Any] + dequeue: Callable[..., Any] + finish: Callable[..., Any] + interval: Callable[..., Any] + requeue: Callable[..., Any] + metrics: Callable[..., Any] @classmethod - def register(cls, redis_client): + def register(cls, redis_client: Any) -> Self: registered_scripts = { script_field.name: redis_client.register_script( cls._read_script(script_field.name) @@ -26,7 +26,7 @@ def register(cls, redis_client): return cls(**registered_scripts) @staticmethod - def _read_script(script_name): + def _read_script(script_name: str) -> str: script_path = ( Path(__file__).with_name("scripts") / "lua" / ("%s.lua" % script_name) ) diff --git a/src/tailback/py.typed b/src/tailback/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/tailback/queue.py b/src/tailback/queue.py index d006b0e..94a8735 100644 --- a/src/tailback/queue.py +++ b/src/tailback/queue.py @@ -3,6 +3,7 @@ # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. import asyncio +from typing import Any from tailback.base import BaseTailback from tailback.lua import LuaScripts @@ -12,28 +13,28 @@ class Tailback(BaseTailback): """Async Tailback API.""" - async def initialize(self): + async def initialize(self) -> None: """Set up the async Redis client and register Lua scripts.""" self._r = create_async_redis_client(self.config.redis) await validate_async_redis_connection(self._r) self._register_lua_scripts() - def _register_lua_scripts(self): + def _register_lua_scripts(self) -> None: self._scripts = LuaScripts.register(self._r) - def reload_lua_scripts(self): + def reload_lua_scripts(self) -> None: """Lets user reload the Lua scripts at run time.""" self._register_lua_scripts() async def enqueue( self, - payload, - interval, - job_id, - queue_id, - queue_type="default", - requeue_limit=None, - ): + payload: Any, + interval: int, + job_id: str, + queue_id: str, + queue_type: str = "default", + requeue_limit: int | None = None, + ) -> dict[str, str]: """Enqueue a job into the specified queue_id and queue_type.""" keys, args = self._build_enqueue_call( payload, @@ -46,25 +47,35 @@ async def enqueue( await self._scripts.enqueue(keys=keys, args=args) return {"status": "queued"} - async def dequeue(self, queue_type="default"): + async def dequeue(self, queue_type: str = "default") -> dict[str, Any]: """Dequeue a ready job for queue_type, or return failure.""" keys, args = self._build_dequeue_call(queue_type) dequeue_response = await self._scripts.dequeue(keys=keys, args=args) return self._dequeue_response(dequeue_response) - async def finish(self, job_id, queue_id, queue_type="default"): + async def finish( + self, + job_id: str, + queue_id: str, + queue_type: str = "default", + ) -> dict[str, str]: """Mark a dequeued job as completed successfully.""" keys, args = self._build_finish_call(job_id, queue_id, queue_type) finish_response = await self._scripts.finish(keys=keys, args=args) return self._finish_response(finish_response) - async def interval(self, interval, queue_id, queue_type="default"): + async def interval( + self, + interval: int, + queue_id: str, + queue_type: str = "default", + ) -> dict[str, str]: """Update the interval for a queue_id and queue_type.""" keys, args = self._build_interval_call(interval, queue_id, queue_type) interval_response = await self._scripts.interval(keys=keys, args=args) return self._interval_response(interval_response) - async def requeue(self): + async def requeue(self) -> None: """Re-queue expired active jobs back into their ready queues.""" timestamp = self._current_timestamp() active_queue_type_list = await self._r.smembers(self._keys.active_queue_types) @@ -80,7 +91,11 @@ async def requeue(self): queue_type=queue_type, ) - async def metrics(self, queue_type=None, queue_id=None): + async def metrics( + self, + queue_type: str | None = None, + queue_id: str | None = None, + ) -> dict[str, Any]: """Return global, queue-type, or queue-specific metrics.""" self._validate_metrics_call(queue_type, queue_id) @@ -127,14 +142,19 @@ async def metrics(self, queue_type=None, queue_id=None): return {"status": "failure"} - async def deep_status(self): + async def deep_status(self) -> Any: """ Check Redis availability. If Redis is down, set() will raise. :return: value or None """ return await self._r.set(self._keys.deep_status, "sharq_deep_status") - async def clear_queue(self, queue_type=None, queue_id=None, purge_all=False): + async def clear_queue( + self, + queue_type: str | None = None, + queue_id: str | None = None, + purge_all: bool = False, + ) -> dict[str, str]: """Clear entries in a queue and optionally purge related resources.""" plan = self._clear_queue_plan(queue_type, queue_id) @@ -160,14 +180,14 @@ async def clear_queue(self, queue_type=None, queue_id=None, purge_all=False): await self._r.delete(plan.job_queue) return response - async def get_queue_length(self, queue_type, queue_id): + async def get_queue_length(self, queue_type: str, queue_id: str) -> int: """ Return the current Redis list length for key_prefix:queue_type:queue_id. """ redis_key = self._queue_length_key(queue_type, queue_id) return await self._r.llen(redis_key) - async def close(self): + async def close(self) -> None: """Cleanly close the underlying Redis client or connection pool.""" if self._r is None: return diff --git a/src/tailback/redis.py b/src/tailback/redis.py index f962f12..73607a6 100644 --- a/src/tailback/redis.py +++ b/src/tailback/redis.py @@ -1,16 +1,21 @@ # -*- coding: utf-8 -*- # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. +from typing import Any, cast + from redis import Redis as SyncRedis from redis import RedisCluster as SyncRedisCluster from redis.asyncio import Redis as AsyncRedis from redis.asyncio.cluster import ClusterNode as AsyncClusterNode from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from tailback.config import RedisConfig from tailback.exceptions import TailbackException -def create_async_redis_client(redis_config): +def create_async_redis_client( + redis_config: RedisConfig, +) -> AsyncRedis | AsyncRedisCluster: if redis_config.conn_type == "unix_sock": return AsyncRedis( db=redis_config.db, @@ -19,9 +24,11 @@ def create_async_redis_client(redis_config): ) if redis_config.conn_type == "tcp_sock": + host = cast(str, redis_config.host) + port = int(cast(int, redis_config.port)) if redis_config.clustered: startup_nodes = [ - AsyncClusterNode(redis_config.host, int(redis_config.port)), + AsyncClusterNode(host, port), ] return AsyncRedisCluster( startup_nodes=startup_nodes, @@ -32,15 +39,15 @@ def create_async_redis_client(redis_config): return AsyncRedis( db=redis_config.db, - host=redis_config.host, - port=int(redis_config.port), + host=host, + port=port, password=redis_config.password, ) raise TailbackException("Unknown redis conn_type: %s" % redis_config.conn_type) -def create_sync_redis_client(redis_config): +def create_sync_redis_client(redis_config: RedisConfig) -> SyncRedis | SyncRedisCluster: if redis_config.conn_type == "unix_sock": return SyncRedis( db=redis_config.db, @@ -49,10 +56,12 @@ def create_sync_redis_client(redis_config): ) if redis_config.conn_type == "tcp_sock": + host = cast(str, redis_config.host) + port = int(cast(int, redis_config.port)) if redis_config.clustered: return SyncRedisCluster( - host=redis_config.host, - port=int(redis_config.port), + host=host, + port=port, decode_responses=False, password=redis_config.password, socket_timeout=5, @@ -60,15 +69,15 @@ def create_sync_redis_client(redis_config): return SyncRedis( db=redis_config.db, - host=redis_config.host, - port=int(redis_config.port), + host=host, + port=port, password=redis_config.password, ) raise TailbackException("Unknown redis conn_type: %s" % redis_config.conn_type) -async def validate_async_redis_connection(redis_client): +async def validate_async_redis_connection(redis_client: Any) -> None: if redis_client is None: raise TailbackException("Redis client is not initialized") @@ -85,7 +94,7 @@ async def validate_async_redis_connection(redis_client): raise TailbackException("Failed to connect to Redis: ping returned False") -def validate_sync_redis_connection(redis_client): +def validate_sync_redis_connection(redis_client: Any) -> None: if redis_client is None: raise TailbackException("Redis client is not initialized") diff --git a/src/tailback/responses.py b/src/tailback/responses.py index 6cdc378..129094b 100644 --- a/src/tailback/responses.py +++ b/src/tailback/responses.py @@ -1,16 +1,21 @@ # -*- coding: utf-8 -*- # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. +from collections.abc import Iterable, Sequence +from typing import Any + from tailback.utils import convert_to_str, deserialize_payload +RedisValue = str | bytes + -def decode_redis_value(value): +def decode_redis_value(value: RedisValue) -> str: if isinstance(value, bytes): return value.decode("utf-8") return value -def format_dequeue_response(dequeue_response): +def format_dequeue_response(dequeue_response: Sequence[Any]) -> dict[str, Any]: if len(dequeue_response) < 4: return {"status": "failure"} @@ -28,9 +33,12 @@ def format_dequeue_response(dequeue_response): } -def format_metrics_counts(enqueue_details, dequeue_details): - enqueue_counts = {} - dequeue_counts = {} +def format_metrics_counts( + enqueue_details: Sequence[Any], + dequeue_details: Sequence[Any], +) -> tuple[dict[str, int], dict[str, int]]: + enqueue_counts: dict[str, int] = {} + dequeue_counts: dict[str, int] = {} for i in range(0, len(enqueue_details), 2): enqueue_counts[str(decode_redis_value(enqueue_details[i]))] = int( enqueue_details[i + 1] or 0 @@ -41,11 +49,17 @@ def format_metrics_counts(enqueue_details, dequeue_details): return enqueue_counts, dequeue_counts -def format_queue_types(active_queue_types, ready_queue_types): +def format_queue_types( + active_queue_types: Iterable[RedisValue], + ready_queue_types: Iterable[RedisValue], +) -> list[str]: return convert_to_str(set(active_queue_types) | set(ready_queue_types)) -def format_queue_ids(ready_queues, active_queues): +def format_queue_ids( + ready_queues: Iterable[RedisValue], + active_queues: Iterable[RedisValue], +) -> list[str]: ready_queue_ids = {decode_redis_value(queue) for queue in ready_queues} active_queue_ids = { decode_redis_value(queue).split(":")[0] for queue in active_queues diff --git a/src/tailback/sync/queue.py b/src/tailback/sync/queue.py index 2ad7ac0..4d307d9 100644 --- a/src/tailback/sync/queue.py +++ b/src/tailback/sync/queue.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. +from typing import Any + from tailback.base import BaseTailback from tailback.lua import LuaScripts from tailback.redis import create_sync_redis_client, validate_sync_redis_connection @@ -9,28 +11,28 @@ class Tailback(BaseTailback): """Synchronous Tailback API.""" - def initialize(self): + def initialize(self) -> None: """Set up the synchronous Redis client and register Lua scripts.""" self._r = create_sync_redis_client(self.config.redis) validate_sync_redis_connection(self._r) self._register_lua_scripts() - def _register_lua_scripts(self): + def _register_lua_scripts(self) -> None: self._scripts = LuaScripts.register(self._r) - def reload_lua_scripts(self): + def reload_lua_scripts(self) -> None: """Lets user reload the Lua scripts at run time.""" self._register_lua_scripts() def enqueue( self, - payload, - interval, - job_id, - queue_id, - queue_type="default", - requeue_limit=None, - ): + payload: Any, + interval: int, + job_id: str, + queue_id: str, + queue_type: str = "default", + requeue_limit: int | None = None, + ) -> dict[str, str]: """Enqueue a job into the specified queue_id and queue_type.""" keys, args = self._build_enqueue_call( payload, @@ -43,25 +45,35 @@ def enqueue( self._scripts.enqueue(keys=keys, args=args) return {"status": "queued"} - def dequeue(self, queue_type="default"): + def dequeue(self, queue_type: str = "default") -> dict[str, Any]: """Dequeue a ready job for queue_type, or return failure.""" keys, args = self._build_dequeue_call(queue_type) dequeue_response = self._scripts.dequeue(keys=keys, args=args) return self._dequeue_response(dequeue_response) - def finish(self, job_id, queue_id, queue_type="default"): + def finish( + self, + job_id: str, + queue_id: str, + queue_type: str = "default", + ) -> dict[str, str]: """Mark a dequeued job as completed successfully.""" keys, args = self._build_finish_call(job_id, queue_id, queue_type) finish_response = self._scripts.finish(keys=keys, args=args) return self._finish_response(finish_response) - def interval(self, interval, queue_id, queue_type="default"): + def interval( + self, + interval: int, + queue_id: str, + queue_type: str = "default", + ) -> dict[str, str]: """Update the interval for a queue_id and queue_type.""" keys, args = self._build_interval_call(interval, queue_id, queue_type) interval_response = self._scripts.interval(keys=keys, args=args) return self._interval_response(interval_response) - def requeue(self): + def requeue(self) -> None: """Re-queue expired active jobs back into their ready queues.""" timestamp = self._current_timestamp() active_queue_type_list = self._r.smembers(self._keys.active_queue_types) @@ -73,7 +85,11 @@ def requeue(self): queue_id, job_id = self._decode_requeue_job(job) self.finish(job_id=job_id, queue_id=queue_id, queue_type=queue_type) - def metrics(self, queue_type=None, queue_id=None): + def metrics( + self, + queue_type: str | None = None, + queue_id: str | None = None, + ) -> dict[str, Any]: """Return global, queue-type, or queue-specific metrics.""" self._validate_metrics_call(queue_type, queue_id) @@ -118,14 +134,19 @@ def metrics(self, queue_type=None, queue_id=None): return {"status": "failure"} - def deep_status(self): + def deep_status(self) -> Any: """ Check Redis availability. If Redis is down, set() will raise. :return: value or None """ return self._r.set(self._keys.deep_status, "sharq_deep_status") - def clear_queue(self, queue_type=None, queue_id=None, purge_all=False): + def clear_queue( + self, + queue_type: str | None = None, + queue_id: str | None = None, + purge_all: bool = False, + ) -> dict[str, str]: """Clear entries in a queue and optionally purge related resources.""" plan = self._clear_queue_plan(queue_type, queue_id) @@ -151,14 +172,14 @@ def clear_queue(self, queue_type=None, queue_id=None, purge_all=False): self._r.delete(plan.job_queue) return response - def get_queue_length(self, queue_type, queue_id): + def get_queue_length(self, queue_type: str, queue_id: str) -> int: """ Return the current Redis list length for key_prefix:queue_type:queue_id. """ redis_key = self._queue_length_key(queue_type, queue_id) return self._r.llen(redis_key) - def close(self): + def close(self) -> None: """Close the underlying synchronous Redis client.""" if self._r is None: return diff --git a/src/tailback/utils.py b/src/tailback/utils.py index 71c0482..a7aeafe 100644 --- a/src/tailback/utils.py +++ b/src/tailback/utils.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- # Copyright (c) 2014 Plivo Team. See LICENSE.txt for details. import time +from collections.abc import Iterable +from typing import Any + import msgpack -VALID_IDENTIFIER_SET = set(list("abcdefghijklmnopqrstuvwxyz0123456789_-")) +VALID_IDENTIFIER_SET: set[str] = set(list("abcdefghijklmnopqrstuvwxyz0123456789_-")) -def is_valid_identifier(identifier): +def is_valid_identifier(identifier: object) -> bool: """Checks if the given identifier is valid or not. A valid identifier may consists of the following characters with a maximum length of 100 characters, minimum of 1 character. @@ -28,14 +31,14 @@ def is_valid_identifier(identifier): return condensed_form.issubset(VALID_IDENTIFIER_SET) -def is_valid_interval(interval): +def is_valid_interval(interval: object) -> bool: """Checks if the given interval is valid. A valid interval is always a positive, non-zero integer value. """ return isinstance(interval, int) and interval > 0 -def is_valid_requeue_limit(requeue_limit): +def is_valid_requeue_limit(requeue_limit: object) -> bool: """Checks if the given requeue limit is valid. A valid requeue limit is always greater than or equal to -1. @@ -49,14 +52,14 @@ def is_valid_requeue_limit(requeue_limit): return True -def serialize_payload(payload): +def serialize_payload(payload: Any) -> bytes: """Tries to serialize the payload using msgpack. If it is not serializable, raises a TypeError. """ return msgpack.packb(payload, use_bin_type=True) -def deserialize_payload(payload): +def deserialize_payload(payload: bytes) -> Any: """Tries to deserialize the payload using msgpack.""" # Handle older Tailback payloads as well (before py3 migration) if payload.startswith(b'"') and payload.endswith(b'"'): @@ -65,12 +68,12 @@ def deserialize_payload(payload): return msgpack.unpackb(payload, raw=False) -def generate_epoch(): +def generate_epoch() -> int: """Generates an unix epoch in ms.""" return int(time.time() * 1000) -def convert_to_str(queue_set): +def convert_to_str(queue_set: Iterable[str | bytes | bytearray]) -> list[str]: """Takes set and decodes bytes to string""" queue_list = [] for queue in list(queue_set): diff --git a/src/tailback/validators.py b/src/tailback/validators.py index e6d5f7c..f537b84 100644 --- a/src/tailback/validators.py +++ b/src/tailback/validators.py @@ -2,6 +2,7 @@ # Copyright (c) 2025 Flowdacity Development Team. See LICENSE.txt for details. from dataclasses import dataclass +from typing import Any, cast from tailback.exceptions import BadArgumentException from tailback.utils import ( @@ -26,14 +27,14 @@ class EnqueueArguments: def validate_enqueue_arguments( - payload, - interval, - job_id, - queue_id, - queue_type, - requeue_limit, - default_requeue_limit, -): + payload: Any, + interval: object, + job_id: object, + queue_id: object, + queue_type: object, + requeue_limit: object | None, + default_requeue_limit: int, +) -> EnqueueArguments: if not is_valid_interval(interval): raise BadArgumentException(INVALID_INTERVAL) @@ -46,6 +47,7 @@ def validate_enqueue_arguments( if not is_valid_requeue_limit(requeue_limit): raise BadArgumentException(INVALID_REQUEUE_LIMIT) + requeue_limit = cast(int, requeue_limit) try: serialized_payload = serialize_payload(payload) @@ -58,17 +60,25 @@ def validate_enqueue_arguments( ) -def validate_dequeue_arguments(queue_type): +def validate_dequeue_arguments(queue_type: object) -> None: _validate_identifier(queue_type, INVALID_QUEUE_TYPE) -def validate_finish_arguments(job_id, queue_id, queue_type): +def validate_finish_arguments( + job_id: object, + queue_id: object, + queue_type: object, +) -> None: _validate_identifier(job_id, INVALID_JOB_ID) _validate_identifier(queue_id, INVALID_QUEUE_ID) _validate_identifier(queue_type, INVALID_QUEUE_TYPE) -def validate_interval_arguments(interval, queue_id, queue_type): +def validate_interval_arguments( + interval: object, + queue_id: object, + queue_type: object, +) -> None: if not is_valid_interval(interval): raise BadArgumentException(INVALID_INTERVAL) @@ -76,7 +86,10 @@ def validate_interval_arguments(interval, queue_id, queue_type): _validate_identifier(queue_type, INVALID_QUEUE_TYPE) -def validate_metrics_arguments(queue_type, queue_id): +def validate_metrics_arguments( + queue_type: object | None, + queue_id: object | None, +) -> None: if queue_id is not None and not is_valid_identifier(queue_id): raise BadArgumentException(INVALID_QUEUE_ID) @@ -84,7 +97,10 @@ def validate_metrics_arguments(queue_type, queue_id): raise BadArgumentException(INVALID_QUEUE_TYPE) -def validate_clear_queue_arguments(queue_type, queue_id): +def validate_clear_queue_arguments( + queue_type: object | None, + queue_id: object | None, +) -> None: if queue_id is None or not is_valid_identifier(queue_id): raise BadArgumentException(INVALID_QUEUE_ID) @@ -92,11 +108,11 @@ def validate_clear_queue_arguments(queue_type, queue_id): raise BadArgumentException(INVALID_QUEUE_TYPE) -def validate_get_queue_length_arguments(queue_type, queue_id): +def validate_get_queue_length_arguments(queue_type: object, queue_id: object) -> None: _validate_identifier(queue_type, INVALID_QUEUE_TYPE) _validate_identifier(queue_id, INVALID_QUEUE_ID) -def _validate_identifier(identifier, message): +def _validate_identifier(identifier: object, message: str) -> None: if not is_valid_identifier(identifier): raise BadArgumentException(message)