diff --git a/packages/sqlalchemy-spanner/README.rst b/packages/sqlalchemy-spanner/README.rst index fc9e5b6c0f75..0a103abea125 100644 --- a/packages/sqlalchemy-spanner/README.rst +++ b/packages/sqlalchemy-spanner/README.rst @@ -159,6 +159,38 @@ Read for row in connection.execute(select(["*"], from_obj=table)).fetchall(): print(row) +Async Support +~~~~~~~~~~~~~ + +The Spanner dialect also supports asyncio when used with SQLAlchemy 1.4 or 2.0. +To use the async client, use the ``spanner+spanner_asyncio`` prefix: + +.. code:: python + + spanner+spanner_asyncio:///projects/project-id/instances/instance-id/databases/database-id + +Example usage with ``create_async_engine``: + +.. code:: python + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy import text + import asyncio + + async def main(): + engine = create_async_engine( + "spanner+spanner_asyncio:///projects/project-id/instances/instance-id/databases/database-id" + ) + + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + print(result.fetchone()) + + await engine.dispose() + + if __name__ == "__main__": + asyncio.run(main()) + Migration --------- diff --git a/packages/sqlalchemy-spanner/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner_asyncio.py b/packages/sqlalchemy-spanner/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner_asyncio.py new file mode 100644 index 000000000000..ff3b98827f7b --- /dev/null +++ b/packages/sqlalchemy-spanner/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner_asyncio.py @@ -0,0 +1,147 @@ +import asyncio +from .sqlalchemy_spanner import SpannerDialect + +from sqlalchemy.connectors.asyncio import ( + AsyncAdapt_dbapi_connection, + AsyncAdapt_dbapi_cursor, + AsyncAdapt_dbapi_module, +) +from sqlalchemy.util.concurrency import await_only + + +class AsyncIODBAPISpannerCursor: + def __init__(self, sync_cursor): + self._sync_cursor = sync_cursor + + @property + def description(self): + return self._sync_cursor.description + + @property + def rowcount(self): + return self._sync_cursor.rowcount + + @property + def lastrowid(self): + return self._sync_cursor.lastrowid + + @property + def arraysize(self): + return self._sync_cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._sync_cursor.arraysize = value + + async def close(self): + await asyncio.to_thread(self._sync_cursor.close) + + async def execute(self, operation, parameters=None): + return await asyncio.to_thread(self._sync_cursor.execute, operation, parameters) + + async def executemany(self, operation, seq_of_parameters): + return await asyncio.to_thread( + self._sync_cursor.executemany, operation, seq_of_parameters + ) + + async def fetchone(self): + return await asyncio.to_thread(self._sync_cursor.fetchone) + + async def fetchmany(self, size=None): + return await asyncio.to_thread(self._sync_cursor.fetchmany, size) + + async def fetchall(self): + return await asyncio.to_thread(self._sync_cursor.fetchall) + + async def nextset(self): + if hasattr(self._sync_cursor, "nextset"): + return await asyncio.to_thread(self._sync_cursor.nextset) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + +class AsyncIODBAPISpannerConnection: + def __init__(self, sync_conn): + self._sync_conn = sync_conn + + async def commit(self): + await asyncio.to_thread(self._sync_conn.commit) + + async def rollback(self): + await asyncio.to_thread(self._sync_conn.rollback) + + async def close(self): + await asyncio.to_thread(self._sync_conn.close) + + def cursor(self): + return AsyncIODBAPISpannerCursor(self._sync_conn.cursor()) + + def __getattr__(self, name): + return getattr(self._sync_conn, name) + + +class AsyncAdapt_spanner_cursor(AsyncAdapt_dbapi_cursor): + @property + def connection(self): + return self._adapt_connection + + +class AsyncAdapt_spanner_connection(AsyncAdapt_dbapi_connection): + _cursor_cls = AsyncAdapt_spanner_cursor + + @property + def connection(self): + return self._connection._sync_conn + + def __getattr__(self, name): + return getattr(self._connection, name) + + +class AsyncAdapt_spanner_dbapi(AsyncAdapt_dbapi_module): + await_ = staticmethod(await_only) + + def __init__(self, spanner_dbapi): + self.spanner_dbapi = spanner_dbapi + for name in dir(spanner_dbapi): + if not name.startswith("__") and name != "connect": + setattr(self, name, getattr(spanner_dbapi, name)) + + def connect(self, *arg, **kw): + async_creator_fn = kw.pop("async_creator_fn", None) + if async_creator_fn: + connection = async_creator_fn(*arg, **kw) + else: + connection = self.spanner_dbapi.connect(*arg, **kw) + + return AsyncAdapt_spanner_connection( + self, AsyncIODBAPISpannerConnection(connection) + ) + + +class SpannerDialect_asyncio(SpannerDialect): + driver = "spanner_asyncio" + is_async = True + supports_statement_cache = True + + @classmethod + def import_dbapi(cls): + from google.cloud import spanner_dbapi + + return AsyncAdapt_spanner_dbapi(spanner_dbapi) + + @classmethod + def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def get_pool_class(cls, url): + from sqlalchemy.pool import AsyncAdaptedQueuePool + + return AsyncAdaptedQueuePool + + def get_driver_connection(self, connection): + return connection._connection diff --git a/packages/sqlalchemy-spanner/setup.py b/packages/sqlalchemy-spanner/setup.py index 9bf2183e2982..0bef0178e0e2 100644 --- a/packages/sqlalchemy-spanner/setup.py +++ b/packages/sqlalchemy-spanner/setup.py @@ -73,7 +73,8 @@ long_description=readme, entry_points={ "sqlalchemy.dialects": [ - "spanner.spanner = google.cloud.sqlalchemy_spanner:SpannerDialect" + "spanner.spanner = google.cloud.sqlalchemy_spanner:SpannerDialect", + "spanner.spanner_asyncio = google.cloud.sqlalchemy_spanner.sqlalchemy_spanner_asyncio:SpannerDialect_asyncio", ] }, install_requires=dependencies, diff --git a/packages/sqlalchemy-spanner/tests/unit/test_asyncio.py b/packages/sqlalchemy-spanner/tests/unit/test_asyncio.py new file mode 100644 index 000000000000..3a25678fe41a --- /dev/null +++ b/packages/sqlalchemy-spanner/tests/unit/test_asyncio.py @@ -0,0 +1,40 @@ +import os +import pytest +from sqlalchemy.ext.asyncio import create_async_engine +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner_asyncio import ( + SpannerDialect_asyncio, +) +from sqlalchemy.testing.plugin.plugin_base import fixtures + +class AsyncioTest(fixtures.TestBase): + @pytest.mark.asyncio + async def test_async_engine_creation(self): + assert os.environ.get("SPANNER_EMULATOR_HOST") is not None + engine = create_async_engine("spanner+spanner_asyncio:///projects/p/instances/i/databases/d") + assert engine.dialect.is_async + assert isinstance(engine.dialect, SpannerDialect_asyncio) + + @pytest.mark.asyncio + async def test_async_connection(self, mocker): + from sqlalchemy import text + from sqlalchemy.pool import NullPool + assert os.environ.get("SPANNER_EMULATOR_HOST") is not None + engine = create_async_engine( + "spanner+spanner_asyncio:///projects/p/instances/i/databases/d", + poolclass=NullPool + ) + + # We need to mock the underlying sync connect + mock_connect = mocker.patch("google.cloud.spanner_dbapi.connect") + mock_sync_conn = mock_connect.return_value + mock_sync_cursor = mock_sync_conn.cursor.return_value + + # When we call execute, it should work through the async adapter + async with engine.connect() as conn: + assert conn.dialect == engine.dialect + # This will eventually call cursor.execute in a thread + await conn.execute(text("SELECT 1")) + + mock_connect.assert_called_once() + mock_sync_conn.close.assert_called_once() + mock_sync_cursor.execute.assert_called_once()