Skip to content

Commit 4f4bdd4

Browse files
committed
Unify connection error handling between async and sync
Error handling (especially logging) in the code establishing TCP & TLS connections differed too much making reasoning about it more complicated than it needs to be.
1 parent 8c7c19b commit 4f4bdd4

File tree

3 files changed

+89
-76
lines changed

3 files changed

+89
-76
lines changed

src/neo4j/_async/io/_bolt_socket.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,11 @@ async def connect(
327327
s = None
328328
try:
329329
s = await cls._connect_secure(
330-
resolved_address, tcp_timeout, keep_alive, ssl_context
330+
resolved_address,
331+
tcp_timeout,
332+
deadline,
333+
keep_alive,
334+
ssl_context,
331335
)
332336
agreed_version = await s._handshake(resolved_address, deadline)
333337
return s, agreed_version

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 79 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -205,54 +205,91 @@ def kill(self):
205205

206206
@classmethod
207207
async def _connect_secure(
208-
cls, resolved_address, timeout, keep_alive, ssl_context
208+
cls,
209+
resolved_address: Address,
210+
timeout: float | None,
211+
deadline: Deadline,
212+
keep_alive: bool,
213+
ssl_context: SSLContext | None,
209214
) -> t.Self:
210215
"""
211216
Connect to the address and return the socket.
212217
213218
:param resolved_address:
214219
:param timeout: seconds
220+
:param deadline: deadline for the whole operation
215221
:param keep_alive: True or False
216222
:param ssl_context: SSLContext or None
217223
218224
:returns: AsyncBoltSocket object
219225
"""
220226
loop = asyncio.get_event_loop()
221227
s = None
222-
local_port = 0
223228

224229
# TODO: tomorrow me: fix this mess
225230
try:
226-
if len(resolved_address) == 2:
227-
s = socket(AF_INET)
228-
elif len(resolved_address) == 4:
229-
s = socket(AF_INET6)
230-
else:
231-
raise ValueError(f"Unsupported address {resolved_address!r}")
232-
s.setblocking(False) # asyncio + blocking = no-no!
233-
log.debug("[#0000] C: <OPEN> %s", resolved_address)
234-
await wait_for(loop.sock_connect(s, resolved_address), timeout)
235-
local_port = s.getsockname()[1]
231+
try:
232+
if len(resolved_address) == 2:
233+
s = socket(AF_INET)
234+
elif len(resolved_address) == 4:
235+
s = socket(AF_INET6)
236+
else:
237+
raise ValueError(
238+
f"Unsupported address {resolved_address!r}"
239+
)
240+
s.setblocking(False) # asyncio + blocking = no-no!
241+
log.debug("[#0000] C: <OPEN> %s", resolved_address)
242+
await wait_for(loop.sock_connect(s, resolved_address), timeout)
243+
local_port = s.getsockname()[1]
236244

237-
keep_alive = 1 if keep_alive else 0
238-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
245+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
246+
except asyncio.TimeoutError:
247+
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
248+
raise ServiceUnavailable(
249+
"Timed out trying to establish connection to "
250+
f"{resolved_address!r}"
251+
) from None
252+
except asyncio.CancelledError:
253+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
254+
raise
239255

240256
ssl_kwargs: dict[str, t.Any] = {}
241257

258+
hostname = resolved_address._host_name or None
242259
if ssl_context is not None:
243-
hostname = resolved_address._host_name or None
244260
sni_host = hostname if HAS_SNI and hostname else None
245-
ssl_kwargs.update(ssl=ssl_context, server_hostname=sni_host)
261+
ssl_kwargs.update(
262+
ssl=ssl_context,
263+
server_hostname=sni_host,
264+
ssl_handshake_timeout=deadline.to_timeout(),
265+
)
246266
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
247267

248268
reader = asyncio.StreamReader(
249269
limit=2**16, # 64 KiB,
250270
loop=loop,
251271
)
252272
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
253-
transport, _ = await loop.create_connection(
254-
lambda: protocol, sock=s, **ssl_kwargs
255-
)
273+
274+
try:
275+
transport, _ = await loop.create_connection(
276+
lambda: protocol, sock=s, **ssl_kwargs
277+
)
278+
279+
except (OSError, SSLError, CertificateError) as error:
280+
log.debug(
281+
"[#0000] S: <SECURE FAILURE> %s: %r",
282+
resolved_address,
283+
error,
284+
)
285+
raise BoltSecurityError(
286+
message="Failed to establish encrypted connection.",
287+
address=(hostname, local_port),
288+
) from error
289+
except asyncio.CancelledError:
290+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
291+
raise
292+
256293
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
257294

258295
if ssl_context is not None:
@@ -265,39 +302,8 @@ async def _connect_secure(
265302
raise BoltProtocolError(
266303
"When using an encrypted socket, the server should "
267304
"always provide a certificate",
268-
address=(resolved_address._host_name, local_port),
305+
address=(hostname, local_port),
269306
)
270-
271-
return cls(reader, protocol, writer)
272-
273-
except asyncio.TimeoutError:
274-
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
275-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
276-
if s:
277-
cls._kill_raw_socket(s)
278-
raise ServiceUnavailable(
279-
"Timed out trying to establish connection to "
280-
f"{resolved_address!r}"
281-
) from None
282-
except asyncio.CancelledError:
283-
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
284-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
285-
if s:
286-
cls._kill_raw_socket(s)
287-
raise
288-
except (SSLError, CertificateError) as error:
289-
log.debug(
290-
"[#0000] S: <SECURE FAILURE> %s: %s",
291-
resolved_address,
292-
error,
293-
)
294-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
295-
if s:
296-
cls._kill_raw_socket(s)
297-
raise BoltSecurityError(
298-
message="Failed to establish encrypted connection.",
299-
address=(resolved_address._host_name, local_port),
300-
) from error
301307
except Exception as error:
302308
log.debug(
303309
"[#0000] S: <ERROR> %s %s",
@@ -314,6 +320,8 @@ async def _connect_secure(
314320
) from error
315321
raise
316322

323+
return cls(reader, protocol, writer)
324+
317325
@abc.abstractmethod
318326
async def _handshake(
319327
self,
@@ -463,13 +471,19 @@ def kill(self):
463471

464472
@classmethod
465473
def _connect_secure(
466-
cls, resolved_address, timeout, keep_alive, ssl_context
467-
):
474+
cls,
475+
resolved_address: Address,
476+
timeout: float | None,
477+
deadline: Deadline,
478+
keep_alive: bool,
479+
ssl_context: SSLContext | None,
480+
) -> t.Self:
468481
"""
469482
Connect to the address and return the socket.
470483
471484
:param resolved_address:
472485
:param timeout: seconds
486+
:param deadline: deadline for the whole operation
473487
:param keep_alive: True or False
474488
:returns: socket object
475489
"""
@@ -497,26 +511,14 @@ def _connect_secure(
497511
log.debug("[#0000] C: <OPEN> %s", resolved_address)
498512
s.connect(resolved_address)
499513
s.settimeout(t)
500-
keep_alive = 1 if keep_alive else 0
501-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
514+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
502515
except TimeoutError:
503516
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
504-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
505-
if s:
506-
cls._kill_raw_socket(s)
507517
raise ServiceUnavailable(
508518
"Timed out trying to establish connection to "
509519
f"{resolved_address!r}"
510520
) from None
511521
except Exception as error:
512-
log.debug(
513-
"[#0000] S: <ERROR> %s %s",
514-
type(error).__name__,
515-
" ".join(map(repr, error.args)),
516-
)
517-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
518-
if s:
519-
cls._kill_raw_socket(s)
520522
if isinstance(error, OSError):
521523
raise ServiceUnavailable(
522524
"Failed to establish connection to "
@@ -531,16 +533,17 @@ def _connect_secure(
531533
sni_host = hostname if HAS_SNI and hostname else None
532534
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
533535
try:
536+
t = s.gettimeout()
537+
if timeout:
538+
s.settimeout(deadline.to_timeout())
534539
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
540+
s.settimeout(t)
535541
except (OSError, SSLError, CertificateError) as cause:
536542
log.debug(
537-
"[#0000] S: <SECURE FAILURE> %s: %s",
543+
"[#0000] S: <SECURE FAILURE> %s: %r",
538544
resolved_address,
539545
cause,
540546
)
541-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
542-
if s:
543-
cls._kill_raw_socket(s)
544547
raise BoltSecurityError(
545548
message="Failed to establish encrypted connection.",
546549
address=(hostname, local_port),
@@ -554,15 +557,17 @@ def _connect_secure(
554557
"[#0000] S: <SECURE FAILURE> %s: no certificate",
555558
resolved_address,
556559
)
557-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
558-
if s:
559-
cls._kill_raw_socket(s)
560560
raise BoltProtocolError(
561561
"When using an encrypted socket, the server should"
562562
"always provide a certificate",
563563
address=(hostname, local_port),
564564
)
565-
except Exception:
565+
except Exception as error:
566+
log.debug(
567+
"[#0000] S: <ERROR> %s %s",
568+
type(error).__name__,
569+
" ".join(map(repr, error.args)),
570+
)
566571
if s is not None:
567572
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
568573
cls._kill_raw_socket(s)

src/neo4j/_sync/io/_bolt_socket.py

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)