diff --git a/be/src/common/thread_safety_annotations.h b/be/src/common/thread_safety_annotations.h index 6cd8d4b0cae45c..6bbdb8ce6546ad 100644 --- a/be/src/common/thread_safety_annotations.h +++ b/be/src/common/thread_safety_annotations.h @@ -22,6 +22,7 @@ #pragma once #include +#include #ifdef BE_TEST namespace doris { @@ -93,6 +94,27 @@ class CAPABILITY("mutex") AnnotatedMutex { std::mutex _mutex; }; +// Annotated shared mutex wrapper for use with Clang thread safety analysis. +// Wraps std::shared_mutex and provides both exclusive and shared capability +// operations so GUARDED_BY / REQUIRES_SHARED / etc. can reference it. +class CAPABILITY("mutex") AnnotatedSharedMutex { +public: + void lock() ACQUIRE() { _mutex.lock(); } + void unlock() RELEASE() { _mutex.unlock(); } + bool try_lock() TRY_ACQUIRE(true) { return _mutex.try_lock(); } + + void lock_shared() ACQUIRE_SHARED() { _mutex.lock_shared(); } + void unlock_shared() RELEASE_SHARED() { _mutex.unlock_shared(); } + bool try_lock_shared() TRY_ACQUIRE_SHARED(true) { return _mutex.try_lock_shared(); } + + // Access the underlying std::shared_mutex (e.g., for std::condition_variable_any). + // Use with care — this bypasses thread safety annotations. + std::shared_mutex& native_handle() { return _mutex; } + +private: + std::shared_mutex _mutex; +}; + // RAII scoped lock guard annotated for thread safety analysis. // In BE_TEST builds, injects a random sleep before acquiring and after // releasing the lock to exercise concurrent code paths. @@ -119,6 +141,32 @@ class SCOPED_CAPABILITY LockGuard { MutexType& _mu; }; +// RAII scoped shared lock guard annotated for thread safety analysis. +// In BE_TEST builds, injects a random sleep before acquiring and after +// releasing the lock to exercise concurrent code paths. +template +class SCOPED_CAPABILITY SharedLockGuard { +public: + explicit SharedLockGuard(MutexType& mu) ACQUIRE_SHARED(mu) : _mu(mu) { +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + _mu.lock_shared(); + } + ~SharedLockGuard() RELEASE() { + _mu.unlock_shared(); +#ifdef BE_TEST + doris::mock_random_sleep(); +#endif + } + + SharedLockGuard(const SharedLockGuard&) = delete; + SharedLockGuard& operator=(const SharedLockGuard&) = delete; + +private: + MutexType& _mu; +}; + // RAII unique lock annotated for thread safety analysis. // Supports manual lock/unlock while preserving capability tracking. template diff --git a/be/src/exec/exchange/vdata_stream_mgr.cpp b/be/src/exec/exchange/vdata_stream_mgr.cpp index 70d59ea767309b..1ba53bc2b1b279 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.cpp +++ b/be/src/exec/exchange/vdata_stream_mgr.cpp @@ -43,7 +43,7 @@ VDataStreamMgr::~VDataStreamMgr() { // It will core during graceful stop. auto receivers = std::vector>(); { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); auto receiver_iterator = _receiver_map.begin(); while (receiver_iterator != _receiver_map.end()) { // Could not call close directly, because during close method, it will remove itself @@ -76,22 +76,16 @@ std::shared_ptr VDataStreamMgr::create_recvr( this, memory_used_counter, state, fragment_instance_id, dest_node_id, num_senders, is_merging, profile, data_queue_capacity)); uint32_t hash_value = get_hash_value(fragment_instance_id, dest_node_id); - std::unique_lock l(_lock); + LockGuard l(_lock); _fragment_stream_set.insert(std::make_pair(fragment_instance_id, dest_node_id)); _receiver_map.insert(std::make_pair(hash_value, recvr)); return recvr; } -Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock) { +Status VDataStreamMgr::_find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) { VLOG_ROW << "looking up fragment_instance_id=" << print_id(fragment_instance_id) << ", node=" << node_id; - uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); - // Create lock guard and not own lock currently and will lock conditionally - std::shared_lock recvr_lock(_lock, std::defer_lock); - if (acquire_lock) { - recvr_lock.lock(); - } std::pair range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { @@ -107,6 +101,13 @@ Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNod node_id, print_id(fragment_instance_id)); } +Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, + std::shared_ptr* res) { + SharedLockGuard recvr_lock(_lock); + uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); + return _find_recvr(hash_value, fragment_instance_id, node_id, res); +} + Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request, ::google::protobuf::Closure** done, const int64_t wait_for_worker) { @@ -173,7 +174,7 @@ Status VDataStreamMgr::deregister_recvr(const TUniqueId& fragment_instance_id, P << ", node=" << node_id; uint32_t hash_value = get_hash_value(fragment_instance_id, node_id); { - std::unique_lock l(_lock); + LockGuard l(_lock); auto range = _receiver_map.equal_range(hash_value); while (range.first != range.second) { const std::shared_ptr& recvr = range.first->second; @@ -204,12 +205,13 @@ void VDataStreamMgr::cancel(const TUniqueId& fragment_instance_id, Status exec_s VLOG_QUERY << "cancelling all streams for fragment=" << print_id(fragment_instance_id); std::vector> recvrs; { - std::shared_lock l(_lock); + SharedLockGuard l(_lock); FragmentStreamSet::iterator i = _fragment_stream_set.lower_bound(std::make_pair(fragment_instance_id, 0)); while (i != _fragment_stream_set.end() && i->first == fragment_instance_id) { std::shared_ptr recvr; - WARN_IF_ERROR(find_recvr(i->first, i->second, &recvr, false), ""); + uint32_t hash_value = get_hash_value(i->first, i->second); + WARN_IF_ERROR(_find_recvr(hash_value, i->first, i->second, &recvr), ""); if (recvr == nullptr) { // keep going but at least log it std::stringstream err; diff --git a/be/src/exec/exchange/vdata_stream_mgr.h b/be/src/exec/exchange/vdata_stream_mgr.h index 3825aa1f02b142..d69758c94fd834 100644 --- a/be/src/exec/exchange/vdata_stream_mgr.h +++ b/be/src/exec/exchange/vdata_stream_mgr.h @@ -30,6 +30,7 @@ #include "common/be_mock_util.h" #include "common/global_types.h" #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "runtime/runtime_profile.h" namespace google { @@ -57,8 +58,7 @@ class VDataStreamMgr { RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity); MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, - bool acquire_lock = true); + std::shared_ptr* res); Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id); @@ -68,9 +68,9 @@ class VDataStreamMgr { void cancel(const TUniqueId& fragment_instance_id, Status exec_status); private: - std::shared_mutex _lock; + AnnotatedSharedMutex _lock; using StreamMap = std::unordered_multimap>; - StreamMap _receiver_map; + StreamMap _receiver_map GUARDED_BY(_lock); struct ComparisonOp { bool operator()(const std::pair& a, @@ -88,7 +88,11 @@ class VDataStreamMgr { } }; using FragmentStreamSet = std::set, ComparisonOp>; - FragmentStreamSet _fragment_stream_set; + FragmentStreamSet _fragment_stream_set GUARDED_BY(_lock); + + Status _find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id, + PlanNodeId node_id, std::shared_ptr* res) + REQUIRES_SHARED(_lock); uint32_t get_hash_value(const TUniqueId& fragment_instance_id, PlanNodeId node_id); }; diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp index 75dbe9a130dcff..f8e687f09b7f7b 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.cpp +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.cpp @@ -189,7 +189,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc( auto filter_id = runtime_filter_desc->filter_id; GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } @@ -239,7 +239,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrfilter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); if (iter == _filter_map.end()) { return Status::InvalidArgument("unknown filter id {}", @@ -247,12 +247,12 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptrsecond; - std::unique_lock l(iter->second.mtx); + std::unique_lock l(cnt_val.mtx); // Discard stale-stage runtime filter size requests from old recursive CTE rounds. // Each round increments the stage counter; only messages matching the current stage // should be processed. This prevents old PFC's runtime filters from corrupting // the merge state of the new round's filters. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } cnt_val.source_addrs.push_back(request->source_addr()); @@ -273,7 +273,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptr(); - sync_request->set_stage(iter->second.stage); + sync_request->set_stage(cnt_val.stage); auto callback = HandleErrorBrpcCallback::create_shared( query_ctx->ignore_runtime_filter_error() ? std::weak_ptr {} @@ -343,7 +343,7 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto filter_id = request->filter_id(); std::map::iterator iter; { - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); iter = _filter_map.find(filter_id); VLOG_ROW << "recv filter id:" << request->filter_id() << " " << request->ShortDebugString(); if (iter == _filter_map.end()) { @@ -354,9 +354,9 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr q auto& cnt_val = iter->second; bool is_ready = false; { - std::lock_guard l(iter->second.mtx); + std::lock_guard l(cnt_val.mtx); // Discard stale-stage merge requests from old recursive CTE rounds. - if (request->stage() != iter->second.stage) { + if (request->stage() != cnt_val.stage) { return Status::OK(); } if (cnt_val.merger == nullptr) { @@ -508,7 +508,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( for (const auto& filter_id : filter_ids) { GlobalMergeContext* cnt_val; { - std::unique_lock guard(_filter_map_mutex); + LockGuard guard(_filter_map_mutex); cnt_val = &_filter_map[filter_id]; // may inplace construct default object } RETURN_IF_ERROR(cnt_val->reset(query_ctx)); @@ -518,7 +518,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf( std::string RuntimeFilterMergeControllerEntity::debug_string() { std::string result = "RuntimeFilterMergeControllerEntity Info:\n"; - std::shared_lock guard(_filter_map_mutex); + SharedLockGuard guard(_filter_map_mutex); for (const auto& [filter_id, ctx] : _filter_map) { result += fmt::format("filter_id: {}, stage: {}, {}\n", filter_id, ctx.stage, ctx.merger->debug_string()); diff --git a/be/src/exec/runtime_filter/runtime_filter_mgr.h b/be/src/exec/runtime_filter/runtime_filter_mgr.h index b25d9956ad8d16..536eb63e152caa 100644 --- a/be/src/exec/runtime_filter/runtime_filter_mgr.h +++ b/be/src/exec/runtime_filter/runtime_filter_mgr.h @@ -27,12 +27,11 @@ #include #include #include -#include #include -#include #include #include "common/status.h" +#include "common/thread_safety_annotations.h" #include "util/uid_util.h" namespace butil { @@ -174,7 +173,7 @@ class RuntimeFilterMergeControllerEntity { std::string debug_string(); bool empty() { - std::shared_lock read_lock(_filter_map_mutex); + SharedLockGuard read_lock(_filter_map_mutex); return _filter_map.empty(); } @@ -191,9 +190,9 @@ class RuntimeFilterMergeControllerEntity { int64_t merge_time, PUniqueId query_id, int execution_timeout); // protect _filter_map - std::shared_mutex _filter_map_mutex; + AnnotatedSharedMutex _filter_map_mutex; std::shared_ptr _mem_tracker; - std::map _filter_map; + std::map _filter_map GUARDED_BY(_filter_map_mutex); }; } // namespace doris diff --git a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp index ab6b03b13c5572..f0c4e05c7e6528 100644 --- a/be/test/exec/pipeline/vdata_stream_recvr_test.cpp +++ b/be/test/exec/pipeline/vdata_stream_recvr_test.cpp @@ -577,7 +577,7 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) { struct MockVDataStreamMgr : public VDataStreamMgr { ~MockVDataStreamMgr() override = default; Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id, - std::shared_ptr* res, bool acquire_lock = true) override { + std::shared_ptr* res) override { *res = recvr; return Status::OK(); }