Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions be/src/common/thread_safety_annotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#pragma once

#include <mutex>
#include <shared_mutex>

#ifdef BE_TEST
namespace doris {
Expand Down Expand Up @@ -93,6 +94,27 @@ class CAPABILITY("mutex") AnnotatedMutex {
std::mutex _mutex;
};

// Annotated shared mutex wrapper for use with Clang thread safety analysis.
// Wraps std::shared_mutex and provides both exclusive and shared capability
// operations so GUARDED_BY / REQUIRES_SHARED / etc. can reference it.
class CAPABILITY("mutex") AnnotatedSharedMutex {
public:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们使用这个封装的API,跟直接使用lock的各种方法是什么区别?

void lock() ACQUIRE() { _mutex.lock(); }
void unlock() RELEASE() { _mutex.unlock(); }
bool try_lock() TRY_ACQUIRE(true) { return _mutex.try_lock(); }

void lock_shared() ACQUIRE_SHARED() { _mutex.lock_shared(); }
void unlock_shared() RELEASE_SHARED() { _mutex.unlock_shared(); }
bool try_lock_shared() TRY_ACQUIRE_SHARED(true) { return _mutex.try_lock_shared(); }

// Access the underlying std::shared_mutex (e.g., for std::condition_variable_any).
// Use with care — this bypasses thread safety annotations.
std::shared_mutex& native_handle() { return _mutex; }

private:
std::shared_mutex _mutex;
};

// RAII scoped lock guard annotated for thread safety analysis.
// In BE_TEST builds, injects a random sleep before acquiring and after
// releasing the lock to exercise concurrent code paths.
Expand All @@ -119,6 +141,32 @@ class SCOPED_CAPABILITY LockGuard {
MutexType& _mu;
};

// RAII scoped shared lock guard annotated for thread safety analysis.
// In BE_TEST builds, injects a random sleep before acquiring and after
// releasing the lock to exercise concurrent code paths.
template <typename MutexType>
class SCOPED_CAPABILITY SharedLockGuard {
public:
explicit SharedLockGuard(MutexType& mu) ACQUIRE_SHARED(mu) : _mu(mu) {
#ifdef BE_TEST
doris::mock_random_sleep();
#endif
_mu.lock_shared();
}
~SharedLockGuard() RELEASE() {
_mu.unlock_shared();
#ifdef BE_TEST
doris::mock_random_sleep();
#endif
}

SharedLockGuard(const SharedLockGuard&) = delete;
SharedLockGuard& operator=(const SharedLockGuard&) = delete;

private:
MutexType& _mu;
};

// RAII unique lock annotated for thread safety analysis.
// Supports manual lock/unlock while preserving capability tracking.
template <typename MutexType>
Expand Down
28 changes: 15 additions & 13 deletions be/src/exec/exchange/vdata_stream_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ VDataStreamMgr::~VDataStreamMgr() {
// It will core during graceful stop.
auto receivers = std::vector<std::shared_ptr<VDataStreamRecvr>>();
{
std::shared_lock l(_lock);
SharedLockGuard l(_lock);
auto receiver_iterator = _receiver_map.begin();
while (receiver_iterator != _receiver_map.end()) {
// Could not call close directly, because during close method, it will remove itself
Expand Down Expand Up @@ -76,22 +76,16 @@ std::shared_ptr<VDataStreamRecvr> VDataStreamMgr::create_recvr(
this, memory_used_counter, state, fragment_instance_id, dest_node_id, num_senders,
is_merging, profile, data_queue_capacity));
uint32_t hash_value = get_hash_value(fragment_instance_id, dest_node_id);
std::unique_lock l(_lock);
LockGuard l(_lock);
_fragment_stream_set.insert(std::make_pair(fragment_instance_id, dest_node_id));
_receiver_map.insert(std::make_pair(hash_value, recvr));
return recvr;
}

Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res, bool acquire_lock) {
Status VDataStreamMgr::_find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id,
PlanNodeId node_id, std::shared_ptr<VDataStreamRecvr>* res) {
VLOG_ROW << "looking up fragment_instance_id=" << print_id(fragment_instance_id)
<< ", node=" << node_id;
uint32_t hash_value = get_hash_value(fragment_instance_id, node_id);
// Create lock guard and not own lock currently and will lock conditionally
std::shared_lock recvr_lock(_lock, std::defer_lock);
if (acquire_lock) {
recvr_lock.lock();
}
std::pair<StreamMap::iterator, StreamMap::iterator> range =
_receiver_map.equal_range(hash_value);
while (range.first != range.second) {
Expand All @@ -107,6 +101,13 @@ Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNod
node_id, print_id(fragment_instance_id));
}

Status VDataStreamMgr::find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res) {
SharedLockGuard recvr_lock(_lock);
uint32_t hash_value = get_hash_value(fragment_instance_id, node_id);
return _find_recvr(hash_value, fragment_instance_id, node_id, res);
}

Status VDataStreamMgr::transmit_block(const PTransmitDataParams* request,
::google::protobuf::Closure** done,
const int64_t wait_for_worker) {
Expand Down Expand Up @@ -173,7 +174,7 @@ Status VDataStreamMgr::deregister_recvr(const TUniqueId& fragment_instance_id, P
<< ", node=" << node_id;
uint32_t hash_value = get_hash_value(fragment_instance_id, node_id);
{
std::unique_lock l(_lock);
LockGuard l(_lock);
auto range = _receiver_map.equal_range(hash_value);
while (range.first != range.second) {
const std::shared_ptr<VDataStreamRecvr>& recvr = range.first->second;
Expand Down Expand Up @@ -204,12 +205,13 @@ void VDataStreamMgr::cancel(const TUniqueId& fragment_instance_id, Status exec_s
VLOG_QUERY << "cancelling all streams for fragment=" << print_id(fragment_instance_id);
std::vector<std::shared_ptr<VDataStreamRecvr>> recvrs;
{
std::shared_lock l(_lock);
SharedLockGuard l(_lock);
FragmentStreamSet::iterator i =
_fragment_stream_set.lower_bound(std::make_pair(fragment_instance_id, 0));
while (i != _fragment_stream_set.end() && i->first == fragment_instance_id) {
std::shared_ptr<VDataStreamRecvr> recvr;
WARN_IF_ERROR(find_recvr(i->first, i->second, &recvr, false), "");
uint32_t hash_value = get_hash_value(i->first, i->second);
WARN_IF_ERROR(_find_recvr(hash_value, i->first, i->second, &recvr), "");
if (recvr == nullptr) {
// keep going but at least log it
std::stringstream err;
Expand Down
14 changes: 9 additions & 5 deletions be/src/exec/exchange/vdata_stream_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "common/be_mock_util.h"
#include "common/global_types.h"
#include "common/status.h"
#include "common/thread_safety_annotations.h"
#include "runtime/runtime_profile.h"

namespace google {
Expand Down Expand Up @@ -57,8 +58,7 @@ class VDataStreamMgr {
RuntimeProfile* profile, bool is_merging, size_t data_queue_capacity);

MOCK_FUNCTION Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res,
bool acquire_lock = true);
std::shared_ptr<VDataStreamRecvr>* res);

Status deregister_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id);

Expand All @@ -68,9 +68,9 @@ class VDataStreamMgr {
void cancel(const TUniqueId& fragment_instance_id, Status exec_status);

private:
std::shared_mutex _lock;
AnnotatedSharedMutex _lock;
using StreamMap = std::unordered_multimap<uint32_t, std::shared_ptr<VDataStreamRecvr>>;
StreamMap _receiver_map;
StreamMap _receiver_map GUARDED_BY(_lock);

struct ComparisonOp {
bool operator()(const std::pair<doris::TUniqueId, PlanNodeId>& a,
Expand All @@ -88,7 +88,11 @@ class VDataStreamMgr {
}
};
using FragmentStreamSet = std::set<std::pair<TUniqueId, PlanNodeId>, ComparisonOp>;
FragmentStreamSet _fragment_stream_set;
FragmentStreamSet _fragment_stream_set GUARDED_BY(_lock);

Status _find_recvr(uint32_t hash_value, const TUniqueId& fragment_instance_id,
PlanNodeId node_id, std::shared_ptr<VDataStreamRecvr>* res)
REQUIRES_SHARED(_lock);

uint32_t get_hash_value(const TUniqueId& fragment_instance_id, PlanNodeId node_id);
};
Expand Down
20 changes: 10 additions & 10 deletions be/src/exec/runtime_filter/runtime_filter_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc(
auto filter_id = runtime_filter_desc->filter_id;
GlobalMergeContext* cnt_val;
{
std::unique_lock<std::shared_mutex> guard(_filter_map_mutex);
LockGuard guard(_filter_map_mutex);
cnt_val = &_filter_map[filter_id]; // may inplace construct default object
}

Expand Down Expand Up @@ -239,20 +239,20 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptr<Quer
auto filter_id = request->filter_id();
std::map<int, GlobalMergeContext>::iterator iter;
{
std::shared_lock<std::shared_mutex> guard(_filter_map_mutex);
SharedLockGuard guard(_filter_map_mutex);
iter = _filter_map.find(filter_id);
if (iter == _filter_map.end()) {
return Status::InvalidArgument("unknown filter id {}",
std::to_string(request->filter_id()));
}
}
auto& cnt_val = iter->second;
std::unique_lock<std::mutex> l(iter->second.mtx);
std::unique_lock<std::mutex> l(cnt_val.mtx);
// Discard stale-stage runtime filter size requests from old recursive CTE rounds.
// Each round increments the stage counter; only messages matching the current stage
// should be processed. This prevents old PFC's runtime filters from corrupting
// the merge state of the new round's filters.
if (request->stage() != iter->second.stage) {
if (request->stage() != cnt_val.stage) {
return Status::OK();
}
cnt_val.source_addrs.push_back(request->source_addr());
Expand All @@ -273,7 +273,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(std::shared_ptr<Quer
}

auto sync_request = std::make_shared<PSyncFilterSizeRequest>();
sync_request->set_stage(iter->second.stage);
sync_request->set_stage(cnt_val.stage);

auto callback = HandleErrorBrpcCallback<PSyncFilterSizeResponse>::create_shared(
query_ctx->ignore_runtime_filter_error() ? std::weak_ptr<QueryContext> {}
Expand Down Expand Up @@ -343,7 +343,7 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr<QueryContext> q
auto filter_id = request->filter_id();
std::map<int, GlobalMergeContext>::iterator iter;
{
std::shared_lock<std::shared_mutex> guard(_filter_map_mutex);
SharedLockGuard guard(_filter_map_mutex);
iter = _filter_map.find(filter_id);
VLOG_ROW << "recv filter id:" << request->filter_id() << " " << request->ShortDebugString();
if (iter == _filter_map.end()) {
Expand All @@ -354,9 +354,9 @@ Status RuntimeFilterMergeControllerEntity::merge(std::shared_ptr<QueryContext> q
auto& cnt_val = iter->second;
bool is_ready = false;
{
std::lock_guard<std::mutex> l(iter->second.mtx);
std::lock_guard<std::mutex> l(cnt_val.mtx);
// Discard stale-stage merge requests from old recursive CTE rounds.
if (request->stage() != iter->second.stage) {
if (request->stage() != cnt_val.stage) {
return Status::OK();
}
if (cnt_val.merger == nullptr) {
Expand Down Expand Up @@ -508,7 +508,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf(
for (const auto& filter_id : filter_ids) {
GlobalMergeContext* cnt_val;
{
std::unique_lock<std::shared_mutex> guard(_filter_map_mutex);
LockGuard guard(_filter_map_mutex);
cnt_val = &_filter_map[filter_id]; // may inplace construct default object
}
RETURN_IF_ERROR(cnt_val->reset(query_ctx));
Expand All @@ -518,7 +518,7 @@ Status RuntimeFilterMergeControllerEntity::reset_global_rf(

std::string RuntimeFilterMergeControllerEntity::debug_string() {
std::string result = "RuntimeFilterMergeControllerEntity Info:\n";
std::shared_lock<std::shared_mutex> guard(_filter_map_mutex);
SharedLockGuard guard(_filter_map_mutex);
for (const auto& [filter_id, ctx] : _filter_map) {
result += fmt::format("filter_id: {}, stage: {}, {}\n", filter_id, ctx.stage,
ctx.merger->debug_string());
Expand Down
9 changes: 4 additions & 5 deletions be/src/exec/runtime_filter/runtime_filter_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
#include <map>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <unordered_set>
#include <utility>
#include <vector>

#include "common/status.h"
#include "common/thread_safety_annotations.h"
#include "util/uid_util.h"

namespace butil {
Expand Down Expand Up @@ -174,7 +173,7 @@ class RuntimeFilterMergeControllerEntity {
std::string debug_string();

bool empty() {
std::shared_lock<std::shared_mutex> read_lock(_filter_map_mutex);
SharedLockGuard read_lock(_filter_map_mutex);
return _filter_map.empty();
}

Expand All @@ -191,9 +190,9 @@ class RuntimeFilterMergeControllerEntity {
int64_t merge_time, PUniqueId query_id, int execution_timeout);

// protect _filter_map
std::shared_mutex _filter_map_mutex;
AnnotatedSharedMutex _filter_map_mutex;
std::shared_ptr<MemTracker> _mem_tracker;

std::map<int, GlobalMergeContext> _filter_map;
std::map<int, GlobalMergeContext> _filter_map GUARDED_BY(_filter_map_mutex);
};
} // namespace doris
2 changes: 1 addition & 1 deletion be/test/exec/pipeline/vdata_stream_recvr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ TEST_F(DataStreamRecvrTest, TestRemoteLocalMultiSender) {
struct MockVDataStreamMgr : public VDataStreamMgr {
~MockVDataStreamMgr() override = default;
Status find_recvr(const TUniqueId& fragment_instance_id, PlanNodeId node_id,
std::shared_ptr<VDataStreamRecvr>* res, bool acquire_lock = true) override {
std::shared_ptr<VDataStreamRecvr>* res) override {
*res = recvr;
return Status::OK();
}
Expand Down
Loading