From 61ed57ae98ec3808ebbe5efe4b97f69d21f22a1c Mon Sep 17 00:00:00 2001
From: Federico Busetti <729029+febus982@users.noreply.github.com>
Date: Mon, 3 Mar 2025 09:40:18 +0000
Subject: [PATCH] Switch back from ABC to protocols
---
.idea/misc.xml | 2 +-
.idea/sqlalchemy-bind-manager.iml | 3 ++-
docs/repository/usage.md | 20 ++++++++++++++-
.../_repository/abstract.py | 25 +++----------------
tests/test_interfaces.py | 19 ++++++++++----
5 files changed, 39 insertions(+), 30 deletions(-)
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 31a9dc3..820854d 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -4,7 +4,7 @@
-
+
diff --git a/.idea/sqlalchemy-bind-manager.iml b/.idea/sqlalchemy-bind-manager.iml
index 739e159..e89aaa4 100644
--- a/.idea/sqlalchemy-bind-manager.iml
+++ b/.idea/sqlalchemy-bind-manager.iml
@@ -4,8 +4,9 @@
+
-
+
diff --git a/docs/repository/usage.md b/docs/repository/usage.md
index 4d82e10..ab19b2b 100644
--- a/docs/repository/usage.md
+++ b/docs/repository/usage.md
@@ -57,7 +57,25 @@ async def some_async_function(repository: SQLAlchemyAsyncRepositoryInterface[MyM
...
```
-Both repository and related interface are Generic, accepting the model class as a typing argument.
+Both repository and related interface are Protocols, accepting the model class as a typing argument. You can also
+extend the protocols with your custom methods.
+
+```python
+from typing import Protocol
+from sqlalchemy_bind_manager.repository import SQLAlchemyRepositoryInterface, SQLAlchemyRepository
+
+# SQLAlchemy model
+class MyModel:
+ ...
+
+class MyCustomRepositoryInterface(SQLAlchemyRepositoryInterface[MyModel], Protocol):
+ def some_custom_method(self, model: MyModel) -> MyModel:
+ ...
+
+class MyCustomRepository(SQLAlchemyRepository[MyModel]):
+ def some_custom_method(self, model: MyModel) -> MyModel:
+ return model
+```
///
### Maximum query limit
diff --git a/sqlalchemy_bind_manager/_repository/abstract.py b/sqlalchemy_bind_manager/_repository/abstract.py
index 94b5f34..1c1a5b2 100644
--- a/sqlalchemy_bind_manager/_repository/abstract.py
+++ b/sqlalchemy_bind_manager/_repository/abstract.py
@@ -27,14 +27,13 @@
# Software is furnished to do so, subject to the following conditions:
#
#
-from abc import ABC, abstractmethod
from typing import (
Any,
- Generic,
Iterable,
List,
Literal,
Mapping,
+ Protocol,
Tuple,
Union,
)
@@ -48,8 +47,7 @@
)
-class SQLAlchemyAsyncRepositoryInterface(Generic[MODEL], ABC):
- @abstractmethod
+class SQLAlchemyAsyncRepositoryInterface(Protocol[MODEL]):
async def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""Get a model by primary key.
@@ -59,7 +57,6 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""
...
- @abstractmethod
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""Get a list of models by primary keys.
@@ -68,7 +65,6 @@ async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""
...
- @abstractmethod
async def save(self, instance: MODEL) -> MODEL:
"""Persist a model.
@@ -77,7 +73,6 @@ async def save(self, instance: MODEL) -> MODEL:
"""
...
- @abstractmethod
async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""Persist many models in a single database get_session.
@@ -86,7 +81,6 @@ async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""
...
- @abstractmethod
async def delete(self, instance: MODEL) -> None:
"""Deletes a model.
@@ -94,7 +88,6 @@ async def delete(self, instance: MODEL) -> None:
"""
...
- @abstractmethod
async def delete_many(self, instances: Iterable[MODEL]) -> None:
"""Deletes a collection of models in a single transaction.
@@ -102,7 +95,6 @@ async def delete_many(self, instances: Iterable[MODEL]) -> None:
"""
...
- @abstractmethod
async def find(
self,
search_params: Union[None, Mapping[str, Any]] = None,
@@ -130,7 +122,6 @@ async def find(
"""
...
- @abstractmethod
async def paginated_find(
self,
items_per_page: int,
@@ -169,7 +160,6 @@ async def paginated_find(
"""
...
- @abstractmethod
async def cursor_paginated_find(
self,
items_per_page: int,
@@ -205,8 +195,7 @@ async def cursor_paginated_find(
...
-class SQLAlchemyRepositoryInterface(Generic[MODEL], ABC):
- @abstractmethod
+class SQLAlchemyRepositoryInterface(Protocol[MODEL]):
def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""Get a model by primary key.
@@ -216,7 +205,6 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""
...
- @abstractmethod
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""Get a list of models by primary keys.
@@ -225,7 +213,6 @@ def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""
...
- @abstractmethod
def save(self, instance: MODEL) -> MODEL:
"""Persist a model.
@@ -234,7 +221,6 @@ def save(self, instance: MODEL) -> MODEL:
"""
...
- @abstractmethod
def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""Persist many models in a single database get_session.
@@ -243,7 +229,6 @@ def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""
...
- @abstractmethod
def delete(self, instance: MODEL) -> None:
"""Deletes a model.
@@ -251,7 +236,6 @@ def delete(self, instance: MODEL) -> None:
"""
...
- @abstractmethod
def delete_many(self, instances: Iterable[MODEL]) -> None:
"""Deletes a collection of models in a single transaction.
@@ -259,7 +243,6 @@ def delete_many(self, instances: Iterable[MODEL]) -> None:
"""
...
- @abstractmethod
def find(
self,
search_params: Union[None, Mapping[str, Any]] = None,
@@ -287,7 +270,6 @@ def find(
"""
...
- @abstractmethod
def paginated_find(
self,
items_per_page: int,
@@ -326,7 +308,6 @@ def paginated_find(
"""
...
- @abstractmethod
def cursor_paginated_find(
self,
items_per_page: int,
diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py
index b05509b..07f35fe 100644
--- a/tests/test_interfaces.py
+++ b/tests/test_interfaces.py
@@ -1,4 +1,5 @@
from inspect import signature
+from typing import Protocol, runtime_checkable
from sqlalchemy_bind_manager.repository import (
SQLAlchemyAsyncRepository,
@@ -8,9 +9,17 @@
)
+@runtime_checkable
+class RuntimeRepoProtocol(SQLAlchemyRepositoryInterface, Protocol): ...
+
+
+@runtime_checkable
+class RuntimeAsyncRepoProtocol(SQLAlchemyAsyncRepositoryInterface, Protocol): ...
+
+
def test_interfaces():
- assert issubclass(SQLAlchemyRepository, SQLAlchemyRepositoryInterface)
- assert issubclass(SQLAlchemyAsyncRepository, SQLAlchemyAsyncRepositoryInterface)
+ assert issubclass(SQLAlchemyRepository, RuntimeRepoProtocol)
+ assert issubclass(SQLAlchemyAsyncRepository, RuntimeAsyncRepoProtocol)
sync_methods = [
method
@@ -26,15 +35,15 @@ def test_interfaces():
assert sync_methods == async_methods
for method in sync_methods:
- # Sync signature is the same as sync protocol
+ # Concrete sync signature is the same as sync protocol signature
assert signature(getattr(SQLAlchemyRepository, method)) == signature(
getattr(SQLAlchemyRepositoryInterface, method)
)
- # Async signature is the same as async protocol
+ # Concrete async signature is the same as async protocol signature
assert signature(getattr(SQLAlchemyAsyncRepository, method)) == signature(
getattr(SQLAlchemyAsyncRepositoryInterface, method)
)
- # Sync signature is the same as async signature
+ # Sync protocol signature is the same as async protocol signature
assert signature(
getattr(SQLAlchemyAsyncRepositoryInterface, method)
) == signature(getattr(SQLAlchemyRepositoryInterface, method))