Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion .idea/sqlalchemy-bind-manager.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 19 additions & 1 deletion docs/repository/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,25 @@ async def some_async_function(repository: SQLAlchemyAsyncRepositoryInterface[MyM
...
```

Both repository and related interface are Generic, accepting the model class as a typing argument.
Both repository and related interface are Protocols, accepting the model class as a typing argument. You can also
extend the protocols with your custom methods.

```python
from typing import Protocol
from sqlalchemy_bind_manager.repository import SQLAlchemyRepositoryInterface, SQLAlchemyRepository

# SQLAlchemy model
class MyModel:
...

class MyCustomRepositoryInterface(SQLAlchemyRepositoryInterface[MyModel], Protocol):
def some_custom_method(self, model: MyModel) -> MyModel:
...

class MyCustomRepository(SQLAlchemyRepository[MyModel]):
def some_custom_method(self, model: MyModel) -> MyModel:
return model
```
///

### Maximum query limit
Expand Down
25 changes: 3 additions & 22 deletions sqlalchemy_bind_manager/_repository/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
# Software is furnished to do so, subject to the following conditions:
#
#
from abc import ABC, abstractmethod
from typing import (
Any,
Generic,
Iterable,
List,
Literal,
Mapping,
Protocol,
Tuple,
Union,
)
Expand All @@ -48,8 +47,7 @@
)


class SQLAlchemyAsyncRepositoryInterface(Generic[MODEL], ABC):
@abstractmethod
class SQLAlchemyAsyncRepositoryInterface(Protocol[MODEL]):
async def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""Get a model by primary key.

Expand All @@ -59,7 +57,6 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""
...

@abstractmethod
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""Get a list of models by primary keys.

Expand All @@ -68,7 +65,6 @@ async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""
...

@abstractmethod
async def save(self, instance: MODEL) -> MODEL:
"""Persist a model.

Expand All @@ -77,7 +73,6 @@ async def save(self, instance: MODEL) -> MODEL:
"""
...

@abstractmethod
async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""Persist many models in a single database get_session.

Expand All @@ -86,23 +81,20 @@ async def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""
...

@abstractmethod
async def delete(self, instance: MODEL) -> None:
"""Deletes a model.

:param instance: The model instance
"""
...

@abstractmethod
async def delete_many(self, instances: Iterable[MODEL]) -> None:
"""Deletes a collection of models in a single transaction.

:param instances: The model instances
"""
...

@abstractmethod
async def find(
self,
search_params: Union[None, Mapping[str, Any]] = None,
Expand Down Expand Up @@ -130,7 +122,6 @@ async def find(
"""
...

@abstractmethod
async def paginated_find(
self,
items_per_page: int,
Expand Down Expand Up @@ -169,7 +160,6 @@ async def paginated_find(
"""
...

@abstractmethod
async def cursor_paginated_find(
self,
items_per_page: int,
Expand Down Expand Up @@ -205,8 +195,7 @@ async def cursor_paginated_find(
...


class SQLAlchemyRepositoryInterface(Generic[MODEL], ABC):
@abstractmethod
class SQLAlchemyRepositoryInterface(Protocol[MODEL]):
def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""Get a model by primary key.

Expand All @@ -216,7 +205,6 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
"""
...

@abstractmethod
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""Get a list of models by primary keys.

Expand All @@ -225,7 +213,6 @@ def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
"""
...

@abstractmethod
def save(self, instance: MODEL) -> MODEL:
"""Persist a model.

Expand All @@ -234,7 +221,6 @@ def save(self, instance: MODEL) -> MODEL:
"""
...

@abstractmethod
def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""Persist many models in a single database get_session.

Expand All @@ -243,23 +229,20 @@ def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
"""
...

@abstractmethod
def delete(self, instance: MODEL) -> None:
"""Deletes a model.

:param instance: The model instance
"""
...

@abstractmethod
def delete_many(self, instances: Iterable[MODEL]) -> None:
"""Deletes a collection of models in a single transaction.

:param instances: The model instances
"""
...

@abstractmethod
def find(
self,
search_params: Union[None, Mapping[str, Any]] = None,
Expand Down Expand Up @@ -287,7 +270,6 @@ def find(
"""
...

@abstractmethod
def paginated_find(
self,
items_per_page: int,
Expand Down Expand Up @@ -326,7 +308,6 @@ def paginated_find(
"""
...

@abstractmethod
def cursor_paginated_find(
self,
items_per_page: int,
Expand Down
19 changes: 14 additions & 5 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from inspect import signature
from typing import Protocol, runtime_checkable

from sqlalchemy_bind_manager.repository import (
SQLAlchemyAsyncRepository,
Expand All @@ -8,9 +9,17 @@
)


@runtime_checkable
class RuntimeRepoProtocol(SQLAlchemyRepositoryInterface, Protocol): ...


@runtime_checkable
class RuntimeAsyncRepoProtocol(SQLAlchemyAsyncRepositoryInterface, Protocol): ...


def test_interfaces():
assert issubclass(SQLAlchemyRepository, SQLAlchemyRepositoryInterface)
assert issubclass(SQLAlchemyAsyncRepository, SQLAlchemyAsyncRepositoryInterface)
assert issubclass(SQLAlchemyRepository, RuntimeRepoProtocol)
assert issubclass(SQLAlchemyAsyncRepository, RuntimeAsyncRepoProtocol)

sync_methods = [
method
Expand All @@ -26,15 +35,15 @@ def test_interfaces():
assert sync_methods == async_methods

for method in sync_methods:
# Sync signature is the same as sync protocol
# Concrete sync signature is the same as sync protocol signature
assert signature(getattr(SQLAlchemyRepository, method)) == signature(
getattr(SQLAlchemyRepositoryInterface, method)
)
# Async signature is the same as async protocol
# Concrete async signature is the same as async protocol signature
assert signature(getattr(SQLAlchemyAsyncRepository, method)) == signature(
getattr(SQLAlchemyAsyncRepositoryInterface, method)
)
# Sync signature is the same as async signature
# Sync protocol signature is the same as async protocol signature
assert signature(
getattr(SQLAlchemyAsyncRepositoryInterface, method)
) == signature(getattr(SQLAlchemyRepositoryInterface, method))