enhance: unable to compile C++ tests (#29616)

The tests need to call a private method, Milvus uses `#define` to
replace private with public, the hack trick works but would be broken if
the including order changed.

This uses friend to make all things work well

Signed-off-by: yah01 <yang.cen@zilliz.com>
Signed-off-by: yah01 <yah2er0ne@outlook.com>
pull/29659/head
yah01 2024-01-04 13:20:46 +08:00 committed by GitHub
parent 336fce0582
commit 99e0f1e65a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 25 deletions

View File

@ -52,8 +52,8 @@ IndexFactory::CreateScalarIndex(
//
template <>
inline ScalarIndexPtr<std::string>
IndexFactory::CreateScalarIndex(
ScalarIndexPtr<std::string>
IndexFactory::CreateScalarIndex<std::string>(
const IndexType& index_type,
const storage::FileManagerContext& file_manager_context,
DataType d_type) {
@ -88,7 +88,7 @@ IndexFactory::CreateScalarIndex(
template <>
ScalarIndexPtr<std::string>
IndexFactory::CreateScalarIndex(
IndexFactory::CreateScalarIndex<std::string>(
const IndexType& index_type,
const storage::FileManagerContext& file_manager_context,
std::shared_ptr<milvus_storage::Space> space,

View File

@ -83,6 +83,8 @@ class IndexFactory {
// IndexBasePtr
// CreateIndex(DataType dtype, const IndexType& index_type);
private:
FRIEND_TEST(StringIndexMarisaTest, Reverse);
template <typename T>
ScalarIndexPtr<T>
CreateScalarIndex(const IndexType& index_type,
@ -98,12 +100,12 @@ class IndexFactory {
DataType d_type = DataType::NONE);
};
template <>
ScalarIndexPtr<std::string>
IndexFactory::CreateScalarIndex<std::string>(
const IndexType& index_type,
const storage::FileManagerContext& file_manager_context,
DataType d_type);
// template <>
// ScalarIndexPtr<std::string>
// IndexFactory::CreateScalarIndex<std::string>(
// const IndexType& index_type,
// const storage::FileManagerContext& file_manager_context,
// DataType d_type);
template <>
ScalarIndexPtr<std::string>

View File

@ -15,8 +15,6 @@
#include "index/Index.h"
#include "index/ScalarIndex.h"
#define private public
#include "index/StringIndexMarisa.h"
#include "index/IndexFactory.h"
@ -28,6 +26,8 @@
constexpr int64_t nb = 100;
namespace schemapb = milvus::proto::schema;
namespace milvus {
namespace index {
class StringIndexBaseTest : public ::testing::Test {
protected:
void
@ -431,7 +431,7 @@ class StringIndexMarisaTestV2 : public StringIndexBaseTest {
auto vec_size = DIM * 4;
auto vec_field_data_type = milvus::DataType::VECTOR_FLOAT;
auto dataset = GenDataset(nb, knowhere::metric::L2, false);
auto dataset = ::GenDataset(nb, knowhere::metric::L2, false);
space = TestSpace(vec_size, dataset, strs);
}
@ -460,3 +460,6 @@ TEST_F(StringIndexMarisaTestV2, Base) {
new_index->LoadV2();
ASSERT_EQ(strs.size(), index->Count());
}
} // namespace index
} // namespace milvus

View File

@ -897,15 +897,14 @@ func TestProxy(t *testing.T) {
t.Run("show collections", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{
Base: nil,
DbName: dbName,
TimeStamp: 0,
Type: milvuspb.ShowType_All,
CollectionNames: nil,
Base: nil,
DbName: dbName,
TimeStamp: 0,
Type: milvuspb.ShowType_All,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 1, len(resp.CollectionNames), resp.CollectionNames)
assert.True(t, merr.Ok(resp.GetStatus()))
assert.Contains(t, resp.CollectionNames, collectionName, "collections: %v", resp.CollectionNames)
})
wg.Add(1)
@ -2386,7 +2385,7 @@ func TestProxy(t *testing.T) {
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
assert.Equal(t, 0, len(resp.CollectionNames))
assert.NotContains(t, resp.CollectionNames, collectionName)
})
username := "test_username_" + funcutil.RandomString(15)

View File

@ -52,11 +52,11 @@ func (t *showCollectionTask) Execute(ctx context.Context) error {
t.Rsp.Status = merr.Status(err)
return err
}
for _, meta := range colls {
t.Rsp.CollectionNames = append(t.Rsp.CollectionNames, meta.Name)
t.Rsp.CollectionIds = append(t.Rsp.CollectionIds, meta.CollectionID)
t.Rsp.CreatedTimestamps = append(t.Rsp.CreatedTimestamps, meta.CreateTime)
physical, _ := tsoutil.ParseHybridTs(meta.CreateTime)
for _, coll := range colls {
t.Rsp.CollectionNames = append(t.Rsp.CollectionNames, coll.Name)
t.Rsp.CollectionIds = append(t.Rsp.CollectionIds, coll.CollectionID)
t.Rsp.CreatedTimestamps = append(t.Rsp.CreatedTimestamps, coll.CreateTime)
physical, _ := tsoutil.ParseHybridTs(coll.CreateTime)
t.Rsp.CreatedUtcTimestamps = append(t.Rsp.CreatedUtcTimestamps, uint64(physical))
}
return nil