Skip to content

Commit 6aadcf1

Browse files
authored
agent: lock rows for concurrent queries in PostgreSQL (#1688)
* agent: lock rows for concurrent queries in PostgreSQL * fix race conditions in workers * refactor
1 parent 07604a1 commit 6aadcf1

8 files changed

Lines changed: 296 additions & 136 deletions

File tree

src/Simplex/FileTransfer/Agent.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ runXFTPRcvWorker c srv Worker {doWork} = do
223223
agentXFTPDownloadChunk c userId digest replica chunkSpec
224224
liftIO $ waitUntilForeground c
225225
(entityId, complete, progress) <- withStore c $ \db -> runExceptT $ do
226+
liftIO $ lockRcvFileForUpdate db rcvFileId
226227
liftIO $ updateRcvFileChunkReceived db (rcvChunkReplicaId replica) rcvChunkId relChunkPath
227228
RcvFile {size = FileSize currentSize, chunks, redirect} <- ExceptT $ getRcvFile db rcvFileId
228229
let rcvd = receivedSize chunks
@@ -413,6 +414,7 @@ runXFTPSndPrepareWorker c Worker {doWork} = do
413414
withStore' c $ \db -> updateSndFileStatus db sndFileId SFSEncrypting
414415
(digest, chunkSpecsDigests) <- encryptFileForUpload sndFile fsEncPath
415416
withStore c $ \db -> do
417+
lockSndFileForUpdate db sndFileId
416418
updateSndFileEncrypted db sndFileId digest chunkSpecsDigests
417419
getSndFile db sndFileId
418420
else pure sndFile
@@ -530,6 +532,7 @@ runXFTPSndWorker c srv Worker {doWork} = do
530532
agentXFTPUploadChunk c userId chunkDigest replica' chunkSpec'
531533
liftIO $ waitUntilForeground c
532534
sf@SndFile {sndFileEntityId, prefixPath, chunks} <- withStore c $ \db -> do
535+
lockSndFileForUpdate db sndFileId
533536
updateSndChunkReplicaStatus db sndChunkReplicaId SFRSUploaded
534537
getSndFile db sndFileId
535538
let uploaded = uploadedSize chunks

src/Simplex/Messaging/Agent.hs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,8 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
11451145
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, lastExternalSndId = 0, deleted = False, ratchetSyncState = RSOk, pqSupport}
11461146
case sq_ of
11471147
Just sq@SndQueue {e2ePubKey = Just _k} -> do
1148-
e2eSndParams <- withStore c $ \db ->
1148+
e2eSndParams <- withStore c $ \db -> do
1149+
lockConnForUpdate db connId
11491150
getSndRatchet db connId v >>= \case
11501151
Right r -> pure $ Right $ snd r
11511152
Left e -> do
@@ -1159,6 +1160,7 @@ startJoinInvitation c userId connId sq_ enableNtfs cReqUri pqSup =
11591160
sndKey_ = snd <$> invLink_
11601161
(q, _) <- lift $ newSndQueue userId "" qInfo sndKey_
11611162
withStore c $ \db -> runExceptT $ do
1163+
liftIO $ lockConnForUpdate db connId
11621164
e2eSndParams <- createRatchet_ db g maxSupported pqSupport e2eRcvParams
11631165
sq' <- maybe (ExceptT $ updateNewConnSnd db connId q) pure sq_
11641166
pure (cData, sq', e2eSndParams, lnkId_)
@@ -1237,7 +1239,8 @@ joinConnSrv c nm userId connId enableNtfs cReqUri@CRContactUri {} cInfo pqSup su
12371239
AgentConfig {smpClientVRange = vr, smpAgentVRange, e2eEncryptVRange = e2eVR} <- asks config
12381240
let qUri = SMPQueueUri vr $ (rcvSMPQueueAddress rq) {queueMode = Just QMMessaging}
12391241
crData = ConnReqUriData SSSimplex smpAgentVRange [qUri] Nothing
1240-
e2eRcvParams <- withStore' c $ \db ->
1242+
e2eRcvParams <- withStore' c $ \db -> do
1243+
lockConnForUpdate db connId
12411244
getRatchetX3dhKeys db connId >>= \case
12421245
Right keys -> pure $ CR.mkRcvE2ERatchetParams (maxVersion e2eVR) keys
12431246
Left e -> do
@@ -1957,7 +1960,7 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
19571960
withRetryLock2 ri' qLock $ \riState loop -> do
19581961
liftIO $ waitWhileSuspended c
19591962
liftIO $ waitForUserNetwork c
1960-
resp <- tryError $ case msgType of
1963+
resp <- tryAllErrors $ case msgType of
19611964
AM_CONN_INFO -> sendConfirmation c NRMBackground sq msgBody
19621965
AM_CONN_INFO_REPLY -> sendConfirmation c NRMBackground sq msgBody
19631966
_ -> case pendingMsgPrepData_ of
@@ -2097,10 +2100,12 @@ runSmpQueueMsgDelivery c@AgentClient {subQ} sq@SndQueue {userId, connId, server,
20972100
notifyDelMsgs :: InternalId -> AgentErrorType -> UTCTime -> AM ()
20982101
notifyDelMsgs msgId err expireTs = do
20992102
notifyDel msgId $ MERR (unId msgId) err
2100-
msgIds_ <- withStore' c $ \db -> getExpiredSndMessages db connId sq expireTs
2103+
msgIds_ <- withStore' c $ \db -> do
2104+
msgIds_ <- getExpiredSndMessages db connId sq expireTs
2105+
forM_ msgIds_ $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
2106+
pure msgIds_
21012107
forM_ (L.nonEmpty msgIds_) $ \msgIds -> do
21022108
notify $ MERRS (L.map unId msgIds) err
2103-
withStore' c $ \db -> forM_ msgIds $ \msgId' -> deleteSndMsgDelivery db connId sq msgId' False `catchAll_` pure ()
21042109
atomically $ incSMPServerStat' c userId server sentExpiredErrs (length msgIds_ + 1)
21052110
delMsg :: InternalId -> AM ()
21062111
delMsg = delMsgKeep False
@@ -3025,7 +3030,8 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
30253030
throwE e
30263031
agentClientMsg :: TVar ChaChaDRG -> ByteString -> AM (Maybe (InternalId, MsgMeta, AMessage, CR.RatchetX448))
30273032
agentClientMsg g encryptedMsgHash = withStore c $ \db -> runExceptT $ do
3028-
rc <- ExceptT $ getRatchet db connId -- ratchet state pre-decryption - required for processing EREADY
3033+
liftIO $ lockConnForUpdate db connId
3034+
rc <- ExceptT $ getRatchetForUpdate db connId -- ratchet state pre-decryption - required for processing EREADY
30293035
(agentMsgBody, pqEncryption) <- agentRatchetDecrypt' g db connId rc encAgentMessage
30303036
liftEither (parse smpP (SEAgentError $ AGENT A_MESSAGE) agentMsgBody) >>= \case
30313037
agentMsg@(AgentMessage APrivHeader {sndMsgId, prevMsgHash} aMessage) -> do
@@ -3260,6 +3266,7 @@ processSMPTransmissions c@AgentClient {subQ} (tSess@(userId, srv, _), _v, sessId
32603266
Just sqs' -> do
32613267
(sq_@SndQueue {sndPrivateKey}, dhPublicKey) <- lift $ newSndQueue userId connId qInfo Nothing
32623268
sq2 <- withStore c $ \db -> do
3269+
lockConnForUpdate db connId
32633270
liftIO $ mapM_ (deleteConnSndQueue db connId) delSqs
32643271
addConnSndQueue db connId (sq_ :: NewSndQueue) {primary = True, dbReplaceQueueId = Just dbQueueId}
32653272
logServer "<--" c srv rId $ "MSG <QADD>:" <> logSecret' srvMsgId <> " " <> logSecret (senderId queueAddress)
@@ -3564,7 +3571,7 @@ agentRatchetEncrypt db cData msg getPaddedLen pqEnc_ currentE2EVersion = do
35643571

35653572
agentRatchetEncryptHeader :: DB.Connection -> ConnData -> (VersionSMPA -> PQSupport -> Int) -> Maybe PQEncryption -> CR.VersionE2E -> ExceptT StoreError IO (CR.MsgEncryptKeyX448, Int, PQEncryption)
35663573
agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport} getPaddedLen pqEnc_ currentE2EVersion = do
3567-
rc <- ExceptT $ getRatchet db connId
3574+
rc <- ExceptT $ getRatchetForUpdate db connId
35683575
let paddedLen = getPaddedLen v pqSupport
35693576
(mek, rc') <- withExceptT (SEAgentError . cryptoError) $ CR.rcEncryptHeader rc pqEnc_ currentE2EVersion
35703577
liftIO $ updateRatchet db connId rc' CR.SMDNoChange
@@ -3573,7 +3580,7 @@ agentRatchetEncryptHeader db ConnData {connId, connAgentVersion = v, pqSupport}
35733580
-- encoded EncAgentMessage -> encoded AgentMessage
35743581
agentRatchetDecrypt :: TVar ChaChaDRG -> DB.Connection -> ConnId -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)
35753582
agentRatchetDecrypt g db connId encAgentMsg = do
3576-
rc <- ExceptT $ getRatchet db connId
3583+
rc <- ExceptT $ getRatchetForUpdate db connId
35773584
agentRatchetDecrypt' g db connId rc encAgentMsg
35783585

35793586
agentRatchetDecrypt' :: TVar ChaChaDRG -> DB.Connection -> ConnId -> CR.RatchetX448 -> ByteString -> ExceptT StoreError IO (ByteString, PQEncryption)

src/Simplex/Messaging/Agent/Client.hs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,39 +2114,42 @@ withWork :: AgentClient -> TMVar () -> (DB.Connection -> IO (Either StoreError (
21142114
withWork c doWork = withWork_ c doWork . withStore' c
21152115
{-# INLINE withWork #-}
21162116

2117+
-- setting doWork flag to "no work" before getWork rather than after prevents race condition when flag is set to "has work" by another thread after getWork call.
21172118
withWork_ :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' (Maybe a)) -> (a -> ExceptT e m ()) -> ExceptT e m ()
21182119
withWork_ c doWork getWork action =
2119-
getWork >>= \case
2120-
Right (Just r) -> action r
2121-
Right Nothing -> noWork
2122-
-- worker is stopped here (noWork) because the next iteration is likely to produce the same result
2120+
noWork >> getWork >>= \case
2121+
Right (Just r) -> hasWork >> action r
2122+
Right Nothing -> pure ()
21232123
Left e
2124-
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
2125-
| otherwise -> notifyErr INTERNAL e
2124+
| isWorkItemError e -> notifyErr (CRITICAL False) e -- worker remains stopped here because the next iteration is likely to produce the same result
2125+
| otherwise -> hasWork >> notifyErr INTERNAL e
21262126
where
2127+
hasWork = atomically $ hasWorkToDo' doWork
21272128
noWork = liftIO $ noWorkToDo doWork
21282129
notifyErr err e = do
21292130
logError $ "withWork_ error: " <> tshow e
21302131
atomically $ writeTBQueue (subQ c) ("", "", AEvt SAEConn $ ERR $ err $ show e)
21312132

21322133
withWorkItems :: (AnyStoreError e', MonadIO m) => AgentClient -> TMVar () -> ExceptT e m (Either e' [Either e' a]) -> (NonEmpty a -> ExceptT e m ()) -> ExceptT e m ()
21332134
withWorkItems c doWork getWork action = do
2134-
getWork >>= \case
2135-
Right [] -> noWork
2135+
noWork >> getWork >>= \case
2136+
Right [] -> pure ()
21362137
Right rs -> do
21372138
let (errs, items) = partitionEithers rs
21382139
case L.nonEmpty items of
2139-
Just items' -> action items'
2140+
Just items' -> hasWork >> action items'
21402141
Nothing -> do
2141-
let criticalErr = find isWorkItemError errs
2142-
forM_ criticalErr $ \err -> do
2143-
notifyErr (CRITICAL False) err
2144-
when (all isWorkItemError errs) noWork
2142+
case find isWorkItemError errs of
2143+
Nothing -> hasWork
2144+
Just err -> do
2145+
notifyErr (CRITICAL False) err
2146+
unless (all isWorkItemError errs) hasWork
21452147
forM_ (L.nonEmpty errs) $ notifySub c . ERRS . L.map (\e -> ("", INTERNAL $ show e))
21462148
Left e
2147-
| isWorkItemError e -> noWork >> notifyErr (CRITICAL False) e
2148-
| otherwise -> notifyErr INTERNAL e
2149+
| isWorkItemError e -> notifyErr (CRITICAL False) e
2150+
| otherwise -> hasWork >> notifyErr INTERNAL e
21492151
where
2152+
hasWork = atomically $ hasWorkToDo' doWork
21502153
noWork = liftIO $ noWorkToDo doWork
21512154
notifyErr err e = do
21522155
logError $ "withWorkItems error: " <> tshow e

0 commit comments

Comments
 (0)