Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions packages/sqlalchemy-spanner/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,38 @@ Read
for row in connection.execute(select(["*"], from_obj=table)).fetchall():
print(row)

Async Support
~~~~~~~~~~~~~

The Spanner dialect also supports asyncio when used with SQLAlchemy 1.4 or 2.0.
To use the async client, use the ``spanner+spanner_asyncio`` prefix:

.. code:: python

spanner+spanner_asyncio:///projects/project-id/instances/instance-id/databases/database-id

Example usage with ``create_async_engine``:

.. code:: python

from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
import asyncio

async def main():
engine = create_async_engine(
"spanner+spanner_asyncio:///projects/project-id/instances/instance-id/databases/database-id"
)

async with engine.connect() as conn:
result = await conn.execute(text("SELECT 1"))
print(result.fetchone())

await engine.dispose()

if __name__ == "__main__":
asyncio.run(main())

Migration
---------

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
from .sqlalchemy_spanner import SpannerDialect

from sqlalchemy.connectors.asyncio import (
AsyncAdapt_dbapi_connection,
AsyncAdapt_dbapi_cursor,
AsyncAdapt_dbapi_module,
)
from sqlalchemy.util.concurrency import await_only


class AsyncIODBAPISpannerCursor:
def __init__(self, sync_cursor):
self._sync_cursor = sync_cursor

@property
def description(self):
return self._sync_cursor.description

@property
def rowcount(self):
return self._sync_cursor.rowcount

@property
def lastrowid(self):
return self._sync_cursor.lastrowid

@property
def arraysize(self):
return self._sync_cursor.arraysize

@arraysize.setter
def arraysize(self, value):
self._sync_cursor.arraysize = value

async def close(self):
await asyncio.to_thread(self._sync_cursor.close)

async def execute(self, operation, parameters=None):
return await asyncio.to_thread(self._sync_cursor.execute, operation, parameters)

async def executemany(self, operation, seq_of_parameters):
return await asyncio.to_thread(
self._sync_cursor.executemany, operation, seq_of_parameters
)

async def fetchone(self):
return await asyncio.to_thread(self._sync_cursor.fetchone)

async def fetchmany(self, size=None):
return await asyncio.to_thread(self._sync_cursor.fetchmany, size)

async def fetchall(self):
return await asyncio.to_thread(self._sync_cursor.fetchall)

async def nextset(self):
if hasattr(self._sync_cursor, "nextset"):
return await asyncio.to_thread(self._sync_cursor.nextset)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()


class AsyncIODBAPISpannerConnection:
def __init__(self, sync_conn):
self._sync_conn = sync_conn

async def commit(self):
await asyncio.to_thread(self._sync_conn.commit)

async def rollback(self):
await asyncio.to_thread(self._sync_conn.rollback)

async def close(self):
await asyncio.to_thread(self._sync_conn.close)

def cursor(self):
return AsyncIODBAPISpannerCursor(self._sync_conn.cursor())

def __getattr__(self, name):
return getattr(self._sync_conn, name)


class AsyncAdapt_spanner_cursor(AsyncAdapt_dbapi_cursor):
@property
def connection(self):
return self._adapt_connection


class AsyncAdapt_spanner_connection(AsyncAdapt_dbapi_connection):
_cursor_cls = AsyncAdapt_spanner_cursor

@property
def connection(self):
return self._connection._sync_conn

def __getattr__(self, name):
return getattr(self._connection, name)


class AsyncAdapt_spanner_dbapi(AsyncAdapt_dbapi_module):
await_ = staticmethod(await_only)

def __init__(self, spanner_dbapi):
self.spanner_dbapi = spanner_dbapi
for name in dir(spanner_dbapi):
if not name.startswith("__") and name != "connect":
setattr(self, name, getattr(spanner_dbapi, name))

def connect(self, *arg, **kw):
async_creator_fn = kw.pop("async_creator_fn", None)
if async_creator_fn:
connection = async_creator_fn(*arg, **kw)
else:
connection = self.spanner_dbapi.connect(*arg, **kw)

return AsyncAdapt_spanner_connection(
self, AsyncIODBAPISpannerConnection(connection)
)


class SpannerDialect_asyncio(SpannerDialect):
driver = "spanner_asyncio"
is_async = True
supports_statement_cache = True

@classmethod
def import_dbapi(cls):
from google.cloud import spanner_dbapi

return AsyncAdapt_spanner_dbapi(spanner_dbapi)

@classmethod
def dbapi(cls):
return cls.import_dbapi()

@classmethod
def get_pool_class(cls, url):
from sqlalchemy.pool import AsyncAdaptedQueuePool

return AsyncAdaptedQueuePool

def get_driver_connection(self, connection):
return connection._connection
3 changes: 2 additions & 1 deletion packages/sqlalchemy-spanner/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
long_description=readme,
entry_points={
"sqlalchemy.dialects": [
"spanner.spanner = google.cloud.sqlalchemy_spanner:SpannerDialect"
"spanner.spanner = google.cloud.sqlalchemy_spanner:SpannerDialect",
"spanner.spanner_asyncio = google.cloud.sqlalchemy_spanner.sqlalchemy_spanner_asyncio:SpannerDialect_asyncio",
]
},
install_requires=dependencies,
Expand Down
40 changes: 40 additions & 0 deletions packages/sqlalchemy-spanner/tests/unit/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner_asyncio import (
SpannerDialect_asyncio,
)
from sqlalchemy.testing.plugin.plugin_base import fixtures

class AsyncioTest(fixtures.TestBase):
@pytest.mark.asyncio
async def test_async_engine_creation(self):
assert os.environ.get("SPANNER_EMULATOR_HOST") is not None
engine = create_async_engine("spanner+spanner_asyncio:///projects/p/instances/i/databases/d")
assert engine.dialect.is_async
assert isinstance(engine.dialect, SpannerDialect_asyncio)

@pytest.mark.asyncio
async def test_async_connection(self, mocker):
from sqlalchemy import text
from sqlalchemy.pool import NullPool
assert os.environ.get("SPANNER_EMULATOR_HOST") is not None
engine = create_async_engine(
"spanner+spanner_asyncio:///projects/p/instances/i/databases/d",
poolclass=NullPool
)

# We need to mock the underlying sync connect
mock_connect = mocker.patch("google.cloud.spanner_dbapi.connect")
mock_sync_conn = mock_connect.return_value
mock_sync_cursor = mock_sync_conn.cursor.return_value

# When we call execute, it should work through the async adapter
async with engine.connect() as conn:
assert conn.dialect == engine.dialect
# This will eventually call cursor.execute in a thread
await conn.execute(text("SELECT 1"))

mock_connect.assert_called_once()
mock_sync_conn.close.assert_called_once()
mock_sync_cursor.execute.assert_called_once()
Loading