diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 1a52e55f..7e8793fe 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -205,7 +205,12 @@ def kill(self): @classmethod async def _connect_secure( - cls, resolved_address, timeout, deadline, keep_alive, ssl_context + cls, + resolved_address: Address, + timeout: float | None, + deadline: Deadline, + keep_alive: bool, + ssl_context: SSLContext | None, ) -> t.Self: """ Connect to the address and return the socket. @@ -220,28 +225,38 @@ async def _connect_secure( """ loop = asyncio.get_event_loop() s = None - local_port = 0 # TODO: tomorrow me: fix this mess try: - if len(resolved_address) == 2: - s = socket(AF_INET) - elif len(resolved_address) == 4: - s = socket(AF_INET6) - else: - raise ValueError(f"Unsupported address {resolved_address!r}") - s.setblocking(False) # asyncio + blocking = no-no! - log.debug("[#0000] C: %s", resolved_address) - await wait_for(loop.sock_connect(s, resolved_address), timeout) - local_port = s.getsockname()[1] + try: + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + f"Unsupported address {resolved_address!r}" + ) + s.setblocking(False) # asyncio + blocking = no-no! + log.debug("[#0000] C: %s", resolved_address) + await wait_for(loop.sock_connect(s, resolved_address), timeout) + local_port = s.getsockname()[1] - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0) + except asyncio.TimeoutError: + log.debug("[#0000] S: %s", resolved_address) + raise ServiceUnavailable( + "Timed out trying to establish connection to " + f"{resolved_address!r}" + ) from None + except asyncio.CancelledError: + log.debug("[#0000] S: %s", resolved_address) + raise ssl_kwargs: dict[str, t.Any] = {} + hostname = resolved_address._host_name or None if ssl_context is not None: - hostname = resolved_address._host_name or None sni_host = hostname if HAS_SNI and hostname else None ssl_kwargs.update( ssl=ssl_context, @@ -255,9 +270,26 @@ async def _connect_secure( loop=loop, ) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) - transport, _ = await loop.create_connection( - lambda: protocol, sock=s, **ssl_kwargs - ) + + try: + transport, _ = await loop.create_connection( + lambda: protocol, sock=s, **ssl_kwargs + ) + + except (OSError, SSLError, CertificateError) as error: + log.debug( + "[#0000] S: %s: %r", + resolved_address, + error, + ) + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(hostname, local_port), + ) from error + except asyncio.CancelledError: + log.debug("[#0000] S: %s", resolved_address) + raise + writer = asyncio.StreamWriter(transport, protocol, reader, loop) if ssl_context is not None: @@ -270,39 +302,8 @@ async def _connect_secure( raise BoltProtocolError( "When using an encrypted socket, the server should " "always provide a certificate", - address=(resolved_address._host_name, local_port), + address=(hostname, local_port), ) - - return cls(reader, protocol, writer) - - except asyncio.TimeoutError: - log.debug("[#0000] S: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) - raise ServiceUnavailable( - "Timed out trying to establish connection to " - f"{resolved_address!r}" - ) from None - except asyncio.CancelledError: - log.debug("[#0000] S: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) - raise - except (SSLError, CertificateError) as error: - log.debug( - "[#0000] S: %s: %s", - resolved_address, - error, - ) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(resolved_address._host_name, local_port), - ) from error except Exception as error: log.debug( "[#0000] S: %s %s", @@ -319,6 +320,8 @@ async def _connect_secure( ) from error raise + return cls(reader, protocol, writer) + @abc.abstractmethod async def _handshake( self, @@ -355,7 +358,7 @@ def _kill_raw_socket(cls, socket_): socket_.close() -class BoltSocketBase: +class BoltSocketBase(abc.ABC): Bolt: t.Final[type[Bolt]] = None # type: ignore[assignment] def __init__(self, socket_: socket): @@ -468,8 +471,13 @@ def kill(self): @classmethod def _connect_secure( - cls, resolved_address, timeout, deadline, keep_alive, ssl_context - ): + cls, + resolved_address: Address, + timeout: float | None, + deadline: Deadline, + keep_alive: bool, + ssl_context: SSLContext | None, + ) -> t.Self: """ Connect to the address and return the socket. @@ -503,26 +511,14 @@ def _connect_secure( log.debug("[#0000] C: %s", resolved_address) s.connect(resolved_address) s.settimeout(t) - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0) except TimeoutError: log.debug("[#0000] S: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) raise ServiceUnavailable( "Timed out trying to establish connection to " f"{resolved_address!r}" ) from None except Exception as error: - log.debug( - "[#0000] S: %s %s", - type(error).__name__, - " ".join(map(repr, error.args)), - ) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) if isinstance(error, OSError): raise ServiceUnavailable( "Failed to establish connection to " @@ -544,13 +540,10 @@ def _connect_secure( s.settimeout(t) except (OSError, SSLError, CertificateError) as cause: log.debug( - "[#0000] S: %s: %s", + "[#0000] S: %s: %r", resolved_address, cause, ) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) raise BoltSecurityError( message="Failed to establish encrypted connection.", address=(hostname, local_port), @@ -564,15 +557,17 @@ def _connect_secure( "[#0000] S: %s: no certificate", resolved_address, ) - log.debug("[#0000] C: %s", resolved_address) - if s: - cls._kill_raw_socket(s) raise BoltProtocolError( "When using an encrypted socket, the server should" "always provide a certificate", address=(hostname, local_port), ) - except Exception: + except Exception as error: + log.debug( + "[#0000] S: %s %s", + type(error).__name__, + " ".join(map(repr, error.args)), + ) if s is not None: log.debug("[#0000] C: %s", resolved_address) cls._kill_raw_socket(s)