Unverified Commit 3bc17d8c authored by 王翔宇's avatar 王翔宇 Committed by GitHub
Browse files

Fix multi client search crash in tracing module fix #1789 fix #1832 (#1899)



* Fix multi client search crash in tracing module fix #1789 fix #1832

Signed-off-by: default avatarwxyu <xy.wang@zilliz.com>

* add lock for every context_map_ access

Signed-off-by: default avatarwxyu <xy.wang@zilliz.com>

* remove never used variable

Signed-off-by: default avatarwxyu <xy.wang@zilliz.com>
parent c8a59b27
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@ Please mark all change in change log and use the issue from GitHub
## Bug
-   \#1276 SQLite throw exception after create 50000+ partitions in a table
-   \#1762 Server is not forbidden to create new partition which tag is `_default`
-   \#1789 Fix multi-client search cause server crash
-   \#1832 Fix crash in tracing module
-   \#1873 Fix index file serialize to incorrect path
-   \#1881 Fix Annoy index search failure

+1 −0
Original line number Diff line number Diff line
@@ -301,6 +301,7 @@ if (DEFINED ENV{MILVUS_GRPC_URL})
    set(GRPC_SOURCE_URL "$ENV{MILVUS_GRPC_URL}")
else ()
    set(GRPC_SOURCE_URL
            "https://github.com/milvus-io/grpc-milvus/archive/${GRPC_VERSION}.zip"
            "https://github.com/youny626/grpc-milvus/archive/${GRPC_VERSION}.zip"
            "https://gitee.com/quicksilver/grpc-milvus/repository/archive/${GRPC_VERSION}.zip")
endif ()
+115 −33
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@

#include <fiu-local.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

@@ -158,6 +159,50 @@ ConstructCollectionInfo(const CollectionInfo& collection_info, ::milvus::grpc::C

}  // namespace

namespace {

#define REQ_ID ("request_id")

std::atomic<int64_t> _sequential_id;

int64_t
get_sequential_id() {
    return _sequential_id++;
}

void
set_request_id(::grpc::ServerContext* context, const std::string& request_id) {
    if (not context) {
        // error
        SERVER_LOG_ERROR << "set_request_id: grpc::ServerContext is nullptr" << std::endl;
        return;
    }

    context->AddInitialMetadata(REQ_ID, request_id);
}

std::string
get_request_id(::grpc::ServerContext* context) {
    if (not context) {
        // error
        SERVER_LOG_ERROR << "get_request_id: grpc::ServerContext is nullptr" << std::endl;
        return "INVALID_ID";
    }

    auto server_metadata = context->server_metadata();

    auto request_id_kv = server_metadata.find(REQ_ID);
    if (request_id_kv == server_metadata.end()) {
        // error
        SERVER_LOG_ERROR << std::string(REQ_ID) << " not found in grpc.server_metadata" << std::endl;
        return "INVALID_ID";
    }

    return request_id_kv->second.data();
}

}  // namespace

GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
    : tracer_(tracer), random_num_generator_() {
    std::random_device random_device;
@@ -187,16 +232,42 @@ GrpcRequestHandler::OnPostRecvInitialMetaData(
        return;
    }
    auto span = tracer_->StartSpan(server_rpc_info->method(), {opentracing::ChildOf(span_context_maybe->get())});

    auto server_context = server_rpc_info->server_context();
    auto client_metadata = server_context->client_metadata();
    // TODO: request id

    // if client provide request_id in metadata, milvus just use it,
    // else milvus generate a sequential id.
    std::string request_id;
    auto request_id_kv = client_metadata.find("request_id");
    if (request_id_kv != client_metadata.end()) {
        request_id = request_id_kv->second.data();
        SERVER_LOG_DEBUG << "client provide request_id: " << request_id;

        // if request_id is being used by another request,
        // convert it to request_id_n.
        std::lock_guard<std::mutex> lock(context_map_mutex_);
        if (context_map_.find(request_id) == context_map_.end()) {
            // if not found exist, mark
            context_map_[request_id] = nullptr;
        } else {
            // Finding a unused suffix
            int64_t suffix = 1;
            std::string try_request_id;
            bool exist = true;
            do {
                try_request_id = request_id + "_" + std::to_string(suffix);
                exist = context_map_.find(try_request_id) != context_map_.end();
                suffix++;
            } while (exist);
            context_map_[try_request_id] = nullptr;
        }
    } else {
        request_id = std::to_string(random_id()) + std::to_string(random_id());
        request_id = std::to_string(get_sequential_id());
        set_request_id(server_context, request_id);
        SERVER_LOG_DEBUG << "milvus generate request_id: " << request_id;
    }

    auto trace_context = std::make_shared<tracing::TraceContext>(span);
    auto context = std::make_shared<Context>(request_id);
    context->SetTraceContext(trace_context);
@@ -207,23 +278,33 @@ void
GrpcRequestHandler::OnPreSendMessage(::grpc::experimental::ServerRpcInfo* server_rpc_info,
                                     ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
    std::lock_guard<std::mutex> lock(context_map_mutex_);
    context_map_[server_rpc_info->server_context()]->GetTraceContext()->GetSpan()->Finish();
    auto search = context_map_.find(server_rpc_info->server_context());
    if (search != context_map_.end()) {
        context_map_.erase(search);
    auto request_id = get_request_id(server_rpc_info->server_context());

    if (context_map_.find(request_id) == context_map_.end()) {
        // error
        SERVER_LOG_ERROR << "request_id " << request_id << " not found in context_map_";
        return;
    }
    context_map_[request_id]->GetTraceContext()->GetSpan()->Finish();
    context_map_.erase(request_id);
}

const std::shared_ptr<Context>&
GrpcRequestHandler::GetContext(::grpc::ServerContext* server_context) {
    std::lock_guard<std::mutex> lock(context_map_mutex_);
    return context_map_[server_context];
    auto request_id = get_request_id(server_context);
    if (context_map_.find(request_id) == context_map_.end()) {
        SERVER_LOG_ERROR << "GetContext: request_id " << request_id << " not found in context_map_";
        return nullptr;
    }
    return context_map_[request_id];
}

void
GrpcRequestHandler::SetContext(::grpc::ServerContext* server_context, const std::shared_ptr<Context>& context) {
    std::lock_guard<std::mutex> lock(context_map_mutex_);
    context_map_[server_context] = context;
    auto request_id = get_request_id(server_context);
    context_map_[request_id] = context;
}

uint64_t
@@ -244,7 +325,7 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil
    CHECK_NULLPTR_RETURN(request);

    Status status =
        request_handler_.CreateCollection(context_map_[context], request->collection_name(), request->dimension(),
        request_handler_.CreateCollection(GetContext(context), request->collection_name(), request->dimension(),
                                          request->index_file_size(), request->metric_type());
    SET_RESPONSE(response, status, context);

@@ -258,7 +339,7 @@ GrpcRequestHandler::HasCollection(::grpc::ServerContext* context, const ::milvus

    bool has_collection = false;

    Status status = request_handler_.HasCollection(context_map_[context], request->collection_name(), has_collection);
    Status status = request_handler_.HasCollection(GetContext(context), request->collection_name(), has_collection);
    response->set_bool_reply(has_collection);
    SET_RESPONSE(response->mutable_status(), status, context);

@@ -270,7 +351,7 @@ GrpcRequestHandler::DropCollection(::grpc::ServerContext* context, const ::milvu
                                   ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.DropCollection(context_map_[context], request->collection_name());
    Status status = request_handler_.DropCollection(GetContext(context), request->collection_name());

    SET_RESPONSE(response, status, context);
    return ::grpc::Status::OK;
@@ -289,8 +370,8 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::
        }
    }

    Status status = request_handler_.CreateIndex(context_map_[context], request->collection_name(),
                                                 request->index_type(), json_params);
    Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->index_type(),
                                                 json_params);

    SET_RESPONSE(response, status, context);
    return ::grpc::Status::OK;
@@ -309,7 +390,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:

    // step 2: insert vectors
    Status status =
        request_handler_.Insert(context_map_[context], request->collection_name(), vectors, request->partition_tag());
        request_handler_.Insert(GetContext(context), request->collection_name(), vectors, request->partition_tag());

    // step 3: return id array
    response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
@@ -329,7 +410,7 @@ GrpcRequestHandler::GetVectorByID(::grpc::ServerContext* context, const ::milvus
    std::vector<int64_t> vector_ids = {request->id()};
    engine::VectorsData vectors;
    Status status =
        request_handler_.GetVectorByID(context_map_[context], request->collection_name(), vector_ids, vectors);
        request_handler_.GetVectorByID(GetContext(context), request->collection_name(), vector_ids, vectors);

    if (!vectors.float_data_.empty()) {
        response->mutable_vector_data()->mutable_float_data()->Resize(vectors.float_data_.size(), 0);
@@ -351,7 +432,7 @@ GrpcRequestHandler::GetVectorIDs(::grpc::ServerContext* context, const ::milvus:
    CHECK_NULLPTR_RETURN(request);

    std::vector<int64_t> vector_ids;
    Status status = request_handler_.GetVectorIDs(context_map_[context], request->collection_name(),
    Status status = request_handler_.GetVectorIDs(GetContext(context), request->collection_name(),
                                                  request->segment_name(), vector_ids);

    if (!vector_ids.empty()) {
@@ -393,7 +474,8 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
    std::vector<std::string> file_ids;
    TopKQueryResult result;
    fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
    Status status = request_handler_.Search(context_map_[context], request->collection_name(), vectors, request->topk(),

    Status status = request_handler_.Search(GetContext(context), request->collection_name(), vectors, request->topk(),
                                            json_params, partitions, file_ids, result);

    // step 5: construct and return result
@@ -428,7 +510,7 @@ GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::g

    // step 3: search vectors
    TopKQueryResult result;
    Status status = request_handler_.SearchByID(context_map_[context], request->collection_name(), request->id(),
    Status status = request_handler_.SearchByID(GetContext(context), request->collection_name(), request->id(),
                                                request->topk(), json_params, partitions, result);

    // step 4: construct and return result
@@ -474,7 +556,7 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus

    // step 5: search vectors
    TopKQueryResult result;
    Status status = request_handler_.Search(context_map_[context], search_request->collection_name(), vectors,
    Status status = request_handler_.Search(GetContext(context), search_request->collection_name(), vectors,
                                            search_request->topk(), json_params, partitions, file_ids, result);

    // step 6: construct and return result
@@ -492,7 +574,7 @@ GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::m

    CollectionSchema collection_schema;
    Status status =
        request_handler_.DescribeCollection(context_map_[context], request->collection_name(), collection_schema);
        request_handler_.DescribeCollection(GetContext(context), request->collection_name(), collection_schema);
    response->set_collection_name(collection_schema.collection_name_);
    response->set_dimension(collection_schema.dimension_);
    response->set_index_file_size(collection_schema.index_file_size_);
@@ -508,7 +590,7 @@ GrpcRequestHandler::CountCollection(::grpc::ServerContext* context, const ::milv
    CHECK_NULLPTR_RETURN(request);

    int64_t row_count = 0;
    Status status = request_handler_.CountCollection(context_map_[context], request->collection_name(), row_count);
    Status status = request_handler_.CountCollection(GetContext(context), request->collection_name(), row_count);
    response->set_collection_row_count(row_count);
    SET_RESPONSE(response->mutable_status(), status, context);
    return ::grpc::Status::OK;
@@ -520,7 +602,7 @@ GrpcRequestHandler::ShowCollections(::grpc::ServerContext* context, const ::milv
    CHECK_NULLPTR_RETURN(request);

    std::vector<std::string> collections;
    Status status = request_handler_.ShowCollections(context_map_[context], collections);
    Status status = request_handler_.ShowCollections(GetContext(context), collections);
    for (auto& collection : collections) {
        response->add_collection_names(collection);
    }
@@ -536,7 +618,7 @@ GrpcRequestHandler::ShowCollectionInfo(::grpc::ServerContext* context, const ::m

    CollectionInfo collection_info;
    Status status =
        request_handler_.ShowCollectionInfo(context_map_[context], request->collection_name(), collection_info);
        request_handler_.ShowCollectionInfo(GetContext(context), request->collection_name(), collection_info);
    ConstructCollectionInfo(collection_info, response);
    SET_RESPONSE(response->mutable_status(), status, context);

@@ -549,7 +631,7 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Co
    CHECK_NULLPTR_RETURN(request);

    std::string reply;
    Status status = request_handler_.Cmd(context_map_[context], request->cmd(), reply);
    Status status = request_handler_.Cmd(GetContext(context), request->cmd(), reply);
    response->set_string_reply(reply);
    SET_RESPONSE(response->mutable_status(), status, context);

@@ -568,7 +650,7 @@ GrpcRequestHandler::DeleteByID(::grpc::ServerContext* context, const ::milvus::g
    }

    // step 2: delete vector
    Status status = request_handler_.DeleteByID(context_map_[context], request->collection_name(), vector_ids);
    Status status = request_handler_.DeleteByID(GetContext(context), request->collection_name(), vector_ids);
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -579,7 +661,7 @@ GrpcRequestHandler::PreloadCollection(::grpc::ServerContext* context, const ::mi
                                      ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.PreloadCollection(context_map_[context], request->collection_name());
    Status status = request_handler_.PreloadCollection(GetContext(context), request->collection_name());
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -591,7 +673,7 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus
    CHECK_NULLPTR_RETURN(request);

    IndexParam param;
    Status status = request_handler_.DescribeIndex(context_map_[context], request->collection_name(), param);
    Status status = request_handler_.DescribeIndex(GetContext(context), request->collection_name(), param);
    response->set_collection_name(param.collection_name_);
    response->set_index_type(param.index_type_);
    ::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
@@ -607,7 +689,7 @@ GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::gr
                              ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.DropIndex(context_map_[context], request->collection_name());
    Status status = request_handler_.DropIndex(GetContext(context), request->collection_name());
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -618,7 +700,7 @@ GrpcRequestHandler::CreatePartition(::grpc::ServerContext* context, const ::milv
                                    ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.CreatePartition(context_map_[context], request->collection_name(), request->tag());
    Status status = request_handler_.CreatePartition(GetContext(context), request->collection_name(), request->tag());
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -630,7 +712,7 @@ GrpcRequestHandler::ShowPartitions(::grpc::ServerContext* context, const ::milvu
    CHECK_NULLPTR_RETURN(request);

    std::vector<PartitionParam> partitions;
    Status status = request_handler_.ShowPartitions(context_map_[context], request->collection_name(), partitions);
    Status status = request_handler_.ShowPartitions(GetContext(context), request->collection_name(), partitions);
    for (auto& partition : partitions) {
        response->add_partition_tag_array(partition.tag_);
    }
@@ -645,7 +727,7 @@ GrpcRequestHandler::DropPartition(::grpc::ServerContext* context, const ::milvus
                                  ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.DropPartition(context_map_[context], request->collection_name(), request->tag());
    Status status = request_handler_.DropPartition(GetContext(context), request->collection_name(), request->tag());
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -660,7 +742,7 @@ GrpcRequestHandler::Flush(::grpc::ServerContext* context, const ::milvus::grpc::
    for (int32_t i = 0; i < request->collection_name_array().size(); i++) {
        collection_names.push_back(request->collection_name_array(i));
    }
    Status status = request_handler_.Flush(context_map_[context], collection_names);
    Status status = request_handler_.Flush(GetContext(context), collection_names);
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
@@ -671,7 +753,7 @@ GrpcRequestHandler::Compact(::grpc::ServerContext* context, const ::milvus::grpc
                            ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    Status status = request_handler_.Compact(context_map_[context], request->collection_name());
    Status status = request_handler_.Compact(GetContext(context), request->collection_name());
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
+3 −1
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@

#pragma once

#include <grpcpp/server_context.h>
#include <server/context/Context.h>

#include <cstdint>
@@ -311,7 +312,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
 private:
    RequestHandler request_handler_;

    std::unordered_map<::grpc::ServerContext*, std::shared_ptr<Context>> context_map_;
    // std::unordered_map<::grpc::ServerContext*, std::shared_ptr<Context>> context_map_;
    std::unordered_map<std::string, std::shared_ptr<Context>> context_map_;
    std::shared_ptr<opentracing::Tracer> tracer_;
    //    std::unordered_map<::grpc::ServerContext*, std::unique_ptr<opentracing::Span>> span_map_;