Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 69 additions & 74 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def kill(self):

@classmethod
async def _connect_secure(
cls, resolved_address, timeout, deadline, keep_alive, ssl_context
cls,
resolved_address: Address,
timeout: float | None,
deadline: Deadline,
keep_alive: bool,
ssl_context: SSLContext | None,
) -> t.Self:
"""
Connect to the address and return the socket.
Expand All @@ -220,28 +225,38 @@ async def _connect_secure(
"""
loop = asyncio.get_event_loop()
s = None
local_port = 0

# TODO: tomorrow me: fix this mess
try:
if len(resolved_address) == 2:
s = socket(AF_INET)
elif len(resolved_address) == 4:
s = socket(AF_INET6)
else:
raise ValueError(f"Unsupported address {resolved_address!r}")
s.setblocking(False) # asyncio + blocking = no-no!
log.debug("[#0000] C: <OPEN> %s", resolved_address)
await wait_for(loop.sock_connect(s, resolved_address), timeout)
local_port = s.getsockname()[1]
try:
if len(resolved_address) == 2:
s = socket(AF_INET)
elif len(resolved_address) == 4:
s = socket(AF_INET6)
else:
raise ValueError(
f"Unsupported address {resolved_address!r}"
)
s.setblocking(False) # asyncio + blocking = no-no!
log.debug("[#0000] C: <OPEN> %s", resolved_address)
await wait_for(loop.sock_connect(s, resolved_address), timeout)
local_port = s.getsockname()[1]

keep_alive = 1 if keep_alive else 0
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
except asyncio.TimeoutError:
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
raise ServiceUnavailable(
"Timed out trying to establish connection to "
f"{resolved_address!r}"
) from None
except asyncio.CancelledError:
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
raise

ssl_kwargs: dict[str, t.Any] = {}

hostname = resolved_address._host_name or None
if ssl_context is not None:
hostname = resolved_address._host_name or None
sni_host = hostname if HAS_SNI and hostname else None
ssl_kwargs.update(
ssl=ssl_context,
Expand All @@ -255,9 +270,26 @@ async def _connect_secure(
loop=loop,
)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(
lambda: protocol, sock=s, **ssl_kwargs
)

try:
transport, _ = await loop.create_connection(
lambda: protocol, sock=s, **ssl_kwargs
)

except (OSError, SSLError, CertificateError) as error:
log.debug(
"[#0000] S: <SECURE FAILURE> %s: %r",
resolved_address,
error,
)
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(hostname, local_port),
) from error
except asyncio.CancelledError:
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
raise

writer = asyncio.StreamWriter(transport, protocol, reader, loop)

if ssl_context is not None:
Expand All @@ -270,39 +302,8 @@ async def _connect_secure(
raise BoltProtocolError(
"When using an encrypted socket, the server should "
"always provide a certificate",
address=(resolved_address._host_name, local_port),
address=(hostname, local_port),
)

return cls(reader, protocol, writer)

except asyncio.TimeoutError:
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise ServiceUnavailable(
"Timed out trying to establish connection to "
f"{resolved_address!r}"
) from None
except asyncio.CancelledError:
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise
except (SSLError, CertificateError) as error:
log.debug(
"[#0000] S: <SECURE FAILURE> %s: %s",
resolved_address,
error,
)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(resolved_address._host_name, local_port),
) from error
except Exception as error:
log.debug(
"[#0000] S: <ERROR> %s %s",
Expand All @@ -319,6 +320,8 @@ async def _connect_secure(
) from error
raise

return cls(reader, protocol, writer)

@abc.abstractmethod
async def _handshake(
self,
Expand Down Expand Up @@ -355,7 +358,7 @@ def _kill_raw_socket(cls, socket_):
socket_.close()


class BoltSocketBase:
class BoltSocketBase(abc.ABC):
Bolt: t.Final[type[Bolt]] = None # type: ignore[assignment]

def __init__(self, socket_: socket):
Expand Down Expand Up @@ -468,8 +471,13 @@ def kill(self):

@classmethod
def _connect_secure(
cls, resolved_address, timeout, deadline, keep_alive, ssl_context
):
cls,
resolved_address: Address,
timeout: float | None,
deadline: Deadline,
keep_alive: bool,
ssl_context: SSLContext | None,
) -> t.Self:
"""
Connect to the address and return the socket.

Expand Down Expand Up @@ -503,26 +511,14 @@ def _connect_secure(
log.debug("[#0000] C: <OPEN> %s", resolved_address)
s.connect(resolved_address)
s.settimeout(t)
keep_alive = 1 if keep_alive else 0
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
except TimeoutError:
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise ServiceUnavailable(
"Timed out trying to establish connection to "
f"{resolved_address!r}"
) from None
except Exception as error:
log.debug(
"[#0000] S: <ERROR> %s %s",
type(error).__name__,
" ".join(map(repr, error.args)),
)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
if isinstance(error, OSError):
raise ServiceUnavailable(
"Failed to establish connection to "
Expand All @@ -544,13 +540,10 @@ def _connect_secure(
s.settimeout(t)
except (OSError, SSLError, CertificateError) as cause:
log.debug(
"[#0000] S: <SECURE FAILURE> %s: %s",
"[#0000] S: <SECURE FAILURE> %s: %r",
resolved_address,
cause,
)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(hostname, local_port),
Expand All @@ -564,15 +557,17 @@ def _connect_secure(
"[#0000] S: <SECURE FAILURE> %s: no certificate",
resolved_address,
)
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
if s:
cls._kill_raw_socket(s)
raise BoltProtocolError(
"When using an encrypted socket, the server should"
"always provide a certificate",
address=(hostname, local_port),
)
except Exception:
except Exception as error:
log.debug(
"[#0000] S: <ERROR> %s %s",
type(error).__name__,
" ".join(map(repr, error.args)),
)
if s is not None:
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
cls._kill_raw_socket(s)
Expand Down