Skip to content

Commit 5ed3694

Browse files
committed
Unify connection error handling between async and sync
Error handling (especially logging) in the code establishing TCP & TLS connections differed too much making reasoning about it more complicated than it needs to be.
1 parent fc4811d commit 5ed3694

File tree

1 file changed

+69
-74
lines changed

1 file changed

+69
-74
lines changed

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 69 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ def kill(self):
205205

206206
@classmethod
207207
async def _connect_secure(
208-
cls, resolved_address, timeout, deadline, keep_alive, ssl_context
208+
cls,
209+
resolved_address: Address,
210+
timeout: float | None,
211+
deadline: Deadline,
212+
keep_alive: bool,
213+
ssl_context: SSLContext | None,
209214
) -> t.Self:
210215
"""
211216
Connect to the address and return the socket.
@@ -220,28 +225,38 @@ async def _connect_secure(
220225
"""
221226
loop = asyncio.get_event_loop()
222227
s = None
223-
local_port = 0
224228

225229
# TODO: tomorrow me: fix this mess
226230
try:
227-
if len(resolved_address) == 2:
228-
s = socket(AF_INET)
229-
elif len(resolved_address) == 4:
230-
s = socket(AF_INET6)
231-
else:
232-
raise ValueError(f"Unsupported address {resolved_address!r}")
233-
s.setblocking(False) # asyncio + blocking = no-no!
234-
log.debug("[#0000] C: <OPEN> %s", resolved_address)
235-
await wait_for(loop.sock_connect(s, resolved_address), timeout)
236-
local_port = s.getsockname()[1]
231+
try:
232+
if len(resolved_address) == 2:
233+
s = socket(AF_INET)
234+
elif len(resolved_address) == 4:
235+
s = socket(AF_INET6)
236+
else:
237+
raise ValueError(
238+
f"Unsupported address {resolved_address!r}"
239+
)
240+
s.setblocking(False) # asyncio + blocking = no-no!
241+
log.debug("[#0000] C: <OPEN> %s", resolved_address)
242+
await wait_for(loop.sock_connect(s, resolved_address), timeout)
243+
local_port = s.getsockname()[1]
237244

238-
keep_alive = 1 if keep_alive else 0
239-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
245+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
246+
except asyncio.TimeoutError:
247+
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
248+
raise ServiceUnavailable(
249+
"Timed out trying to establish connection to "
250+
f"{resolved_address!r}"
251+
) from None
252+
except asyncio.CancelledError:
253+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
254+
raise
240255

241256
ssl_kwargs: dict[str, t.Any] = {}
242257

258+
hostname = resolved_address._host_name or None
243259
if ssl_context is not None:
244-
hostname = resolved_address._host_name or None
245260
sni_host = hostname if HAS_SNI and hostname else None
246261
ssl_kwargs.update(
247262
ssl=ssl_context,
@@ -255,9 +270,26 @@ async def _connect_secure(
255270
loop=loop,
256271
)
257272
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
258-
transport, _ = await loop.create_connection(
259-
lambda: protocol, sock=s, **ssl_kwargs
260-
)
273+
274+
try:
275+
transport, _ = await loop.create_connection(
276+
lambda: protocol, sock=s, **ssl_kwargs
277+
)
278+
279+
except (OSError, SSLError, CertificateError) as error:
280+
log.debug(
281+
"[#0000] S: <SECURE FAILURE> %s: %r",
282+
resolved_address,
283+
error,
284+
)
285+
raise BoltSecurityError(
286+
message="Failed to establish encrypted connection.",
287+
address=(hostname, local_port),
288+
) from error
289+
except asyncio.CancelledError:
290+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
291+
raise
292+
261293
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
262294

263295
if ssl_context is not None:
@@ -270,39 +302,8 @@ async def _connect_secure(
270302
raise BoltProtocolError(
271303
"When using an encrypted socket, the server should "
272304
"always provide a certificate",
273-
address=(resolved_address._host_name, local_port),
305+
address=(hostname, local_port),
274306
)
275-
276-
return cls(reader, protocol, writer)
277-
278-
except asyncio.TimeoutError:
279-
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
280-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
281-
if s:
282-
cls._kill_raw_socket(s)
283-
raise ServiceUnavailable(
284-
"Timed out trying to establish connection to "
285-
f"{resolved_address!r}"
286-
) from None
287-
except asyncio.CancelledError:
288-
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
289-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
290-
if s:
291-
cls._kill_raw_socket(s)
292-
raise
293-
except (SSLError, CertificateError) as error:
294-
log.debug(
295-
"[#0000] S: <SECURE FAILURE> %s: %s",
296-
resolved_address,
297-
error,
298-
)
299-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
300-
if s:
301-
cls._kill_raw_socket(s)
302-
raise BoltSecurityError(
303-
message="Failed to establish encrypted connection.",
304-
address=(resolved_address._host_name, local_port),
305-
) from error
306307
except Exception as error:
307308
log.debug(
308309
"[#0000] S: <ERROR> %s %s",
@@ -319,6 +320,8 @@ async def _connect_secure(
319320
) from error
320321
raise
321322

323+
return cls(reader, protocol, writer)
324+
322325
@abc.abstractmethod
323326
async def _handshake(
324327
self,
@@ -355,7 +358,7 @@ def _kill_raw_socket(cls, socket_):
355358
socket_.close()
356359

357360

358-
class BoltSocketBase:
361+
class BoltSocketBase(abc.ABC):
359362
Bolt: t.Final[type[Bolt]] = None # type: ignore[assignment]
360363

361364
def __init__(self, socket_: socket):
@@ -468,8 +471,13 @@ def kill(self):
468471

469472
@classmethod
470473
def _connect_secure(
471-
cls, resolved_address, timeout, deadline, keep_alive, ssl_context
472-
):
474+
cls,
475+
resolved_address: Address,
476+
timeout: float | None,
477+
deadline: Deadline,
478+
keep_alive: bool,
479+
ssl_context: SSLContext | None,
480+
) -> t.Self:
473481
"""
474482
Connect to the address and return the socket.
475483
@@ -503,26 +511,14 @@ def _connect_secure(
503511
log.debug("[#0000] C: <OPEN> %s", resolved_address)
504512
s.connect(resolved_address)
505513
s.settimeout(t)
506-
keep_alive = 1 if keep_alive else 0
507-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
514+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
508515
except TimeoutError:
509516
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
510-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
511-
if s:
512-
cls._kill_raw_socket(s)
513517
raise ServiceUnavailable(
514518
"Timed out trying to establish connection to "
515519
f"{resolved_address!r}"
516520
) from None
517521
except Exception as error:
518-
log.debug(
519-
"[#0000] S: <ERROR> %s %s",
520-
type(error).__name__,
521-
" ".join(map(repr, error.args)),
522-
)
523-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
524-
if s:
525-
cls._kill_raw_socket(s)
526522
if isinstance(error, OSError):
527523
raise ServiceUnavailable(
528524
"Failed to establish connection to "
@@ -544,13 +540,10 @@ def _connect_secure(
544540
s.settimeout(t)
545541
except (OSError, SSLError, CertificateError) as cause:
546542
log.debug(
547-
"[#0000] S: <SECURE FAILURE> %s: %s",
543+
"[#0000] S: <SECURE FAILURE> %s: %r",
548544
resolved_address,
549545
cause,
550546
)
551-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
552-
if s:
553-
cls._kill_raw_socket(s)
554547
raise BoltSecurityError(
555548
message="Failed to establish encrypted connection.",
556549
address=(hostname, local_port),
@@ -564,15 +557,17 @@ def _connect_secure(
564557
"[#0000] S: <SECURE FAILURE> %s: no certificate",
565558
resolved_address,
566559
)
567-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
568-
if s:
569-
cls._kill_raw_socket(s)
570560
raise BoltProtocolError(
571561
"When using an encrypted socket, the server should"
572562
"always provide a certificate",
573563
address=(hostname, local_port),
574564
)
575-
except Exception:
565+
except Exception as error:
566+
log.debug(
567+
"[#0000] S: <ERROR> %s %s",
568+
type(error).__name__,
569+
" ".join(map(repr, error.args)),
570+
)
576571
if s is not None:
577572
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
578573
cls._kill_raw_socket(s)

0 commit comments

Comments
 (0)