Unverified Commit 3c7c9ad6 authored by BossZou's avatar BossZou Committed by GitHub
Browse files

Change url for behavior 'get_entities_by_id' (#2330) (#2336)



Signed-off-by: default avataryhz <413554850@qq.com>
parent b7f410e4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -65,6 +65,7 @@ Please mark all change in change log and use the issue from GitHub
-   \#2256 k-means clustering algorithm use only Euclidean distance metric
-   \#2300 Upgrade mishrads configuration to version 0.4
-   \#2311 Update mishards methods 
-   \#2330 Change url for behavior 'get_entities_by_id'

## Task

+3 −3
Original line number Diff line number Diff line
@@ -577,11 +577,11 @@ class WebController : public oatpp::web::server::api::ApiController {
     *
     * GetVectorByID ?id=
     */
    ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name),
             BODY_STRING(String, body), QUERIES(const QueryParams&, query_params)) {
    ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
             PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
        auto handler = WebRequestHandler();
        String response;
        auto status_dto = handler.GetVector(collection_name, body, query_params, response);
        auto status_dto = handler.GetVector(collection_name, query_params, response);

        switch (status_dto->code->getValue()) {
            case StatusCode::SUCCESS:
+8 −10
Original line number Diff line number Diff line
@@ -1693,22 +1693,20 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se
}

StatusDto::ObjectWrapper
WebRequestHandler::GetVector(const OString& collection_name, const OString& body, const OQueryParams& query_params,
                             OString& response) {
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
    auto status = Status::OK();
    try {
        auto body_json = nlohmann::json::parse(body->c_str());
        if (!body_json.contains("ids")) {
            RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'ids\' is required.")
        }
        auto ids = body_json["ids"];
        if (!ids.is_array()) {
            RETURN_STATUS_DTO(BODY_PARSE_FAIL, "Field \'ids\' must be a array.")
        auto query_ids = query_params.get("ids");
        if (query_ids == nullptr || query_ids.get() == nullptr) {
            RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param ids is required.");
        }

        std::vector<std::string> ids;
        StringHelpFunctions::SplitStringByDelimeter(query_ids->c_str(), ",", ids);

        std::vector<int64_t> vector_ids;
        for (auto& id : ids) {
            vector_ids.push_back(std::stol(id.get<std::string>()));
            vector_ids.push_back(std::stol(id));
        }
        engine::VectorsData vectors;
        nlohmann::json vectors_json;
+1 −1
Original line number Diff line number Diff line
@@ -220,7 +220,7 @@ class WebRequestHandler {
    InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);

    StatusDto::ObjectWrapper
    GetVector(const OString& collection_name, const OString& body, const OQueryParams& query_params, OString& response);
    GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);

    StatusDto::ObjectWrapper
    VectorsOp(const OString& collection_name, const OString& payload, OString& response);
+7 −5
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@
#include "server/web_impl/handler/WebRequestHandler.h"
#include "src/version.h"
#include "utils/CommonUtil.h"
#include "utils/StringHelpFunctions.h"

static const char* COLLECTION_NAME = "test_web";

@@ -325,7 +326,7 @@ class TestClient : public oatpp::web::client::ApiClient {
             PATH(String, collection_name, "collection_name"))

    API_CALL("GET", "/collections/{collection_name}/vectors", getVectors,
             PATH(String, collection_name, "collection_name"), BODY_STRING(String, body))
             PATH(String, collection_name, "collection_name"), QUERY(String, ids))

    API_CALL("POST", "/collections/{collection_name}/vectors", insert,
             PATH(String, collection_name, "collection_name"), BODY_STRING(String, body))
@@ -1302,9 +1303,10 @@ TEST_F(WebControllerTest, GET_VECTORS_BY_IDS) {
    for (size_t i = 0; i < 10; i++) {
        vector_ids.emplace_back(ids.at(i));
    }
    auto body = nlohmann::json();
    body["ids"] = vector_ids;
    auto response = client_ptr->getVectors(collection_name, body.dump().c_str(), conncetion_ptr);

    std::string query_ids;
    milvus::server::StringHelpFunctions::MergeStringWithDelimeter(vector_ids, ",", query_ids);
    auto response = client_ptr->getVectors(collection_name, query_ids.c_str(), conncetion_ptr);
    ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()) << response->readBodyToString()->c_str();

    // validate result
@@ -1329,7 +1331,7 @@ TEST_F(WebControllerTest, GET_VECTORS_BY_IDS) {
    ASSERT_EQ(64, vec.size());

    // non-existent collection
    response = client_ptr->getVectors(collection_name + "_non_existent", body.dump().c_str(), conncetion_ptr);
    response = client_ptr->getVectors(collection_name + "_non_existent", query_ids.c_str(), conncetion_ptr);
    ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode()) << response->readBodyToString()->c_str();
}